Commit e67a2064 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Change summary directory and model checkpoint directory so that training via...

Change summary directory and model checkpoint directory so that training via Keras Compile/Fit() and custom training loop is consistent.

PiperOrigin-RevId: 274202793
parent ad1a37c9
......@@ -148,7 +148,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_mrpc_xla(self):
......@@ -165,7 +166,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.eval_batch_size = 4
FLAGS.enable_xla = True
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_mrpc_no_dist_strat(self):
......@@ -181,7 +183,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_2_gpu_mrpc(self):
......@@ -197,7 +200,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.train_batch_size = 8
FLAGS.eval_batch_size = 8
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_4_gpu_mrpc(self):
......@@ -212,7 +216,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 16
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_8_gpu_mrpc(self):
......@@ -225,7 +230,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_amp_mrpc_no_dist_strat(self):
......@@ -243,7 +249,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_8_gpu_amp_mrpc(self):
......@@ -262,7 +269,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
......@@ -320,7 +328,8 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_8_gpu_mrpc_xla(self):
......@@ -328,7 +337,8 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc_xla')
FLAGS.enable_xla = True
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
......
......@@ -52,7 +52,8 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
def _read_training_summary_from_file(self):
"""Reads the training summary from a file."""
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
with tf.io.gfile.GFile(summary_path, 'rb') as reader:
return json.loads(reader.read().decode('utf-8'))
......
......@@ -122,7 +122,8 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
# Sets timer_callback to None as we do not use it now.
self.timer_callback = None
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
......
......@@ -72,9 +72,9 @@ def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
return steps_per_loop
def write_txt_summary(training_summary, model_dir):
def write_txt_summary(training_summary, summary_dir):
"""Writes a summary text file to record stats."""
summary_path = os.path.join(model_dir, _SUMMARY_TXT)
summary_path = os.path.join(summary_dir, _SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary))
f.write(json.dumps(training_summary, indent=4))
......@@ -221,13 +221,14 @@ def run_customized_training_loop(
]
# Create summary writers
summary_dir = os.path.join(model_dir, 'summaries')
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/eval'))
os.path.join(summary_dir, 'eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/train'))
os.path.join(summary_dir, 'train'))
else:
train_summary_writer = None
......@@ -415,6 +416,6 @@ def run_customized_training_loop(
train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
write_txt_summary(training_summary, model_dir)
write_txt_summary(training_summary, summary_dir)
return model
......@@ -185,7 +185,8 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
# Two checkpoints should be saved after two epochs.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*')))
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'training_summary*')))
tf.io.gfile.glob(
os.path.join(model_dir, 'summaries/training_summary*')))
# Loss and accuracy values should be written into summaries.
self.assertTrue(
......
......@@ -26,10 +26,10 @@ import tensorflow as tf
import typing
def export_bert_model(
model_export_path: typing.Text,
model: tf.keras.Model,
checkpoint_dir: typing.Optional[typing.Text] = None) -> None:
def export_bert_model(model_export_path: typing.Text,
model: tf.keras.Model,
checkpoint_dir: typing.Optional[typing.Text] = None,
restore_model_using_load_weights: bool = False) -> None:
"""Export BERT model for serving which does not include the optimizer.
Arguments:
......@@ -37,6 +37,14 @@ def export_bert_model(
model: Keras model object to export.
checkpoint_dir: Path from which model weights will be loaded, if
specified.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API.
There are 2 different ways to save checkpoints. One is using
tf.train.Checkpoint and another is using Keras model.save_weights().
Custom training loop implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint callback internally uses model.save_weights()
API. Since these two API's cannot be used toghether, model loading logic
must be take into account how model checkpoint was saved.
Raises:
ValueError when either model_export_path or model is not specified.
......@@ -47,13 +55,24 @@ def export_bert_model(
raise ValueError('model must be a tf.keras.Model object.')
if checkpoint_dir:
# Restores the model from latest checkpoint.
checkpoint = tf.train.Checkpoint(model=model)
latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file).assert_existing_objects_matched()
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if restore_model_using_load_weights:
model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
assert tf.io.gfile.exists(model_weight_path)
model.load_weights(model_weight_path)
# tf.train.Checkpoint API was used via custom training loop logic.
else:
checkpoint = tf.train.Checkpoint(model=model)
# Restores the model from latest checkpoint.
latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
model.save(model_export_path, include_optimizer=False, save_format='tf')
......
......@@ -92,7 +92,8 @@ def run_bert_classifier(strategy,
initial_lr,
init_checkpoint,
custom_callbacks=None,
run_eagerly=False):
run_eagerly=False,
use_keras_compile_fit=False):
"""Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels']
......@@ -142,7 +143,7 @@ def run_bert_classifier(strategy,
return tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
if FLAGS.use_keras_compile_fit:
if use_keras_compile_fit:
# Start training using Keras compile/fit API.
logging.info('Training using TF 2.0 Keras compile/fit API with '
'distrubuted strategy.')
......@@ -206,9 +207,11 @@ def run_keras_compile_fit(model_dir,
bert_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[metric_fn()])
summary_callback = tf.keras.callbacks.TensorBoard(model_dir)
checkpoint_dir = os.path.join(model_dir, 'model_checkpoint.{epoch:02d}')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir)
summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True)
if custom_callbacks is not None:
custom_callbacks += [summary_callback, checkpoint_callback]
......@@ -226,12 +229,21 @@ def run_keras_compile_fit(model_dir,
return bert_model
def export_classifier(model_export_path, input_meta_data):
def export_classifier(model_export_path, input_meta_data,
restore_model_using_load_weights):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API.
There are 2 different ways to save checkpoints. One is using
tf.train.Checkpoint and another is using Keras model.save_weights().
Custom training loop implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint callback internally uses model.save_weights()
API. Since these two API's cannot be used toghether, model loading logic
must be take into account how model checkpoint was saved.
Raises:
Export path is not specified, got an empty string or None.
......@@ -243,14 +255,22 @@ def export_classifier(model_export_path, input_meta_data):
classifier_model = bert_models.classifier_model(
bert_config, tf.float32, input_meta_data['num_labels'],
input_meta_data['max_seq_length'])[0]
model_saving_utils.export_bert_model(
model_export_path, model=classifier_model, checkpoint_dir=FLAGS.model_dir)
model_export_path,
model=classifier_model,
checkpoint_dir=FLAGS.model_dir,
restore_model_using_load_weights=restore_model_using_load_weights)
def run_bert(strategy, input_meta_data):
"""Run BERT training."""
if FLAGS.mode == 'export_only':
export_classifier(FLAGS.model_export_path, input_meta_data)
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
export_classifier(FLAGS.model_export_path, input_meta_data,
FLAGS.use_keras_compile_fit)
return
if FLAGS.mode != 'train_and_eval':
......@@ -281,11 +301,17 @@ def run_bert(strategy, input_meta_data):
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
run_eagerly=FLAGS.run_eagerly)
run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit)
if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
FLAGS.model_export_path,
model=trained_model,
restore_model_using_load_weights=FLAGS.use_keras_compile_fit)
return trained_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment