run_tf_glue.py 3.73 KB
Newer Older
1
import os
thomwolf's avatar
thomwolf committed
2
3
import tensorflow as tf
import tensorflow_datasets
Lysandre's avatar
Lysandre committed
4
from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig, glue_convert_examples_to_features, BertForSequenceClassification, glue_processors
thomwolf's avatar
thomwolf committed
5

6
7
8
# script parameters
BATCH_SIZE = 32
EVAL_BATCH_SIZE = BATCH_SIZE * 2
9
10
USE_XLA = False
USE_AMP = False
Lysandre's avatar
Lysandre committed
11
12
13
14
15
16
17
18
19
20
21
22
23
EPOCHS = 3

TASK = "mrpc"

if TASK == "sst-2":
    TFDS_TASK = "sst2"
elif TASK == "sts-b":
    TFDS_TASK = "stsb"
else: 
    TFDS_TASK = TASK

num_labels = len(glue_processors[TASK]().get_labels())
print(num_labels)
24
25
26

tf.config.optimizer.set_jit(USE_XLA)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
27

Lysandre's avatar
Lysandre committed
28
29
# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
config = BertConfig.from_pretrained("bert-base-cased", num_labels=num_labels)
thomwolf's avatar
thomwolf committed
30
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
Lysandre's avatar
Lysandre committed
31
model = TFBertForSequenceClassification.from_pretrained('bert-base-cased', config=config)
32
33

# Load dataset via TensorFlow Datasets
Lysandre's avatar
Lysandre committed
34
data, info = tensorflow_datasets.load(f'glue/{TFDS_TASK}', with_info=True)
35
train_examples = info.splits['train'].num_examples
Lysandre's avatar
Lysandre committed
36
37

# MNLI expects either validation_matched or validation_mismatched
38
valid_examples = info.splits['validation'].num_examples
thomwolf's avatar
thomwolf committed
39

thomwolf's avatar
thomwolf committed
40
# Prepare dataset for GLUE as a tf.data.Dataset instance
Lysandre's avatar
Lysandre committed
41
42
43
44
train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, 128, TASK)

# MNLI expects either validation_matched or validation_mismatched
valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, 128, TASK)
45
46
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
thomwolf's avatar
thomwolf committed
47

thomwolf's avatar
thomwolf committed
48
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule 
49
50
51
52
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
if USE_AMP:
    # loss scaling is currently required when using mixed precision
    opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
Lysandre's avatar
Lysandre committed
53
54
55
56
57
58
59


if num_labels == 1:
    loss = tf.keras.losses.MeanSquaredError()
else:
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

thomwolf's avatar
thomwolf committed
60
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
61
model.compile(optimizer=opt, loss=loss, metrics=[metric])
thomwolf's avatar
thomwolf committed
62
63

# Train and evaluate using tf.keras.Model.fit()
64
65
train_steps = train_examples//BATCH_SIZE
valid_steps = valid_examples//EVAL_BATCH_SIZE
thomwolf's avatar
thomwolf committed
66

Lysandre's avatar
Lysandre committed
67
history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_steps,
68
69
70
71
                    validation_data=valid_dataset, validation_steps=valid_steps)

# Save TF2 model
os.makedirs('./save/', exist_ok=True)
thomwolf's avatar
thomwolf committed
72
model.save_pretrained('./save/')
73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
if TASK == "mrpc":
    # Load the TensorFlow model in PyTorch for inspection
    pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)

    # Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
    sentence_0 = 'This research was consistent with his findings.'
    sentence_1 = 'His findings were compatible with this research.'
    sentence_2 = 'His findings were not compatible with this research.'
    inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
    inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')

    del inputs_1["special_tokens_mask"]
    del inputs_2["special_tokens_mask"]

    pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
    pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
    print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
    print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0')