Commit d9103d00 authored by Eli Bixby's avatar Eli Bixby
Browse files

Fix variable names remove specific loop

parent 2164c8db
...@@ -74,61 +74,50 @@ def get_model_fn(num_gpus, variable_strategy, num_workers): ...@@ -74,61 +74,50 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
tower_gradvars = [] tower_gradvars = []
tower_preds = [] tower_preds = []
if num_gpus != 0: if num_gpus == 0:
for i in range(num_gpus): num_devices = 1
worker_device = '/gpu:{}'.format(i) device_type = 'cpu'
if variable_strategy == 'CPU':
device_setter = cifar10_utils.local_device_setter(
worker_device=worker_device)
elif variable_strategy == 'GPU':
device_setter = cifar10_utils.local_device_setter(
ps_device_type='gpu',
worker_device=worker_device,
ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
num_gpus,
tf.contrib.training.byte_size_load_fn
)
)
with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter):
loss, gradvars, preds = _tower_fn(
is_training,
weight_decay,
tower_features[i],
tower_labels[i],
False,
params['num_layers'],
params['batch_norm_decay'],
params['batch_norm_epsilon'])
tower_losses.append(loss)
tower_gradvars.append(gradvars)
tower_preds.append(preds)
if i == 0:
# Only trigger batch_norm moving mean and variance update from
# the 1st tower. Ideally, we should grab the updates from all
# towers but these stats accumulate extremely fast so we can
# ignore the other stats from the other towers without
# significant detriment.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope)
else: else:
with tf.variable_scope('resnet'), tf.device('/cpu:0'): num_devices = num_gpus
with tf.name_scope('tower_cpu') as name_scope: device_type = 'gpu'
loss, gradvars, preds = _tower_fn(
is_training, for i in range(num_devices):
weight_decay, worker_device = '/{}:{}'.format(device_type, i)
tower_features[0], if variable_strategy == 'CPU':
tower_labels[0], device_setter = cifar10_utils.local_device_setter(
True, worker_device=worker_device)
params['num_layers'], elif variable_strategy == 'GPU':
params['batch_norm_decay'], device_setter = cifar10_utils.local_device_setter(
params['batch_norm_epsilon']) ps_device_type='gpu',
tower_losses.append(loss) worker_device=worker_device,
tower_gradvars.append(gradvars) ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
tower_preds.append(preds) num_gpus,
tf.contrib.training.byte_size_load_fn
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, name_scope) )
)
with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter):
loss, gradvars, preds = _tower_fn(
is_training,
weight_decay,
tower_features[i],
tower_labels[i],
(device_type == 'cpu'),
params['num_layers'],
params['batch_norm_decay'],
params['batch_norm_epsilon'])
tower_losses.append(loss)
tower_gradvars.append(gradvars)
tower_preds.append(preds)
if i == 0:
# Only trigger batch_norm moving mean and variance update from
# the 1st tower. Ideally, we should grab the updates from all
# towers but these stats accumulate extremely fast so we can
# ignore the other stats from the other towers without
# significant detriment.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope)
# Now compute global loss and gradients. # Now compute global loss and gradients.
gradvars = [] gradvars = []
...@@ -420,7 +409,7 @@ if __name__ == '__main__': ...@@ -420,7 +409,7 @@ if __name__ == '__main__':
help='The directory where the model will be stored.' help='The directory where the model will be stored.'
) )
parser.add_argument( parser.add_argument(
'--variable_strategy', '--variable-strategy',
choices=['CPU', 'GPU'], choices=['CPU', 'GPU'],
type=str, type=str,
default='CPU', default='CPU',
...@@ -520,13 +509,13 @@ if __name__ == '__main__': ...@@ -520,13 +509,13 @@ if __name__ == '__main__':
help='Whether to log device placement.' help='Whether to log device placement.'
) )
parser.add_argument( parser.add_argument(
'--batch_norm_decay', '--batch-norm-decay',
type=float, type=float,
default=0.997, default=0.997,
help='Decay for batch norm.' help='Decay for batch norm.'
) )
parser.add_argument( parser.add_argument(
'--batch_norm_epsilon', '--batch-norm-epsilon',
type=float, type=float,
default=1e-5, default=1e-5,
help='Epsilon for batch norm.' help='Epsilon for batch norm.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment