Unverified Commit aa1f71c8 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] support expand_as (#4852)

parent 39ec21ca
...@@ -775,7 +775,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -775,7 +775,7 @@ class TorchModuleGraph(TorchGraph):
""" """
# extract the input & output shape for the view and flatten # extract the input & output shape for the view and flatten
for node_group in self.nodes_py.nodes_op: for node_group in self.nodes_py.nodes_op:
if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']: if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape', 'aten::expand_as']:
# get shape infor for view (aten::view) func # get shape infor for view (aten::view) func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type, cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0] node_group.node_cpps))[0]
......
...@@ -537,6 +537,13 @@ def cat_python(node, speedup): ...@@ -537,6 +537,13 @@ def cat_python(node, speedup):
return CatModule(dim) return CatModule(dim)
def expandas_python(node, speedup):
class ExpandasModule(torch.nn.Module):
def forward(self, x, y):
return x.expand_as(y).clone()
return ExpandasModule()
trans_from_jit_to_python = { trans_from_jit_to_python = {
'aten::add': add_python, 'aten::add': add_python,
'aten::add_': add_python, 'aten::add_': add_python,
...@@ -581,11 +588,11 @@ trans_from_jit_to_python = { ...@@ -581,11 +588,11 @@ trans_from_jit_to_python = {
'aten::unsqueeze': unsqueeze_python, 'aten::unsqueeze': unsqueeze_python,
'aten::constant_pad_nd': constant_pad_nd_python, 'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python, 'aten::silu': silu_python,
'aten::expand_as': expandas_python,
'prim::TupleUnpack': tupleunpack_python, 'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python, 'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python, 'prim::NumToTensor': num2tensor_python,
'prim::GetAttr': getattr_python 'prim::GetAttr': getattr_python
} }
......
...@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_'] ...@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_']
CAT_TYPE = 'aten::cat' CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency') logger = logging.getLogger('Shape_Dependency')
RESHAPE_OPS = [CAT_TYPE, 'aten::view', RESHAPE_OPS = [CAT_TYPE, 'aten::view',
'aten::reshape', 'aten::flatten', 'aten::mean'] 'aten::reshape', 'aten::flatten', 'aten::mean', 'aten::expand_as']
def lcm_list(L): def lcm_list(L):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment