Unverified Commit b65830e0 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Support UnSqueeze (#3960)

parent 442342cb
......@@ -253,6 +253,13 @@ def squeeze_python(node, speedup):
new_squeeze = partial(torch.squeeze, dim=dim)
return new_squeeze
def unsqueeze_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = parse_constant(inputs[1], speedup)
new_unsqueeze = partial(torch.unsqueeze, dim=dim)
return new_unsqueeze
##########################################################
# Split Line
# Following module/functions cannot be translated into a
......@@ -517,6 +524,7 @@ trans_from_jit_to_python = {
'aten::upsample_bilinear2d': upsample_bilinear2d_python,
'aten::exp': exp_python,
'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python,
'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment