"tools/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "9734dcdec4249086b8278cda305f3d4f9f3b9b12"
Commit d793ea82 authored by Ayush Dubey's avatar Ayush Dubey Committed by Toby Boyd
Browse files

Change `CollectiveAllReduceStrategy` to `MultiWorkerMirroredStrategy`. (#6282)

* s/CollectiveAllReduceStrategy/MultiWorkerMirroredStrategy

* More s/contrib.distribute/distribute.experimental
parent 54dffe2e
...@@ -63,8 +63,7 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -63,8 +63,7 @@ def get_distribution_strategy(distribution_strategy="default",
return None return None
if distribution_strategy == "multi_worker_mirrored" or num_workers > 1: if distribution_strategy == "multi_worker_mirrored" or num_workers > 1:
return tf.contrib.distribute.CollectiveAllReduceStrategy( return tf.distribute.experimental.MultiWorkerMirroredStrategy()
num_gpus_per_worker=num_gpus)
if (distribution_strategy == "one_device" or if (distribution_strategy == "one_device" or
(distribution_strategy == "default" and num_gpus <= 1)): (distribution_strategy == "default" and num_gpus <= 1)):
...@@ -91,8 +90,7 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -91,8 +90,7 @@ def get_distribution_strategy(distribution_strategy="default",
return tf.distribute.MirroredStrategy(devices=devices) return tf.distribute.MirroredStrategy(devices=devices)
if distribution_strategy == "parameter_server": if distribution_strategy == "parameter_server":
return tf.contrib.distribute.ParameterServerStrategy( return tf.distribute.experimental.ParameterServerStrategy()
num_gpus_per_worker=num_gpus)
raise ValueError( raise ValueError(
"Unrecognized Distribution Strategy: %r" % distribution_strategy) "Unrecognized Distribution Strategy: %r" % distribution_strategy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment