Commit d9103d00 authored by Eli Bixby's avatar Eli Bixby
Browse files

Fix variable names remove specific loop

parent 2164c8db
......@@ -74,61 +74,50 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
tower_gradvars = []
tower_preds = []
if num_gpus != 0:
for i in range(num_gpus):
worker_device = '/gpu:{}'.format(i)
if variable_strategy == 'CPU':
device_setter = cifar10_utils.local_device_setter(
worker_device=worker_device)
elif variable_strategy == 'GPU':
device_setter = cifar10_utils.local_device_setter(
ps_device_type='gpu',
worker_device=worker_device,
ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
num_gpus,
tf.contrib.training.byte_size_load_fn
)
)
with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter):
loss, gradvars, preds = _tower_fn(
is_training,
weight_decay,
tower_features[i],
tower_labels[i],
False,
params['num_layers'],
params['batch_norm_decay'],
params['batch_norm_epsilon'])
tower_losses.append(loss)
tower_gradvars.append(gradvars)
tower_preds.append(preds)
if i == 0:
# Only trigger batch_norm moving mean and variance update from
# the 1st tower. Ideally, we should grab the updates from all
# towers but these stats accumulate extremely fast so we can
# ignore the other stats from the other towers without
# significant detriment.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope)
if num_gpus == 0:
num_devices = 1
device_type = 'cpu'
else:
with tf.variable_scope('resnet'), tf.device('/cpu:0'):
with tf.name_scope('tower_cpu') as name_scope:
loss, gradvars, preds = _tower_fn(
is_training,
weight_decay,
tower_features[0],
tower_labels[0],
True,
params['num_layers'],
params['batch_norm_decay'],
params['batch_norm_epsilon'])
tower_losses.append(loss)
tower_gradvars.append(gradvars)
tower_preds.append(preds)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, name_scope)
num_devices = num_gpus
device_type = 'gpu'
for i in range(num_devices):
worker_device = '/{}:{}'.format(device_type, i)
if variable_strategy == 'CPU':
device_setter = cifar10_utils.local_device_setter(
worker_device=worker_device)
elif variable_strategy == 'GPU':
device_setter = cifar10_utils.local_device_setter(
ps_device_type='gpu',
worker_device=worker_device,
ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
num_gpus,
tf.contrib.training.byte_size_load_fn
)
)
with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter):
loss, gradvars, preds = _tower_fn(
is_training,
weight_decay,
tower_features[i],
tower_labels[i],
(device_type == 'cpu'),
params['num_layers'],
params['batch_norm_decay'],
params['batch_norm_epsilon'])
tower_losses.append(loss)
tower_gradvars.append(gradvars)
tower_preds.append(preds)
if i == 0:
# Only trigger batch_norm moving mean and variance update from
# the 1st tower. Ideally, we should grab the updates from all
# towers but these stats accumulate extremely fast so we can
# ignore the other stats from the other towers without
# significant detriment.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope)
# Now compute global loss and gradients.
gradvars = []
......@@ -420,7 +409,7 @@ if __name__ == '__main__':
help='The directory where the model will be stored.'
)
parser.add_argument(
'--variable_strategy',
'--variable-strategy',
choices=['CPU', 'GPU'],
type=str,
default='CPU',
......@@ -520,13 +509,13 @@ if __name__ == '__main__':
help='Whether to log device placement.'
)
parser.add_argument(
'--batch_norm_decay',
'--batch-norm-decay',
type=float,
default=0.997,
help='Decay for batch norm.'
)
parser.add_argument(
'--batch_norm_epsilon',
'--batch-norm-epsilon',
type=float,
default=1e-5,
help='Epsilon for batch norm.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment