run_tf_glue.py 1.82 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
import tensorflow as tf
import tensorflow_datasets
thomwolf's avatar
thomwolf committed
3
from transformers import *
thomwolf's avatar
thomwolf committed
4

thomwolf's avatar
thomwolf committed
5
# Load dataset, tokenizer, model from pretrained model/vocabulary
thomwolf's avatar
thomwolf committed
6
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
thomwolf's avatar
thomwolf committed
7
8
dataset = tensorflow_datasets.load('glue/mrpc')
model = TFBertForSequenceClassification.from_pretrained('bert-base-cased')
thomwolf's avatar
thomwolf committed
9

thomwolf's avatar
thomwolf committed
10
11
12
# Prepare dataset for GLUE as a tf.data.Dataset instance
train_dataset = glue_convert_examples_to_features(dataset['train'], tokenizer, task='mrpc')
valid_dataset = glue_convert_examples_to_features(dataset['validation'], tokenizer, task='mrpc')
thomwolf's avatar
thomwolf committed
13
14
15
train_dataset = train_dataset.shuffle(100).batch(32).repeat(3)
valid_dataset = valid_dataset.batch(64)

thomwolf's avatar
thomwolf committed
16
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule 
thomwolf's avatar
thomwolf committed
17
learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(2e-5, 345, end_learning_rate=0)
18
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, clipnorm=1.0)
thomwolf's avatar
thomwolf committed
19
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
thomwolf's avatar
thomwolf committed
20
21

model.compile(optimizer=optimizer, loss=loss, metrics=['sparse_categorical_accuracy'])
thomwolf's avatar
thomwolf committed
22
23

# Train and evaluate using tf.keras.Model.fit()
thomwolf's avatar
thomwolf committed
24
25
model.fit(train_dataset, epochs=3, steps_per_epoch=115,
          validation_data=valid_dataset, validation_steps=7)
thomwolf's avatar
thomwolf committed
26

thomwolf's avatar
thomwolf committed
27
28
29
# Save the TensorFlow model and load it in PyTorch
model.save_pretrained('./save/')
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)
thomwolf's avatar
thomwolf committed
30

thomwolf's avatar
thomwolf committed
31
32
33
34
35
36
37
# Quickly inspect a few predictions - MRPC is a paraphrasing task
inputs = tokenizer.encode_plus("The company is doing great",
                               "The company has good results",
                               add_special_tokens=True,
                               return_tensors='pt')
pred = pytorch_model(**inputs)
print("Paraphrase" if pred.argmax().item() == 0 else "Not paraphrase")