Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
1abaf627
Unverified
Commit
1abaf627
authored
Dec 01, 2021
by
J-shang
Committed by
GitHub
Dec 01, 2021
Browse files
[Compression v2] bugfix & improvement (#4307)
parent
0845d79b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
4 deletions
+10
-4
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/tools/data_collector.py
...ms/compression/v2/pytorch/pruning/tools/data_collector.py
+2
-2
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+7
-1
No files found.
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
1abaf627
...
...
@@ -414,7 +414,7 @@ class SlimPruner(BasicPruner):
def
patched_criterion
(
input_tensor
:
Tensor
,
target
:
Tensor
):
sum_l1
=
0
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
sum_l1
+=
torch
.
norm
(
wrapper
.
module
.
weight
.
data
,
p
=
1
)
sum_l1
+=
torch
.
norm
(
wrapper
.
module
.
weight
,
p
=
1
)
return
criterion
(
input_tensor
,
target
)
+
self
.
_scale
*
sum_l1
return
patched_criterion
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/data_collector.py
View file @
1abaf627
...
...
@@ -24,7 +24,7 @@ class WeightDataCollector(DataCollector):
def
collect
(
self
)
->
Dict
[
str
,
Tensor
]:
data
=
{}
for
_
,
wrapper
in
self
.
compressor
.
get_modules_wrapper
().
items
():
data
[
wrapper
.
name
]
=
wrapper
.
module
.
weight
.
data
.
clone
().
detach
()
data
[
wrapper
.
name
]
=
wrapper
.
module
.
weight
.
data
return
data
...
...
@@ -39,7 +39,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
data
=
{}
for
_
,
wrapper
in
self
.
compressor
.
get_modules_wrapper
().
items
():
data
[
wrapper
.
name
]
=
wrapper
.
module
.
weight
.
data
.
clone
().
detach
()
data
[
wrapper
.
name
]
=
wrapper
.
module
.
weight
.
data
return
data
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
1abaf627
...
...
@@ -24,6 +24,8 @@ class NormalSparsityAllocator(SparsityAllocator):
sparsity_rate
=
wrapper
.
config
[
'total_sparsity'
]
assert
name
in
metrics
,
'Metric of %s is not calculated.'
# We assume the metric value are all positive right now.
metric
=
metrics
[
name
]
if
self
.
continuous_mask
:
metric
*=
self
.
_compress_mask
(
wrapper
.
weight_mask
)
...
...
@@ -66,8 +68,11 @@ class GlobalSparsityAllocator(SparsityAllocator):
for
name
,
metric
in
group_metric_dict
.
items
():
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
# We assume the metric value are all positive right now.
if
self
.
continuous_mask
:
metric
=
metric
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
layer_weight_num
=
wrapper
.
module
.
weight
.
data
.
numel
()
total_weight_num
+=
layer_weight_num
expend_times
=
int
(
layer_weight_num
/
metric
.
numel
())
...
...
@@ -147,7 +152,8 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
group_mask
=
torch
.
cat
(
group_mask
,
dim
=
0
)
for
name
,
metric
in
group_metric_dict
.
items
():
metric
=
(
metric
-
metric
.
min
())
*
group_mask
# We assume the metric value are all positive right now.
metric
=
metric
*
group_mask
pruned_num
=
int
(
sparsities
[
name
]
*
len
(
metric
))
threshold
=
torch
.
topk
(
metric
,
pruned_num
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment