Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9f88ce51
Commit
9f88ce51
authored
May 17, 2021
by
Le Hou
Committed by
A. Unique TensorFlower
May 17, 2021
Browse files
Internal change
PiperOrigin-RevId: 374238894
parent
fd45760c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
27 deletions
+45
-27
official/modeling/progressive/trainer.py
official/modeling/progressive/trainer.py
+45
-27
No files found.
official/modeling/progressive/trainer.py
View file @
9f88ce51
...
...
@@ -51,6 +51,9 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval.
export_max_to_keep: The maximum number of exported checkpoints to keep.
If None (by default), will use the same value as
TrainerConfig.max_to_keep.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to
...
...
@@ -59,6 +62,7 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
progressive
:
Optional
[
policies
.
ProgressiveConfig
]
=
None
export_checkpoint
:
bool
=
True
export_checkpoint_interval
:
Optional
[
int
]
=
None
export_max_to_keep
:
Optional
[
int
]
=
None
export_only_final_stage_ckpt
:
bool
=
True
...
...
@@ -98,6 +102,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Directory for non-progressive checkpoint
self
.
_export_ckpt_dir
=
os
.
path
.
join
(
ckpt_dir
,
'exported_ckpts'
)
tf
.
io
.
gfile
.
makedirs
(
self
.
_export_ckpt_dir
)
self
.
_export_ckpt_manager
=
None
# Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default
...
...
@@ -194,6 +199,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self
.
_train_iter
=
None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self
.
_export_ckpt_manager
=
None
return
logs
def
_update_pt_stage_from_ckpt
(
self
,
ckpt_file
):
...
...
@@ -226,6 +235,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self
.
_train_iter
=
None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self
.
_export_ckpt_manager
=
None
def
_maybe_export_non_progressive_checkpoint
(
self
,
export_ckpt_dir
):
"""Export checkpoints in non-progressive format.
...
...
@@ -244,30 +257,35 @@ class ProgressiveTrainer(trainer_lib.Trainer):
logging
.
info
(
'Not exporting checkpoints until the last stage.'
)
return
global_step_np
=
self
.
global_step
.
numpy
()
if
self
.
config
.
trainer
.
export_checkpoint_interval
is
None
:
step_interval
=
self
.
config
.
trainer
.
checkpoint_interval
else
:
step_interval
=
self
.
config
.
trainer
.
export_checkpoint_interval
if
global_step_np
%
step_interval
!=
0
and
(
global_step_np
<
self
.
_config
.
trainer
.
train_steps
):
logging
.
info
(
'Not exporting checkpoints in global step: %d.'
,
global_step_np
)
return
# Create a checkpoint object just now, to make sure we use
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
# current stage.
if
hasattr
(
self
.
model
,
'checkpoint_items'
):
checkpoint_items
=
self
.
model
.
checkpoint_items
else
:
checkpoint_items
=
{}
checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
,
optimizer
=
self
.
optimizer
,
**
checkpoint_items
)
file_prefix
=
os
.
path
.
join
(
export_ckpt_dir
,
'ckpt-{}'
.
format
(
global_step_np
))
checkpoint
.
save
(
file_prefix
=
file_prefix
)
logging
.
info
(
'Checkpoints exported: %s.'
,
file_prefix
)
if
self
.
_export_ckpt_manager
is
None
:
# Create a checkpoint object just now, to make sure we use
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
# current stage.
if
hasattr
(
self
.
model
,
'checkpoint_items'
):
checkpoint_items
=
self
.
model
.
checkpoint_items
else
:
checkpoint_items
=
{}
checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
,
optimizer
=
self
.
optimizer
,
**
checkpoint_items
)
max_to_keep
=
self
.
config
.
trainer
.
export_max_to_keep
or
(
self
.
config
.
trainer
.
max_to_keep
)
checkpoint_interval
=
self
.
config
.
trainer
.
export_checkpoint_interval
or
(
self
.
config
.
trainer
.
checkpoint_interval
)
self
.
_export_ckpt_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
export_ckpt_dir
,
checkpoint_name
=
'ckpt'
,
step_counter
=
self
.
global_step
,
max_to_keep
=
max_to_keep
,
checkpoint_interval
=
checkpoint_interval
,
)
checkpoint_path
=
self
.
_export_ckpt_manager
.
save
(
checkpoint_number
=
self
.
global_step
.
numpy
(),
check_interval
=
True
)
if
checkpoint_path
:
logging
.
info
(
'Checkpoints exported: %s.'
,
checkpoint_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment