Commit b6313d65 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Improving experiment implementation

- fixing typos
- reusing variables
parent b36c01b6
......@@ -71,11 +71,10 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.')
tf.flags.DEFINE_boolean('use_distortion_for_training', True,
'If doing image distortion for training.')
tf.flags.DEFINE_boolean('run_experiment', False,
"If True will run an experiment,"
"otherwise will run training and evaluatio"
"using the estimator's methods")
"If True will run an experiment,"
"otherwise will run training and evaluation"
"using the estimator interface")
# Perf flags
tf.flags.DEFINE_integer('num_intra_threads', 1,
......@@ -365,17 +364,19 @@ def input_fn(subset, num_shards):
return feature_shards, label_shards
# create experiment
def experiment_fn(run_config, hparams):
# create estimator
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
config=run_config)
return tf.contrib.learn.Experiment(
classifier,
train_input_fn=train_input_fn,
eval_input_fn=test_input_fn,
train_steps=FLAGS.train_steps,
eval_steps=num_eval_examples // FLAGS.eval_batch_size
)
def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps):
def _experiment_fn(run_config, hparams):
# create estimator
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
config=run_config)
return tf.contrib.learn.Experiment(
classifier,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=train_steps,
eval_steps=eval_steps
)
return _experiment_fn
def main(unused_argv):
......@@ -408,8 +409,19 @@ def main(unused_argv):
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
config = config.replace(session_config=sess_config)
train_input_fn = functools.partial(input_fn, subset='train',
num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
num_shards=FLAGS.num_gpus)
train_steps = FLAGS.train_steps
eval_steps = num_eval_examples // FLAGS.eval_batch_size
if FLAGS.run_experiment:
tf.contrib.learn.learn_runner.run(experiment_fn, run_config=run_config)
tf.contrib.learn.learn_runner.run(
get_experiment_fn(train_input_fn, eval_input_fn,
train_steps, eval_steps), run_config=config)
else:
classifier = tf.estimator.Estimator(
model_fn=_resnet_model_fn, config=config)
......@@ -419,17 +431,14 @@ def main(unused_argv):
tensors=tensors_to_log, every_n_iter=100)
print('Starting to train...')
classifier.train(
input_fn=functools.partial(
input_fn, subset='train', num_shards=FLAGS.num_gpus),
steps=FLAGS.train_steps,
hooks=[logging_hook])
classifier.train(input_fn=train_input_fn,
steps=train_steps,
hooks=[logging_hook])
print('Starting to evaluate...')
eval_results = classifier.evaluate(
input_fn=functools.partial(
input_fn, subset='eval', num_shards=FLAGS.num_gpus),
steps=num_eval_examples // FLAGS.eval_batch_size)
input_fn=eval_input_fn,
steps=eval_steps)
print(eval_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment