Commit 40fdb50c authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Allow to provide custom metrics to run_classifier.

PiperOrigin-RevId: 316134518
parent 2dfd1e63
...@@ -125,7 +125,8 @@ def run_bert_classifier(strategy, ...@@ -125,7 +125,8 @@ def run_bert_classifier(strategy,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
training_callbacks=True, training_callbacks=True,
custom_callbacks=None): custom_callbacks=None,
custom_metrics=None):
"""Run BERT classifier training using low-level API.""" """Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data.get('num_labels', 1) num_classes = input_meta_data.get('num_labels', 1)
...@@ -159,7 +160,9 @@ def run_bert_classifier(strategy, ...@@ -159,7 +160,9 @@ def run_bert_classifier(strategy,
# Defines evaluation metrics function, which will create metrics in the # Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope. # correct device and strategy scope.
if is_regression: if custom_metrics:
metric_fn = custom_metrics
elif is_regression:
metric_fn = functools.partial( metric_fn = functools.partial(
tf.keras.metrics.MeanSquaredError, tf.keras.metrics.MeanSquaredError,
'mean_squared_error', 'mean_squared_error',
...@@ -216,10 +219,12 @@ def run_keras_compile_fit(model_dir, ...@@ -216,10 +219,12 @@ def run_keras_compile_fit(model_dir,
checkpoint = tf.train.Checkpoint(model=sub_model) checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched() checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
if not isinstance(metric_fn, (list, tuple)):
metric_fn = [metric_fn]
bert_model.compile( bert_model.compile(
optimizer=optimizer, optimizer=optimizer,
loss=loss_fn, loss=loss_fn,
metrics=[metric_fn()], metrics=[fn() for fn in metric_fn],
experimental_steps_per_execution=steps_per_loop) experimental_steps_per_execution=steps_per_loop)
summary_dir = os.path.join(model_dir, 'summaries') summary_dir = os.path.join(model_dir, 'summaries')
...@@ -350,7 +355,8 @@ def run_bert(strategy, ...@@ -350,7 +355,8 @@ def run_bert(strategy,
train_input_fn=None, train_input_fn=None,
eval_input_fn=None, eval_input_fn=None,
init_checkpoint=None, init_checkpoint=None,
custom_callbacks=None): custom_callbacks=None,
custom_metrics=None):
"""Run BERT training.""" """Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU. # Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla) keras_utils.set_session_config(FLAGS.enable_xla)
...@@ -391,7 +397,8 @@ def run_bert(strategy, ...@@ -391,7 +397,8 @@ def run_bert(strategy,
init_checkpoint or FLAGS.init_checkpoint, init_checkpoint or FLAGS.init_checkpoint,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks,
custom_metrics=custom_metrics)
if FLAGS.model_export_path: if FLAGS.model_export_path:
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
...@@ -399,11 +406,12 @@ def run_bert(strategy, ...@@ -399,11 +406,12 @@ def run_bert(strategy,
return trained_model return trained_model
def custom_main(custom_callbacks=None): def custom_main(custom_callbacks=None, custom_metrics=None):
"""Run classification or regression. """Run classification or regression.
Args: Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop. custom_callbacks: list of tf.keras.Callbacks passed to training loop.
custom_metrics: list of metrics passed to the training loop.
""" """
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
...@@ -474,11 +482,12 @@ def custom_main(custom_callbacks=None): ...@@ -474,11 +482,12 @@ def custom_main(custom_callbacks=None):
bert_config, bert_config,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks,
custom_metrics=custom_metrics)
def main(_): def main(_):
custom_main(custom_callbacks=None) custom_main(custom_callbacks=None, custom_metrics=None)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment