Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2b9f5f8c
"testing/vscode:/vscode.git/clone" did not exist on "8cdc185bb4ca34fcfda70d7e329ddc30c44aadae"
Unverified
Commit
2b9f5f8c
authored
Aug 26, 2021
by
J-shang
Committed by
GitHub
Aug 26, 2021
Browse files
[Model Compression] update config list key (#4074)
parent
862c67df
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
209 additions
and
122 deletions
+209
-122
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+73
-61
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/tools/metrics_calculator.py
...ompression/v2/pytorch/pruning/tools/metrics_calculator.py
+2
-2
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+15
-10
nni/algorithms/compression/v2/pytorch/utils/config_validation.py
...orithms/compression/v2/pytorch/utils/config_validation.py
+34
-48
nni/algorithms/compression/v2/pytorch/utils/pruning.py
nni/algorithms/compression/v2/pytorch/utils/pruning.py
+84
-0
No files found.
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
2b9f5f8c
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
copy
import
deepcopy
import
logging
import
logging
from
typing
import
List
,
Dict
,
Tuple
,
Callable
,
Optional
from
typing
import
List
,
Dict
,
Tuple
,
Callable
,
Optional
from
schema
import
And
,
Optional
as
SchemaOptional
from
schema
import
And
,
Or
,
Optional
as
SchemaOptional
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -12,7 +13,8 @@ from torch.nn import Module
...
@@ -12,7 +13,8 @@ from torch.nn import Module
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
nni.algorithms.compression.v2.pytorch.base.pruner
import
Pruner
from
nni.algorithms.compression.v2.pytorch.base.pruner
import
Pruner
from
nni.algorithms.compression.v2.pytorch.utils.config_validation
import
PrunerSchema
from
nni.algorithms.compression.v2.pytorch.utils.config_validation
import
CompressorSchema
from
nni.algorithms.compression.v2.pytorch.utils.pruning
import
config_list_canonical
from
.tools
import
(
from
.tools
import
(
DataCollector
,
DataCollector
,
...
@@ -43,26 +45,47 @@ _logger = logging.getLogger(__name__)
...
@@ -43,26 +45,47 @@ _logger = logging.getLogger(__name__)
__all__
=
[
'LevelPruner'
,
'L1NormPruner'
,
'L2NormPruner'
,
'FPGMPruner'
,
'SlimPruner'
,
'ActivationPruner'
,
__all__
=
[
'LevelPruner'
,
'L1NormPruner'
,
'L2NormPruner'
,
'FPGMPruner'
,
'SlimPruner'
,
'ActivationPruner'
,
'ActivationAPoZRankPruner'
,
'ActivationMeanRankPruner'
,
'TaylorFOWeightPruner'
]
'ActivationAPoZRankPruner'
,
'ActivationMeanRankPruner'
,
'TaylorFOWeightPruner'
]
NORMAL_SCHEMA
=
{
Or
(
'sparsity'
,
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
]
}
GLOBAL_SCHEMA
=
{
'total_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
SchemaOptional
(
'max_sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<=
1
),
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
]
}
EXCLUDE_SCHEMA
=
{
'exclude'
:
bool
,
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
]
}
INTERNAL_SCHEMA
=
{
'total_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
SchemaOptional
(
'max_sparsity_per_layer'
):
{
str
:
float
},
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
]
}
class
OneShotPruner
(
Pruner
):
class
OneShotPruner
(
Pruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
self
.
data_collector
:
DataCollector
=
None
self
.
data_collector
:
DataCollector
=
None
self
.
metrics_calculator
:
MetricsCalculator
=
None
self
.
metrics_calculator
:
MetricsCalculator
=
None
self
.
sparsity_allocator
:
SparsityAllocator
=
None
self
.
sparsity_allocator
:
SparsityAllocator
=
None
self
.
_convert_config_list
(
config_list
)
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
_convert_config_list
(
self
,
config_list
:
List
[
Dict
]):
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
"""
self
.
_validate_config_before_canonical
(
model
,
config_list
)
Convert `sparsity` in config to `sparsity_per_layer`.
self
.
config_list
=
config_list_canonical
(
model
,
config_list
)
"""
for
config
in
config_list
:
def
_validate_config_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
if
'sparsity'
in
config
:
pass
if
'sparsity_per_layer'
in
config
:
raise
ValueError
(
"'sparsity' and 'sparsity_per_layer' have the same semantics, can not set both in one config."
)
else
:
config
[
'sparsity_per_layer'
]
=
config
.
pop
(
'sparsity'
)
def
reset
(
self
,
model
:
Optional
[
Module
],
config_list
:
Optional
[
List
[
Dict
]]):
def
reset
(
self
,
model
:
Optional
[
Module
],
config_list
:
Optional
[
List
[
Dict
]]):
super
().
reset
(
model
=
model
,
config_list
=
config_list
)
super
().
reset
(
model
=
model
,
config_list
=
config_list
)
...
@@ -115,14 +138,9 @@ class LevelPruner(OneShotPruner):
...
@@ -115,14 +138,9 @@ class LevelPruner(OneShotPruner):
self
.
mode
=
'normal'
self
.
mode
=
'normal'
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
_validate_config_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
schema_list
=
[
deepcopy
(
NORMAL_SCHEMA
),
deepcopy
(
EXCLUDE_SCHEMA
),
deepcopy
(
INTERNAL_SCHEMA
)]
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
CompressorSchema
(
schema_list
,
model
,
_logger
)
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
def
reset_tools
(
self
):
def
reset_tools
(
self
):
...
@@ -171,13 +189,11 @@ class NormPruner(OneShotPruner):
...
@@ -171,13 +189,11 @@ class NormPruner(OneShotPruner):
self
.
dummy_input
=
dummy_input
self
.
dummy_input
=
dummy_input
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
_validate_config_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
schema_list
=
[
deepcopy
(
NORMAL_SCHEMA
),
deepcopy
(
EXCLUDE_SCHEMA
),
deepcopy
(
INTERNAL_SCHEMA
)]
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
for
sub_shcema
in
schema_list
:
SchemaOptional
(
'op_types'
):
[
'Conv2d'
,
'Linear'
],
sub_shcema
[
SchemaOptional
(
'op_types'
)]
=
[
'Conv2d'
,
'Linear'
]
SchemaOptional
(
'op_names'
):
[
str
],
schema
=
CompressorSchema
(
schema_list
,
model
,
_logger
)
SchemaOptional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
...
@@ -291,13 +307,11 @@ class FPGMPruner(OneShotPruner):
...
@@ -291,13 +307,11 @@ class FPGMPruner(OneShotPruner):
self
.
dummy_input
=
dummy_input
self
.
dummy_input
=
dummy_input
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
_validate_config_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
schema_list
=
[
deepcopy
(
NORMAL_SCHEMA
),
deepcopy
(
EXCLUDE_SCHEMA
),
deepcopy
(
INTERNAL_SCHEMA
)]
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
for
sub_shcema
in
schema_list
:
SchemaOptional
(
'op_types'
):
[
'Conv2d'
,
'Linear'
],
sub_shcema
[
SchemaOptional
(
'op_types'
)]
=
[
'Conv2d'
,
'Linear'
]
SchemaOptional
(
'op_names'
):
[
str
],
schema
=
CompressorSchema
(
schema_list
,
model
,
_logger
)
SchemaOptional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
...
@@ -376,15 +390,15 @@ class SlimPruner(OneShotPruner):
...
@@ -376,15 +390,15 @@ class SlimPruner(OneShotPruner):
self
.
_scale
=
scale
self
.
_scale
=
scale
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
_
validate_config
_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
schema
_list
=
[
deepcopy
(
EXCLUDE_SCHEMA
),
deepcopy
(
INTERNAL_SCHEMA
)]
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
if
self
.
mode
==
'global'
:
S
chema
Optional
(
'total_sparsity'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
s
chema
_list
.
append
(
deepcopy
(
GLOBAL_SCHEMA
))
SchemaOptional
(
'max_sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
else
:
S
chema
Optional
(
'op_types'
):
[
'BatchNorm2d'
],
s
chema
_list
.
append
(
deepcopy
(
NORMAL_SCHEMA
))
SchemaOptional
(
'op_names'
):
[
str
],
for
sub_shcema
in
schema_list
:
SchemaOptional
(
'
exclude'
):
bool
sub_shcema
[
SchemaOptional
(
'
op_types'
)]
=
[
'BatchNorm2d'
]
}]
,
model
,
_logger
)
schema
=
CompressorSchema
(
schema_list
,
model
,
_logger
)
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
...
@@ -477,13 +491,11 @@ class ActivationPruner(OneShotPruner):
...
@@ -477,13 +491,11 @@ class ActivationPruner(OneShotPruner):
self
.
_activation
=
self
.
_choose_activation
(
activation
)
self
.
_activation
=
self
.
_choose_activation
(
activation
)
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
_validate_config_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
schema_list
=
[
deepcopy
(
NORMAL_SCHEMA
),
deepcopy
(
EXCLUDE_SCHEMA
),
deepcopy
(
INTERNAL_SCHEMA
)]
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
for
sub_shcema
in
schema_list
:
SchemaOptional
(
'op_types'
):
[
'Conv2d'
,
'Linear'
],
sub_shcema
[
SchemaOptional
(
'op_types'
)]
=
[
'Conv2d'
,
'Linear'
]
SchemaOptional
(
'op_names'
):
[
str
],
schema
=
CompressorSchema
(
schema_list
,
model
,
_logger
)
SchemaOptional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
...
@@ -603,19 +615,19 @@ class TaylorFOWeightPruner(OneShotPruner):
...
@@ -603,19 +615,19 @@ class TaylorFOWeightPruner(OneShotPruner):
self
.
training_batches
=
training_batches
self
.
training_batches
=
training_batches
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
validate_config
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
_
validate_config
_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema
=
PrunerSchema
([{
schema
_list
=
[
deepcopy
(
EXCLUDE_SCHEMA
),
deepcopy
(
INTERNAL_SCHEMA
)]
SchemaOptional
(
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
if
self
.
mode
==
'global'
:
S
chema
Optional
(
'total_sparsity'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
s
chema
_list
.
append
(
deepcopy
(
GLOBAL_SCHEMA
))
SchemaOptional
(
'max_sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
else
:
S
chema
Optional
(
'op_types'
):
[
'Conv2d'
,
'Linear'
],
s
chema
_list
.
append
(
deepcopy
(
NORMAL_SCHEMA
))
SchemaOptional
(
'op_names'
):
[
str
],
for
sub_shcema
in
schema_list
:
SchemaOptional
(
'
exclude'
):
bool
sub_shcema
[
SchemaOptional
(
'
op_types'
)]
=
[
'Conv2d'
,
'Linear'
]
}]
,
model
,
_logger
)
schema
=
CompressorSchema
(
schema_list
,
model
,
_logger
)
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
def
_collector
(
self
,
buffer
:
List
,
weight_tensor
:
Tensor
)
->
Callable
[[
Module
,
Tensor
,
Tensor
],
None
]:
def
_collector
(
self
,
buffer
:
List
,
weight_tensor
:
Tensor
)
->
Callable
[[
Tensor
],
None
]:
def
collect_taylor
(
grad
:
Tensor
):
def
collect_taylor
(
grad
:
Tensor
):
if
len
(
buffer
)
<
self
.
training_batches
:
if
len
(
buffer
)
<
self
.
training_batches
:
buffer
.
append
(
self
.
_calculate_taylor_expansion
(
weight_tensor
,
grad
))
buffer
.
append
(
self
.
_calculate_taylor_expansion
(
weight_tensor
,
grad
))
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
View file @
2b9f5f8c
...
@@ -437,6 +437,6 @@ class SparsityAllocator:
...
@@ -437,6 +437,6 @@ class SparsityAllocator:
mask
=
mask
.
unfold
(
i
,
step
,
step
)
mask
=
mask
.
unfold
(
i
,
step
,
step
)
ein_expression
+=
lower_case_letters
[
i
]
ein_expression
+=
lower_case_letters
[
i
]
ein_expression
=
'...{},{}'
.
format
(
ein_expression
,
ein_expression
)
ein_expression
=
'...{},{}'
.
format
(
ein_expression
,
ein_expression
)
mask
=
torch
.
einsum
(
ein_expression
,
mask
,
torch
.
ones
(
self
.
block_sparse_size
))
mask
=
torch
.
einsum
(
ein_expression
,
mask
,
torch
.
ones
(
self
.
block_sparse_size
)
.
to
(
mask
.
device
)
)
return
(
mask
!=
0
).
type_as
(
mask
)
return
(
mask
!=
0
).
type_as
(
mask
)
nni/algorithms/compression/v2/pytorch/pruning/tools/metrics_calculator.py
View file @
2b9f5f8c
...
@@ -120,7 +120,7 @@ class DistMetricsCalculator(MetricsCalculator):
...
@@ -120,7 +120,7 @@ class DistMetricsCalculator(MetricsCalculator):
metric
=
torch
.
ones
(
*
reorder_tensor
.
size
()[:
len
(
keeped_dim
)],
device
=
reorder_tensor
.
device
)
metric
=
torch
.
ones
(
*
reorder_tensor
.
size
()[:
len
(
keeped_dim
)],
device
=
reorder_tensor
.
device
)
across_dim
=
list
(
range
(
len
(
keeped_dim
),
len
(
reorder_dim
)))
across_dim
=
list
(
range
(
len
(
keeped_dim
),
len
(
reorder_dim
)))
idxs
=
metric
.
nonzero
()
idxs
=
metric
.
nonzero
(
as_tuple
=
False
)
for
idx
in
idxs
:
for
idx
in
idxs
:
other
=
reorder_tensor
other
=
reorder_tensor
for
i
in
idx
:
for
i
in
idx
:
...
@@ -161,7 +161,7 @@ class APoZRankMetricsCalculator(MetricsCalculator):
...
@@ -161,7 +161,7 @@ class APoZRankMetricsCalculator(MetricsCalculator):
for
dim
,
dim_size
in
enumerate
(
_eq_zero
.
size
()):
for
dim
,
dim_size
in
enumerate
(
_eq_zero
.
size
()):
if
dim
not
in
keeped_dim
:
if
dim
not
in
keeped_dim
:
total_size
*=
dim_size
total_size
*=
dim_size
_apoz
=
torch
.
sum
(
_eq_zero
,
dim
=
across_dim
,
d
type
=
torch
.
float64
)
/
total_size
_apoz
=
torch
.
sum
(
_eq_zero
,
dim
=
across_dim
).
type
_as
(
activations
)
/
total_size
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
metrics
[
name
]
=
torch
.
ones_like
(
_apoz
)
-
_apoz
metrics
[
name
]
=
torch
.
ones_like
(
_apoz
)
-
_apoz
return
metrics
return
metrics
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
2b9f5f8c
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
math
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -20,7 +21,7 @@ class NormalSparsityAllocator(SparsityAllocator):
...
@@ -20,7 +21,7 @@ class NormalSparsityAllocator(SparsityAllocator):
def
generate_sparsity
(
self
,
metrics
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
def
generate_sparsity
(
self
,
metrics
:
Dict
[
str
,
Tensor
])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
masks
=
{}
masks
=
{}
for
name
,
wrapper
in
self
.
pruner
.
get_modules_wrapper
().
items
():
for
name
,
wrapper
in
self
.
pruner
.
get_modules_wrapper
().
items
():
sparsity_rate
=
wrapper
.
config
[
'sparsity
_per_layer
'
]
sparsity_rate
=
wrapper
.
config
[
'
total_
sparsity'
]
assert
name
in
metrics
,
'Metric of %s is not calculated.'
assert
name
in
metrics
,
'Metric of %s is not calculated.'
metric
=
metrics
[
name
]
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
metric
=
metrics
[
name
]
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
...
@@ -58,27 +59,31 @@ class GlobalSparsityAllocator(SparsityAllocator):
...
@@ -58,27 +59,31 @@ class GlobalSparsityAllocator(SparsityAllocator):
temp_wrapper_config
=
self
.
pruner
.
get_modules_wrapper
()[
list
(
group_metric_dict
.
keys
())[
0
]].
config
temp_wrapper_config
=
self
.
pruner
.
get_modules_wrapper
()[
list
(
group_metric_dict
.
keys
())[
0
]].
config
total_sparsity
=
temp_wrapper_config
[
'total_sparsity'
]
total_sparsity
=
temp_wrapper_config
[
'total_sparsity'
]
max_sparsity_per_layer
=
temp_wrapper_config
.
get
(
'max_sparsity_per_layer'
,
1.0
)
max_sparsity_per_layer
=
temp_wrapper_config
.
get
(
'max_sparsity_per_layer'
,
{}
)
for
name
,
metric
in
group_metric_dict
.
items
():
for
name
,
metric
in
group_metric_dict
.
items
():
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
metric
=
metric
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
metric
=
metric
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
print
(
metric
)
layer_weight_num
=
wrapper
.
module
.
weight
.
data
.
numel
()
layer_weight_num
=
wrapper
.
module
.
weight
.
data
.
numel
()
stay_num
=
int
(
metric
.
numel
()
*
max_sparsity_per_layer
)
retention_ratio
=
1
-
max_sparsity_per_layer
.
get
(
name
,
1
)
retention_numel
=
math
.
ceil
(
retention_ratio
*
layer_weight_num
)
removed_metric_num
=
math
.
ceil
(
retention_numel
/
(
wrapper
.
weight_mask
.
numel
()
/
metric
.
numel
()))
stay_metric_num
=
metric
.
numel
()
-
removed_metric_num
# Remove the weight parts that must be left
# Remove the weight parts that must be left
stay_metric
=
torch
.
topk
(
metric
.
view
(
-
1
),
stay_num
,
largest
=
False
)[
0
]
stay_metric
=
torch
.
topk
(
metric
.
view
(
-
1
),
stay_
metric_
num
,
largest
=
False
)[
0
]
sub_thresholds
[
name
]
=
stay_metric
.
max
()
sub_thresholds
[
name
]
=
stay_metric
.
max
()
expend_times
=
int
(
layer_weight_num
/
metric
.
numel
())
expend_times
=
int
(
layer_weight_num
/
metric
.
numel
())
if
expend_times
>
1
:
if
expend_times
>
1
:
stay_metric
=
stay_metric
.
expand
(
stay_num
,
int
(
layer_weight_num
/
metric
.
numel
())).
view
(
-
1
)
stay_metric
=
stay_metric
.
expand
(
stay_
metric_
num
,
int
(
layer_weight_num
/
metric
.
numel
())).
view
(
-
1
)
metric_list
.
append
(
stay_metric
)
metric_list
.
append
(
stay_metric
)
total_weight_num
+=
layer_weight_num
total_weight_num
+=
layer_weight_num
assert
total_sparsity
<=
max_sparsity_per_layer
,
'total_sparsity should less than max_sparsity_per_layer.'
total_prune_num
=
int
(
total_sparsity
*
total_weight_num
)
total_prune_num
=
int
(
total_sparsity
*
total_weight_num
)
if
total_prune_num
==
0
:
threshold
=
torch
.
topk
(
torch
.
cat
(
metric_list
).
view
(
-
1
),
total_prune_num
,
largest
=
False
)[
0
].
max
().
item
()
threshold
=
torch
.
cat
(
metric_list
).
min
().
item
()
-
1
else
:
threshold
=
torch
.
topk
(
torch
.
cat
(
metric_list
).
view
(
-
1
),
total_prune_num
,
largest
=
False
)[
0
].
max
().
item
()
return
threshold
,
sub_thresholds
return
threshold
,
sub_thresholds
...
@@ -108,7 +113,7 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
...
@@ -108,7 +113,7 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
for
_
,
group_metric_dict
in
grouped_metrics
.
items
():
for
_
,
group_metric_dict
in
grouped_metrics
.
items
():
group_metric
=
self
.
_group_metric_calculate
(
group_metric_dict
)
group_metric
=
self
.
_group_metric_calculate
(
group_metric_dict
)
sparsities
=
{
name
:
self
.
pruner
.
get_modules_wrapper
()[
name
].
config
[
'sparsity
_per_layer
'
]
for
name
in
group_metric_dict
.
keys
()}
sparsities
=
{
name
:
self
.
pruner
.
get_modules_wrapper
()[
name
].
config
[
'
total_
sparsity'
]
for
name
in
group_metric_dict
.
keys
()}
min_sparsity
=
min
(
sparsities
.
values
())
min_sparsity
=
min
(
sparsities
.
values
())
conv2d_groups
=
[
self
.
group_depen
[
name
]
for
name
in
group_metric_dict
.
keys
()]
conv2d_groups
=
[
self
.
group_depen
[
name
]
for
name
in
group_metric_dict
.
keys
()]
...
...
nni/algorithms/compression/v2/pytorch/utils/config_validation.py
View file @
2b9f5f8c
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
logging
import
Logger
from
typing
import
Dict
,
List
from
schema
import
Schema
,
And
,
SchemaError
from
schema
import
Schema
,
And
,
SchemaError
from
torch.nn
import
Module
class
CompressorSchema
:
def
__init__
(
self
,
data_schema
:
List
[
Dict
],
model
:
Module
,
logger
:
Logger
):
assert
isinstance
(
data_schema
,
list
)
self
.
data_schema
=
data_schema
self
.
compressor_schema
=
Schema
(
self
.
_modify_schema
(
data_schema
,
model
,
logger
))
def
_modify_schema
(
self
,
data_schema
:
List
[
Dict
],
model
:
Module
,
logger
:
Logger
)
->
List
[
Dict
]:
if
not
data_schema
:
return
data_schema
for
i
,
sub_schema
in
enumerate
(
data_schema
):
for
k
,
old_schema
in
sub_schema
.
items
():
if
k
==
'op_types'
or
(
isinstance
(
k
,
Schema
)
and
k
.
_schema
==
'op_types'
):
new_schema
=
And
(
old_schema
,
lambda
n
:
validate_op_types
(
model
,
n
,
logger
))
sub_schema
[
k
]
=
new_schema
if
k
==
'op_names'
or
(
isinstance
(
k
,
Schema
)
and
k
.
_schema
==
'op_names'
):
new_schema
=
And
(
old_schema
,
lambda
n
:
validate_op_names
(
model
,
n
,
logger
))
sub_schema
[
k
]
=
new_schema
data_schema
[
i
]
=
And
(
sub_schema
,
lambda
d
:
validate_op_types_op_names
(
d
))
return
data_schema
def
validate
(
self
,
data
):
self
.
compressor_schema
.
validate
(
data
)
def
validate_op_names
(
model
,
op_names
,
logger
):
def
validate_op_names
(
model
,
op_names
,
logger
):
found_names
=
set
(
map
(
lambda
x
:
x
[
0
],
model
.
named_modules
()))
found_names
=
set
(
map
(
lambda
x
:
x
[
0
],
model
.
named_modules
()))
...
@@ -12,6 +44,7 @@ def validate_op_names(model, op_names, logger):
...
@@ -12,6 +44,7 @@ def validate_op_names(model, op_names, logger):
return
True
return
True
def
validate_op_types
(
model
,
op_types
,
logger
):
def
validate_op_types
(
model
,
op_types
,
logger
):
found_types
=
set
([
'default'
])
|
set
(
map
(
lambda
x
:
type
(
x
[
1
]).
__name__
,
model
.
named_modules
()))
found_types
=
set
([
'default'
])
|
set
(
map
(
lambda
x
:
type
(
x
[
1
]).
__name__
,
model
.
named_modules
()))
...
@@ -21,55 +54,8 @@ def validate_op_types(model, op_types, logger):
...
@@ -21,55 +54,8 @@ def validate_op_types(model, op_types, logger):
return
True
return
True
def
validate_op_types_op_names
(
data
):
def
validate_op_types_op_names
(
data
):
if
not
(
'op_types'
in
data
or
'op_names'
in
data
):
if
not
(
'op_types'
in
data
or
'op_names'
in
data
):
raise
SchemaError
(
'Either op_types or op_names must be specified.'
)
raise
SchemaError
(
'Either op_types or op_names must be specified.'
)
return
True
return
True
class
CompressorSchema
:
def
__init__
(
self
,
data_schema
,
model
,
logger
):
assert
isinstance
(
data_schema
,
list
)
and
len
(
data_schema
)
<=
1
self
.
data_schema
=
data_schema
self
.
compressor_schema
=
Schema
(
self
.
_modify_schema
(
data_schema
,
model
,
logger
))
def
_modify_schema
(
self
,
data_schema
,
model
,
logger
):
if
not
data_schema
:
return
data_schema
for
k
in
data_schema
[
0
]:
old_schema
=
data_schema
[
0
][
k
]
if
k
==
'op_types'
or
(
isinstance
(
k
,
Schema
)
and
k
.
_schema
==
'op_types'
):
new_schema
=
And
(
old_schema
,
lambda
n
:
validate_op_types
(
model
,
n
,
logger
))
data_schema
[
0
][
k
]
=
new_schema
if
k
==
'op_names'
or
(
isinstance
(
k
,
Schema
)
and
k
.
_schema
==
'op_names'
):
new_schema
=
And
(
old_schema
,
lambda
n
:
validate_op_names
(
model
,
n
,
logger
))
data_schema
[
0
][
k
]
=
new_schema
data_schema
[
0
]
=
And
(
data_schema
[
0
],
lambda
d
:
validate_op_types_op_names
(
d
))
return
data_schema
def
validate
(
self
,
data
):
self
.
compressor_schema
.
validate
(
data
)
def
validate_exclude_sparsity
(
data
):
if
not
(
'exclude'
in
data
or
'sparsity_per_layer'
in
data
or
'total_sparsity'
in
data
):
raise
SchemaError
(
'One of [sparsity_per_layer, total_sparsity, exclude] should be specified.'
)
return
True
def
validate_exclude_quant_types_quant_bits
(
data
):
if
not
(
'exclude'
in
data
or
(
'quant_types'
in
data
and
'quant_bits'
in
data
)):
raise
SchemaError
(
'Either (quant_types and quant_bits) or exclude must be specified.'
)
return
True
class
PrunerSchema
(
CompressorSchema
):
def
_modify_schema
(
self
,
data_schema
,
model
,
logger
):
data_schema
=
super
().
_modify_schema
(
data_schema
,
model
,
logger
)
data_schema
[
0
]
=
And
(
data_schema
[
0
],
lambda
d
:
validate_exclude_sparsity
(
d
))
return
data_schema
class
QuantizerSchema
(
CompressorSchema
):
def
_modify_schema
(
self
,
data_schema
,
model
,
logger
):
data_schema
=
super
().
_modify_schema
(
data_schema
,
model
,
logger
)
data_schema
[
0
]
=
And
(
data_schema
[
0
],
lambda
d
:
validate_exclude_quant_types_quant_bits
(
d
))
return
data_schema
nni/algorithms/compression/v2/pytorch/utils/pruning.py
0 → 100644
View file @
2b9f5f8c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
copy
import
deepcopy
from
typing
import
Dict
,
List
from
torch.nn
import
Module
def
config_list_canonical
(
model
:
Module
,
config_list
:
List
[
Dict
])
->
List
[
Dict
]:
'''
Split the config by op_names if 'sparsity' or 'sparsity_per_layer' in config,
and set the sub_config['total_sparsity'] = config['sparsity_per_layer'].
'''
for
config
in
config_list
:
if
'sparsity'
in
config
:
if
'sparsity_per_layer'
in
config
:
raise
ValueError
(
"'sparsity' and 'sparsity_per_layer' have the same semantics, can not set both in one config."
)
else
:
config
[
'sparsity_per_layer'
]
=
config
.
pop
(
'sparsity'
)
config_list
=
dedupe_config_list
(
unfold_config_list
(
model
,
config_list
))
new_config_list
=
[]
for
config
in
config_list
:
if
'sparsity_per_layer'
in
config
:
sparsity_per_layer
=
config
.
pop
(
'sparsity_per_layer'
)
op_names
=
config
.
pop
(
'op_names'
)
for
op_name
in
op_names
:
sub_config
=
deepcopy
(
config
)
sub_config
[
'op_names'
]
=
[
op_name
]
sub_config
[
'total_sparsity'
]
=
sparsity_per_layer
new_config_list
.
append
(
sub_config
)
elif
'max_sparsity_per_layer'
in
config
and
isinstance
(
config
[
'max_sparsity_per_layer'
],
float
):
op_names
=
config
.
get
(
'op_names'
,
[])
max_sparsity_per_layer
=
{}
max_sparsity
=
config
[
'max_sparsity_per_layer'
]
for
op_name
in
op_names
:
max_sparsity_per_layer
[
op_name
]
=
max_sparsity
config
[
'max_sparsity_per_layer'
]
=
max_sparsity_per_layer
new_config_list
.
append
(
config
)
else
:
new_config_list
.
append
(
config
)
return
new_config_list
def
unfold_config_list
(
model
:
Module
,
config_list
:
List
[
Dict
])
->
List
[
Dict
]:
'''
Unfold config_list to op_names level.
'''
unfolded_config_list
=
[]
for
config
in
config_list
:
op_names
=
[]
for
module_name
,
module
in
model
.
named_modules
():
module_type
=
type
(
module
).
__name__
if
'op_types'
in
config
and
module_type
not
in
config
[
'op_types'
]:
continue
if
'op_names'
in
config
and
module_name
not
in
config
[
'op_names'
]:
continue
op_names
.
append
(
module_name
)
unfolded_config
=
deepcopy
(
config
)
unfolded_config
[
'op_names'
]
=
op_names
unfolded_config_list
.
append
(
unfolded_config
)
return
unfolded_config_list
def
dedupe_config_list
(
config_list
:
List
[
Dict
])
->
List
[
Dict
]:
'''
Dedupe the op_names in unfolded config_list.
'''
exclude
=
set
()
exclude_idxes
=
[]
config_list
=
deepcopy
(
config_list
)
for
idx
,
config
in
reversed
(
list
(
enumerate
(
config_list
))):
if
'exclude'
in
config
:
exclude
.
update
(
config
[
'op_names'
])
exclude_idxes
.
append
(
idx
)
continue
config
[
'op_names'
]
=
sorted
(
list
(
set
(
config
[
'op_names'
]).
difference
(
exclude
)))
exclude
.
update
(
config
[
'op_names'
])
for
idx
in
sorted
(
exclude_idxes
,
reverse
=
True
):
config_list
.
pop
(
idx
)
return
config_list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment