"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "65e128250383d55527c87ff3fe331930c857da8e"
Commit 38486725 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 302977474
parent 7e6167a9
...@@ -269,6 +269,7 @@ python run_classifier.py \ ...@@ -269,6 +269,7 @@ python run_classifier.py \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \ --train_batch_size=32 \
--eval_batch_size=32 \ --eval_batch_size=32 \
--steps_per_loop=1000 \
--learning_rate=2e-5 \ --learning_rate=2e-5 \
--num_train_epochs=3 \ --num_train_epochs=3 \
--model_dir=${MODEL_DIR} \ --model_dir=${MODEL_DIR} \
...@@ -276,6 +277,10 @@ python run_classifier.py \ ...@@ -276,6 +277,10 @@ python run_classifier.py \
--tpu=grpc://${TPU_IP_ADDRESS}:8470 --tpu=grpc://${TPU_IP_ADDRESS}:8470
``` ```
Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
training steps inside a `tf.function` can significantly increase TPU utilization
and callbacks will not be called inside the loop.
### SQuAD 1.1 ### SQuAD 1.1
The Stanford Question Answering Dataset (SQuAD) is a popular question answering The Stanford Question Answering Dataset (SQuAD) is a popular question answering
......
...@@ -57,7 +57,7 @@ def define_common_bert_flags(): ...@@ -57,7 +57,7 @@ def define_common_bert_flags():
flags.DEFINE_integer('num_train_epochs', 3, flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.') 'Total number of training epochs to perform.')
flags.DEFINE_integer( flags.DEFINE_integer(
'steps_per_loop', 200, 'steps_per_loop', 1,
'Number of steps per graph-mode loop. Only training step ' 'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called ' 'happens inside the loop. Callbacks will not be called '
'inside.') 'inside.')
......
...@@ -156,6 +156,7 @@ def run_bert_classifier(strategy, ...@@ -156,6 +156,7 @@ def run_bert_classifier(strategy,
init_checkpoint, init_checkpoint,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
steps_per_loop,
eval_steps, eval_steps,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks)
...@@ -189,6 +190,7 @@ def run_keras_compile_fit(model_dir, ...@@ -189,6 +190,7 @@ def run_keras_compile_fit(model_dir,
init_checkpoint, init_checkpoint,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
steps_per_loop,
eval_steps, eval_steps,
custom_callbacks=None): custom_callbacks=None):
"""Runs BERT classifier model using Keras compile/fit API.""" """Runs BERT classifier model using Keras compile/fit API."""
...@@ -203,7 +205,11 @@ def run_keras_compile_fit(model_dir, ...@@ -203,7 +205,11 @@ def run_keras_compile_fit(model_dir,
checkpoint = tf.train.Checkpoint(model=sub_model) checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched() checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
bert_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[metric_fn()]) bert_model.compile(
optimizer=optimizer,
loss=loss_fn,
metrics=[metric_fn()],
experimental_steps_per_execution=steps_per_loop)
summary_dir = os.path.join(model_dir, 'summaries') summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment