"src/targets/vscode:/vscode.git/clone" did not exist on "5e24fdf9751e774927ab97f4a5d539dc5bddafe8"
Commit c8808d34 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add reshape to generated filter onnx test

Add in the reshape parameter to better mirror what we're seeing in the retinanet
network block that's causing the issue
parent a850f0bc
......@@ -2179,12 +2179,20 @@ def gathernd_batch_dims_test():
@onnx_test()
def gatherND_gtn_filter():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [1, 10])
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [5, 2])
indices = helper.make_tensor_value_info('indices', TensorProto.INT64, [10])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10, 1])
yt = helper.make_tensor_value_info('yt', TensorProto.FLOAT, [10, 1])
yn = helper.make_tensor_value_info('yn', TensorProto.FLOAT, [1, 10])
yg = helper.make_tensor_value_info('yg', TensorProto.FLOAT, [1, 10])
yr = helper.make_tensor_value_info('yr', TensorProto.FLOAT, [1, 10])
r_node = helper.make_node(
'Reshape',
inputs=['data'],
outputs=['yr'],
shape=[1, 10],
)
thresh_const = helper.make_node(
'Constant',
......@@ -2214,12 +2222,13 @@ def gatherND_gtn_filter():
node = onnx.helper.make_node(
'GatherND',
inputs=['data', 'yt'],
inputs=['yr', 'yt'],
outputs=['y'],
batch_dims=1,
)
return ([thresh_const, g_node, n_node, t_node, node], [data, indices], [y])
return ([r_node, thresh_const, g_node, n_node, t_node,
node], [data, indices], [y])
@onnx_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment