Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4cf68009
"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "bf07ea6baf7a61ca0979715a5e27b26b6272c7d2"
Unverified
Commit
4cf68009
authored
Aug 26, 2022
by
Xing
Committed by
GitHub
Aug 26, 2022
Browse files
Add Group Norm support for Pruning model (#5069)
parent
858daf9f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
143 additions
and
20 deletions
+143
-20
nni/common/graph_utils.py
nni/common/graph_utils.py
+4
-2
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+104
-14
nni/compression/pytorch/speedup/infer_mask.py
nni/compression/pytorch/speedup/infer_mask.py
+2
-0
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+33
-4
No files found.
nni/common/graph_utils.py
View file @
4cf68009
...
...
@@ -6,6 +6,7 @@ import logging
import
queue
import
re
from
collections
import
defaultdict
from
typing
import
List
,
Dict
import
torch
from
torch.utils.tensorboard._pytorch_graph
import
NodePy
,
NodePyIO
,
NodePyOP
,
GraphPy
CLASSTYPE_KIND
=
'ClassType'
...
...
@@ -262,6 +263,7 @@ class TorchModuleGraph(TorchGraph):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
super
().
__init__
(
model
,
dummy_input
,
traced_model
)
self
.
name_to_node
:
Dict
[
str
,
NodePyOP
]
self
.
global_count
=
0
self
.
reused_module
=
set
()
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
...
...
@@ -802,7 +804,7 @@ class TorchModuleGraph(TorchGraph):
node_group
.
auxiliary
=
self
.
_extract_cat_info
(
node_group
,
cpp_node
)
def
find_predecessors
(
self
,
unique_name
):
def
find_predecessors
(
self
,
unique_name
)
->
List
[
str
]
:
"""
Find predecessor node of the given node
...
...
@@ -825,7 +827,7 @@ class TorchModuleGraph(TorchGraph):
predecessors
.
append
(
node_py
.
unique_name
)
return
predecessors
def
find_successors
(
self
,
unique_name
):
def
find_successors
(
self
,
unique_name
)
->
List
[
str
]
:
"""
Find successor nodes of the given node
...
...
nni/compression/pytorch/speedup/compress_modules.py
View file @
4cf68009
...
...
@@ -48,7 +48,8 @@ replace_module = {
'ConvTranspose2d'
:
lambda
module
,
masks
:
replace_convtranspose2d
(
module
,
masks
),
'Embedding'
:
lambda
module
,
masks
:
replace_embedding
(
module
,
masks
),
'PixelShuffle'
:
lambda
module
,
masks
:
replace_pixelshuffle
(
module
,
masks
),
'Flatten'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
)
'Flatten'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'GroupNorm'
:
lambda
module
,
masks
:
replace_groupnorm
(
module
,
masks
),
}
...
...
@@ -310,6 +311,90 @@ def replace_batchnorm2d(norm, masks):
return
new_norm
def
replace_groupnorm
(
norm
:
nn
.
GroupNorm
,
masks
):
"""
Parameters
----------
norm : torch.nn.GroupNorm
The group norm module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.GroupNorm
The new group norm module
"""
in_masks
,
output_mask
,
_
=
masks
assert
isinstance
(
norm
,
nn
.
GroupNorm
)
in_mask
=
in_masks
[
0
]
# N, C, H, W
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
assert
len
(
remained_in
.
size
())
==
1
if
remained_in
.
size
(
0
)
!=
remained_out
.
size
(
0
):
raise
ShapeMisMatchError
()
ori_channel_step
=
norm
.
num_channels
//
norm
.
num_groups
for
groupid
in
range
(
norm
.
num_groups
):
in_start
=
groupid
*
ori_channel_step
in_end
=
in_start
+
ori_channel_step
new_channel_step
=
torch
.
logical_and
(
in_start
<=
remained_in
,
remained_in
<
in_end
,
).
sum
().
item
()
# this group fully pruned
if
new_channel_step
==
0
:
continue
break
new_groups
=
0
# Validate
for
groupid
in
range
(
norm
.
num_groups
):
in_start
=
groupid
*
ori_channel_step
in_end
=
in_start
+
ori_channel_step
num_item
=
torch
.
logical_and
(
in_start
<=
remained_in
,
remained_in
<
in_end
,
).
sum
().
item
()
if
num_item
==
0
:
continue
# check if the number of remained channel of each group are the same
if
num_item
!=
new_channel_step
:
raise
UnBalancedGroupError
()
new_groups
+=
1
new_num_channels
=
remained_in
.
size
()[
0
]
new_module
=
nn
.
GroupNorm
(
new_groups
,
new_num_channels
,
eps
=
norm
.
eps
,
affine
=
norm
.
affine
,
)
if
new_module
.
affine
:
new_module
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
remained_in
,
)
new_module
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
remained_in
,
)
return
new_module
def
replace_instancenorm2d
(
norm
,
masks
):
"""
Parameters
...
...
@@ -409,18 +494,20 @@ def replace_conv2d(conv, masks):
in_end
=
in_start
+
ori_inchannel_step
out_start
=
groupid
*
ori_outchannel_step
out_end
=
out_start
+
ori_outchannel_step
current_input_index
=
list
(
filter
(
lambda
x
:
in_start
<=
x
and
x
<
in_end
,
remained_in
.
tolist
()))
current_output_index
=
list
(
filter
(
lambda
x
:
out_start
<=
x
and
x
<
out_end
,
remained_out
.
tolist
()))
new_inchannel_step
:
int
=
torch
.
logical_and
(
in_start
<=
remained_in
,
remained_in
<
in_end
).
sum
().
item
()
new_outchannel_step
:
int
=
torch
.
logical_and
(
out_start
<=
remained_out
,
remained_out
<
out_end
).
sum
().
item
()
# remap the global index to the group index
if
len
(
current_input_index
)
==
0
:
if
new_inchannel_step
==
0
:
# if the whole group are pruned
continue
else
:
new_inchannel_step
=
len
(
current_input_index
)
new_outchannel_step
=
len
(
current_output_index
)
break
tmp_weight
=
torch
.
ones
(
n_remained_out
,
new_inchannel_step
,
k_size1
,
k_size2
)
...
...
@@ -436,12 +523,15 @@ def replace_conv2d(conv, masks):
in_end
=
in_start
+
ori_inchannel_step
out_start
=
groupid
*
ori_outchannel_step
out_end
=
out_start
+
ori_outchannel_step
current_input_index
=
list
(
filter
(
lambda
x
:
in_start
<=
x
and
x
<
in_end
,
remained_in
.
tolist
()))
current_output_index
=
list
(
filter
(
lambda
x
:
out_start
<=
x
and
x
<
out_end
,
remained_out
.
tolist
()))
current_input_mask
=
torch
.
logical_and
(
in_start
<=
remained_in
,
remained_in
<
in_end
)
current_input_index
=
remained_in
[
current_input_mask
]
current_output_mask
=
torch
.
logical_and
(
out_start
<=
remained_out
,
remained_out
<
out_end
)
current_output_index
=
remained_out
[
current_output_mask
]
# remap the global index to the group index
current_input_index
=
[
x
-
in_start
for
x
in
current_input_index
]
current_input_index
=
current_input_index
-
in_start
if
len
(
current_input_index
)
==
0
:
# if the whole group are pruned
assert
len
(
current_output_index
)
==
0
...
...
nni/compression/pytorch/speedup/infer_mask.py
View file @
4cf68009
...
...
@@ -82,6 +82,8 @@ class AutoMaskInference:
if
output_mask
is
not
None
:
# assume the given output mask is right
self
.
output_mask
=
output_mask
elif
isinstance
(
module
,
nn
.
GroupNorm
):
self
.
output_mask
=
self
.
in_masks
[
0
]
else
:
if
isinstance
(
self
.
output
,
torch
.
Tensor
):
self
.
output_mask
=
torch
.
ones_like
(
self
.
output
)
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
4cf68009
...
...
@@ -49,7 +49,7 @@ class Dependency:
# user should provide model & dummy_input to trace
# the model or a already traced model
assert
model
is
not
None
and
dummy_input
is
not
None
self
.
graph
=
TorchModuleGraph
(
model
,
dummy_input
,
traced_model
)
self
.
graph
:
TorchModuleGraph
=
TorchModuleGraph
(
model
,
dummy_input
,
traced_model
)
self
.
model
=
model
self
.
dependency
=
dict
()
self
.
build_dependency
()
...
...
@@ -123,6 +123,9 @@ class ChannelDependency(Dependency):
elif
self
.
prune_type
==
'Batchnorm'
:
self
.
target_types
.
append
(
'BatchNorm2d'
)
from
typing
import
Dict
,
Set
self
.
dependency
:
Dict
[
str
,
Set
[
str
]]
super
(
ChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
...
...
@@ -351,7 +354,7 @@ class GroupDependency(Dependency):
----------
model : torch.nn.Module
The model to be analyzed.
d
ata
: torch.Tensor
d
ummy_input
: torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
...
...
@@ -418,6 +421,29 @@ class GroupDependency(Dependency):
return
1
return
group
def
_get_group_norm_condition
(
self
,
node_group
)
->
int
:
"""
Get the number of groups for a group norm layer.
Parameters
----------
node_group : NodePyGroup
target node.
Returns
-------
condition: int
the number that layer's num channel
require to be divisible to
"""
node_name
=
node_group
.
name
_
,
leaf_module
=
get_module_by_name
(
self
.
model
,
node_name
)
if
isinstance
(
leaf_module
,
(
PrunerModuleWrapper
,
PrunerModuleWrapper_v2
)):
leaf_module
=
leaf_module
.
module
assert
isinstance
(
leaf_module
,
(
torch
.
nn
.
GroupNorm
))
return
leaf_module
.
num_groups
def
build_dependency
(
self
):
"""
Build the channel dependency for the conv layers
...
...
@@ -441,8 +467,11 @@ class GroupDependency(Dependency):
"""
self
.
groups
=
{}
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
if
node
.
op_type
==
'Conv2d'
or
node
.
op_type
==
'ConvTranspose2d'
:
group
=
self
.
_get_conv_groups
(
node
)
if
node
.
op_type
in
[
'Conv2d'
,
'ConvTranspose2d'
,
"GroupNorm"
]:
if
node
.
op_type
in
[
'Conv2d'
,
'ConvTranspose2d'
]:
group
=
self
.
_get_conv_groups
(
node
)
elif
node
.
op_type
==
"GroupNorm"
:
group
=
self
.
_get_group_norm_condition
(
node
)
if
node
.
name
in
self
.
groups
:
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment