Commit b36c01b6 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Adding run_experiment option

parent d71cbd0c
...@@ -72,6 +72,11 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.') ...@@ -72,6 +72,11 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.')
tf.flags.DEFINE_boolean('use_distortion_for_training', True, tf.flags.DEFINE_boolean('use_distortion_for_training', True,
'If doing image distortion for training.') 'If doing image distortion for training.')
tf.flags.DEFINE_boolean('run_experiment', False,
"If True will run an experiment,"
"otherwise will run training and evaluatio"
"using the estimator's methods")
# Perf flags # Perf flags
tf.flags.DEFINE_integer('num_intra_threads', 1, tf.flags.DEFINE_integer('num_intra_threads', 1,
"""Number of threads to use for intra-op parallelism. """Number of threads to use for intra-op parallelism.
...@@ -359,6 +364,19 @@ def input_fn(subset, num_shards): ...@@ -359,6 +364,19 @@ def input_fn(subset, num_shards):
label_shards = [tf.parallel_stack(x) for x in label_shards] label_shards = [tf.parallel_stack(x) for x in label_shards]
return feature_shards, label_shards return feature_shards, label_shards
# create experiment
def experiment_fn(run_config, hparams):
# create estimator
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
config=run_config)
return tf.contrib.learn.Experiment(
classifier,
train_input_fn=train_input_fn,
eval_input_fn=test_input_fn,
train_steps=FLAGS.train_steps,
eval_steps=num_eval_examples // FLAGS.eval_batch_size
)
def main(unused_argv): def main(unused_argv):
# The env variable is on deprecation path, default is set to off. # The env variable is on deprecation path, default is set to off.
...@@ -381,7 +399,7 @@ def main(unused_argv): ...@@ -381,7 +399,7 @@ def main(unused_argv):
if num_eval_examples % FLAGS.eval_batch_size != 0: if num_eval_examples % FLAGS.eval_batch_size != 0:
raise ValueError('validation set size must be multiple of eval_batch_size') raise ValueError('validation set size must be multiple of eval_batch_size')
config = tf.estimator.RunConfig() config = tf.contrib.learn.RunConfig(model_dir=FLAGS.model_dir)
sess_config = tf.ConfigProto() sess_config = tf.ConfigProto()
sess_config.allow_soft_placement = True sess_config.allow_soft_placement = True
sess_config.log_device_placement = FLAGS.log_device_placement sess_config.log_device_placement = FLAGS.log_device_placement
...@@ -390,26 +408,29 @@ def main(unused_argv): ...@@ -390,26 +408,29 @@ def main(unused_argv):
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
config = config.replace(session_config=sess_config) config = config.replace(session_config=sess_config)
classifier = tf.estimator.Estimator( if FLAGS.run_experiment:
model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config) tf.contrib.learn.learn_runner.run(experiment_fn, run_config=run_config)
else:
tensors_to_log = {'learning_rate': 'learning_rate'} classifier = tf.estimator.Estimator(
logging_hook = tf.train.LoggingTensorHook( model_fn=_resnet_model_fn, config=config)
tensors=tensors_to_log, every_n_iter=100)
tensors_to_log = {'learning_rate': 'learning_rate'}
print('Starting to train...') logging_hook = tf.train.LoggingTensorHook(
classifier.train( tensors=tensors_to_log, every_n_iter=100)
input_fn=functools.partial(
input_fn, subset='train', num_shards=FLAGS.num_gpus), print('Starting to train...')
steps=FLAGS.train_steps, classifier.train(
hooks=[logging_hook]) input_fn=functools.partial(
input_fn, subset='train', num_shards=FLAGS.num_gpus),
print('Starting to evaluate...') steps=FLAGS.train_steps,
eval_results = classifier.evaluate( hooks=[logging_hook])
input_fn=functools.partial(
input_fn, subset='eval', num_shards=FLAGS.num_gpus), print('Starting to evaluate...')
steps=num_eval_examples // FLAGS.eval_batch_size) eval_results = classifier.evaluate(
print(eval_results) input_fn=functools.partial(
input_fn, subset='eval', num_shards=FLAGS.num_gpus),
steps=num_eval_examples // FLAGS.eval_batch_size)
print(eval_results)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment