Unverified Commit 459d728c authored by RunningLeon's avatar RunningLeon Committed by GitHub
Browse files

[Fix]: parse scales of list in PyTorch==1.7.0 (#451 in MMPose) (#815)

* modify symbolic

* add comments
parent be2616a0
......@@ -39,6 +39,8 @@ def _parse_arg(value, desc):
return tval
elif desc == 'is':
return [int(v) for v in tval]
elif desc == 'fs':
return [float(v) for v in tval]
else:
raise RuntimeError(
"ONNX symbolic doesn't know to interpret Constant node")
......@@ -226,21 +228,18 @@ def _interpolate_size_to_scales(g, input, output_size, dim):
def _interpolate_get_scales_if_available(g, scales):
if len(scales) == 0:
return None
available_scales = _maybe_get_const(scales[0], 'f') != -1 and not _is_none(
scales[0])
available_scales = _maybe_get_const(scales[0],
'fs') != -1 and not _is_none(scales[0])
if not available_scales:
return None
scales_list = []
for scale in scales:
unsqueezed_scale = _unsqueeze_helper(g, scale, 0)
# ONNX only supports float for the scales. double -> float.
unsqueezed_scale = g.op(
'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float'])
scales_list.append(unsqueezed_scale)
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
scales = g.op('Concat', offsets, *scales_list, axis_i=0)
scales_list = g.op(
'Constant', value_t=torch.tensor(_maybe_get_const(scales[0], 'fs')))
# modify to support PyTorch==1.7.0
# https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501
scales = g.op('Concat', offsets, scales_list, axis_i=0)
return scales
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment