Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
65a967e9
Commit
65a967e9
authored
Nov 15, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 15, 2021
Browse files
Internal change
PiperOrigin-RevId: 410013894
parent
30e6e03f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
7 deletions
+50
-7
official/common/distribute_utils.py
official/common/distribute_utils.py
+2
-2
official/common/distribute_utils_test.py
official/common/distribute_utils_test.py
+48
-5
No files found.
official/common/distribute_utils.py
View file @
65a967e9
...
@@ -141,8 +141,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
...
@@ -141,8 +141,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
distribution_strategy
=
distribution_strategy
.
lower
()
distribution_strategy
=
distribution_strategy
.
lower
()
if
distribution_strategy
==
"off"
:
if
distribution_strategy
==
"off"
:
if
num_gpus
>
1
:
if
num_gpus
>
1
:
raise
ValueError
(
"When {} GPUs are specified,
distribution_strategy
"
raise
ValueError
(
f
"When
{
num_gpus
}
GPUs are specified, "
"flag cannot be set to `off`."
.
format
(
num_gpus
)
)
"
distribution_strategy
flag cannot be set to `off`."
)
# Return the default distribution strategy.
# Return the default distribution strategy.
return
tf
.
distribute
.
get_strategy
()
return
tf
.
distribute
.
get_strategy
()
...
...
official/common/distribute_utils_test.py
View file @
65a967e9
...
@@ -12,24 +12,40 @@
...
@@ -12,24 +12,40 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Tests for distribution util functions."""
"""Tests for distribution util functions."""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.common
import
distribute_utils
class
GetDistributionStrategyTest
(
tf
.
test
.
TestCase
):
class
DistributeUtilsTest
(
tf
.
test
.
TestCase
):
"""Tests for get_distribution_strategy."""
"""Tests for distribute util functions."""
def
test_invalid_args
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'`num_gpus` can not be negative.'
):
_
=
distribute_utils
.
get_distribution_strategy
(
num_gpus
=-
1
)
with
self
.
assertRaisesRegex
(
ValueError
,
'.*If you meant to pass the string .*'
):
_
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
False
,
num_gpus
=
0
)
with
self
.
assertRaisesRegex
(
ValueError
,
'When 2 GPUs are specified.*'
):
_
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
'off'
,
num_gpus
=
2
)
with
self
.
assertRaisesRegex
(
ValueError
,
'`OneDeviceStrategy` can not be used.*'
):
_
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
'one_device'
,
num_gpus
=
2
)
def
test_one_device_strategy_cpu
(
self
):
def
test_one_device_strategy_cpu
(
self
):
ds
=
distribute_utils
.
get_distribution_strategy
(
num_gpus
=
0
)
ds
=
distribute_utils
.
get_distribution_strategy
(
'one_device'
,
num_gpus
=
0
)
self
.
assertEquals
(
ds
.
num_replicas_in_sync
,
1
)
self
.
assertEquals
(
ds
.
num_replicas_in_sync
,
1
)
self
.
assertEquals
(
len
(
ds
.
extended
.
worker_devices
),
1
)
self
.
assertEquals
(
len
(
ds
.
extended
.
worker_devices
),
1
)
self
.
assertIn
(
'CPU'
,
ds
.
extended
.
worker_devices
[
0
])
self
.
assertIn
(
'CPU'
,
ds
.
extended
.
worker_devices
[
0
])
def
test_one_device_strategy_gpu
(
self
):
def
test_one_device_strategy_gpu
(
self
):
ds
=
distribute_utils
.
get_distribution_strategy
(
num_gpus
=
1
)
ds
=
distribute_utils
.
get_distribution_strategy
(
'one_device'
,
num_gpus
=
1
)
self
.
assertEquals
(
ds
.
num_replicas_in_sync
,
1
)
self
.
assertEquals
(
ds
.
num_replicas_in_sync
,
1
)
self
.
assertEquals
(
len
(
ds
.
extended
.
worker_devices
),
1
)
self
.
assertEquals
(
len
(
ds
.
extended
.
worker_devices
),
1
)
self
.
assertIn
(
'GPU'
,
ds
.
extended
.
worker_devices
[
0
])
self
.
assertIn
(
'GPU'
,
ds
.
extended
.
worker_devices
[
0
])
...
@@ -41,6 +57,27 @@ class GetDistributionStrategyTest(tf.test.TestCase):
...
@@ -41,6 +57,27 @@ class GetDistributionStrategyTest(tf.test.TestCase):
for
device
in
ds
.
extended
.
worker_devices
:
for
device
in
ds
.
extended
.
worker_devices
:
self
.
assertIn
(
'GPU'
,
device
)
self
.
assertIn
(
'GPU'
,
device
)
_
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
'mirrored'
,
num_gpus
=
2
,
all_reduce_alg
=
'nccl'
,
num_packs
=
2
)
with
self
.
assertRaisesRegex
(
ValueError
,
'When used with `mirrored`, valid values for all_reduce_alg are.*'
):
_
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
'mirrored'
,
num_gpus
=
2
,
all_reduce_alg
=
'dummy'
,
num_packs
=
2
)
def
test_mwms
(
self
):
distribute_utils
.
configure_cluster
(
worker_hosts
=
None
,
task_index
=-
1
)
ds
=
distribute_utils
.
get_distribution_strategy
(
'multi_worker_mirrored'
,
all_reduce_alg
=
'nccl'
)
self
.
assertIsInstance
(
ds
,
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
def
test_no_strategy
(
self
):
def
test_no_strategy
(
self
):
ds
=
distribute_utils
.
get_distribution_strategy
(
'off'
)
ds
=
distribute_utils
.
get_distribution_strategy
(
'off'
)
self
.
assertIs
(
ds
,
tf
.
distribute
.
get_strategy
())
self
.
assertIs
(
ds
,
tf
.
distribute
.
get_strategy
())
...
@@ -54,6 +91,12 @@ class GetDistributionStrategyTest(tf.test.TestCase):
...
@@ -54,6 +91,12 @@ class GetDistributionStrategyTest(tf.test.TestCase):
ValueError
,
'distribution_strategy must be a string but got: 1'
):
ValueError
,
'distribution_strategy must be a string but got: 1'
):
distribute_utils
.
get_distribution_strategy
(
1
)
distribute_utils
.
get_distribution_strategy
(
1
)
def
test_get_strategy_scope
(
self
):
ds
=
distribute_utils
.
get_distribution_strategy
(
'one_device'
,
num_gpus
=
0
)
with
distribute_utils
.
get_strategy_scope
(
ds
):
self
.
assertIs
(
tf
.
distribute
.
get_strategy
(),
ds
)
with
distribute_utils
.
get_strategy_scope
(
None
):
self
.
assertIsNot
(
tf
.
distribute
.
get_strategy
(),
ds
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment