Unverified Commit 20b19b61 authored by Yuefeng Zhou's avatar Yuefeng Zhou Committed by GitHub
Browse files

tf.estimator.train_and_evalute doesn't return anything in multi-worker case. (#6582)

* Update resnet_run_loop.py

* Update resnet_run_loop.py

* Update resnet_run_loop.py

* Update resnet_run_loop.py

* Update resnet_run_loop.py
parent 8c68bc0d
......@@ -628,21 +628,19 @@ def resnet_main(
train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else
flags_obj.train_epochs)
use_train_and_evaluate = flags_obj.use_train_and_evaluate or (
distribution_strategy.__class__.__name__ in [
'CollectiveAllReduceStrategy', 'MultiWorkerMirroredStrategy'])
use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1
if use_train_and_evaluate:
train_spec = tf.estimator.TrainSpec(
input_fn=lambda input_context=None: input_fn_train(
train_epochs, input_context=input_context),
hooks=train_hooks,
max_steps=flags_obj.max_train_steps)
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval,
steps=flags_obj.max_train_steps)
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval)
tf.compat.v1.logging.info('Starting to train and evaluate.')
eval_results, _ = tf.estimator.train_and_evaluate(classifier, train_spec,
eval_spec)
benchmark_logger.log_evaluation_result(eval_results)
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
# tf.estimator.train_and_evalute doesn't return anything in multi-worker
# case.
return {}
else:
if train_epochs == 0:
# If --eval_only is set, perform a single loop with zero train epochs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment