Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5fea53a7
Commit
5fea53a7
authored
Dec 28, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Dec 28, 2021
Browse files
Internal change
PiperOrigin-RevId: 418724903
parent
77d9fd62
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
160 additions
and
0 deletions
+160
-0
official/modeling/tf_utils.py
official/modeling/tf_utils.py
+71
-0
official/modeling/tf_utils_test.py
official/modeling/tf_utils_test.py
+89
-0
No files found.
official/modeling/tf_utils.py
View file @
5fea53a7
...
...
@@ -201,3 +201,74 @@ def safe_mean(losses):
total
=
tf
.
reduce_sum
(
losses
)
num_elements
=
tf
.
cast
(
tf
.
size
(
losses
),
dtype
=
losses
.
dtype
)
return
tf
.
math
.
divide_no_nan
(
total
,
num_elements
)
def
get_replica_id
():
"""Gets replica id depending on the environment."""
context
=
tf
.
distribute
.
get_replica_context
()
if
context
is
not
None
:
return
context
.
replica_id_in_sync_group
else
:
raise
RuntimeError
(
"Unknown replica context. The `get_replica_id` method "
"relies on TF 2.x tf.distribute API."
)
def
cross_replica_concat
(
value
,
axis
,
name
=
"cross_replica_concat"
):
"""Concatenates the given `value` across (GPU/TPU) cores, along `axis`.
In general, each core ("replica") will pass a
replica-specific value as `value` (corresponding to some element of a
data-parallel computation taking place across replicas).
The resulting concatenated `Tensor` will have the same shape as `value` for
all dimensions except `axis`, where it will be larger by a factor of the
number of replicas. It will also have the same `dtype` as `value`.
The position of a given replica's `value` within the resulting concatenation
is determined by that replica's replica ID. For
example:
With `value` for replica 0 given as
0 0 0
0 0 0
and `value` for replica 1 given as
1 1 1
1 1 1
the resulting concatenation along axis 0 will be
0 0 0
0 0 0
1 1 1
1 1 1
and this result will be identical across all replicas.
Note that this API only works in TF2 with `tf.distribute`.
Args:
value: The `Tensor` to concatenate across replicas. Each replica will have a
different value for this `Tensor`, and these replica-specific values will
be concatenated.
axis: The axis along which to perform the concatenation as a Python integer
(not a `Tensor`). E.g., `axis=0` to concatenate along the batch dimension.
name: A name for the operation (used to create a name scope).
Returns:
The result of concatenating `value` along `axis` across replicas.
Raises:
RuntimeError: when the batch (0-th) dimension is None.
"""
with
tf
.
name_scope
(
name
):
context
=
tf
.
distribute
.
get_replica_context
()
# Typically this could be hit only if the tensor is derived from a
# dataset with finite epochs and drop_remainder=False, where the last
# batch could of different batch size and then the dim-0 is of dynamic
# shape.
if
value
.
shape
.
as_list
()[
0
]
is
None
:
raise
RuntimeError
(
f
"
{
value
}
has unknown batch."
)
return
context
.
all_gather
(
value
,
axis
=
axis
)
official/modeling/tf_utils_test.py
0 → 100644
View file @
5fea53a7
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf_utils."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.modeling
import
tf_utils
def
all_strategy_combinations
():
return
combinations
.
combine
(
strategy
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
mirrored_strategy_with_two_gpus
,
],
mode
=
'eager'
,
)
class
TFUtilsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_cross_replica_concat
(
self
,
strategy
):
num_cores
=
strategy
.
num_replicas_in_sync
shape
=
(
2
,
3
,
4
)
def
concat
(
axis
):
@
tf
.
function
def
function
():
replica_value
=
tf
.
fill
(
shape
,
tf_utils
.
get_replica_id
())
return
tf_utils
.
cross_replica_concat
(
replica_value
,
axis
=
axis
)
return
function
def
expected
(
axis
):
values
=
[
np
.
full
(
shape
,
i
)
for
i
in
range
(
num_cores
)]
return
np
.
concatenate
(
values
,
axis
=
axis
)
per_replica_results
=
strategy
.
run
(
concat
(
axis
=
0
))
replica_0_result
=
per_replica_results
.
values
[
0
].
numpy
()
for
value
in
per_replica_results
.
values
[
1
:]:
self
.
assertAllClose
(
value
.
numpy
(),
replica_0_result
)
self
.
assertAllClose
(
replica_0_result
,
expected
(
axis
=
0
))
replica_0_result
=
strategy
.
run
(
concat
(
axis
=
1
)).
values
[
0
].
numpy
()
self
.
assertAllClose
(
replica_0_result
,
expected
(
axis
=
1
))
replica_0_result
=
strategy
.
run
(
concat
(
axis
=
2
)).
values
[
0
].
numpy
()
self
.
assertAllClose
(
replica_0_result
,
expected
(
axis
=
2
))
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_cross_replica_concat_gradient
(
self
,
strategy
):
num_cores
=
strategy
.
num_replicas_in_sync
shape
=
(
10
,
5
)
@
tf
.
function
def
function
():
replica_value
=
tf
.
random
.
normal
(
shape
)
with
tf
.
GradientTape
()
as
tape
:
tape
.
watch
(
replica_value
)
concat_value
=
tf_utils
.
cross_replica_concat
(
replica_value
,
axis
=
0
)
output
=
tf
.
reduce_sum
(
concat_value
)
return
tape
.
gradient
(
output
,
replica_value
)
per_replica_gradients
=
strategy
.
run
(
function
)
for
gradient
in
per_replica_gradients
.
values
:
self
.
assertAllClose
(
gradient
,
num_cores
*
tf
.
ones
(
shape
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment