Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
997eaa19
Commit
997eaa19
authored
Jun 28, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jun 28, 2020
Browse files
Internal change
PiperOrigin-RevId: 318714418
parent
68d983b9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
4 additions
and
80 deletions
+4
-80
official/nlp/transformer/optimizer.py
official/nlp/transformer/optimizer.py
+0
-71
official/nlp/transformer/transformer_main.py
official/nlp/transformer/transformer_main.py
+3
-8
official/nlp/transformer/translate.py
official/nlp/transformer/translate.py
+1
-1
No files found.
official/nlp/transformer/optimizer.py
View file @
997eaa19
...
...
@@ -18,9 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
K
=
tf
.
keras
.
backend
class
LearningRateSchedule
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
...
...
@@ -66,72 +64,3 @@ class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
'hidden_size'
:
self
.
hidden_size
,
'warmup_steps'
:
self
.
warmup_steps
,
}
class
LearningRateFn
(
object
):
"""Creates learning rate function."""
def
__init__
(
self
,
learning_rate
,
hidden_size
,
warmup_steps
):
self
.
learning_rate
=
learning_rate
self
.
hidden_size
=
hidden_size
self
.
warmup_steps
=
float
(
warmup_steps
)
def
__call__
(
self
,
global_step
):
"""Calculate learning rate with linear warmup and rsqrt decay."""
step
=
float
(
global_step
)
learning_rate
=
self
.
learning_rate
learning_rate
*=
(
self
.
hidden_size
**
-
0.5
)
# Apply linear warmup
learning_rate
*=
np
.
minimum
(
1.0
,
step
/
self
.
warmup_steps
)
# Apply rsqrt decay
learning_rate
/=
np
.
sqrt
(
np
.
maximum
(
step
,
self
.
warmup_steps
))
return
learning_rate
class
LearningRateScheduler
(
tf
.
keras
.
callbacks
.
Callback
):
"""Keras callback to schedule learning rate.
TODO(tianlin): Refactor this scheduler and LearningRateBatchScheduler in
official/resnet/keras/keras_common.py.
"""
def
__init__
(
self
,
schedule
,
init_steps
=
None
,
verbose
=
False
):
super
(
LearningRateScheduler
,
self
).
__init__
()
self
.
schedule
=
schedule
self
.
verbose
=
verbose
if
init_steps
is
None
:
init_steps
=
0.0
self
.
steps
=
float
(
init_steps
)
# Total steps during training.
def
on_epoch_begin
(
self
,
epoch
,
logs
=
None
):
if
not
hasattr
(
self
.
model
.
optimizer
,
'lr'
):
raise
ValueError
(
'Optimizer must have a "lr" attribute.'
)
if
not
hasattr
(
self
.
model
.
optimizer
,
'iterations'
):
raise
ValueError
(
'Optimizer must have a "iterations" attribute.'
)
def
on_train_batch_begin
(
self
,
batch
,
logs
=
None
):
"""Adjusts learning rate for each train batch."""
if
self
.
verbose
>
0
:
iterations
=
K
.
get_value
(
self
.
model
.
optimizer
.
iterations
)
print
(
'Original iteration %d'
%
iterations
)
self
.
steps
+=
1.0
try
:
# new API
lr
=
float
(
K
.
get_value
(
self
.
model
.
optimizer
.
lr
))
lr
=
self
.
schedule
(
self
.
steps
,
lr
)
except
TypeError
:
# Support for old API for backward compatibility
lr
=
self
.
schedule
(
self
.
steps
)
if
not
isinstance
(
lr
,
(
float
,
np
.
float32
,
np
.
float64
)):
raise
ValueError
(
'The output of the "schedule" function '
'should be float.'
)
K
.
set_value
(
self
.
model
.
optimizer
.
lr
,
lr
)
K
.
set_value
(
self
.
model
.
optimizer
.
iterations
,
self
.
steps
)
if
self
.
verbose
>
0
:
print
(
'Batch %05d Step %05d: LearningRateScheduler setting learning '
'rate to %s.'
%
(
batch
+
1
,
self
.
steps
,
lr
))
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
logs
=
logs
or
{}
logs
[
'lr'
]
=
K
.
get_value
(
self
.
model
.
optimizer
.
lr
)
logs
[
'steps'
]
=
self
.
steps
official/nlp/transformer/transformer_main.py
View file @
997eaa19
...
...
@@ -241,7 +241,7 @@ class TransformerTask(object):
if
params
[
"use_ctl"
]:
train_ds_iterator
=
iter
(
train_ds
)
callbacks
=
self
.
_create_callbacks
(
flags_obj
.
model_dir
,
0
,
params
)
callbacks
=
self
.
_create_callbacks
(
flags_obj
.
model_dir
,
params
)
# Only TimeHistory callback is supported for CTL
if
params
[
"use_ctl"
]:
...
...
@@ -408,14 +408,9 @@ class TransformerTask(object):
for
i
in
range
(
length
):
translate
.
translate_from_input
(
val_outputs
[
i
],
subtokenizer
)
def
_create_callbacks
(
self
,
cur_log_dir
,
init_steps
,
params
):
def
_create_callbacks
(
self
,
cur_log_dir
,
params
):
"""Creates a list of callbacks."""
sfunc
=
optimizer
.
LearningRateFn
(
params
[
"learning_rate"
],
params
[
"hidden_size"
],
params
[
"learning_rate_warmup_steps"
])
scheduler_callback
=
optimizer
.
LearningRateScheduler
(
sfunc
,
init_steps
)
callbacks
=
misc
.
get_callbacks
()
callbacks
.
append
(
scheduler_callback
)
if
params
[
"enable_checkpointing"
]:
ckpt_full_path
=
os
.
path
.
join
(
cur_log_dir
,
"cp-{epoch:04d}.ckpt"
)
callbacks
.
append
(
...
...
@@ -445,7 +440,7 @@ class TransformerTask(object):
params
[
"learning_rate"
],
params
[
"hidden_size"
],
params
[
"learning_rate_warmup_steps"
])
opt
=
tf
.
keras
.
optimizers
.
Adam
(
lr_schedule
if
self
.
use_tpu
else
params
[
"learning_rate"
]
,
lr_schedule
,
params
[
"optimizer_adam_beta1"
],
params
[
"optimizer_adam_beta2"
],
epsilon
=
params
[
"optimizer_adam_epsilon"
])
...
...
official/nlp/transformer/translate.py
View file @
997eaa19
...
...
@@ -181,7 +181,7 @@ def translate_file(model,
raise
ValueError
(
"File output is a directory, will not save outputs to "
"file."
)
logging
.
info
(
"Writing to file %s"
,
output_file
)
with
tf
.
compat
.
v1
.
gfile
.
Open
(
output_file
,
"w"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
output_file
,
"w"
)
as
f
:
for
i
in
sorted_keys
:
f
.
write
(
"%s
\n
"
%
translations
[
i
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment