Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
34706ba0
Unverified
Commit
34706ba0
authored
May 15, 2020
by
Jared T Nielsen
Committed by
GitHub
May 15, 2020
Browse files
Allow for None gradients in GradientAccumulator. (#4372)
parent
edf9ac11
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
3 deletions
+7
-3
src/transformers/optimization_tf.py
src/transformers/optimization_tf.py
+7
-3
No files found.
src/transformers/optimization_tf.py
View file @
34706ba0
...
...
@@ -217,7 +217,7 @@ class GradientAccumulator(object):
"""The accumulated gradients on the current replica."""
if
not
self
.
_gradients
:
raise
ValueError
(
"The accumulator should be called first to initialize the gradients"
)
return
list
(
gradient
.
value
()
for
gradient
in
self
.
_gradients
)
return
list
(
gradient
.
value
()
if
gradient
is
not
None
else
gradient
for
gradient
in
self
.
_gradients
)
def
__call__
(
self
,
gradients
):
"""Accumulates :obj:`gradients` on the current replica."""
...
...
@@ -231,6 +231,8 @@ class GradientAccumulator(object):
synchronization
=
tf
.
VariableSynchronization
.
ON_READ
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
,
)
if
gradient
is
not
None
else
gradient
for
gradient
in
gradients
]
)
...
...
@@ -238,7 +240,8 @@ class GradientAccumulator(object):
raise
ValueError
(
"Expected %s gradients, but got %d"
%
(
len
(
self
.
_gradients
),
len
(
gradients
)))
for
accum_gradient
,
gradient
in
zip
(
self
.
_gradients
,
gradients
):
accum_gradient
.
assign_add
(
gradient
)
if
accum_gradient
is
not
None
and
gradient
is
not
None
:
accum_gradient
.
assign_add
(
gradient
)
self
.
_accum_steps
.
assign_add
(
1
)
...
...
@@ -248,4 +251,5 @@ class GradientAccumulator(object):
return
self
.
_accum_steps
.
assign
(
0
)
for
gradient
in
self
.
_gradients
:
gradient
.
assign
(
tf
.
zeros_like
(
gradient
))
if
gradient
is
not
None
:
gradient
.
assign
(
tf
.
zeros_like
(
gradient
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment