Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b547c6fa
Commit
b547c6fa
authored
Sep 24, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 24, 2020
Browse files
Avoid global_step binding to model. It is buggy due to tf.train.Checkpoint delayed restoration...
PiperOrigin-RevId: 333591143
parent
eaff981c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
11 deletions
+5
-11
official/nlp/nhnet/evaluation.py
official/nlp/nhnet/evaluation.py
+2
-8
official/nlp/nhnet/trainer.py
official/nlp/nhnet/trainer.py
+3
-3
No files found.
official/nlp/nhnet/evaluation.py
View file @
b547c6fa
...
@@ -15,11 +15,6 @@
...
@@ -15,11 +15,6 @@
# ==============================================================================
# ==============================================================================
"""Evaluation for Bert2Bert."""
"""Evaluation for Bert2Bert."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
os
import
os
# Import libraries
# Import libraries
from
absl
import
logging
from
absl
import
logging
...
@@ -114,7 +109,6 @@ def continuous_eval(strategy,
...
@@ -114,7 +109,6 @@ def continuous_eval(strategy,
dtype
=
tf
.
int64
,
dtype
=
tf
.
int64
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
,
shape
=
[])
shape
=
[])
model
.
global_step
=
global_step
@
tf
.
function
@
tf
.
function
def
test_step
(
inputs
):
def
test_step
(
inputs
):
...
@@ -149,7 +143,7 @@ def continuous_eval(strategy,
...
@@ -149,7 +143,7 @@ def continuous_eval(strategy,
eval_results
=
{}
eval_results
=
{}
for
latest_checkpoint
in
tf
.
train
.
checkpoints_iterator
(
for
latest_checkpoint
in
tf
.
train
.
checkpoints_iterator
(
model_dir
,
timeout
=
timeout
):
model_dir
,
timeout
=
timeout
):
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
global_step
=
global_step
)
checkpoint
.
restore
(
latest_checkpoint
).
expect_partial
()
checkpoint
.
restore
(
latest_checkpoint
).
expect_partial
()
logging
.
info
(
"Loaded checkpoint %s"
,
latest_checkpoint
)
logging
.
info
(
"Loaded checkpoint %s"
,
latest_checkpoint
)
...
@@ -162,7 +156,7 @@ def continuous_eval(strategy,
...
@@ -162,7 +156,7 @@ def continuous_eval(strategy,
metric
.
update_state
(
func
(
logits
.
numpy
(),
targets
.
numpy
()))
metric
.
update_state
(
func
(
logits
.
numpy
(),
targets
.
numpy
()))
with
eval_summary_writer
.
as_default
():
with
eval_summary_writer
.
as_default
():
step
=
model
.
global_step
.
numpy
()
step
=
global_step
.
numpy
()
for
metric
,
_
in
metrics_and_funcs
:
for
metric
,
_
in
metrics_and_funcs
:
eval_results
[
metric
.
name
]
=
metric
.
result
().
numpy
().
astype
(
float
)
eval_results
[
metric
.
name
]
=
metric
.
result
().
numpy
().
astype
(
float
)
tf
.
summary
.
scalar
(
tf
.
summary
.
scalar
(
...
...
official/nlp/nhnet/trainer.py
View file @
b547c6fa
...
@@ -145,7 +145,6 @@ def train(params, strategy, dataset=None):
...
@@ -145,7 +145,6 @@ def train(params, strategy, dataset=None):
FLAGS
.
model_type
,
params
,
init_checkpoint
=
FLAGS
.
init_checkpoint
)
FLAGS
.
model_type
,
params
,
init_checkpoint
=
FLAGS
.
init_checkpoint
)
opt
=
optimizer
.
create_optimizer
(
params
)
opt
=
optimizer
.
create_optimizer
(
params
)
trainer
=
Trainer
(
model
,
params
)
trainer
=
Trainer
(
model
,
params
)
model
.
global_step
=
opt
.
iterations
trainer
.
compile
(
trainer
.
compile
(
optimizer
=
opt
,
optimizer
=
opt
,
...
@@ -153,12 +152,13 @@ def train(params, strategy, dataset=None):
...
@@ -153,12 +152,13 @@ def train(params, strategy, dataset=None):
summary_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
"summaries"
)
summary_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
"summaries"
)
summary_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
summary_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
summary_dir
,
update_freq
=
max
(
100
,
FLAGS
.
steps_per_loop
))
summary_dir
,
update_freq
=
max
(
100
,
FLAGS
.
steps_per_loop
))
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
opt
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
opt
,
global_step
=
opt
.
iterations
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
checkpoint
,
directory
=
FLAGS
.
model_dir
,
directory
=
FLAGS
.
model_dir
,
max_to_keep
=
10
,
max_to_keep
=
10
,
step_counter
=
model
.
global_step
,
step_counter
=
opt
.
iterations
,
checkpoint_interval
=
FLAGS
.
checkpoint_interval
)
checkpoint_interval
=
FLAGS
.
checkpoint_interval
)
if
checkpoint_manager
.
restore_or_initialize
():
if
checkpoint_manager
.
restore_or_initialize
():
logging
.
info
(
"Training restored from the checkpoints in: %s"
,
logging
.
info
(
"Training restored from the checkpoints in: %s"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment