Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
34db82bf
Commit
34db82bf
authored
Sep 15, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 15, 2020
Browse files
Internal change
PiperOrigin-RevId: 331877703
parent
9e8d7643
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
10 deletions
+21
-10
official/nlp/train_ctl_continuous_finetune.py
official/nlp/train_ctl_continuous_finetune.py
+19
-8
official/nlp/train_ctl_continuous_finetune_test.py
official/nlp/train_ctl_continuous_finetune_test.py
+2
-2
No files found.
official/nlp/train_ctl_continuous_finetune.py
View file @
34db82bf
...
...
@@ -17,7 +17,7 @@
import
os
import
time
from
typing
import
Mapping
,
Any
from
typing
import
Any
,
Mapping
,
Optional
from
absl
import
app
from
absl
import
flags
...
...
@@ -36,30 +36,36 @@ from official.core import train_utils
from
official.modeling
import
performance
from
official.modeling.hyperparams
import
config_definitions
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_integer
(
'pretrain_steps'
,
default
=
None
,
help
=
'The number of total training steps for the pretraining job.'
)
def
run_continuous_finetune
(
mode
:
str
,
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
pretrain_steps
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Args:
mode: A 'str', specifying the mode.
continuous_train_and_eval - monitors a checkpoint directory. Once a new
checkpoint is discovered, loads the checkpoint, finetune the model by
training it (probably on another dataset or with another task), then
evaluate the finetuned model.
mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint, finetune the model by training it (probably on another dataset
or with another task), then evaluate the finetuned model.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
...
...
@@ -140,6 +146,11 @@ def run_continuous_finetune(
train_utils
.
remove_ckpts
(
model_dir
)
if
pretrain_steps
and
global_step
.
numpy
()
>=
pretrain_steps
:
logging
.
info
(
'The global_step reaches the pretraining end. Continuous '
'finetuning terminates.'
)
break
if
run_post_eval
:
return
eval_metrics
return
{}
...
...
@@ -150,7 +161,7 @@ def main(_):
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
train_utils
.
serialize_config
(
params
,
model_dir
)
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
model_dir
)
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
model_dir
,
FLAGS
.
pretrain_steps
)
if
__name__
==
'__main__'
:
...
...
official/nlp/train_ctl_continuous_finetune_test.py
View file @
34db82bf
...
...
@@ -31,10 +31,10 @@ FLAGS = flags.FLAGS
tfm_flags
.
define_flags
()
class
Main
ContinuousFinetuneTest
(
tf
.
test
.
TestCase
):
class
ContinuousFinetuneTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
MainContinuousFinetuneTest
,
self
).
setUp
()
super
().
setUp
()
self
.
_model_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'model_dir'
)
@
flagsaver
.
flagsaver
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment