run_tf_glue.py 1.66 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import tensorflow as tf
import tensorflow_datasets
from pytorch_transformers import BertTokenizer, BertForSequenceClassification, TFBertForSequenceClassification, glue_convert_examples_to_features

# Load tokenizer, model, dataset
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
tf_model = TFBertForSequenceClassification.from_pretrained('bert-base-cased')
dataset = tensorflow_datasets.load("glue/mrpc")

# Prepare dataset for GLUE
train_dataset = glue_convert_examples_to_features(dataset['train'], tokenizer, task='mrpc', max_length=128)
valid_dataset = glue_convert_examples_to_features(dataset['validation'], tokenizer, task='mrpc', max_length=128)
train_dataset = train_dataset.shuffle(100).batch(32).repeat(3)
valid_dataset = valid_dataset.batch(64)

# Compile tf.keras model for training
learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(2e-5, 345, end_learning_rate=0)
18
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, clipnorm=1.0)
thomwolf's avatar
thomwolf committed
19
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
20
tf_model.compile(optimizer=optimizer, loss=loss, metrics=['sparse_categorical_accuracy'])
thomwolf's avatar
thomwolf committed
21
22

# Train and evaluate using tf.keras.Model.fit()
thomwolf's avatar
thomwolf committed
23
tf_model.fit(train_dataset, epochs=3, steps_per_epoch=115, validation_data=valid_dataset, validation_steps=7)
thomwolf's avatar
thomwolf committed
24
25
26

# Save the model and load it in PyTorch
tf_model.save_pretrained('./runs/')
thomwolf's avatar
thomwolf committed
27
pt_model = BertForSequenceClassification.from_pretrained('./runs/', from_tf=True)
thomwolf's avatar
thomwolf committed
28
29

# Quickly inspect a few predictions
thomwolf's avatar
thomwolf committed
30
31
inputs = tokenizer.encode_plus("I said the company is doing great", "The company has good results", add_special_tokens=True, return_tensors='pt')
pred = pt_model(torch.tensor(tokens))