Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dc4d1121
Commit
dc4d1121
authored
Nov 08, 2020
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
Nov 08, 2020
Browse files
Remove explicit control dependency for weight decay.
PiperOrigin-RevId: 341329653
parent
2bde2485
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
16 deletions
+4
-16
official/nlp/optimization.py
official/nlp/optimization.py
+4
-16
No files found.
official/nlp/optimization.py
View file @
dc4d1121
...
...
@@ -191,27 +191,15 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return
coefficients
[
'lr_t'
],
dict
(
apply_state
=
apply_state
)
def
_resource_apply_dense
(
self
,
grad
,
var
,
apply_state
=
None
):
# As the weight decay doesn't take any tensors from forward pass as inputs,
# add a control dependency here to make sure it happens strictly in the
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with
tf
.
control_dependencies
([
tf
.
identity
(
grad
)]):
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
with
tf
.
control_dependencies
([
decay
]):
return
super
(
AdamWeightDecay
,
self
).
_resource_apply_dense
(
grad
,
var
,
**
kwargs
)
def
_resource_apply_sparse
(
self
,
grad
,
var
,
indices
,
apply_state
=
None
):
# As the weight decay doesn't take any tensors from forward pass as inputs,
# add a control dependency here to make sure it happens strictly in the
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with
tf
.
control_dependencies
([
tf
.
identity
(
grad
)]):
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
with
tf
.
control_dependencies
([
decay
]):
return
super
(
AdamWeightDecay
,
self
).
_resource_apply_sparse
(
grad
,
var
,
indices
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment