Commit 4418bf77 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add greater operator to gatherND filter model

parent eeedc246
......@@ -2184,10 +2184,24 @@ def gatherND_gtn_filter():
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10, 1])
yt = helper.make_tensor_value_info('yt', TensorProto.FLOAT, [10, 1])
yn = helper.make_tensor_value_info('yn', TensorProto.FLOAT, [1, 10])
yg = helper.make_tensor_value_info('yg', TensorProto.FLOAT, [1, 10])
thresh_const = helper.make_node(
'Constant',
inputs=[],
outputs=['thresh'],
value=0.5,
)
g_node = onnx.helper.make_node(
'Greater',
inputs=['indices', 'thresh'],
outputs=['yg'],
)
n_node = onnx.helper.make_node(
'NonZero',
inputs=['indices'],
inputs=['yg'],
outputs=['yn'],
)
......@@ -2205,7 +2219,7 @@ def gatherND_gtn_filter():
batch_dims=1,
)
return ([n_node, t_node, node], [data, indices], [y])
return ([thresh_const, g_node, n_node, t_node, node], [data, indices], [y])
@onnx_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment