Commit b6313d65 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Improving experiment implementation

- fixing typos
- reusing variables
parent b36c01b6
......@@ -71,11 +71,10 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.')
tf.flags.DEFINE_boolean('use_distortion_for_training', True,
'If doing image distortion for training.')
tf.flags.DEFINE_boolean('run_experiment', False,
"If True will run an experiment,"
"otherwise will run training and evaluatio"
"using the estimator's methods")
"otherwise will run training and evaluation"
"using the estimator interface")
# Perf flags
tf.flags.DEFINE_integer('num_intra_threads', 1,
......@@ -365,17 +364,19 @@ def input_fn(subset, num_shards):
return feature_shards, label_shards
# create experiment
def experiment_fn(run_config, hparams):
def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps):
def _experiment_fn(run_config, hparams):
# create estimator
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
config=run_config)
return tf.contrib.learn.Experiment(
classifier,
train_input_fn=train_input_fn,
eval_input_fn=test_input_fn,
train_steps=FLAGS.train_steps,
eval_steps=num_eval_examples // FLAGS.eval_batch_size
eval_input_fn=eval_input_fn,
train_steps=train_steps,
eval_steps=eval_steps
)
return _experiment_fn
def main(unused_argv):
......@@ -408,8 +409,19 @@ def main(unused_argv):
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
config = config.replace(session_config=sess_config)
train_input_fn = functools.partial(input_fn, subset='train',
num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
num_shards=FLAGS.num_gpus)
train_steps = FLAGS.train_steps
eval_steps = num_eval_examples // FLAGS.eval_batch_size
if FLAGS.run_experiment:
tf.contrib.learn.learn_runner.run(experiment_fn, run_config=run_config)
tf.contrib.learn.learn_runner.run(
get_experiment_fn(train_input_fn, eval_input_fn,
train_steps, eval_steps), run_config=config)
else:
classifier = tf.estimator.Estimator(
model_fn=_resnet_model_fn, config=config)
......@@ -419,17 +431,14 @@ def main(unused_argv):
tensors=tensors_to_log, every_n_iter=100)
print('Starting to train...')
classifier.train(
input_fn=functools.partial(
input_fn, subset='train', num_shards=FLAGS.num_gpus),
steps=FLAGS.train_steps,
classifier.train(input_fn=train_input_fn,
steps=train_steps,
hooks=[logging_hook])
print('Starting to evaluate...')
eval_results = classifier.evaluate(
input_fn=functools.partial(
input_fn, subset='eval', num_shards=FLAGS.num_gpus),
steps=num_eval_examples // FLAGS.eval_batch_size)
input_fn=eval_input_fn,
steps=eval_steps)
print(eval_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment