Commit 2f69dc64 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Small style fixes

parent 68218034
......@@ -33,6 +33,8 @@ import functools
import operator
import os
import cifar10
import cifar10_model
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
......@@ -41,8 +43,6 @@ from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
import cifar10
import cifar10_model
tf.logging.set_verbosity(tf.logging.INFO)
......@@ -80,9 +80,10 @@ tf.flags.DEFINE_boolean('run_experiment', False,
'If True will run an experiment,'
'otherwise will run training and evaluation'
'using the estimator interface.'
'Experiments perform training on several workers in parallel'
', in other words experiments know how to invoke train and'
' eval in a sensible fashion for distributed training.')
'Experiments perform training on several workers in'
'parallel, in other words experiments know how to'
' invoke train and eval in a sensible fashion for'
' distributed training.')
tf.flags.DEFINE_boolean('sync', False,
'If true when running in a distributed environment'
......@@ -117,8 +118,8 @@ tf.flags.DEFINE_boolean('log_device_placement', False,
class ExamplesPerSecondHook(session_run_hook.SessionRunHook):
"""Hook to print out examples per second
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
......@@ -131,15 +132,16 @@ class ExamplesPerSecondHook(session_run_hook.SessionRunHook):
every_n_steps=100,
every_n_secs=None,):
"""Initializer for ExamplesPerSecondHook.
Args:
Args:
batch_size: Total batch size used to calculate examples/second from
global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds.
"""
"""
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError(
'exactly one of every_n_steps and every_n_secs should be provided.')
raise ValueError('exactly one of every_n_steps'
' and every_n_secs should be provided.')
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)
......@@ -188,6 +190,7 @@ class GpuParamServerDeviceSetter(object):
def __init__(self, worker_device, ps_devices):
"""Initializer for GpuParamServerDeviceSetter.
Args:
worker_device: the device to use for computation Ops.
ps_devices: a list of devices to use for Variable Ops. Each variable is
......@@ -202,7 +205,7 @@ class GpuParamServerDeviceSetter(object):
return op.device
if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']:
return self.worker_device
# Gets the least loaded ps_device
device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
......@@ -211,6 +214,7 @@ class GpuParamServerDeviceSetter(object):
return device_name
def _create_device_setter(is_cpu_ps, worker, num_gpus):
"""Create device setter object."""
if is_cpu_ps:
......@@ -400,7 +404,8 @@ def input_fn(subset, num_shards):
elif subset == 'validate' or subset == 'eval':
batch_size = FLAGS.eval_batch_size
else:
raise ValueError('Subset must be one of \'train\', \'validate\' and \'eval\'')
raise ValueError('Subset must be one of \'train\''
', \'validate\' and \'eval\'')
with tf.device('/cpu:0'):
use_distortion = subset == 'train' and FLAGS.use_distortion_for_training
dataset = cifar10.Cifar10DataSet(FLAGS.data_dir, subset, use_distortion)
......@@ -429,7 +434,14 @@ def input_fn(subset, num_shards):
# create experiment
def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps,
train_hooks):
"""Returns an Experiment function.
Experiments perform training on several workers in parallel,
in other words experiments know how to invoke train and eval in a sensible
fashion for distributed training.
"""
def _experiment_fn(run_config, hparams):
"""Returns an Experiment."""
del hparams # unused arg
# create estimator
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
......@@ -491,7 +503,7 @@ def main(unused_argv):
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
examples_sec_hook = ExamplesPerSecondHook(
FLAGS.train_batch_size, every_n_steps=10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment