Commit b6313d65 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Improving experiment implementation

- fixing typos
- reusing variables
parent b36c01b6
...@@ -71,11 +71,10 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.') ...@@ -71,11 +71,10 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.')
tf.flags.DEFINE_boolean('use_distortion_for_training', True, tf.flags.DEFINE_boolean('use_distortion_for_training', True,
'If doing image distortion for training.') 'If doing image distortion for training.')
tf.flags.DEFINE_boolean('run_experiment', False, tf.flags.DEFINE_boolean('run_experiment', False,
"If True will run an experiment," "If True will run an experiment,"
"otherwise will run training and evaluatio" "otherwise will run training and evaluation"
"using the estimator's methods") "using the estimator interface")
# Perf flags # Perf flags
tf.flags.DEFINE_integer('num_intra_threads', 1, tf.flags.DEFINE_integer('num_intra_threads', 1,
...@@ -365,17 +364,19 @@ def input_fn(subset, num_shards): ...@@ -365,17 +364,19 @@ def input_fn(subset, num_shards):
return feature_shards, label_shards return feature_shards, label_shards
# create experiment # create experiment
def experiment_fn(run_config, hparams): def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps):
# create estimator def _experiment_fn(run_config, hparams):
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn, # create estimator
config=run_config) classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
return tf.contrib.learn.Experiment( config=run_config)
classifier, return tf.contrib.learn.Experiment(
train_input_fn=train_input_fn, classifier,
eval_input_fn=test_input_fn, train_input_fn=train_input_fn,
train_steps=FLAGS.train_steps, eval_input_fn=eval_input_fn,
eval_steps=num_eval_examples // FLAGS.eval_batch_size train_steps=train_steps,
) eval_steps=eval_steps
)
return _experiment_fn
def main(unused_argv): def main(unused_argv):
...@@ -408,8 +409,19 @@ def main(unused_argv): ...@@ -408,8 +409,19 @@ def main(unused_argv):
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
config = config.replace(session_config=sess_config) config = config.replace(session_config=sess_config)
train_input_fn = functools.partial(input_fn, subset='train',
num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
num_shards=FLAGS.num_gpus)
train_steps = FLAGS.train_steps
eval_steps = num_eval_examples // FLAGS.eval_batch_size
if FLAGS.run_experiment: if FLAGS.run_experiment:
tf.contrib.learn.learn_runner.run(experiment_fn, run_config=run_config) tf.contrib.learn.learn_runner.run(
get_experiment_fn(train_input_fn, eval_input_fn,
train_steps, eval_steps), run_config=config)
else: else:
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=_resnet_model_fn, config=config) model_fn=_resnet_model_fn, config=config)
...@@ -419,17 +431,14 @@ def main(unused_argv): ...@@ -419,17 +431,14 @@ def main(unused_argv):
tensors=tensors_to_log, every_n_iter=100) tensors=tensors_to_log, every_n_iter=100)
print('Starting to train...') print('Starting to train...')
classifier.train( classifier.train(input_fn=train_input_fn,
input_fn=functools.partial( steps=train_steps,
input_fn, subset='train', num_shards=FLAGS.num_gpus), hooks=[logging_hook])
steps=FLAGS.train_steps,
hooks=[logging_hook])
print('Starting to evaluate...') print('Starting to evaluate...')
eval_results = classifier.evaluate( eval_results = classifier.evaluate(
input_fn=functools.partial( input_fn=eval_input_fn,
input_fn, subset='eval', num_shards=FLAGS.num_gpus), steps=eval_steps)
steps=num_eval_examples // FLAGS.eval_batch_size)
print(eval_results) print(eval_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment