Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
11aff9df
Unverified
Commit
11aff9df
authored
Jun 10, 2022
by
J-shang
Committed by
GitHub
Jun 10, 2022
Browse files
[Bugbash] fix speedup replacement related (#4906)
parent
993109bb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
76 additions
and
30 deletions
+76
-30
docs/source/conf.py
docs/source/conf.py
+1
-1
nni/common/graph_utils.py
nni/common/graph_utils.py
+2
-1
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+28
-22
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+9
-5
nni/compression/pytorch/speedup/jit_translate.py
nni/compression/pytorch/speedup/jit_translate.py
+30
-0
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+6
-1
No files found.
docs/source/conf.py
View file @
11aff9df
...
@@ -110,7 +110,7 @@ linkcheck_ignore = [
...
@@ -110,7 +110,7 @@ linkcheck_ignore = [
r
'https://www\.msra\.cn/'
,
# MSRA
r
'https://www\.msra\.cn/'
,
# MSRA
r
'https://1drv\.ms/'
,
# OneDrive (shortcut)
r
'https://1drv\.ms/'
,
# OneDrive (shortcut)
r
'https://onedrive\.live\.com/'
,
# OneDrive
r
'https://onedrive\.live\.com/'
,
# OneDrive
r
'https://www\.openml\.org/'
,
r
'https://www\.openml\.org/'
,
# OpenML
]
]
# Ignore all links located in release.rst
# Ignore all links located in release.rst
...
...
nni/common/graph_utils.py
View file @
11aff9df
...
@@ -775,7 +775,8 @@ class TorchModuleGraph(TorchGraph):
...
@@ -775,7 +775,8 @@ class TorchModuleGraph(TorchGraph):
"""
"""
# extract the input & output shape for the view and flatten
# extract the input & output shape for the view and flatten
for
node_group
in
self
.
nodes_py
.
nodes_op
:
for
node_group
in
self
.
nodes_py
.
nodes_op
:
if
node_group
.
op_type
in
[
'aten::view'
,
'aten::flatten'
,
'aten::mean'
,
'aten::reshape'
,
'aten::expand_as'
]:
if
node_group
.
op_type
in
[
'aten::view'
,
'aten::flatten'
,
'aten::mean'
,
'aten::reshape'
,
'aten::expand_as'
,
'aten::pixel_shuffle'
]:
# get shape infor for view (aten::view) func
# get shape infor for view (aten::view) func
cpp_node
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
node_group
.
op_type
,
cpp_node
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
node_group
.
op_type
,
node_group
.
node_cpps
))[
0
]
node_group
.
node_cpps
))[
0
]
...
...
nni/compression/pytorch/speedup/compress_modules.py
View file @
11aff9df
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.error_code
import
EmptyLayerError
,
ShapeMisMatchError
,
InputsNumberError
,
OutputTypeError
,
UnBalancedGroupError
from
.error_code
import
EmptyLayerError
,
ShapeMisMatchError
,
InputsNumberError
,
OutputTypeError
,
UnBalancedGroupError
...
@@ -595,37 +596,42 @@ def replace_layernorm(layernorm, masks):
...
@@ -595,37 +596,42 @@ def replace_layernorm(layernorm, masks):
def
replace_pixelshuffle
(
pixelshuffle
,
masks
):
def
replace_pixelshuffle
(
pixelshuffle
,
masks
):
"""
"""
Parameters
This is a nearly `no_replace` function.
----------
norm : torch.nn.PixelShuffle
The pixelshuffle module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
We can not replace pixelshuffle easily right now, pixelshuffle is a kind of location mapping.
-------
It will map tensor with shape (r^2 * C, H, W) to (C, r * H, r* W). So we have a dependency here,
torch.nn.PixelShuffle
the preserved input channel number should be a multiple of C, and the multiple can be squared to positive integer.
The new pixelshuffle module
This dependence is similar to the group dependency in ConvXD, but more restrictive,
i.e., each `r^2 input channels` group can not be free to preserve any number of channels, must be a number in [1, 4, 9, 16, ... , r^2].
"""
"""
in_masks
,
output_mask
,
_
=
masks
in_masks
,
output_mask
,
_
=
masks
assert
isinstance
(
pixelshuffle
,
torch
.
nn
.
PixelShuffle
)
assert
isinstance
(
pixelshuffle
,
torch
.
nn
.
PixelShuffle
)
if
len
(
in_masks
)
!=
1
:
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
in_mask
=
in_masks
[
0
]
# N, C, H, W
# FIXME: This should be a correct replacement logic, but since we can't correctly generate qualified masks,
# most of the time this is a no_replace.
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
upscale_factor
=
pixelshuffle
.
upscale_factor
in_channel_num
,
out_channel_num
=
remained_in
.
size
(
0
),
remained_out
.
size
(
0
)
if
remained_in
.
size
(
0
)
%
(
upscale_factor
*
upscale_factor
):
upscale_factor
=
math
.
floor
(
math
.
sqrt
(
in_channel_num
/
out_channel_num
))
_logger
.
debug
(
"Shape mismatch, remained_in:%d upscale_factor:%d"
,
remained_in
.
size
(
0
),
remained_out
.
size
(
0
))
if
in_channel_num
!=
out_channel_num
*
(
upscale_factor
*
upscale_factor
):
raise
ShapeMisMatchError
()
err_msg
=
"Your speedup model may encounter shape mismatch error during inference. "
if
remained_out
.
size
(
0
)
*
upscale_factor
*
upscale_factor
!=
remained_in
:
err_msg
+=
f
"PixelShuffle preserved input channel number is
{
in_channel_num
}
, "
raise
ShapeMisMatchError
()
err_msg
+=
f
"preserved output channel number is
{
out_channel_num
}
, "
err_msg
+=
"unable to find a suitable upscale_factor, keep it as it is, please replace this module manually, "
err_msg
+=
"or adjust the module sparsity ratio before this module to ensure that a suitable upscale_factor can be found."
# Don't raise an error because the user maybe know how to manually replace this function.
_logger
.
error
(
err_msg
)
# NOTE: no_replace, use the orignal upscale_factor if we can not find a suitable upscale_factor.
upscale_factor
=
pixelshuffle
.
upscale_factor
if
upscale_factor
!=
pixelshuffle
.
upscale_factor
:
war_msg
=
f
"Change PixelShuffle upscale_factor from
{
pixelshuffle
.
upscale_factor
}
to
{
upscale_factor
}
, "
war_msg
+=
"subsequent computation semantics may have changed."
_logger
.
warning
(
war_msg
)
new_pixelshuffle
=
torch
.
nn
.
PixelShuffle
(
upscale_factor
)
new_pixelshuffle
=
torch
.
nn
.
PixelShuffle
(
upscale_factor
)
return
new_pixelshuffle
return
new_pixelshuffle
\ No newline at end of file
nni/compression/pytorch/speedup/compressor.py
View file @
11aff9df
...
@@ -459,13 +459,17 @@ class ModelSpeedup:
...
@@ -459,13 +459,17 @@ class ModelSpeedup:
self
.
bound_model
,
g_node
.
name
)
self
.
bound_model
,
g_node
.
name
)
m_type
=
g_node
.
op_type
m_type
=
g_node
.
op_type
if
(
not
m_type
in
replace_module
)
and
(
m_type
not
in
self
.
customized_replace_func
):
if
(
not
m_type
in
replace_module
)
and
(
m_type
not
in
self
.
customized_replace_func
):
raise
RuntimeError
(
err_msg
=
f
"Has not supported replacing module with type:
{
m_type
}
, "
"Has not supported replacing the module: `{}`"
.
format
(
m_type
))
err_msg
+=
f
"you could report an issue at https://github.com/microsoft/nni. "
err_msg
+=
f
"If you know how to replace
{
m_type
}
, "
err_msg
+=
f
"you could implement module replacement by passing in"
err_msg
+=
f
"`customized_replace_func` to `
{
self
.
__class__
.
__name__
}
`. "
err_msg
+=
f
"You are welcome to contribute back to nni as native support if you have implemented the replacement function, "
err_msg
+=
f
"so that more users can benefit from your contributions."
raise
RuntimeError
(
err_msg
)
_logger
.
info
(
"replace module (name: %s, op_type: %s)"
,
_logger
.
info
(
"replace module (name: %s, op_type: %s)"
,
g_node
.
name
,
m_type
)
g_node
.
name
,
m_type
)
replace_function
=
replace_module
[
m_type
]
replace_function
=
self
.
customized_replace_func
.
get
(
m_type
,
replace_module
.
get
(
m_type
,
None
))
if
m_type
in
self
.
customized_replace_func
:
replace_function
=
self
.
customized_replace_func
[
m_type
]
compressed_module
=
replace_function
(
compressed_module
=
replace_function
(
leaf_module
,
auto_infer
.
get_masks
())
leaf_module
,
auto_infer
.
get_masks
())
new_submodule
=
compressed_module
new_submodule
=
compressed_module
...
...
nni/compression/pytorch/speedup/jit_translate.py
View file @
11aff9df
...
@@ -492,6 +492,35 @@ def upsample_bilinear2d_python(node, speedup):
...
@@ -492,6 +492,35 @@ def upsample_bilinear2d_python(node, speedup):
return
UpsampleModule
(
size_list
,
scale_list
)
return
UpsampleModule
(
size_list
,
scale_list
)
def
upsample_nearest2d_python
(
node
,
speedup
):
class
UpsampleModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size_list
,
scale_list
):
super
(
UpsampleModule
,
self
).
__init__
()
self
.
size_list
=
size_list
self
.
scale_list
=
scale_list
def
forward
(
self
,
*
args
):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return
torch
.
nn
.
functional
.
upsample_nearest
(
args
[
0
],
size
=
self
.
size_list
,
scale_factor
=
self
.
scale_list
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_list_node
=
inputs
[
1
].
node
()
scale_list_node
=
inputs
[
2
].
node
()
size_list
=
None
scale_list
=
None
if
size_list_node
.
kind
()
==
'prim::ListConstruct'
:
size_list
=
translate_list
(
inputs
[
1
],
speedup
)
if
scale_list_node
.
kind
()
==
'prim::ListConstruct'
:
scale_list
=
translate_list
(
inputs
[
2
],
speedup
)
return
UpsampleModule
(
size_list
,
scale_list
)
def
typeas_python
(
node
,
speedup
):
def
typeas_python
(
node
,
speedup
):
"""
"""
currently only support type_as float.
currently only support type_as float.
...
@@ -583,6 +612,7 @@ trans_from_jit_to_python = {
...
@@ -583,6 +612,7 @@ trans_from_jit_to_python = {
'aten::to'
:
to_python
,
'aten::to'
:
to_python
,
'aten::type_as'
:
typeas_python
,
'aten::type_as'
:
typeas_python
,
'aten::upsample_bilinear2d'
:
upsample_bilinear2d_python
,
'aten::upsample_bilinear2d'
:
upsample_bilinear2d_python
,
'aten::upsample_nearest2d'
:
upsample_nearest2d_python
,
'aten::exp'
:
exp_python
,
'aten::exp'
:
exp_python
,
'aten::squeeze'
:
squeeze_python
,
'aten::squeeze'
:
squeeze_python
,
'aten::unsqueeze'
:
unsqueeze_python
,
'aten::unsqueeze'
:
unsqueeze_python
,
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
11aff9df
...
@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_']
...
@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_']
CAT_TYPE
=
'aten::cat'
CAT_TYPE
=
'aten::cat'
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
,
'aten::expand_as'
]
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
,
'aten::expand_as'
,
'aten::pixel_shuffle'
]
def
lcm_list
(
L
):
def
lcm_list
(
L
):
...
@@ -85,6 +85,11 @@ def reshape_break_channel_dependency(op_node):
...
@@ -85,6 +85,11 @@ def reshape_break_channel_dependency(op_node):
"""
"""
in_shape
=
op_node
.
auxiliary
[
'in_shape'
]
in_shape
=
op_node
.
auxiliary
[
'in_shape'
]
out_shape
=
op_node
.
auxiliary
[
'out_shape'
]
out_shape
=
op_node
.
auxiliary
[
'out_shape'
]
# FIXME: e.g., in_shape will be None if the input comes from a buffer, should be fixed in next release
if
not
in_shape
or
not
out_shape
:
return
True
if
len
(
in_shape
)
<=
1
or
len
(
out_shape
)
<=
1
:
return
True
in_channel
=
in_shape
[
1
]
in_channel
=
in_shape
[
1
]
out_channel
=
out_shape
[
1
]
out_channel
=
out_shape
[
1
]
return
in_channel
!=
out_channel
return
in_channel
!=
out_channel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment