Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e5c3ac63
Unverified
Commit
e5c3ac63
authored
Aug 16, 2021
by
J-shang
Committed by
GitHub
Aug 16, 2021
Browse files
Compression v2 Stage 1 (#3917)
parent
e219bae8
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
2196 additions
and
0 deletions
+2196
-0
examples/model_compress/pruning/v2/naive_prune_torch.py
examples/model_compress/pruning/v2/naive_prune_torch.py
+153
-0
nni/algorithms/compression/v2/pytorch/__init__.py
nni/algorithms/compression/v2/pytorch/__init__.py
+0
-0
nni/algorithms/compression/v2/pytorch/base/__init__.py
nni/algorithms/compression/v2/pytorch/base/__init__.py
+2
-0
nni/algorithms/compression/v2/pytorch/base/compressor.py
nni/algorithms/compression/v2/pytorch/base/compressor.py
+285
-0
nni/algorithms/compression/v2/pytorch/base/pruner.py
nni/algorithms/compression/v2/pytorch/base/pruner.py
+178
-0
nni/algorithms/compression/v2/pytorch/pruning/__init__.py
nni/algorithms/compression/v2/pytorch/pruning/__init__.py
+1
-0
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+645
-0
nni/algorithms/compression/v2/pytorch/pruning/tools/__init__.py
...gorithms/compression/v2/pytorch/pruning/tools/__init__.py
+23
-0
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
+442
-0
nni/algorithms/compression/v2/pytorch/pruning/tools/data_collector.py
...ms/compression/v2/pytorch/pruning/tools/data_collector.py
+58
-0
nni/algorithms/compression/v2/pytorch/pruning/tools/metrics_calculator.py
...ompression/v2/pytorch/pruning/tools/metrics_calculator.py
+184
-0
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+150
-0
nni/algorithms/compression/v2/pytorch/utils/__init__.py
nni/algorithms/compression/v2/pytorch/utils/__init__.py
+0
-0
nni/algorithms/compression/v2/pytorch/utils/config_validation.py
...orithms/compression/v2/pytorch/utils/config_validation.py
+75
-0
No files found.
examples/model_compress/pruning/v2/naive_prune_torch.py
0 → 100644
View file @
e5c3ac63
import
argparse
import
logging
from
pathlib
import
Path
import
torch
from
torchvision
import
transforms
,
datasets
from
nni.algorithms.compression.v2.pytorch
import
pruning
from
nni.compression.pytorch
import
ModelSpeedup
from
examples.model_compress.models.cifar10.vgg
import
VGG
logging
.
getLogger
().
setLevel
(
logging
.
DEBUG
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
=
VGG
().
to
(
device
)
normalize
=
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
))
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./data'
,
train
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomCrop
(
32
,
4
),
transforms
.
ToTensor
(),
normalize
,
]),
download
=
True
),
batch_size
=
128
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./data'
,
train
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
normalize
,
])),
batch_size
=
200
,
shuffle
=
False
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
=
None
):
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
if
batch_idx
%
100
==
0
:
print
(
'Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}'
.
format
(
epoch
,
batch_idx
*
len
(
data
),
len
(
train_loader
.
dataset
),
100.
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
evaluator
(
model
):
model
.
eval
()
criterion
=
torch
.
nn
.
NLLLoss
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
test_loss
+=
criterion
(
output
,
target
).
item
()
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
acc
=
100
*
correct
/
len
(
test_loader
.
dataset
)
print
(
'Test Loss: {} Accuracy: {}%
\n
'
.
format
(
test_loss
,
acc
))
return
acc
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
fintune_optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
def
main
(
args
):
if
args
.
pre_train
:
for
i
in
range
(
1
):
trainer
(
model
,
fintune_optimizer
,
criterion
,
epoch
=
i
)
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity_per_layer'
:
0.8
}]
kwargs
=
{
'model'
:
model
,
'config_list'
:
config_list
,
}
if
args
.
pruner
==
'level'
:
pruner
=
pruning
.
LevelPruner
(
**
kwargs
)
else
:
kwargs
[
'mode'
]
=
args
.
mode
if
kwargs
[
'mode'
]
==
'dependency_aware'
:
kwargs
[
'dummy_input'
]
=
torch
.
rand
(
10
,
3
,
32
,
32
).
to
(
device
)
if
args
.
pruner
==
'l1norm'
:
pruner
=
pruning
.
L1NormPruner
(
**
kwargs
)
elif
args
.
pruner
==
'l2norm'
:
pruner
=
pruning
.
L2NormPruner
(
**
kwargs
)
elif
args
.
pruner
==
'fpgm'
:
pruner
=
pruning
.
FPGMPruner
(
**
kwargs
)
else
:
kwargs
[
'trainer'
]
=
trainer
kwargs
[
'optimizer'
]
=
optimizer
kwargs
[
'criterion'
]
=
criterion
if
args
.
pruner
==
'slim'
:
kwargs
[
'config_list'
]
=
[{
'op_types'
:
[
'BatchNorm2d'
],
'total_sparsity'
:
0.8
,
'max_sparsity_per_layer'
:
0.9
}]
kwargs
[
'training_epochs'
]
=
1
pruner
=
pruning
.
SlimPruner
(
**
kwargs
)
elif
args
.
pruner
==
'mean_activation'
:
pruner
=
pruning
.
ActivationMeanRankPruner
(
**
kwargs
)
elif
args
.
pruner
==
'apoz'
:
pruner
=
pruning
.
ActivationAPoZRankPruner
(
**
kwargs
)
elif
args
.
pruner
==
'taylorfo'
:
pruner
=
pruning
.
TaylorFOWeightPruner
(
**
kwargs
)
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
show_pruned_weights
()
if
args
.
speed_up
:
tmp_masks
=
{}
for
name
,
mask
in
masks
.
items
():
tmp_masks
[
name
]
=
{}
tmp_masks
[
name
][
'weight'
]
=
mask
.
get
(
'weight_mask'
)
if
'bias'
in
masks
:
tmp_masks
[
name
][
'bias'
]
=
mask
.
get
(
'bias_mask'
)
torch
.
save
(
tmp_masks
,
Path
(
'./temp_masks.pth'
))
pruner
.
_unwrap_model
()
ModelSpeedup
(
model
,
torch
.
rand
(
10
,
3
,
32
,
32
).
to
(
device
),
Path
(
'./temp_masks.pth'
))
if
args
.
finetune
:
for
i
in
range
(
1
):
trainer
(
pruned_model
,
fintune_optimizer
,
criterion
,
epoch
=
i
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch CIFAR10 Example for model comporession'
)
parser
.
add_argument
(
'--pruner'
,
type
=
str
,
default
=
'l1norm'
,
choices
=
[
'level'
,
'l1norm'
,
'l2norm'
,
'slim'
,
'fpgm'
,
'mean_activation'
,
'apoz'
,
'taylorfo'
],
help
=
'pruner to use'
)
parser
.
add_argument
(
'--mode'
,
type
=
str
,
default
=
'normal'
,
choices
=
[
'normal'
,
'dependency_aware'
,
'global'
])
parser
.
add_argument
(
'--pre-train'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to pre-train the model'
)
parser
.
add_argument
(
'--speed-up'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to speed-up the pruned model'
)
parser
.
add_argument
(
'--finetune'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to finetune the pruned model'
)
args
=
parser
.
parse_args
()
main
(
args
)
nni/algorithms/compression/v2/pytorch/__init__.py
0 → 100644
View file @
e5c3ac63
nni/algorithms/compression/v2/pytorch/base/__init__.py
0 → 100644
View file @
e5c3ac63
from
.compressor
import
Compressor
,
LayerInfo
from
.pruner
import
Pruner
,
PrunerModuleWrapper
nni/algorithms/compression/v2/pytorch/base/compressor.py
0 → 100644
View file @
e5c3ac63
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
collections
import
logging
from
typing
import
List
,
Dict
,
Optional
,
OrderedDict
,
Tuple
,
Any
import
torch
from
torch.nn
import
Module
from
nni.common.graph_utils
import
TorchModuleGraph
from
nni.compression.pytorch.utils
import
get_module_by_name
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'LayerInfo'
,
'Compressor'
]
class
LayerInfo
:
def
__init__
(
self
,
name
:
str
,
module
:
Module
):
self
.
module
=
module
self
.
name
=
name
self
.
type
=
type
(
module
).
__name__
def
_setattr
(
model
:
Module
,
name
:
str
,
module
:
Module
):
parent_module
,
_
=
get_module_by_name
(
model
,
name
)
if
parent_module
is
not
None
:
name_list
=
name
.
split
(
"."
)
setattr
(
parent_module
,
name_list
[
-
1
],
module
)
else
:
raise
'{} not exist.'
.
format
(
name
)
weighted_modules
=
[
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Linear'
,
'Bilinear'
,
'PReLU'
,
'Embedding'
,
'EmbeddingBag'
,
]
class
Compressor
:
"""
The abstract base pytorch compressor.
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
"""
Parameters
----------
model
The model under compressed.
config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
"""
assert
isinstance
(
model
,
Module
)
self
.
is_wrapped
=
False
self
.
reset
(
model
=
model
,
config_list
=
config_list
)
def
reset
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
"""
Reset the compressor with model and config_list.
Parameters
----------
model
The model under compressed.
config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
"""
assert
isinstance
(
model
,
Module
),
'Only support compressing pytorch Module, but the type of model is {}.'
.
format
(
type
(
model
))
self
.
bound_model
=
model
self
.
config_list
=
config_list
self
.
validate_config
(
model
=
model
,
config_list
=
config_list
)
self
.
_unwrap_model
()
self
.
_modules_to_compress
=
None
self
.
modules_wrapper
=
collections
.
OrderedDict
()
for
layer
,
config
in
self
.
_detect_modules_to_compress
():
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
[
layer
.
name
]
=
wrapper
self
.
_wrap_model
()
def
_detect_modules_to_compress
(
self
)
->
List
[
Tuple
[
LayerInfo
,
Dict
]]:
"""
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
"""
if
self
.
_modules_to_compress
is
None
:
self
.
_modules_to_compress
=
[]
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
module
==
self
.
bound_model
:
continue
layer
=
LayerInfo
(
name
,
module
)
config
=
self
.
_select_config
(
layer
)
if
config
is
not
None
:
self
.
_modules_to_compress
.
append
((
layer
,
config
))
return
self
.
_modules_to_compress
def
_select_config
(
self
,
layer
:
LayerInfo
)
->
Optional
[
Dict
]:
"""
Find the configuration for `layer` by parsing `self.config_list`.
Parameters
----------
layer
The layer that need to check if has compression configuration.
Returns
-------
Optional[Dict]
The retrieved configuration for this layer, if None, this layer should not be compressed.
"""
ret
=
None
for
config
in
self
.
config_list
:
config
=
config
.
copy
()
# expand config if key `default` is in config['op_types']
if
'op_types'
in
config
and
'default'
in
config
[
'op_types'
]:
expanded_op_types
=
[]
for
op_type
in
config
[
'op_types'
]:
if
op_type
==
'default'
:
expanded_op_types
.
extend
(
weighted_modules
)
else
:
expanded_op_types
.
append
(
op_type
)
config
[
'op_types'
]
=
expanded_op_types
# check if condition is satisified
if
'op_types'
in
config
and
layer
.
type
not
in
config
[
'op_types'
]:
continue
if
'op_names'
in
config
and
layer
.
name
not
in
config
[
'op_names'
]:
continue
ret
=
config
if
ret
is
None
or
'exclude'
in
ret
:
return
None
return
ret
def
get_modules_wrapper
(
self
)
->
OrderedDict
[
str
,
Module
]:
"""
Returns
-------
OrderedDict[str, Module]
An ordered dict, key is the name of the module, value is the wrapper of the module.
"""
return
self
.
modules_wrapper
def
_wrap_model
(
self
):
"""
Wrap all modules that needed to be compressed.
"""
if
not
self
.
is_wrapped
:
for
_
,
wrapper
in
reversed
(
self
.
get_modules_wrapper
().
items
()):
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
self
.
is_wrapped
=
True
def
_unwrap_model
(
self
):
"""
Unwrap all modules that needed to be compressed.
"""
if
self
.
is_wrapped
:
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
self
.
is_wrapped
=
False
def
set_wrappers_attribute
(
self
,
name
:
str
,
value
:
Any
):
"""
To register attributes used in wrapped module's forward method.
If the type of the value is Torch.tensor, then this value is registered as a buffer in wrapper,
which will be saved by model.state_dict. Otherwise, this value is just a regular variable in wrapper.
Parameters
----------
name
Name of the variable.
value
Value of the variable.
"""
for
wrapper
in
self
.
get_modules_wrapper
():
if
isinstance
(
value
,
torch
.
Tensor
):
wrapper
.
register_buffer
(
name
,
value
.
clone
())
else
:
setattr
(
wrapper
,
name
,
value
)
def
generate_graph
(
self
,
dummy_input
:
Any
)
->
TorchModuleGraph
:
"""
Generate a `TorchModuleGraph` instance of `self.bound_model` based on `jit.trace`.
Parameters
----------
dummy_input
The dummy input for `jit.trace`, users should put it on right device before pass in.
Returns
-------
TorchModuleGraph
A `TorchModuleGraph` instance.
"""
self
.
_unwrap_model
()
graph
=
TorchModuleGraph
(
model
=
self
.
bound_model
,
dummy_input
=
dummy_input
)
self
.
_wrap_model
()
return
graph
def
generate_module_groups
(
self
)
->
Dict
[
int
,
List
[
str
]]:
"""
Get all module names in each config in config_list.
Returns
-------
Dict[int, List[str]]
A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}.
"""
self
.
_unwrap_model
()
module_groups
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
module
==
self
.
bound_model
:
continue
layer
=
LayerInfo
(
name
,
module
)
ret
=
None
for
idx
,
config
in
enumerate
(
self
.
config_list
):
config
=
config
.
copy
()
# expand config if key `default` is in config['op_types']
if
'op_types'
in
config
and
'default'
in
config
[
'op_types'
]:
expanded_op_types
=
[]
for
op_type
in
config
[
'op_types'
]:
if
op_type
==
'default'
:
expanded_op_types
.
extend
(
weighted_modules
)
else
:
expanded_op_types
.
append
(
op_type
)
config
[
'op_types'
]
=
expanded_op_types
# check if condition is satisified
if
'op_types'
in
config
and
layer
.
type
not
in
config
[
'op_types'
]:
continue
if
'op_names'
in
config
and
layer
.
name
not
in
config
[
'op_names'
]:
continue
ret
=
(
idx
,
config
)
if
ret
is
not
None
and
'exclude'
not
in
ret
[
1
]:
module_groups
.
setdefault
(
ret
[
0
],
[])
module_groups
[
ret
[
0
]].
append
(
name
)
self
.
_wrap_model
()
return
module_groups
def
_wrap_modules
(
self
,
layer
:
LayerInfo
,
config
:
Dict
):
"""
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
Parameters
----------
layer
the layer to instrument the compression operation
config
the configuration for compressing this layer
"""
raise
NotImplementedError
()
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
"""
Subclass can optionally implement this method to check if config_list is valid.
Parameters
----------
model
The model under compressed.
config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
"""
pass
def
compress
(
self
)
->
Module
:
"""
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method.
`self._modules_to_compress` records all the to-be-compressed layers.
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
return
self
.
bound_model
nni/algorithms/compression/v2/pytorch/base/pruner.py
0 → 100644
View file @
e5c3ac63
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
Tensor
from
torch.nn
import
Module
from
.compressor
import
Compressor
,
LayerInfo
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'Pruner'
]
class
PrunerModuleWrapper
(
Module
):
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
,
pruner
:
Compressor
):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module
The module user wants to compress.
config
The configurations that users specify for compression.
module_name
The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
"""
super
().
__init__
()
# origin layer information
self
.
module
=
module
self
.
name
=
module_name
# config and pruner
self
.
config
=
config
self
.
pruner
=
pruner
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
def
forward
(
self
,
*
inputs
):
# apply mask to weight, bias
self
.
module
.
weight
.
data
=
self
.
module
.
weight
.
data
.
mul_
(
self
.
weight_mask
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
module
.
bias
.
data
=
self
.
module
.
bias
.
data
.
mul_
(
self
.
bias_mask
)
return
self
.
module
(
*
inputs
)
class
Pruner
(
Compressor
):
"""
The abstract class for pruning algorithm. Inherit this class and implement the `_reset_tools` to customize a pruner.
"""
def
reset
(
self
,
model
:
Optional
[
Module
]
=
None
,
config_list
:
Optional
[
List
[
Dict
]]
=
None
):
super
().
reset
(
model
=
model
,
config_list
=
config_list
)
def
_wrap_modules
(
self
,
layer
:
LayerInfo
,
config
:
Dict
):
"""
Create a wrapper module to replace the original one.
Parameters
----------
layer
The layer to instrument the mask.
config
The configuration for generating the mask.
"""
_logger
.
debug
(
"Module detected to compress : %s."
,
layer
.
name
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
,
self
)
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
# move newly registered buffers to the same device of weight
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
def
load_masks
(
self
,
masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]):
"""
Load an exist masks on the wrapper. You can train the model with an exist masks after load the masks.
Parameters
----------
masks
The masks dict with format {'op_name': {'weight_mask': mask, 'bias_mask': mask}}.
"""
wrappers
=
self
.
get_modules_wrapper
()
for
name
,
layer_mask
in
masks
.
items
():
assert
name
in
wrappers
,
'{} is not in wrappers of this pruner, can not apply the mask.'
.
format
(
name
)
for
mask_type
,
mask
in
layer_mask
.
items
():
assert
hasattr
(
wrappers
[
name
],
mask_type
),
'there is no attribute {} in wrapper'
.
format
(
mask_type
)
setattr
(
wrappers
[
name
],
mask_type
,
mask
)
def
compress
(
self
)
->
Tuple
[
Module
,
Dict
[
str
,
Dict
[
str
,
Tensor
]]]:
"""
Returns
-------
Tuple[Module, Dict]
Return the wrapped model and mask.
"""
return
self
.
bound_model
,
{}
# NOTE: need refactor dim with supporting list
def
show_pruned_weights
(
self
,
dim
:
int
=
0
):
"""
Log the simulated prune sparsity.
Parameters
----------
dim
The pruned dim.
"""
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
weight_mask
=
wrapper
.
weight_mask
mask_size
=
weight_mask
.
size
()
if
len
(
mask_size
)
==
1
:
index
=
torch
.
nonzero
(
weight_mask
.
abs
()
!=
0
,
as_tuple
=
False
).
tolist
()
else
:
sum_idx
=
list
(
range
(
len
(
mask_size
)))
sum_idx
.
remove
(
dim
)
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
,
as_tuple
=
False
).
tolist
()
_logger
.
info
(
f
'simulated prune
{
wrapper
.
name
}
remain/total:
{
len
(
index
)
}
/
{
weight_mask
.
size
(
dim
)
}
'
)
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
Export pruned model weights, masks and onnx model(optional)
Parameters
----------
model_path
Path to save pruned model state_dict.
mask_path
(optional) path to save mask dict.
onnx_path
(optional) path to save onnx model.
input_shape
Input shape to onnx model.
device
Device of the model, used to place the dummy input tensor for exporting onnx file.
The tensor is placed on cpu if ```device``` is None.
"""
assert
model_path
is
not
None
,
'model_path must be specified'
mask_dict
=
{}
self
.
_unwrap_model
()
# used for generating correct state_dict name without wrapper state
for
name
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
weight_mask
=
wrapper
.
weight_mask
bias_mask
=
wrapper
.
bias_mask
if
weight_mask
is
not
None
:
mask_sum
=
weight_mask
.
sum
().
item
()
mask_num
=
weight_mask
.
numel
()
_logger
.
debug
(
'Layer: %s Sparsity: %.4f'
,
name
,
1
-
mask_sum
/
mask_num
)
wrapper
.
module
.
weight
.
data
=
wrapper
.
module
.
weight
.
data
.
mul
(
weight_mask
)
if
bias_mask
is
not
None
:
wrapper
.
module
.
bias
.
data
=
wrapper
.
module
.
bias
.
data
.
mul
(
bias_mask
)
# save mask to dict
mask_dict
[
name
]
=
{
"weight_mask"
:
weight_mask
,
"bias_mask"
:
bias_mask
}
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
if
mask_path
is
not
None
:
torch
.
save
(
mask_dict
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
# input info needed
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
.
to
(
device
),
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
self
.
_wrap_model
()
nni/algorithms/compression/v2/pytorch/pruning/__init__.py
0 → 100644
View file @
e5c3ac63
from
.basic_pruner
import
*
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
0 → 100644
View file @
e5c3ac63
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
typing
import
List
,
Dict
,
Tuple
,
Callable
,
Optional
from
schema
import
And
,
Optional
as
SchemaOptional
import
torch
from
torch
import
Tensor
import
torch.nn
as
nn
from
torch.nn
import
Module
from
torch.optim
import
Optimizer
from
nni.algorithms.compression.v2.pytorch.base.pruner
import
Pruner
from
nni.algorithms.compression.v2.pytorch.utils.config_validation
import
PrunerSchema
from
.tools
import
(
DataCollector
,
HookCollectorInfo
,
WeightDataCollector
,
WeightTrainerBasedDataCollector
,
SingleHookTrainerBasedDataCollector
)
from
.tools
import
(
MetricsCalculator
,
NormMetricsCalculator
,
MultiDataNormMetricsCalculator
,
DistMetricsCalculator
,
APoZRankMetricsCalculator
,
MeanRankMetricsCalculator
)
from
.tools
import
(
SparsityAllocator
,
NormalSparsityAllocator
,
GlobalSparsityAllocator
,
Conv2dDependencyAwareAllocator
)
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'LevelPruner'
,
'L1NormPruner'
,
'L2NormPruner'
,
'FPGMPruner'
,
'SlimPruner'
,
'ActivationPruner'
,
'ActivationAPoZRankPruner'
,
'ActivationMeanRankPruner'
,
'TaylorFOWeightPruner'
]
class
OneShotPruner
(
Pruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
self
.
data_collector
:
DataCollector
=
None
self
.
metrics_calculator
:
MetricsCalculator
=
None
self
.
sparsity_allocator
:
SparsityAllocator
=
None
self
.
_convert_config_list
(
config_list
)
super
().
__init__
(
model
,
config_list
)
def
_convert_config_list
(
self
,
config_list
:
List
[
Dict
]):
"""
Convert `sparsity` in config to `sparsity_per_layer`.
"""
for
config
in
config_list
:
if
'sparsity'
in
config
:
if
'sparsity_per_layer'
in
config
:
raise
ValueError
(
"'sparsity' and 'sparsity_per_layer' have the same semantics, can not set both in one config."
)
else
:
config
[
'sparsity_per_layer'
]
=
config
.
pop
(
'sparsity'
)
def
reset
(
self
,
model
:
Optional
[
Module
],
config_list
:
Optional
[
List
[
Dict
]]):
super
().
reset
(
model
=
model
,
config_list
=
config_list
)
self
.
reset_tools
()
def
reset_tools
(
self
):
"""
This function is used to reset `self.data_collector`, `self.metrics_calculator` and `self.sparsity_allocator`.
The subclass needs to implement this function to complete the pruning process.
See `compress()` to understand how NNI use these three part to generate mask for the bound model.
"""
raise
NotImplementedError
()
def
compress
(
self
)
->
Tuple
[
Module
,
Dict
]:
"""
Used to generate the mask. Pruning process is divided in three stages.
`self.data_collector` collect the data used to calculate the specify metric.
`self.metrics_calculator` calculate the metric and `self.sparsity_allocator` generate the mask depend on the metric.
Returns
-------
Tuple[Module, Dict]
Return the wrapped model and mask.
"""
data
=
self
.
data_collector
.
collect
()
_logger
.
debug
(
'Collected Data:
\n
%s'
,
data
)
metrics
=
self
.
metrics_calculator
.
calculate_metrics
(
data
)
_logger
.
debug
(
'Metrics Calculate:
\n
%s'
,
metrics
)
masks
=
self
.
sparsity_allocator
.
generate_sparsity
(
metrics
)
_logger
.
debug
(
'Masks:
\n
%s'
,
masks
)
self
.
load_masks
(
masks
)
return
self
.
bound_model
,
masks
class
LevelPruner
(
OneShotPruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Operation types to prune.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
"""
self
.
mode
=
'normal'
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
def
reset_tools
(
self
):
if
self
.
data_collector
is
None
:
self
.
data_collector
=
WeightDataCollector
(
self
)
else
:
self
.
data_collector
.
reset
()
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
NormMetricsCalculator
()
if
self
.
sparsity_allocator
is
None
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
)
class
NormPruner
(
OneShotPruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
p
:
int
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Conv2d and Linear are supported in NormPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
p
The order of norm.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
self
.
p
=
p
self
.
mode
=
mode
self
.
dummy_input
=
dummy_input
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
SchemaOptional
(
'op_types'
):
[
'Conv2d'
,
'Linear'
],
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
def
reset_tools
(
self
):
if
self
.
data_collector
is
None
:
self
.
data_collector
=
WeightDataCollector
(
self
)
else
:
self
.
data_collector
.
reset
()
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
NormMetricsCalculator
(
p
=
self
.
p
,
dim
=
0
)
if
self
.
sparsity_allocator
is
None
:
if
self
.
mode
==
'normal'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
dim
=
0
)
elif
self
.
mode
==
'dependency_aware'
:
self
.
sparsity_allocator
=
Conv2dDependencyAwareAllocator
(
self
,
0
,
self
.
dummy_input
)
else
:
raise
NotImplementedError
(
'Only support mode `normal` and `dependency_aware`'
)
class
L1NormPruner
(
NormPruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Conv2d and Linear are supported in L1NormPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the l1-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
super
().
__init__
(
model
,
config_list
,
1
,
mode
,
dummy_input
)
class
L2NormPruner
(
NormPruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Conv2d and Linear are supported in L2NormPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
super
().
__init__
(
model
,
config_list
,
2
,
mode
,
dummy_input
)
class
FPGMPruner
(
OneShotPruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Conv2d and Linear are supported in FPGMPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the FPGM of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
self
.
mode
=
mode
self
.
dummy_input
=
dummy_input
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
SchemaOptional
(
'op_types'
):
[
'Conv2d'
,
'Linear'
],
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
def
reset_tools
(
self
):
if
self
.
data_collector
is
None
:
self
.
data_collector
=
WeightDataCollector
(
self
)
else
:
self
.
data_collector
.
reset
()
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
DistMetricsCalculator
(
p
=
2
,
dim
=
0
)
if
self
.
sparsity_allocator
is
None
:
if
self
.
mode
==
'normal'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
dim
=
0
)
elif
self
.
mode
==
'dependency_aware'
:
self
.
sparsity_allocator
=
Conv2dDependencyAwareAllocator
(
self
,
0
,
self
.
dummy_input
)
else
:
raise
NotImplementedError
(
'Only support mode `normal` and `dependency_aware`'
)
class
SlimPruner
(
OneShotPruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_epochs
:
int
,
scale
:
float
=
0.0001
,
mode
=
'global'
):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- total_sparsity : This is to specify the total sparsity for all layers in this config,
each layer may have different sparsity.
- max_sparsity_per_layer : Always used with total_sparsity. Limit the max sparsity of each layer.
- op_types : Only BatchNorm2d is supported in SlimPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_epochs
The epoch number for training model to sparsify the BN weight.
mode
'normal' or 'global'.
If prune the model in a global way, all layer weights with same config will be considered uniformly.
That means a single layer may not reach or exceed the sparsity setting in config,
but the total pruned weights meet the sparsity setting.
"""
self
.
mode
=
mode
self
.
trainer
=
trainer
self
.
optimizer
=
optimizer
self
.
criterion
=
criterion
self
.
training_epochs
=
training_epochs
self
.
_scale
=
scale
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
SchemaOptional
(
'total_sparsity'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
SchemaOptional
(
'max_sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
SchemaOptional
(
'op_types'
):
[
'BatchNorm2d'
],
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
def
criterion_patch
(
self
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
])
->
Callable
[[
Tensor
,
Tensor
],
Tensor
]:
def
patched_criterion
(
input_tensor
:
Tensor
,
target
:
Tensor
):
sum_l1
=
0
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
sum_l1
+=
torch
.
norm
(
wrapper
.
module
.
weight
.
data
,
p
=
1
)
return
criterion
(
input_tensor
,
target
)
+
self
.
_scale
*
sum_l1
return
patched_criterion
def
reset_tools
(
self
):
if
self
.
data_collector
is
None
:
self
.
data_collector
=
WeightTrainerBasedDataCollector
(
self
,
self
.
trainer
,
self
.
optimizer
,
self
.
criterion
,
self
.
training_epochs
,
criterion_patch
=
self
.
criterion_patch
)
else
:
self
.
data_collector
.
reset
()
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
NormMetricsCalculator
()
if
self
.
sparsity_allocator
is
None
:
if
self
.
mode
==
'normal'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
)
elif
self
.
mode
==
'global'
:
self
.
sparsity_allocator
=
GlobalSparsityAllocator
(
self
)
else
:
raise
NotImplementedError
(
'Only support mode `normal` and `global`'
)
class
ActivationPruner
(
OneShotPruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_batches
:
int
,
activation
:
str
=
'relu'
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Conv2d and Linear are supported in ActivationPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_batches
The batch number used to collect activations.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the activation-based metrics and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
self
.
mode
=
mode
self
.
dummy_input
=
dummy_input
self
.
trainer
=
trainer
self
.
optimizer
=
optimizer
self
.
criterion
=
criterion
self
.
training_batches
=
training_batches
self
.
_activation
=
self
.
_choose_activation
(
activation
)
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
SchemaOptional
(
'op_types'
):
[
'Conv2d'
,
'Linear'
],
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
def
_choose_activation
(
self
,
activation
:
str
=
'relu'
)
->
Callable
:
if
activation
==
'relu'
:
return
nn
.
functional
.
relu
elif
activation
==
'relu6'
:
return
nn
.
functional
.
relu6
else
:
raise
'Unsupported activatoin {}'
.
format
(
activation
)
def
_collector
(
self
,
buffer
:
List
)
->
Callable
[[
Module
,
Tensor
,
Tensor
],
None
]:
def
collect_activation
(
_module
:
Module
,
_input
:
Tensor
,
output
:
Tensor
):
if
len
(
buffer
)
<
self
.
training_batches
:
buffer
.
append
(
self
.
_activation
(
output
.
detach
()))
return
collect_activation
def
reset_tools
(
self
):
collector_info
=
HookCollectorInfo
([
layer_info
for
layer_info
,
_
in
self
.
_detect_modules_to_compress
()],
'forward'
,
self
.
_collector
)
if
self
.
data_collector
is
None
:
self
.
data_collector
=
SingleHookTrainerBasedDataCollector
(
self
,
self
.
trainer
,
self
.
optimizer
,
self
.
criterion
,
1
,
collector_infos
=
[
collector_info
])
else
:
self
.
data_collector
.
reset
()
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
self
.
_get_metrics_calculator
()
if
self
.
sparsity_allocator
is
None
:
if
self
.
mode
==
'normal'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
dim
=
0
)
elif
self
.
mode
==
'dependency_aware'
:
self
.
sparsity_allocator
=
Conv2dDependencyAwareAllocator
(
self
,
0
,
self
.
dummy_input
)
else
:
raise
NotImplementedError
(
'Only support mode `normal` and `dependency_aware`'
)
def
_get_metrics_calculator
(
self
)
->
MetricsCalculator
:
raise
NotImplementedError
()
class
ActivationAPoZRankPruner
(
ActivationPruner
):
def
_get_metrics_calculator
(
self
)
->
MetricsCalculator
:
return
APoZRankMetricsCalculator
(
dim
=
1
)
class
ActivationMeanRankPruner
(
ActivationPruner
):
def
_get_metrics_calculator
(
self
)
->
MetricsCalculator
:
return
MeanRankMetricsCalculator
(
dim
=
1
)
class
TaylorFOWeightPruner
(
OneShotPruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_batches
:
int
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- total_sparsity : This is to specify the total sparsity for all layers in this config,
each layer may have different sparsity.
- max_sparsity_per_layer : Always used with total_sparsity. Limit the max sparsity of each layer.
- op_types : Conv2d and Linear are supported in TaylorFOWeightPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_batches
The batch number used to collect activations.
mode
'normal', 'dependency_aware' or 'global'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the taylorFO and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
If prune the model in a global way, all layer weights with same config will be considered uniformly.
That means a single layer may not reach or exceed the sparsity setting in config,
but the total pruned weights meet the sparsity setting.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
self
.
mode
=
mode
self
.
dummy_input
=
dummy_input
self
.
trainer
=
trainer
self
.
optimizer
=
optimizer
self
.
criterion
=
criterion
self
.
training_batches
=
training_batches
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
SchemaOptional
(
'total_sparsity'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
SchemaOptional
(
'max_sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
SchemaOptional
(
'op_types'
):
[
'Conv2d'
,
'Linear'
],
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
def
_collector
(
self
,
buffer
:
List
,
weight_tensor
:
Tensor
)
->
Callable
[[
Module
,
Tensor
,
Tensor
],
None
]:
def
collect_taylor
(
grad
:
Tensor
):
if
len
(
buffer
)
<
self
.
training_batches
:
buffer
.
append
(
self
.
_calculate_taylor_expansion
(
weight_tensor
,
grad
))
return
collect_taylor
def
_calculate_taylor_expansion
(
self
,
weight_tensor
:
Tensor
,
grad
:
Tensor
)
->
Tensor
:
return
(
weight_tensor
.
detach
()
*
grad
.
detach
()).
data
.
pow
(
2
)
def
reset_tools
(
self
):
hook_targets
=
{
layer_info
.
name
:
layer_info
.
module
.
weight
for
layer_info
,
_
in
self
.
_detect_modules_to_compress
()}
collector_info
=
HookCollectorInfo
(
hook_targets
,
'tensor'
,
self
.
_collector
)
if
self
.
data_collector
is
None
:
self
.
data_collector
=
SingleHookTrainerBasedDataCollector
(
self
,
self
.
trainer
,
self
.
optimizer
,
self
.
criterion
,
1
,
collector_infos
=
[
collector_info
])
else
:
self
.
data_collector
.
reset
()
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
MultiDataNormMetricsCalculator
(
p
=
1
,
dim
=
0
)
if
self
.
sparsity_allocator
is
None
:
if
self
.
mode
==
'normal'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
dim
=
0
)
elif
self
.
mode
==
'global'
:
self
.
sparsity_allocator
=
GlobalSparsityAllocator
(
self
,
dim
=
0
)
elif
self
.
mode
==
'dependency_aware'
:
self
.
sparsity_allocator
=
Conv2dDependencyAwareAllocator
(
self
,
0
,
self
.
dummy_input
)
else
:
raise
NotImplementedError
(
'Only support mode `normal`, `global` and `dependency_aware`'
)
nni/algorithms/compression/v2/pytorch/pruning/tools/__init__.py
0 → 100644
View file @
e5c3ac63
from
.base
import
(
HookCollectorInfo
,
DataCollector
,
MetricsCalculator
,
SparsityAllocator
)
from
.data_collector
import
(
WeightDataCollector
,
WeightTrainerBasedDataCollector
,
SingleHookTrainerBasedDataCollector
)
from
.metrics_calculator
import
(
NormMetricsCalculator
,
MultiDataNormMetricsCalculator
,
DistMetricsCalculator
,
APoZRankMetricsCalculator
,
MeanRankMetricsCalculator
)
from
.sparsity_allocator
import
(
NormalSparsityAllocator
,
GlobalSparsityAllocator
,
Conv2dDependencyAwareAllocator
)
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
0 → 100644
View file @
e5c3ac63
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
types
from
typing
import
List
,
Dict
,
Optional
,
Callable
,
Union
import
torch
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.optim
import
Optimizer
from
nni.algorithms.compression.v2.pytorch.base
import
Compressor
,
LayerInfo
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'DataCollector'
,
'TrainerBasedDataCollector'
,
'HookCollectorInfo'
,
'MetricsCalculator'
,
'SparsityAllocator'
]
class
DataCollector
:
"""
An abstract class for collect the data needed by the compressor.
"""
def
__init__
(
self
,
compressor
:
Compressor
):
"""
Parameters
----------
compressor
The compressor binded with this DataCollector.
"""
self
.
compressor
=
compressor
def
reset
(
self
):
"""
Reset the `DataCollector`.
"""
raise
NotImplementedError
()
def
collect
(
self
)
->
Dict
:
"""
Collect the compressor needed data, i.e., module weight, the output of activation function.
Returns
-------
Dict
Usually has format like {module_name: tensor_type_data}.
"""
raise
NotImplementedError
()
class
HookCollectorInfo
:
def
__init__
(
self
,
targets
:
Union
[
Dict
[
str
,
Tensor
],
List
[
LayerInfo
]],
hook_type
:
str
,
collector
:
Union
[
Callable
[[
List
,
Tensor
],
Callable
[[
Tensor
],
None
]],
Callable
[[
List
],
Callable
[[
Module
,
Tensor
,
Tensor
],
None
]]]):
"""
This class used to aggregate the information of what kind of hook is placed on which layers.
Parameters
----------
targets
List of LayerInfo or Dict of {layer_name: weight_tensor}, the hook targets.
hook_type
'forward' or 'backward'.
collector
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor, the output is a hook function.
The buffer is used to store the data wanted to hook.
"""
self
.
targets
=
targets
self
.
hook_type
=
hook_type
self
.
collector
=
collector
class
TrainerBasedDataCollector
(
DataCollector
):
"""
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
"""
def
__init__
(
self
,
compressor
:
Compressor
,
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_epochs
:
int
,
opt_before_tasks
:
List
=
[],
opt_after_tasks
:
List
=
[],
collector_infos
:
List
[
HookCollectorInfo
]
=
[],
criterion_patch
:
Callable
[[
Callable
],
Callable
]
=
None
):
"""
Parameters
----------
compressor
The compressor binded with this DataCollector.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_epochs
The total number of calling trainer.
opt_before_tasks
A list of function that will be called one by one before origin `optimizer.step()`.
Note that these functions will be patched into `optimizer.step()`.
opt_after_tasks
A list of function that will be called one by one after origin `optimizer.step()`.
Note that these functions will be patched into `optimizer.step()`.
collector_infos
A list of `HookCollectorInfo` instance. And the hooks will be registered in `__init__`.
criterion_patch
A callable function used to patch the criterion. Take a criterion function as input and return a new one.
Example::
def criterion_patch(criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]:
weight = ...
def patched_criterion(output, target):
return criterion(output, target) + torch.norm(weight)
return patched_criterion
"""
super
().
__init__
(
compressor
)
self
.
trainer
=
trainer
self
.
training_epochs
=
training_epochs
self
.
_origin_optimizer
=
optimizer
self
.
_origin_criterion
=
criterion
self
.
_opt_before_tasks
=
opt_before_tasks
self
.
_opt_after_tasks
=
opt_after_tasks
self
.
_collector_infos
=
collector_infos
self
.
_criterion_patch
=
criterion_patch
self
.
reset
()
def
reset
(
self
):
# refresh optimizer and criterion
self
.
compressor
.
_unwrap_model
()
if
self
.
_origin_optimizer
is
not
None
:
optimizer_cls
=
self
.
_origin_optimizer
.
__class__
if
optimizer_cls
.
__name__
==
'SGD'
:
self
.
optimizer
=
optimizer_cls
(
self
.
compressor
.
bound_model
.
parameters
(),
lr
=
0.001
)
else
:
self
.
optimizer
=
optimizer_cls
(
self
.
compressor
.
bound_model
.
parameters
())
self
.
optimizer
.
load_state_dict
(
self
.
_origin_optimizer
.
state_dict
())
else
:
self
.
optimizer
=
None
if
self
.
_criterion_patch
is
not
None
:
self
.
criterion
=
self
.
_criterion_patch
(
self
.
_origin_criterion
)
else
:
self
.
criterion
=
self
.
_origin_criterion
self
.
compressor
.
_wrap_model
()
# patch optimizer
self
.
_patch_optimizer
()
# hook
self
.
_remove_all_hook
()
self
.
_hook_id
=
0
self
.
_hook_handles
=
{}
self
.
_hook_buffer
=
{}
self
.
_add_all_hook
()
def
_patch_optimizer
(
self
):
def
patch_step
(
old_step
):
def
new_step
(
_
,
*
args
,
**
kwargs
):
for
task
in
self
.
_opt_before_tasks
:
task
()
# call origin optimizer step method
output
=
old_step
(
*
args
,
**
kwargs
)
for
task
in
self
.
_opt_after_tasks
:
task
()
return
output
return
new_step
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
step
=
types
.
MethodType
(
patch_step
(
self
.
optimizer
.
step
),
self
.
optimizer
)
def
_add_hook
(
self
,
collector_info
:
HookCollectorInfo
)
->
int
:
self
.
_hook_id
+=
1
self
.
_hook_handles
[
self
.
_hook_id
]
=
{}
self
.
_hook_buffer
[
self
.
_hook_id
]
=
{}
if
collector_info
.
hook_type
==
'forward'
:
self
.
_add_forward_hook
(
self
.
_hook_id
,
collector_info
.
targets
,
collector_info
.
collector
)
elif
collector_info
.
hook_type
==
'backward'
:
self
.
_add_backward_hook
(
self
.
_hook_id
,
collector_info
.
targets
,
collector_info
.
collector
)
elif
collector_info
.
hook_type
==
'tensor'
:
self
.
_add_tensor_hook
(
self
.
_hook_id
,
collector_info
.
targets
,
collector_info
.
collector
)
else
:
_logger
.
warning
(
'Skip unsupported hook type: %s'
,
collector_info
.
hook_type
)
return
self
.
_hook_id
def
_add_forward_hook
(
self
,
hook_id
:
int
,
layers
:
List
[
LayerInfo
],
collector
:
Callable
[[
List
],
Callable
[[
Module
,
Tensor
,
Tensor
],
None
]]):
assert
all
(
isinstance
(
layer_info
,
LayerInfo
)
for
layer_info
in
layers
)
for
layer
in
layers
:
self
.
_hook_buffer
[
hook_id
][
layer
.
name
]
=
[]
handle
=
layer
.
module
.
register_forward_hook
(
collector
(
self
.
_hook_buffer
[
hook_id
][
layer
.
name
]))
self
.
_hook_handles
[
hook_id
][
layer
.
name
]
=
handle
def
_add_backward_hook
(
self
,
hook_id
:
int
,
layers
:
List
[
LayerInfo
],
collector
:
Callable
[[
List
],
Callable
[[
Module
,
Tensor
,
Tensor
],
None
]]):
assert
all
(
isinstance
(
layer_info
,
LayerInfo
)
for
layer_info
in
layers
)
for
layer
in
layers
:
self
.
_hook_buffer
[
hook_id
][
layer
.
name
]
=
[]
handle
=
layer
.
module
.
register_backward_hook
(
collector
(
self
.
_hook_buffer
[
hook_id
][
layer
.
name
]))
self
.
_hook_handles
[
hook_id
][
layer
.
name
]
=
handle
def
_add_tensor_hook
(
self
,
hook_id
:
int
,
tensors
:
Dict
[
str
,
Tensor
],
collector
:
Callable
[[
List
,
Tensor
],
Callable
[[
Tensor
],
None
]]):
assert
all
(
isinstance
(
tensor
,
Tensor
)
for
_
,
tensor
in
tensors
.
items
())
for
layer_name
,
tensor
in
tensors
.
items
():
self
.
_hook_buffer
[
hook_id
][
layer_name
]
=
[]
handle
=
tensor
.
register_hook
(
collector
(
self
.
_hook_buffer
[
hook_id
][
layer_name
],
tensor
))
self
.
_hook_handles
[
hook_id
][
layer_name
]
=
handle
def
_remove_hook
(
self
,
hook_id
:
int
):
if
hook_id
not
in
self
.
_hook_handles
:
raise
ValueError
(
"%s is not a valid collector id"
%
str
(
hook_id
))
for
handle
in
self
.
_hook_handles
[
hook_id
]:
handle
.
remove
()
del
self
.
_hook_handles
[
hook_id
]
def
_add_all_hook
(
self
):
for
collector_info
in
self
.
_collector_infos
:
self
.
_add_hook
(
collector_info
)
def
_remove_all_hook
(
self
):
if
hasattr
(
self
,
'_hook_handles'
):
for
hook_id
in
list
(
self
.
_hook_handles
.
keys
()):
self
.
_remove_hook
(
hook_id
)
class
MetricsCalculator
:
"""
An abstract class for calculate a kind of metrics of the given data.
"""
def
__init__
(
self
,
dim
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
block_sparse_size
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
block_sparse_size
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The under pruning weight size is (768, 768), and you want to apply a block sparse on dim=[0] with block size [64, 768],
then you can set block_sparse_size=[64]. The final metric size is (12,).
"""
self
.
dim
=
dim
if
not
isinstance
(
dim
,
int
)
else
[
dim
]
self
.
block_sparse_size
=
block_sparse_size
if
not
isinstance
(
block_sparse_size
,
int
)
else
[
block_sparse_size
]
if
self
.
block_sparse_size
is
not
None
:
assert
all
(
i
>=
1
for
i
in
self
.
block_sparse_size
)
elif
self
.
dim
is
not
None
:
self
.
block_sparse_size
=
[
1
]
*
len
(
self
.
dim
)
if
self
.
dim
is
not
None
:
assert
all
(
i
>=
0
for
i
in
self
.
dim
)
self
.
dim
,
self
.
block_sparse_size
=
(
list
(
t
)
for
t
in
zip
(
*
sorted
(
zip
(
self
.
dim
,
self
.
block_sparse_size
))))
def
calculate_metrics
(
self
,
data
:
Dict
)
->
Dict
[
str
,
Tensor
]:
"""
Parameters
----------
data
A dict handle the data used to calculate metrics. Usually has format like {module_name: tensor_type_data}.
Returns
-------
Dict[str, Tensor]
The key is the layer_name, value is the metric.
Note that the metric has the same size with the data size on `dim`.
"""
raise
NotImplementedError
()
class
SparsityAllocator
:
"""
An abstract class for allocate mask based on metrics.
"""
def
__init__
(
self
,
pruner
:
Compressor
,
dim
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
block_sparse_size
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
):
"""
Parameters
----------
pruner
The pruner that binded with this `SparsityAllocator`.
dim
The under pruning weight dimensions, which metric size should equal to the under pruning weight size on these dimensions.
None means one-to-one correspondence between pruned dimensions and metric, which equal to set `dim` as all under pruning weight dimensions.
The mask will expand to the weight size depend on `dim`.
Example:
The under pruning weight has size (2, 3, 4), and `dim=1` means the under pruning weight dimension is 1.
Then the metric should have a size (3,), i.e., `metric=[0.9, 0.1, 0.8]`.
Assuming by some kind of `SparsityAllocator` get the mask on weight dimension 1 `mask=[1, 0, 1]`,
then the dimension mask will expand to the final mask `[[[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]], [[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]]]`.
block_sparse_size
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The metric size is (12,), and block_sparse_size=[64], then the mask will expand to (768,) at first before expand with `dim`.
"""
self
.
pruner
=
pruner
self
.
dim
=
dim
if
not
isinstance
(
dim
,
int
)
else
[
dim
]
self
.
block_sparse_size
=
block_sparse_size
if
not
isinstance
(
block_sparse_size
,
int
)
else
[
block_sparse_size
]
if
self
.
block_sparse_size
is
not
None
:
assert
all
(
i
>=
1
for
i
in
self
.
block_sparse_size
)
elif
self
.
dim
is
not
None
:
self
.
block_sparse_size
=
[
1
]
*
len
(
self
.
dim
)
if
self
.
dim
is
not
None
:
assert
all
(
i
>=
0
for
i
in
self
.
dim
)
self
.
dim
,
self
.
block_sparse_size
=
(
list
(
t
)
for
t
in
zip
(
*
sorted
(
zip
(
self
.
dim
,
self
.
block_sparse_size
))))
def
generate_sparsity
(
self
,
metrics
:
Dict
)
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
"""
Parameters
----------
metrics
A metric dict. The key is the name of layer, the value is its metric.
"""
raise
NotImplementedError
()
def
_expand_mask
(
self
,
name
:
str
,
mask
:
Tensor
)
->
Dict
[
str
,
Tensor
]:
"""
Parameters
----------
name
The masked module name.
mask
The reduced mask with `self.dim` and `self.block_sparse_size`.
Returns
-------
Dict[str, Tensor]
The key is `weight_mask` or `bias_mask`, value is the final mask.
"""
weight_mask
=
mask
.
clone
()
if
self
.
block_sparse_size
is
not
None
:
# expend mask with block_sparse_size
expand_size
=
list
(
weight_mask
.
size
())
reshape_size
=
list
(
weight_mask
.
size
())
for
i
,
block_width
in
reversed
(
list
(
enumerate
(
self
.
block_sparse_size
))):
weight_mask
=
weight_mask
.
unsqueeze
(
i
+
1
)
expand_size
.
insert
(
i
+
1
,
block_width
)
reshape_size
[
i
]
*=
block_width
weight_mask
=
weight_mask
.
expand
(
expand_size
).
reshape
(
reshape_size
)
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
weight_size
=
wrapper
.
module
.
weight
.
data
.
size
()
if
self
.
dim
is
None
:
assert
weight_mask
.
size
()
==
weight_size
expand_mask
=
{
'weight_mask'
:
weight_mask
}
else
:
# expand mask to weight size with dim
assert
len
(
weight_mask
.
size
())
==
len
(
self
.
dim
)
assert
all
(
weight_size
[
j
]
==
weight_mask
.
size
(
i
)
for
i
,
j
in
enumerate
(
self
.
dim
))
idxs
=
list
(
range
(
len
(
weight_size
)))
[
idxs
.
pop
(
i
)
for
i
in
reversed
(
self
.
dim
)]
for
i
in
idxs
:
weight_mask
=
weight_mask
.
unsqueeze
(
i
)
expand_mask
=
{
'weight_mask'
:
weight_mask
.
expand
(
weight_size
).
clone
()}
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence.
# If we support more kind of masks, this place need refactor.
if
wrapper
.
bias_mask
is
not
None
and
weight_mask
.
size
()
==
wrapper
.
bias_mask
.
size
():
expand_mask
[
'bias_mask'
]
=
weight_mask
.
clone
()
return
expand_mask
def
_compress_mask
(
self
,
mask
:
Tensor
)
->
Tensor
:
"""
Parameters
----------
name
The masked module name.
mask
The entire mask has the same size with weight.
Returns
-------
Tensor
Reduce the mask with `self.dim` and `self.block_sparse_size`.
"""
if
self
.
dim
is
None
or
len
(
mask
.
size
())
==
1
:
mask
=
mask
.
clone
()
else
:
mask_dim
=
list
(
range
(
len
(
mask
.
size
())))
for
dim
in
self
.
dim
:
mask_dim
.
remove
(
dim
)
mask
=
torch
.
sum
(
mask
,
dim
=
mask_dim
)
if
self
.
block_sparse_size
is
not
None
:
# operation like pooling
lower_case_letters
=
'abcdefghijklmnopqrstuvwxyz'
ein_expression
=
''
for
i
,
step
in
enumerate
(
self
.
block_sparse_size
):
mask
=
mask
.
unfold
(
i
,
step
,
step
)
ein_expression
+=
lower_case_letters
[
i
]
ein_expression
=
'...{},{}'
.
format
(
ein_expression
,
ein_expression
)
mask
=
torch
.
einsum
(
ein_expression
,
mask
,
torch
.
ones
(
self
.
block_sparse_size
))
return
(
mask
!=
0
).
type_as
(
mask
)
nni/algorithms/compression/v2/pytorch/pruning/tools/data_collector.py
0 → 100644
View file @
e5c3ac63
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
typing
import
Dict
,
List
from
torch
import
Tensor
from
.base
import
DataCollector
,
TrainerBasedDataCollector
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'WeightDataCollector'
,
'WeightTrainerBasedDataCollector'
,
'SingleHookTrainerBasedDataCollector'
]
class
WeightDataCollector
(
DataCollector
):
"""
Collect all wrapper weights.
"""
def
reset
(
self
):
pass
def
collect
(
self
)
->
Dict
[
str
,
Tensor
]:
data
=
{}
for
_
,
wrapper
in
self
.
compressor
.
get_modules_wrapper
().
items
():
data
[
wrapper
.
name
]
=
wrapper
.
module
.
weight
.
data
.
clone
().
detach
()
return
data
class
WeightTrainerBasedDataCollector
(
TrainerBasedDataCollector
):
"""
Collect all wrapper weights after training or inference.
"""
def
collect
(
self
)
->
Dict
[
str
,
Tensor
]:
for
_
in
range
(
self
.
training_epochs
):
self
.
trainer
(
self
.
compressor
.
bound_model
,
self
.
optimizer
,
self
.
criterion
)
data
=
{}
for
_
,
wrapper
in
self
.
compressor
.
get_modules_wrapper
().
items
():
data
[
wrapper
.
name
]
=
wrapper
.
module
.
weight
.
data
.
clone
().
detach
()
return
data
class
SingleHookTrainerBasedDataCollector
(
TrainerBasedDataCollector
):
"""
Add hooks and collect data during training or inference.
Single means each wrapper only has one hook to collect data.
"""
def
collect
(
self
)
->
Dict
[
str
,
List
[
Tensor
]]:
for
_
in
range
(
self
.
training_epochs
):
self
.
trainer
(
self
.
compressor
.
bound_model
,
self
.
optimizer
,
self
.
criterion
)
data
=
{}
[
data
.
update
(
buffer_dict
)
for
_
,
buffer_dict
in
self
.
_hook_buffer
.
items
()]
return
data
nni/algorithms/compression/v2/pytorch/pruning/tools/metrics_calculator.py
0 → 100644
View file @
e5c3ac63
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
from
torch
import
Tensor
from
.base
import
MetricsCalculator
__all__
=
[
'NormMetricsCalculator'
,
'MultiDataNormMetricsCalculator'
,
'DistMetricsCalculator'
,
'APoZRankMetricsCalculator'
,
'MeanRankMetricsCalculator'
]
class
NormMetricsCalculator
(
MetricsCalculator
):
"""
Calculate the specify norm for each tensor in data.
L1, L2, Level, Slim pruner use this to calculate metric.
"""
def
__init__
(
self
,
dim
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
p
:
Optional
[
Union
[
int
,
float
]]
=
None
):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
p
The order of norm. None means Frobenius norm.
"""
super
().
__init__
(
dim
=
dim
)
self
.
p
=
p
if
p
is
not
None
else
'fro'
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Tensor
]:
metrics
=
{}
for
name
,
tensor
in
data
.
items
():
keeped_dim
=
list
(
range
(
len
(
tensor
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
across_dim
=
list
(
range
(
len
(
tensor
.
size
())))
[
across_dim
.
pop
(
i
)
for
i
in
reversed
(
keeped_dim
)]
if
len
(
across_dim
)
==
0
:
metrics
[
name
]
=
tensor
.
abs
()
else
:
metrics
[
name
]
=
tensor
.
norm
(
p
=
self
.
p
,
dim
=
across_dim
)
return
metrics
class
MultiDataNormMetricsCalculator
(
NormMetricsCalculator
):
"""
Sum each list of tensor in data at first, then calculate the specify norm for each sumed tensor.
TaylorFO pruner use this to calculate metric.
"""
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
[
Tensor
]])
->
Dict
[
str
,
Tensor
]:
new_data
=
{
name
:
sum
(
list_tensor
)
for
name
,
list_tensor
in
data
.
items
()}
return
super
().
calculate_metrics
(
new_data
)
class
DistMetricsCalculator
(
MetricsCalculator
):
"""
Calculate the sum of specify distance for each element with all other elements in specify `dim` in each tensor in data.
FPGM pruner use this to calculate metric.
"""
def
__init__
(
self
,
p
:
float
,
dim
:
Union
[
int
,
List
[
int
]]):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
p
The order of norm.
"""
super
().
__init__
(
dim
=
dim
)
self
.
p
=
p
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Tensor
]:
metrics
=
{}
for
name
,
tensor
in
data
.
items
():
keeped_dim
=
list
(
range
(
len
(
tensor
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
reorder_dim
=
list
(
keeped_dim
)
reorder_dim
.
extend
([
i
for
i
in
range
(
len
(
tensor
.
size
()))
if
i
not
in
keeped_dim
])
reorder_tensor
=
tensor
.
permute
(
*
reorder_dim
).
clone
()
metric
=
torch
.
ones
(
*
reorder_tensor
.
size
()[:
len
(
keeped_dim
)],
device
=
reorder_tensor
.
device
)
across_dim
=
list
(
range
(
len
(
keeped_dim
),
len
(
reorder_dim
)))
idxs
=
metric
.
nonzero
()
for
idx
in
idxs
:
other
=
reorder_tensor
for
i
in
idx
:
other
=
other
[
i
]
other
=
other
.
clone
()
if
len
(
across_dim
)
==
0
:
dist_sum
=
torch
.
abs
(
reorder_tensor
-
other
).
sum
()
else
:
dist_sum
=
torch
.
norm
((
reorder_tensor
-
other
),
p
=
self
.
p
,
dim
=
across_dim
).
sum
()
# NOTE: this place need refactor when support layer level pruning.
tmp_metric
=
metric
for
i
in
idx
[:
-
1
]:
tmp_metric
=
tmp_metric
[
i
]
tmp_metric
[
idx
[
-
1
]]
=
dist_sum
metrics
[
name
]
=
metric
return
metrics
class
APoZRankMetricsCalculator
(
MetricsCalculator
):
"""
This metric counts the zero number at the same position in the tensor list in data,
then sum the zero number on `dim` and calculate the non-zero rate.
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
APoZRank pruner use this to calculate metric.
"""
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
[
Tensor
]])
->
Dict
[
str
,
Tensor
]:
metrics
=
{}
for
name
,
tensor_list
in
data
.
items
():
# NOTE: dim=0 means the batch dim is 0
activations
=
torch
.
cat
(
tensor_list
,
dim
=
0
)
_eq_zero
=
torch
.
eq
(
activations
,
torch
.
zeros_like
(
activations
))
keeped_dim
=
list
(
range
(
len
(
_eq_zero
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
across_dim
=
list
(
range
(
len
(
_eq_zero
.
size
())))
[
across_dim
.
pop
(
i
)
for
i
in
reversed
(
keeped_dim
)]
# The element number on each [keeped_dim + 1] in _eq_zero
total_size
=
1
for
dim
,
dim_size
in
enumerate
(
_eq_zero
.
size
()):
if
dim
not
in
keeped_dim
:
total_size
*=
dim_size
_apoz
=
torch
.
sum
(
_eq_zero
,
dim
=
across_dim
,
dtype
=
torch
.
float64
)
/
total_size
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
metrics
[
name
]
=
torch
.
ones_like
(
_apoz
)
-
_apoz
return
metrics
class
MeanRankMetricsCalculator
(
MetricsCalculator
):
"""
This metric simply concat the list of tensor on dim 0, and average on `dim`.
MeanRank pruner use this to calculate metric.
"""
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
[
Tensor
]])
->
Dict
[
str
,
Tensor
]:
metrics
=
{}
for
name
,
tensor_list
in
data
.
items
():
# NOTE: dim=0 means the batch dim is 0
activations
=
torch
.
cat
(
tensor_list
,
dim
=
0
)
keeped_dim
=
list
(
range
(
len
(
activations
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
across_dim
=
list
(
range
(
len
(
activations
.
size
())))
[
across_dim
.
pop
(
i
)
for
i
in
reversed
(
keeped_dim
)]
metrics
[
name
]
=
torch
.
mean
(
activations
,
across_dim
)
return
metrics
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
0 → 100644
View file @
e5c3ac63
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
torch
import
Tensor
from
nni.algorithms.compression.v2.pytorch.base
import
Pruner
from
nni.compression.pytorch.utils.shape_dependency
import
ChannelDependency
,
GroupDependency
from
.base
import
SparsityAllocator
class
NormalSparsityAllocator
(
SparsityAllocator
):
"""
This allocator simply pruned the weight with smaller metrics in layer level.
"""
def
generate_sparsity
(
self
,
metrics
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
masks
=
{}
for
name
,
wrapper
in
self
.
pruner
.
get_modules_wrapper
().
items
():
sparsity_rate
=
wrapper
.
config
[
'sparsity_per_layer'
]
assert
name
in
metrics
,
'Metric of %s is not calculated.'
metric
=
metrics
[
name
]
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
prune_num
=
int
(
sparsity_rate
*
metric
.
numel
())
if
prune_num
==
0
:
continue
threshold
=
torch
.
topk
(
metric
.
view
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
return
masks
class
GlobalSparsityAllocator
(
SparsityAllocator
):
"""
This allocator pruned the weight with smaller metrics in group level.
This means all layers in a group will sort metrics uniformly.
The layers with the same config in config_list is a group.
"""
def
generate_sparsity
(
self
,
metrics
:
Dict
)
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
masks
=
{}
# {group_index: {layer_name: metric}}
grouped_metrics
=
{
idx
:
{
name
:
metrics
[
name
]
for
name
in
names
}
for
idx
,
names
in
self
.
pruner
.
generate_module_groups
().
items
()}
for
_
,
group_metric_dict
in
grouped_metrics
.
items
():
threshold
,
sub_thresholds
=
self
.
_calculate_threshold
(
group_metric_dict
)
for
name
,
metric
in
group_metric_dict
.
items
():
mask
=
torch
.
gt
(
metric
,
min
(
threshold
,
sub_thresholds
[
name
])).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
return
masks
def
_calculate_threshold
(
self
,
group_metric_dict
:
Dict
[
str
,
Tensor
])
->
Tuple
[
float
,
Dict
[
str
,
float
]]:
metric_list
=
[]
sub_thresholds
=
{}
total_weight_num
=
0
temp_wrapper_config
=
self
.
pruner
.
get_modules_wrapper
()[
list
(
group_metric_dict
.
keys
())[
0
]].
config
total_sparsity
=
temp_wrapper_config
[
'total_sparsity'
]
max_sparsity_per_layer
=
temp_wrapper_config
.
get
(
'max_sparsity_per_layer'
,
1.0
)
for
name
,
metric
in
group_metric_dict
.
items
():
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
metric
=
metric
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
print
(
metric
)
layer_weight_num
=
wrapper
.
module
.
weight
.
data
.
numel
()
stay_num
=
int
(
metric
.
numel
()
*
max_sparsity_per_layer
)
# Remove the weight parts that must be left
stay_metric
=
torch
.
topk
(
metric
.
view
(
-
1
),
stay_num
,
largest
=
False
)[
0
]
sub_thresholds
[
name
]
=
stay_metric
.
max
()
expend_times
=
int
(
layer_weight_num
/
metric
.
numel
())
if
expend_times
>
1
:
stay_metric
=
stay_metric
.
expand
(
stay_num
,
int
(
layer_weight_num
/
metric
.
numel
())).
view
(
-
1
)
metric_list
.
append
(
stay_metric
)
total_weight_num
+=
layer_weight_num
assert
total_sparsity
<=
max_sparsity_per_layer
,
'total_sparsity should less than max_sparsity_per_layer.'
total_prune_num
=
int
(
total_sparsity
*
total_weight_num
)
threshold
=
torch
.
topk
(
torch
.
cat
(
metric_list
).
view
(
-
1
),
total_prune_num
,
largest
=
False
)[
0
].
max
().
item
()
return
threshold
,
sub_thresholds
class
Conv2dDependencyAwareAllocator
(
SparsityAllocator
):
"""
A specify allocator for Conv2d with dependency aware.
"""
def
__init__
(
self
,
pruner
:
Pruner
,
dim
:
int
,
dummy_input
:
Any
):
assert
isinstance
(
dim
,
int
),
'Only support single dim in Conv2dDependencyAwareAllocator.'
super
().
__init__
(
pruner
,
dim
=
dim
)
self
.
dummy_input
=
dummy_input
def
_get_dependency
(
self
):
graph
=
self
.
pruner
.
generate_graph
(
dummy_input
=
self
.
dummy_input
)
self
.
channel_depen
=
ChannelDependency
(
traced_model
=
graph
.
trace
).
dependency_sets
self
.
group_depen
=
GroupDependency
(
traced_model
=
graph
.
trace
).
dependency_sets
def
generate_sparsity
(
self
,
metrics
:
Dict
)
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
self
.
_get_dependency
()
masks
=
{}
grouped_metrics
=
{}
for
idx
,
names
in
enumerate
(
self
.
channel_depen
):
grouped_metric
=
{
name
:
metrics
[
name
]
*
self
.
_compress_mask
(
self
.
pruner
.
get_modules_wrapper
()[
name
].
weight_mask
)
for
name
in
names
if
name
in
metrics
}
if
len
(
grouped_metric
)
>
0
:
grouped_metrics
[
idx
]
=
grouped_metric
for
_
,
group_metric_dict
in
grouped_metrics
.
items
():
group_metric
=
self
.
_group_metric_calculate
(
group_metric_dict
)
sparsities
=
{
name
:
self
.
pruner
.
get_modules_wrapper
()[
name
].
config
[
'sparsity_per_layer'
]
for
name
in
group_metric_dict
.
keys
()}
min_sparsity
=
min
(
sparsities
.
values
())
conv2d_groups
=
[
self
.
group_depen
[
name
]
for
name
in
group_metric_dict
.
keys
()]
max_conv2d_group
=
np
.
lcm
.
reduce
(
conv2d_groups
)
pruned_per_conv2d_group
=
int
(
group_metric
.
numel
()
/
max_conv2d_group
*
min_sparsity
)
conv2d_group_step
=
int
(
group_metric
.
numel
()
/
max_conv2d_group
)
group_mask
=
[]
for
gid
in
range
(
max_conv2d_group
):
_start
=
gid
*
conv2d_group_step
_end
=
(
gid
+
1
)
*
conv2d_group_step
if
pruned_per_conv2d_group
>
0
:
threshold
=
torch
.
topk
(
group_metric
[
_start
:
_end
],
pruned_per_conv2d_group
,
largest
=
False
)[
0
].
max
()
conv2d_group_mask
=
torch
.
gt
(
group_metric
[
_start
:
_end
],
threshold
).
type_as
(
group_metric
)
else
:
conv2d_group_mask
=
torch
.
ones
(
conv2d_group_step
,
device
=
group_metric
.
device
)
group_mask
.
append
(
conv2d_group_mask
)
group_mask
=
torch
.
cat
(
group_mask
,
dim
=
0
)
for
name
,
metric
in
group_metric_dict
.
items
():
metric
=
(
metric
-
metric
.
min
())
*
group_mask
pruned_num
=
int
(
sparsities
[
name
]
*
len
(
metric
))
threshold
=
torch
.
topk
(
metric
,
pruned_num
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
return
masks
def
_group_metric_calculate
(
self
,
group_metrics
:
Union
[
Dict
[
str
,
Tensor
],
List
[
Tensor
]])
->
Tensor
:
"""
Add all metric value in the same position in one group.
"""
group_metrics
=
list
(
group_metrics
.
values
())
if
isinstance
(
group_metrics
,
dict
)
else
group_metrics
assert
all
(
group_metrics
[
0
].
size
()
==
group_metric
.
size
()
for
group_metric
in
group_metrics
),
'Metrics size do not match.'
group_sum_metric
=
torch
.
zeros
(
group_metrics
[
0
].
size
(),
device
=
group_metrics
[
0
].
device
)
for
group_metric
in
group_metrics
:
group_sum_metric
+=
group_metric
return
group_sum_metric
nni/algorithms/compression/v2/pytorch/utils/__init__.py
0 → 100644
View file @
e5c3ac63
nni/algorithms/compression/v2/pytorch/utils/config_validation.py
0 → 100644
View file @
e5c3ac63
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
schema
import
Schema
,
And
,
SchemaError
def
validate_op_names
(
model
,
op_names
,
logger
):
found_names
=
set
(
map
(
lambda
x
:
x
[
0
],
model
.
named_modules
()))
not_found_op_names
=
list
(
set
(
op_names
)
-
found_names
)
if
not_found_op_names
:
logger
.
warning
(
'op_names %s not found in model'
,
not_found_op_names
)
return
True
def
validate_op_types
(
model
,
op_types
,
logger
):
found_types
=
set
([
'default'
])
|
set
(
map
(
lambda
x
:
type
(
x
[
1
]).
__name__
,
model
.
named_modules
()))
not_found_op_types
=
list
(
set
(
op_types
)
-
found_types
)
if
not_found_op_types
:
logger
.
warning
(
'op_types %s not found in model'
,
not_found_op_types
)
return
True
def
validate_op_types_op_names
(
data
):
if
not
(
'op_types'
in
data
or
'op_names'
in
data
):
raise
SchemaError
(
'Either op_types or op_names must be specified.'
)
return
True
class
CompressorSchema
:
def
__init__
(
self
,
data_schema
,
model
,
logger
):
assert
isinstance
(
data_schema
,
list
)
and
len
(
data_schema
)
<=
1
self
.
data_schema
=
data_schema
self
.
compressor_schema
=
Schema
(
self
.
_modify_schema
(
data_schema
,
model
,
logger
))
def
_modify_schema
(
self
,
data_schema
,
model
,
logger
):
if
not
data_schema
:
return
data_schema
for
k
in
data_schema
[
0
]:
old_schema
=
data_schema
[
0
][
k
]
if
k
==
'op_types'
or
(
isinstance
(
k
,
Schema
)
and
k
.
_schema
==
'op_types'
):
new_schema
=
And
(
old_schema
,
lambda
n
:
validate_op_types
(
model
,
n
,
logger
))
data_schema
[
0
][
k
]
=
new_schema
if
k
==
'op_names'
or
(
isinstance
(
k
,
Schema
)
and
k
.
_schema
==
'op_names'
):
new_schema
=
And
(
old_schema
,
lambda
n
:
validate_op_names
(
model
,
n
,
logger
))
data_schema
[
0
][
k
]
=
new_schema
data_schema
[
0
]
=
And
(
data_schema
[
0
],
lambda
d
:
validate_op_types_op_names
(
d
))
return
data_schema
def
validate
(
self
,
data
):
self
.
compressor_schema
.
validate
(
data
)
def
validate_exclude_sparsity
(
data
):
if
not
(
'exclude'
in
data
or
'sparsity_per_layer'
in
data
or
'total_sparsity'
in
data
):
raise
SchemaError
(
'One of [sparsity_per_layer, total_sparsity, exclude] should be specified.'
)
return
True
def
validate_exclude_quant_types_quant_bits
(
data
):
if
not
(
'exclude'
in
data
or
(
'quant_types'
in
data
and
'quant_bits'
in
data
)):
raise
SchemaError
(
'Either (quant_types and quant_bits) or exclude must be specified.'
)
return
True
class
PrunerSchema
(
CompressorSchema
):
def
_modify_schema
(
self
,
data_schema
,
model
,
logger
):
data_schema
=
super
().
_modify_schema
(
data_schema
,
model
,
logger
)
data_schema
[
0
]
=
And
(
data_schema
[
0
],
lambda
d
:
validate_exclude_sparsity
(
d
))
return
data_schema
class
QuantizerSchema
(
CompressorSchema
):
def
_modify_schema
(
self
,
data_schema
,
model
,
logger
):
data_schema
=
super
().
_modify_schema
(
data_schema
,
model
,
logger
)
data_schema
[
0
]
=
And
(
data_schema
[
0
],
lambda
d
:
validate_exclude_quant_types_quant_bits
(
d
))
return
data_schema
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment