Commit 396fd9de authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332061237
parent 543755a0
......@@ -164,9 +164,9 @@ def run_customized_training_loop(
evaluation is skipped.
eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
is not none.
metric_fn: A metrics function that returns a Keras Metric object to record
evaluation result using evaluation dataset or with training dataset
after every epoch.
metric_fn: A metrics function that returns either a Keras Metric object or
a list of Keras Metric objects to record evaluation result using
evaluation dataset or with training dataset after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during
......@@ -291,7 +291,9 @@ def run_customized_training_loop(
logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
eval_metrics = [metric_fn()] if metric_fn else []
eval_metrics = metric_fn() if metric_fn else []
if not isinstance(eval_metrics, list):
eval_metrics = [eval_metrics]
# If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation.
train_metrics = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment