Commit 376e65a6 authored by Timothy Liu's avatar Timothy Liu
Browse files

Added automatic mixed precision and XLA options to run_tf_glue.py

parent 86f23a19
...@@ -6,6 +6,11 @@ from transformers import BertTokenizer, TFBertForSequenceClassification, glue_co ...@@ -6,6 +6,11 @@ from transformers import BertTokenizer, TFBertForSequenceClassification, glue_co
# script parameters # script parameters
BATCH_SIZE = 32 BATCH_SIZE = 32
EVAL_BATCH_SIZE = BATCH_SIZE * 2 EVAL_BATCH_SIZE = BATCH_SIZE * 2
USE_XLA = False
USE_AMP = False
tf.config.optimizer.set_jit(USE_XLA)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
# Load tokenizer and model from pretrained model/vocabulary # Load tokenizer and model from pretrained model/vocabulary
tokenizer = BertTokenizer.from_pretrained('bert-base-cased') tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
...@@ -23,10 +28,13 @@ train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1) ...@@ -23,10 +28,13 @@ train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE) valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule # Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08) opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
if USE_AMP:
# loss scaling is currently required when using mixed precision
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) model.compile(optimizer=opt, loss=loss, metrics=[metric])
# Train and evaluate using tf.keras.Model.fit() # Train and evaluate using tf.keras.Model.fit()
train_steps = train_examples//BATCH_SIZE train_steps = train_examples//BATCH_SIZE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment