"tests/vscode:/vscode.git/clone" did not exist on "b785ddb654e4be3ae0066e231734754bdb2a191c"
Unverified Commit 76300c26 authored by Yuefeng Zhou's avatar Yuefeng Zhou Committed by GitHub
Browse files

Scale up learning rate according to num workers in Estimator imagenet models. (#6472)

* Move distribution strategy creation before creating any ops, which is
required by multi-node collective ops in eager mode.

* Scale up learning rate according to num workers in ResNet50 w/
Estimator.

* Scale up LR in cifar.

* Fix a typo.

* Add num_workers to run param as well. Make num_worker optional in
params.
parent 4c11b84b
...@@ -212,9 +212,9 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -212,9 +212,9 @@ def cifar10_model_fn(features, labels, mode, params):
features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS]) features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS])
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under. # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
learning_rate_fn = resnet_run_loop.learning_rate_with_decay( learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=128, batch_size=params['batch_size'] * params.get('num_workers', 1),
num_images=NUM_IMAGES['train'], boundary_epochs=[91, 136, 182], batch_denom=128, num_images=NUM_IMAGES['train'],
decay_rates=[1, 0.1, 0.01, 0.001]) boundary_epochs=[91, 136, 182], decay_rates=[1, 0.1, 0.01, 0.001])
# Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
# and seems more stable in testing. The difference was nominal for ResNet-56. # and seems more stable in testing. The difference was nominal for ResNet-56.
......
...@@ -318,9 +318,10 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -318,9 +318,10 @@ def imagenet_model_fn(features, labels, mode, params):
base_lr = .128 base_lr = .128
learning_rate_fn = resnet_run_loop.learning_rate_with_decay( learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=256, batch_size=params['batch_size'] * params.get('num_workers', 1),
num_images=NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90], batch_denom=256, num_images=NUM_IMAGES['train'],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr) boundary_epochs=[30, 60, 80, 90], decay_rates=[1, 0.1, 0.01, 0.001, 1e-4],
warmup=warmup, base_lr=base_lr)
return resnet_run_loop.resnet_model_fn( return resnet_run_loop.resnet_model_fn(
features=features, features=features,
......
...@@ -573,7 +573,8 @@ def resnet_main( ...@@ -573,7 +573,8 @@ def resnet_main(
'resnet_version': int(flags_obj.resnet_version), 'resnet_version': int(flags_obj.resnet_version),
'loss_scale': flags_core.get_loss_scale(flags_obj), 'loss_scale': flags_core.get_loss_scale(flags_obj),
'dtype': flags_core.get_tf_dtype(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj),
'fine_tune': flags_obj.fine_tune 'fine_tune': flags_obj.fine_tune,
'num_workers': num_workers,
}) })
run_params = { run_params = {
...@@ -583,6 +584,7 @@ def resnet_main( ...@@ -583,6 +584,7 @@ def resnet_main(
'resnet_version': flags_obj.resnet_version, 'resnet_version': flags_obj.resnet_version,
'synthetic_data': flags_obj.use_synthetic_data, 'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs, 'train_epochs': flags_obj.train_epochs,
'num_workers': num_workers,
} }
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
dataset_name = dataset_name + '-synthetic' dataset_name = dataset_name + '-synthetic'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment