Commit 7207422d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 273653001
parent dc93d9e5
...@@ -56,6 +56,10 @@ def define_common_bert_flags(): ...@@ -56,6 +56,10 @@ def define_common_bert_flags():
'scale_loss', False, 'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica ' 'Whether to divide the loss by number of replica inside the per-replica '
'loss function.') 'loss function.')
flags.DEFINE_boolean(
'use_keras_compile_fit', False,
'If True, uses Keras compile/fit() API for training logic. Otherwise '
'use custom training loop.')
# Adds flags for mixed precision training. # Adds flags for mixed precision training.
flags_core.define_performance( flags_core.define_performance(
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import functools import functools
import json import json
import math import math
import os
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -82,19 +83,19 @@ def get_loss_fn(num_classes, loss_factor=1.0): ...@@ -82,19 +83,19 @@ def get_loss_fn(num_classes, loss_factor=1.0):
return classification_loss_fn return classification_loss_fn
def run_customized_training(strategy, def run_bert_classifier(strategy,
bert_config, bert_config,
input_meta_data, input_meta_data,
model_dir, model_dir,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
steps_per_loop, steps_per_loop,
eval_steps, eval_steps,
warmup_steps, warmup_steps,
initial_lr, initial_lr,
init_checkpoint, init_checkpoint,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=False): run_eagerly=False):
"""Run BERT classifier training using low-level API.""" """Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels'] num_classes = input_meta_data['num_labels']
...@@ -144,6 +145,27 @@ def run_customized_training(strategy, ...@@ -144,6 +145,27 @@ def run_customized_training(strategy,
return tf.keras.metrics.SparseCategoricalAccuracy( return tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32) 'test_accuracy', dtype=tf.float32)
if FLAGS.use_keras_compile_fit:
# Start training using Keras compile/fit API.
logging.info('Training using TF 2.0 Keras compile/fit API with '
'distrubuted strategy.')
return run_keras_compile_fit(
model_dir,
strategy,
_get_classifier_model,
train_input_fn,
eval_input_fn,
loss_fn,
metric_fn,
init_checkpoint,
epochs,
steps_per_epoch,
eval_steps,
custom_callbacks=None)
# Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with '
'distrubuted strategy.')
return model_training_utils.run_customized_training_loop( return model_training_utils.run_customized_training_loop(
strategy=strategy, strategy=strategy,
model_fn=_get_classifier_model, model_fn=_get_classifier_model,
...@@ -161,6 +183,52 @@ def run_customized_training(strategy, ...@@ -161,6 +183,52 @@ def run_customized_training(strategy,
run_eagerly=run_eagerly) run_eagerly=run_eagerly)
def run_keras_compile_fit(model_dir,
strategy,
model_fn,
train_input_fn,
eval_input_fn,
loss_fn,
metric_fn,
init_checkpoint,
epochs,
steps_per_epoch,
eval_steps,
custom_callbacks=None):
"""Runs BERT classifier model using Keras compile/fit API."""
with strategy.scope():
training_dataset = train_input_fn()
evaluation_dataset = eval_input_fn()
bert_model, sub_model = model_fn()
optimizer = bert_model.optimizer
if init_checkpoint:
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
bert_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[metric_fn()])
summary_callback = tf.keras.callbacks.TensorBoard(model_dir)
checkpoint_dir = os.path.join(model_dir, 'model_checkpoint.{epoch:02d}')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir)
if custom_callbacks is not None:
custom_callbacks += [summary_callback, checkpoint_callback]
else:
custom_callbacks = [summary_callback, checkpoint_callback]
bert_model.fit(
x=training_dataset,
validation_data=evaluation_dataset,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_steps=eval_steps,
callbacks=custom_callbacks)
return bert_model
def export_classifier(model_export_path, input_meta_data): def export_classifier(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference. """Exports a trained model as a `SavedModel` for inference.
...@@ -203,10 +271,8 @@ def run_bert(strategy, input_meta_data): ...@@ -203,10 +271,8 @@ def run_bert(strategy, input_meta_data):
if not strategy: if not strategy:
raise ValueError('Distribution strategy has not been specified.') raise ValueError('Distribution strategy has not been specified.')
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted' trained_model = run_bert_classifier(
'strategy.')
trained_model = run_customized_training(
strategy, strategy,
bert_config, bert_config,
input_meta_data, input_meta_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment