Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2566badb
Unverified
Commit
2566badb
authored
Mar 23, 2022
by
J-shang
Committed by
GitHub
Mar 23, 2022
Browse files
[Model Compression] Pruning Wrapper Refactor (#4488)
parent
8d5f643c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
145 additions
and
108 deletions
+145
-108
nni/algorithms/compression/v2/pytorch/base/compressor.py
nni/algorithms/compression/v2/pytorch/base/compressor.py
+1
-8
nni/algorithms/compression/v2/pytorch/base/pruner.py
nni/algorithms/compression/v2/pytorch/base/pruner.py
+82
-10
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/movement_pruner.py
...orithms/compression/v2/pytorch/pruning/movement_pruner.py
+5
-80
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/tools/data_collector.py
...ms/compression/v2/pytorch/pruning/tools/data_collector.py
+2
-2
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+1
-1
test/ut/compression/v2/test_pruning_tools_torch.py
test/ut/compression/v2/test_pruning_tools_torch.py
+5
-5
test/ut/compression/v2/test_pruning_wrapper.py
test/ut/compression/v2/test_pruning_wrapper.py
+47
-0
No files found.
nni/algorithms/compression/v2/pytorch/base/compressor.py
View file @
2566badb
...
...
@@ -257,14 +257,7 @@ class Compressor:
Dict[str, str]
Return a dict `{original_model_parameter_name: wrapped_model_parameter_name}`
"""
if
self
.
is_wrapped
:
wrapped_param_names
=
{
id
(
param
):
name
for
name
,
param
in
self
.
bound_model
.
named_parameters
()}
self
.
_unwrap_model
()
parameter_name_map
=
{
name
:
wrapped_param_names
[
id
(
param
)]
for
name
,
param
in
self
.
bound_model
.
named_parameters
()}
self
.
_wrap_model
()
return
parameter_name_map
else
:
raise
Exception
(
'When only the model is wrapped can get the parameter_name_map.'
)
raise
NotImplementedError
()
def
_wrap_modules
(
self
,
layer
:
LayerInfo
,
config
:
Dict
):
"""
...
...
nni/algorithms/compression/v2/pytorch/base/pruner.py
View file @
2566badb
...
...
@@ -6,9 +6,9 @@ from typing import Dict, List, Optional, Tuple
import
torch
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.nn
import
Module
,
Parameter
from
.compressor
import
Compressor
,
LayerInfo
from
.compressor
import
Compressor
,
LayerInfo
,
_setattr
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -27,31 +27,57 @@ class PrunerModuleWrapper(Module):
The configurations that users specify for compression.
module_name
The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
"""
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
,
pruner
:
Compressor
):
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
):
super
().
__init__
()
# origin layer information
self
.
module
=
module
self
.
name
=
module_name
# config
and pruner
# config
information
self
.
config
=
config
self
.
pruner
=
pruner
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
module
.
weight
.
size
()))
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
module
.
bias
.
size
()))
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
def
_weight2buffer
(
self
):
"""
When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable.
The best place to call this function is in `Pruner._wrap_model()`.
"""
self
.
weight
.
data
=
self
.
module
.
weight
.
data
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
weight
.
data
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
bias
.
data
=
self
.
module
.
bias
.
data
delattr
(
self
.
module
,
'bias'
)
self
.
module
.
register_buffer
(
'bias'
,
self
.
bias
.
data
)
def
_weight2parameter
(
self
):
"""
When don't need to record score or need to export the model, call `_weight2parameter()` to make the original weight trainable.
The best place to call this function is in `Pruner._unwrap_model()`.
"""
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
weight
=
Parameter
(
torch
.
empty
(
self
.
weight
.
size
()))
self
.
module
.
weight
.
data
=
torch
.
mul
(
self
.
weight
,
self
.
weight_mask
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
delattr
(
self
.
module
,
'bias'
)
self
.
module
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
.
size
()))
self
.
module
.
bias
.
data
=
torch
.
mul
(
self
.
bias
,
self
.
bias_mask
)
def
forward
(
self
,
*
inputs
):
# apply mask to weight, bias
self
.
module
.
weight
.
data
=
self
.
module
.
weight
.
data
.
mul_
(
self
.
weight_mask
)
self
.
module
.
weight
=
torch
.
mul
(
self
.
weight
,
self
.
weight_mask
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
module
.
bias
.
data
=
self
.
module
.
bias
.
data
.
mul_
(
self
.
bias_mask
)
self
.
module
.
bias
=
torch
.
mul
(
self
.
bias
,
self
.
bias_mask
)
return
self
.
module
(
*
inputs
)
...
...
@@ -75,12 +101,58 @@ class Pruner(Compressor):
The configuration for generating the mask.
"""
_logger
.
debug
(
"Module detected to compress : %s."
,
layer
.
name
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
,
self
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
)
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
# move newly registered buffers to the same device of weight
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
# The following `_wrap_model`, `_unwrap_model`, `get_origin2wrapped_parameter_name_map` can merge to `Compressor`,
# if quantizer use the similar structure wrapper.
def
_wrap_model
(
self
):
"""
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
if
not
self
.
is_wrapped
:
for
_
,
wrapper
in
reversed
(
self
.
get_modules_wrapper
().
items
()):
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
wrapper
.
_weight2buffer
()
self
.
is_wrapped
=
True
def
_unwrap_model
(
self
):
"""
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
if
self
.
is_wrapped
:
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
wrapper
.
_weight2parameter
()
self
.
is_wrapped
=
False
def
get_origin2wrapped_parameter_name_map
(
self
)
->
Dict
[
str
,
str
]:
"""
Get the name mapping of parameters from original model to wrapped model.
Returns
-------
Dict[str, str]
Return a dict `{original_model_parameter_name: wrapped_model_parameter_name}`
"""
if
self
.
is_wrapped
:
wrapped_param_names
=
{
id
(
param
):
name
for
name
,
param
in
self
.
bound_model
.
named_parameters
()}
self
.
_unwrap_model
()
parameter_name_map
=
{}
for
name
,
param
in
self
.
bound_model
.
named_parameters
():
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`, the name will not change after wrap.
# If the parameter name in under wrapped module is others, the name `xxx.param` will change to `xxx.module.param` after wrap.
parameter_name_map
[
name
]
=
wrapped_param_names
[
id
(
param
)]
if
id
(
param
)
in
wrapped_param_names
else
name
self
.
_wrap_model
()
return
parameter_name_map
else
:
raise
Exception
(
'When only the model is wrapped can get the parameter_name_map.'
)
def
load_masks
(
self
,
masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]):
"""
Load an exist masks on the wrapper. You can train the model with an exist masks after load the masks.
...
...
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
2566badb
...
...
@@ -999,7 +999,7 @@ class TaylorFOWeightPruner(BasicPruner):
return
(
weight_tensor
.
detach
()
*
grad
.
detach
()).
data
.
pow
(
2
)
def
reset_tools
(
self
):
hook_targets
=
{
layer_info
.
name
:
layer_info
.
module
.
weight
for
layer_info
,
_
in
self
.
_detec
t_modules_
to_compres
s
()}
hook_targets
=
{
name
:
wrapper
.
weight
for
name
,
wrapper
in
self
.
ge
t_modules_
wrapper
().
item
s
()}
collector_info
=
HookCollectorInfo
(
hook_targets
,
'tensor'
,
self
.
_collector
)
if
self
.
data_collector
is
None
:
self
.
data_collector
=
SingleHookTrainerBasedDataCollector
(
self
,
self
.
trainer
,
self
.
optimizer_helper
,
self
.
criterion
,
...
...
nni/algorithms/compression/v2/pytorch/pruning/movement_pruner.py
View file @
2566badb
...
...
@@ -10,7 +10,7 @@ from torch import autograd, Tensor
from
torch.nn
import
Module
,
Parameter
from
torch.optim
import
Optimizer
,
Adam
from
nni.algorithms.compression.v2.pytorch.base
.compressor
import
Compressor
,
_setatt
r
,
LayerInfo
from
nni.algorithms.compression.v2.pytorch.base
import
PrunerModuleWrappe
r
,
LayerInfo
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
BasicPruner
,
NORMAL_SCHEMA
,
EXCLUDE_SCHEMA
,
INTERNAL_SCHEMA
from
nni.algorithms.compression.v2.pytorch.utils
import
CompressorSchema
,
OptimizerConstructHelper
from
nni.common.serializer
import
Traceable
...
...
@@ -25,7 +25,7 @@ from .tools import (
_logger
=
logging
.
getLogger
(
__name__
)
class
PrunerScoredModuleWrapper
(
Module
):
class
PrunerScoredModuleWrapper
(
PrunerModuleWrapper
):
"""
Wrap a module to enable data parallel, forward method customization and buffer registeration.
Different from `PrunerModuleWrapper`, `PrunerScoredModuleWrapper` will record the gradient.
...
...
@@ -38,56 +38,12 @@ class PrunerScoredModuleWrapper(Module):
The configurations that users specify for compression.
module_name
The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
"""
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
,
pruner
:
Compressor
):
super
().
__init__
()
# origin layer information
self
.
module
=
module
self
.
name
=
module_name
# config and pruner
self
.
config
=
config
self
.
pruner
=
pruner
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
module
.
weight
.
size
()))
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
):
super
().
__init__
(
module
,
module_name
,
config
)
self
.
weight_score
=
Parameter
(
torch
.
empty
(
self
.
weight
.
size
()))
torch
.
nn
.
init
.
constant_
(
self
.
weight_score
,
val
=
0.0
)
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
module
.
bias
.
size
()))
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
def
_weight2buffer
(
self
):
"""
When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable.
The best place to call this function is in `Pruner._wrap_model()`.
"""
self
.
weight
.
data
=
self
.
module
.
weight
.
data
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
weight
.
data
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
bias
.
data
=
self
.
module
.
bias
.
data
delattr
(
self
.
module
,
'bias'
)
self
.
module
.
register_buffer
(
'bias'
,
self
.
bias
.
data
)
def
_weight2parameter
(
self
):
"""
When don't need to record score or need to export the model, call `_weight2parameter()` to make the original weight trainable.
The best place to call this function is in `Pruner._unwrap_model()`.
"""
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
weight
=
Parameter
(
torch
.
empty
(
self
.
weight
.
size
()))
self
.
module
.
weight
.
data
=
torch
.
mul
(
self
.
weight
,
self
.
weight_mask
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
delattr
(
self
.
module
,
'bias'
)
self
.
module
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
.
size
()))
self
.
module
.
bias
.
data
=
torch
.
mul
(
self
.
bias
,
self
.
bias_mask
)
def
forward
(
self
,
*
inputs
):
# apply mask to weight, bias
self
.
module
.
weight
=
torch
.
mul
(
self
.
weight
,
_StraightThrough
.
apply
(
self
.
weight_score
,
self
.
weight_mask
))
...
...
@@ -259,28 +215,6 @@ class MovementPruner(BasicPruner):
else
:
self
.
data_collector
.
reset
()
def
_wrap_model
(
self
):
"""
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
if
not
self
.
is_wrapped
:
for
_
,
wrapper
in
reversed
(
self
.
get_modules_wrapper
().
items
()):
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
wrapper
.
_weight2buffer
()
self
.
is_wrapped
=
True
def
_unwrap_model
(
self
):
"""
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
if
self
.
is_wrapped
:
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
wrapper
.
_weight2parameter
()
self
.
is_wrapped
=
False
def
_wrap_modules
(
self
,
layer
:
LayerInfo
,
config
:
Dict
):
"""
Create a wrapper module to replace the original one.
...
...
@@ -294,21 +228,12 @@ class MovementPruner(BasicPruner):
The configuration for generating the mask.
"""
_logger
.
debug
(
"Module detected to compress : %s."
,
layer
.
name
)
wrapper
=
PrunerScoredModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
,
self
)
wrapper
=
PrunerScoredModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
)
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
# move newly registered buffers to the same device of weight
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
def
get_origin2wrapped_parameter_name_map
(
self
)
->
Dict
[
str
,
str
]:
if
self
.
is_wrapped
:
self
.
_unwrap_model
()
parameter_name_map
=
{
name
:
name
for
name
,
_
in
self
.
bound_model
.
named_parameters
()}
self
.
_wrap_model
()
return
parameter_name_map
else
:
raise
Exception
(
'When only the model is wrapped can get the parameter_name_map.'
)
def
compress
(
self
)
->
Tuple
[
Module
,
Dict
]:
# sparsity grow from 0
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
View file @
2566badb
...
...
@@ -384,7 +384,7 @@ class SparsityAllocator:
weight_mask
=
weight_mask
.
expand
(
expand_size
).
reshape
(
reshape_size
)
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
weight_size
=
wrapper
.
module
.
weight
.
data
.
size
()
weight_size
=
wrapper
.
weight
.
data
.
size
()
if
self
.
dim
is
None
:
assert
weight_mask
.
size
()
==
weight_size
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/data_collector.py
View file @
2566badb
...
...
@@ -24,7 +24,7 @@ class WeightDataCollector(DataCollector):
def
collect
(
self
)
->
Dict
[
str
,
Tensor
]:
data
=
{}
for
_
,
wrapper
in
self
.
compressor
.
get_modules_wrapper
().
items
():
data
[
wrapper
.
name
]
=
wrapper
.
module
.
weight
.
data
data
[
wrapper
.
name
]
=
wrapper
.
weight
.
data
return
data
...
...
@@ -39,7 +39,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
data
=
{}
for
_
,
wrapper
in
self
.
compressor
.
get_modules_wrapper
().
items
():
data
[
wrapper
.
name
]
=
wrapper
.
module
.
weight
.
data
data
[
wrapper
.
name
]
=
wrapper
.
weight
.
data
return
data
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
2566badb
...
...
@@ -132,7 +132,7 @@ class GlobalSparsityAllocator(SparsityAllocator):
if
self
.
continuous_mask
:
metric
=
metric
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
layer_weight_num
=
wrapper
.
module
.
weight
.
data
.
numel
()
layer_weight_num
=
wrapper
.
weight
.
data
.
numel
()
total_weight_num
+=
layer_weight_num
expend_times
=
int
(
layer_weight_num
/
metric
.
numel
())
...
...
test/ut/compression/v2/test_pruning_tools_torch.py
View file @
2566badb
...
...
@@ -83,17 +83,17 @@ class PruningToolsTestCase(unittest.TestCase):
# Test WeightDataCollector
data_collector
=
WeightDataCollector
(
pruner
)
data
=
data_collector
.
collect
()
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
module
.
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
# Test WeightTrainerBasedDataCollector
def
opt_after
():
model
.
conv1
.
module
.
weight
.
data
=
torch
.
ones
(
5
,
1
,
5
,
5
)
model
.
conv2
.
module
.
weight
.
data
=
torch
.
ones
(
10
,
5
,
5
,
5
)
model
.
conv1
.
weight
.
data
=
torch
.
ones
(
5
,
1
,
5
,
5
)
model
.
conv2
.
weight
.
data
=
torch
.
ones
(
10
,
5
,
5
,
5
)
optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
model
,
get_optimizer
(
model
))
data_collector
=
WeightTrainerBasedDataCollector
(
pruner
,
trainer
,
optimizer_helper
,
criterion
,
1
,
opt_after_tasks
=
[
opt_after
])
data
=
data_collector
.
collect
()
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
module
.
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
assert
all
(
t
.
numel
()
==
(
t
==
1
).
type_as
(
t
).
sum
().
item
()
for
t
in
data
.
values
())
# Test SingleHookTrainerBasedDataCollector
...
...
@@ -102,7 +102,7 @@ class PruningToolsTestCase(unittest.TestCase):
if
len
(
buffer
)
<
2
:
buffer
.
append
(
grad
.
clone
().
detach
())
return
collect_taylor
hook_targets
=
{
'conv1'
:
model
.
conv1
.
module
.
weight
,
'conv2'
:
model
.
conv2
.
module
.
weight
}
hook_targets
=
{
'conv1'
:
model
.
conv1
.
weight
,
'conv2'
:
model
.
conv2
.
weight
}
collector_info
=
HookCollectorInfo
(
hook_targets
,
'tensor'
,
_collector
)
optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
model
,
get_optimizer
(
model
))
...
...
test/ut/compression/v2/test_pruning_wrapper.py
0 → 100644
View file @
2566badb
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
unittest
import
torch
import
torch.nn.functional
as
F
from
nni.algorithms.compression.v2.pytorch.pruning
import
L1NormPruner
class
TorchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
5
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
10
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
class
PrunerTestCase
(
unittest
.
TestCase
):
def
test_pruner_module_wrapper
(
self
):
model
=
TorchModel
()
conv1_weight
=
model
.
conv1
.
weight
.
data
.
clone
()
conv2_weight
=
model
.
conv2
.
weight
.
data
.
clone
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
L1NormPruner
(
model
,
config_list
)
_
,
masks
=
pruner
.
compress
()
model
(
torch
.
rand
(
10
,
1
,
28
,
28
))
assert
torch
.
equal
(
model
.
conv1
.
weight
.
data
,
conv1_weight
)
assert
torch
.
equal
(
model
.
conv2
.
weight
.
data
,
conv2_weight
)
assert
torch
.
equal
(
model
.
conv1
.
module
.
weight
.
data
,
conv1_weight
*
masks
[
'conv1'
][
'weight'
])
assert
torch
.
equal
(
model
.
conv2
.
module
.
weight
.
data
,
conv2_weight
*
masks
[
'conv2'
][
'weight'
])
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment