"src/libtorio/ffmpeg/pybind/pybind.cpp" did not exist on "76fca37ac8941b72a509a6e58d623632efe04543"
Unverified Commit 1ea1b489 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Fix] Fix roi align symbolic for torch>=1.13 (#2443)

* fix roi align symbolic for torch>1.13

* fix lint
parent 5f58b910
...@@ -20,16 +20,25 @@ class RoIAlignFunction(Function): ...@@ -20,16 +20,25 @@ class RoIAlignFunction(Function):
def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio, def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
pool_mode, aligned): pool_mode, aligned):
from torch.onnx import TensorProtoDataType from torch.onnx import TensorProtoDataType
from torch.onnx.symbolic_helper import _slice_helper from torch.onnx.symbolic_opset9 import sub
from torch.onnx.symbolic_opset9 import squeeze, sub
def _select(g, self, dim, index):
return g.op('Gather', self, index, axis_i=dim)
# batch_indices = rois[:, 0].long() # batch_indices = rois[:, 0].long()
batch_indices = _slice_helper(g, rois, axes=[1], starts=[0], ends=[1]) batch_indices = _select(
batch_indices = squeeze(g, batch_indices, 1) g, rois, 1,
g.op('Constant', value_t=torch.tensor([0], dtype=torch.long)))
batch_indices = g.op('Squeeze', batch_indices, axes_i=[1])
batch_indices = g.op( batch_indices = g.op(
'Cast', batch_indices, to_i=TensorProtoDataType.INT64) 'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
# rois = rois[:, 1:] # rois = rois[:, 1:]
rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) rois = _select(
g, rois, 1,
g.op(
'Constant',
value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
if aligned: if aligned:
# rois -= 0.5/spatial_scale # rois -= 0.5/spatial_scale
aligned_offset = g.op( aligned_offset = g.op(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment