Commit 1ca9e3e4 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Use host python training loop in GPU BERT tests

PiperOrigin-RevId: 295869937
parent b95064af
......@@ -81,7 +81,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
steps_per_loop = 1
steps_per_loop = 100
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = run_classifier.get_dataset_fn(
......
......@@ -46,7 +46,7 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
def on_batch_end(self, batch, logs=None):
self.batch_stop_times[batch] = time.time()
def get_examples_per_sec(self, batch_size, num_batches_to_skip=10):
def get_examples_per_sec(self, batch_size, num_batches_to_skip=1):
batch_durations = []
for batch in self.batch_start_times:
if batch in self.batch_stop_times and batch >= num_batches_to_skip:
......@@ -92,7 +92,8 @@ class BertBenchmarkBase(PerfZeroBenchmark):
'name':
'exp_per_second',
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
FLAGS.steps_per_loop)
})
else:
metrics.append({
......
......@@ -155,7 +155,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 1
FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
......@@ -414,7 +414,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 1
FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
......@@ -508,7 +508,7 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 1
FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
......@@ -581,7 +581,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 1
FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
......
......@@ -329,12 +329,12 @@ def run_customized_training_loop(
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
def _run_callbacks_on_batch_end(batch, logs):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch)
callback.on_batch_end(batch, logs)
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
......@@ -363,18 +363,19 @@ def run_customized_training_loop(
# Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if steps == 1:
if tf.test.is_built_with_cuda():
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
for _ in range(steps):
train_single_step(train_iterator)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
train_loss = _float_metric_value(train_loss_metric)
_run_callbacks_on_batch_end(current_step, {'loss': train_loss})
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, train_loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment