Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2afac6d1
"vscode:/vscode.git/clone" did not exist on "8c1d1b08fa3bbff5b5e77d62fc0038418b520224"
Commit
2afac6d1
authored
Jul 13, 2021
by
Rino Lee
Committed by
A. Unique TensorFlower
Jul 13, 2021
Browse files
Internal change
PiperOrigin-RevId: 384478988
parent
bac45446
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
0 deletions
+63
-0
official/core/actions.py
official/core/actions.py
+62
-0
official/core/train_lib.py
official/core/train_lib.py
+1
-0
No files found.
official/core/actions.py
View file @
2afac6d1
...
@@ -20,12 +20,57 @@ from typing import List
...
@@ -20,12 +20,57 @@ from typing import List
import
gin
import
gin
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.core
import
base_trainer
from
official.core
import
base_trainer
from
official.core
import
config_definitions
from
official.core
import
config_definitions
from
official.modeling
import
optimization
from
official.modeling
import
optimization
class
PruningActions
:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
pruning metrics to tensorboard.
This action must be used when training a pruned model to avoid pruning error.
"""
def
__init__
(
self
,
export_dir
:
str
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the pruning summaries.
model: `tf.keras.Model` model instance used for training. This will be
used to assign a pruning step to each prunable weight.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to find the current training steps.
"""
self
.
_optimizer
=
optimizer
self
.
update_pruning_step
=
tfmot
.
sparsity
.
keras
.
UpdatePruningStep
()
self
.
update_pruning_step
.
set_model
(
model
)
self
.
update_pruning_step
.
on_train_begin
()
self
.
pruning_summaries
=
tfmot
.
sparsity
.
keras
.
PruningSummaries
(
log_dir
=
export_dir
)
model
.
optimizer
=
optimizer
self
.
pruning_summaries
.
set_model
(
model
)
def
__call__
(
self
,
output
:
orbit
.
runner
.
Output
):
"""Update pruning step and log pruning summaries.
Args:
output: The train output to test.
"""
self
.
update_pruning_step
.
on_epoch_end
(
batch
=
None
)
self
.
pruning_summaries
.
on_epoch_begin
(
epoch
=
None
)
class
EMACheckpointing
:
class
EMACheckpointing
:
"""Eval action to save checkpoint with average weights when EMA is used.
"""Eval action to save checkpoint with average weights when EMA is used.
...
@@ -92,3 +137,20 @@ def get_eval_actions(
...
@@ -92,3 +137,20 @@ def get_eval_actions(
max_to_keep
=
params
.
trainer
.
max_to_keep
))
max_to_keep
=
params
.
trainer
.
max_to_keep
))
return
eval_actions
return
eval_actions
@
gin
.
configurable
def
get_train_actions
(
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
"""Gets train actions for TFM trainer."""
train_actions
=
[]
# Adds pruning callback actions.
if
hasattr
(
params
.
task
,
'pruning'
):
train_actions
.
append
(
PruningActions
(
export_dir
=
model_dir
,
model
=
trainer
.
model
,
optimizer
=
trainer
.
optimizer
))
return
train_actions
official/core/train_lib.py
View file @
2afac6d1
...
@@ -105,6 +105,7 @@ def run_experiment(
...
@@ -105,6 +105,7 @@ def run_experiment(
(
save_summary
)
else
None
,
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
,
(
save_summary
)
else
None
,
train_actions
=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
),
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment