"tools/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "6d89995687e7853946175314a893a11fdc695a0c"
Commit abf60128 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Deprecate old customized training loop for run_classifier.py as compile/fit...

Deprecate old customized training loop for run_classifier.py as compile/fit fully satisfy needs/performance.

PiperOrigin-RevId: 313660745
parent 94b1efc1
...@@ -93,7 +93,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -93,7 +93,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
max_seq_length, max_seq_length,
FLAGS.eval_batch_size, FLAGS.eval_batch_size,
is_training=False) is_training=False)
run_classifier.run_bert_classifier( _, summary = run_classifier.run_bert_classifier(
strategy, strategy,
bert_config, bert_config,
input_meta_data, input_meta_data,
...@@ -107,7 +107,9 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -107,7 +107,9 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
FLAGS.init_checkpoint, FLAGS.init_checkpoint,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
training_callbacks=False,
custom_callbacks=callbacks) custom_callbacks=callbacks)
return summary
class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
...@@ -142,12 +144,10 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): ...@@ -142,12 +144,10 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
use_ds=True): use_ds=True):
"""Starts BERT performance benchmark test.""" """Starts BERT performance benchmark test."""
start_time_sec = time.time() start_time_sec = time.time()
self._run_bert_classifier(callbacks=[self.timer_callback], use_ds=use_ds) summary = self._run_bert_classifier(
callbacks=[self.timer_callback], use_ds=use_ds)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
# Since we do not load from any pretrained checkpoints, we ignore all # Since we do not load from any pretrained checkpoints, we ignore all
# accuracy metrics. # accuracy metrics.
summary.pop('eval_metrics', None) summary.pop('eval_metrics', None)
...@@ -246,8 +246,7 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): ...@@ -246,8 +246,7 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
self._run_and_report_benchmark(summary_path, use_ds=False) self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_8_gpu_amp_mrpc(self): def benchmark_8_gpu_amp_mrpc(self):
"""Test BERT model performance with 8 GPUs with automatic mixed precision. """Test BERT model performance with 8 GPUs with automatic mixed precision."""
"""
self._setup() self._setup()
self.num_gpus = 8 self.num_gpus = 8
...@@ -308,12 +307,9 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase): ...@@ -308,12 +307,9 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
"""Starts BERT accuracy benchmark test.""" """Starts BERT accuracy benchmark test."""
start_time_sec = time.time() start_time_sec = time.time()
self._run_bert_classifier(callbacks=[self.timer_callback]) summary = self._run_bert_classifier(callbacks=[self.timer_callback])
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
super(BertClassifyAccuracy, self)._report_benchmark( super(BertClassifyAccuracy, self)._report_benchmark(
stats=summary, stats=summary,
wall_time_sec=wall_time_sec, wall_time_sec=wall_time_sec,
......
...@@ -33,7 +33,6 @@ from official.nlp.bert import common_flags ...@@ -33,7 +33,6 @@ from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils from official.nlp.bert import model_saving_utils
from official.nlp.bert import model_training_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
...@@ -110,9 +109,8 @@ def run_bert_classifier(strategy, ...@@ -110,9 +109,8 @@ def run_bert_classifier(strategy,
init_checkpoint, init_checkpoint,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
custom_callbacks=None, training_callbacks=True,
run_eagerly=False, custom_callbacks=None):
use_keras_compile_fit=False):
"""Run BERT classifier training using low-level API.""" """Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels'] num_classes = input_meta_data['num_labels']
...@@ -142,46 +140,26 @@ def run_bert_classifier(strategy, ...@@ -142,46 +140,26 @@ def run_bert_classifier(strategy,
# correct device and strategy scope. # correct device and strategy scope.
def metric_fn(): def metric_fn():
return tf.keras.metrics.SparseCategoricalAccuracy( return tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32) 'accuracy', dtype=tf.float32)
if use_keras_compile_fit: # Start training using Keras compile/fit API.
# Start training using Keras compile/fit API. logging.info('Training using TF 2.x Keras compile/fit API with '
logging.info('Training using TF 2.0 Keras compile/fit API with '
'distribution strategy.')
return run_keras_compile_fit(
model_dir,
strategy,
_get_classifier_model,
train_input_fn,
eval_input_fn,
loss_fn,
metric_fn,
init_checkpoint,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
custom_callbacks=custom_callbacks)
# Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with '
'distribution strategy.') 'distribution strategy.')
return model_training_utils.run_customized_training_loop( return run_keras_compile_fit(
strategy=strategy, model_dir,
model_fn=_get_classifier_model, strategy,
loss_fn=loss_fn, _get_classifier_model,
model_dir=model_dir, train_input_fn,
steps_per_epoch=steps_per_epoch, eval_input_fn,
steps_per_loop=steps_per_loop, loss_fn,
epochs=epochs, metric_fn,
train_input_fn=train_input_fn, init_checkpoint,
eval_input_fn=eval_input_fn, epochs,
eval_steps=eval_steps, steps_per_epoch,
init_checkpoint=init_checkpoint, steps_per_loop,
sub_model_export_name=FLAGS.sub_model_export_name, eval_steps,
metric_fn=metric_fn, training_callbacks=training_callbacks,
custom_callbacks=custom_callbacks, custom_callbacks=custom_callbacks)
run_eagerly=run_eagerly)
def run_keras_compile_fit(model_dir, def run_keras_compile_fit(model_dir,
...@@ -196,6 +174,7 @@ def run_keras_compile_fit(model_dir, ...@@ -196,6 +174,7 @@ def run_keras_compile_fit(model_dir,
steps_per_epoch, steps_per_epoch,
steps_per_loop, steps_per_loop,
eval_steps, eval_steps,
training_callbacks=True,
custom_callbacks=None): custom_callbacks=None):
"""Runs BERT classifier model using Keras compile/fit API.""" """Runs BERT classifier model using Keras compile/fit API."""
...@@ -226,20 +205,25 @@ def run_keras_compile_fit(model_dir, ...@@ -226,20 +205,25 @@ def run_keras_compile_fit(model_dir,
checkpoint_interval=0) checkpoint_interval=0)
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager) checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
if custom_callbacks is not None: if training_callbacks:
custom_callbacks += [summary_callback, checkpoint_callback] if custom_callbacks is not None:
else: custom_callbacks += [summary_callback, checkpoint_callback]
custom_callbacks = [summary_callback, checkpoint_callback] else:
custom_callbacks = [summary_callback, checkpoint_callback]
bert_model.fit( history = bert_model.fit(
x=training_dataset, x=training_dataset,
validation_data=evaluation_dataset, validation_data=evaluation_dataset,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
epochs=epochs, epochs=epochs,
validation_steps=eval_steps, validation_steps=eval_steps,
callbacks=custom_callbacks) callbacks=custom_callbacks)
stats = {'total_training_steps': steps_per_epoch * epochs}
return bert_model if 'loss' in history.history:
stats['train_loss'] = history.history['loss'][-1]
if 'val_accuracy' in history.history:
stats['eval_metrics'] = history.history['val_accuracy'][-1]
return bert_model, stats
def get_predictions_and_labels(strategy, def get_predictions_and_labels(strategy,
...@@ -364,7 +348,7 @@ def run_bert(strategy, ...@@ -364,7 +348,7 @@ def run_bert(strategy,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir)) logdir=FLAGS.model_dir))
trained_model = run_bert_classifier( trained_model, _ = run_bert_classifier(
strategy, strategy,
model_config, model_config,
input_meta_data, input_meta_data,
...@@ -378,8 +362,6 @@ def run_bert(strategy, ...@@ -378,8 +362,6 @@ def run_bert(strategy,
init_checkpoint or FLAGS.init_checkpoint, init_checkpoint or FLAGS.init_checkpoint,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks)
if FLAGS.model_export_path: if FLAGS.model_export_path:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment