run_tf_glue.py 2.66 KB
Newer Older
1
import os
thomwolf's avatar
thomwolf committed
2
3
import tensorflow as tf
import tensorflow_datasets
thomwolf's avatar
thomwolf committed
4
from transformers import BertTokenizer, TFBertForSequenceClassification, glue_convert_examples_to_features, BertForSequenceClassification
thomwolf's avatar
thomwolf committed
5

6
7
8
9
10
# script parameters
BATCH_SIZE = 32
EVAL_BATCH_SIZE = BATCH_SIZE * 2

# Load tokenizer and model from pretrained model/vocabulary
thomwolf's avatar
thomwolf committed
11
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
thomwolf's avatar
thomwolf committed
12
model = TFBertForSequenceClassification.from_pretrained('bert-base-cased')
13
14
15
16
17

# Load dataset via TensorFlow Datasets
data, info = tensorflow_datasets.load('glue/mrpc', with_info=True)
train_examples = info.splits['train'].num_examples
valid_examples = info.splits['validation'].num_examples
thomwolf's avatar
thomwolf committed
18

thomwolf's avatar
thomwolf committed
19
# Prepare dataset for GLUE as a tf.data.Dataset instance
thomwolf's avatar
thomwolf committed
20
21
train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, 128, 'mrpc')
valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, 128, 'mrpc')
22
23
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
thomwolf's avatar
thomwolf committed
24

thomwolf's avatar
thomwolf committed
25
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule 
26
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
thomwolf's avatar
thomwolf committed
27
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
thomwolf's avatar
thomwolf committed
28
29
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
thomwolf's avatar
thomwolf committed
30
31

# Train and evaluate using tf.keras.Model.fit()
32
33
train_steps = train_examples//BATCH_SIZE
valid_steps = valid_examples//EVAL_BATCH_SIZE
thomwolf's avatar
thomwolf committed
34

35
36
37
38
39
history = model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps,
                    validation_data=valid_dataset, validation_steps=valid_steps)

# Save TF2 model
os.makedirs('./save/', exist_ok=True)
thomwolf's avatar
thomwolf committed
40
model.save_pretrained('./save/')
41
42

# Load the TensorFlow model in PyTorch for inspection
thomwolf's avatar
thomwolf committed
43
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)
thomwolf's avatar
thomwolf committed
44

thomwolf's avatar
thomwolf committed
45
# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
46
47
48
sentence_0 = 'This research was consistent with his findings.'
sentence_1 = 'His findings were compatible with this research.'
sentence_2 = 'His findings were not compatible with this research.'
thomwolf's avatar
thomwolf committed
49
50
51
52
53
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')

pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
54
55
print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0')