"vscode:/vscode.git/clone" did not exist on "1a01cecb7cadfb96365acfa72e23e9b89c7197a6"
Unverified Commit b3594a83 authored by Yuefeng Zhou's avatar Yuefeng Zhou Committed by GitHub
Browse files

Move distribution strategy creation before creating any ops, which is (#6435)

required by multi-node collective ops in eager mode.
parent 6765b16d
......@@ -121,6 +121,12 @@ def run(flags_obj):
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus)
strategy_scope = keras_common.get_strategy_scope(strategy)
if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn(
......@@ -147,12 +153,6 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus)
strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
......
......@@ -118,6 +118,13 @@ def run(flags_obj):
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster())
strategy_scope = keras_common.get_strategy_scope(strategy)
# pylint: disable=protected-access
if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
......@@ -150,13 +157,6 @@ def run(flags_obj):
parse_record_fn=parse_record_keras,
dtype=dtype)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster())
strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer()
if dtype == 'float16':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment