Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7811307c
Unverified
Commit
7811307c
authored
Sep 19, 2022
by
J-shang
Committed by
GitHub
Sep 19, 2022
Browse files
[Test] fix compression ut (#5129)
* fix test often failed * update * fix lint
parent
c88ac7b9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
16 deletions
+23
-16
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+14
-7
test/algo/compression/assets/common.py
test/algo/compression/assets/common.py
+4
-4
test/algo/compression/assets/simple_mnist/simple_torch_model.py
...lgo/compression/assets/simple_mnist/simple_torch_model.py
+5
-5
No files found.
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
7811307c
...
...
@@ -212,7 +212,6 @@ class DependencyAwareAllocator(SparsityAllocator):
return
fused_metrics
def
common_target_masks_generation
(
self
,
metrics
:
Dict
[
str
,
Dict
[
str
,
Tensor
]])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
# placeholder, here we need more discussion about dependence sparsity, Plan A or Plan B.
masks
=
{}
# generate public part for modules that have dependencies
for
module_names
in
self
.
channel_dependency
:
...
...
@@ -228,7 +227,8 @@ class DependencyAwareAllocator(SparsityAllocator):
group_nums
=
[
self
.
group_dependency
.
get
(
module_name
,
1
)
for
module_name
in
sub_metrics
.
keys
()]
max_group_nums
=
int
(
np
.
lcm
.
reduce
(
group_nums
))
pruned_numel_per_group
=
int
(
fused_metric
.
numel
()
//
max_group_nums
*
min_sparsity_rate
)
numel_per_group
=
fused_metric
.
numel
()
//
max_group_nums
kept_numel_per_group
=
numel_per_group
-
int
(
numel_per_group
*
min_sparsity_rate
)
group_step
=
fused_metric
.
shape
[
0
]
//
max_group_nums
# get the public part of the mask of the module with dependencies
...
...
@@ -236,9 +236,15 @@ class DependencyAwareAllocator(SparsityAllocator):
for
gid
in
range
(
max_group_nums
):
_start
=
gid
*
group_step
_end
=
(
gid
+
1
)
*
group_step
if
pruned_numel_per_group
>
0
:
threshold
=
torch
.
topk
(
fused_metric
[
_start
:
_end
].
reshape
(
-
1
),
pruned_numel_per_group
,
largest
=
False
)[
0
].
max
()
dependency_mask
[
_start
:
_end
]
=
torch
.
gt
(
fused_metric
[
_start
:
_end
],
threshold
).
type_as
(
fused_metric
)
if
kept_numel_per_group
>
0
:
flatten_partial_fused_metric
=
fused_metric
[
_start
:
_end
].
reshape
(
-
1
)
kept_indices
=
torch
.
topk
(
flatten_partial_fused_metric
,
kept_numel_per_group
).
indices
flatten_partial_mask
=
torch
.
zeros_like
(
flatten_partial_fused_metric
).
scatter
(
0
,
kept_indices
,
1.0
)
dependency_mask
[
_start
:
_end
]
=
flatten_partial_mask
.
reshape_as
(
dependency_mask
[
_start
:
_end
])
else
:
# all zeros means this target will be whole masked, will break the model in most cases,
# maybe replace this layer to identity layer in the future
dependency_mask
[
_start
:
_end
]
=
torch
.
zeros_like
(
dependency_mask
[
_start
:
_end
])
# change the metric value corresponding to the public mask part to the minimum value
for
module_name
,
targets_metric
in
sub_metrics
.
items
():
...
...
@@ -262,8 +268,9 @@ class DependencyAwareAllocator(SparsityAllocator):
sparsity_rate
=
wrapper
.
config
[
'total_sparsity'
]
prune_num
=
int
(
sparsity_rate
*
target_metric
.
numel
())
if
prune_num
!=
0
:
threshold
=
torch
.
topk
(
target_metric
.
reshape
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
shrinked_mask
=
torch
.
gt
(
target_metric
,
threshold
).
type_as
(
target_metric
)
flatten_metric
=
target_metric
.
reshape
(
-
1
)
kept_indices
=
torch
.
topk
(
flatten_metric
,
target_metric
.
numel
()
-
prune_num
).
indices
shrinked_mask
=
torch
.
zeros_like
(
flatten_metric
).
scatter
(
0
,
kept_indices
,
1.0
).
reshape_as
(
target_metric
)
else
:
# target_metric should have the same size as shrinked_mask
shrinked_mask
=
torch
.
ones_like
(
target_metric
)
...
...
test/algo/compression/assets/common.py
View file @
7811307c
...
...
@@ -15,12 +15,12 @@ log_dir = Path(__file__).parent.parent / 'logs'
def
create_model
(
model_type
:
str
):
torch_config_list
=
[{
'op_types'
:
[
'Linear'
],
'sparsity'
:
0.5
},
{
'op_names'
:
[
'conv1'
,
'conv2'
,
'conv3'
],
'sparsity'
:
0.5
},
torch_config_list
=
[{
'op_types'
:
[
'Linear'
],
'sparsity'
:
0.
7
5
},
{
'op_names'
:
[
'conv1'
,
'conv2'
,
'conv3'
],
'sparsity'
:
0.
7
5
},
{
'op_names'
:
[
'fc2'
],
'exclude'
:
True
}]
lightning_config_list
=
[{
'op_types'
:
[
'Linear'
],
'sparsity'
:
0.5
},
{
'op_names'
:
[
'model.conv1'
,
'model.conv2'
,
'model.conv3'
],
'sparsity'
:
0.5
},
lightning_config_list
=
[{
'op_types'
:
[
'Linear'
],
'sparsity'
:
0.
7
5
},
{
'op_names'
:
[
'model.conv1'
,
'model.conv2'
,
'model.conv3'
],
'sparsity'
:
0.
7
5
},
{
'op_names'
:
[
'model.fc2'
],
'exclude'
:
True
}]
if
model_type
==
'lightning'
:
...
...
test/algo/compression/assets/simple_mnist/simple_torch_model.py
View file @
7811307c
...
...
@@ -23,11 +23,11 @@ class SimpleTorchModel(torch.nn.Module):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
16
,
3
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
16
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
16
,
8
,
3
,
groups
=
4
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
8
)
self
.
conv3
=
torch
.
nn
.
Conv2d
(
16
,
8
,
3
)
self
.
bn3
=
torch
.
nn
.
BatchNorm2d
(
8
)
self
.
fc1
=
torch
.
nn
.
Linear
(
8
*
24
*
24
,
100
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
16
,
32
,
3
,
groups
=
4
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
32
)
self
.
conv3
=
torch
.
nn
.
Conv2d
(
16
,
32
,
3
)
self
.
bn3
=
torch
.
nn
.
BatchNorm2d
(
32
)
self
.
fc1
=
torch
.
nn
.
Linear
(
32
*
24
*
24
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment