Commit 7dde2e2f authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #2044 from tfboyd/cifar_estimator

Enhanced comments and renamed ParamServerDeviceSetter
parents 78007443 2e960eb1
...@@ -98,16 +98,20 @@ tf.flags.DEFINE_boolean('log_device_placement', False, ...@@ -98,16 +98,20 @@ tf.flags.DEFINE_boolean('log_device_placement', False,
'Whether to log device placement.') 'Whether to log device placement.')
# TODO(jamesqin): Replace with fix in b/62239022 class GpuParamServerDeviceSetter(object):
class ParamServerDeviceSetter(object): """Used with tf.device() to place variables on the least loaded GPU.
"""Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices): A common use for this class is to pass a list of GPU devices, e.g. ['gpu:0',
"""Initializer for ParamServerDeviceSetter. 'gpu:1','gpu:2'], as ps_devices. When each variable is placed, it will be
placed on the least loaded gpu. All other Ops, which will be the computation
Ops, will be placed on the worker_device.
"""
def __init__(self, worker_device, ps_devices):
"""Initializer for GpuParamServerDeviceSetter.
Args: Args:
worker_device: the device to use for computer ops. worker_device: the device to use for computation Ops.
ps_devices: a list of devices to use for Variable ops. Each variable is ps_devices: a list of devices to use for Variable Ops. Each variable is
assigned to the least loaded device. assigned to the least loaded device.
""" """
self.ps_devices = ps_devices self.ps_devices = ps_devices
...@@ -119,7 +123,8 @@ class ParamServerDeviceSetter(object): ...@@ -119,7 +123,8 @@ class ParamServerDeviceSetter(object):
return op.device return op.device
if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']: if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']:
return self.worker_device return self.worker_device
# Gets the least loaded ps_device
device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1)) device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index] device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements() var_size = op.outputs[0].get_shape().num_elements()
...@@ -127,15 +132,16 @@ class ParamServerDeviceSetter(object): ...@@ -127,15 +132,16 @@ class ParamServerDeviceSetter(object):
return device_name return device_name
def _create_device_setter(is_cpu_ps, worker, num_gpus):
def _create_device_setter(is_cpu_ps, worker):
"""Create device setter object.""" """Create device setter object."""
if is_cpu_ps: if is_cpu_ps:
# tf.train.replica_device_setter supports placing variables on the CPU, all
# on one GPU, or on ps_servers defined in a cluster_spec.
return tf.train.replica_device_setter( return tf.train.replica_device_setter(
worker_device=worker, ps_device='/cpu:0', ps_tasks=1) worker_device=worker, ps_device='/cpu:0', ps_tasks=1)
else: else:
gpus = ['/gpu:%d' % i for i in range(FLAGS.num_gpus)] gpus = ['/gpu:%d' % i for i in range(num_gpus)]
return ParamServerDeviceSetter(worker, gpus) return GpuParamServerDeviceSetter(worker, gpus)
def _resnet_model_fn(features, labels, mode): def _resnet_model_fn(features, labels, mode):
...@@ -169,7 +175,7 @@ def _resnet_model_fn(features, labels, mode): ...@@ -169,7 +175,7 @@ def _resnet_model_fn(features, labels, mode):
if num_gpus != 0: if num_gpus != 0:
for i in range(num_gpus): for i in range(num_gpus):
worker = '/gpu:%d' % i worker = '/gpu:%d' % i
device_setter = _create_device_setter(is_cpu_ps, worker) device_setter = _create_device_setter(is_cpu_ps, worker, FLAGS.num_gpus)
with tf.variable_scope('resnet', reuse=bool(i != 0)): with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope: with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter): with tf.device(device_setter):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment