Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6a2a62a6
"tools/text_generation_cli.py" did not exist on "61184a8fa52035cbc6bfcdd89a48deda22dd5e15"
Commit
6a2a62a6
authored
Mar 08, 2022
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Mar 09, 2022
Browse files
Internal change
PiperOrigin-RevId: 433381529
parent
a9322830
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
691 additions
and
0 deletions
+691
-0
official/projects/pruning/README.md
official/projects/pruning/README.md
+44
-0
official/projects/pruning/configs/__init__.py
official/projects/pruning/configs/__init__.py
+18
-0
official/projects/pruning/configs/experiments/image_classification/imagenet_mobilenetv2_pruning_gpu.yaml
...mage_classification/imagenet_mobilenetv2_pruning_gpu.yaml
+59
-0
official/projects/pruning/configs/experiments/image_classification/imagenet_resnet50_pruning_gpu.yaml
...s/image_classification/imagenet_resnet50_pruning_gpu.yaml
+60
-0
official/projects/pruning/configs/image_classification.py
official/projects/pruning/configs/image_classification.py
+81
-0
official/projects/pruning/configs/image_classification_test.py
...ial/projects/pruning/configs/image_classification_test.py
+49
-0
official/projects/pruning/registry_imports.py
official/projects/pruning/registry_imports.py
+18
-0
official/projects/pruning/tasks/__init__.py
official/projects/pruning/tasks/__init__.py
+18
-0
official/projects/pruning/tasks/image_classification.py
official/projects/pruning/tasks/image_classification.py
+139
-0
official/projects/pruning/tasks/image_classification_test.py
official/projects/pruning/tasks/image_classification_test.py
+176
-0
official/projects/pruning/train.py
official/projects/pruning/train.py
+29
-0
No files found.
official/projects/pruning/README.md
0 → 100644
View file @
6a2a62a6
# Training with Pruning
[TOC]
⚠️ Disclaimer: All datasets hyperlinked from this page are not owned or
distributed by Google. The dataset is made available by third parties.
Please review the terms and conditions made available by the third parties
before using the data.
## Overview
This project includes pruning codes for TensorFlow models.
These are examples to show how to apply the Model Optimization Toolkit's
[
pruning API
](
https://www.tensorflow.org/model_optimization/guide/pruning
)
.
## How to train a model
```
bash
EXPERIMENT
=
xxx
# Change this for your run, for example, 'resnet_imagenet_pruning'
CONFIG_FILE
=
xxx
# Change this for your run, for example, path of imagenet_resnet50_pruning_gpu.yaml
MODEL_DIR
=
xxx
# Change this for your run, for example, /tmp/model_dir
python3 train.py
\
--experiment
=
${
EXPERIMENT
}
\
--config_file
=
${
CONFIG_FILE
}
\
--model_dir
=
${
MODEL_DIR
}
\
--mode
=
train_and_eval
```
## Accuracy
<figure
align=
"center"
>
<img
width=
70%
src=
https://storage.googleapis.com/tf_model_garden/models/pruning/images/readme-pruning-classification-resnet.png
>
<img
width=
70%
src=
https://storage.googleapis.com/tf_model_garden/models/pruning/images/readme-pruning-classification-mobilenet.png
>
<figcaption>
Comparison of Imagenet top-1 accuracy for the classification models
</figcaption>
</figure>
Note: The Top-1 model accuracy is measured on the validation set of
[
ImageNet
](
https://www.image-net.org/
)
.
## Pre-trained Models
### Image Classification
Model |Resolution|Top-1 Accuracy (Dense)|Top-1 Accuracy (50% sparsity)|Top-1 Accuracy (80% sparsity)|Config |Download
----------------------|----------|---------------------|-------------------------|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|
|MobileNetV2 |224x224 |72.768% |71.334% |61.378% |
[
config
](
https://github.com/tensorflow/models/blob/master/official/projects/pruning/configs/experiments/image_classification/imagenet_mobilenetv2_pruning_gpu.yaml
)
|
[
TFLite(50% sparsity)
](
https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v2_1.0_float/mobilenet_v2_0.5_pruned_1.00_224_float.tflite
)
, |
|ResNet50 |224x224 |76.704% |76.61% |75.508% |
[
config
](
https://github.com/tensorflow/models/blob/master/official/projects/pruning/configs/experiments/image_classification/imagenet_resnet50_pruning_gpu.yaml
)
|
[
TFLite(80% sparsity)
](
https://storage.googleapis.com/tf_model_garden/vision/resnet50_imagenet/resnet_50_0.8_pruned_224_float.tflite
)
|
official/projects/pruning/configs/__init__.py
0 → 100644
View file @
6a2a62a6
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from
official.projects.pruning.configs
import
image_classification
official/projects/pruning/configs/experiments/image_classification/imagenet_mobilenetv2_pruning_gpu.yaml
0 → 100644
View file @
6a2a62a6
# MobileNetV2_1.0 ImageNet classification.
runtime
:
distribution_strategy
:
'
mirrored'
mixed_precision_dtype
:
'
float32'
loss_scale
:
'
dynamic'
task
:
model
:
num_classes
:
1001
input_size
:
[
224
,
224
,
3
]
backbone
:
type
:
'
mobilenet'
mobilenet
:
model_id
:
'
MobileNetV2'
filter_size_scale
:
1.0
dropout_rate
:
0.1
losses
:
l2_weight_decay
:
0
one_hot
:
true
label_smoothing
:
0.1
train_data
:
input_path
:
'
/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/train*'
is_training
:
true
global_batch_size
:
1024
dtype
:
'
float32'
validation_data
:
input_path
:
'
/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/valid*'
is_training
:
false
global_batch_size
:
1024
dtype
:
'
float32'
drop_remainder
:
false
pruning
:
pretrained_original_checkpoint
:
'
gs://**/mobilenetv2_gpu/22984194/ckpt-625500'
pruning_schedule
:
'
PolynomialDecay'
begin_step
:
0
end_step
:
80000
initial_sparsity
:
0.2
final_sparsity
:
0.5
frequency
:
400
trainer
:
# Top1 accuracy 71.33% after 17hr for 8 GPUs with pruning.
# Pretrained network without pruning has Top1 accuracy 72.77%
train_steps
:
125100
# 50 epoch
validation_steps
:
98
validation_interval
:
2502
steps_per_loop
:
2502
summary_interval
:
2502
checkpoint_interval
:
2502
optimizer_config
:
learning_rate
:
type
:
'
exponential'
exponential
:
initial_learning_rate
:
0.04
decay_steps
:
5004
decay_rate
:
0.85
staircase
:
true
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
0
official/projects/pruning/configs/experiments/image_classification/imagenet_resnet50_pruning_gpu.yaml
0 → 100644
View file @
6a2a62a6
runtime
:
distribution_strategy
:
'
mirrored'
mixed_precision_dtype
:
'
float32'
loss_scale
:
'
dynamic'
task
:
model
:
num_classes
:
1001
input_size
:
[
224
,
224
,
3
]
backbone
:
type
:
'
resnet'
resnet
:
model_id
:
50
losses
:
l2_weight_decay
:
0
one_hot
:
true
label_smoothing
:
0.1
train_data
:
input_path
:
'
/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/train*'
is_training
:
true
global_batch_size
:
1024
dtype
:
'
float32'
validation_data
:
input_path
:
'
/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/valid*'
is_training
:
false
global_batch_size
:
1024
dtype
:
'
float32'
drop_remainder
:
false
pruning
:
pretrained_original_checkpoint
:
'
gs://**/resnet_classifier_gpu/ckpt-56160'
pruning_schedule
:
'
PolynomialDecay'
begin_step
:
0
end_step
:
40000
initial_sparsity
:
0.2
final_sparsity
:
0.8
frequency
:
40
trainer
:
# Top1 accuracy 75.508% after 7hr for 8 GPUs with pruning.
# Pretrained network without pruning has Top1 accuracy 76.7%
train_steps
:
50000
validation_steps
:
50
validation_interval
:
1251
steps_per_loop
:
1251
summary_interval
:
1251
checkpoint_interval
:
1251
optimizer_config
:
optimizer
:
type
:
'
sgd'
sgd
:
momentum
:
0.9
learning_rate
:
type
:
'
exponential'
exponential
:
initial_learning_rate
:
0.01
decay_steps
:
2502
decay_rate
:
0.9
staircase
:
true
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
0
official/projects/pruning/configs/image_classification.py
0 → 100644
View file @
6a2a62a6
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification configuration definition."""
import
dataclasses
from
typing
import
Optional
,
Tuple
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.vision.configs
import
image_classification
@
dataclasses
.
dataclass
class
PruningConfig
(
hyperparams
.
Config
):
"""Pruning parameters.
Attributes:
pretrained_original_checkpoint: The pretrained checkpoint location of the
original model.
pruning_schedule: A string that indicates the name of `PruningSchedule`
object that controls pruning rate throughout training. Current available
options are: `PolynomialDecay` and `ConstantSparsity`.
begin_step: Step at which to begin pruning.
end_step: Step at which to end pruning.
initial_sparsity: Sparsity ratio at which pruning begins.
final_sparsity: Sparsity ratio at which pruning ends.
frequency: Number of training steps between sparsity adjustment.
sparsity_m_by_n: Structured sparsity specification. It specifies m zeros
over n consecutive weight elements.
"""
pretrained_original_checkpoint
:
Optional
[
str
]
=
None
pruning_schedule
:
str
=
'PolynomialDecay'
begin_step
:
int
=
0
end_step
:
int
=
1000
initial_sparsity
:
float
=
0.0
final_sparsity
:
float
=
0.1
frequency
:
int
=
100
sparsity_m_by_n
:
Optional
[
Tuple
[
int
,
int
]]
=
None
@
dataclasses
.
dataclass
class
ImageClassificationTask
(
image_classification
.
ImageClassificationTask
):
pruning
:
Optional
[
PruningConfig
]
=
None
@
exp_factory
.
register_config_factory
(
'resnet_imagenet_pruning'
)
def
image_classification_imagenet
()
->
cfg
.
ExperimentConfig
:
"""Builds an image classification config for the resnet with pruning."""
config
=
image_classification
.
image_classification_imagenet
()
task
=
ImageClassificationTask
.
from_args
(
pruning
=
PruningConfig
(),
**
config
.
task
.
as_dict
())
config
.
task
=
task
runtime
=
cfg
.
RuntimeConfig
(
enable_xla
=
False
)
config
.
runtime
=
runtime
return
config
@
exp_factory
.
register_config_factory
(
'mobilenet_imagenet_pruning'
)
def
image_classification_imagenet_mobilenet
()
->
cfg
.
ExperimentConfig
:
"""Builds an image classification config for the mobilenetV2 with pruning."""
config
=
image_classification
.
image_classification_imagenet_mobilenet
()
task
=
ImageClassificationTask
.
from_args
(
pruning
=
PruningConfig
(),
**
config
.
task
.
as_dict
())
config
.
task
=
task
return
config
official/projects/pruning/configs/image_classification_test.py
0 → 100644
View file @
6a2a62a6
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for image_classification."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.projects.pruning.configs
import
image_classification
as
pruning_exp_cfg
from
official.vision
import
beta
from
official.vision.configs
import
image_classification
as
exp_cfg
class
ImageClassificationConfigTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
(
'resnet_imagenet_pruning'
,),
(
'mobilenet_imagenet_pruning'
),
)
def
test_image_classification_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
ImageClassificationTask
)
self
.
assertIsInstance
(
config
.
task
,
pruning_exp_cfg
.
ImageClassificationTask
)
self
.
assertIsInstance
(
config
.
task
.
pruning
,
pruning_exp_cfg
.
PruningConfig
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
ImageClassificationModel
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
validate
()
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/pruning/registry_imports.py
0 → 100644
View file @
6a2a62a6
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration on pruning project."""
# pylint: disable=unused-import
from
official.projects.pruning
import
configs
from
official.projects.pruning.tasks
import
image_classification
official/projects/pruning/tasks/__init__.py
0 → 100644
View file @
6a2a62a6
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Modeling package definition."""
from
official.projects.pruning.tasks
import
image_classification
official/projects/pruning/tasks/image_classification.py
0 → 100644
View file @
6a2a62a6
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification task definition."""
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.core
import
task_factory
from
official.projects.pruning.configs
import
image_classification
as
exp_cfg
from
official.vision.modeling.backbones
import
mobilenet
from
official.vision.modeling.layers
import
nn_blocks
from
official.vision.tasks
import
image_classification
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
class
ImageClassificationTask
(
image_classification
.
ImageClassificationTask
):
"""A task for image classification with pruning."""
_BLOCK_LAYER_SUFFIX_MAP
=
{
nn_blocks
.
BottleneckBlock
:
(
'conv2d/kernel:0'
,
'conv2d_1/kernel:0'
,
'conv2d_2/kernel:0'
,
'conv2d_3/kernel:0'
,
),
nn_blocks
.
InvertedBottleneckBlock
:
(
'conv2d/kernel:0'
,
'conv2d_1/kernel:0'
,
'depthwise_conv2d/depthwise_kernel:0'
),
mobilenet
.
Conv2DBNBlock
:
(
'conv2d/kernel:0'
,),
}
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
"""Builds classification model with pruning."""
model
=
super
(
ImageClassificationTask
,
self
).
build_model
()
if
self
.
task_config
.
pruning
is
None
:
return
model
pruning_cfg
=
self
.
task_config
.
pruning
prunable_model
=
tf
.
keras
.
models
.
clone_model
(
model
,
clone_function
=
self
.
_make_block_prunable
,
)
original_checkpoint
=
pruning_cfg
.
pretrained_original_checkpoint
if
original_checkpoint
is
not
None
:
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
prunable_model
,
**
model
.
checkpoint_items
)
status
=
ckpt
.
read
(
original_checkpoint
)
status
.
expect_partial
().
assert_existing_objects_matched
()
pruning_params
=
{}
if
pruning_cfg
.
sparsity_m_by_n
is
not
None
:
pruning_params
[
'sparsity_m_by_n'
]
=
pruning_cfg
.
sparsity_m_by_n
if
pruning_cfg
.
pruning_schedule
==
'PolynomialDecay'
:
pruning_params
[
'pruning_schedule'
]
=
tfmot
.
sparsity
.
keras
.
PolynomialDecay
(
initial_sparsity
=
pruning_cfg
.
initial_sparsity
,
final_sparsity
=
pruning_cfg
.
final_sparsity
,
begin_step
=
pruning_cfg
.
begin_step
,
end_step
=
pruning_cfg
.
end_step
,
frequency
=
pruning_cfg
.
frequency
)
elif
pruning_cfg
.
pruning_schedule
==
'ConstantSparsity'
:
pruning_params
[
'pruning_schedule'
]
=
tfmot
.
sparsity
.
keras
.
ConstantSparsity
(
target_sparsity
=
pruning_cfg
.
final_sparsity
,
begin_step
=
pruning_cfg
.
begin_step
,
frequency
=
pruning_cfg
.
frequency
)
else
:
raise
NotImplementedError
(
'Only PolynomialDecay and ConstantSparsity are currently supported. Not support %s'
%
pruning_cfg
.
pruning_schedule
)
pruned_model
=
tfmot
.
sparsity
.
keras
.
prune_low_magnitude
(
prunable_model
,
**
pruning_params
)
# Print out prunable weights for debugging purpose.
prunable_layers
=
collect_prunable_layers
(
pruned_model
)
pruned_weights
=
[]
for
layer
in
prunable_layers
:
pruned_weights
+=
[
weight
.
name
for
weight
,
_
,
_
in
layer
.
pruning_vars
]
unpruned_weights
=
[
weight
.
name
for
weight
in
pruned_model
.
weights
if
weight
.
name
not
in
pruned_weights
]
logging
.
info
(
'%d / %d weights are pruned.
\n
Pruned weights: [
\n
%s
\n
],
\n
'
'Unpruned weights: [
\n
%s
\n
],'
,
len
(
pruned_weights
),
len
(
model
.
weights
),
', '
.
join
(
pruned_weights
),
', '
.
join
(
unpruned_weights
))
return
pruned_model
def
_make_block_prunable
(
self
,
layer
:
tf
.
keras
.
layers
.
Layer
)
->
tf
.
keras
.
layers
.
Layer
:
if
isinstance
(
layer
,
tf
.
keras
.
Model
):
return
tf
.
keras
.
models
.
clone_model
(
layer
,
input_tensors
=
None
,
clone_function
=
self
.
_make_block_prunable
)
if
layer
.
__class__
not
in
self
.
_BLOCK_LAYER_SUFFIX_MAP
:
return
layer
prunable_weights
=
[]
for
layer_suffix
in
self
.
_BLOCK_LAYER_SUFFIX_MAP
[
layer
.
__class__
]:
for
weight
in
layer
.
weights
:
if
weight
.
name
.
endswith
(
layer_suffix
):
prunable_weights
.
append
(
weight
)
def
get_prunable_weights
():
return
prunable_weights
layer
.
get_prunable_weights
=
get_prunable_weights
return
layer
def
collect_prunable_layers
(
model
):
"""Recursively collect the prunable layers in the model."""
prunable_layers
=
[]
for
layer
in
model
.
layers
:
if
isinstance
(
layer
,
tf
.
keras
.
Model
):
prunable_layers
+=
collect_prunable_layers
(
layer
)
if
layer
.
__class__
.
__name__
==
'PruneLowMagnitude'
:
prunable_layers
.
append
(
layer
)
return
prunable_layers
official/projects/pruning/tasks/image_classification_test.py
0 → 100644
View file @
6a2a62a6
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for image classification task."""
# pylint: disable=unused-import
import
tempfile
from
absl.testing
import
parameterized
import
numpy
as
np
import
orbit
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.core
import
actions
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.projects.pruning.tasks
import
image_classification
as
img_cls_task
from
official.vision
import
beta
class
ImageClassificationTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_validate_model_pruned
(
self
,
model
,
config_name
):
pruning_weight_names
=
[]
prunable_layers
=
img_cls_task
.
collect_prunable_layers
(
model
)
for
layer
in
prunable_layers
:
for
weight
,
_
,
_
in
layer
.
pruning_vars
:
pruning_weight_names
.
append
(
weight
.
name
)
if
config_name
==
'resnet_imagenet_pruning'
:
# Conv2D : 1
# BottleneckBlockGroup : 4+3+3 = 10
# BottleneckBlockGroup1 : 4+3+3+3 = 13
# BottleneckBlockGroup2 : 4+3+3+3+3+3 = 19
# BottleneckBlockGroup3 : 4+3+3 = 10
# FullyConnected : 1
# Total : 54
self
.
assertLen
(
pruning_weight_names
,
54
)
elif
config_name
==
'mobilenet_imagenet_pruning'
:
# Conv2DBN = 1
# InvertedBottleneckBlockGroup = 2
# InvertedBottleneckBlockGroup1~16 = 48
# Conv2DBN = 1
# FullyConnected : 1
# Total : 53
self
.
assertLen
(
pruning_weight_names
,
53
)
def
_check_2x4_sparsity
(
self
,
model
):
def
_is_pruned_2_by_4
(
weights
):
if
weights
.
shape
.
rank
==
2
:
prepared_weights
=
tf
.
transpose
(
weights
)
elif
weights
.
shape
.
rank
==
4
:
perm_weights
=
tf
.
transpose
(
weights
,
perm
=
[
3
,
0
,
1
,
2
])
prepared_weights
=
tf
.
reshape
(
perm_weights
,
[
-
1
,
perm_weights
.
shape
[
-
1
]])
prepared_weights_np
=
prepared_weights
.
numpy
()
for
row
in
range
(
0
,
prepared_weights_np
.
shape
[
0
]):
for
col
in
range
(
0
,
prepared_weights_np
.
shape
[
1
],
4
):
if
np
.
count_nonzero
(
prepared_weights_np
[
row
,
col
:
col
+
4
])
>
2
:
return
False
return
True
prunable_layers
=
img_cls_task
.
collect_prunable_layers
(
model
)
for
layer
in
prunable_layers
:
for
weight
,
_
,
_
in
layer
.
pruning_vars
:
if
weight
.
shape
[
-
2
]
%
4
==
0
:
self
.
assertTrue
(
_is_pruned_2_by_4
(
weight
))
def
_validate_metrics
(
self
,
logs
,
metrics
):
for
metric
in
metrics
:
logs
[
metric
.
name
]
=
metric
.
result
()
self
.
assertIn
(
'loss'
,
logs
)
self
.
assertIn
(
'accuracy'
,
logs
)
self
.
assertIn
(
'top_5_accuracy'
,
logs
)
@
parameterized
.
parameters
((
'resnet_imagenet_pruning'
),
(
'mobilenet_imagenet_pruning'
))
def
testTaskWithUnstructuredSparsity
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
config
.
task
.
train_data
.
global_batch_size
=
2
task
=
img_cls_task
.
ImageClassificationTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
config
.
task
.
train_data
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
if
isinstance
(
optimizer
,
optimization
.
ExponentialMovingAverage
)
and
not
optimizer
.
has_shadow_copy
:
optimizer
.
shadow_copy
(
model
)
if
config
.
task
.
pruning
:
# This is an auxilary initialization required to prune a model which is
# originally done in the train library.
actions
.
PruningAction
(
export_dir
=
tempfile
.
gettempdir
(),
model
=
model
,
optimizer
=
optimizer
)
# Check all layers and target weights are successfully pruned.
self
.
_validate_model_pruned
(
model
,
config_name
)
logs
=
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
self
.
_validate_metrics
(
logs
,
metrics
)
logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
self
.
_validate_metrics
(
logs
,
metrics
)
@
parameterized
.
parameters
((
'resnet_imagenet_pruning'
),
(
'mobilenet_imagenet_pruning'
))
def
testTaskWithStructuredSparsity
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
config
.
task
.
train_data
.
global_batch_size
=
2
# Add structured sparsity
config
.
task
.
pruning
.
sparsity_m_by_n
=
(
2
,
4
)
config
.
task
.
pruning
.
frequency
=
1
task
=
img_cls_task
.
ImageClassificationTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
config
.
task
.
train_data
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
if
isinstance
(
optimizer
,
optimization
.
ExponentialMovingAverage
)
and
not
optimizer
.
has_shadow_copy
:
optimizer
.
shadow_copy
(
model
)
# This is an auxiliary initialization required to prune a model which is
# originally done in the train library.
pruning_actions
=
actions
.
PruningAction
(
export_dir
=
tempfile
.
gettempdir
(),
model
=
model
,
optimizer
=
optimizer
)
# Check all layers and target weights are successfully pruned.
self
.
_validate_model_pruned
(
model
,
config_name
)
logs
=
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
self
.
_validate_metrics
(
logs
,
metrics
)
logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
self
.
_validate_metrics
(
logs
,
metrics
)
pruning_actions
.
update_pruning_step
.
on_epoch_end
(
batch
=
None
)
# Check whether the weights are pruned in 2x4 pattern.
self
.
_check_2x4_sparsity
(
model
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/pruning/train.py
0 → 100644
View file @
6a2a62a6
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver, including Pruning configs.."""
from
absl
import
app
from
official.common
import
flags
as
tfm_flags
# To build up a connection with the training binary for pruning, the custom
# configs & tasks are imported while unused.
from
official.projects.pruning
import
configs
# pylint: disable=unused-import
from
official.projects.pruning.tasks
import
image_classification
# pylint: disable=unused-import
from
official.vision
import
train
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
train
.
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment