Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
29955e8b
Commit
29955e8b
authored
Sep 27, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 27, 2020
Browse files
Internal change
PiperOrigin-RevId: 334062074
parent
f039e4b9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
12 deletions
+24
-12
official/nlp/train_ctl_continuous_finetune.py
official/nlp/train_ctl_continuous_finetune.py
+15
-6
official/nlp/train_ctl_continuous_finetune_test.py
official/nlp/train_ctl_continuous_finetune_test.py
+9
-6
No files found.
official/nlp/train_ctl_continuous_finetune.py
View file @
29955e8b
...
...
@@ -102,10 +102,24 @@ def run_continuous_finetune(
summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
model_dir
,
'eval'
))
global_step
=
0
def
timeout_fn
():
if
pretrain_steps
and
global_step
<
pretrain_steps
:
# Keeps waiting for another timeout period.
logging
.
info
(
'Continue waiting for new checkpoint as current pretrain '
'global_step=%d and target is %d.'
,
global_step
,
pretrain_steps
)
return
False
# Quits the loop.
return
True
for
pretrain_ckpt
in
tf
.
train
.
checkpoints_iterator
(
checkpoint_dir
=
params
.
task
.
init_checkpoint
,
min_interval_secs
=
10
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
):
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
):
with
distribution_strategy
.
scope
():
global_step
=
train_utils
.
read_global_step_from_checkpoint
(
pretrain_ckpt
)
...
...
@@ -154,11 +168,6 @@ def run_continuous_finetune(
# if we need gc here.
gc
.
collect
()
if
pretrain_steps
and
global_step
.
numpy
()
>=
pretrain_steps
:
logging
.
info
(
'The global_step reaches the pretraining end. Continuous '
'finetuning terminates.'
)
break
if
run_post_eval
:
return
eval_metrics
return
{}
...
...
official/nlp/train_ctl_continuous_finetune_test.py
View file @
29955e8b
...
...
@@ -15,10 +15,9 @@
# ==============================================================================
import
os
# Import libraries
from
absl
import
flags
from
absl.testing
import
flagsaver
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
...
...
@@ -31,14 +30,14 @@ FLAGS = flags.FLAGS
tfm_flags
.
define_flags
()
class
ContinuousFinetuneTest
(
tf
.
test
.
TestCase
):
class
ContinuousFinetuneTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_model_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'model_dir'
)
@
flagsaver
.
flagsaver
def
testTrainCtl
(
self
):
@
parameterized
.
parameters
(
None
,
1
)
def
testTrainCtl
(
self
,
pretrain_steps
):
src_model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
...
...
@@ -81,7 +80,11 @@ class ContinuousFinetuneTest(tf.test.TestCase):
params
=
train_utils
.
parse_configuration
(
FLAGS
)
eval_metrics
=
train_ctl_continuous_finetune
.
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
FLAGS
.
model_dir
,
run_post_eval
=
True
)
FLAGS
.
mode
,
params
,
FLAGS
.
model_dir
,
run_post_eval
=
True
,
pretrain_steps
=
pretrain_steps
)
self
.
assertIn
(
'best_acc'
,
eval_metrics
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment