Unverified Commit 11aff9df authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Bugbash] fix speedup replacement related (#4906)

parent 993109bb
...@@ -110,7 +110,7 @@ linkcheck_ignore = [ ...@@ -110,7 +110,7 @@ linkcheck_ignore = [
r'https://www\.msra\.cn/', # MSRA r'https://www\.msra\.cn/', # MSRA
r'https://1drv\.ms/', # OneDrive (shortcut) r'https://1drv\.ms/', # OneDrive (shortcut)
r'https://onedrive\.live\.com/', # OneDrive r'https://onedrive\.live\.com/', # OneDrive
r'https://www\.openml\.org/', r'https://www\.openml\.org/', # OpenML
] ]
# Ignore all links located in release.rst # Ignore all links located in release.rst
......
...@@ -775,7 +775,8 @@ class TorchModuleGraph(TorchGraph): ...@@ -775,7 +775,8 @@ class TorchModuleGraph(TorchGraph):
""" """
# extract the input & output shape for the view and flatten # extract the input & output shape for the view and flatten
for node_group in self.nodes_py.nodes_op: for node_group in self.nodes_py.nodes_op:
if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape', 'aten::expand_as']: if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape', 'aten::expand_as',
'aten::pixel_shuffle']:
# get shape infor for view (aten::view) func # get shape infor for view (aten::view) func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type, cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0] node_group.node_cpps))[0]
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from .error_code import EmptyLayerError, ShapeMisMatchError, InputsNumberError, OutputTypeError, UnBalancedGroupError from .error_code import EmptyLayerError, ShapeMisMatchError, InputsNumberError, OutputTypeError, UnBalancedGroupError
...@@ -595,37 +596,42 @@ def replace_layernorm(layernorm, masks): ...@@ -595,37 +596,42 @@ def replace_layernorm(layernorm, masks):
def replace_pixelshuffle(pixelshuffle, masks): def replace_pixelshuffle(pixelshuffle, masks):
""" """
Parameters This is a nearly `no_replace` function.
----------
norm : torch.nn.PixelShuffle
The pixelshuffle module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns We can not replace pixelshuffle easily right now, pixelshuffle is a kind of location mapping.
------- It will map tensor with shape (r^2 * C, H, W) to (C, r * H, r* W). So we have a dependency here,
torch.nn.PixelShuffle the preserved input channel number should be a multiple of C, and the multiple can be squared to positive integer.
The new pixelshuffle module This dependence is similar to the group dependency in ConvXD, but more restrictive,
i.e., each `r^2 input channels` group can not be free to preserve any number of channels, must be a number in [1, 4, 9, 16, ... , r^2].
""" """
in_masks, output_mask, _ = masks in_masks, output_mask, _ = masks
assert isinstance(pixelshuffle, torch.nn.PixelShuffle) assert isinstance(pixelshuffle, torch.nn.PixelShuffle)
if len(in_masks) != 1: if len(in_masks) != 1:
raise InputsNumberError() raise InputsNumberError()
in_mask = in_masks[0] in_mask = in_masks[0]
# N, C, H, W # FIXME: This should be a correct replacement logic, but since we can't correctly generate qualified masks,
# most of the time this is a no_replace.
_, remained_in = convert_to_coarse_mask(in_mask, 1) _, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1) _, remained_out = convert_to_coarse_mask(output_mask, 1)
upscale_factor = pixelshuffle.upscale_factor in_channel_num, out_channel_num = remained_in.size(0), remained_out.size(0)
if remained_in.size(0) % (upscale_factor * upscale_factor): upscale_factor = math.floor(math.sqrt(in_channel_num / out_channel_num))
_logger.debug("Shape mismatch, remained_in:%d upscale_factor:%d",
remained_in.size(0), remained_out.size(0)) if in_channel_num != out_channel_num * (upscale_factor * upscale_factor):
raise ShapeMisMatchError() err_msg = "Your speedup model may encounter shape mismatch error during inference. "
if remained_out.size(0) * upscale_factor * upscale_factor != remained_in: err_msg += f"PixelShuffle preserved input channel number is {in_channel_num}, "
raise ShapeMisMatchError() err_msg += f"preserved output channel number is {out_channel_num}, "
err_msg += "unable to find a suitable upscale_factor, keep it as it is, please replace this module manually, "
err_msg += "or adjust the module sparsity ratio before this module to ensure that a suitable upscale_factor can be found."
# Don't raise an error because the user maybe know how to manually replace this function.
_logger.error(err_msg)
# NOTE: no_replace, use the orignal upscale_factor if we can not find a suitable upscale_factor.
upscale_factor = pixelshuffle.upscale_factor
if upscale_factor != pixelshuffle.upscale_factor:
war_msg = f"Change PixelShuffle upscale_factor from {pixelshuffle.upscale_factor} to {upscale_factor}, "
war_msg += "subsequent computation semantics may have changed."
_logger.warning(war_msg)
new_pixelshuffle = torch.nn.PixelShuffle(upscale_factor) new_pixelshuffle = torch.nn.PixelShuffle(upscale_factor)
return new_pixelshuffle
return new_pixelshuffle
\ No newline at end of file
...@@ -459,13 +459,17 @@ class ModelSpeedup: ...@@ -459,13 +459,17 @@ class ModelSpeedup:
self.bound_model, g_node.name) self.bound_model, g_node.name)
m_type = g_node.op_type m_type = g_node.op_type
if (not m_type in replace_module) and (m_type not in self.customized_replace_func): if (not m_type in replace_module) and (m_type not in self.customized_replace_func):
raise RuntimeError( err_msg = f"Has not supported replacing module with type: {m_type}, "
"Has not supported replacing the module: `{}`".format(m_type)) err_msg += f"you could report an issue at https://github.com/microsoft/nni. "
err_msg += f"If you know how to replace {m_type}, "
err_msg += f"you could implement module replacement by passing in"
err_msg += f"`customized_replace_func` to `{self.__class__.__name__}`. "
err_msg += f"You are welcome to contribute back to nni as native support if you have implemented the replacement function, "
err_msg += f"so that more users can benefit from your contributions."
raise RuntimeError(err_msg)
_logger.info("replace module (name: %s, op_type: %s)", _logger.info("replace module (name: %s, op_type: %s)",
g_node.name, m_type) g_node.name, m_type)
replace_function = replace_module[m_type] replace_function = self.customized_replace_func.get(m_type, replace_module.get(m_type, None))
if m_type in self.customized_replace_func:
replace_function = self.customized_replace_func[m_type]
compressed_module = replace_function( compressed_module = replace_function(
leaf_module, auto_infer.get_masks()) leaf_module, auto_infer.get_masks())
new_submodule = compressed_module new_submodule = compressed_module
......
...@@ -492,6 +492,35 @@ def upsample_bilinear2d_python(node, speedup): ...@@ -492,6 +492,35 @@ def upsample_bilinear2d_python(node, speedup):
return UpsampleModule(size_list, scale_list) return UpsampleModule(size_list, scale_list)
def upsample_nearest2d_python(node, speedup):
class UpsampleModule(torch.nn.Module):
def __init__(self, size_list, scale_list):
super(UpsampleModule, self).__init__()
self.size_list = size_list
self.scale_list = scale_list
def forward(self, *args):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return torch.nn.functional.upsample_nearest(args[0],
size=self.size_list, scale_factor=self.scale_list)
c_node = node.key_node
inputs = list(c_node.inputs())
size_list_node = inputs[1].node()
scale_list_node = inputs[2].node()
size_list = None
scale_list = None
if size_list_node.kind() == 'prim::ListConstruct':
size_list = translate_list(inputs[1], speedup)
if scale_list_node.kind() == 'prim::ListConstruct':
scale_list = translate_list(inputs[2], speedup)
return UpsampleModule(size_list, scale_list)
def typeas_python(node, speedup): def typeas_python(node, speedup):
""" """
currently only support type_as float. currently only support type_as float.
...@@ -583,6 +612,7 @@ trans_from_jit_to_python = { ...@@ -583,6 +612,7 @@ trans_from_jit_to_python = {
'aten::to': to_python, 'aten::to': to_python,
'aten::type_as': typeas_python, 'aten::type_as': typeas_python,
'aten::upsample_bilinear2d': upsample_bilinear2d_python, 'aten::upsample_bilinear2d': upsample_bilinear2d_python,
'aten::upsample_nearest2d': upsample_nearest2d_python,
'aten::exp': exp_python, 'aten::exp': exp_python,
'aten::squeeze': squeeze_python, 'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python, 'aten::unsqueeze': unsqueeze_python,
......
...@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_'] ...@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_']
CAT_TYPE = 'aten::cat' CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency') logger = logging.getLogger('Shape_Dependency')
RESHAPE_OPS = [CAT_TYPE, 'aten::view', RESHAPE_OPS = [CAT_TYPE, 'aten::view',
'aten::reshape', 'aten::flatten', 'aten::mean', 'aten::expand_as'] 'aten::reshape', 'aten::flatten', 'aten::mean', 'aten::expand_as', 'aten::pixel_shuffle']
def lcm_list(L): def lcm_list(L):
...@@ -85,6 +85,11 @@ def reshape_break_channel_dependency(op_node): ...@@ -85,6 +85,11 @@ def reshape_break_channel_dependency(op_node):
""" """
in_shape = op_node.auxiliary['in_shape'] in_shape = op_node.auxiliary['in_shape']
out_shape = op_node.auxiliary['out_shape'] out_shape = op_node.auxiliary['out_shape']
# FIXME: e.g., in_shape will be None if the input comes from a buffer, should be fixed in next release
if not in_shape or not out_shape:
return True
if len(in_shape) <= 1 or len(out_shape) <= 1:
return True
in_channel = in_shape[1] in_channel = in_shape[1]
out_channel = out_shape[1] out_channel = out_shape[1]
return in_channel != out_channel return in_channel != out_channel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment