Unverified Commit aa1f71c8 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] support expand_as (#4852)

parent 39ec21ca
......@@ -775,7 +775,7 @@ class TorchModuleGraph(TorchGraph):
"""
# extract the input & output shape for the view and flatten
for node_group in self.nodes_py.nodes_op:
if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape', 'aten::expand_as']:
# get shape infor for view (aten::view) func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
......
......@@ -537,6 +537,13 @@ def cat_python(node, speedup):
return CatModule(dim)
def expandas_python(node, speedup):
class ExpandasModule(torch.nn.Module):
def forward(self, x, y):
return x.expand_as(y).clone()
return ExpandasModule()
trans_from_jit_to_python = {
'aten::add': add_python,
'aten::add_': add_python,
......@@ -581,11 +588,11 @@ trans_from_jit_to_python = {
'aten::unsqueeze': unsqueeze_python,
'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python,
'aten::expand_as': expandas_python,
'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python,
'prim::GetAttr': getattr_python
}
......
......@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_']
CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency')
RESHAPE_OPS = [CAT_TYPE, 'aten::view',
'aten::reshape', 'aten::flatten', 'aten::mean']
'aten::reshape', 'aten::flatten', 'aten::mean', 'aten::expand_as']
def lcm_list(L):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment