Commit b523f160 authored by Jeremiah Harmsen's avatar Jeremiah Harmsen Committed by A. Unique TensorFlower
Browse files

Thread custom callbacks through BERT classifier. This allows injection of...

Thread custom callbacks through BERT classifier.  This allows injection of custom callbacks for library users of run_classifier.

PiperOrigin-RevId: 306966758
parent 8fd50e97
...@@ -331,7 +331,8 @@ def run_bert(strategy, ...@@ -331,7 +331,8 @@ def run_bert(strategy,
model_config, model_config,
train_input_fn=None, train_input_fn=None,
eval_input_fn=None, eval_input_fn=None,
init_checkpoint=None): init_checkpoint=None,
custom_callbacks=None):
"""Run BERT training.""" """Run BERT training."""
if FLAGS.mode == 'export_only': if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API # As Keras ModelCheckpoint callback used with Keras compile/fit() API
...@@ -358,14 +359,14 @@ def run_bert(strategy, ...@@ -358,14 +359,14 @@ def run_bert(strategy,
if not strategy: if not strategy:
raise ValueError('Distribution strategy has not been specified.') raise ValueError('Distribution strategy has not been specified.')
if not custom_callbacks:
custom_callbacks = []
if FLAGS.log_steps: if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory( custom_callbacks.append(keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size, batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir, logdir=FLAGS.model_dir))
)]
else:
custom_callbacks = None
trained_model = run_bert_classifier( trained_model = run_bert_classifier(
strategy, strategy,
...@@ -396,9 +397,12 @@ def run_bert(strategy, ...@@ -396,9 +397,12 @@ def run_bert(strategy,
return trained_model return trained_model
def main(_): def custom_main(custom_callbacks=None):
# Users should always run this script under TF 2.x """Run classification.
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
"""
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
...@@ -423,7 +427,12 @@ def main(_): ...@@ -423,7 +427,12 @@ def main(_):
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_bert(strategy, input_meta_data, bert_config, train_input_fn, run_bert(strategy, input_meta_data, bert_config, train_input_fn,
eval_input_fn) eval_input_fn, custom_callbacks=custom_callbacks)
def main(_):
# Users should always run this script under TF 2.x
custom_main(custom_callbacks=None)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment