Commit b3ab2074 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307089919
parent 85b50c88
...@@ -353,7 +353,8 @@ def train_and_eval( ...@@ -353,7 +353,8 @@ def train_and_eval(
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer, model.compile(optimizer=optimizer,
loss=loss_obj, loss=loss_obj,
metrics=metrics) metrics=metrics,
experimental_steps_per_execution=params.train.steps_per_loop)
initial_epoch = 0 initial_epoch = 0
if params.train.resume_checkpoint: if params.train.resume_checkpoint:
...@@ -389,9 +390,8 @@ def train_and_eval( ...@@ -389,9 +390,8 @@ def train_and_eval(
steps_per_epoch=train_steps, steps_per_epoch=train_steps,
initial_epoch=initial_epoch, initial_epoch=initial_epoch,
callbacks=callbacks, callbacks=callbacks,
**validation_kwargs, verbose=2,
experimental_steps_per_execution=params.train.steps_per_loop, **validation_kwargs)
verbose=2)
validation_output = None validation_output = None
if not params.evaluation.skip_eval: if not params.evaluation.skip_eval:
...@@ -441,8 +441,7 @@ def run(flags_obj: flags.FlagValues, ...@@ -441,8 +441,7 @@ def run(flags_obj: flags.FlagValues,
def main(_): def main(_):
with logger.benchmark_context(flags.FLAGS): stats = run(flags.FLAGS)
stats = run(flags.FLAGS)
if stats: if stats:
logging.info('Run stats:\n%s', stats) logging.info('Run stats:\n%s', stats)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment