Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4cf68009
Unverified
Commit
4cf68009
authored
Aug 26, 2022
by
Xing
Committed by
GitHub
Aug 26, 2022
Browse files
Add Group Norm support for Pruning model (#5069)
parent
858daf9f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
143 additions
and
20 deletions
+143
-20
nni/common/graph_utils.py
nni/common/graph_utils.py
+4
-2
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+104
-14
nni/compression/pytorch/speedup/infer_mask.py
nni/compression/pytorch/speedup/infer_mask.py
+2
-0
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+33
-4
No files found.
nni/common/graph_utils.py
View file @
4cf68009
...
...
@@ -6,6 +6,7 @@ import logging
import
queue
import
re
from
collections
import
defaultdict
from
typing
import
List
,
Dict
import
torch
from
torch.utils.tensorboard._pytorch_graph
import
NodePy
,
NodePyIO
,
NodePyOP
,
GraphPy
CLASSTYPE_KIND
=
'ClassType'
...
...
@@ -262,6 +263,7 @@ class TorchModuleGraph(TorchGraph):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
super
().
__init__
(
model
,
dummy_input
,
traced_model
)
self
.
name_to_node
:
Dict
[
str
,
NodePyOP
]
self
.
global_count
=
0
self
.
reused_module
=
set
()
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
...
...
@@ -802,7 +804,7 @@ class TorchModuleGraph(TorchGraph):
node_group
.
auxiliary
=
self
.
_extract_cat_info
(
node_group
,
cpp_node
)
def
find_predecessors
(
self
,
unique_name
):
def
find_predecessors
(
self
,
unique_name
)
->
List
[
str
]
:
"""
Find predecessor node of the given node
...
...
@@ -825,7 +827,7 @@ class TorchModuleGraph(TorchGraph):
predecessors
.
append
(
node_py
.
unique_name
)
return
predecessors
def
find_successors
(
self
,
unique_name
):
def
find_successors
(
self
,
unique_name
)
->
List
[
str
]
:
"""
Find successor nodes of the given node
...
...
nni/compression/pytorch/speedup/compress_modules.py
View file @
4cf68009
...
...
@@ -48,7 +48,8 @@ replace_module = {
'ConvTranspose2d'
:
lambda
module
,
masks
:
replace_convtranspose2d
(
module
,
masks
),
'Embedding'
:
lambda
module
,
masks
:
replace_embedding
(
module
,
masks
),
'PixelShuffle'
:
lambda
module
,
masks
:
replace_pixelshuffle
(
module
,
masks
),
'Flatten'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
)
'Flatten'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'GroupNorm'
:
lambda
module
,
masks
:
replace_groupnorm
(
module
,
masks
),
}
...
...
@@ -310,6 +311,90 @@ def replace_batchnorm2d(norm, masks):
return
new_norm
def
replace_groupnorm
(
norm
:
nn
.
GroupNorm
,
masks
):
"""
Parameters
----------
norm : torch.nn.GroupNorm
The group norm module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.GroupNorm
The new group norm module
"""
in_masks
,
output_mask
,
_
=
masks
assert
isinstance
(
norm
,
nn
.
GroupNorm
)
in_mask
=
in_masks
[
0
]
# N, C, H, W
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
assert
len
(
remained_in
.
size
())
==
1
if
remained_in
.
size
(
0
)
!=
remained_out
.
size
(
0
):
raise
ShapeMisMatchError
()
ori_channel_step
=
norm
.
num_channels
//
norm
.
num_groups
for
groupid
in
range
(
norm
.
num_groups
):
in_start
=
groupid
*
ori_channel_step
in_end
=
in_start
+
ori_channel_step
new_channel_step
=
torch
.
logical_and
(
in_start
<=
remained_in
,
remained_in
<
in_end
,
).
sum
().
item
()
# this group fully pruned
if
new_channel_step
==
0
:
continue
break
new_groups
=
0
# Validate
for
groupid
in
range
(
norm
.
num_groups
):
in_start
=
groupid
*
ori_channel_step
in_end
=
in_start
+
ori_channel_step
num_item
=
torch
.
logical_and
(
in_start
<=
remained_in
,
remained_in
<
in_end
,
).
sum
().
item
()
if
num_item
==
0
:
continue
# check if the number of remained channel of each group are the same
if
num_item
!=
new_channel_step
:
raise
UnBalancedGroupError
()
new_groups
+=
1
new_num_channels
=
remained_in
.
size
()[
0
]
new_module
=
nn
.
GroupNorm
(
new_groups
,
new_num_channels
,
eps
=
norm
.
eps
,
affine
=
norm
.
affine
,
)
if
new_module
.
affine
:
new_module
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
remained_in
,
)
new_module
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
remained_in
,
)
return
new_module
def
replace_instancenorm2d
(
norm
,
masks
):
"""
Parameters
...
...
@@ -409,18 +494,20 @@ def replace_conv2d(conv, masks):
in_end
=
in_start
+
ori_inchannel_step
out_start
=
groupid
*
ori_outchannel_step
out_end
=
out_start
+
ori_outchannel_step
current_input_index
=
list
(
filter
(
lambda
x
:
in_start
<=
x
and
x
<
in_end
,
remained_in
.
tolist
()))
current_output_index
=
list
(
filter
(
lambda
x
:
out_start
<=
x
and
x
<
out_end
,
remained_out
.
tolist
()))
new_inchannel_step
:
int
=
torch
.
logical_and
(
in_start
<=
remained_in
,
remained_in
<
in_end
).
sum
().
item
()
new_outchannel_step
:
int
=
torch
.
logical_and
(
out_start
<=
remained_out
,
remained_out
<
out_end
).
sum
().
item
()
# remap the global index to the group index
if
len
(
current_input_index
)
==
0
:
if
new_inchannel_step
==
0
:
# if the whole group are pruned
continue
else
:
new_inchannel_step
=
len
(
current_input_index
)
new_outchannel_step
=
len
(
current_output_index
)
break
tmp_weight
=
torch
.
ones
(
n_remained_out
,
new_inchannel_step
,
k_size1
,
k_size2
)
...
...
@@ -436,12 +523,15 @@ def replace_conv2d(conv, masks):
in_end
=
in_start
+
ori_inchannel_step
out_start
=
groupid
*
ori_outchannel_step
out_end
=
out_start
+
ori_outchannel_step
current_input_index
=
list
(
filter
(
lambda
x
:
in_start
<=
x
and
x
<
in_end
,
remained_in
.
tolist
()))
current_output_index
=
list
(
filter
(
lambda
x
:
out_start
<=
x
and
x
<
out_end
,
remained_out
.
tolist
()))
current_input_mask
=
torch
.
logical_and
(
in_start
<=
remained_in
,
remained_in
<
in_end
)
current_input_index
=
remained_in
[
current_input_mask
]
current_output_mask
=
torch
.
logical_and
(
out_start
<=
remained_out
,
remained_out
<
out_end
)
current_output_index
=
remained_out
[
current_output_mask
]
# remap the global index to the group index
current_input_index
=
[
x
-
in_start
for
x
in
current_input_index
]
current_input_index
=
current_input_index
-
in_start
if
len
(
current_input_index
)
==
0
:
# if the whole group are pruned
assert
len
(
current_output_index
)
==
0
...
...
nni/compression/pytorch/speedup/infer_mask.py
View file @
4cf68009
...
...
@@ -82,6 +82,8 @@ class AutoMaskInference:
if
output_mask
is
not
None
:
# assume the given output mask is right
self
.
output_mask
=
output_mask
elif
isinstance
(
module
,
nn
.
GroupNorm
):
self
.
output_mask
=
self
.
in_masks
[
0
]
else
:
if
isinstance
(
self
.
output
,
torch
.
Tensor
):
self
.
output_mask
=
torch
.
ones_like
(
self
.
output
)
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
4cf68009
...
...
@@ -49,7 +49,7 @@ class Dependency:
# user should provide model & dummy_input to trace
# the model or a already traced model
assert
model
is
not
None
and
dummy_input
is
not
None
self
.
graph
=
TorchModuleGraph
(
model
,
dummy_input
,
traced_model
)
self
.
graph
:
TorchModuleGraph
=
TorchModuleGraph
(
model
,
dummy_input
,
traced_model
)
self
.
model
=
model
self
.
dependency
=
dict
()
self
.
build_dependency
()
...
...
@@ -123,6 +123,9 @@ class ChannelDependency(Dependency):
elif
self
.
prune_type
==
'Batchnorm'
:
self
.
target_types
.
append
(
'BatchNorm2d'
)
from
typing
import
Dict
,
Set
self
.
dependency
:
Dict
[
str
,
Set
[
str
]]
super
(
ChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
...
...
@@ -351,7 +354,7 @@ class GroupDependency(Dependency):
----------
model : torch.nn.Module
The model to be analyzed.
d
ata
: torch.Tensor
d
ummy_input
: torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
...
...
@@ -418,6 +421,29 @@ class GroupDependency(Dependency):
return
1
return
group
def
_get_group_norm_condition
(
self
,
node_group
)
->
int
:
"""
Get the number of groups for a group norm layer.
Parameters
----------
node_group : NodePyGroup
target node.
Returns
-------
condition: int
the number that layer's num channel
require to be divisible to
"""
node_name
=
node_group
.
name
_
,
leaf_module
=
get_module_by_name
(
self
.
model
,
node_name
)
if
isinstance
(
leaf_module
,
(
PrunerModuleWrapper
,
PrunerModuleWrapper_v2
)):
leaf_module
=
leaf_module
.
module
assert
isinstance
(
leaf_module
,
(
torch
.
nn
.
GroupNorm
))
return
leaf_module
.
num_groups
def
build_dependency
(
self
):
"""
Build the channel dependency for the conv layers
...
...
@@ -441,8 +467,11 @@ class GroupDependency(Dependency):
"""
self
.
groups
=
{}
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
if
node
.
op_type
==
'Conv2d'
or
node
.
op_type
==
'ConvTranspose2d'
:
if
node
.
op_type
in
[
'Conv2d'
,
'ConvTranspose2d'
,
"GroupNorm"
]:
if
node
.
op_type
in
[
'Conv2d'
,
'ConvTranspose2d'
]:
group
=
self
.
_get_conv_groups
(
node
)
elif
node
.
op_type
==
"GroupNorm"
:
group
=
self
.
_get_group_norm_condition
(
node
)
if
node
.
name
in
self
.
groups
:
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment