Unverified Commit 76300c26 authored by Yuefeng Zhou's avatar Yuefeng Zhou Committed by GitHub
Browse files

Scale up learning rate according to num workers in Estimator imagenet models. (#6472)

* Move distribution strategy creation before creating any ops, which is
required by multi-node collective ops in eager mode.

* Scale up learning rate according to num workers in ResNet50 w/
Estimator.

* Scale up LR in cifar.

* Fix a typo.

* Add num_workers to run param as well. Make num_worker optional in
params.
parent 4c11b84b
......@@ -212,9 +212,9 @@ def cifar10_model_fn(features, labels, mode, params):
features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS])
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=128,
num_images=NUM_IMAGES['train'], boundary_epochs=[91, 136, 182],
decay_rates=[1, 0.1, 0.01, 0.001])
batch_size=params['batch_size'] * params.get('num_workers', 1),
batch_denom=128, num_images=NUM_IMAGES['train'],
boundary_epochs=[91, 136, 182], decay_rates=[1, 0.1, 0.01, 0.001])
# Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
# and seems more stable in testing. The difference was nominal for ResNet-56.
......
......@@ -318,9 +318,10 @@ def imagenet_model_fn(features, labels, mode, params):
base_lr = .128
learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=256,
num_images=NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr)
batch_size=params['batch_size'] * params.get('num_workers', 1),
batch_denom=256, num_images=NUM_IMAGES['train'],
boundary_epochs=[30, 60, 80, 90], decay_rates=[1, 0.1, 0.01, 0.001, 1e-4],
warmup=warmup, base_lr=base_lr)
return resnet_run_loop.resnet_model_fn(
features=features,
......
......@@ -573,7 +573,8 @@ def resnet_main(
'resnet_version': int(flags_obj.resnet_version),
'loss_scale': flags_core.get_loss_scale(flags_obj),
'dtype': flags_core.get_tf_dtype(flags_obj),
'fine_tune': flags_obj.fine_tune
'fine_tune': flags_obj.fine_tune,
'num_workers': num_workers,
})
run_params = {
......@@ -583,6 +584,7 @@ def resnet_main(
'resnet_version': flags_obj.resnet_version,
'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs,
'num_workers': num_workers,
}
if flags_obj.use_synthetic_data:
dataset_name = dataset_name + '-synthetic'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment