Unverified Commit 4558bfbd authored by xcnick's avatar xcnick Committed by GitHub
Browse files

[Fix] Fix onnx unit tests (#2155)

parent 1248da3c
......@@ -67,7 +67,8 @@ def process_grid_sample(func, input, grid, ort_custom_op_path=''):
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
ort_result = sess.run(None, {
'input': input.detach().numpy(),
'grid': grid.detach().numpy()
......@@ -160,7 +161,8 @@ def test_nms():
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
onnx_dets, _ = sess.run(None, {
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
......@@ -234,7 +236,8 @@ def test_softnms():
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
onnx_dets, onnx_inds = sess.run(None, {
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
......@@ -302,7 +305,8 @@ def test_roialign():
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
onnx_output = sess.run(None, {
'input': input.detach().numpy(),
'rois': rois.detach().numpy()
......@@ -378,7 +382,8 @@ def test_roialign_rotated():
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
onnx_output = sess.run(None, {
'features': input.detach().numpy(),
'rois': rois.detach().numpy()
......@@ -440,7 +445,8 @@ def test_roipool():
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file)
sess = rt.InferenceSession(
onnx_file, providers=['CPUExecutionProvider'])
onnx_output = sess.run(
None, {
'input': input.detach().cpu().numpy(),
......@@ -470,7 +476,7 @@ def test_interpolate():
onnx_file,
input_names=['input'],
opset_version=opset_version)
sess = rt.InferenceSession(onnx_file)
sess = rt.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])
onnx_result = sess.run(None, {'input': dummy_input.detach().numpy()})
pytorch_result = func(dummy_input).detach().numpy()
......@@ -581,7 +587,8 @@ def test_rotated_feature_align():
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
onnx_output = sess.run(None, {
'feature': feature.detach().numpy(),
'bbox': bbox.detach().numpy()
......@@ -629,7 +636,8 @@ def test_corner_pool(mode, opset=11):
session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
sess = rt.InferenceSession(onnx_file, session_options)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
ort_result = sess.run(None, {'input': input.detach().numpy()})
pytorch_results = wrapped_model(input.clone())
......@@ -698,7 +706,8 @@ def test_cummax_cummin(key, opset=11):
session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
sess = rt.InferenceSession(onnx_file, session_options)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
ort_output, ort_inds = sess.run(None,
{'input': input.detach().numpy()})
pytorch_output, pytorch_inds = wrapped_model(input.clone())
......@@ -737,7 +746,7 @@ def test_roll(shifts_dims_pair):
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(onnx_file)
sess = rt.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])
ort_output = sess.run(None, {'input': input.detach().numpy()})[0]
with torch.no_grad():
......@@ -822,7 +831,8 @@ def test_modulated_deform_conv2d():
session_options.register_custom_ops_library(ort_custom_op_path)
# compute onnx_output
sess = rt.InferenceSession(onnx_file, session_options)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
onnx_output = sess.run(
None, {
'input': input.cpu().detach().numpy(),
......@@ -902,7 +912,8 @@ def test_deform_conv2d(threshold=1e-3):
session_options.register_custom_ops_library(ort_custom_op_path)
# compute onnx_output
sess = rt.InferenceSession(onnx_file, session_options)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
onnx_output = sess.run(
None, {
'input': x.cpu().detach().numpy(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment