Commit 1ca9e3e4 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Use host python training loop in GPU BERT tests

PiperOrigin-RevId: 295869937
parent b95064af
...@@ -81,7 +81,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -81,7 +81,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
distribution_strategy='mirrored' if use_ds else 'off', distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus) num_gpus=self.num_gpus)
steps_per_loop = 1 steps_per_loop = 100
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
train_input_fn = run_classifier.get_dataset_fn( train_input_fn = run_classifier.get_dataset_fn(
......
...@@ -46,7 +46,7 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback): ...@@ -46,7 +46,7 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
self.batch_stop_times[batch] = time.time() self.batch_stop_times[batch] = time.time()
def get_examples_per_sec(self, batch_size, num_batches_to_skip=10): def get_examples_per_sec(self, batch_size, num_batches_to_skip=1):
batch_durations = [] batch_durations = []
for batch in self.batch_start_times: for batch in self.batch_start_times:
if batch in self.batch_stop_times and batch >= num_batches_to_skip: if batch in self.batch_stop_times and batch >= num_batches_to_skip:
...@@ -92,7 +92,8 @@ class BertBenchmarkBase(PerfZeroBenchmark): ...@@ -92,7 +92,8 @@ class BertBenchmarkBase(PerfZeroBenchmark):
'name': 'name':
'exp_per_second', 'exp_per_second',
'value': 'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size) self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
FLAGS.steps_per_loop)
}) })
else: else:
metrics.append({ metrics.append({
......
...@@ -155,7 +155,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -155,7 +155,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1 FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 1 FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
...@@ -414,7 +414,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase): ...@@ -414,7 +414,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
FLAGS.num_train_epochs = 2 FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 1 FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
...@@ -508,7 +508,7 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase): ...@@ -508,7 +508,7 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
FLAGS.num_train_epochs = 2 FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 1 FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
...@@ -581,7 +581,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase): ...@@ -581,7 +581,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1 FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 1 FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
......
...@@ -329,12 +329,12 @@ def run_customized_training_loop( ...@@ -329,12 +329,12 @@ def run_customized_training_loop(
for callback in custom_callbacks: for callback in custom_callbacks:
callback.on_batch_begin(batch) callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch): def _run_callbacks_on_batch_end(batch, logs):
"""Runs custom callbacks at the end of every step.""" """Runs custom callbacks at the end of every step."""
if not custom_callbacks: if not custom_callbacks:
return return
for callback in custom_callbacks: for callback in custom_callbacks:
callback.on_batch_end(batch) callback.on_batch_end(batch, logs)
# Training loop starts here. # Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
...@@ -363,18 +363,19 @@ def run_customized_training_loop( ...@@ -363,18 +363,19 @@ def run_customized_training_loop(
# Runs several steps in the host while loop. # Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if steps == 1: if tf.test.is_built_with_cuda():
# TODO(zongweiz): merge with train_steps once tf.while_loop # TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed. # GPU performance bugs are fixed.
for _ in range(steps):
train_single_step(train_iterator) train_single_step(train_iterator)
else: else:
# Converts steps to a Tensor to avoid tf.function retracing. # Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator, train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32)) tf.convert_to_tensor(steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step) train_loss = _float_metric_value(train_loss_metric)
_run_callbacks_on_batch_end(current_step, {'loss': train_loss})
current_step += steps current_step += steps
train_loss = _float_metric_value(train_loss_metric)
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, train_loss) current_step, total_training_steps, train_loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment