Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5ceb3acf
Commit
5ceb3acf
authored
Sep 09, 2019
by
David Chen
Committed by
A. Unique TensorFlower
Sep 09, 2019
Browse files
Internal change
PiperOrigin-RevId: 268057019
parent
e91c41c2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
7 deletions
+8
-7
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+8
-7
No files found.
official/transformer/v2/transformer_main.py
View file @
5ceb3acf
...
@@ -180,13 +180,13 @@ class TransformerTask(object):
...
@@ -180,13 +180,13 @@ class TransformerTask(object):
if
not
params
[
"static_batch"
]:
if
not
params
[
"static_batch"
]:
raise
ValueError
(
"TPU requires static batch for input data."
)
raise
ValueError
(
"TPU requires static batch for input data."
)
else
:
else
:
print
(
"Running transformer with num_gpus ="
,
num_gpus
)
logging
.
info
(
"Running transformer with num_gpus ="
,
num_gpus
)
if
self
.
distribution_strategy
:
if
self
.
distribution_strategy
:
print
(
"For training, using distribution strategy: "
,
logging
.
info
(
"For training, using distribution strategy: "
,
self
.
distribution_strategy
)
self
.
distribution_strategy
)
else
:
else
:
print
(
"Not using any distribution strategy."
)
logging
.
info
(
"Not using any distribution strategy."
)
@
property
@
property
def
use_tpu
(
self
):
def
use_tpu
(
self
):
...
@@ -289,7 +289,8 @@ class TransformerTask(object):
...
@@ -289,7 +289,8 @@ class TransformerTask(object):
else
flags_obj
.
steps_between_evals
)
else
flags_obj
.
steps_between_evals
)
current_iteration
=
current_step
//
flags_obj
.
steps_between_evals
current_iteration
=
current_step
//
flags_obj
.
steps_between_evals
print
(
"Start train iteration at global step:{}"
.
format
(
current_step
))
logging
.
info
(
"Start train iteration at global step:{}"
.
format
(
current_step
))
history
=
None
history
=
None
if
params
[
"use_ctl"
]:
if
params
[
"use_ctl"
]:
if
not
self
.
use_tpu
:
if
not
self
.
use_tpu
:
...
@@ -324,7 +325,7 @@ class TransformerTask(object):
...
@@ -324,7 +325,7 @@ class TransformerTask(object):
current_step
+=
train_steps_per_eval
current_step
+=
train_steps_per_eval
logging
.
info
(
"Train history: {}"
.
format
(
history
.
history
))
logging
.
info
(
"Train history: {}"
.
format
(
history
.
history
))
print
(
"End train iteration at global step:{}"
.
format
(
current_step
))
logging
.
info
(
"End train iteration at global step:{}"
.
format
(
current_step
))
if
(
flags_obj
.
bleu_source
and
flags_obj
.
bleu_ref
):
if
(
flags_obj
.
bleu_source
and
flags_obj
.
bleu_ref
):
uncased_score
,
cased_score
=
self
.
eval
()
uncased_score
,
cased_score
=
self
.
eval
()
...
@@ -401,7 +402,7 @@ class TransformerTask(object):
...
@@ -401,7 +402,7 @@ class TransformerTask(object):
else
:
else
:
model
.
load_weights
(
init_weight_path
)
model
.
load_weights
(
init_weight_path
)
else
:
else
:
print
(
"Weights not loaded from path:{}"
.
format
(
init_weight_path
))
logging
.
info
(
"Weights not loaded from path:{}"
.
format
(
init_weight_path
))
def
_create_optimizer
(
self
):
def
_create_optimizer
(
self
):
"""Creates optimizer."""
"""Creates optimizer."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment