Commit b523f160 authored by Jeremiah Harmsen's avatar Jeremiah Harmsen Committed by A. Unique TensorFlower
Browse files

Thread custom callbacks through BERT classifier. This allows injection of...

Thread custom callbacks through BERT classifier.  This allows injection of custom callbacks for library users of run_classifier.

PiperOrigin-RevId: 306966758
parent 8fd50e97
......@@ -331,7 +331,8 @@ def run_bert(strategy,
model_config,
train_input_fn=None,
eval_input_fn=None,
init_checkpoint=None):
init_checkpoint=None,
custom_callbacks=None):
"""Run BERT training."""
if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
......@@ -358,14 +359,14 @@ def run_bert(strategy,
if not strategy:
raise ValueError('Distribution strategy has not been specified.')
if not custom_callbacks:
custom_callbacks = []
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
custom_callbacks.append(keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
logdir=FLAGS.model_dir))
trained_model = run_bert_classifier(
strategy,
......@@ -396,9 +397,12 @@ def run_bert(strategy,
return trained_model
def main(_):
# Users should always run this script under TF 2.x
def custom_main(custom_callbacks=None):
"""Run classification.
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
"""
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
......@@ -423,7 +427,12 @@ def main(_):
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_bert(strategy, input_meta_data, bert_config, train_input_fn,
eval_input_fn)
eval_input_fn, custom_callbacks=custom_callbacks)
def main(_):
# Users should always run this script under TF 2.x
custom_main(custom_callbacks=None)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment