Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
92f6754e
Unverified
Commit
92f6754e
authored
May 25, 2021
by
colorjam
Committed by
GitHub
May 25, 2021
Browse files
[Model Compression] Update api of iterative pruners (#3507)
parent
26f47727
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
116 additions
and
101 deletions
+116
-101
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+5
-7
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+14
-9
test/ut/sdk/test_dependecy_aware.py
test/ut/sdk/test_dependecy_aware.py
+33
-21
test/ut/sdk/test_model_speedup.py
test/ut/sdk/test_model_speedup.py
+2
-2
test/ut/sdk/test_pruners.py
test/ut/sdk/test_pruners.py
+62
-62
No files found.
nni/compression/pytorch/utils/mask_conflict.py
View file @
92f6754e
...
...
@@ -31,7 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# if the input is the path of the mask_file
assert
os
.
path
.
exists
(
masks
)
masks
=
torch
.
load
(
masks
)
assert
len
(
masks
)
>
0
,
'Mask tensor cannot be empty'
assert
len
(
masks
)
>
0
,
'Mask tensor cannot be empty'
# if the user uses the model and dummy_input to trace the model, we
# should get the traced model handly, so that, we only trace the
# model once, GroupMaskConflict and ChannelMaskConflict will reuse
...
...
@@ -181,10 +181,8 @@ class GroupMaskConflict(MaskFix):
w_mask
=
self
.
masks
[
layername
][
'weight'
]
shape
=
w_mask
.
size
()
count
=
np
.
prod
(
shape
[
1
:])
all_ones
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
count
).
nonzero
().
squeeze
(
1
).
tolist
()
all_zeros
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
0
).
nonzero
().
squeeze
(
1
).
tolist
()
all_ones
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
count
).
nonzero
().
squeeze
(
1
).
tolist
()
all_zeros
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
0
).
nonzero
().
squeeze
(
1
).
tolist
()
if
len
(
all_ones
)
+
len
(
all_zeros
)
<
w_mask
.
size
(
0
):
# In fine-grained pruning, skip this layer
_logger
.
info
(
'Layers %s using fine-grained pruning'
,
layername
)
...
...
@@ -198,7 +196,7 @@ class GroupMaskConflict(MaskFix):
group_masked
=
[]
for
i
in
range
(
group
):
_start
=
step
*
i
_end
=
step
*
(
i
+
1
)
_end
=
step
*
(
i
+
1
)
_tmp_list
=
list
(
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
group_masked
.
append
(
_tmp_list
)
...
...
@@ -286,7 +284,7 @@ class ChannelMaskConflict(MaskFix):
0
,
2
,
3
)
if
self
.
conv_prune_dim
==
0
else
(
1
,
2
,
3
)
channel_mask
=
(
mask
.
abs
().
sum
(
tmp_sum_idx
)
!=
0
).
int
()
channel_masks
.
append
(
channel_mask
)
if
(
channel_mask
.
sum
()
*
(
mask
.
numel
()
/
mask
.
shape
[
1
-
self
.
conv_prune_dim
])).
item
()
!=
(
mask
>
0
).
sum
().
item
():
if
(
channel_mask
.
sum
()
*
(
mask
.
numel
()
/
mask
.
shape
[
1
-
self
.
conv_prune_dim
])).
item
()
!=
(
mask
>
0
).
sum
().
item
():
fine_grained
=
True
else
:
raise
RuntimeError
(
...
...
test/ut/sdk/test_compressor_torch.py
View file @
92f6754e
...
...
@@ -61,9 +61,8 @@ class CompressorTestCase(TestCase):
def
test_torch_level_pruner
(
self
):
model
=
TorchModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
configure_list
=
[{
'sparsity'
:
0.8
,
'op_types'
:
[
'default'
]}]
torch_pruner
.
LevelPruner
(
model
,
configure_list
,
optimizer
).
compress
()
torch_pruner
.
LevelPruner
(
model
,
configure_list
).
compress
()
def
test_torch_naive_quantizer
(
self
):
model
=
TorchModel
()
...
...
@@ -93,7 +92,7 @@ class CompressorTestCase(TestCase):
model
=
TorchModel
()
config_list
=
[{
'sparsity'
:
0.6
,
'op_types'
:
[
'Conv2d'
]},
{
'sparsity'
:
0.2
,
'op_types'
:
[
'Conv2d'
]}]
pruner
=
torch_pruner
.
FPGMPruner
(
model
,
config_list
,
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
)
pruner
=
torch_pruner
.
FPGMPruner
(
model
,
config_list
)
model
.
conv2
.
module
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
masks
=
pruner
.
calc_mask
(
model
.
conv2
)
...
...
@@ -152,7 +151,7 @@ class CompressorTestCase(TestCase):
config_list
=
[{
'sparsity'
:
0.2
,
'op_types'
:
[
'BatchNorm2d'
]}]
model
.
bn1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
bn2
.
weight
.
data
=
torch
.
tensor
(
-
w
).
float
()
pruner
=
torch_pruner
.
SlimPruner
(
model
,
config_list
)
pruner
=
torch_pruner
.
SlimPruner
(
model
,
config_list
,
optimizer
=
None
,
trainer
=
None
,
criterion
=
None
)
mask1
=
pruner
.
calc_mask
(
model
.
bn1
)
mask2
=
pruner
.
calc_mask
(
model
.
bn2
)
...
...
@@ -165,7 +164,7 @@ class CompressorTestCase(TestCase):
config_list
=
[{
'sparsity'
:
0.6
,
'op_types'
:
[
'BatchNorm2d'
]}]
model
.
bn1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
bn2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
pruner
=
torch_pruner
.
SlimPruner
(
model
,
config_list
)
pruner
=
torch_pruner
.
SlimPruner
(
model
,
config_list
,
optimizer
=
None
,
trainer
=
None
,
criterion
=
None
)
mask1
=
pruner
.
calc_mask
(
model
.
bn1
)
mask2
=
pruner
.
calc_mask
(
model
.
bn2
)
...
...
@@ -202,8 +201,8 @@ class CompressorTestCase(TestCase):
model
=
TorchModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
pruner
=
torch_pruner
.
TaylorFOWeightFilterPruner
(
model
,
config_list
,
optimizer
,
statistics_batch_num
=
1
)
pruner
=
torch_pruner
.
TaylorFOWeightFilterPruner
(
model
,
config_list
,
optimizer
,
trainer
=
None
,
criterion
=
None
,
sparsity_training_epochs
=
1
)
x
=
torch
.
rand
((
1
,
1
,
28
,
28
),
requires_grad
=
True
)
model
.
conv1
.
module
.
weight
.
data
=
torch
.
tensor
(
w1
).
float
()
model
.
conv2
.
module
.
weight
.
data
=
torch
.
tensor
(
w2
).
float
()
...
...
@@ -345,7 +344,7 @@ class CompressorTestCase(TestCase):
],
[
{
'sparsity'
:
0.2
},
{
'sparsity'
:
0.6
,
'op_names'
:
'abc'
}
{
'sparsity'
:
0.6
,
'op_names'
:
'abc'
}
]
]
model
=
TorchModel
()
...
...
@@ -353,7 +352,13 @@ class CompressorTestCase(TestCase):
for
pruner_class
in
pruner_classes
:
for
config_list
in
bad_configs
:
try
:
pruner_class
(
model
,
config_list
,
optimizer
)
kwargs
=
{}
if
pruner_class
in
(
torch_pruner
.
SlimPruner
,
torch_pruner
.
AGPPruner
,
torch_pruner
.
ActivationMeanRankFilterPruner
,
torch_pruner
.
ActivationAPoZRankFilterPruner
):
kwargs
=
{
'optimizer'
:
None
,
'trainer'
:
None
,
'criterion'
:
None
}
print
(
'kwargs'
,
kwargs
)
pruner_class
(
model
,
config_list
,
**
kwargs
)
print
(
config_list
)
assert
False
,
'Validation error should be raised for bad configuration'
except
schema
.
SchemaError
:
...
...
test/ut/sdk/test_dependecy_aware.py
View file @
92f6754e
...
...
@@ -46,6 +46,24 @@ def generate_random_sparsity_v2(model):
'sparsity'
:
sparsity
})
return
cfg_list
def
train
(
model
,
criterion
,
optimizer
,
callback
=
None
):
model
.
train
()
device
=
next
(
model
.
parameters
()).
device
data
=
torch
.
randn
(
2
,
3
,
224
,
224
).
to
(
device
)
target
=
torch
.
tensor
([
0
,
1
]).
long
().
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
# callback should be inserted between loss.backward() and optimizer.step()
if
callback
:
callback
()
optimizer
.
step
()
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
,
callback
=
None
):
return
train
(
model
,
criterion
,
optimizer
,
callback
=
callback
)
class
DependencyawareTest
(
TestCase
):
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.3.0"
,
"not supported"
)
...
...
@@ -55,6 +73,7 @@ class DependencyawareTest(TestCase):
sparsity
=
0.7
cfg_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
sparsity
}]
dummy_input
=
torch
.
ones
(
1
,
3
,
224
,
224
)
for
model_name
in
model_zoo
:
for
pruner
in
pruners
:
print
(
'Testing on '
,
pruner
)
...
...
@@ -72,16 +91,12 @@ class DependencyawareTest(TestCase):
momentum
=
0.9
,
weight_decay
=
4e-5
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
tmp_pruner
=
pruner
(
net
,
cfg_list
,
optimizer
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
# train one single batch so that the the pruner can collect the
# statistic
optimizer
.
zero_grad
()
out
=
net
(
dummy_input
)
batchsize
=
dummy_input
.
size
(
0
)
loss
=
criterion
(
out
,
torch
.
zeros
(
batchsize
,
dtype
=
torch
.
int64
))
loss
.
backward
()
optimizer
.
step
()
if
pruner
==
TaylorFOWeightFilterPruner
:
tmp_pruner
=
pruner
(
net
,
cfg_list
,
optimizer
,
trainer
=
trainer
,
criterion
=
criterion
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
else
:
tmp_pruner
=
pruner
(
net
,
cfg_list
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
tmp_pruner
.
compress
()
tmp_pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
...
...
@@ -91,7 +106,7 @@ class DependencyawareTest(TestCase):
ms
.
speedup_model
()
for
name
,
module
in
net
.
named_modules
():
if
isinstance
(
module
,
nn
.
Conv2d
):
expected
=
int
(
ori_filters
[
name
]
*
(
1
-
sparsity
))
expected
=
int
(
ori_filters
[
name
]
*
(
1
-
sparsity
))
filter_diff
=
abs
(
expected
-
module
.
out_channels
)
errmsg
=
'%s Ori: %d, Expected: %d, Real: %d'
%
(
name
,
ori_filters
[
name
],
expected
,
module
.
out_channels
)
...
...
@@ -124,16 +139,13 @@ class DependencyawareTest(TestCase):
momentum
=
0.9
,
weight_decay
=
4e-5
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
tmp_pruner
=
pruner
(
net
,
cfg_list
,
optimizer
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
# train one single batch so that the the pruner can collect the
# statistic
optimizer
.
zero_grad
()
out
=
net
(
dummy_input
)
batchsize
=
dummy_input
.
size
(
0
)
loss
=
criterion
(
out
,
torch
.
zeros
(
batchsize
,
dtype
=
torch
.
int64
))
loss
.
backward
()
optimizer
.
step
()
if
pruner
in
(
TaylorFOWeightFilterPruner
,
ActivationMeanRankFilterPruner
,
ActivationAPoZRankFilterPruner
):
tmp_pruner
=
pruner
(
net
,
cfg_list
,
optimizer
,
trainer
=
trainer
,
criterion
=
criterion
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
else
:
tmp_pruner
=
pruner
(
net
,
cfg_list
,
dependency_aware
=
True
,
dummy_input
=
dummy_input
)
tmp_pruner
.
compress
()
tmp_pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
...
...
test/ut/sdk/test_model_speedup.py
View file @
92f6754e
...
...
@@ -17,7 +17,7 @@ from unittest import TestCase, main
from
nni.compression.pytorch
import
ModelSpeedup
,
apply_compression_results
from
nni.algorithms.compression.pytorch.pruning
import
L1FilterPruner
from
nni.algorithms.compression.pytorch.pruning.weight_masker
import
WeightMasker
from
nni.algorithms.compression.pytorch.pruning.
one_shot
import
_StructuredFilter
Pruner
from
nni.algorithms.compression.pytorch.pruning.
dependency_aware_pruner
import
DependencyAware
Pruner
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
...
...
@@ -205,7 +205,7 @@ class L1ChannelMasker(WeightMasker):
return
{
'weight_mask'
:
mask_weight
.
detach
(),
'bias_mask'
:
mask_bias
}
class
L1ChannelPruner
(
_StructuredFilter
Pruner
):
class
L1ChannelPruner
(
DependencyAware
Pruner
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dependency_aware
=
False
,
dummy_input
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l1'
,
optimizer
=
optimizer
,
dependency_aware
=
dependency_aware
,
dummy_input
=
dummy_input
)
...
...
test/ut/sdk/test_pruners.py
View file @
92f6754e
...
...
@@ -42,13 +42,10 @@ prune_config = {
'agp'
:
{
'pruner_class'
:
AGPPruner
,
'config_list'
:
[{
'initial_sparsity'
:
0.
,
'final_sparsity'
:
0.8
,
'start_epoch'
:
0
,
'end_epoch'
:
10
,
'frequency'
:
1
,
'sparsity'
:
0.8
,
'op_types'
:
[
'Conv2d'
]
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[]
},
'slim'
:
{
...
...
@@ -57,6 +54,7 @@ prune_config = {
'sparsity'
:
0.7
,
'op_types'
:
[
'BatchNorm2d'
]
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
bn1
,
0.7
,
model
.
bias
)
]
...
...
@@ -97,6 +95,7 @@ prune_config = {
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
...
...
@@ -107,6 +106,7 @@ prune_config = {
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
...
...
@@ -117,6 +117,7 @@ prune_config = {
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
...
...
@@ -127,7 +128,7 @@ prune_config = {
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
]
}],
'short_term_fine_tuner'
:
lambda
model
:
model
,
'short_term_fine_tuner'
:
lambda
model
:
model
,
'evaluator'
:
lambda
model
:
0.9
,
'validators'
:
[]
},
...
...
@@ -146,7 +147,7 @@ prune_config = {
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
,
callback
:
model
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
...
...
@@ -158,7 +159,7 @@ prune_config = {
'op_types'
:
[
'Conv2d'
],
}],
'base_algo'
:
'l1'
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
,
callback
:
model
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'evaluator'
:
lambda
model
:
0.9
,
'dummy_input'
:
torch
.
randn
([
64
,
1
,
28
,
28
]),
'validators'
:
[]
...
...
@@ -170,7 +171,7 @@ prune_config = {
'op_types'
:
[
'Conv2d'
],
}],
'base_algo'
:
'l2'
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
,
callback
:
model
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'evaluator'
:
lambda
model
:
0.9
,
'dummy_input'
:
torch
.
randn
([
64
,
1
,
28
,
28
]),
'validators'
:
[]
...
...
@@ -182,7 +183,7 @@ prune_config = {
'op_types'
:
[
'Conv2d'
],
}],
'base_algo'
:
'fpgm'
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
,
callback
:
model
,
'trainer'
:
lambda
model
,
optimizer
,
criterion
,
epoch
:
model
,
'evaluator'
:
lambda
model
:
0.9
,
'dummy_input'
:
torch
.
randn
([
64
,
1
,
28
,
28
]),
'validators'
:
[]
...
...
@@ -206,88 +207,87 @@ class Model(nn.Module):
def
forward
(
self
,
x
):
return
self
.
fc
(
self
.
pool
(
self
.
bn1
(
self
.
conv1
(
x
))).
view
(
x
.
size
(
0
),
-
1
))
class
SimpleDataset
:
def
__getitem__
(
self
,
index
):
return
torch
.
randn
(
3
,
32
,
32
),
1.
def
__len__
(
self
):
return
1000
def
train
(
model
,
train_loader
,
criterion
,
optimizer
):
model
.
train
()
device
=
next
(
model
.
parameters
()).
device
x
=
torch
.
randn
(
2
,
1
,
28
,
28
).
to
(
device
)
y
=
torch
.
tensor
([
0
,
1
]).
long
().
to
(
device
)
# print('hello...')
for
_
in
range
(
2
):
out
=
model
(
x
)
loss
=
criterion
(
out
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
def
pruners_test
(
pruner_names
=
[
'level'
,
'agp'
,
'slim'
,
'fpgm'
,
'l1'
,
'l2'
,
'taylorfo'
,
'mean_activation'
,
'apoz'
,
'netadapt'
,
'simulatedannealing'
,
'admm'
,
'autocompress_l1'
,
'autocompress_l2'
,
'autocompress_fpgm'
,],
bias
=
True
):
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
dummy_input
=
torch
.
randn
(
2
,
1
,
28
,
28
).
to
(
device
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
SimpleDataset
(),
batch_size
=
16
,
shuffle
=
False
,
drop_last
=
True
)
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
):
return
train
(
model
,
train_loader
,
criterion
,
optimizer
)
for
pruner_name
in
pruner_names
:
print
(
'testing {}...'
.
format
(
pruner_name
))
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
=
Model
(
bias
=
bias
).
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
config_list
=
prune_config
[
pruner_name
][
'config_list'
]
x
=
torch
.
randn
(
2
,
1
,
28
,
28
).
to
(
device
)
y
=
torch
.
tensor
([
0
,
1
]).
long
().
to
(
device
)
out
=
model
(
x
)
loss
=
F
.
cross_entropy
(
out
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
if
pruner_name
==
'netadapt'
:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
short_term_fine_tuner
=
prune_config
[
pruner_name
][
'short_term_fine_tuner'
],
evaluator
=
prune_config
[
pruner_name
][
'evaluator'
])
elif
pruner_name
==
'simulatedannealing'
:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
evaluator
=
prune_config
[
pruner_name
][
'evaluator'
])
elif
pruner_name
in
(
'agp'
,
'slim'
,
'taylorfo'
,
'apoz'
,
'mean_activation'
):
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
trainer
,
optimizer
=
optimizer
,
criterion
=
criterion
)
elif
pruner_name
==
'admm'
:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
prune_config
[
pruner_name
][
'
trainer
'
]
)
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
trainer
)
elif
pruner_name
.
startswith
(
'autocompress'
):
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
prune_config
[
pruner_name
][
'trainer'
],
evaluator
=
prune_config
[
pruner_name
][
'evaluator'
],
dummy_input
=
x
,
base_algo
=
prune_config
[
pruner_name
][
'base_algo'
])
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
prune_config
[
pruner_name
][
'trainer'
],
evaluator
=
prune_config
[
pruner_name
][
'evaluator'
],
criterion
=
torch
.
nn
.
CrossEntropyLoss
(),
dummy_input
=
dummy_input
,
base_algo
=
prune_config
[
pruner_name
][
'base_algo'
])
else
:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
optimizer
)
pruner
.
compress
()
x
=
torch
.
randn
(
2
,
1
,
28
,
28
).
to
(
device
)
y
=
torch
.
tensor
([
0
,
1
]).
long
().
to
(
device
)
out
=
model
(
x
)
loss
=
F
.
cross_entropy
(
out
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
if
pruner_name
==
'taylorfo'
:
# taylorfo algorithm calculate contributions at first iteration(step), and do pruning
# when iteration >= statistics_batch_num (default 1)
optimizer
.
step
()
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
)
pruner
.
compress
()
pruner
.
export_model
(
'./model_tmp.pth'
,
'./mask_tmp.pth'
,
'./onnx_tmp.pth'
,
input_shape
=
(
2
,
1
,
28
,
28
),
device
=
device
)
for
v
in
prune_config
[
pruner_name
][
'validators'
]:
v
(
model
)
filePaths
=
[
'./model_tmp.pth'
,
'./mask_tmp.pth'
,
'./onnx_tmp.pth'
,
'./search_history.csv'
,
'./search_result.json'
]
for
f
in
filePaths
:
if
os
.
path
.
exists
(
f
):
os
.
remove
(
f
)
def
_test_agp
(
pruning_algorithm
):
model
=
Model
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
config_list
=
prune_config
[
'agp'
][
'config_list'
]
pruner
=
AGPPruner
(
model
,
config_list
,
optimizer
,
pruning_algorithm
=
pruning_algorithm
)
pruner
.
compress
()
x
=
torch
.
randn
(
2
,
1
,
28
,
28
)
y
=
torch
.
tensor
([
0
,
1
]).
long
()
def
_test_agp
(
pruning_algorithm
):
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
SimpleDataset
(),
batch_size
=
16
,
shuffle
=
False
,
drop_last
=
True
)
model
=
Model
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
for
epoch
in
range
(
config_list
[
0
][
'start_epoch'
],
config_list
[
0
][
'end_epoch'
]
+
1
):
pruner
.
update_epoch
(
epoch
)
out
=
model
(
x
)
loss
=
F
.
cross_entropy
(
out
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
):
return
train
(
model
,
train_loader
,
criterion
,
optimizer
)
target_sparsity
=
pruner
.
compute_target_sparsity
(
config_list
[
0
])
actual_sparsity
=
(
model
.
conv1
.
weight_mask
==
0
).
sum
().
item
()
/
model
.
conv1
.
weight_mask
.
numel
()
# set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
assert
math
.
isclose
(
actual_sparsity
,
target_sparsity
,
abs_tol
=
0.2
)
config_list
=
prune_config
[
'agp'
][
'config_list'
]
pruner
=
AGPPruner
(
model
,
config_list
,
optimizer
=
optimizer
,
trainer
=
trainer
,
criterion
=
torch
.
nn
.
CrossEntropyLoss
(),
pruning_algorithm
=
pruning_algorithm
)
pruner
.
compress
()
class
SimpleDataset
:
def
__getitem__
(
self
,
index
):
return
torch
.
randn
(
3
,
32
,
32
),
1.
target_sparsity
=
pruner
.
compute_target_sparsity
(
config_list
[
0
])
actual_sparsity
=
(
model
.
conv1
.
weight_mask
==
0
).
sum
().
item
()
/
model
.
conv1
.
weight_mask
.
numel
()
# set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
assert
math
.
isclose
(
actual_sparsity
,
target_sparsity
,
abs_tol
=
0.2
)
def
__len__
(
self
):
return
1000
class
PrunerTestCase
(
TestCase
):
def
test_pruners
(
self
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment