Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
19113a57
Commit
19113a57
authored
Oct 16, 2020
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
Oct 16, 2020
Browse files
Internal change
PiperOrigin-RevId: 337609198
parent
a3e847b6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
4 deletions
+16
-4
official/nlp/optimization.py
official/nlp/optimization.py
+16
-4
No files found.
official/nlp/optimization.py
View file @
19113a57
...
@@ -194,15 +194,27 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
...
@@ -194,15 +194,27 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return
coefficients
[
'lr_t'
],
dict
(
apply_state
=
apply_state
)
return
coefficients
[
'lr_t'
],
dict
(
apply_state
=
apply_state
)
def
_resource_apply_dense
(
self
,
grad
,
var
,
apply_state
=
None
):
def
_resource_apply_dense
(
self
,
grad
,
var
,
apply_state
=
None
):
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
# As the weight decay doesn't take any tensors from forward pass as inputs,
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
# add a control dependency here to make sure it happens strictly in the
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with
tf
.
control_dependencies
([
grad
]):
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
with
tf
.
control_dependencies
([
decay
]):
with
tf
.
control_dependencies
([
decay
]):
return
super
(
AdamWeightDecay
,
return
super
(
AdamWeightDecay
,
self
).
_resource_apply_dense
(
grad
,
var
,
**
kwargs
)
self
).
_resource_apply_dense
(
grad
,
var
,
**
kwargs
)
def
_resource_apply_sparse
(
self
,
grad
,
var
,
indices
,
apply_state
=
None
):
def
_resource_apply_sparse
(
self
,
grad
,
var
,
indices
,
apply_state
=
None
):
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
# As the weight decay doesn't take any tensors from forward pass as inputs,
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
# add a control dependency here to make sure it happens strictly in the
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with
tf
.
control_dependencies
([
grad
]):
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
with
tf
.
control_dependencies
([
decay
]):
with
tf
.
control_dependencies
([
decay
]):
return
super
(
AdamWeightDecay
,
return
super
(
AdamWeightDecay
,
self
).
_resource_apply_sparse
(
grad
,
var
,
indices
,
**
kwargs
)
self
).
_resource_apply_sparse
(
grad
,
var
,
indices
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment