Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
08f45dc4
Commit
08f45dc4
authored
Mar 10, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 300203487
parent
682d36b5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
8 deletions
+18
-8
official/nlp/transformer/misc.py
official/nlp/transformer/misc.py
+6
-0
official/nlp/transformer/transformer_main.py
official/nlp/transformer/transformer_main.py
+12
-8
No files found.
official/nlp/transformer/misc.py
View file @
08f45dc4
...
...
@@ -205,6 +205,12 @@ def define_transformer_flags():
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'
))
flags
.
DEFINE_bool
(
name
=
'enable_checkpointing'
,
default
=
True
,
help
=
flags_core
.
help_wrap
(
'Whether to do checkpointing during training. When running under '
'benchmark harness, we will avoid checkpointing.'
))
flags_core
.
set_defaults
(
data_dir
=
'/tmp/translate_ende'
,
model_dir
=
'/tmp/transformer_model'
,
...
...
official/nlp/transformer/transformer_main.py
View file @
08f45dc4
...
...
@@ -159,6 +159,7 @@ class TransformerTask(object):
params
[
"enable_tensorboard"
]
=
flags_obj
.
enable_tensorboard
params
[
"enable_metrics_in_training"
]
=
flags_obj
.
enable_metrics_in_training
params
[
"steps_between_evals"
]
=
flags_obj
.
steps_between_evals
params
[
"enable_checkpointing"
]
=
flags_obj
.
enable_checkpointing
self
.
distribution_strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
...
...
@@ -313,6 +314,8 @@ class TransformerTask(object):
tf
.
compat
.
v2
.
summary
.
scalar
(
metric_obj
.
name
,
metric_obj
.
result
(),
current_step
)
if
flags_obj
.
enable_checkpointing
:
# avoid check-pointing when running for benchmarking.
checkpoint_name
=
checkpoint
.
save
(
os
.
path
.
join
(
flags_obj
.
model_dir
,
"ctl_step_{}.ckpt"
.
format
(
current_step
)))
...
...
@@ -397,6 +400,7 @@ class TransformerTask(object):
scheduler_callback
=
optimizer
.
LearningRateScheduler
(
sfunc
,
init_steps
)
callbacks
=
misc
.
get_callbacks
(
params
[
"steps_between_evals"
])
callbacks
.
append
(
scheduler_callback
)
if
params
[
"enable_checkpointing"
]:
ckpt_full_path
=
os
.
path
.
join
(
cur_log_dir
,
"cp-{epoch:04d}.ckpt"
)
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment