Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bde9cdca
Commit
bde9cdca
authored
Feb 07, 2021
by
Le Hou
Committed by
A. Unique TensorFlower
Feb 07, 2021
Browse files
Internal change
PiperOrigin-RevId: 356158181
parent
98839bd2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
21 deletions
+27
-21
official/modeling/progressive/trainer.py
official/modeling/progressive/trainer.py
+2
-21
official/modeling/progressive/utils.py
official/modeling/progressive/utils.py
+25
-0
No files found.
official/modeling/progressive/trainer.py
View file @
bde9cdca
...
...
@@ -32,6 +32,7 @@ from official.core import base_task
from
official.core
import
base_trainer
as
trainer_lib
from
official.core
import
config_definitions
from
official.modeling.progressive
import
policies
from
official.modeling.progressive
import
utils
ExperimentConfig
=
config_definitions
.
ExperimentConfig
...
...
@@ -61,26 +62,6 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
export_only_final_stage_ckpt
:
bool
=
True
class
CheckpointWithHooks
(
tf
.
train
.
Checkpoint
):
"""Same as tf.train.Checkpoint but supports hooks.
When running continuous_eval jobs, when a new checkpoint arrives, we have to
update our model and optimizer etc. to match the stage_id of the checkpoint.
However, when orbit loads a checkpoint, it does not inform us. So we use this
class to update our model to the correct stage before checkpoint restore.
"""
def
__init__
(
self
,
before_load_hook
,
**
kwargs
):
self
.
_before_load_hook
=
before_load_hook
super
(
CheckpointWithHooks
,
self
).
__init__
(
**
kwargs
)
# override
def
read
(
self
,
save_path
,
options
=
None
):
self
.
_before_load_hook
(
save_path
)
logging
.
info
(
'Ran before_load_hook.'
)
super
(
CheckpointWithHooks
,
self
).
read
(
save_path
=
save_path
,
options
=
options
)
@
gin
.
configurable
class
ProgressiveTrainer
(
trainer_lib
.
Trainer
):
"""Implements the progressive trainer shared for TensorFlow models."""
...
...
@@ -124,7 +105,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
self
.
_global_step
=
orbit
.
utils
.
create_global_step
()
self
.
_checkpoint
=
CheckpointWithHooks
(
self
.
_checkpoint
=
utils
.
CheckpointWithHooks
(
before_load_hook
=
self
.
_update_pt_stage_from_ckpt
,
global_step
=
self
.
global_step
,
**
self
.
_task
.
cur_checkpoint_items
)
...
...
official/modeling/progressive/utils.py
View file @
bde9cdca
...
...
@@ -14,6 +14,9 @@
"""Util classes and functions."""
from
absl
import
logging
import
tensorflow
as
tf
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.training.tracking
import
tracking
...
...
@@ -29,3 +32,25 @@ class VolatileTrackable(tracking.AutoTrackable):
for
k
,
v
in
kwargs
.
items
():
delattr
(
self
,
k
)
# untrack this object
setattr
(
self
,
k
,
v
)
# track the new object
class
CheckpointWithHooks
(
tf
.
train
.
Checkpoint
):
"""Same as tf.train.Checkpoint but supports hooks.
In progressive training, use this class instead of tf.train.Checkpoint.
Since the network architecture changes during progressive training, we need to
prepare something (like switch to the correct architecture) before loading the
checkpoint. This class supports a hook that will be executed before checkpoint
loading.
"""
def
__init__
(
self
,
before_load_hook
,
**
kwargs
):
self
.
_before_load_hook
=
before_load_hook
super
(
CheckpointWithHooks
,
self
).
__init__
(
**
kwargs
)
# override
def
read
(
self
,
save_path
,
options
=
None
):
self
.
_before_load_hook
(
save_path
)
logging
.
info
(
'Ran before_load_hook.'
)
super
(
CheckpointWithHooks
,
self
).
read
(
save_path
=
save_path
,
options
=
options
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment