Commit bf44e372 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 452380386
parent f67822f5
......@@ -119,7 +119,7 @@ def run_experiment(
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
if mode == 'train' or mode == 'train_and_post_eval':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
......@@ -152,9 +152,9 @@ def run_experiment(
logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2)
if run_post_eval:
if run_post_eval or mode == 'train_and_post_eval':
with distribution_strategy.scope():
return trainer.model, trainer.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
return trainer.model, controller.evaluate(
steps=params.trainer.validation_steps)
else:
return trainer.model, {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment