Commit b3ab2074 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307089919
parent 85b50c88
......@@ -353,7 +353,8 @@ def train_and_eval(
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer,
loss=loss_obj,
metrics=metrics)
metrics=metrics,
experimental_steps_per_execution=params.train.steps_per_loop)
initial_epoch = 0
if params.train.resume_checkpoint:
......@@ -389,9 +390,8 @@ def train_and_eval(
steps_per_epoch=train_steps,
initial_epoch=initial_epoch,
callbacks=callbacks,
**validation_kwargs,
experimental_steps_per_execution=params.train.steps_per_loop,
verbose=2)
verbose=2,
**validation_kwargs)
validation_output = None
if not params.evaluation.skip_eval:
......@@ -441,8 +441,7 @@ def run(flags_obj: flags.FlagValues,
def main(_):
with logger.benchmark_context(flags.FLAGS):
stats = run(flags.FLAGS)
stats = run(flags.FLAGS)
if stats:
logging.info('Run stats:\n%s', stats)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment