Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a254f058
Unverified
Commit
a254f058
authored
Dec 28, 2021
by
J-shang
Committed by
GitHub
Dec 28, 2021
Browse files
fix pre-masks inherit (#4428)
parent
6e643b00
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
1 deletion
+6
-1
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+6
-1
No files found.
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
a254f058
...
...
@@ -36,6 +36,8 @@ class NormalSparsityAllocator(SparsityAllocator):
threshold
=
torch
.
topk
(
metric
.
view
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
if
self
.
continuous_mask
:
masks
[
name
][
'weight'
]
*=
wrapper
.
weight_mask
return
masks
...
...
@@ -55,6 +57,8 @@ class GlobalSparsityAllocator(SparsityAllocator):
for
name
,
metric
in
group_metric_dict
.
items
():
mask
=
torch
.
gt
(
metric
,
min
(
threshold
,
sub_thresholds
[
name
])).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
if
self
.
continuous_mask
:
masks
[
name
][
'weight'
]
*=
self
.
pruner
.
get_modules_wrapper
()[
name
].
weight_mask
return
masks
def
_calculate_threshold
(
self
,
group_metric_dict
:
Dict
[
str
,
Tensor
])
->
Tuple
[
float
,
Dict
[
str
,
float
]]:
...
...
@@ -158,7 +162,8 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
threshold
=
torch
.
topk
(
metric
,
pruned_num
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
if
self
.
continuous_mask
:
masks
[
name
][
'weight'
]
*=
self
.
pruner
.
get_modules_wrapper
()[
name
].
weight_mask
return
masks
def
_group_metric_calculate
(
self
,
group_metrics
:
Union
[
Dict
[
str
,
Tensor
],
List
[
Tensor
]])
->
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment