Commit 5ac5ab9e authored by Negin Raoof's avatar Negin Raoof Committed by Francisco Massa
Browse files

[ONNX] Export new_empty_tensor (#1733)

* adding new_empty_tensor symbolic

* flake8

* fix for feedback

* skipping the ORT test

* fix for ORT test
parent 40c99eae
......@@ -28,13 +28,13 @@ class ONNXExporterTester(unittest.TestCase):
def setUpClass(cls):
torch.manual_seed(123)
def run_model(self, model, inputs_list, tolerate_small_mismatch=False):
def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True):
model.eval()
onnx_io = io.BytesIO()
# export to onnx with the first input
torch.onnx.export(model, inputs_list[0], onnx_io,
do_constant_folding=True, opset_version=_onnx_opset_version)
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
......@@ -74,6 +74,20 @@ class ONNXExporterTester(unittest.TestCase):
else:
raise
@unittest.skip("Disable test until Split w/ zero sizes is implemented in ORT")
def test_new_empty_tensor(self):
class Module(torch.nn.Module):
def __init__(self):
super(Module, self).__init__()
self.conv2 = ops.misc.ConvTranspose2d(16, 33, (3, 5))
def forward(self, input2):
return self.conv2(input2)
input = torch.rand(0, 16, 10, 10)
test_input = torch.rand(0, 16, 20, 20)
self.run_model(Module(), [(input, ), (test_input,)], do_constant_folding=False)
def test_nms(self):
boxes = torch.rand(5, 4)
boxes[:, 2:] += torch.rand(5, 2)
......
......@@ -5,7 +5,8 @@ _onnx_opset_version = 11
def _register_custom_op():
from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx
from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx, scalar_type_to_pytorch_type, \
cast_pytorch_to_onnx
from torch.onnx.symbolic_opset9 import select, unsqueeze, squeeze, _cast_Long, reshape
@parse_args('v', 'v', 'f')
......@@ -31,7 +32,18 @@ def _register_custom_op():
pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale)
return roi_pool, None
@parse_args('v', 'is')
def new_empty_tensor_op(g, input, shape):
dtype = input.type().scalarType()
if dtype is None:
dtype = 'Float'
dtype = scalar_type_to_onnx.index(cast_pytorch_to_onnx[dtype])
shape = g.op("Constant", value_t=torch.tensor(shape))
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype]))
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, _onnx_opset_version)
register_custom_op_symbolic('torchvision::roi_align', roi_align, _onnx_opset_version)
register_custom_op_symbolic('torchvision::roi_pool', roi_pool, _onnx_opset_version)
register_custom_op_symbolic('torchvision::_new_empty_tensor_op', new_empty_tensor_op, _onnx_opset_version)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment