"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d9023a671ad7d947ae4f1366c9246d4ae8201d00"
Commit 2f69dc64 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Small style fixes

parent 68218034
...@@ -33,6 +33,8 @@ import functools ...@@ -33,6 +33,8 @@ import functools
import operator import operator
import os import os
import cifar10
import cifar10_model
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
...@@ -41,8 +43,6 @@ from tensorflow.python.training import basic_session_run_hooks ...@@ -41,8 +43,6 @@ from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import session_run_hook from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util from tensorflow.python.training import training_util
import cifar10
import cifar10_model
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
...@@ -80,9 +80,10 @@ tf.flags.DEFINE_boolean('run_experiment', False, ...@@ -80,9 +80,10 @@ tf.flags.DEFINE_boolean('run_experiment', False,
'If True will run an experiment,' 'If True will run an experiment,'
'otherwise will run training and evaluation' 'otherwise will run training and evaluation'
'using the estimator interface.' 'using the estimator interface.'
'Experiments perform training on several workers in parallel' 'Experiments perform training on several workers in'
', in other words experiments know how to invoke train and' 'parallel, in other words experiments know how to'
' eval in a sensible fashion for distributed training.') ' invoke train and eval in a sensible fashion for'
' distributed training.')
tf.flags.DEFINE_boolean('sync', False, tf.flags.DEFINE_boolean('sync', False,
'If true when running in a distributed environment' 'If true when running in a distributed environment'
...@@ -117,8 +118,8 @@ tf.flags.DEFINE_boolean('log_device_placement', False, ...@@ -117,8 +118,8 @@ tf.flags.DEFINE_boolean('log_device_placement', False,
class ExamplesPerSecondHook(session_run_hook.SessionRunHook): class ExamplesPerSecondHook(session_run_hook.SessionRunHook):
"""Hook to print out examples per second """Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the the running average of examples per second. The examples per second for the
...@@ -131,15 +132,16 @@ class ExamplesPerSecondHook(session_run_hook.SessionRunHook): ...@@ -131,15 +132,16 @@ class ExamplesPerSecondHook(session_run_hook.SessionRunHook):
every_n_steps=100, every_n_steps=100,
every_n_secs=None,): every_n_secs=None,):
"""Initializer for ExamplesPerSecondHook. """Initializer for ExamplesPerSecondHook.
Args:
Args:
batch_size: Total batch size used to calculate examples/second from batch_size: Total batch size used to calculate examples/second from
global time. global time.
every_n_steps: Log stats every n steps. every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds. every_n_secs: Log stats every n seconds.
""" """
if (every_n_steps is None) == (every_n_secs is None): if (every_n_steps is None) == (every_n_secs is None):
raise ValueError( raise ValueError('exactly one of every_n_steps'
'exactly one of every_n_steps and every_n_secs should be provided.') ' and every_n_secs should be provided.')
self._timer = basic_session_run_hooks.SecondOrStepTimer( self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs) every_steps=every_n_steps, every_secs=every_n_secs)
...@@ -188,6 +190,7 @@ class GpuParamServerDeviceSetter(object): ...@@ -188,6 +190,7 @@ class GpuParamServerDeviceSetter(object):
def __init__(self, worker_device, ps_devices): def __init__(self, worker_device, ps_devices):
"""Initializer for GpuParamServerDeviceSetter. """Initializer for GpuParamServerDeviceSetter.
Args: Args:
worker_device: the device to use for computation Ops. worker_device: the device to use for computation Ops.
ps_devices: a list of devices to use for Variable Ops. Each variable is ps_devices: a list of devices to use for Variable Ops. Each variable is
...@@ -202,7 +205,7 @@ class GpuParamServerDeviceSetter(object): ...@@ -202,7 +205,7 @@ class GpuParamServerDeviceSetter(object):
return op.device return op.device
if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']: if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']:
return self.worker_device return self.worker_device
# Gets the least loaded ps_device # Gets the least loaded ps_device
device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1)) device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index] device_name = self.ps_devices[device_index]
...@@ -211,6 +214,7 @@ class GpuParamServerDeviceSetter(object): ...@@ -211,6 +214,7 @@ class GpuParamServerDeviceSetter(object):
return device_name return device_name
def _create_device_setter(is_cpu_ps, worker, num_gpus): def _create_device_setter(is_cpu_ps, worker, num_gpus):
"""Create device setter object.""" """Create device setter object."""
if is_cpu_ps: if is_cpu_ps:
...@@ -400,7 +404,8 @@ def input_fn(subset, num_shards): ...@@ -400,7 +404,8 @@ def input_fn(subset, num_shards):
elif subset == 'validate' or subset == 'eval': elif subset == 'validate' or subset == 'eval':
batch_size = FLAGS.eval_batch_size batch_size = FLAGS.eval_batch_size
else: else:
raise ValueError('Subset must be one of \'train\', \'validate\' and \'eval\'') raise ValueError('Subset must be one of \'train\''
', \'validate\' and \'eval\'')
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
use_distortion = subset == 'train' and FLAGS.use_distortion_for_training use_distortion = subset == 'train' and FLAGS.use_distortion_for_training
dataset = cifar10.Cifar10DataSet(FLAGS.data_dir, subset, use_distortion) dataset = cifar10.Cifar10DataSet(FLAGS.data_dir, subset, use_distortion)
...@@ -429,7 +434,14 @@ def input_fn(subset, num_shards): ...@@ -429,7 +434,14 @@ def input_fn(subset, num_shards):
# create experiment # create experiment
def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps, def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps,
train_hooks): train_hooks):
"""Returns an Experiment function.
Experiments perform training on several workers in parallel,
in other words experiments know how to invoke train and eval in a sensible
fashion for distributed training.
"""
def _experiment_fn(run_config, hparams): def _experiment_fn(run_config, hparams):
"""Returns an Experiment."""
del hparams # unused arg del hparams # unused arg
# create estimator # create estimator
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn, classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
...@@ -491,7 +503,7 @@ def main(unused_argv): ...@@ -491,7 +503,7 @@ def main(unused_argv):
logging_hook = tf.train.LoggingTensorHook( logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100) tensors=tensors_to_log, every_n_iter=100)
examples_sec_hook = ExamplesPerSecondHook( examples_sec_hook = ExamplesPerSecondHook(
FLAGS.train_batch_size, every_n_steps=10) FLAGS.train_batch_size, every_n_steps=10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment