Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f24dc27b
"docs/archive_en_US/TrialExample/Trials.md" did not exist on "c84ba2578454ef12bbf6a3d8560f9fc27ad81038"
Unverified
Commit
f24dc27b
authored
Jun 30, 2022
by
J-shang
Committed by
GitHub
Jun 30, 2022
Browse files
[Compression] block sparse refactor (#4932)
parent
00e4debb
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
619 additions
and
477 deletions
+619
-477
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+24
-23
nni/algorithms/compression/v2/pytorch/pruning/tools/__init__.py
...gorithms/compression/v2/pytorch/pruning/tools/__init__.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
+112
-149
nni/algorithms/compression/v2/pytorch/pruning/tools/metrics_calculator.py
...ompression/v2/pytorch/pruning/tools/metrics_calculator.py
+58
-111
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+156
-187
nni/algorithms/compression/v2/pytorch/utils/__init__.py
nni/algorithms/compression/v2/pytorch/utils/__init__.py
+6
-1
nni/algorithms/compression/v2/pytorch/utils/attr.py
nni/algorithms/compression/v2/pytorch/utils/attr.py
+32
-0
nni/algorithms/compression/v2/pytorch/utils/scaling.py
nni/algorithms/compression/v2/pytorch/utils/scaling.py
+195
-0
test/algo/compression/v2/test_pruning_tools_torch.py
test/algo/compression/v2/test_pruning_tools_torch.py
+6
-5
test/algo/compression/v2/test_scaling.py
test/algo/compression/v2/test_scaling.py
+29
-0
No files found.
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
f24dc27b
...
...
@@ -13,8 +13,7 @@ from torch.nn import Module
from
torch.optim
import
Optimizer
from
nni.common.serializer
import
Traceable
from
nni.algorithms.compression.v2.pytorch.base.pruner
import
Pruner
from
nni.algorithms.compression.v2.pytorch.utils
import
CompressorSchema
,
config_list_canonical
,
OptimizerConstructHelper
from
..base
import
Pruner
from
.tools
import
(
DataCollector
,
...
...
@@ -38,9 +37,11 @@ from .tools import (
NormalSparsityAllocator
,
BankSparsityAllocator
,
GlobalSparsityAllocator
,
Conv2d
DependencyAwareAllocator
DependencyAwareAllocator
)
from
..utils
import
CompressorSchema
,
config_list_canonical
,
OptimizerConstructHelper
,
Scaling
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'LevelPruner'
,
'L1NormPruner'
,
'L2NormPruner'
,
'FPGMPruner'
,
'SlimPruner'
,
'ActivationPruner'
,
...
...
@@ -275,12 +276,12 @@ class NormPruner(BasicPruner):
else
:
self
.
data_collector
.
reset
()
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
NormMetricsCalculator
(
p
=
self
.
p
,
dim
=
0
)
self
.
metrics_calculator
=
NormMetricsCalculator
(
p
=
self
.
p
,
scalers
=
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
if
self
.
sparsity_allocator
is
None
:
if
self
.
mode
==
'normal'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
dim
=
0
)
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
elif
self
.
mode
==
'dependency_aware'
:
self
.
sparsity_allocator
=
Conv2d
DependencyAwareAllocator
(
self
,
0
,
self
.
dummy_input
)
self
.
sparsity_allocator
=
DependencyAwareAllocator
(
self
,
self
.
dummy_input
,
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
else
:
raise
NotImplementedError
(
'Only support mode `normal` and `dependency_aware`'
)
...
...
@@ -440,12 +441,12 @@ class FPGMPruner(BasicPruner):
else
:
self
.
data_collector
.
reset
()
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
DistMetricsCalculator
(
p
=
2
,
dim
=
0
)
self
.
metrics_calculator
=
DistMetricsCalculator
(
p
=
2
,
scalers
=
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
if
self
.
sparsity_allocator
is
None
:
if
self
.
mode
==
'normal'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
dim
=
0
)
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
elif
self
.
mode
==
'dependency_aware'
:
self
.
sparsity_allocator
=
Conv2d
DependencyAwareAllocator
(
self
,
0
,
self
.
dummy_input
)
self
.
sparsity_allocator
=
DependencyAwareAllocator
(
self
,
self
.
dummy_input
,
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
else
:
raise
NotImplementedError
(
'Only support mode `normal` and `dependency_aware`'
)
...
...
@@ -688,16 +689,16 @@ class ActivationPruner(BasicPruner):
else
:
self
.
data_collector
.
reset
(
collector_infos
=
[
collector_info
])
# type: ignore
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
self
.
_
get
_metrics_calculator
()
self
.
metrics_calculator
=
self
.
_
create
_metrics_calculator
()
if
self
.
sparsity_allocator
is
None
:
if
self
.
mode
==
'normal'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
dim
=
0
)
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
elif
self
.
mode
==
'dependency_aware'
:
self
.
sparsity_allocator
=
Conv2d
DependencyAwareAllocator
(
self
,
0
,
self
.
dummy_input
)
self
.
sparsity_allocator
=
DependencyAwareAllocator
(
self
,
self
.
dummy_input
,
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
else
:
raise
NotImplementedError
(
'Only support mode `normal` and `dependency_aware`'
)
def
_
get
_metrics_calculator
(
self
)
->
MetricsCalculator
:
def
_
create
_metrics_calculator
(
self
)
->
MetricsCalculator
:
raise
NotImplementedError
()
...
...
@@ -782,8 +783,8 @@ class ActivationAPoZRankPruner(ActivationPruner):
# return a matrix that the position of zero in `output` is one, others is zero.
return
torch
.
eq
(
self
.
_activation
(
output
.
detach
()),
torch
.
zeros_like
(
output
)).
type_as
(
output
)
def
_
get
_metrics_calculator
(
self
)
->
MetricsCalculator
:
return
APoZRankMetricsCalculator
(
dim
=
1
)
def
_
create
_metrics_calculator
(
self
)
->
MetricsCalculator
:
return
APoZRankMetricsCalculator
(
Scaling
(
kernel_size
=
[
-
1
,
1
],
kernel_padding_mode
=
'back'
)
)
class
ActivationMeanRankPruner
(
ActivationPruner
):
...
...
@@ -865,8 +866,8 @@ class ActivationMeanRankPruner(ActivationPruner):
# return the activation of `output` directly.
return
self
.
_activation
(
output
.
detach
())
def
_
get
_metrics_calculator
(
self
)
->
MetricsCalculator
:
return
MeanRankMetricsCalculator
(
dim
=
1
)
def
_
create
_metrics_calculator
(
self
)
->
MetricsCalculator
:
return
MeanRankMetricsCalculator
(
Scaling
(
kernel_size
=
[
-
1
,
1
],
kernel_padding_mode
=
'back'
)
)
class
TaylorFOWeightPruner
(
BasicPruner
):
...
...
@@ -1009,14 +1010,14 @@ class TaylorFOWeightPruner(BasicPruner):
else
:
self
.
data_collector
.
reset
(
collector_infos
=
[
collector_info
])
# type: ignore
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
MultiDataNormMetricsCalculator
(
p
=
1
,
dim
=
0
)
self
.
metrics_calculator
=
MultiDataNormMetricsCalculator
(
p
=
1
,
scalers
=
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
if
self
.
sparsity_allocator
is
None
:
if
self
.
mode
==
'normal'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
dim
=
0
)
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
elif
self
.
mode
==
'global'
:
self
.
sparsity_allocator
=
GlobalSparsityAllocator
(
self
,
dim
=
0
)
self
.
sparsity_allocator
=
GlobalSparsityAllocator
(
self
,
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
elif
self
.
mode
==
'dependency_aware'
:
self
.
sparsity_allocator
=
Conv2d
DependencyAwareAllocator
(
self
,
0
,
self
.
dummy_input
)
self
.
sparsity_allocator
=
DependencyAwareAllocator
(
self
,
self
.
dummy_input
,
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
else
:
raise
NotImplementedError
(
'Only support mode `normal`, `global` and `dependency_aware`'
)
...
...
@@ -1146,12 +1147,12 @@ class ADMMPruner(BasicPruner):
if
self
.
granularity
==
'fine-grained'
:
self
.
metrics_calculator
=
NormMetricsCalculator
(
p
=
1
)
elif
self
.
granularity
==
'coarse-grained'
:
self
.
metrics_calculator
=
NormMetricsCalculator
(
dim
=
0
,
p
=
1
)
self
.
metrics_calculator
=
NormMetricsCalculator
(
p
=
1
,
scalers
=
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
if
self
.
sparsity_allocator
is
None
:
if
self
.
granularity
==
'fine-grained'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
)
elif
self
.
granularity
==
'coarse-grained'
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
dim
=
0
)
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
def
compress
(
self
)
->
Tuple
[
Module
,
Dict
]:
"""
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/__init__.py
View file @
f24dc27b
...
...
@@ -25,7 +25,7 @@ from .sparsity_allocator import (
NormalSparsityAllocator
,
BankSparsityAllocator
,
GlobalSparsityAllocator
,
Conv2d
DependencyAwareAllocator
DependencyAwareAllocator
)
from
.task_generator
import
(
AGPTaskGenerator
,
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
View file @
f24dc27b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
from
datetime
import
datetime
import
logging
from
pathlib
import
Path
...
...
@@ -13,12 +14,24 @@ from torch import Tensor
from
torch.nn
import
Module
from
torch.optim
import
Optimizer
from
nni.algorithms.compression.v2.pytorch
.base
import
Pruner
,
LayerInfo
,
Task
,
TaskResult
from
nni.algorithms.compression.v2.pytorch
.utils
import
OptimizerConstructHelper
from
..
.base
import
Pruner
,
LayerInfo
,
Task
,
TaskResult
from
..
.utils
import
OptimizerConstructHelper
,
Scaling
_logger
=
logging
.
getLogger
(
__name__
)
def
_get_scaler
(
scalers
:
Dict
[
str
,
Dict
[
str
,
Scaling
]]
|
None
,
module_name
:
str
,
target_name
:
str
)
->
Scaling
|
None
:
# Get scaler for the specific target in the specific module. Return None if don't find it.
# `module_name` is not used in current nni version, will support different modules using different scalers in the future.
if
scalers
:
default_module_scalers
=
scalers
.
get
(
'_default'
,
{})
default_target_scaler
=
default_module_scalers
.
get
(
target_name
,
default_module_scalers
.
get
(
'_default'
,
None
))
module_scalers
=
scalers
.
get
(
module_name
,
{})
return
module_scalers
.
get
(
target_name
,
module_scalers
.
get
(
'_default'
,
default_target_scaler
))
else
:
return
None
class
DataCollector
:
"""
An abstract class for collect the data needed by the compressor.
...
...
@@ -245,49 +258,21 @@ class MetricsCalculator:
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
block_sparse_size
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The under pruning weight size is (768, 768), and you want to apply a block sparse on dim=[0] with block size [64, 768],
then you can set block_sparse_size=[64]. The final metric size is (12,).
scalers
Scaler is used to scale the metrics' size. It scaling metric to the same size as the shrinked mask in the sparsity allocator.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
Passing in `None` means no need to scale.
"""
def
__init__
(
self
,
dim
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
block_sparse_size
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
):
self
.
dim
=
dim
if
not
isinstance
(
dim
,
int
)
else
[
dim
]
self
.
block_sparse_size
=
block_sparse_size
if
not
isinstance
(
block_sparse_size
,
int
)
else
[
block_sparse_size
]
if
self
.
block_sparse_size
is
not
None
:
assert
all
(
i
>=
1
for
i
in
self
.
block_sparse_size
)
elif
self
.
dim
is
not
None
:
self
.
block_sparse_size
=
[
1
]
*
len
(
self
.
dim
)
if
self
.
dim
is
not
None
:
assert
all
(
i
>=
0
for
i
in
self
.
dim
)
self
.
dim
,
self
.
block_sparse_size
=
(
list
(
t
)
for
t
in
zip
(
*
sorted
(
zip
(
self
.
dim
,
self
.
block_sparse_size
))))
# type: ignore
def
__init__
(
self
,
scalers
:
Dict
[
str
,
Dict
[
str
,
Scaling
]]
|
Scaling
|
None
=
None
):
self
.
scalers
:
Dict
[
str
,
Dict
[
str
,
Scaling
]]
|
None
=
scalers
if
isinstance
(
scalers
,
(
dict
,
type
(
None
)))
else
{
'_default'
:
{
'_default'
:
scalers
}}
# type: ignore
def
_get_scaler
(
self
,
module_name
:
str
,
target_name
:
str
)
->
Scaling
:
scaler
=
_get_scaler
(
self
.
scalers
,
module_name
,
target_name
)
return
scaler
if
scaler
else
Scaling
([
1
])
def
calculate_metrics
(
self
,
data
:
Dict
)
->
Dict
[
str
,
Tensor
]:
"""
...
...
@@ -307,142 +292,120 @@ class MetricsCalculator:
class
SparsityAllocator
:
"""
A
n abstract
class for allocat
e
mask based on metrics.
A
base
class for allocat
ing
mask based on metrics.
Parameters
----------
pruner
The pruner that binded with this `SparsityAllocator`.
dim
The under pruning weight dimensions, which metric size should equal to the under pruning weight size on these dimensions.
None means one-to-one correspondence between pruned dimensions and metric, which equal to set `dim` as all under pruning weight dimensions.
The mask will expand to the weight size depend on `dim`.
Example:
The under pruning weight has size (2, 3, 4), and `dim=1` means the under pruning weight dimension is 1.
Then the metric should have a size (3,), i.e., `metric=[0.9, 0.1, 0.8]`.
Assuming by some kind of `SparsityAllocator` get the mask on weight dimension 1 `mask=[1, 0, 1]`,
then the dimension mask will expand to the final mask `[[[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]], [[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]]]`.
block_sparse_size
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The metric size is (12,), and block_sparse_size=[64], then the mask will expand to (768,) at first before expand with `dim`.
scalers
Scaler is used to scale the masks' size. It shrinks the mask of the same size as the pruning target to the same size as the metric,
or expands the mask of the same size as the metric to the same size as the pruning target.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
Passing in `None` means no need to scale.
continuous_mask
Inherit the mask already in the wrapper if set True.
If set True, the part that has been masked will be masked first.
If set False, the part that has been masked may be unmasked due to the increase of its corresponding metric.
"""
def
__init__
(
self
,
pruner
:
Pruner
,
dim
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
block_sparse_size
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
continuous_mask
:
bool
=
True
):
def
__init__
(
self
,
pruner
:
Pruner
,
scalers
:
Dict
[
str
,
Dict
[
str
,
Scaling
]]
|
Scaling
|
None
=
None
,
continuous_mask
:
bool
=
True
):
self
.
pruner
=
pruner
self
.
dim
=
dim
if
not
isinstance
(
dim
,
int
)
else
[
dim
]
self
.
block_sparse_size
=
block_sparse_size
if
not
isinstance
(
block_sparse_size
,
int
)
else
[
block_sparse_size
]
if
self
.
block_sparse_size
is
not
None
:
assert
all
(
i
>=
1
for
i
in
self
.
block_sparse_size
)
elif
self
.
dim
is
not
None
:
self
.
block_sparse_size
=
[
1
]
*
len
(
self
.
dim
)
if
self
.
dim
is
not
None
:
assert
all
(
i
>=
0
for
i
in
self
.
dim
)
self
.
dim
,
self
.
block_sparse_size
=
(
list
(
t
)
for
t
in
zip
(
*
sorted
(
zip
(
self
.
dim
,
self
.
block_sparse_size
))))
# type: ignore
self
.
scalers
:
Dict
[
str
,
Dict
[
str
,
Scaling
]]
|
None
=
scalers
if
isinstance
(
scalers
,
(
dict
,
type
(
None
)))
else
{
'_default'
:
{
'_default'
:
scalers
}}
# type: ignore
self
.
continuous_mask
=
continuous_mask
def
generate_sparsity
(
self
,
metrics
:
Dict
)
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
def
_get_scaler
(
self
,
module_name
:
str
,
target_name
:
str
)
->
Scaling
|
None
:
return
_get_scaler
(
self
.
scalers
,
module_name
,
target_name
)
def
_expand_mask
(
self
,
module_name
:
str
,
target_name
:
str
,
mask
:
Tensor
)
->
Tensor
:
# Expand the shrinked mask to the pruning target size.
scaler
=
self
.
_get_scaler
(
module_name
=
module_name
,
target_name
=
target_name
)
if
scaler
:
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
return
scaler
.
expand
(
mask
,
getattr
(
wrapper
,
f
'
{
target_name
}
_mask'
).
shape
)
else
:
return
mask
.
clone
()
def
_shrink_mask
(
self
,
module_name
:
str
,
target_name
:
str
,
mask
:
Tensor
)
->
Tensor
:
# Shrink the mask by scaler, shrinked mask usually has the same size with metric.
scaler
=
self
.
_get_scaler
(
module_name
=
module_name
,
target_name
=
target_name
)
if
scaler
:
mask
=
(
scaler
.
shrink
(
mask
)
!=
0
).
type_as
(
mask
)
return
mask
def
_continuous_mask
(
self
,
new_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
# Set the already masked part in the metric to the minimum value.
target_name
=
'weight'
for
module_name
,
target_mask
in
new_masks
.
items
():
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
old_target_mask
=
getattr
(
wrapper
,
f
'
{
target_name
}
_mask'
,
None
)
if
old_target_mask
is
not
None
:
new_masks
[
module_name
][
target_name
]
=
torch
.
min
(
target_mask
[
target_name
],
old_target_mask
)
return
new_masks
def
common_target_masks_generation
(
self
,
metrics
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
"""
Generate masks for metrics-dependent targets.
Parameters
----------
metrics
A metric dict. The key is the name of layer, the value is its metric.
The format is {module_name: weight_metric}.
The metric of `weight` usually has the same size with shrinked mask.
Return
------
Dict[str, Dict[str, Tensor]]
The format is {module_name: {target_name: mask}}.
Return the masks of the same size as its target.
"""
raise
NotImplementedError
()
def
_expand_mask
(
self
,
name
:
str
,
mask
:
Tensor
)
->
Dict
[
str
,
Tensor
]:
def
special_target_masks_generation
(
self
,
masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
)
->
Dict
[
str
,
Dict
[
str
,
Tensor
]
]
:
"""
Some pruning targets' mask generation depends on other targets, i.e., bias mask depends on weight mask.
This function is used to generate these masks, and it be called at the end of `generate_sparsity`.
Parameters
----------
name
The masked module name.
mask
The reduced mask with `self.dim` and `self.block_sparse_size`.
masks
The format is {module_name: {target_name: mask}}.
It is usually the return value of `common_target_masks_generation`.
"""
for
module_name
,
module_masks
in
masks
.
items
():
# generate bias mask, this may move to wrapper in the future
weight_mask
=
module_masks
.
get
(
'weight'
,
None
)
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
old_bias_mask
=
getattr
(
wrapper
,
'bias_mask'
,
None
)
if
weight_mask
is
not
None
and
old_bias_mask
is
not
None
and
weight_mask
.
shape
[
0
]
==
old_bias_mask
.
shape
[
0
]:
# keep dim 0 and reduce all other dims by sum
reduce_dims
=
[
reduce_dim
for
reduce_dim
in
range
(
1
,
len
(
weight_mask
.
shape
))]
# count unmasked number of values on dim 0 (output channel) of weight
unmasked_num_on_dim0
=
weight_mask
.
sum
(
reduce_dims
)
if
reduce_dims
else
weight_mask
module_masks
[
'bias'
]
=
(
unmasked_num_on_dim0
!=
0
).
type_as
(
old_bias_mask
)
return
masks
Returns
-------
Dict[str, Tensor]
The key is `weight` or `bias`, value is the final mask.
"""
weight_mask
=
mask
.
clone
()
if
self
.
block_sparse_size
is
not
None
:
# expend mask with block_sparse_size
expand_size
=
list
(
weight_mask
.
size
())
reshape_size
=
list
(
weight_mask
.
size
())
for
i
,
block_width
in
reversed
(
list
(
enumerate
(
self
.
block_sparse_size
))):
weight_mask
=
weight_mask
.
unsqueeze
(
i
+
1
)
expand_size
.
insert
(
i
+
1
,
block_width
)
reshape_size
[
i
]
*=
block_width
weight_mask
=
weight_mask
.
expand
(
expand_size
).
reshape
(
reshape_size
)
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
weight_size
=
wrapper
.
weight
.
data
.
size
()
# type: ignore
if
self
.
dim
is
None
:
assert
weight_mask
.
size
()
==
weight_size
expand_mask
=
{
'weight'
:
weight_mask
}
else
:
# expand mask to weight size with dim
assert
len
(
weight_mask
.
size
())
==
len
(
self
.
dim
)
assert
all
(
weight_size
[
j
]
==
weight_mask
.
size
(
i
)
for
i
,
j
in
enumerate
(
self
.
dim
))
idxs
=
list
(
range
(
len
(
weight_size
)))
[
idxs
.
pop
(
i
)
for
i
in
reversed
(
self
.
dim
)]
for
i
in
idxs
:
weight_mask
=
weight_mask
.
unsqueeze
(
i
)
expand_mask
=
{
'weight'
:
weight_mask
.
expand
(
weight_size
).
clone
()}
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence.
# If we support more kind of masks, this place need refactor.
if
wrapper
.
bias_mask
is
not
None
and
weight_mask
.
size
()
==
wrapper
.
bias_mask
.
size
():
# type: ignore
expand_mask
[
'bias'
]
=
weight_mask
.
clone
()
return
expand_mask
def
_compress_mask
(
self
,
mask
:
Tensor
)
->
Tensor
:
"""
This function will reduce the mask with `self.dim` and `self.block_sparse_size`.
e.g., a mask tensor with size [50, 60, 70], self.dim is (0, 1), self.block_sparse_size is [10, 10].
Then, the reduced mask size is [50 / 10, 60 / 10] => [5, 6].
def
generate_sparsity
(
self
,
metrics
:
Dict
)
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
"""
The main function of `SparsityAllocator`, generate a set of masks based on the given metrics.
Parameters
----------
name
The masked module name.
mask
The entire mask has the same size with weight.
metrics
A metric dict with format {module_name: weight_metric}
Returns
-------
Tensor
Reduced mask.
"""
if
self
.
dim
is
None
or
len
(
mask
.
size
())
==
1
:
mask
=
mask
.
clone
()
else
:
mask_dim
=
list
(
range
(
len
(
mask
.
size
())))
for
dim
in
self
.
dim
:
mask_dim
.
remove
(
dim
)
mask
=
torch
.
sum
(
mask
,
dim
=
mask_dim
)
if
self
.
block_sparse_size
is
not
None
:
# operation like pooling
lower_case_letters
=
'abcdefghijklmnopqrstuvwxyz'
ein_expression
=
''
for
i
,
step
in
enumerate
(
self
.
block_sparse_size
):
mask
=
mask
.
unfold
(
i
,
step
,
step
)
ein_expression
+=
lower_case_letters
[
i
]
ein_expression
=
'...{},{}'
.
format
(
ein_expression
,
ein_expression
)
mask
=
torch
.
einsum
(
ein_expression
,
mask
,
torch
.
ones
(
self
.
block_sparse_size
).
to
(
mask
.
device
))
return
(
mask
!=
0
).
type_as
(
mask
)
Dict[str, Dict[str, Tensor]]
The masks format is {module_name: {target_name: mask}}.
"""
masks
=
self
.
common_target_masks_generation
(
metrics
)
masks
=
self
.
special_target_masks_generation
(
masks
)
if
self
.
continuous_mask
:
masks
=
self
.
_continuous_mask
(
masks
)
return
masks
class
TaskGenerator
:
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/metrics_calculator.py
View file @
f24dc27b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Dict
,
List
,
Optional
,
Union
from
__future__
import
annotations
from
typing
import
Dict
,
List
import
torch
from
torch
import
Tensor
from
.base
import
MetricsCalculator
from
...utils
import
Scaling
__all__
=
[
'NormMetricsCalculator'
,
'MultiDataNormMetricsCalculator'
,
'DistMetricsCalculator'
,
'APoZRankMetricsCalculator'
,
'MeanRankMetricsCalculator'
,
'StraightMetricsCalculator'
]
...
...
@@ -28,49 +31,28 @@ class NormMetricsCalculator(MetricsCalculator):
"""
Calculate the specify norm for each tensor in data.
L1, L2, Level, Slim pruner use this to calculate metric.
"""
def
__init__
(
self
,
dim
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
p
:
Optional
[
Union
[
int
,
float
]]
=
None
):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
p
The order of norm. None means Frobenius norm.
scalers
Please view the base class `MetricsCalculator` docstring.
"""
super
().
__init__
(
dim
=
dim
)
def
__init__
(
self
,
p
:
int
|
float
|
None
=
None
,
scalers
:
Dict
[
str
,
Dict
[
str
,
Scaling
]]
|
Scaling
|
None
=
None
):
super
().
__init__
(
scalers
=
scalers
)
self
.
p
=
p
if
p
is
not
None
else
'fro'
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Tensor
]:
def
reduce_func
(
t
:
Tensor
)
->
Tensor
:
return
t
.
norm
(
p
=
self
.
p
,
dim
=-
1
)
# type: ignore
metrics
=
{}
for
name
,
tensor
in
data
.
items
():
keeped_dim
=
list
(
range
(
len
(
tensor
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
across_dim
=
list
(
range
(
len
(
tensor
.
size
())))
[
across_dim
.
pop
(
i
)
for
i
in
reversed
(
keeped_dim
)]
if
len
(
across_dim
)
==
0
:
metrics
[
name
]
=
tensor
.
abs
()
else
:
metrics
[
name
]
=
tensor
.
norm
(
p
=
self
.
p
,
dim
=
across_dim
)
# type: ignore
target_name
=
'weight'
for
module_name
,
target_data
in
data
.
items
():
scaler
=
self
.
_get_scaler
(
module_name
,
target_name
)
metrics
[
module_name
]
=
scaler
.
shrink
(
target_data
,
reduce_func
)
return
metrics
...
...
@@ -90,66 +72,32 @@ class DistMetricsCalculator(MetricsCalculator):
"""
Calculate the sum of specify distance for each element with all other elements in specify `dim` in each tensor in data.
FPGM pruner uses this to calculate metric.
"""
def
__init__
(
self
,
p
:
float
,
dim
:
Union
[
int
,
List
[
int
]]):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
p
The order of norm.
The order of norm. None means Frobenius norm.
scalers
Please view the base class `MetricsCalculator` docstring.
"""
super
().
__init__
(
dim
=
dim
)
self
.
p
=
p
def
__init__
(
self
,
p
:
int
|
float
|
None
=
None
,
scalers
:
Dict
[
str
,
Dict
[
str
,
Scaling
]]
|
Scaling
|
None
=
None
):
super
().
__init__
(
scalers
=
scalers
)
self
.
p
=
p
if
p
is
not
None
else
'fro'
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Tensor
]:
def
reduce_func
(
t
:
Tensor
)
->
Tensor
:
reshape_data
=
t
.
reshape
(
-
1
,
t
.
shape
[
-
1
])
metric
=
torch
.
zeros
(
reshape_data
.
shape
[
0
],
device
=
reshape_data
.
device
)
for
i
in
range
(
reshape_data
.
shape
[
0
]):
metric
[
i
]
=
(
reshape_data
-
reshape_data
[
i
]).
norm
(
p
=
self
.
p
,
dim
=-
1
).
sum
()
# type: ignore
return
metric
.
reshape
(
t
.
shape
[:
-
1
])
metrics
=
{}
for
name
,
tensor
in
data
.
items
():
keeped_dim
=
list
(
range
(
len
(
tensor
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
reorder_dim
=
list
(
keeped_dim
)
reorder_dim
.
extend
([
i
for
i
in
range
(
len
(
tensor
.
size
()))
if
i
not
in
keeped_dim
])
reorder_tensor
=
tensor
.
permute
(
*
reorder_dim
).
clone
()
metric
=
torch
.
ones
(
*
reorder_tensor
.
size
()[:
len
(
keeped_dim
)],
device
=
reorder_tensor
.
device
)
across_dim
=
list
(
range
(
len
(
keeped_dim
),
len
(
reorder_dim
)))
idxs
=
metric
.
nonzero
(
as_tuple
=
False
)
for
idx
in
idxs
:
other
=
reorder_tensor
for
i
in
idx
:
other
=
other
[
i
]
other
=
other
.
clone
()
if
len
(
across_dim
)
==
0
:
dist_sum
=
torch
.
abs
(
reorder_tensor
-
other
).
sum
()
else
:
dist_sum
=
torch
.
norm
((
reorder_tensor
-
other
),
p
=
self
.
p
,
dim
=
across_dim
).
sum
()
# type: ignore
# NOTE: this place need refactor when support layer level pruning.
tmp_metric
=
metric
for
i
in
idx
[:
-
1
]:
tmp_metric
=
tmp_metric
[
i
]
tmp_metric
[
idx
[
-
1
]]
=
dist_sum
metrics
[
name
]
=
metric
target_name
=
'weight'
for
module_name
,
target_data
in
data
.
items
():
scaler
=
self
.
_get_scaler
(
module_name
,
target_name
)
metrics
[
module_name
]
=
scaler
.
shrink
(
target_data
,
reduce_func
)
return
metrics
...
...
@@ -161,19 +109,15 @@ class APoZRankMetricsCalculator(MetricsCalculator):
APoZRank pruner uses this to calculate metric.
"""
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
])
->
Dict
[
str
,
Tensor
]:
def
reduce_func
(
t
:
Tensor
)
->
Tensor
:
return
1
-
t
.
mean
(
dim
=-
1
)
metrics
=
{}
for
name
,
(
num
,
zero_counts
)
in
data
.
items
():
keeped_dim
=
list
(
range
(
len
(
zero_counts
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
across_dim
=
list
(
range
(
len
(
zero_counts
.
size
())))
[
across_dim
.
pop
(
i
)
for
i
in
reversed
(
keeped_dim
)]
# The element number on each keeped_dim in zero_counts
total_size
=
num
for
dim
,
dim_size
in
enumerate
(
zero_counts
.
size
()):
if
dim
not
in
keeped_dim
:
total_size
*=
dim_size
_apoz
=
torch
.
sum
(
zero_counts
,
dim
=
across_dim
).
type_as
(
zero_counts
)
/
total_size
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
metrics
[
name
]
=
torch
.
ones_like
(
_apoz
)
-
_apoz
target_name
=
'weight'
for
module_name
,
target_data
in
data
.
items
():
target_data
=
target_data
[
1
]
/
target_data
[
0
]
scaler
=
self
.
_get_scaler
(
module_name
,
target_name
)
metrics
[
module_name
]
=
scaler
.
shrink
(
target_data
,
reduce_func
)
return
metrics
...
...
@@ -183,11 +127,14 @@ class MeanRankMetricsCalculator(MetricsCalculator):
This metric simply calculate the average on `self.dim`, then divide by the batch_number.
MeanRank pruner uses this to calculate metric.
"""
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
[
Tensor
]])
->
Dict
[
str
,
Tensor
]:
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
])
->
Dict
[
str
,
Tensor
]:
def
reduce_func
(
t
:
Tensor
)
->
Tensor
:
return
t
.
mean
(
dim
=-
1
)
metrics
=
{}
for
name
,
(
num
,
activation_sum
)
in
data
.
items
():
keeped_dim
=
list
(
range
(
len
(
activation_sum
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
across_dim
=
list
(
range
(
len
(
activation_sum
.
size
())))
[
across_dim
.
pop
(
i
)
for
i
in
reversed
(
keeped_dim
)]
metrics
[
name
]
=
torch
.
mean
(
activation_sum
,
across_dim
)
/
num
target_name
=
'weight'
for
module_name
,
target_data
in
data
.
items
():
target_data
=
target_data
[
1
]
/
target_data
[
0
]
scaler
=
self
.
_get_scaler
(
module_name
,
target_name
)
metrics
[
module_
name
]
=
scaler
.
shrink
(
target_data
,
reduce_func
)
return
metrics
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
f24dc27b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
math
from
__future__
import
annotations
import
itertools
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Union
import
numpy
as
np
import
torch
from
torch
import
Tensor
from
nni.
algorithms.compression.v2.pytorch.base
import
Pruner
from
nni.
common.graph_utils
import
TorchModuleGraph
from
nni.compression.pytorch.utils.shape_dependency
import
ChannelDependency
,
GroupDependency
from
.base
import
SparsityAllocator
from
...base
import
Pruner
from
...utils
import
Scaling
class
NormalSparsityAllocator
(
SparsityAllocator
):
"""
This allocator
simply pruned the weigh
t with
small
er metric
s in layer level
.
This allocator
directly masks the locations of each pruning targe
t with
low
er metric
values
.
"""
def
generate_sparsity
(
self
,
metrics
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
def
common_target_masks_generation
(
self
,
metrics
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
masks
=
{}
for
name
,
wrapper
in
self
.
pruner
.
get_modules_wrapper
().
items
():
# TODO: Support more target type in wrapper & config list refactor
target_name
=
'weight'
for
module_name
,
target_metric
in
metrics
.
items
():
masks
[
module_name
]
=
{}
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
sparsity_rate
=
wrapper
.
config
[
'total_sparsity'
]
assert
name
in
metrics
,
'Metric of {} is not calculated.'
.
format
(
name
)
# We assume the metric value are all positive right now.
metric
=
metrics
[
name
]
if
self
.
continuous_mask
:
metric
*=
self
.
_compress_mask
(
wrapper
.
weight_mask
)
# type: ignore
prune_num
=
int
(
sparsity_rate
*
metric
.
numel
())
if
prune_num
==
0
:
threshold
=
metric
.
min
()
-
1
prune_num
=
int
(
sparsity_rate
*
target_metric
.
numel
())
if
prune_num
!=
0
:
threshold
=
torch
.
topk
(
target_metric
.
reshape
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
shrinked_mask
=
torch
.
gt
(
target_metric
,
threshold
).
type_as
(
target_metric
)
else
:
threshold
=
torch
.
topk
(
metric
.
view
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
if
self
.
continuous_mask
:
masks
[
name
][
'weight'
]
*=
wrapper
.
weight_mask
# target_metric should have the same size as shrinked_mask
shrinked_mask
=
torch
.
ones_like
(
target_metric
)
masks
[
module_name
][
target_name
]
=
self
.
_expand_mask
(
module_name
,
target_name
,
shrinked_mask
)
return
masks
class
BankSparsityAllocator
(
SparsityAllocator
):
"""
In bank pruner, all values in weight are divided into different sub blocks each shape
aligned with balance_gran. Each sub block has the same sparsity which equal to the overall sparsity.
This allocator pruned the weight in the granularity of block.
"""
def
__init__
(
self
,
pruner
:
Pruner
,
balance_gran
:
list
):
super
().
__init__
(
pruner
)
self
.
balance_gran
=
balance_gran
...
...
@@ -54,199 +56,166 @@ class BankSparsityAllocator(SparsityAllocator):
assert
isinstance
(
gran
,
int
)
and
gran
>
0
,
'All values in list balance_gran
\
should be type int and bigger than zero'
def
generate_sparsity
(
self
,
metrics
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
def
common_target_masks_generation
(
self
,
metrics
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
masks
=
{}
for
name
,
wrapper
in
self
.
pruner
.
get_modules_wrapper
().
items
():
# TODO: Support more target type in wrapper & config list refactor
target_name
=
'weight'
for
module_name
,
target_metric
in
metrics
.
items
():
masks
[
module_name
]
=
{}
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
sparsity_rate
=
wrapper
.
config
[
'total_sparsity'
]
assert
name
in
metrics
,
'Metric of {} is not calculated.'
.
format
(
name
)
# We assume the metric value are all positive right now.
metric
=
metrics
[
name
]
if
self
.
continuous_mask
:
metric
*=
self
.
_compress_mask
(
wrapper
.
weight_mask
)
# type: ignore
n_dim
=
len
(
metric
.
shape
)
n_dim
=
len
(
target_metric
.
shape
)
assert
n_dim
>=
len
(
self
.
balance_gran
),
'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran
balance_gran
=
[
1
]
*
(
n_dim
-
len
(
self
.
balance_gran
))
+
self
.
balance_gran
for
i
,
j
in
zip
(
metric
.
shape
,
balance_gran
):
assert
i
%
j
==
0
,
'Length of {}
weight
is not aligned with balance granularity'
.
format
(
name
)
for
i
,
j
in
zip
(
target_
metric
.
shape
,
balance_gran
):
assert
i
%
j
==
0
,
'Length of {}
{}
is not aligned with balance granularity'
.
format
(
module_name
,
target_
name
)
mask
=
torch
.
zeros
(
metric
.
shape
).
type_as
(
metric
)
loop_iters
=
[
range
(
int
(
i
/
j
))
for
i
,
j
in
zip
(
metric
.
shape
,
balance_gran
)]
# FIXME: The following code need refactor, do it after scaling refactor is done.
shrinked_mask
=
torch
.
ones
(
target_metric
.
shape
).
type_as
(
target_metric
)
loop_iters
=
[
range
(
int
(
i
/
j
))
for
i
,
j
in
zip
(
target_metric
.
shape
,
balance_gran
)]
for
iter_params
in
itertools
.
product
(
*
loop_iters
):
index_str_list
=
[
f
"
{
iter_param
*
gran
}
:
{
(
iter_param
+
1
)
*
gran
}
"
\
for
iter_param
,
gran
in
zip
(
iter_params
,
balance_gran
)]
index_str
=
","
.
join
(
index_str_list
)
sub_metric_str
=
"metric[{}]"
.
format
(
index_str
)
sub_mask_str
=
"mask[{}] = mask_bank"
.
format
(
index_str
)
metric_bank
=
eval
(
sub_metric_str
)
sub_metric_str
=
"
target_
metric[{}]"
.
format
(
index_str
)
sub_mask_str
=
"
shrinked_
mask[{}] = mask_bank"
.
format
(
index_str
)
metric_bank
:
Tensor
=
eval
(
sub_metric_str
)
prune_num
=
int
(
sparsity_rate
*
metric_bank
.
numel
())
if
prune_num
==
0
:
threshold
=
metric_bank
.
min
()
-
1
else
:
threshold
=
torch
.
topk
(
metric_bank
.
reshape
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
# mask_bank will be used in exec(sub_mask_str)
if
prune_num
!=
0
:
threshold
=
torch
.
topk
(
metric_bank
.
reshape
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
mask_bank
=
torch
.
gt
(
metric_bank
,
threshold
).
type_as
(
metric_bank
)
else
:
mask_bank
=
torch
.
ones_like
(
metric_bank
)
exec
(
sub_mask_str
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
if
self
.
continuous_mask
:
masks
[
name
][
'weight'
]
*=
wrapper
.
weight_mask
masks
[
module_name
][
target_name
]
=
self
.
_expand_mask
(
module_name
,
target_name
,
shrinked_mask
)
return
masks
class
GlobalSparsityAllocator
(
SparsityAllocator
):
"""
This allocator pruned the weight with smaller metrics in group level.
This means all layers in a group will sort metrics uniformly.
The layers with the same config in config_list is a group.
This allocator sorts all metrics as a whole, mask the locations of pruning target with lower metric value.
"""
def
generate_sparsity
(
self
,
metrics
:
Dict
)
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
def
common_target_masks_generation
(
self
,
metrics
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
masks
=
{}
# {group_index: {layer_name: metric}}
grouped_metrics
=
{
idx
:
{
name
:
metrics
[
name
]
for
name
in
names
}
for
idx
,
names
in
self
.
pruner
.
generate_module_groups
().
items
()}
for
_
,
group_metric_dict
in
grouped_metrics
.
items
():
threshold
,
sub_thresholds
=
self
.
_calculate_threshold
(
group_metric_dict
)
for
name
,
metric
in
group_metric_dict
.
items
():
mask
=
torch
.
gt
(
metric
,
min
(
threshold
,
sub_thresholds
[
name
])).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
if
self
.
continuous_mask
:
masks
[
name
][
'weight'
]
*=
self
.
pruner
.
get_modules_wrapper
()[
name
].
weight_mask
if
not
metrics
:
return
masks
def
_calculate_threshold
(
self
,
group_metric_dict
:
Dict
[
str
,
Tensor
])
->
Tuple
[
float
,
Dict
[
str
,
float
]]:
# TODO: support more target type in wrapper & config list refactor
target_name
=
'weight'
# validate all wrapper setting the same sparsity
# TODO: move validation logic to pruner
global_sparsity_rate
=
self
.
pruner
.
get_modules_wrapper
()[
list
(
metrics
.
keys
())[
0
]].
config
[
'total_sparsity'
]
for
module_name
,
target_metric
in
metrics
.
items
():
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
assert
global_sparsity_rate
==
wrapper
.
config
[
'total_sparsity'
]
# find the largest metric value among all metrics
max_metric_value
=
list
(
metrics
.
values
())[
0
].
max
()
for
module_name
,
target_metric
in
metrics
.
items
():
max_metric_value
=
max_metric_value
if
max_metric_value
>=
target_metric
.
max
()
else
target_metric
.
max
()
# prevent each module from being over-pruned, prevent ratio is 'max_sparsity_per_layer'
for
module_name
,
target_metric
in
metrics
.
items
():
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
max_sparsity
=
wrapper
.
config
.
get
(
'max_sparsity_per_layer'
,
{}).
get
(
module_name
,
0.99
)
assert
0
<=
max_sparsity
<=
1
old_target_mask
:
Tensor
=
getattr
(
wrapper
,
f
'
{
target_name
}
_mask'
)
expand_times
=
old_target_mask
.
numel
()
//
target_metric
.
numel
()
max_pruning_numel
=
int
(
max_sparsity
*
target_metric
.
numel
())
*
expand_times
threshold
=
torch
.
topk
(
target_metric
.
reshape
(
-
1
),
max_pruning_numel
,
largest
=
False
)[
0
].
max
()
metrics
[
module_name
]
=
torch
.
where
(
target_metric
<=
threshold
,
target_metric
,
max_metric_value
)
# build the global_matric & calculate global threshold
metric_list
=
[]
sub_thresholds
=
{}
total_weight_num
=
0
temp_wrapper_config
=
self
.
pruner
.
get_modules_wrapper
()[
list
(
group_metric_dict
.
keys
())[
0
]].
config
total_sparsity
=
temp_wrapper_config
[
'total_sparsity'
]
max_sparsity_per_layer
=
temp_wrapper_config
.
get
(
'max_sparsity_per_layer'
,
{})
for
name
,
metric
in
group_metric_dict
.
items
():
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
# We assume the metric value are all positive right now.
if
self
.
continuous_mask
:
metric
=
metric
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
# type: ignore
layer_weight_num
=
wrapper
.
weight
.
data
.
numel
()
# type: ignore
total_weight_num
+=
layer_weight_num
expend_times
=
int
(
layer_weight_num
/
metric
.
numel
())
retention_ratio
=
1
-
max_sparsity_per_layer
.
get
(
name
,
1
)
retention_numel
=
math
.
ceil
(
retention_ratio
*
layer_weight_num
)
removed_metric_num
=
math
.
ceil
(
retention_numel
/
(
wrapper
.
weight_mask
.
numel
()
/
metric
.
numel
()))
# type: ignore
stay_metric_num
=
metric
.
numel
()
-
removed_metric_num
if
stay_metric_num
<=
0
:
sub_thresholds
[
name
]
=
metric
.
min
().
item
()
-
1
continue
# Remove the weight parts that must be left
stay_metric
=
torch
.
topk
(
metric
.
view
(
-
1
),
stay_metric_num
,
largest
=
False
)[
0
]
sub_thresholds
[
name
]
=
stay_metric
.
max
()
if
expend_times
>
1
:
stay_metric
=
stay_metric
.
expand
(
int
(
layer_weight_num
/
metric
.
numel
()),
stay_metric_num
).
contiguous
().
view
(
-
1
)
metric_list
.
append
(
stay_metric
)
total_prune_num
=
int
(
total_sparsity
*
total_weight_num
)
if
total_prune_num
==
0
:
threshold
=
torch
.
cat
(
metric_list
).
min
().
item
()
-
1
else
:
threshold
=
torch
.
topk
(
torch
.
cat
(
metric_list
).
view
(
-
1
),
total_prune_num
,
largest
=
False
)[
0
].
max
().
item
()
return
threshold
,
sub_thresholds
for
module_name
,
target_metric
in
metrics
.
items
():
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
old_target_mask
:
Tensor
=
getattr
(
wrapper
,
f
'
{
target_name
}
_mask'
)
expand_times
=
old_target_mask
.
numel
()
//
target_metric
.
numel
()
metric_list
.
append
(
target_metric
.
reshape
(
-
1
).
unsqueeze
(
0
).
expand
(
expand_times
,
-
1
).
reshape
(
-
1
))
global_metric
=
torch
.
cat
(
metric_list
)
max_pruning_num
=
int
((
global_metric
!=
max_metric_value
).
sum
().
item
())
total_pruning_num
=
min
(
int
(
global_sparsity_rate
*
global_metric
.
numel
()),
max_pruning_num
)
global_threshold
=
torch
.
topk
(
global_metric
.
reshape
(
-
1
),
total_pruning_num
,
largest
=
False
)[
0
].
max
()
# generate masks for each target
for
module_name
,
target_metric
in
metrics
.
items
():
masks
[
module_name
]
=
{}
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
shrinked_mask
=
torch
.
gt
(
target_metric
,
global_threshold
).
type_as
(
target_metric
)
masks
[
module_name
][
target_name
]
=
self
.
_expand_mask
(
module_name
,
target_name
,
shrinked_mask
)
return
masks
class
Conv2d
DependencyAwareAllocator
(
SparsityAllocator
):
class
DependencyAwareAllocator
(
Normal
SparsityAllocator
):
"""
An allocator specific for Conv2d with dependency-aware.
An specific allocator for Conv2d & Linear module with dependency-aware.
It will generate a public mask for the modules that have dependencies,
then generate the part of the non-public mask for each module.
For other module types, the way to generate the mask is the same as `NormalSparsityAllocator`.
"""
def
__init__
(
self
,
pruner
:
Pruner
,
dim
:
int
,
dummy_input
:
Any
):
assert
isinstance
(
dim
,
int
),
'Only support single dim in Conv2dDependencyAwareAllocator.'
super
().
__init__
(
pruner
,
dim
=
dim
)
self
.
dummy_input
=
dummy_input
def
__init__
(
self
,
pruner
:
Pruner
,
dummy_input
:
Any
,
scalers
:
Dict
[
str
,
Dict
[
str
,
Scaling
]]
|
Scaling
|
None
=
None
):
# Scaling(kernel_size=[1], kernel_padding_mode='back') means output channel pruning.
scalers
=
scalers
if
scalers
else
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
super
().
__init__
(
pruner
,
scalers
=
scalers
)
self
.
channel_dependency
,
self
.
group_dependency
=
self
.
_get_dependency
(
dummy_input
)
def
_get_dependency
(
self
):
graph
=
self
.
pruner
.
generate_graph
(
dummy_input
=
self
.
dummy_input
)
def
_get_dependency
(
self
,
dummy_input
:
Any
):
# get the channel dependency and group dependency
# channel dependency format: [[module_name1, module_name2], [module_name3], ...]
# group dependency format: {module_name: group_num}
self
.
pruner
.
_unwrap_model
()
self
.
channel_depen
=
ChannelDependency
(
model
=
self
.
pruner
.
bound_model
,
dummy_input
=
self
.
dummy_input
,
traced_model
=
graph
.
trace
).
dependency_sets
self
.
group_depen
=
GroupDependency
(
model
=
self
.
pruner
.
bound_model
,
dummy_input
=
self
.
dummy_input
,
traced_model
=
graph
.
trace
).
dependency_sets
graph
=
TorchModuleGraph
(
model
=
self
.
pruner
.
bound_model
,
dummy_input
=
dummy_input
)
channel_dependency
=
ChannelDependency
(
model
=
self
.
pruner
.
bound_model
,
dummy_input
=
dummy_input
,
traced_model
=
graph
.
trace
).
dependency_sets
group_dependency
=
GroupDependency
(
model
=
self
.
pruner
.
bound_model
,
dummy_input
=
dummy_input
,
traced_model
=
graph
.
trace
).
dependency_sets
self
.
pruner
.
_wrap_model
()
def
generate_sparsity
(
self
,
metrics
:
Dict
)
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
self
.
_get_dependency
()
masks
=
{}
grouped_metrics
=
{}
grouped_names
=
set
()
# combine metrics with channel dependence
for
idx
,
names
in
enumerate
(
self
.
channel_depen
):
grouped_metric
=
{
name
:
metrics
[
name
]
for
name
in
names
if
name
in
metrics
}
grouped_names
.
update
(
grouped_metric
.
keys
())
if
self
.
continuous_mask
:
for
name
,
metric
in
grouped_metric
.
items
():
metric
*=
self
.
_compress_mask
(
self
.
pruner
.
get_modules_wrapper
()[
name
].
weight_mask
)
# type: ignore
if
len
(
grouped_metric
)
>
0
:
grouped_metrics
[
idx
]
=
grouped_metric
# ungrouped metrics stand alone as a group
ungrouped_names
=
set
(
metrics
.
keys
()).
difference
(
grouped_names
)
for
name
in
ungrouped_names
:
idx
+=
1
# type: ignore
grouped_metrics
[
idx
]
=
{
name
:
metrics
[
name
]}
# generate masks
for
_
,
group_metric_dict
in
grouped_metrics
.
items
():
group_metric
=
self
.
_group_metric_calculate
(
group_metric_dict
)
sparsities
=
{
name
:
self
.
pruner
.
get_modules_wrapper
()[
name
].
config
[
'total_sparsity'
]
for
name
in
group_metric_dict
.
keys
()}
min_sparsity
=
min
(
sparsities
.
values
())
# generate group mask
conv2d_groups
,
group_mask
=
[],
[]
for
name
in
group_metric_dict
.
keys
():
if
name
in
self
.
group_depen
:
conv2d_groups
.
append
(
self
.
group_depen
[
name
])
else
:
# not in group_depen means not a Conv2d layer, in this case, assume the group number is 1
conv2d_groups
.
append
(
1
)
max_conv2d_group
=
np
.
lcm
.
reduce
(
conv2d_groups
)
pruned_per_conv2d_group
=
int
(
group_metric
.
numel
()
/
max_conv2d_group
*
min_sparsity
)
conv2d_group_step
=
int
(
group_metric
.
numel
()
/
max_conv2d_group
)
for
gid
in
range
(
max_conv2d_group
):
_start
=
gid
*
conv2d_group_step
_end
=
(
gid
+
1
)
*
conv2d_group_step
if
pruned_per_conv2d_group
>
0
:
threshold
=
torch
.
topk
(
group_metric
[
_start
:
_end
],
pruned_per_conv2d_group
,
largest
=
False
)[
0
].
max
()
conv2d_group_mask
=
torch
.
gt
(
group_metric
[
_start
:
_end
],
threshold
).
type_as
(
group_metric
)
else
:
conv2d_group_mask
=
torch
.
ones
(
conv2d_group_step
,
device
=
group_metric
.
device
)
group_mask
.
append
(
conv2d_group_mask
)
group_mask
=
torch
.
cat
(
group_mask
,
dim
=
0
)
# generate final mask
for
name
,
metric
in
group_metric_dict
.
items
():
# We assume the metric value are all positive right now.
metric
=
metric
*
group_mask
pruned_num
=
int
(
sparsities
[
name
]
*
len
(
metric
))
if
pruned_num
==
0
:
threshold
=
metric
.
min
()
-
1
return
channel_dependency
,
group_dependency
def
_metric_fuse
(
self
,
metrics
:
Union
[
Dict
[
str
,
Tensor
],
List
[
Tensor
]])
->
Tensor
:
# Sum all metric value in the same position.
metrics
=
list
(
metrics
.
values
())
if
isinstance
(
metrics
,
dict
)
else
metrics
assert
all
(
metrics
[
0
].
size
()
==
metric
.
size
()
for
metric
in
metrics
),
'Metrics size do not match.'
fused_metric
=
torch
.
zeros_like
(
metrics
[
0
])
for
metric
in
metrics
:
fused_metric
+=
metric
return
fused_metric
def
common_target_masks_generation
(
self
,
metrics
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
# generate public part for modules that have dependencies
for
module_names
in
self
.
channel_dependency
:
sub_metrics
=
{
module_name
:
metrics
[
module_name
]
for
module_name
in
module_names
if
module_name
in
metrics
}
if
not
sub_metrics
:
continue
fused_metric
=
self
.
_metric_fuse
(
sub_metrics
)
sparsity_rates
=
{
module_name
:
self
.
pruner
.
get_modules_wrapper
()[
module_name
].
config
[
'total_sparsity'
]
for
module_name
in
sub_metrics
.
keys
()}
min_sparsity_rate
=
min
(
sparsity_rates
.
values
())
group_nums
=
[
self
.
group_dependency
.
get
(
module_name
,
1
)
for
module_name
in
sub_metrics
.
keys
()]
max_group_nums
=
int
(
np
.
lcm
.
reduce
(
group_nums
))
pruned_numel_per_group
=
int
(
fused_metric
.
numel
()
//
max_group_nums
*
min_sparsity_rate
)
group_step
=
fused_metric
.
shape
[
0
]
//
max_group_nums
# get the public part of the mask of the module with dependencies
sub_masks
=
[]
for
gid
in
range
(
max_group_nums
):
_start
=
gid
*
group_step
_end
=
(
gid
+
1
)
*
group_step
if
pruned_numel_per_group
>
0
:
threshold
=
torch
.
topk
(
fused_metric
[
_start
:
_end
].
reshape
(
-
1
),
pruned_numel_per_group
,
largest
=
False
)[
0
].
max
()
sub_mask
=
torch
.
gt
(
fused_metric
[
_start
:
_end
],
threshold
).
type_as
(
fused_metric
)
else
:
threshold
=
torch
.
topk
(
metric
,
pruned_num
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
if
self
.
continuous_mask
:
masks
[
name
][
'weight'
]
*=
self
.
pruner
.
get_modules_wrapper
()[
name
].
weight_mask
return
masks
sub_mask
=
torch
.
ones_like
(
fused_metric
[
_start
:
_end
])
sub_masks
.
append
(
sub_mask
)
dependency_mask
=
torch
.
cat
(
sub_masks
,
dim
=
0
)
def
_group_metric_calculate
(
self
,
group_metrics
:
Union
[
Dict
[
str
,
Tensor
],
List
[
Tensor
]])
->
Tensor
:
"""
Add all metric value in the same position in one group.
"""
group_metrics
=
list
(
group_metrics
.
values
())
if
isinstance
(
group_metrics
,
dict
)
else
group_metrics
assert
all
(
group_metrics
[
0
].
size
()
==
group_metric
.
size
()
for
group_metric
in
group_metrics
),
'Metrics size do not match.'
group_sum_metric
=
torch
.
zeros
(
group_metrics
[
0
].
size
(),
device
=
group_metrics
[
0
].
device
)
for
group_metric
in
group_metrics
:
group_sum_metric
+=
group_metric
return
group_sum_metric
# change the metric value corresponding to the public mask part to the minimum value
for
module_name
,
target_metric
in
sub_metrics
.
items
():
min_value
=
target_metric
.
min
()
metrics
[
module_name
]
=
torch
.
where
(
dependency_mask
!=
0
,
target_metric
,
min_value
)
return
super
().
common_target_masks_generation
(
metrics
)
nni/algorithms/compression/v2/pytorch/utils/__init__.py
View file @
f24dc27b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.attr
import
(
get_nested_attr
,
set_nested_attr
)
from
.config_validation
import
CompressorSchema
from
.constructor_helper
import
*
from
.pruning
import
(
config_list_canonical
,
unfold_config_list
,
...
...
@@ -12,4 +17,4 @@ from .pruning import (
get_model_weights_numel
,
get_module_by_name
)
from
.
constructor_helper
import
*
from
.
scaling
import
Scaling
nni/algorithms/compression/v2/pytorch/utils/attr.py
0 → 100644
View file @
f24dc27b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
functools
import
reduce
from
typing
import
Any
,
overload
@
overload
def
get_nested_attr
(
__o
:
object
,
__name
:
str
)
->
Any
:
...
@
overload
def
get_nested_attr
(
__o
:
object
,
__name
:
str
,
__default
:
Any
)
->
Any
:
...
def
get_nested_attr
(
__o
:
object
,
__name
:
str
,
*
args
)
->
Any
:
"""
Get a nested named attribute from an object by a `.` separated name.
rgetattr(x, 'y.z') is equivalent to getattr(getattr(x, 'y'), 'z') and x.y.z.
"""
def
_getattr
(
__o
,
__name
):
return
getattr
(
__o
,
__name
,
*
args
)
return
reduce
(
_getattr
,
[
__o
]
+
__name
.
split
(
'.'
))
# type: ignore
def
set_nested_attr
(
__obj
:
object
,
__name
:
str
,
__value
:
Any
):
"""
Set the nested named attribute on the given object to the specified value by a `.` separated name.
set_nested_attr(x, 'y.z', v) is equivalent to setattr(getattr(x, 'y'), 'z', v) x.y.z = v.
"""
pre
,
_
,
post
=
__name
.
rpartition
(
'.'
)
return
setattr
(
get_nested_attr
(
__obj
,
pre
)
if
pre
else
__obj
,
post
,
__value
)
nni/algorithms/compression/v2/pytorch/utils/scaling.py
0 → 100644
View file @
f24dc27b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
from
functools
import
reduce
from
typing
import
Callable
,
List
,
overload
from
typing_extensions
import
Literal
import
torch
from
torch
import
Tensor
class
Scaling
:
"""
In the process of generating masks, a large number of operations like pooling or upsampling are involved.
This class provides tensor-related scaling functions for a given scaling kernel.
Similar to the concept of convolutional kernel, the scaling kernel also moves over the tensor and does operations.
The scaling kernel in this class is defined by two parts, kernel size and scaling function (shrink and expand).
Parameters
----------
kernel_size
kernel_size is the scale, which determines how large a range in a tensor should shrink to a value,
or how large a value in a tensor should expand.
`-1` can be used to indicate that it is a full step in this dimension,
and the dimension where -1 is located will be reduced or unsqueezed during scaling.
Example::
kernel_size = [2, -1]
# For a given 2D-tensor with size (4, 3),
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]]
# shrinking it by shrink function, its size becomes (2,) after shrinking:
[shrink([[1, 2, 3], [4, 5, 6]]), shrink([[7, 8, 9], [10, 11, 12]])]
# expanding it by expand function with a given expand size,
# if the expand function is repeating the values, and the expand size is (4, 6, 2):
[[[1, 1],
[1, 1],
[2, 2],
[2, 2],
[3, 3],
[3, 3]],
...
[9, 9]]]
# note that the original tensor with size (4, 3) will unsqueeze to size (4, 3, 1) at first
# for the `-1` in kernel_size, then expand size (4, 3, 1) to size (4, 6, 2).
kernel_padding_mode
'front' or 'back', default is 'front'.
If set 'front', for a given tensor when shrinking, padding `1` at front of kernel_size until `len(tensor.shape) == len(kernel_size)`;
for a given expand size when expanding, padding `1` at front of kernel_size until `len(expand_size) == len(kernel_size)`.
If set 'back', for a given tensor when shrinking, padding `-1` at back of kernel_size until `len(tensor.shape) == len(kernel_size)`;
for a given expand size when expanding, padding `-1` at back of kernel_size until `len(expand_size) == len(kernel_size)`.
"""
def
__init__
(
self
,
kernel_size
:
List
[
int
],
kernel_padding_mode
:
Literal
[
'front'
,
'back'
]
=
'front'
)
->
None
:
self
.
kernel_size
=
kernel_size
assert
kernel_padding_mode
in
[
'front'
,
'back'
],
f
"kernel_padding_mode should be one of ['front', 'back'], but get kernel_padding_mode=
{
kernel_padding_mode
}
."
self
.
kernel_padding_mode
=
kernel_padding_mode
def
_padding
(
self
,
_list
:
List
[
int
],
length
:
int
,
padding_value
:
int
=
-
1
,
padding_mode
:
Literal
[
'front'
,
'back'
]
=
'back'
)
->
List
[
int
]:
"""
Padding the `_list` to a specific length with `padding_value`.
Parameters
----------
_list
The list of int value to be padding.
length
The length to pad to.
padding_value
Padding value, should be a int.
padding_mode
If `padding_mode` is `'front'`, then the padding applied on the front of the size list.
If `padding_mode` is `'back'`, then the padding applied on the back of the size list.
Returns
-------
List[int]
The padded list.
"""
assert
len
(
_list
)
<=
length
padding
=
[
padding_value
for
_
in
range
(
length
-
len
(
_list
))]
if
padding_mode
==
'front'
:
new_list
=
padding
+
list
(
_list
)
elif
padding_mode
==
'back'
:
new_list
=
list
(
_list
)
+
padding
else
:
raise
ValueError
(
f
'Unsupported padding mode:
{
padding_mode
}
.'
)
return
new_list
def
_shrink
(
self
,
target
:
Tensor
,
kernel_size
:
List
[
int
],
reduce_func
:
Callable
[[
Tensor
],
Tensor
]
|
None
=
None
)
->
Tensor
:
"""
Main logic about how to shrink target. Subclass could override this function to customize.
Sum all values covered by the kernel as a simple implementation.
"""
# step 1: put the part covered by the kernel to the end of the converted target.
# e.g., target size is [10, 20], kernel_size is [2, 4], then new_target size is [5, 5, 8].
reshape_size
=
[]
final_size
=
[]
reduced_dims
=
[]
for
(
dim
,
step
)
in
enumerate
(
kernel_size
):
if
step
==
-
1
:
step
=
target
.
shape
[
dim
]
reduced_dims
.
insert
(
0
,
dim
)
assert
target
.
shape
[
dim
]
%
step
==
0
reshape_size
.
append
(
target
.
shape
[
dim
]
//
step
)
final_size
.
append
(
target
.
shape
[
dim
]
//
step
)
reshape_size
.
append
(
step
)
permute_dims
=
[
2
*
_
for
_
in
range
(
len
(
kernel_size
))]
+
[
2
*
_
+
1
for
_
in
range
(
len
(
kernel_size
))]
converted_target
=
target
.
reshape
(
reshape_size
).
permute
(
permute_dims
).
reshape
(
final_size
+
[
-
1
])
# step 2: reduce the converted_target last dim with a certain way, by default is converted_target.sum(-1).
result
=
reduce_func
(
converted_target
)
if
reduce_func
else
converted_target
.
sum
(
-
1
)
# step 3: reduce the dims where kernel_size is -1.
# e.g., target size is [10, 40], kernel_size is [-1, 4], result size is [1, 10], then reduce result to size [10].
result
=
reduce
(
lambda
t
,
dim
:
t
.
squeeze
(
dim
),
[
result
]
+
reduced_dims
)
# type: ignore
return
result
def
_expand
(
self
,
target
:
Tensor
,
kernel_size
:
List
[
int
],
expand_size
:
List
[
int
])
->
Tensor
:
"""
Main logic about how to expand target to a specific size. Subclass could override this function to customize.
Repeat each value to reach the kernel size as a simple implementation.
"""
# step 1: unsqueeze the target tensor where -1 is located in kernel_size.
unsqueezed_dims
=
[
dim
for
(
dim
,
step
)
in
enumerate
(
kernel_size
)
if
step
==
-
1
]
new_target
:
Tensor
=
reduce
(
lambda
t
,
dim
:
t
.
unsqueeze
(
dim
),
[
target
]
+
unsqueezed_dims
)
# type: ignore
# step 2: build the _expand_size and unsqueeze target tensor on each dim
_expand_size
=
[]
for
a
,
b
in
zip
(
kernel_size
,
expand_size
):
if
a
==
-
1
:
_expand_size
.
append
(
1
)
_expand_size
.
append
(
b
)
else
:
assert
b
%
a
==
0
,
f
'Can not expand tensor with
{
target
.
shape
}
to
{
expand_size
}
with kernel size
{
kernel_size
}
.'
_expand_size
.
append
(
b
//
a
)
_expand_size
.
append
(
a
)
new_target
:
Tensor
=
reduce
(
lambda
t
,
dim
:
t
.
unsqueeze
(
dim
),
[
new_target
]
+
[
2
*
_
+
1
for
_
in
range
(
len
(
expand_size
))])
# type: ignore
# step 3: expanding the new target to _expand_size and reshape to expand_size.
# Note that we can also give an interface for how to expand the tensor, like `reduce_func` in `_shrink`, currently we don't have that need.
result
=
new_target
.
expand
(
_expand_size
).
reshape
(
expand_size
).
clone
()
return
result
def
shrink
(
self
,
target
:
Tensor
,
reduce_func
:
Callable
[[
Tensor
],
Tensor
]
|
None
=
None
)
->
Tensor
:
# Canonicalize kernel_size to target size length at first.
# If kernel_padding_mode is 'front', padding 1 at the front of `self.kernel_size`.
# e.g., padding kernel_size [2, 2] to [1, 2, 2] when target size length is 3.
# If kernel_padding_mode is 'back', padding -1 at the back of `self.kernel_size`.
# e.g., padding kernel_size [1] to [1, -1, -1] when target size length is 3.
if
self
.
kernel_padding_mode
==
'front'
:
kernel_size
=
self
.
_padding
(
self
.
kernel_size
,
len
(
target
.
shape
),
1
,
'front'
)
elif
self
.
kernel_padding_mode
==
'back'
:
kernel_size
=
self
.
_padding
(
self
.
kernel_size
,
len
(
target
.
shape
),
-
1
,
'back'
)
else
:
raise
ValueError
(
f
'Unsupported kernel padding mode:
{
self
.
kernel_padding_mode
}
.'
)
return
self
.
_shrink
(
target
,
kernel_size
,
reduce_func
)
def
expand
(
self
,
target
:
Tensor
,
expand_size
:
List
[
int
]):
# Similar with `self.shrink`, canonicalize kernel_size to expand_size length at first.
if
self
.
kernel_padding_mode
==
'front'
:
kernel_size
=
self
.
_padding
(
self
.
kernel_size
,
len
(
expand_size
),
1
,
'front'
)
elif
self
.
kernel_padding_mode
==
'back'
:
kernel_size
=
self
.
_padding
(
self
.
kernel_size
,
len
(
expand_size
),
-
1
,
'back'
)
else
:
raise
ValueError
(
f
'Unsupported kernel padding mode:
{
self
.
kernel_padding_mode
}
.'
)
return
self
.
_expand
(
target
,
kernel_size
,
expand_size
)
@
overload
def
validate
(
self
,
target
:
List
[
int
]):
...
@
overload
def
validate
(
self
,
target
:
Tensor
):
...
def
validate
(
self
,
target
:
List
[
int
]
|
Tensor
):
"""
Validate the target tensor can be shape-lossless scaling.
That means the shape will not change after `shrink` then `expand`.
"""
target
=
target
if
isinstance
(
target
,
Tensor
)
else
torch
.
rand
(
target
)
if
self
.
expand
((
self
.
shrink
(
target
)),
list
(
target
.
shape
)).
shape
!=
target
.
shape
:
raise
ValueError
(
f
'The tensor with shape
{
target
.
shape
}
, can not shape-lossless scaling with '
+
f
'kernel size is
{
self
.
kernel_size
}
and kernel_padding_mode is
{
self
.
kernel_padding_mode
}
.'
)
test/algo/compression/v2/test_pruning_tools_torch.py
View file @
f24dc27b
...
...
@@ -26,6 +26,7 @@ from nni.algorithms.compression.v2.pytorch.pruning.tools import (
)
from
nni.algorithms.compression.v2.pytorch.pruning.tools.base
import
HookCollectorInfo
from
nni.algorithms.compression.v2.pytorch.utils
import
get_module_by_name
from
nni.algorithms.compression.v2.pytorch.utils.scaling
import
Scaling
from
nni.algorithms.compression.v2.pytorch.utils.constructor_helper
import
OptimizerConstructHelper
...
...
@@ -112,7 +113,7 @@ class PruningToolsTestCase(unittest.TestCase):
def
test_metrics_calculator
(
self
):
# Test NormMetricsCalculator
metrics_calculator
=
NormMetricsCalculator
(
dim
=
0
,
p
=
2
)
metrics_calculator
=
NormMetricsCalculator
(
p
=
2
,
scalers
=
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
data
=
{
'1'
:
torch
.
ones
(
3
,
3
,
3
),
'2'
:
torch
.
ones
(
4
,
4
)
*
2
...
...
@@ -125,7 +126,7 @@ class PruningToolsTestCase(unittest.TestCase):
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
# Test DistMetricsCalculator
metrics_calculator
=
DistMetricsCalculator
(
dim
=
0
,
p
=
2
)
metrics_calculator
=
DistMetricsCalculator
(
p
=
2
,
scalers
=
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
data
=
{
'1'
:
torch
.
tensor
([[
1
,
2
],
[
4
,
6
]],
dtype
=
torch
.
float32
),
'2'
:
torch
.
tensor
([[
0
,
0
],
[
1
,
1
]],
dtype
=
torch
.
float32
)
...
...
@@ -138,7 +139,7 @@ class PruningToolsTestCase(unittest.TestCase):
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
# Test MultiDataNormMetricsCalculator
metrics_calculator
=
MultiDataNormMetricsCalculator
(
dim
=
0
,
p
=
1
)
metrics_calculator
=
MultiDataNormMetricsCalculator
(
p
=
1
,
scalers
=
Scaling
(
kernel_size
=
[
1
],
kernel_padding_mode
=
'back'
)
)
data
=
{
'1'
:
[
2
,
torch
.
ones
(
3
,
3
,
3
)
*
2
],
'2'
:
[
2
,
torch
.
ones
(
4
,
4
)
*
2
]
...
...
@@ -151,7 +152,7 @@ class PruningToolsTestCase(unittest.TestCase):
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
# Test APoZRankMetricsCalculator
metrics_calculator
=
APoZRankMetricsCalculator
(
dim
=
1
)
metrics_calculator
=
APoZRankMetricsCalculator
(
Scaling
(
kernel_size
=
[
-
1
,
1
],
kernel_padding_mode
=
'back'
)
)
data
=
{
'1'
:
[
2
,
torch
.
tensor
([[
1
,
1
],
[
1
,
1
]],
dtype
=
torch
.
float32
)],
'2'
:
[
2
,
torch
.
tensor
([[
0
,
0
,
1
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float32
)]
...
...
@@ -164,7 +165,7 @@ class PruningToolsTestCase(unittest.TestCase):
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
# Test MeanRankMetricsCalculator
metrics_calculator
=
MeanRankMetricsCalculator
(
dim
=
1
)
metrics_calculator
=
MeanRankMetricsCalculator
(
Scaling
(
kernel_size
=
[
-
1
,
1
],
kernel_padding_mode
=
'back'
)
)
data
=
{
'1'
:
[
2
,
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
float32
)],
'2'
:
[
2
,
torch
.
tensor
([[
0
,
0
,
1
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float32
)]
...
...
test/algo/compression/v2/test_scaling.py
0 → 100644
View file @
f24dc27b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
pytest
import
torch
from
nni.algorithms.compression.v2.pytorch.utils.scaling
import
Scaling
def
test_scaling
():
data
=
torch
.
tensor
([
_
for
_
in
range
(
100
)]).
reshape
(
10
,
10
)
scaler
=
Scaling
([
5
],
kernel_padding_mode
=
'front'
)
shrinked_data
=
scaler
.
shrink
(
data
)
assert
list
(
shrinked_data
.
shape
)
==
[
10
,
2
]
expanded_data
=
scaler
.
expand
(
data
,
[
10
,
50
])
assert
list
(
expanded_data
.
shape
)
==
[
10
,
50
]
scaler
=
Scaling
([
5
,
5
],
kernel_padding_mode
=
'back'
)
shrinked_data
=
scaler
.
shrink
(
data
)
assert
list
(
shrinked_data
.
shape
)
==
[
2
,
2
]
expanded_data
=
scaler
.
expand
(
data
,
[
50
,
50
,
10
])
assert
list
(
expanded_data
.
shape
)
==
[
50
,
50
,
10
]
scaler
.
validate
([
10
,
10
,
10
])
if
__name__
==
'__main__'
:
test_scaling
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment