"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "ff31b6e3b7897ab951167d72173a83df2cda6c57"
Commit 9c377959 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 384257276
parent adc12a16
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports from official.common import registry_imports
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as semantic_segmentation_3d_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# Lint as: python3 # Lint as: python3
"""TensorFlow Model Garden Vision training driver with spatial partitioning.""" """TensorFlow Model Garden Vision training driver with spatial partitioning."""
from typing import Sequence
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -33,19 +34,34 @@ from official.modeling import performance ...@@ -33,19 +34,34 @@ from official.modeling import performance
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def get_computation_shape_for_model_parallelism(input_partition_dims): def get_computation_shape_for_model_parallelism(
"""Return computation shape to be used for TPUStrategy spatial partition.""" input_partition_dims: Sequence[int]) -> Sequence[int]:
"""Returns computation shape to be used for TPUStrategy spatial partition.
Args:
input_partition_dims: The number of partitions along each dimension.
Returns:
A list of integers specifying the computation shape.
Raises:
ValueError: If the number of logical devices is not supported.
"""
num_logical_devices = np.prod(input_partition_dims) num_logical_devices = np.prod(input_partition_dims)
if num_logical_devices == 1: if num_logical_devices == 1:
return [1, 1, 1, 1] return [1, 1, 1, 1]
if num_logical_devices == 2: elif num_logical_devices == 2:
return [1, 1, 1, 2] return [1, 1, 1, 2]
if num_logical_devices == 4: elif num_logical_devices == 4:
return [1, 2, 1, 2] return [1, 2, 1, 2]
if num_logical_devices == 8: elif num_logical_devices == 8:
return [2, 2, 1, 2] return [2, 2, 1, 2]
if num_logical_devices == 16: elif num_logical_devices == 16:
return [4, 2, 1, 2] return [4, 2, 1, 2]
else:
raise ValueError(
'The number of logical devices %d is not supported. Supported numbers '
'are 1, 2, 4, 8, 16' % num_logical_devices)
def create_distribution_strategy(distribution_strategy, def create_distribution_strategy(distribution_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment