Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9cff6e3b
Commit
9cff6e3b
authored
Jul 12, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Jul 12, 2021
Browse files
Internal change.
PiperOrigin-RevId: 384257276
parent
0e74158f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
6 deletions
+23
-6
official/vision/beta/projects/volumetric_models/registry_imports.py
...ision/beta/projects/volumetric_models/registry_imports.py
+1
-0
official/vision/beta/train_spatial_partitioning.py
official/vision/beta/train_spatial_partitioning.py
+22
-6
No files found.
official/vision/beta/projects/volumetric_models/registry_imports.py
View file @
9cff6e3b
...
...
@@ -16,6 +16,7 @@
# pylint: disable=unused-import
from
official.common
import
registry_imports
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
semantic_segmentation_3d_cfg
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.vision.beta.projects.volumetric_models.tasks
import
semantic_segmentation_3d
official/vision/beta/train_spatial_partitioning.py
View file @
9cff6e3b
...
...
@@ -14,6 +14,7 @@
# Lint as: python3
"""TensorFlow Model Garden Vision training driver with spatial partitioning."""
from
typing
import
Sequence
from
absl
import
app
from
absl
import
flags
...
...
@@ -33,19 +34,34 @@ from official.modeling import performance
FLAGS
=
flags
.
FLAGS
def
get_computation_shape_for_model_parallelism
(
input_partition_dims
):
"""Return computation shape to be used for TPUStrategy spatial partition."""
def
get_computation_shape_for_model_parallelism
(
input_partition_dims
:
Sequence
[
int
])
->
Sequence
[
int
]:
"""Returns computation shape to be used for TPUStrategy spatial partition.
Args:
input_partition_dims: The number of partitions along each dimension.
Returns:
A list of integers specifying the computation shape.
Raises:
ValueError: If the number of logical devices is not supported.
"""
num_logical_devices
=
np
.
prod
(
input_partition_dims
)
if
num_logical_devices
==
1
:
return
[
1
,
1
,
1
,
1
]
if
num_logical_devices
==
2
:
el
if
num_logical_devices
==
2
:
return
[
1
,
1
,
1
,
2
]
if
num_logical_devices
==
4
:
el
if
num_logical_devices
==
4
:
return
[
1
,
2
,
1
,
2
]
if
num_logical_devices
==
8
:
el
if
num_logical_devices
==
8
:
return
[
2
,
2
,
1
,
2
]
if
num_logical_devices
==
16
:
el
if
num_logical_devices
==
16
:
return
[
4
,
2
,
1
,
2
]
else
:
raise
ValueError
(
'The number of logical devices %d is not supported. Supported numbers '
'are 1, 2, 4, 8, 16'
%
num_logical_devices
)
def
create_distribution_strategy
(
distribution_strategy
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment