Commit 40fdb50c authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Allow to provide custom metrics to run_classifier.

PiperOrigin-RevId: 316134518
parent 2dfd1e63
......@@ -125,7 +125,8 @@ def run_bert_classifier(strategy,
train_input_fn,
eval_input_fn,
training_callbacks=True,
custom_callbacks=None):
custom_callbacks=None,
custom_metrics=None):
"""Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data.get('num_labels', 1)
......@@ -159,7 +160,9 @@ def run_bert_classifier(strategy,
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
if is_regression:
if custom_metrics:
metric_fn = custom_metrics
elif is_regression:
metric_fn = functools.partial(
tf.keras.metrics.MeanSquaredError,
'mean_squared_error',
......@@ -216,10 +219,12 @@ def run_keras_compile_fit(model_dir,
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
if not isinstance(metric_fn, (list, tuple)):
metric_fn = [metric_fn]
bert_model.compile(
optimizer=optimizer,
loss=loss_fn,
metrics=[metric_fn()],
metrics=[fn() for fn in metric_fn],
experimental_steps_per_execution=steps_per_loop)
summary_dir = os.path.join(model_dir, 'summaries')
......@@ -350,7 +355,8 @@ def run_bert(strategy,
train_input_fn=None,
eval_input_fn=None,
init_checkpoint=None,
custom_callbacks=None):
custom_callbacks=None,
custom_metrics=None):
"""Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla)
......@@ -391,7 +397,8 @@ def run_bert(strategy,
init_checkpoint or FLAGS.init_checkpoint,
train_input_fn,
eval_input_fn,
custom_callbacks=custom_callbacks)
custom_callbacks=custom_callbacks,
custom_metrics=custom_metrics)
if FLAGS.model_export_path:
model_saving_utils.export_bert_model(
......@@ -399,11 +406,12 @@ def run_bert(strategy,
return trained_model
def custom_main(custom_callbacks=None):
def custom_main(custom_callbacks=None, custom_metrics=None):
"""Run classification or regression.
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
custom_metrics: list of metrics passed to the training loop.
"""
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
......@@ -474,11 +482,12 @@ def custom_main(custom_callbacks=None):
bert_config,
train_input_fn,
eval_input_fn,
custom_callbacks=custom_callbacks)
custom_callbacks=custom_callbacks,
custom_metrics=custom_metrics)
def main(_):
custom_main(custom_callbacks=None)
custom_main(custom_callbacks=None, custom_metrics=None)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment