"...serve/git@developer.sourcefind.cn:guobj/qwen_lmdeploy.git" did not exist on "6904053f1b40842a214a4704863c12ecc3957430"
Unverified Commit 11aff9df authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Bugbash] fix speedup replacement related (#4906)

parent 993109bb
......@@ -110,7 +110,7 @@ linkcheck_ignore = [
r'https://www\.msra\.cn/', # MSRA
r'https://1drv\.ms/', # OneDrive (shortcut)
r'https://onedrive\.live\.com/', # OneDrive
r'https://www\.openml\.org/',
r'https://www\.openml\.org/', # OpenML
]
# Ignore all links located in release.rst
......
......@@ -775,7 +775,8 @@ class TorchModuleGraph(TorchGraph):
"""
# extract the input & output shape for the view and flatten
for node_group in self.nodes_py.nodes_op:
if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape', 'aten::expand_as']:
if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape', 'aten::expand_as',
'aten::pixel_shuffle']:
# get shape infor for view (aten::view) func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import logging
import math
import torch
import torch.nn as nn
from .error_code import EmptyLayerError, ShapeMisMatchError, InputsNumberError, OutputTypeError, UnBalancedGroupError
......@@ -595,37 +596,42 @@ def replace_layernorm(layernorm, masks):
def replace_pixelshuffle(pixelshuffle, masks):
"""
Parameters
----------
norm : torch.nn.PixelShuffle
The pixelshuffle module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
This is a nearly `no_replace` function.
Returns
-------
torch.nn.PixelShuffle
The new pixelshuffle module
We can not replace pixelshuffle easily right now, pixelshuffle is a kind of location mapping.
It will map tensor with shape (r^2 * C, H, W) to (C, r * H, r* W). So we have a dependency here,
the preserved input channel number should be a multiple of C, and the multiple can be squared to positive integer.
This dependence is similar to the group dependency in ConvXD, but more restrictive,
i.e., each `r^2 input channels` group can not be free to preserve any number of channels, must be a number in [1, 4, 9, 16, ... , r^2].
"""
in_masks, output_mask, _ = masks
assert isinstance(pixelshuffle, torch.nn.PixelShuffle)
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]
# N, C, H, W
# FIXME: This should be a correct replacement logic, but since we can't correctly generate qualified masks,
# most of the time this is a no_replace.
_, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1)
upscale_factor = pixelshuffle.upscale_factor
if remained_in.size(0) % (upscale_factor * upscale_factor):
_logger.debug("Shape mismatch, remained_in:%d upscale_factor:%d",
remained_in.size(0), remained_out.size(0))
raise ShapeMisMatchError()
if remained_out.size(0) * upscale_factor * upscale_factor != remained_in:
raise ShapeMisMatchError()
in_channel_num, out_channel_num = remained_in.size(0), remained_out.size(0)
upscale_factor = math.floor(math.sqrt(in_channel_num / out_channel_num))
if in_channel_num != out_channel_num * (upscale_factor * upscale_factor):
err_msg = "Your speedup model may encounter shape mismatch error during inference. "
err_msg += f"PixelShuffle preserved input channel number is {in_channel_num}, "
err_msg += f"preserved output channel number is {out_channel_num}, "
err_msg += "unable to find a suitable upscale_factor, keep it as it is, please replace this module manually, "
err_msg += "or adjust the module sparsity ratio before this module to ensure that a suitable upscale_factor can be found."
# Don't raise an error because the user maybe know how to manually replace this function.
_logger.error(err_msg)
# NOTE: no_replace, use the orignal upscale_factor if we can not find a suitable upscale_factor.
upscale_factor = pixelshuffle.upscale_factor
if upscale_factor != pixelshuffle.upscale_factor:
war_msg = f"Change PixelShuffle upscale_factor from {pixelshuffle.upscale_factor} to {upscale_factor}, "
war_msg += "subsequent computation semantics may have changed."
_logger.warning(war_msg)
new_pixelshuffle = torch.nn.PixelShuffle(upscale_factor)
return new_pixelshuffle
\ No newline at end of file
return new_pixelshuffle
......@@ -459,13 +459,17 @@ class ModelSpeedup:
self.bound_model, g_node.name)
m_type = g_node.op_type
if (not m_type in replace_module) and (m_type not in self.customized_replace_func):
raise RuntimeError(
"Has not supported replacing the module: `{}`".format(m_type))
err_msg = f"Has not supported replacing module with type: {m_type}, "
err_msg += f"you could report an issue at https://github.com/microsoft/nni. "
err_msg += f"If you know how to replace {m_type}, "
err_msg += f"you could implement module replacement by passing in"
err_msg += f"`customized_replace_func` to `{self.__class__.__name__}`. "
err_msg += f"You are welcome to contribute back to nni as native support if you have implemented the replacement function, "
err_msg += f"so that more users can benefit from your contributions."
raise RuntimeError(err_msg)
_logger.info("replace module (name: %s, op_type: %s)",
g_node.name, m_type)
replace_function = replace_module[m_type]
if m_type in self.customized_replace_func:
replace_function = self.customized_replace_func[m_type]
replace_function = self.customized_replace_func.get(m_type, replace_module.get(m_type, None))
compressed_module = replace_function(
leaf_module, auto_infer.get_masks())
new_submodule = compressed_module
......
......@@ -492,6 +492,35 @@ def upsample_bilinear2d_python(node, speedup):
return UpsampleModule(size_list, scale_list)
def upsample_nearest2d_python(node, speedup):
class UpsampleModule(torch.nn.Module):
def __init__(self, size_list, scale_list):
super(UpsampleModule, self).__init__()
self.size_list = size_list
self.scale_list = scale_list
def forward(self, *args):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return torch.nn.functional.upsample_nearest(args[0],
size=self.size_list, scale_factor=self.scale_list)
c_node = node.key_node
inputs = list(c_node.inputs())
size_list_node = inputs[1].node()
scale_list_node = inputs[2].node()
size_list = None
scale_list = None
if size_list_node.kind() == 'prim::ListConstruct':
size_list = translate_list(inputs[1], speedup)
if scale_list_node.kind() == 'prim::ListConstruct':
scale_list = translate_list(inputs[2], speedup)
return UpsampleModule(size_list, scale_list)
def typeas_python(node, speedup):
"""
currently only support type_as float.
......@@ -583,6 +612,7 @@ trans_from_jit_to_python = {
'aten::to': to_python,
'aten::type_as': typeas_python,
'aten::upsample_bilinear2d': upsample_bilinear2d_python,
'aten::upsample_nearest2d': upsample_nearest2d_python,
'aten::exp': exp_python,
'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python,
......
......@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_']
CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency')
RESHAPE_OPS = [CAT_TYPE, 'aten::view',
'aten::reshape', 'aten::flatten', 'aten::mean', 'aten::expand_as']
'aten::reshape', 'aten::flatten', 'aten::mean', 'aten::expand_as', 'aten::pixel_shuffle']
def lcm_list(L):
......@@ -85,6 +85,11 @@ def reshape_break_channel_dependency(op_node):
"""
in_shape = op_node.auxiliary['in_shape']
out_shape = op_node.auxiliary['out_shape']
# FIXME: e.g., in_shape will be None if the input comes from a buffer, should be fixed in next release
if not in_shape or not out_shape:
return True
if len(in_shape) <= 1 or len(out_shape) <= 1:
return True
in_channel = in_shape[1]
out_channel = out_shape[1]
return in_channel != out_channel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment