Unverified Commit 5f9e6b61 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Feature] Add ONNX export support to torch.roll (#1194)

* add torch.roll to onnx

* remove skip test

* add support to torch<170

* add dim=0 for torch==1.9.0

* update fixture function name, add comment in roll symbolic
parent 366c628a
...@@ -90,7 +90,7 @@ def topk(g, self, k, dim, largest, sorted, out=None): ...@@ -90,7 +90,7 @@ def topk(g, self, k, dim, largest, sorted, out=None):
def masked_select(g, self, mask): def masked_select(g, self, mask):
from torch.onnx.symbolic_opset9 import nonzero, expand_as from torch.onnx.symbolic_opset9 import expand_as, nonzero
index = nonzero(g, expand_as(g, mask, self)) index = nonzero(g, expand_as(g, mask, self))
return g.op('GatherND', self, index) return g.op('GatherND', self, index)
...@@ -406,6 +406,65 @@ def cummin(g, input, dim): ...@@ -406,6 +406,65 @@ def cummin(g, input, dim):
return g.op('mmcv::cummin', input, dim_i=dim, outputs=2) return g.op('mmcv::cummin', input, dim_i=dim, outputs=2)
@parse_args('v', 'v', 'is')
def roll(g, input, shifts, dims):
from torch.onnx.symbolic_opset9 import squeeze
from packaging import version
input_shape = g.op('Shape', input)
need_flatten = len(dims) == 0
# If dims is not specified, the tensor will be flattened before
# rolling and then restored to the original shape.
if need_flatten:
resize_shape = input_shape
input = g.op('Reshape', input,
g.op('Constant', value_t=torch.LongTensor([1, -1])))
input_shape = g.op('Shape', input)
dims = [1]
for index, dim in enumerate(dims):
end_size = sym_help._slice_helper(
g, input_shape, axes=[0], ends=[dim + 1], starts=[dim])
shift_size = sym_help._slice_helper(
g, shifts, axes=[0], ends=[index + 1], starts=[index])
slice_size = g.op('Sub', end_size, shift_size)
# Can not use Mod because tensorrt does not support
div_size = g.op('Div', slice_size, end_size)
slice_size = g.op('Sub', slice_size, g.op('Mul', end_size, div_size))
if version.parse(torch.__version__) >= version.parse('1.7.0'):
# add dim=0 for pytorch 1.9.0
end_size = squeeze(g, end_size, 0)
slice_size = squeeze(g, slice_size, 0)
else:
end_size = g.op('Squeeze', end_size)
slice_size = g.op('Squeeze', slice_size)
dim = torch.LongTensor([dim])
input_slice0 = sym_help._slice_helper(
g,
input,
axes=dim,
starts=torch.LongTensor([0]),
ends=slice_size,
dynamic_slice=True)
input_slice1 = sym_help._slice_helper(
g,
input,
axes=dim,
ends=end_size,
starts=slice_size,
dynamic_slice=True)
input = g.op('Concat', input_slice1, input_slice0, axis_i=dim)
if need_flatten:
input = g.op('Reshape', input, resize_shape)
return input
def register_extra_symbolics(opset=11): def register_extra_symbolics(opset=11):
register_op('one_hot', one_hot, '', opset) register_op('one_hot', one_hot, '', opset)
register_op('im2col', im2col, '', opset) register_op('im2col', im2col, '', opset)
...@@ -433,3 +492,4 @@ def register_extra_symbolics(opset=11): ...@@ -433,3 +492,4 @@ def register_extra_symbolics(opset=11):
register_op('grid_sampler', grid_sampler, '', opset) register_op('grid_sampler', grid_sampler, '', opset)
register_op('cummax', cummax, '', opset) register_op('cummax', cummax, '', opset)
register_op('cummin', cummin, '', opset) register_op('cummin', cummin, '', opset)
register_op('roll', roll, '', opset)
...@@ -13,6 +13,19 @@ from packaging import version ...@@ -13,6 +13,19 @@ from packaging import version
onnx_file = 'tmp.onnx' onnx_file = 'tmp.onnx'
@pytest.fixture(autouse=True)
def run_before_and_after_test():
# clear onnx_file before test
if os.path.exists(onnx_file):
os.remove(onnx_file)
yield
# clear onnx_file after test
if os.path.exists(onnx_file):
os.remove(onnx_file)
class WrapFunction(nn.Module): class WrapFunction(nn.Module):
def __init__(self, wrapped_function): def __init__(self, wrapped_function):
...@@ -56,7 +69,6 @@ def process_grid_sample(func, input, grid, ort_custom_op_path=''): ...@@ -56,7 +69,6 @@ def process_grid_sample(func, input, grid, ort_custom_op_path=''):
'grid': grid.detach().numpy() 'grid': grid.detach().numpy()
}) })
pytorch_results = wrapped_model(input.clone(), grid.clone()) pytorch_results = wrapped_model(input.clone(), grid.clone())
os.remove(onnx_file)
assert np.allclose(pytorch_results, ort_result, atol=1e-3) assert np.allclose(pytorch_results, ort_result, atol=1e-3)
...@@ -149,7 +161,6 @@ def test_nms(): ...@@ -149,7 +161,6 @@ def test_nms():
'boxes': boxes.detach().numpy() 'boxes': boxes.detach().numpy()
}) })
onnx_score = onnx_dets[:, 4] onnx_score = onnx_dets[:, 4]
os.remove(onnx_file)
assert np.allclose(pytorch_score, onnx_score, atol=1e-3) assert np.allclose(pytorch_score, onnx_score, atol=1e-3)
...@@ -225,7 +236,7 @@ def test_softnms(): ...@@ -225,7 +236,7 @@ def test_softnms():
'scores': scores.detach().numpy(), 'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy() 'boxes': boxes.detach().numpy()
}) })
os.remove(onnx_file)
assert np.allclose(pytorch_dets, onnx_dets, atol=1e-3) assert np.allclose(pytorch_dets, onnx_dets, atol=1e-3)
assert np.allclose(onnx_inds, onnx_inds, atol=1e-3) assert np.allclose(onnx_inds, onnx_inds, atol=1e-3)
...@@ -299,7 +310,7 @@ def test_roialign(): ...@@ -299,7 +310,7 @@ def test_roialign():
onnx_output = onnx_output[0] onnx_output = onnx_output[0]
# allclose # allclose
os.remove(onnx_file)
assert np.allclose(pytorch_output, onnx_output, atol=1e-3) assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
...@@ -378,7 +389,7 @@ def test_roialign_rotated(): ...@@ -378,7 +389,7 @@ def test_roialign_rotated():
onnx_output = onnx_output[0] onnx_output = onnx_output[0]
# allclose # allclose
os.remove(onnx_file)
assert np.allclose(pytorch_output, onnx_output, atol=1e-3) assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
...@@ -443,7 +454,6 @@ def test_roipool(): ...@@ -443,7 +454,6 @@ def test_roipool():
onnx_output = onnx_output[0] onnx_output = onnx_output[0]
# allclose # allclose
os.remove(onnx_file)
assert np.allclose(pytorch_output, onnx_output, atol=1e-3) assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
...@@ -468,8 +478,7 @@ def test_interpolate(): ...@@ -468,8 +478,7 @@ def test_interpolate():
sess = rt.InferenceSession(onnx_file) sess = rt.InferenceSession(onnx_file)
onnx_result = sess.run(None, {'input': dummy_input.detach().numpy()}) onnx_result = sess.run(None, {'input': dummy_input.detach().numpy()})
pytorch_result = func(dummy_input).detach().numpy() pytorch_result = func(dummy_input).detach().numpy()
if os.path.exists(onnx_file):
os.remove(onnx_file)
assert np.allclose(pytorch_result, onnx_result, atol=1e-3) assert np.allclose(pytorch_result, onnx_result, atol=1e-3)
...@@ -515,7 +524,7 @@ def test_corner_pool(mode, opset=11): ...@@ -515,7 +524,7 @@ def test_corner_pool(mode, opset=11):
sess = rt.InferenceSession(onnx_file, session_options) sess = rt.InferenceSession(onnx_file, session_options)
ort_result = sess.run(None, {'input': input.detach().numpy()}) ort_result = sess.run(None, {'input': input.detach().numpy()})
pytorch_results = wrapped_model(input.clone()) pytorch_results = wrapped_model(input.clone())
os.remove(onnx_file)
assert np.allclose(pytorch_results, ort_result, atol=1e-5) assert np.allclose(pytorch_results, ort_result, atol=1e-5)
...@@ -591,4 +600,41 @@ def test_cummax_cummin(key, opset=11): ...@@ -591,4 +600,41 @@ def test_cummax_cummin(key, opset=11):
pytorch_inds = pytorch_inds.detach().numpy() pytorch_inds = pytorch_inds.detach().numpy()
assert np.allclose(pytorch_output, ort_output, atol=1e-5) assert np.allclose(pytorch_output, ort_output, atol=1e-5)
assert np.all(pytorch_inds == ort_inds) assert np.all(pytorch_inds == ort_inds)
os.remove(onnx_file)
@pytest.mark.parametrize('shifts_dims_pair', [([-3, 5], [2, 0]), (5, None)])
def test_roll(shifts_dims_pair):
opset = 11
from mmcv.onnx.symbolic import register_extra_symbolics
register_extra_symbolics(opset)
input = torch.arange(0, 4 * 5 * 6, dtype=torch.float32).view(4, 5, 6)
shifts, dims = shifts_dims_pair
func = partial(torch.roll, shifts=shifts, dims=dims)
wrapped_model = WrapFunction(func).eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model,
input,
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input'],
output_names=['output'],
opset_version=opset)
onnx_model = onnx.load(onnx_file)
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(onnx_file)
ort_output = sess.run(None, {'input': input.detach().numpy()})[0]
with torch.no_grad():
pytorch_output = wrapped_model(input.clone())
torch.testing.assert_allclose(ort_output, pytorch_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment