Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
cd98c48f
Unverified
Commit
cd98c48f
authored
Aug 11, 2022
by
J-shang
Committed by
GitHub
Aug 11, 2022
Browse files
[Compression] Improve bank sparse execution efficiency (#5033)
parent
d03c411c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
21 deletions
+22
-21
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+22
-21
No files found.
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
cd98c48f
...
...
@@ -3,7 +3,7 @@
from
__future__
import
annotations
import
itertools
from
functools
import
reduce
from
typing
import
Any
,
Dict
import
numpy
as
np
...
...
@@ -69,28 +69,29 @@ class BankSparsityAllocator(SparsityAllocator):
assert
n_dim
>=
len
(
self
.
balance_gran
),
'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran
balance_gran
=
[
1
]
*
(
n_dim
-
len
(
self
.
balance_gran
))
+
self
.
balance_gran
balance_numel
=
reduce
(
lambda
x
,
y
:
x
*
y
,
balance_gran
)
reshape_size_split
=
[]
reshape_size_balance
=
[]
for
i
,
j
in
zip
(
target_metric
.
shape
,
balance_gran
):
assert
i
%
j
==
0
,
'Length of {} {} is not aligned with balance granularity'
.
format
(
module_name
,
target_name
)
# FIXME: The following code need refactor, do it after scaling refactor is done.
shrinked_mask
=
torch
.
ones
(
target_metric
.
shape
).
type_as
(
target_metric
)
loop_iters
=
[
range
(
int
(
i
/
j
))
for
i
,
j
in
zip
(
target_metric
.
shape
,
balance_gran
)]
for
iter_params
in
itertools
.
product
(
*
loop_iters
):
index_str_list
=
[
f
"
{
iter_param
*
gran
}
:
{
(
iter_param
+
1
)
*
gran
}
"
\
for
iter_param
,
gran
in
zip
(
iter_params
,
balance_gran
)]
index_str
=
","
.
join
(
index_str_list
)
sub_metric_str
=
"target_metric[{}]"
.
format
(
index_str
)
sub_mask_str
=
"shrinked_mask[{}] = mask_bank"
.
format
(
index_str
)
metric_bank
:
Tensor
=
eval
(
sub_metric_str
)
prune_num
=
int
(
sparsity_rate
*
metric_bank
.
numel
())
# mask_bank will be used in exec(sub_mask_str)
if
prune_num
!=
0
:
threshold
=
torch
.
topk
(
metric_bank
.
reshape
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
mask_bank
=
torch
.
gt
(
metric_bank
,
threshold
).
type_as
(
metric_bank
)
# type: ignore
else
:
mask_bank
=
torch
.
ones_like
(
metric_bank
)
# type: ignore
mask_bank
=
mask_bank
# `type: ignore` is useless for unused-variable error, add this line to workaround
exec
(
sub_mask_str
)
reshape_size_split
.
extend
([
i
//
j
,
j
])
reshape_size_balance
.
append
(
i
//
j
)
reshape_size_balance
.
append
(
balance_numel
)
permute_dims_balance
=
[
_
*
2
for
_
in
range
(
n_dim
)]
+
[
_
*
2
+
1
for
_
in
range
(
n_dim
)]
_target_metric
=
target_metric
.
reshape
(
reshape_size_split
).
permute
(
permute_dims_balance
)
reshape_size_split_p
=
_target_metric
.
shape
balance_metric
=
_target_metric
.
reshape
(
reshape_size_balance
)
kept_num
=
balance_numel
-
int
(
sparsity_rate
*
balance_numel
)
kept_indices
=
torch
.
topk
(
balance_metric
,
kept_num
).
indices
shrinked_mask
=
torch
.
zeros_like
(
balance_metric
).
scatter
(
-
1
,
kept_indices
,
1.0
).
reshape
(
reshape_size_split_p
)
permute_dims_split
=
[]
for
i
in
range
(
n_dim
):
permute_dims_split
.
extend
([
i
,
i
+
n_dim
])
shrinked_mask
=
shrinked_mask
.
permute
(
permute_dims_split
).
reshape_as
(
target_metric
)
masks
[
module_name
][
target_name
]
=
self
.
_expand_mask
(
module_name
,
target_name
,
shrinked_mask
)
return
masks
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment