"vscode:/vscode.git/clone" did not exist on "477ab964e2165cb586b5c00425f6e463d7edeadd"
Commit 1af7172d authored by Yanhui Liang's avatar Yanhui Liang Committed by A. Unique TensorFlower
Browse files

Remove 'num_workers' arg from get_distribution_strategy() method.

PiperOrigin-RevId: 291810091
parent 569ec532
...@@ -673,12 +673,11 @@ class ExecutorBuilder(object): ...@@ -673,12 +673,11 @@ class ExecutorBuilder(object):
""" """
def __init__(self, strategy_type=None, strategy_config=None): def __init__(self, strategy_type=None, strategy_config=None):
num_workers = distribution_utils.configure_cluster( _ = distribution_utils.configure_cluster(
strategy_config.worker_hosts, strategy_config.task_index) strategy_config.worker_hosts, strategy_config.task_index)
self._strategy = distribution_utils.get_distribution_strategy( self._strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=strategy_type, distribution_strategy=strategy_type,
num_gpus=strategy_config.num_gpus, num_gpus=strategy_config.num_gpus,
num_workers=num_workers,
all_reduce_alg=strategy_config.all_reduce_alg, all_reduce_alg=strategy_config.all_reduce_alg,
num_packs=strategy_config.num_packs, num_packs=strategy_config.num_packs,
tpu_address=strategy_config.tpu) tpu_address=strategy_config.tpu)
......
...@@ -563,7 +563,6 @@ def resnet_main( ...@@ -563,7 +563,6 @@ def resnet_main(
distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj), num_gpus=flags_core.get_num_gpus(flags_obj),
num_workers=num_workers,
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs) num_packs=flags_obj.num_packs)
......
...@@ -83,7 +83,6 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs): ...@@ -83,7 +83,6 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
def get_distribution_strategy(distribution_strategy="mirrored", def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0, num_gpus=0,
num_workers=1,
all_reduce_alg=None, all_reduce_alg=None,
num_packs=1, num_packs=1,
tpu_address=None): tpu_address=None):
...@@ -96,7 +95,6 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -96,7 +95,6 @@ def get_distribution_strategy(distribution_strategy="mirrored",
'off' means not to use Distribution Strategy; 'tpu' means to use 'off' means not to use Distribution Strategy; 'tpu' means to use
TPUStrategy using `tpu_address`. TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model. num_gpus: Number of GPUs to run this model.
num_workers: Number of workers to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
...@@ -120,8 +118,8 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -120,8 +118,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if distribution_strategy == "off": if distribution_strategy == "off":
if num_gpus > 1: if num_gpus > 1:
raise ValueError( raise ValueError(
"When {} GPUs and {} workers are specified, distribution_strategy " "When {} GPUs are specified, distribution_strategy "
"flag cannot be set to 'off'.".format(num_gpus, num_workers)) "flag cannot be set to 'off'.".format(num_gpus))
return None return None
if distribution_strategy == "tpu": if distribution_strategy == "tpu":
......
...@@ -104,7 +104,6 @@ def run(flags_obj): ...@@ -104,7 +104,6 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus, num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(),
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs) num_packs=flags_obj.num_packs)
......
...@@ -212,7 +212,6 @@ def run(flags_obj): ...@@ -212,7 +212,6 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus, num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(),
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs, num_packs=flags_obj.num_packs,
tpu_address=flags_obj.tpu) tpu_address=flags_obj.tpu)
......
...@@ -84,13 +84,12 @@ def run(flags_obj): ...@@ -84,13 +84,12 @@ def run(flags_obj):
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
# Configures cluster spec for distribution strategy. # Configures cluster spec for distribution strategy.
num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
flags_obj.task_index) flags_obj.task_index)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus, num_gpus=flags_obj.num_gpus,
num_workers=num_workers,
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs, num_packs=flags_obj.num_packs,
tpu_address=flags_obj.tpu) tpu_address=flags_obj.tpu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment