Commit a0eea42a authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332314917
parent 2218a0dd
......@@ -98,7 +98,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0,
all_reduce_alg=None,
num_packs=1,
tpu_address=None):
tpu_address=None,
**kwargs):
"""Return a DistributionStrategy for running the model.
Args:
......@@ -117,6 +118,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not be
None if `distribution_strategy` is set to `tpu`.
**kwargs: Additional kwargs for internal usages.
Returns:
tf.distribute.DistibutionStrategy object.
......@@ -125,6 +127,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
del kwargs
if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.")
......
......@@ -141,6 +141,15 @@ class RuntimeConfig(base_config.Config):
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
# Global model parallelism configurations.
num_cores_per_replica: int = 1
default_shard_dim: int = -1
def model_parallelism(self):
return dict(
num_cores_per_replica=self.num_cores_per_replica,
default_shard_dim=self.default_shard_dim)
@dataclasses.dataclass
class TensorboardConfig(base_config.Config):
......
......@@ -20,10 +20,10 @@ from absl import flags
import gin
from official.core import train_utils
from official.common import distribute_utils
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
......@@ -52,7 +52,8 @@ def main(_):
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment