Commit 7207422d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 273653001
parent dc93d9e5
......@@ -56,6 +56,10 @@ def define_common_bert_flags():
'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica '
'loss function.')
flags.DEFINE_boolean(
'use_keras_compile_fit', False,
'If True, uses Keras compile/fit() API for training logic. Otherwise '
'use custom training loop.')
# Adds flags for mixed precision training.
flags_core.define_performance(
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import functools
import json
import math
import os
from absl import app
from absl import flags
......@@ -82,7 +83,7 @@ def get_loss_fn(num_classes, loss_factor=1.0):
return classification_loss_fn
def run_customized_training(strategy,
def run_bert_classifier(strategy,
bert_config,
input_meta_data,
model_dir,
......@@ -144,6 +145,27 @@ def run_customized_training(strategy,
return tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
if FLAGS.use_keras_compile_fit:
# Start training using Keras compile/fit API.
logging.info('Training using TF 2.0 Keras compile/fit API with '
'distrubuted strategy.')
return run_keras_compile_fit(
model_dir,
strategy,
_get_classifier_model,
train_input_fn,
eval_input_fn,
loss_fn,
metric_fn,
init_checkpoint,
epochs,
steps_per_epoch,
eval_steps,
custom_callbacks=None)
# Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with '
'distrubuted strategy.')
return model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_classifier_model,
......@@ -161,6 +183,52 @@ def run_customized_training(strategy,
run_eagerly=run_eagerly)
def run_keras_compile_fit(model_dir,
strategy,
model_fn,
train_input_fn,
eval_input_fn,
loss_fn,
metric_fn,
init_checkpoint,
epochs,
steps_per_epoch,
eval_steps,
custom_callbacks=None):
"""Runs BERT classifier model using Keras compile/fit API."""
with strategy.scope():
training_dataset = train_input_fn()
evaluation_dataset = eval_input_fn()
bert_model, sub_model = model_fn()
optimizer = bert_model.optimizer
if init_checkpoint:
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
bert_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[metric_fn()])
summary_callback = tf.keras.callbacks.TensorBoard(model_dir)
checkpoint_dir = os.path.join(model_dir, 'model_checkpoint.{epoch:02d}')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir)
if custom_callbacks is not None:
custom_callbacks += [summary_callback, checkpoint_callback]
else:
custom_callbacks = [summary_callback, checkpoint_callback]
bert_model.fit(
x=training_dataset,
validation_data=evaluation_dataset,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_steps=eval_steps,
callbacks=custom_callbacks)
return bert_model
def export_classifier(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
......@@ -203,10 +271,8 @@ def run_bert(strategy, input_meta_data):
if not strategy:
raise ValueError('Distribution strategy has not been specified.')
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
trained_model = run_customized_training(
trained_model = run_bert_classifier(
strategy,
bert_config,
input_meta_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment