"magic_pdf/model/vscode:/vscode.git/clone" did not exist on "1ec5d09d13707ce82a371629e54be2f686392db8"
Commit 9cff6e3b authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 384257276
parent 0e74158f
......@@ -16,6 +16,7 @@
# pylint: disable=unused-import
from official.common import registry_imports
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as semantic_segmentation_3d_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d
......@@ -14,6 +14,7 @@
# Lint as: python3
"""TensorFlow Model Garden Vision training driver with spatial partitioning."""
from typing import Sequence
from absl import app
from absl import flags
......@@ -33,19 +34,34 @@ from official.modeling import performance
FLAGS = flags.FLAGS
def get_computation_shape_for_model_parallelism(input_partition_dims):
"""Return computation shape to be used for TPUStrategy spatial partition."""
def get_computation_shape_for_model_parallelism(
input_partition_dims: Sequence[int]) -> Sequence[int]:
"""Returns computation shape to be used for TPUStrategy spatial partition.
Args:
input_partition_dims: The number of partitions along each dimension.
Returns:
A list of integers specifying the computation shape.
Raises:
ValueError: If the number of logical devices is not supported.
"""
num_logical_devices = np.prod(input_partition_dims)
if num_logical_devices == 1:
return [1, 1, 1, 1]
if num_logical_devices == 2:
elif num_logical_devices == 2:
return [1, 1, 1, 2]
if num_logical_devices == 4:
elif num_logical_devices == 4:
return [1, 2, 1, 2]
if num_logical_devices == 8:
elif num_logical_devices == 8:
return [2, 2, 1, 2]
if num_logical_devices == 16:
elif num_logical_devices == 16:
return [4, 2, 1, 2]
else:
raise ValueError(
'The number of logical devices %d is not supported. Supported numbers '
'are 1, 2, 4, 8, 16' % num_logical_devices)
def create_distribution_strategy(distribution_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment