Unverified Commit e38d570e authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7404)

262178259  by hongkuny<hongkuny@google.com>:

    We should call training=True in CTL train step.

--
262081759  by akuegel<akuegel@google.com>:

    Internal change

PiperOrigin-RevId: 262178259
parent b68a6503
...@@ -81,6 +81,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -81,6 +81,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
@flagsaver.flagsaver @flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False): def _train_squad(self, use_ds=True, run_eagerly=False):
"""Runs BERT SQuAD training.""" """Runs BERT SQuAD training."""
assert tf.version.VERSION.startswith('2.')
input_meta_data = self._read_input_meta_data_from_file() input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds) strategy = self._get_distribution_strategy(use_ds)
...@@ -93,6 +94,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -93,6 +94,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
@flagsaver.flagsaver @flagsaver.flagsaver
def _evaluate_squad(self, use_ds=True): def _evaluate_squad(self, use_ds=True):
"""Runs BERT SQuAD evaluation.""" """Runs BERT SQuAD evaluation."""
assert tf.version.VERSION.startswith('2.')
input_meta_data = self._read_input_meta_data_from_file() input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds) strategy = self._get_distribution_strategy(use_ds)
...@@ -160,7 +162,8 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -160,7 +162,8 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._setup() self._setup()
self.num_gpus = 1 self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad')
FLAGS.train_batch_size = 4 # XLA runs out of memory when running with batch size 4.
FLAGS.train_batch_size = 3
FLAGS.enable_xla = True FLAGS.enable_xla = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
......
...@@ -45,8 +45,11 @@ def define_common_bert_flags(): ...@@ -45,8 +45,11 @@ def define_common_bert_flags():
'inside.') 'inside.')
flags.DEFINE_float('learning_rate', 5e-5, flags.DEFINE_float('learning_rate', 5e-5,
'The initial learning rate for Adam.') 'The initial learning rate for Adam.')
flags.DEFINE_boolean(
'run_eagerly', False,
'Run the model op by op without building a model function.')
# add flags for mixed precision training. # Adds flags for mixed precision training.
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=False, num_parallel_calls=False,
inter_op=False, inter_op=False,
......
...@@ -243,7 +243,7 @@ def run_customized_training_loop( ...@@ -243,7 +243,7 @@ def run_customized_training_loop(
inputs, labels = inputs inputs, labels = inputs
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
model_outputs = model(inputs) model_outputs = model(inputs, training=True)
loss = loss_fn(labels, model_outputs) loss = loss_fn(labels, model_outputs)
if use_float16: if use_float16:
scaled_loss = optimizer.get_scaled_loss(loss) scaled_loss = optimizer.get_scaled_loss(loss)
......
...@@ -95,7 +95,8 @@ def run_customized_training(strategy, ...@@ -95,7 +95,8 @@ def run_customized_training(strategy,
initial_lr, initial_lr,
init_checkpoint, init_checkpoint,
use_remote_tpu=False, use_remote_tpu=False,
custom_callbacks=None): custom_callbacks=None,
run_eagerly=False):
"""Run BERT classifier training using low-level API.""" """Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels'] num_classes = input_meta_data['num_labels']
...@@ -143,7 +144,8 @@ def run_customized_training(strategy, ...@@ -143,7 +144,8 @@ def run_customized_training(strategy,
init_checkpoint=init_checkpoint, init_checkpoint=init_checkpoint,
metric_fn=metric_fn, metric_fn=metric_fn,
use_remote_tpu=use_remote_tpu, use_remote_tpu=use_remote_tpu,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks,
run_eagerly=run_eagerly)
def export_classifier(model_export_path, input_meta_data): def export_classifier(model_export_path, input_meta_data):
...@@ -204,7 +206,8 @@ def run_bert(strategy, input_meta_data): ...@@ -204,7 +206,8 @@ def run_bert(strategy, input_meta_data):
warmup_steps, warmup_steps,
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.init_checkpoint, FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu) use_remote_tpu=use_remote_tpu,
run_eagerly=FLAGS.run_eagerly)
if FLAGS.model_export_path: if FLAGS.model_export_path:
with tf.device(model_training_utils.get_primary_cpu_task(use_remote_tpu)): with tf.device(model_training_utils.get_primary_cpu_task(use_remote_tpu)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment