run_tf_glue.py 3.92 KB
Newer Older
1
import os
Aymeric Augustin's avatar
Aymeric Augustin committed
2

thomwolf's avatar
thomwolf committed
3
4
import tensorflow as tf
import tensorflow_datasets
5

6
from transformers import (
Aymeric Augustin's avatar
Aymeric Augustin committed
7
8
    BertConfig,
    BertForSequenceClassification,
9
10
11
12
13
    BertTokenizer,
    TFBertForSequenceClassification,
    glue_convert_examples_to_features,
    glue_processors,
)
thomwolf's avatar
thomwolf committed
14

Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17
18
# script parameters
BATCH_SIZE = 32
EVAL_BATCH_SIZE = BATCH_SIZE * 2
19
20
USE_XLA = False
USE_AMP = False
Lysandre's avatar
Lysandre committed
21
22
23
24
25
26
27
28
EPOCHS = 3

TASK = "mrpc"

if TASK == "sst-2":
    TFDS_TASK = "sst2"
elif TASK == "sts-b":
    TFDS_TASK = "stsb"
29
else:
Lysandre's avatar
Lysandre committed
30
31
32
33
    TFDS_TASK = TASK

num_labels = len(glue_processors[TASK]().get_labels())
print(num_labels)
34
35
36

tf.config.optimizer.set_jit(USE_XLA)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
37

Lysandre's avatar
Lysandre committed
38
39
# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
config = BertConfig.from_pretrained("bert-base-cased", num_labels=num_labels)
40
41
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
model = TFBertForSequenceClassification.from_pretrained("bert-base-cased", config=config)
42
43

# Load dataset via TensorFlow Datasets
44
45
data, info = tensorflow_datasets.load(f"glue/{TFDS_TASK}", with_info=True)
train_examples = info.splits["train"].num_examples
Lysandre's avatar
Lysandre committed
46
47

# MNLI expects either validation_matched or validation_mismatched
48
valid_examples = info.splits["validation"].num_examples
thomwolf's avatar
thomwolf committed
49

thomwolf's avatar
thomwolf committed
50
# Prepare dataset for GLUE as a tf.data.Dataset instance
51
train_dataset = glue_convert_examples_to_features(data["train"], tokenizer, 128, TASK)
Lysandre's avatar
Lysandre committed
52
53

# MNLI expects either validation_matched or validation_mismatched
54
valid_dataset = glue_convert_examples_to_features(data["validation"], tokenizer, 128, TASK)
55
56
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
thomwolf's avatar
thomwolf committed
57

58
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
59
60
61
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
if USE_AMP:
    # loss scaling is currently required when using mixed precision
62
    opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
Lysandre's avatar
Lysandre committed
63
64
65
66
67
68
69


if num_labels == 1:
    loss = tf.keras.losses.MeanSquaredError()
else:
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

70
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
71
model.compile(optimizer=opt, loss=loss, metrics=[metric])
thomwolf's avatar
thomwolf committed
72
73

# Train and evaluate using tf.keras.Model.fit()
74
75
train_steps = train_examples // BATCH_SIZE
valid_steps = valid_examples // EVAL_BATCH_SIZE
thomwolf's avatar
thomwolf committed
76

77
78
79
80
81
82
83
history = model.fit(
    train_dataset,
    epochs=EPOCHS,
    steps_per_epoch=train_steps,
    validation_data=valid_dataset,
    validation_steps=valid_steps,
)
84
85

# Save TF2 model
86
87
os.makedirs("./save/", exist_ok=True)
model.save_pretrained("./save/")
88

89
90
if TASK == "mrpc":
    # Load the TensorFlow model in PyTorch for inspection
91
    # This is to demo the interoperability between the two frameworks, you don't have to
92
    # do this in real life (you can run the inference on the TF model).
93
    pytorch_model = BertForSequenceClassification.from_pretrained("./save/", from_tf=True)
94
95

    # Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
96
97
98
99
100
    sentence_0 = "This research was consistent with his findings."
    sentence_1 = "His findings were compatible with this research."
    sentence_2 = "His findings were not compatible with this research."
    inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors="pt")
    inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors="pt")
101
102
103
104
105
106

    del inputs_1["special_tokens_mask"]
    del inputs_2["special_tokens_mask"]

    pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
    pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
107
108
    print("sentence_1 is", "a paraphrase" if pred_1 else "not a paraphrase", "of sentence_0")
    print("sentence_2 is", "a paraphrase" if pred_2 else "not a paraphrase", "of sentence_0")