Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7c83a9d7
Commit
7c83a9d7
authored
Mar 25, 2020
by
Will Cromar
Committed by
A. Unique TensorFlower
Mar 25, 2020
Browse files
Internal change
PiperOrigin-RevId: 302921283
parent
9fc4fd08
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
1 deletion
+22
-1
official/nlp/transformer/misc.py
official/nlp/transformer/misc.py
+4
-1
official/nlp/transformer/transformer_main.py
official/nlp/transformer/transformer_main.py
+18
-0
No files found.
official/nlp/transformer/misc.py
View file @
7c83a9d7
...
@@ -239,7 +239,10 @@ def get_callbacks(steps_per_epoch):
...
@@ -239,7 +239,10 @@ def get_callbacks(steps_per_epoch):
"""Returns common callbacks."""
"""Returns common callbacks."""
callbacks
=
[]
callbacks
=
[]
if
FLAGS
.
enable_time_history
:
if
FLAGS
.
enable_time_history
:
time_callback
=
keras_utils
.
TimeHistory
(
FLAGS
.
batch_size
,
FLAGS
.
log_steps
)
time_callback
=
keras_utils
.
TimeHistory
(
FLAGS
.
batch_size
,
FLAGS
.
log_steps
,
FLAGS
.
model_dir
if
FLAGS
.
enable_tensorboard
else
None
)
callbacks
.
append
(
time_callback
)
callbacks
.
append
(
time_callback
)
if
FLAGS
.
enable_tensorboard
:
if
FLAGS
.
enable_tensorboard
:
...
...
official/nlp/transformer/transformer_main.py
View file @
7c83a9d7
...
@@ -246,6 +246,11 @@ class TransformerTask(object):
...
@@ -246,6 +246,11 @@ class TransformerTask(object):
callbacks
=
self
.
_create_callbacks
(
flags_obj
.
model_dir
,
0
,
params
)
callbacks
=
self
.
_create_callbacks
(
flags_obj
.
model_dir
,
0
,
params
)
# Only TimeHistory callback is supported for CTL
if
params
[
"use_ctl"
]:
callbacks
=
[
cb
for
cb
in
callbacks
if
isinstance
(
cb
,
keras_utils
.
TimeHistory
)]
# TODO(b/139418525): Refactor the custom training loop logic.
# TODO(b/139418525): Refactor the custom training loop logic.
@
tf
.
function
@
tf
.
function
def
train_steps
(
iterator
,
steps
):
def
train_steps
(
iterator
,
steps
):
...
@@ -299,8 +304,13 @@ class TransformerTask(object):
...
@@ -299,8 +304,13 @@ class TransformerTask(object):
if
not
self
.
use_tpu
:
if
not
self
.
use_tpu
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Custom training loop on GPUs is not implemented."
)
"Custom training loop on GPUs is not implemented."
)
# Runs training steps.
# Runs training steps.
with
summary_writer
.
as_default
():
with
summary_writer
.
as_default
():
for
cb
in
callbacks
:
cb
.
on_epoch_begin
(
current_iteration
)
cb
.
on_batch_begin
(
0
)
train_steps
(
train_steps
(
train_ds_iterator
,
train_ds_iterator
,
tf
.
convert_to_tensor
(
train_steps_per_eval
,
dtype
=
tf
.
int32
))
tf
.
convert_to_tensor
(
train_steps_per_eval
,
dtype
=
tf
.
int32
))
...
@@ -309,10 +319,18 @@ class TransformerTask(object):
...
@@ -309,10 +319,18 @@ class TransformerTask(object):
logging
.
info
(
"Train Step: %d/%d / loss = %s"
,
current_step
,
logging
.
info
(
"Train Step: %d/%d / loss = %s"
,
current_step
,
flags_obj
.
train_steps
,
train_loss
)
flags_obj
.
train_steps
,
train_loss
)
for
cb
in
callbacks
:
cb
.
on_batch_end
(
train_steps_per_eval
-
1
)
cb
.
on_epoch_end
(
current_iteration
)
if
params
[
"enable_tensorboard"
]:
if
params
[
"enable_tensorboard"
]:
for
metric_obj
in
train_metrics
:
for
metric_obj
in
train_metrics
:
tf
.
compat
.
v2
.
summary
.
scalar
(
metric_obj
.
name
,
metric_obj
.
result
(),
tf
.
compat
.
v2
.
summary
.
scalar
(
metric_obj
.
name
,
metric_obj
.
result
(),
current_step
)
current_step
)
summary_writer
.
flush
()
for
cb
in
callbacks
:
cb
.
on_train_end
()
if
flags_obj
.
enable_checkpointing
:
if
flags_obj
.
enable_checkpointing
:
# avoid check-pointing when running for benchmarking.
# avoid check-pointing when running for benchmarking.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment