Commit 21b7dad4 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Create gatherND with transpose and Nonzero inputs

Used to generate a subset of the broken op in retinanet to speed up
debugging/matching
parent c930d79a
......@@ -2177,6 +2177,37 @@ def gathernd_batch_dims_test():
return ([node], [x, i], [y])
@onnx_test()
def gatherND_gtn_filter():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [1, 10])
indices = helper.make_tensor_value_info('indices', TensorProto.INT64, [10])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10, 1])
yt = helper.make_tensor_value_info('yt', TensorProto.FLOAT, [10, 1])
yn = helper.make_tensor_value_info('yn', TensorProto.FLOAT, [1, 10])
n_node = onnx.helper.make_node(
'NonZero',
inputs=['indices'],
outputs=['yn'],
)
t_node = onnx.helper.make_node(
'Transpose',
inputs=['yn'],
outputs=['yt'],
perm=[1, 0],
)
node = onnx.helper.make_node(
'GatherND',
inputs=['data', 'yt'],
outputs=['y'],
batch_dims=1,
)
return ([n_node, t_node, node], [data, indices], [y])
@onnx_test()
def gemm_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, 6])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment