Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
11aff9df
"...serve/git@developer.sourcefind.cn:guobj/qwen_lmdeploy.git" did not exist on "6904053f1b40842a214a4704863c12ecc3957430"
Unverified
Commit
11aff9df
authored
Jun 10, 2022
by
J-shang
Committed by
GitHub
Jun 10, 2022
Browse files
[Bugbash] fix speedup replacement related (#4906)
parent
993109bb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
76 additions
and
30 deletions
+76
-30
docs/source/conf.py
docs/source/conf.py
+1
-1
nni/common/graph_utils.py
nni/common/graph_utils.py
+2
-1
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+28
-22
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+9
-5
nni/compression/pytorch/speedup/jit_translate.py
nni/compression/pytorch/speedup/jit_translate.py
+30
-0
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+6
-1
No files found.
docs/source/conf.py
View file @
11aff9df
...
...
@@ -110,7 +110,7 @@ linkcheck_ignore = [
r
'https://www\.msra\.cn/'
,
# MSRA
r
'https://1drv\.ms/'
,
# OneDrive (shortcut)
r
'https://onedrive\.live\.com/'
,
# OneDrive
r
'https://www\.openml\.org/'
,
r
'https://www\.openml\.org/'
,
# OpenML
]
# Ignore all links located in release.rst
...
...
nni/common/graph_utils.py
View file @
11aff9df
...
...
@@ -775,7 +775,8 @@ class TorchModuleGraph(TorchGraph):
"""
# extract the input & output shape for the view and flatten
for
node_group
in
self
.
nodes_py
.
nodes_op
:
if
node_group
.
op_type
in
[
'aten::view'
,
'aten::flatten'
,
'aten::mean'
,
'aten::reshape'
,
'aten::expand_as'
]:
if
node_group
.
op_type
in
[
'aten::view'
,
'aten::flatten'
,
'aten::mean'
,
'aten::reshape'
,
'aten::expand_as'
,
'aten::pixel_shuffle'
]:
# get shape infor for view (aten::view) func
cpp_node
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
node_group
.
op_type
,
node_group
.
node_cpps
))[
0
]
...
...
nni/compression/pytorch/speedup/compress_modules.py
View file @
11aff9df
...
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import
logging
import
math
import
torch
import
torch.nn
as
nn
from
.error_code
import
EmptyLayerError
,
ShapeMisMatchError
,
InputsNumberError
,
OutputTypeError
,
UnBalancedGroupError
...
...
@@ -595,37 +596,42 @@ def replace_layernorm(layernorm, masks):
def
replace_pixelshuffle
(
pixelshuffle
,
masks
):
"""
Parameters
----------
norm : torch.nn.PixelShuffle
The pixelshuffle module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
This is a nearly `no_replace` function.
Returns
-------
torch.nn.PixelShuffle
The new pixelshuffle module
We can not replace pixelshuffle easily right now, pixelshuffle is a kind of location mapping.
It will map tensor with shape (r^2 * C, H, W) to (C, r * H, r* W). So we have a dependency here,
the preserved input channel number should be a multiple of C, and the multiple can be squared to positive integer.
This dependence is similar to the group dependency in ConvXD, but more restrictive,
i.e., each `r^2 input channels` group can not be free to preserve any number of channels, must be a number in [1, 4, 9, 16, ... , r^2].
"""
in_masks
,
output_mask
,
_
=
masks
assert
isinstance
(
pixelshuffle
,
torch
.
nn
.
PixelShuffle
)
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
# N, C, H, W
# FIXME: This should be a correct replacement logic, but since we can't correctly generate qualified masks,
# most of the time this is a no_replace.
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
upscale_factor
=
pixelshuffle
.
upscale_factor
if
remained_in
.
size
(
0
)
%
(
upscale_factor
*
upscale_factor
):
_logger
.
debug
(
"Shape mismatch, remained_in:%d upscale_factor:%d"
,
remained_in
.
size
(
0
),
remained_out
.
size
(
0
))
raise
ShapeMisMatchError
()
if
remained_out
.
size
(
0
)
*
upscale_factor
*
upscale_factor
!=
remained_in
:
raise
ShapeMisMatchError
()
in_channel_num
,
out_channel_num
=
remained_in
.
size
(
0
),
remained_out
.
size
(
0
)
upscale_factor
=
math
.
floor
(
math
.
sqrt
(
in_channel_num
/
out_channel_num
))
if
in_channel_num
!=
out_channel_num
*
(
upscale_factor
*
upscale_factor
):
err_msg
=
"Your speedup model may encounter shape mismatch error during inference. "
err_msg
+=
f
"PixelShuffle preserved input channel number is
{
in_channel_num
}
, "
err_msg
+=
f
"preserved output channel number is
{
out_channel_num
}
, "
err_msg
+=
"unable to find a suitable upscale_factor, keep it as it is, please replace this module manually, "
err_msg
+=
"or adjust the module sparsity ratio before this module to ensure that a suitable upscale_factor can be found."
# Don't raise an error because the user maybe know how to manually replace this function.
_logger
.
error
(
err_msg
)
# NOTE: no_replace, use the orignal upscale_factor if we can not find a suitable upscale_factor.
upscale_factor
=
pixelshuffle
.
upscale_factor
if
upscale_factor
!=
pixelshuffle
.
upscale_factor
:
war_msg
=
f
"Change PixelShuffle upscale_factor from
{
pixelshuffle
.
upscale_factor
}
to
{
upscale_factor
}
, "
war_msg
+=
"subsequent computation semantics may have changed."
_logger
.
warning
(
war_msg
)
new_pixelshuffle
=
torch
.
nn
.
PixelShuffle
(
upscale_factor
)
return
new_pixelshuffle
\ No newline at end of file
return
new_pixelshuffle
nni/compression/pytorch/speedup/compressor.py
View file @
11aff9df
...
...
@@ -459,13 +459,17 @@ class ModelSpeedup:
self
.
bound_model
,
g_node
.
name
)
m_type
=
g_node
.
op_type
if
(
not
m_type
in
replace_module
)
and
(
m_type
not
in
self
.
customized_replace_func
):
raise
RuntimeError
(
"Has not supported replacing the module: `{}`"
.
format
(
m_type
))
err_msg
=
f
"Has not supported replacing module with type:
{
m_type
}
, "
err_msg
+=
f
"you could report an issue at https://github.com/microsoft/nni. "
err_msg
+=
f
"If you know how to replace
{
m_type
}
, "
err_msg
+=
f
"you could implement module replacement by passing in"
err_msg
+=
f
"`customized_replace_func` to `
{
self
.
__class__
.
__name__
}
`. "
err_msg
+=
f
"You are welcome to contribute back to nni as native support if you have implemented the replacement function, "
err_msg
+=
f
"so that more users can benefit from your contributions."
raise
RuntimeError
(
err_msg
)
_logger
.
info
(
"replace module (name: %s, op_type: %s)"
,
g_node
.
name
,
m_type
)
replace_function
=
replace_module
[
m_type
]
if
m_type
in
self
.
customized_replace_func
:
replace_function
=
self
.
customized_replace_func
[
m_type
]
replace_function
=
self
.
customized_replace_func
.
get
(
m_type
,
replace_module
.
get
(
m_type
,
None
))
compressed_module
=
replace_function
(
leaf_module
,
auto_infer
.
get_masks
())
new_submodule
=
compressed_module
...
...
nni/compression/pytorch/speedup/jit_translate.py
View file @
11aff9df
...
...
@@ -492,6 +492,35 @@ def upsample_bilinear2d_python(node, speedup):
return
UpsampleModule
(
size_list
,
scale_list
)
def
upsample_nearest2d_python
(
node
,
speedup
):
class
UpsampleModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size_list
,
scale_list
):
super
(
UpsampleModule
,
self
).
__init__
()
self
.
size_list
=
size_list
self
.
scale_list
=
scale_list
def
forward
(
self
,
*
args
):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return
torch
.
nn
.
functional
.
upsample_nearest
(
args
[
0
],
size
=
self
.
size_list
,
scale_factor
=
self
.
scale_list
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
size_list_node
=
inputs
[
1
].
node
()
scale_list_node
=
inputs
[
2
].
node
()
size_list
=
None
scale_list
=
None
if
size_list_node
.
kind
()
==
'prim::ListConstruct'
:
size_list
=
translate_list
(
inputs
[
1
],
speedup
)
if
scale_list_node
.
kind
()
==
'prim::ListConstruct'
:
scale_list
=
translate_list
(
inputs
[
2
],
speedup
)
return
UpsampleModule
(
size_list
,
scale_list
)
def
typeas_python
(
node
,
speedup
):
"""
currently only support type_as float.
...
...
@@ -583,6 +612,7 @@ trans_from_jit_to_python = {
'aten::to'
:
to_python
,
'aten::type_as'
:
typeas_python
,
'aten::upsample_bilinear2d'
:
upsample_bilinear2d_python
,
'aten::upsample_nearest2d'
:
upsample_nearest2d_python
,
'aten::exp'
:
exp_python
,
'aten::squeeze'
:
squeeze_python
,
'aten::unsqueeze'
:
unsqueeze_python
,
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
11aff9df
...
...
@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_']
CAT_TYPE
=
'aten::cat'
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
,
'aten::expand_as'
]
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
,
'aten::expand_as'
,
'aten::pixel_shuffle'
]
def
lcm_list
(
L
):
...
...
@@ -85,6 +85,11 @@ def reshape_break_channel_dependency(op_node):
"""
in_shape
=
op_node
.
auxiliary
[
'in_shape'
]
out_shape
=
op_node
.
auxiliary
[
'out_shape'
]
# FIXME: e.g., in_shape will be None if the input comes from a buffer, should be fixed in next release
if
not
in_shape
or
not
out_shape
:
return
True
if
len
(
in_shape
)
<=
1
or
len
(
out_shape
)
<=
1
:
return
True
in_channel
=
in_shape
[
1
]
out_channel
=
out_shape
[
1
]
return
in_channel
!=
out_channel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment