"src/turbomind/vscode:/vscode.git/clone" did not exist on "d5a89465226eced35c994784af386192e9f76090"
Unverified Commit 7f351c62 authored by Qianli Scott Zhu's avatar Qianli Scott Zhu Committed by GitHub
Browse files

Minor update the resnet runloop. (#4113)

1. trainning hooks and train/eval function does not need to be
declared several times.
2. change to use tf.logging.info instead of print
3. Show the current index and total count of training cycle, which
give user some hint about where they are in the whole process.
parent e6082458
...@@ -401,28 +401,29 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -401,28 +401,29 @@ def resnet_main(flags, model_function, input_function, shape=None):
benchmark_logger = logger.config_benchmark_logger(flags.benchmark_log_dir) benchmark_logger = logger.config_benchmark_logger(flags.benchmark_log_dir)
benchmark_logger.log_run_info('resnet') benchmark_logger.log_run_info('resnet')
for _ in range(flags.train_epochs // flags.epochs_between_evals):
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
flags.hooks, flags.hooks,
batch_size=flags.batch_size, batch_size=flags.batch_size,
benchmark_log_dir=flags.benchmark_log_dir) benchmark_log_dir=flags.benchmark_log_dir)
print('Starting a training cycle.')
def input_fn_train(): def input_fn_train():
return input_function(True, flags.data_dir, flags.batch_size, return input_function(True, flags.data_dir, flags.batch_size,
flags.epochs_between_evals, flags.epochs_between_evals,
flags.num_parallel_calls, flags.multi_gpu) flags.num_parallel_calls, flags.multi_gpu)
classifier.train(input_fn=input_fn_train, hooks=train_hooks,
max_steps=flags.max_train_steps)
print('Starting to evaluate.')
# Evaluate the model and print results
def input_fn_eval(): def input_fn_eval():
return input_function(False, flags.data_dir, flags.batch_size, return input_function(False, flags.data_dir, flags.batch_size,
1, flags.num_parallel_calls, flags.multi_gpu) 1, flags.num_parallel_calls, flags.multi_gpu)
total_training_cycle = flags.train_epochs // flags.epochs_between_evals
for cycle_index in range(total_training_cycle):
tf.logging.info('Starting a training cycle: %d/%d',
cycle_index, total_training_cycle)
classifier.train(input_fn=input_fn_train, hooks=train_hooks,
max_steps=flags.max_train_steps)
tf.logging.info('Starting to evaluate.')
# flags.max_train_steps is generally associated with testing and profiling. # flags.max_train_steps is generally associated with testing and profiling.
# As a result it is frequently called with synthetic data, which will # As a result it is frequently called with synthetic data, which will
# iterate forever. Passing steps=flags.max_train_steps allows the eval # iterate forever. Passing steps=flags.max_train_steps allows the eval
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment