Commit f1f0503c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332314917
parent 955389a9
...@@ -98,7 +98,8 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -98,7 +98,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0, num_gpus=0,
all_reduce_alg=None, all_reduce_alg=None,
num_packs=1, num_packs=1,
tpu_address=None): tpu_address=None,
**kwargs):
"""Return a DistributionStrategy for running the model. """Return a DistributionStrategy for running the model.
Args: Args:
...@@ -117,6 +118,7 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -117,6 +118,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`. or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not be tpu_address: Optional. String that represents TPU to connect to. Must not be
None if `distribution_strategy` is set to `tpu`. None if `distribution_strategy` is set to `tpu`.
**kwargs: Additional kwargs for internal usages.
Returns: Returns:
tf.distribute.DistibutionStrategy object. tf.distribute.DistibutionStrategy object.
...@@ -125,6 +127,7 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -125,6 +127,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
`num_gpus` is larger than 1; or `num_gpus` is negative or if `num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified. `distribution_strategy` is `tpu` but `tpu_address` is not specified.
""" """
del kwargs
if num_gpus < 0: if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.") raise ValueError("`num_gpus` can not be negative.")
......
...@@ -141,6 +141,15 @@ class RuntimeConfig(base_config.Config): ...@@ -141,6 +141,15 @@ class RuntimeConfig(base_config.Config):
run_eagerly: bool = False run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False batchnorm_spatial_persistent: bool = False
# Global model parallelism configurations.
num_cores_per_replica: int = 1
default_shard_dim: int = -1
def model_parallelism(self):
return dict(
num_cores_per_replica=self.num_cores_per_replica,
default_shard_dim=self.default_shard_dim)
@dataclasses.dataclass @dataclasses.dataclass
class TensorboardConfig(base_config.Config): class TensorboardConfig(base_config.Config):
......
...@@ -20,10 +20,10 @@ from absl import flags ...@@ -20,10 +20,10 @@ from absl import flags
import gin import gin
from official.core import train_utils from official.core import train_utils
from official.common import distribute_utils
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports from official.common import registry_imports
# pylint: enable=unused-import # pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
...@@ -52,7 +52,8 @@ def main(_): ...@@ -52,7 +52,8 @@ def main(_):
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu) tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope(): with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir) task = task_factory.get_task(params.task, logging_dir=model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment