Unverified Commit 03a2e3a1 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Fix onnx interpolate conversion (#917)

* fix onnx interpolate recommit

* fix bugs on torch==1.6.0

* remove print

* fix error in torch==1.5.1
parent be6541d4
......@@ -228,8 +228,13 @@ def _interpolate_size_to_scales(g, input, output_size, dim):
def _interpolate_get_scales_if_available(g, scales):
if len(scales) == 0:
return None
# scales[0] is NoneType in Pytorch == 1.5.1
# scales[0] is TensorType with sizes = [] in Pytorch == 1.6.0
# scales[0] is ListType in Pytorch == 1.7.0
scale_desc = 'fs' if scales[0].type().kind() == 'ListType' else 'f'
# scales[0] is TensorType with sizes = [2] in Pytorch == 1.8.0
scale_desc = 'fs' if scales[0].type().kind() == 'ListType' or (
scales[0].type().kind() == 'TensorType' and
(sum(scales[0].type().sizes()) > 1)) else 'f'
available_scales = _maybe_get_const(
scales[0], scale_desc) != -1 and not _is_none(scales[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment