Unverified Commit 6aaa2d92 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] support aten::sub & aten::constant_pad_nd in speedup (#4644)

parent 45cefc7c
...@@ -16,6 +16,7 @@ replace_module = { ...@@ -16,6 +16,7 @@ replace_module = {
'MaxPool2d': lambda module, masks: no_replace(module, masks), 'MaxPool2d': lambda module, masks: no_replace(module, masks),
'AvgPool2d': lambda module, masks: no_replace(module, masks), 'AvgPool2d': lambda module, masks: no_replace(module, masks),
'AdaptiveAvgPool2d': lambda module, masks: no_replace(module, masks), 'AdaptiveAvgPool2d': lambda module, masks: no_replace(module, masks),
'ZeroPad2d': lambda module, masks: no_replace(module, masks),
'ReLU': lambda module, masks: no_replace(module, masks), 'ReLU': lambda module, masks: no_replace(module, masks),
'ReLU6': lambda module, masks: no_replace(module, masks), 'ReLU6': lambda module, masks: no_replace(module, masks),
'LeakyReLU': lambda module, masks: no_replace(module, masks), 'LeakyReLU': lambda module, masks: no_replace(module, masks),
......
...@@ -142,6 +142,29 @@ def add_python(node, speedup): ...@@ -142,6 +142,29 @@ def add_python(node, speedup):
return new_add return new_add
def sub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = [None, None]
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant[i] = parse_constant(input_i, speedup)
break
if constant[0] is None and constant[1] is None:
new_sub = torch.sub
elif constant[0] is not None:
new_sub = partial(torch.sub, input=constant)
else:
new_sub = partial(torch.sub, other=constant)
return new_sub
def floor_div_python(node, speedup): def floor_div_python(node, speedup):
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs()) inputs = list(c_node.inputs())
...@@ -228,6 +251,10 @@ def gelu_python(node, speedup): ...@@ -228,6 +251,10 @@ def gelu_python(node, speedup):
return torch.nn.GELU() return torch.nn.GELU()
def silu_python(node, speedup):
return torch.nn.SiLU()
def avgpool2d_python(node, speedup): def avgpool2d_python(node, speedup):
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs()) inputs = list(c_node.inputs())
...@@ -277,6 +304,14 @@ def unsqueeze_python(node, speedup): ...@@ -277,6 +304,14 @@ def unsqueeze_python(node, speedup):
new_unsqueeze = partial(torch.unsqueeze, dim=dim) new_unsqueeze = partial(torch.unsqueeze, dim=dim)
return new_unsqueeze return new_unsqueeze
def constant_pad_nd_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
pad = translate_list(inputs[1], speedup)
value = parse_constant(inputs[2], speedup)
new_constant_pad_nd = partial(torch.nn.functional.pad, pad=pad, value=value)
return new_constant_pad_nd
########################################################## ##########################################################
# Split Line # Split Line
# Following module/functions cannot be translated into a # Following module/functions cannot be translated into a
...@@ -379,7 +414,7 @@ def reshape_python(node, speedup): ...@@ -379,7 +414,7 @@ def reshape_python(node, speedup):
logger.info('Reshape Module output size: %s', str(self.shape)) logger.info('Reshape Module output size: %s', str(self.shape))
def forward(self, *args): def forward(self, *args):
return args[0].view(self.shape) return args[0].reshape(self.shape)
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs()) inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup) shape = translate_list(inputs[1], speedup)
...@@ -505,6 +540,8 @@ def cat_python(node, speedup): ...@@ -505,6 +540,8 @@ def cat_python(node, speedup):
trans_from_jit_to_python = { trans_from_jit_to_python = {
'aten::add': add_python, 'aten::add': add_python,
'aten::add_': add_python, 'aten::add_': add_python,
'aten::sub': sub_python,
'aten::sub_': sub_python,
'aten::mul': mul_python, 'aten::mul': mul_python,
'aten::mul_': mul_python, 'aten::mul_': mul_python,
'aten::relu': relu_python, 'aten::relu': relu_python,
...@@ -542,6 +579,8 @@ trans_from_jit_to_python = { ...@@ -542,6 +579,8 @@ trans_from_jit_to_python = {
'aten::exp': exp_python, 'aten::exp': exp_python,
'aten::squeeze': squeeze_python, 'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python, 'aten::unsqueeze': unsqueeze_python,
'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python,
'prim::TupleUnpack': tupleunpack_python, 'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python, 'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python, 'prim::NumToTensor': num2tensor_python,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment