Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
92f6754e
Unverified
Commit
92f6754e
authored
May 25, 2021
by
colorjam
Committed by
GitHub
May 25, 2021
Browse files
[Model Compression] Update api of iterative pruners (#3507)
parent
26f47727
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
116 additions
and
101 deletions
+116
-101
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+5
-7
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+14
-9
test/ut/sdk/test_dependecy_aware.py
test/ut/sdk/test_dependecy_aware.py
+33
-21
test/ut/sdk/test_model_speedup.py
test/ut/sdk/test_model_speedup.py
+2
-2
test/ut/sdk/test_pruners.py
test/ut/sdk/test_pruners.py
+62
-62
No files found.
nni/compression/pytorch/utils/mask_conflict.py
View file @
92f6754e
...
...
@@ -31,7 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# if the input is the path of the mask_file
assert
os
.
path
.
exists
(
masks
)
masks
=
torch
.
load
(
masks
)
assert
len
(
masks
)
>
0
,
'Mask tensor cannot be empty'
assert
len
(
masks
)
>
0
,
'Mask tensor cannot be empty'
# if the user uses the model and dummy_input to trace the model, we
# should get the traced model handly, so that, we only trace the
# model once, GroupMaskConflict and ChannelMaskConflict will reuse
...
...
@@ -181,10 +181,8 @@ class GroupMaskConflict(MaskFix):
w_mask
=
self
.
masks
[
layername
][
'weight'
]
shape
=
w_mask
.
size
()
count
=
np
.
prod
(
shape
[
1
:])
all_ones
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
count
).
nonzero
().
squeeze
(
1
).
tolist
()
all_zeros
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
0
).
nonzero
().
squeeze
(
1
).
tolist
()
all_ones
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
count
).
nonzero
().
squeeze
(
1
).
tolist
()
all_zeros
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
0
).
nonzero
().
squeeze
(
1
).
tolist
()
if
len
(
all_ones
)
+
len
(
all_zeros
)
<
w_mask
.
size
(
0
):
# In fine-grained pruning, skip this layer
_logger
.
info
(
'Layers %s using fine-grained pruning'
,
layername
)
...
...
@@ -198,7 +196,7 @@ class GroupMaskConflict(MaskFix):
group_masked
=
[]
for
i
in
range
(
group
):
_start
=
step
*
i
_end
=
step
*
(
i
+
1
)
_end
=
step
*
(
i
+
1
)
_tmp_list
=
list
(
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
group_masked
.
append
(
_tmp_list
)
...
...
@@ -286,7 +284,7 @@ class ChannelMaskConflict(MaskFix):
0
,
2
,
3
)
if
self
.
conv_prune_dim
==
0
else
(
1
,
2
,
3
)
channel_mask
=
(
mask
.
abs
().
sum
(
tmp_sum_idx
)
!=
0
).
int
()
channel_masks
.
append
(
channel_mask
)
if
(
channel_mask
.
sum
()
*
(
mask
.
numel
()
/
mask
.
shape
[
1
-
self
.
conv_prune_dim
])).
item
()
!=
(
mask
>
0
).
sum
().
item
():
if
(
channel_mask
.
sum
()
*
(
mask
.
numel
()
/
mask
.
shape
[
1
-
self
.
conv_prune_dim
])).
item
()
!=
(
mask
>
0
).
sum
().
item
():
fine_grained
=
True
else
:
raise
RuntimeError
(
...
...
test/ut/sdk/test_compressor_torch.py
View file @
92f6754e
...
...
@@ -61,9 +61,8 @@ class CompressorTestCase(TestCase):
def
test_torch_level_pruner
(
self
):
model
=
TorchModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
configure_list
=
[{
'sparsity'
:
0.8
,
'op_types'
:
[
'default'
]}]
torch_pruner
.
LevelPruner
(
model
,
configure_list
,
optimizer
).
compress
()
torch_pruner
.
LevelPruner
(
model
,
configure_list
).
compress
()
def
test_torch_naive_quantizer
(
self
):
model
=
TorchModel
()
...
...
@@ -93,7 +92,7 @@ class CompressorTestCase(TestCase):
model
=
TorchModel
()
config_list
=
[{
'sparsity'
:
0.6
,
'op_types'
:
[
'Conv2d'
]},
{
'sparsity'
:
0.2
,
'op_types'
:
[
'Conv2d'
]}]
pruner
=
torch_pruner
.
FPGMPruner
(
model
,
config_list
,
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
)
pruner
=
torch_pruner
.
FPGMPruner
(
model
,
config_list
)
model
.
conv2
.
module
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
masks
=
pruner
.
calc_mask
(
model
.
conv2
)
...
...
@@ -152,7 +151,7 @@ class CompressorTestCase(TestCase):
config_list
=
[{
'sparsity'
:
0.2
,
'op_types'
:
[
'BatchNorm2d'
]}]
model
.
bn1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
bn2
.
weight
.
data
=
torch
.
tensor
(
-
w
).
float
()
pruner
=
torch_pruner
.
SlimPruner
(
model
,
config_list
)
pruner
=
torch_pruner
.
SlimPruner
(
model
,
config_list
,
optimizer
=
None
,
trainer
=
None
,
criterion
=
None
)
mask1
=
pruner
.
calc_mask
(
model
.
bn1
)
mask2
=
pruner
.
calc_mask
(
model
.
bn2
)
...
...
@@ -165,7 +164,7 @@ class CompressorTestCase(TestCase):
config_list
=
[{
'sparsity'
:
0.6
,
'op_types'
:
[
'BatchNorm2d'
]}]
model
.
bn1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
bn2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
pruner
=
torch_pruner
.
SlimPruner
(
model
,
config_list
)
pruner
=
torch_pruner
.
SlimPruner
(
model
,
config_list
,
optimizer
=
None
,
trainer
=
None
,
criterion
=
None
)
mask1
=
pruner
.
calc_mask
(
model
.
bn1
)
mask2
=
pruner
.
calc_mask
(
model
.
bn2
)
...
...
@@ -202,8 +201,8 @@ class CompressorTestCase(TestCase):
model
=
TorchModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
pruner
=
torch_pruner
.
TaylorFOWeightFilterPruner
(
model
,
config_list
,
optimizer
,
statistics_batch_num
=
1
)
pruner
=
torch_pruner
.
TaylorFOWeightFilterPruner
(
model
,
config_list
,
optimizer
,
trainer
=
None
,
criterion
=
None
,
sparsity_training_epochs
=
1
)
x
=
torch
.
rand
((
1
,
1
,
28
,
28
),
requires_grad
=
True
)
model
.
conv1
.
module
.
weight
.
data
=
torch
.
tensor
(
w1
).
float
()
model
.
conv2
.
module
.
weight
.
data
=
torch
.
tensor
(
w2
).
float
()
...
...
@@ -345,7 +344,7 @@ class CompressorTestCase(TestCase):
],
[
{
'sparsity'
:
0.2
},
{
'sparsity'
:
0.6
,
'op_names'
:
'abc'
}
{
'sparsity'
:
0.6
,
'op_names'
:
'abc'
}
]
]
model
=
TorchModel
()
...
...
@@ -353,7 +352,13 @@ class CompressorTestCase(TestCase):
for
pruner_class
in
pruner_classes
:
for
config_list
in
bad_configs
:
try
:
pruner_class
(
model
,
config_list
,
optimizer
)
kwargs
=
{}
if
pruner_class
in
(
torch_pruner
.
SlimPruner
,
torch_pruner
.
AGPPruner
,
torch_pruner
.
ActivationMeanRankFilterPruner
,
torch_pruner
.
ActivationAPoZRankFilterPruner
):
kwargs
=
{
'optimizer'
:
None
,
'trainer'
:
None
,
'criterion'
:
None
}
print
(
'kwargs'
,
kwargs
)
pruner_class
(
model
,
config_list
,
**
kwargs
)
print
(
config_list
)
assert
False
,
'Validation error should be raised for bad configuration'
except
schema
.
SchemaError
:
...
...
test/ut/sdk/test_dependecy_aware.py
View file @
92f6754e
...
...
@@ -46,6 +46,24 @@ def generate_random_sparsity_v2(model):
'sparsity'
:
sparsity
})
return
cfg_list
def
train
(
model
,
criterion
,
optimizer
,
callback
=
None
):
model
.
train
()
device
=
next
(
model
.
parameters
()).
device
data
=
torch
.
randn
(
2
,
3
,
224
,
224
).
to
(
device
)
target
=
torch
.
tensor
([
0
,
1
]).
long
().
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
# callback should be inserted between loss.backward() and optimizer.step()
if
callback
:
callback
()
optimizer
.
step
()
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
,
callback
=
None
):
return
train
(
model
,
criterion
,
optimizer
,
callback
=
callback
)
class
DependencyawareTest
(
TestCase
):
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.3.0"
,
"not supported"
)
...
...
@@ -55,6 +73,7 @@ class DependencyawareTest(TestCase):
sparsity
=
0.7
cfg_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
sparsity
}]
dummy_input
=
torch
.
ones
(
1
,
3
,
224
,
224
)
for
model_name
in
model_zoo
:
for
pruner
in
pruners
:
print
(
'Testing on '
,
pruner
)
...
...
@@ -72,16 +91,12 @@ class DependencyawareTest(TestCase):
momentum
=
0.9
,
weight_decay
=
4e-5
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
tmp_pruner
=
pruner
(
net
,
cfg_list
,
optimizer
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
# train one single batch so that the the pruner can collect the
# statistic
optimizer
.
zero_grad
()
out
=
net
(
dummy_input
)
batchsize
=
dummy_input
.
size
(
0
)
loss
=
criterion
(
out
,
torch
.
zeros
(
batchsize
,
dtype
=
torch
.
int64
))
loss
.
backward
()
optimizer
.
step
()
if
pruner
==
TaylorFOWeightFilterPruner
:
tmp_pruner
=
pruner
(
net
,
cfg_list
,
optimizer
,
trainer
=
trainer
,
criterion
=
criterion
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
else
:
tmp_pruner
=
pruner
(
net
,
cfg_list
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
tmp_pruner
.
compress
()
tmp_pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
...
...
@@ -91,7 +106,7 @@ class DependencyawareTest(TestCase):
ms
.
speedup_model
()
for
name
,
module
in
net
.
named_modules
():
if
isinstance
(
module
,
nn
.
Conv2d
):
expected
=
int
(
ori_filters
[
name
]
*
(
1
-
sparsity
))
expected
=
int
(
ori_filters
[
name
]
*
(
1
-
sparsity
))
filter_diff
=
abs
(
expected
-
module
.
out_channels
)
errmsg
=
'%s Ori: %d, Expected: %d, Real: %d'
%
(
name
,
ori_filters
[
name
],
expected
,
module
.
out_channels
)
...
...
@@ -124,16 +139,13 @@ class DependencyawareTest(TestCase):
momentum
=
0.9
,
weight_decay
=
4e-5
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
tmp_pruner
=
pruner
(
net
,
cfg_list
,
optimizer
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
# train one single batch so that the the pruner can collect the
# statistic
optimizer
.
zero_grad
()
out
=
net
(
dummy_input
)
batchsize
=
dummy_input
.
size
(
0
)
loss
=
criterion
(
out
,
torch
.
zeros
(
batchsize
,
dtype
=
torch
.
int64
))
loss
.
backward
()
optimizer
.
step
()
if
pruner
in
(
TaylorFOWeightFilterPruner
,
ActivationMeanRankFilterPruner
,
ActivationAPoZRankFilterPruner
):
tmp_pruner
=
pruner
(
net
,
cfg_list
,
optimizer
,
trainer
=
trainer
,
criterion
=
criterion
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
else
:
tmp_pruner
=
pruner
(
net
,
cfg_list
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
tmp_pruner
.
compress
()
tmp_pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
...
...
test/ut/sdk/test_model_speedup.py
View file @
92f6754e
...
...
@@ -17,7 +17,7 @@ from unittest import TestCase, main
from
nni.compression.pytorch
import
ModelSpeedup
,
apply_compression_results
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
from
nni.algorithms.compression.pytorch.pruning.weight_masker
import
WeightMasker
from
nni.algorithms.compression.pytorch.pruning.
one_shot
import
_StructuredFilter
Pruner
from
nni.algorithms.compression.pytorch.pruning.
dependency_aware_pruner
import
DependencyAware
Pruner
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
...
...
@@ -205,7 +205,7 @@ class L1ChannelMasker(WeightMasker):
return
{
'weight_mask'
:
mask_weight
.
detach
(),
'bias_mask'
:
mask_bias
}
class
L1ChannelPruner
(
_StructuredFilter
Pruner
):
class
L1ChannelPruner
(
DependencyAware
Pruner
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l1'
,
optimizer
=
optimizer
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
...
...
test/ut/sdk/test_pruners.py
View file @
92f6754e
...
...
@@ -42,13 +42,10 @@ prune_config = {
'agp'
:
{
'pruner_class'
:
AGPPruner
,
'config_list'
:
[{
'initial_sparsity'
:
0.
,
'final_sparsity'
:
0.8
,
'start_epoch'
:
0
,
'end_epoch'
:
10
,
'frequency'
:
1
,
'sparsity'
:
0.8
,
'op_types'
:
[
'Conv2d'
]
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[]
},
'slim'
:
{
...
...
@@ -57,6 +54,7 @@ prune_config = {
'sparsity'
:
0.7
,
'op_types'
:
[
'BatchNorm2d'
]
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
bn1
,
0.7
,
model
.
bias
)
]
...
...
@@ -97,6 +95,7 @@ prune_config = {
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
...
...
@@ -107,6 +106,7 @@ prune_config = {
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
...
...
@@ -117,6 +117,7 @@ prune_config = {
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
...
...
@@ -127,7 +128,7 @@ prune_config = {
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
]
}],
'short_term_fine_tuner'
:
lambda
model
:
model
,
'short_term_fine_tuner'
:
lambda
model
:
model
,
'evaluator'
:
lambda
model
:
0.9
,
'validators'
:
[]
},
...
...
@@ -146,7 +147,7 @@ prune_config = {
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
,
callback
:
model
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
...
...
@@ -158,7 +159,7 @@ prune_config = {
'op_types'
:
[
'Conv2d'
],
}],
'base_algo'
:
'l1'
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
,
callback
:
model
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'evaluator'
:
lambda
model
:
0.9
,
'dummy_input'
:
torch
.
randn
([
64
,
1
,
28
,
28
]),
'validators'
:
[]
...
...
@@ -170,7 +171,7 @@ prune_config = {
'op_types'
:
[
'Conv2d'
],
}],
'base_algo'
:
'l2'
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
,
callback
:
model
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'evaluator'
:
lambda
model
:
0.9
,
'dummy_input'
:
torch
.
randn
([
64
,
1
,
28
,
28
]),
'validators'
:
[]
...
...
@@ -182,7 +183,7 @@ prune_config = {
'op_types'
:
[
'Conv2d'
],
}],
'base_algo'
:
'fpgm'
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
,
callback
:
model
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'evaluator'
:
lambda
model
:
0.9
,
'dummy_input'
:
torch
.
randn
([
64
,
1
,
28
,
28
]),
'validators'
:
[]
...
...
@@ -206,88 +207,87 @@ class Model(nn.Module):
def
forward
(
self
,
x
):
return
self
.
fc
(
self
.
pool
(
self
.
bn1
(
self
.
conv1
(
x
))).
view
(
x
.
size
(
0
),
-
1
))
class
SimpleDataset
:
def
__getitem__
(
self
,
index
):
return
torch
.
randn
(
3
,
32
,
32
),
1.
def
__len__
(
self
):
return
1000
def
train
(
model
,
train_loader
,
criterion
,
optimizer
):
model
.
train
()
device
=
next
(
model
.
parameters
()).
device
x
=
torch
.
randn
(
2
,
1
,
28
,
28
).
to
(
device
)
y
=
torch
.
tensor
([
0
,
1
]).
long
().
to
(
device
)
# print('hello...')
for
_
in
range
(
2
):
out
=
model
(
x
)
loss
=
criterion
(
out
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
def
pruners_test
(
pruner_names
=
[
'level'
,
'agp'
,
'slim'
,
'fpgm'
,
'l1'
,
'l2'
,
'taylorfo'
,
'mean_activation'
,
'apoz'
,
'netadapt'
,
'simulatedannealing'
,
'admm'
,
'autocompress_l1'
,
'autocompress_l2'
,
'autocompress_fpgm'
,],
bias
=
True
):
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
dummy_input
=
torch
.
randn
(
2
,
1
,
28
,
28
).
to
(
device
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
SimpleDataset
(),
batch_size
=
16
,
shuffle
=
False
,
drop_last
=
True
)
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
):
return
train
(
model
,
train_loader
,
criterion
,
optimizer
)
for
pruner_name
in
pruner_names
:
print
(
'testing {}...'
.
format
(
pruner_name
))
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
=
Model
(
bias
=
bias
).
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
config_list
=
prune_config
[
pruner_name
][
'config_list'
]
x
=
torch
.
randn
(
2
,
1
,
28
,
28
).
to
(
device
)
y
=
torch
.
tensor
([
0
,
1
]).
long
().
to
(
device
)
out
=
model
(
x
)
loss
=
F
.
cross_entropy
(
out
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
if
pruner_name
==
'netadapt'
:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
short_term_fine_tuner
=
prune_config
[
pruner_name
][
'short_term_fine_tuner'
],
evaluator
=
prune_config
[
pruner_name
][
'evaluator'
])
elif
pruner_name
==
'simulatedannealing'
:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
evaluator
=
prune_config
[
pruner_name
][
'evaluator'
])
elif
pruner_name
in
(
'agp'
,
'slim'
,
'taylorfo'
,
'apoz'
,
'mean_activation'
):
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
trainer
,
optimizer
=
optimizer
,
criterion
=
criterion
)
elif
pruner_name
==
'admm'
:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
prune_config
[
pruner_name
][
'
trainer
'
]
)
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
trainer
)
elif
pruner_name
.
startswith
(
'autocompress'
):
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
prune_config
[
pruner_name
][
'trainer'
],
evaluator
=
prune_config
[
pruner_name
][
'evaluator'
],
dummy_input
=
x
,
base_algo
=
prune_config
[
pruner_name
][
'base_algo'
])
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
prune_config
[
pruner_name
][
'trainer'
],
evaluator
=
prune_config
[
pruner_name
][
'evaluator'
],
criterion
=
torch
.
nn
.
CrossEntropyLoss
(),
dummy_input
=
dummy_input
,
base_algo
=
prune_config
[
pruner_name
][
'base_algo'
])
else
:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
optimizer
)
pruner
.
compress
()
x
=
torch
.
randn
(
2
,
1
,
28
,
28
).
to
(
device
)
y
=
torch
.
tensor
([
0
,
1
]).
long
().
to
(
device
)
out
=
model
(
x
)
loss
=
F
.
cross_entropy
(
out
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
if
pruner_name
==
'taylorfo'
:
# taylorfo algorithm calculate contributions at first iteration(step), and do pruning
# when iteration >= statistics_batch_num (default 1)
optimizer
.
step
()
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
)
pruner
.
compress
()
pruner
.
export_model
(
'./model_tmp.pth'
,
'./mask_tmp.pth'
,
'./onnx_tmp.pth'
,
input_shape
=
(
2
,
1
,
28
,
28
),
device
=
device
)
for
v
in
prune_config
[
pruner_name
][
'validators'
]:
v
(
model
)
filePaths
=
[
'./model_tmp.pth'
,
'./mask_tmp.pth'
,
'./onnx_tmp.pth'
,
'./search_history.csv'
,
'./search_result.json'
]
for
f
in
filePaths
:
if
os
.
path
.
exists
(
f
):
os
.
remove
(
f
)
def
_test_agp
(
pruning_algorithm
):
model
=
Model
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
config_list
=
prune_config
[
'agp'
][
'config_list'
]
pruner
=
AGPPruner
(
model
,
config_list
,
optimizer
,
pruning_algorithm
=
pruning_algorithm
)
pruner
.
compress
()
x
=
torch
.
randn
(
2
,
1
,
28
,
28
)
y
=
torch
.
tensor
([
0
,
1
]).
long
()
def
_test_agp
(
pruning_algorithm
):
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
SimpleDataset
(),
batch_size
=
16
,
shuffle
=
False
,
drop_last
=
True
)
model
=
Model
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
for
epoch
in
range
(
config_list
[
0
][
'start_epoch'
],
config_list
[
0
][
'end_epoch'
]
+
1
):
pruner
.
update_epoch
(
epoch
)
out
=
model
(
x
)
loss
=
F
.
cross_entropy
(
out
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
):
return
train
(
model
,
train_loader
,
criterion
,
optimizer
)
target_sparsity
=
pruner
.
compute_target_sparsity
(
config_list
[
0
])
actual_sparsity
=
(
model
.
conv1
.
weight_mask
==
0
).
sum
().
item
()
/
model
.
conv1
.
weight_mask
.
numel
()
# set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
assert
math
.
isclose
(
actual_sparsity
,
target_sparsity
,
abs_tol
=
0.2
)
config_list
=
prune_config
[
'agp'
][
'config_list'
]
pruner
=
AGPPruner
(
model
,
config_list
,
optimizer
=
optimizer
,
trainer
=
trainer
,
criterion
=
torch
.
nn
.
CrossEntropyLoss
(),
pruning_algorithm
=
pruning_algorithm
)
pruner
.
compress
()
class
SimpleDataset
:
def
__getitem__
(
self
,
index
):
return
torch
.
randn
(
3
,
32
,
32
),
1.
target_sparsity
=
pruner
.
compute_target_sparsity
(
config_list
[
0
])
actual_sparsity
=
(
model
.
conv1
.
weight_mask
==
0
).
sum
().
item
()
/
model
.
conv1
.
weight_mask
.
numel
()
# set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
assert
math
.
isclose
(
actual_sparsity
,
target_sparsity
,
abs_tol
=
0.2
)
def
__len__
(
self
):
return
1000
class
PrunerTestCase
(
TestCase
):
def
test_pruners
(
self
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment