Commit e3ae4b5e authored by André Susano Pinto's avatar André Susano Pinto Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307998953
parent 543fb4cf
...@@ -88,6 +88,7 @@ def get_loss_fn(): ...@@ -88,6 +88,7 @@ def get_loss_fn():
def run_customized_training(strategy, def run_customized_training(strategy,
bert_config, bert_config,
init_checkpoint,
max_seq_length, max_seq_length,
max_predictions_per_seq, max_predictions_per_seq,
model_dir, model_dir,
...@@ -128,6 +129,7 @@ def run_customized_training(strategy, ...@@ -128,6 +129,7 @@ def run_customized_training(strategy,
loss_fn=get_loss_fn(), loss_fn=get_loss_fn(),
scale_loss=FLAGS.scale_loss, scale_loss=FLAGS.scale_loss,
model_dir=model_dir, model_dir=model_dir,
init_checkpoint=init_checkpoint,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
...@@ -153,6 +155,7 @@ def run_bert_pretrain(strategy): ...@@ -153,6 +155,7 @@ def run_bert_pretrain(strategy):
return run_customized_training( return run_customized_training(
strategy, strategy,
bert_config, bert_config,
FLAGS.init_checkpoint, # Used to initialize only the BERT submodel.
FLAGS.max_seq_length, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, FLAGS.max_predictions_per_seq,
FLAGS.model_dir, FLAGS.model_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment