Commit 38800db3 authored by André Susano Pinto's avatar André Susano Pinto Committed by A. Unique TensorFlower
Browse files

Add --sub_model_export_name to run_classifier and run_squad.

This allows one to finetune a BERT model into a task before using it for
another task. E.g. SQuAD before finetune another QA type of tasks.

PiperOrigin-RevId: 313145768
parent 0108405d
...@@ -73,6 +73,9 @@ def define_common_bert_flags(): ...@@ -73,6 +73,9 @@ def define_common_bert_flags():
'If specified, init_checkpoint flag should not be used.') 'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_bool('hub_module_trainable', True, flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.') 'True to make keras layers in the hub module trainable.')
flags.DEFINE_string('sub_model_export_name', None,
'If set, `sub_model` checkpoints are exported into '
'FLAGS.model_dir/FLAGS.sub_model_export_name.')
flags_core.define_log_steps() flags_core.define_log_steps()
......
...@@ -178,6 +178,7 @@ def run_bert_classifier(strategy, ...@@ -178,6 +178,7 @@ def run_bert_classifier(strategy,
eval_input_fn=eval_input_fn, eval_input_fn=eval_input_fn,
eval_steps=eval_steps, eval_steps=eval_steps,
init_checkpoint=init_checkpoint, init_checkpoint=init_checkpoint,
sub_model_export_name=FLAGS.sub_model_export_name,
metric_fn=metric_fn, metric_fn=metric_fn,
custom_callbacks=custom_callbacks, custom_callbacks=custom_callbacks,
run_eagerly=run_eagerly) run_eagerly=run_eagerly)
......
...@@ -49,12 +49,14 @@ def train_squad(strategy, ...@@ -49,12 +49,14 @@ def train_squad(strategy,
input_meta_data, input_meta_data,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=False, run_eagerly=False,
init_checkpoint=None): init_checkpoint=None,
sub_model_export_name=None):
"""Run bert squad training.""" """Run bert squad training."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
init_checkpoint = init_checkpoint or FLAGS.init_checkpoint init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
run_squad_helper.train_squad(strategy, input_meta_data, bert_config, run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly, init_checkpoint) custom_callbacks, run_eagerly, init_checkpoint,
sub_model_export_name=sub_model_export_name)
def predict_squad(strategy, input_meta_data): def predict_squad(strategy, input_meta_data):
...@@ -125,6 +127,7 @@ def main(_): ...@@ -125,6 +127,7 @@ def main(_):
input_meta_data, input_meta_data,
custom_callbacks=custom_callbacks, custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly, run_eagerly=FLAGS.run_eagerly,
sub_model_export_name=FLAGS.sub_model_export_name,
) )
if 'predict' in FLAGS.mode: if 'predict' in FLAGS.mode:
predict_squad(strategy, input_meta_data) predict_squad(strategy, input_meta_data)
......
...@@ -221,7 +221,8 @@ def train_squad(strategy, ...@@ -221,7 +221,8 @@ def train_squad(strategy,
bert_config, bert_config,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=False, run_eagerly=False,
init_checkpoint=None): init_checkpoint=None,
sub_model_export_name=None):
"""Run bert squad training.""" """Run bert squad training."""
if strategy: if strategy:
logging.info('Training using customized training loop with distribution' logging.info('Training using customized training loop with distribution'
...@@ -279,6 +280,7 @@ def train_squad(strategy, ...@@ -279,6 +280,7 @@ def train_squad(strategy,
epochs=epochs, epochs=epochs,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
init_checkpoint=init_checkpoint or FLAGS.init_checkpoint, init_checkpoint=init_checkpoint or FLAGS.init_checkpoint,
sub_model_export_name=sub_model_export_name,
run_eagerly=run_eagerly, run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks, custom_callbacks=custom_callbacks,
explicit_allreduce=False, explicit_allreduce=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment