"git@developer.sourcefind.cn:modelzoo/chatglm.git" did not exist on "3191e623ab0650dd32b4406db945cd8b1821792e"
Commit b139a84b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332061237
parent 81fa82ef
...@@ -164,9 +164,9 @@ def run_customized_training_loop( ...@@ -164,9 +164,9 @@ def run_customized_training_loop(
evaluation is skipped. evaluation is skipped.
eval_steps: Number of steps to run evaluation. Required if `eval_input_fn` eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
is not none. is not none.
metric_fn: A metrics function that returns a Keras Metric object to record metric_fn: A metrics function that returns either a Keras Metric object or
evaluation result using evaluation dataset or with training dataset a list of Keras Metric objects to record evaluation result using
after every epoch. evaluation dataset or with training dataset after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`. `model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
...@@ -291,7 +291,9 @@ def run_customized_training_loop( ...@@ -291,7 +291,9 @@ def run_customized_training_loop(
logging.info('Loading from checkpoint file completed') logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32) train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
eval_metrics = [metric_fn()] if metric_fn else [] eval_metrics = metric_fn() if metric_fn else []
if not isinstance(eval_metrics, list):
eval_metrics = [eval_metrics]
# If evaluation is required, make a copy of metric as it will be used by # If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation. # both train and evaluation.
train_metrics = [ train_metrics = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment