Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9f88ce51
Commit
9f88ce51
authored
May 17, 2021
by
Le Hou
Committed by
A. Unique TensorFlower
May 17, 2021
Browse files
Internal change
PiperOrigin-RevId: 374238894
parent
fd45760c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
27 deletions
+45
-27
official/modeling/progressive/trainer.py
official/modeling/progressive/trainer.py
+45
-27
No files found.
official/modeling/progressive/trainer.py
View file @
9f88ce51
...
@@ -51,6 +51,9 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
...
@@ -51,6 +51,9 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
export_checkpoint_interval: A bool. The number of steps between exporting
export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as
checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval.
TrainerConfig.checkpoint_interval.
export_max_to_keep: The maximum number of exported checkpoints to keep.
If None (by default), will use the same value as
TrainerConfig.max_to_keep.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to
during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to
not export small, partial models. In many cases, it is not meaningful to
...
@@ -59,6 +62,7 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
...
@@ -59,6 +62,7 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
progressive
:
Optional
[
policies
.
ProgressiveConfig
]
=
None
progressive
:
Optional
[
policies
.
ProgressiveConfig
]
=
None
export_checkpoint
:
bool
=
True
export_checkpoint
:
bool
=
True
export_checkpoint_interval
:
Optional
[
int
]
=
None
export_checkpoint_interval
:
Optional
[
int
]
=
None
export_max_to_keep
:
Optional
[
int
]
=
None
export_only_final_stage_ckpt
:
bool
=
True
export_only_final_stage_ckpt
:
bool
=
True
...
@@ -98,6 +102,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
...
@@ -98,6 +102,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Directory for non-progressive checkpoint
# Directory for non-progressive checkpoint
self
.
_export_ckpt_dir
=
os
.
path
.
join
(
ckpt_dir
,
'exported_ckpts'
)
self
.
_export_ckpt_dir
=
os
.
path
.
join
(
ckpt_dir
,
'exported_ckpts'
)
tf
.
io
.
gfile
.
makedirs
(
self
.
_export_ckpt_dir
)
tf
.
io
.
gfile
.
makedirs
(
self
.
_export_ckpt_dir
)
self
.
_export_ckpt_manager
=
None
# Receive other checkpoint export, e.g, best checkpoint exporter.
# Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default
# TODO(lehou): unify the checkpoint exporting logic, although the default
...
@@ -194,6 +199,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
...
@@ -194,6 +199,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Setting `self._train_iter` to None will rebuild the dataset iterator.
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self
.
_train_iter
=
None
self
.
_train_iter
=
None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self
.
_export_ckpt_manager
=
None
return
logs
return
logs
def
_update_pt_stage_from_ckpt
(
self
,
ckpt_file
):
def
_update_pt_stage_from_ckpt
(
self
,
ckpt_file
):
...
@@ -226,6 +235,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
...
@@ -226,6 +235,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Setting `self._train_iter` to None will rebuild the dataset iterator.
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self
.
_train_iter
=
None
self
.
_train_iter
=
None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self
.
_export_ckpt_manager
=
None
def
_maybe_export_non_progressive_checkpoint
(
self
,
export_ckpt_dir
):
def
_maybe_export_non_progressive_checkpoint
(
self
,
export_ckpt_dir
):
"""Export checkpoints in non-progressive format.
"""Export checkpoints in non-progressive format.
...
@@ -244,30 +257,35 @@ class ProgressiveTrainer(trainer_lib.Trainer):
...
@@ -244,30 +257,35 @@ class ProgressiveTrainer(trainer_lib.Trainer):
logging
.
info
(
'Not exporting checkpoints until the last stage.'
)
logging
.
info
(
'Not exporting checkpoints until the last stage.'
)
return
return
global_step_np
=
self
.
global_step
.
numpy
()
if
self
.
_export_ckpt_manager
is
None
:
if
self
.
config
.
trainer
.
export_checkpoint_interval
is
None
:
# Create a checkpoint object just now, to make sure we use
step_interval
=
self
.
config
.
trainer
.
checkpoint_interval
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
else
:
# current stage.
step_interval
=
self
.
config
.
trainer
.
export_checkpoint_interval
if
hasattr
(
self
.
model
,
'checkpoint_items'
):
if
global_step_np
%
step_interval
!=
0
and
(
checkpoint_items
=
self
.
model
.
checkpoint_items
global_step_np
<
self
.
_config
.
trainer
.
train_steps
):
else
:
logging
.
info
(
'Not exporting checkpoints in global step: %d.'
,
checkpoint_items
=
{}
global_step_np
)
checkpoint
=
tf
.
train
.
Checkpoint
(
return
global_step
=
self
.
global_step
,
model
=
self
.
model
,
# Create a checkpoint object just now, to make sure we use
optimizer
=
self
.
optimizer
,
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
**
checkpoint_items
)
# current stage.
if
hasattr
(
self
.
model
,
'checkpoint_items'
):
max_to_keep
=
self
.
config
.
trainer
.
export_max_to_keep
or
(
checkpoint_items
=
self
.
model
.
checkpoint_items
self
.
config
.
trainer
.
max_to_keep
)
else
:
checkpoint_interval
=
self
.
config
.
trainer
.
export_checkpoint_interval
or
(
checkpoint_items
=
{}
self
.
config
.
trainer
.
checkpoint_interval
)
checkpoint
=
tf
.
train
.
Checkpoint
(
self
.
_export_ckpt_manager
=
tf
.
train
.
CheckpointManager
(
global_step
=
self
.
global_step
,
checkpoint
,
model
=
self
.
model
,
directory
=
export_ckpt_dir
,
optimizer
=
self
.
optimizer
,
checkpoint_name
=
'ckpt'
,
**
checkpoint_items
)
step_counter
=
self
.
global_step
,
file_prefix
=
os
.
path
.
join
(
export_ckpt_dir
,
max_to_keep
=
max_to_keep
,
'ckpt-{}'
.
format
(
global_step_np
))
checkpoint_interval
=
checkpoint_interval
,
checkpoint
.
save
(
file_prefix
=
file_prefix
)
)
logging
.
info
(
'Checkpoints exported: %s.'
,
file_prefix
)
checkpoint_path
=
self
.
_export_ckpt_manager
.
save
(
checkpoint_number
=
self
.
global_step
.
numpy
(),
check_interval
=
True
)
if
checkpoint_path
:
logging
.
info
(
'Checkpoints exported: %s.'
,
checkpoint_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment