Commit 28328ae3 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro
Browse files

Small fixes

parent b6313d65
...@@ -72,9 +72,9 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.') ...@@ -72,9 +72,9 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.')
tf.flags.DEFINE_boolean('use_distortion_for_training', True, tf.flags.DEFINE_boolean('use_distortion_for_training', True,
'If doing image distortion for training.') 'If doing image distortion for training.')
tf.flags.DEFINE_boolean('run_experiment', False, tf.flags.DEFINE_boolean('run_experiment', False,
"If True will run an experiment," 'If True will run an experiment,'
"otherwise will run training and evaluation" 'otherwise will run training and evaluation'
"using the estimator interface") 'using the estimator interface')
# Perf flags # Perf flags
tf.flags.DEFINE_integer('num_intra_threads', 1, tf.flags.DEFINE_integer('num_intra_threads', 1,
...@@ -363,9 +363,11 @@ def input_fn(subset, num_shards): ...@@ -363,9 +363,11 @@ def input_fn(subset, num_shards):
label_shards = [tf.parallel_stack(x) for x in label_shards] label_shards = [tf.parallel_stack(x) for x in label_shards]
return feature_shards, label_shards return feature_shards, label_shards
# create experiment # create experiment
def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps): def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps):
def _experiment_fn(run_config, hparams): def _experiment_fn(run_config, hparams):
del hparams # unused arg
# create estimator # create estimator
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn, classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
config=run_config) config=run_config)
...@@ -413,7 +415,7 @@ def main(unused_argv): ...@@ -413,7 +415,7 @@ def main(unused_argv):
num_shards=FLAGS.num_gpus) num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval', eval_input_fn = functools.partial(input_fn, subset='eval',
num_shards=FLAGS.num_gpus) num_shards=FLAGS.num_gpus)
train_steps = FLAGS.train_steps train_steps = FLAGS.train_steps
eval_steps = num_eval_examples // FLAGS.eval_batch_size eval_steps = num_eval_examples // FLAGS.eval_batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment