Commit 7dde2e2f authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #2044 from tfboyd/cifar_estimator

Enhanced comments and renamed ParamServerDeviceSetter
parents 78007443 2e960eb1
......@@ -98,16 +98,20 @@ tf.flags.DEFINE_boolean('log_device_placement', False,
'Whether to log device placement.')
# TODO(jamesqin): Replace with fix in b/62239022
class ParamServerDeviceSetter(object):
"""Helper class to assign variables on the least loaded ps-device."""
class GpuParamServerDeviceSetter(object):
"""Used with tf.device() to place variables on the least loaded GPU.
def __init__(self, worker_device, ps_devices):
"""Initializer for ParamServerDeviceSetter.
A common use for this class is to pass a list of GPU devices, e.g. ['gpu:0',
'gpu:1','gpu:2'], as ps_devices. When each variable is placed, it will be
placed on the least loaded gpu. All other Ops, which will be the computation
Ops, will be placed on the worker_device.
"""
def __init__(self, worker_device, ps_devices):
"""Initializer for GpuParamServerDeviceSetter.
Args:
worker_device: the device to use for computer ops.
ps_devices: a list of devices to use for Variable ops. Each variable is
worker_device: the device to use for computation Ops.
ps_devices: a list of devices to use for Variable Ops. Each variable is
assigned to the least loaded device.
"""
self.ps_devices = ps_devices
......@@ -120,6 +124,7 @@ class ParamServerDeviceSetter(object):
if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']:
return self.worker_device
# Gets the least loaded ps_device
device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
......@@ -127,15 +132,16 @@ class ParamServerDeviceSetter(object):
return device_name
def _create_device_setter(is_cpu_ps, worker):
def _create_device_setter(is_cpu_ps, worker, num_gpus):
"""Create device setter object."""
if is_cpu_ps:
# tf.train.replica_device_setter supports placing variables on the CPU, all
# on one GPU, or on ps_servers defined in a cluster_spec.
return tf.train.replica_device_setter(
worker_device=worker, ps_device='/cpu:0', ps_tasks=1)
else:
gpus = ['/gpu:%d' % i for i in range(FLAGS.num_gpus)]
return ParamServerDeviceSetter(worker, gpus)
gpus = ['/gpu:%d' % i for i in range(num_gpus)]
return GpuParamServerDeviceSetter(worker, gpus)
def _resnet_model_fn(features, labels, mode):
......@@ -169,7 +175,7 @@ def _resnet_model_fn(features, labels, mode):
if num_gpus != 0:
for i in range(num_gpus):
worker = '/gpu:%d' % i
device_setter = _create_device_setter(is_cpu_ps, worker)
device_setter = _create_device_setter(is_cpu_ps, worker, FLAGS.num_gpus)
with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment