Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
80501b61
Commit
80501b61
authored
Sep 15, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 15, 2020
Browse files
Internal change
PiperOrigin-RevId: 331877703
parent
eb18c12c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
10 deletions
+21
-10
official/nlp/train_ctl_continuous_finetune.py
official/nlp/train_ctl_continuous_finetune.py
+19
-8
official/nlp/train_ctl_continuous_finetune_test.py
official/nlp/train_ctl_continuous_finetune_test.py
+2
-2
No files found.
official/nlp/train_ctl_continuous_finetune.py
View file @
80501b61
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
os
import
os
import
time
import
time
from
typing
import
Mapping
,
Any
from
typing
import
Any
,
Mapping
,
Optional
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
...
@@ -36,30 +36,36 @@ from official.core import train_utils
...
@@ -36,30 +36,36 @@ from official.core import train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.modeling.hyperparams
import
config_definitions
from
official.modeling.hyperparams
import
config_definitions
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_integer
(
'pretrain_steps'
,
default
=
None
,
help
=
'The number of total training steps for the pretraining job.'
)
def
run_continuous_finetune
(
def
run_continuous_finetune
(
mode
:
str
,
mode
:
str
,
params
:
config_definitions
.
ExperimentConfig
,
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
,
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
run_post_eval
:
bool
=
False
,
pretrain_steps
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
"""Run modes with continuous training.
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Currently only supports continuous_train_and_eval.
Args:
Args:
mode: A 'str', specifying the mode.
mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
continuous_train_and_eval - monitors a checkpoint directory. Once a new
checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint is discovered, loads the checkpoint, finetune the model by
checkpoint, finetune the model by training it (probably on another dataset
training it (probably on another dataset or with another task), then
or with another task), then evaluate the finetuned model.
evaluate the finetuned model.
params: ExperimentConfig instance.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns:
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
eval logs: returns eval metrics logs when run_post_eval is set to True,
...
@@ -140,6 +146,11 @@ def run_continuous_finetune(
...
@@ -140,6 +146,11 @@ def run_continuous_finetune(
train_utils
.
remove_ckpts
(
model_dir
)
train_utils
.
remove_ckpts
(
model_dir
)
if
pretrain_steps
and
global_step
.
numpy
()
>=
pretrain_steps
:
logging
.
info
(
'The global_step reaches the pretraining end. Continuous '
'finetuning terminates.'
)
break
if
run_post_eval
:
if
run_post_eval
:
return
eval_metrics
return
eval_metrics
return
{}
return
{}
...
@@ -150,7 +161,7 @@ def main(_):
...
@@ -150,7 +161,7 @@ def main(_):
params
=
train_utils
.
parse_configuration
(
FLAGS
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
model_dir
=
FLAGS
.
model_dir
train_utils
.
serialize_config
(
params
,
model_dir
)
train_utils
.
serialize_config
(
params
,
model_dir
)
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
model_dir
)
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
model_dir
,
FLAGS
.
pretrain_steps
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/nlp/train_ctl_continuous_finetune_test.py
View file @
80501b61
...
@@ -31,10 +31,10 @@ FLAGS = flags.FLAGS
...
@@ -31,10 +31,10 @@ FLAGS = flags.FLAGS
tfm_flags
.
define_flags
()
tfm_flags
.
define_flags
()
class
Main
ContinuousFinetuneTest
(
tf
.
test
.
TestCase
):
class
ContinuousFinetuneTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
MainContinuousFinetuneTest
,
self
).
setUp
()
super
().
setUp
()
self
.
_model_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'model_dir'
)
self
.
_model_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'model_dir'
)
@
flagsaver
.
flagsaver
@
flagsaver
.
flagsaver
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment