Commit 53349569 authored by Khalique's avatar Khalique
Browse files

return tuple for gen onnx script

parent 6b36c82e
......@@ -6,7 +6,23 @@ from onnx import AttributeProto, TensorProto, GraphProto
def onnx_test(op_test):
def run_test():
model_def = helper.make_model(op_test(), producer_name=op_test.__name__)
op_info = op_test()
if len(op_info) > 3:
graph_def = helper.make_graph(
op_info[0],
op_test.__name__,
op_info[1],
op_info[2],
initializer=op_info[3]
)
else:
graph_def = helper.make_graph(
op_info[0],
op_test.__name__,
op_info[1],
op_info[2]
)
model_def = helper.make_model(graph_def, producer_name=op_test.__name__)
onnx.save(model_def, '{}.onnx'.format(op_test.__name__))
return run_test
......@@ -21,12 +37,7 @@ def acos_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_acos',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def add_bcast_test():
......@@ -42,12 +53,7 @@ def add_bcast_test():
outputs=['2']
)
return helper.make_graph(
[node],
'test-add_bcast',
[x,y],
[z]
)
return ([node], [x,y], [z])
@onnx_test
def add_fp16_test():
......@@ -61,14 +67,13 @@ def add_fp16_test():
outputs=['2'],
)
return helper.make_graph(
return (
[node],
'test-add-fp16',
[x,y],
[z],
# '0' -> 1.5, '1' -> 2.5
initializer=[onnx.helper.make_tensor('0', TensorProto.FLOAT16, [1], [15872]),
onnx.helper.make_tensor('1', TensorProto.FLOAT16, [1], [16640])]
[onnx.helper.make_tensor('0', TensorProto.FLOAT16, [1], [15872]),
onnx.helper.make_tensor('1', TensorProto.FLOAT16, [1], [16640])]
)
model_def = helper.make_model(graph_def, producer_name=('add-fp16-example'))
......@@ -86,12 +91,11 @@ def add_scalar_test():
outputs=['2']
)
return helper.make_graph(
return (
[node],
'test-add-scalar',
[x,y],
[z],
initializer=[helper.make_tensor('1', TensorProto.FLOAT, [], [1])]
[helper.make_tensor('1', TensorProto.FLOAT, [], [1])]
)
@onnx_test
......@@ -107,13 +111,7 @@ def argmax_test():
keepdims = 0
)
return helper.make_graph(
[node],
'test_argmax',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def argmin_test():
......@@ -128,12 +126,7 @@ def argmin_test():
keepdims = 0
)
return helper.make_graph(
[node],
'test_argmin',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def asin_test():
......@@ -146,12 +139,8 @@ def asin_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_asin',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def atan_test():
......@@ -164,12 +153,7 @@ def atan_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_atan',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def cast_test():
......@@ -183,12 +167,7 @@ def cast_test():
to = 1
)
return helper.make_graph(
[node],
'test_cast',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def clip_test():
......@@ -204,12 +183,7 @@ def clip_test():
min=0.0
)
return helper.make_graph(
[node],
'test-model',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def concat_test():
......@@ -224,12 +198,7 @@ def concat_test():
outputs=['2'],
)
return helper.make_graph(
[node],
'test-concat',
[x,y],
[z]
)
return ([node], [x,y], [z])
@onnx_test
def constant_test():
......@@ -248,12 +217,7 @@ def constant_test():
),
)
return helper.make_graph(
[node],
'test-constant',
[],
[y]
)
return ([node], [], [y])
@onnx_test
def constant_fill_test():
......@@ -269,12 +233,7 @@ def constant_fill_test():
input_as_shape = 0,
)
return helper.make_graph(
[node],
'constant_fill',
[],
[value],
)
return ([node], [], [value])
@onnx_test
def constant_fill_input_as_shape_test():
......@@ -305,12 +264,7 @@ def constant_fill_input_as_shape_test():
input_as_shape = 1,
)
return helper.make_graph(
[const_shape_node, node],
'constant_fill',
[],
[value],
)
return ([const_shape_node, node], [], [value])
@onnx_test
def constant_scalar_test():
......@@ -329,12 +283,7 @@ def constant_scalar_test():
),
)
return helper.make_graph(
[node],
'test-constant',
[],
[y]
)
return ([node], [], [y])
@onnx_test
def const_of_shape_empty_input_test():
......@@ -365,12 +314,7 @@ def const_of_shape_empty_input_test():
value = tensor_val,
)
return helper.make_graph(
[shape_const, node],
'constant_of_shape',
[],
[y],
)
return ([shape_const, node], [], [y])
@onnx_test
def const_of_shape_float_test():
......@@ -401,12 +345,7 @@ def const_of_shape_float_test():
value = tensor_val
)
return helper.make_graph(
[shape_const, node],
'constant_of_shape',
[],
[y],
)
return ([shape_const, node], [], [y])
@onnx_test
def const_of_shape_int64_test():
......@@ -436,12 +375,7 @@ def const_of_shape_int64_test():
value = tensor_val
)
return helper.make_graph(
[shape_const, node],
'constant_of_shape',
[],
[y],
)
return ([shape_const, node], [], [y])
@onnx_test
def const_of_shape_no_value_attr_test():
......@@ -466,12 +400,7 @@ def const_of_shape_no_value_attr_test():
outputs=['y'],
)
return helper.make_graph(
[shape_const, node],
'constant_of_shape',
[],
[y],
)
return ([shape_const, node], [], [y])
@onnx_test
def conv_autopad_fail_test():
......@@ -489,12 +418,7 @@ def conv_autopad_fail_test():
pads = [0,0,1,1,0,0,1,1]
)
return helper.make_graph(
[node],
'test_conv',
[x, y],
[out],
)
return ([node], [x,y], [out])
@onnx_test
def conv_bias_test():
......@@ -511,12 +435,7 @@ def conv_bias_test():
strides = [1, 1]
)
return helper.make_graph(
[node],
'test_conv',
[x, y, z],
[out],
)
return ([node], [x,y,z], [out])
@onnx_test
def conv_bn_relu_maxpool_test():
......@@ -560,11 +479,10 @@ def conv_bn_relu_maxpool_test():
kernel_shape=[2,2]
)
return helper.make_graph(
return (
[node0, node1, node2, node3],
'test_conv_bn_relu',
[x, y, z, m, n, k, l],
[out],
[out]
)
@onnx_test
......@@ -598,11 +516,10 @@ def conv_relu_maxpool_test():
kernel_shape=[2,2]
)
return helper.make_graph(
return (
[node1, node2, node3],
'test_conv_relu',
[x, y, z],
[out],
[out]
)
@onnx_test
......@@ -662,11 +579,10 @@ def conv_relu_maxpool_x2_test():
kernel_shape=[2,2]
)
return helper.make_graph(
return (
[node1, node2, node3, node4, node5, node6],
'test_conv_relu2',
[x, y, z, m, n],
[out],
[out]
)
@onnx_test
......@@ -680,12 +596,7 @@ def cos_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_cos',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def cosh_test():
......@@ -698,12 +609,7 @@ def cosh_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_cosh',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def dropout_test():
......@@ -716,12 +622,7 @@ def dropout_test():
outputs=['1'],
)
return helper.make_graph(
[node],
'test-dropout',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def elu_test():
......@@ -735,12 +636,7 @@ def elu_test():
alpha=0.01
)
return helper.make_graph(
[node],
'test-model',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def erf_test():
......@@ -753,12 +649,7 @@ def erf_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_erf',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def exp_test():
......@@ -771,12 +662,7 @@ def exp_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_exp',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def expand_test():
......@@ -802,12 +688,7 @@ def expand_test():
outputs=['y']
)
return helper.make_graph(
[shape_const, node],
'expand',
[x],
[y],
)
return ([shape_const,node], [x], [y])
@onnx_test
def flatten_test():
......@@ -828,12 +709,7 @@ def flatten_test():
outputs=['3']
)
return helper.make_graph(
[node,node2],
'test-flatten',
[x],
[y,y2]
)
return ([node,node2], [x], [y,y2])
@onnx_test
def gather_test():
......@@ -848,12 +724,7 @@ def gather_test():
axis=1,
)
return helper.make_graph(
[node],
'test_gather',
[x, i],
[y],
)
return ([node], [x,i], [y])
@onnx_test
def gemm_test():
......@@ -872,12 +743,7 @@ def gemm_test():
transB=1
)
return helper.make_graph(
[node],
'test-gemm',
[x, y, z],
[a]
)
return ([node], [x,y,z], [a])
@onnx_test
def gemm_ex_test():
......@@ -895,12 +761,7 @@ def gemm_ex_test():
transA = 1
)
return helper.make_graph(
[node],
'test_gemm_ex',
[m1, m2, m3],
[y],
)
return ([node], [m1,m2,m3], [y])
@onnx_test
def gemm_ex_brcst_test():
......@@ -918,12 +779,7 @@ def gemm_ex_brcst_test():
transA = 1
)
return helper.make_graph(
[node],
'test_gemm_ex',
[m1, m2, m3],
[y],
)
return ([node], [m1,m2,m3], [y])
@onnx_test
def globalavgpool_test():
......@@ -936,12 +792,7 @@ def globalavgpool_test():
outputs=['1'],
)
return helper.make_graph(
[node],
'test-globalavgpool',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def globalmaxpool_test():
......@@ -954,12 +805,7 @@ def globalmaxpool_test():
outputs=['1'],
)
return helper.make_graph(
[node],
'test-globalmaxpool',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def group_conv_test():
......@@ -974,12 +820,7 @@ def group_conv_test():
outputs=['2'],
)
return helper.make_graph(
[node],
'test-group_conv',
[x,y],
[z]
)
return ([node], [x,y], [z])
@onnx_test
def imagescaler_test():
......@@ -994,12 +835,7 @@ def imagescaler_test():
scale=0.5
)
return helper.make_graph(
[node],
'test-imagescaler',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def implicit_add_bcast_test():
......@@ -1013,12 +849,7 @@ def implicit_add_bcast_test():
outputs=['2'],
)
return helper.make_graph(
[node],
'test-multi_bcast',
[x,y],
[z]
)
return ([node], [x,y], [z])
@onnx_test
def implicit_pow_bcast_test():
......@@ -1032,12 +863,7 @@ def implicit_pow_bcast_test():
outputs=['out'],
)
return helper.make_graph(
[node],
'pow_test',
[arg0, arg1],
[arg_out],
)
return ([node], [arg0,arg1], [arg_out])
@onnx_test
def implicit_sub_bcast_test():
......@@ -1051,12 +877,7 @@ def implicit_sub_bcast_test():
outputs=['out'],
)
return helper.make_graph(
[node],
'subtraction2',
[arg0, arg1],
[arg_out],
)
return ([node], [arg0,arg1], [arg_out])
@onnx_test
def leaky_relu_test():
......@@ -1070,12 +891,7 @@ def leaky_relu_test():
alpha=0.01
)
return helper.make_graph(
[node],
'test-model',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def log_test():
......@@ -1088,12 +904,7 @@ def log_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_log',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def logsoftmax_test():
......@@ -1107,12 +918,7 @@ def logsoftmax_test():
axis = 1
)
return helper.make_graph(
[node],
'test_logsoftmax',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def lrn_test():
......@@ -1129,12 +935,7 @@ def lrn_test():
outputs=['1']
)
return helper.make_graph(
[node],
'test-lrn',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def matmul_bmbm_test():
......@@ -1148,12 +949,7 @@ def matmul_bmbm_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_matmul',
[m1, m2],
[y],
)
return ([node], [m1,m2], [y])
@onnx_test
def matmul_bmv_test():
......@@ -1167,12 +963,7 @@ def matmul_bmv_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_matmul',
[m1, m2],
[y],
)
return ([node], [m1,m2], [y])
@onnx_test
def matmul_mv_test():
......@@ -1186,12 +977,7 @@ def matmul_mv_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_matmul',
[m1, m2],
[y],
)
return ([node], [m1,m2], [y])
@onnx_test
def matmul_vbm_test():
......@@ -1205,12 +991,7 @@ def matmul_vbm_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_matmul',
[m1, m2],
[y],
)
return ([node], [m1,m2], [y])
@onnx_test
def matmul_vm_test():
......@@ -1224,12 +1005,7 @@ def matmul_vm_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_matmul',
[m1, m2],
[y],
)
return ([node], [m1,m2], [y])
@onnx_test
def matmul_vv_test():
......@@ -1243,12 +1019,7 @@ def matmul_vv_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_matmul',
[m1, m2],
[y],
)
return ([node], [m1,m2], [y])
@onnx_test
def max_test():
......@@ -1263,12 +1034,7 @@ def max_test():
outputs=['3'],
)
return helper.make_graph(
[node],
'test-dropout',
[a, b, c],
[y]
)
return ([node], [a,b,c], [y])
@onnx_test
def min_test():
......@@ -1283,12 +1049,7 @@ def min_test():
outputs=['3'],
)
return helper.make_graph(
[node],
'test-dropout',
[a, b, c],
[y]
)
return ([node], [a,b,c], [y])
@onnx_test
def no_pad_test():
......@@ -1302,13 +1063,7 @@ def no_pad_test():
outputs=['1']
)
return helper.make_graph(
[node],
'test-no-pad',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def pad_test():
......@@ -1322,13 +1077,7 @@ def pad_test():
outputs=['1']
)
return helper.make_graph(
[node],
'test-pad',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def pow_test():
......@@ -1343,12 +1092,7 @@ def pow_test():
)
return helper.make_graph(
[node],
'pow_test',
[arg0, arg1],
[arg_out],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def reducemean_test():
......@@ -1364,12 +1108,7 @@ def reducemean_test():
keepdims = 0
)
return helper.make_graph(
[node],
'test_reducemean',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def reducemean_keepdims_test():
......@@ -1385,12 +1124,7 @@ def reducemean_keepdims_test():
keepdims = 1
)
return helper.make_graph(
[node],
'test_reducemean',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def reducesum_test():
......@@ -1406,12 +1140,7 @@ def reducesum_test():
keepdims = 0
)
return helper.make_graph(
[node],
'test_reducesum',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def reducesum_multiaxis_test():
......@@ -1427,12 +1156,7 @@ def reducesum_multiaxis_test():
keepdims = 0
)
return helper.make_graph(
[node],
'test_reducesum',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def reducesum_keepdims_test():
......@@ -1448,12 +1172,7 @@ def reducesum_keepdims_test():
keepdims = 1
)
return helper.make_graph(
[node],
'test_reducesum',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def reshape_test():
......@@ -1476,12 +1195,11 @@ def reshape_test():
outputs=['3']
)
return helper.make_graph(
return (
[node,node2],
'test-reshape',
[x, x_shape],
[y,y2],
initializer=[helper.make_tensor('1', TensorProto.INT64, [2], [3, 8])]
[helper.make_tensor('1', TensorProto.INT64, [2], [3, 8])]
)
@onnx_test
......@@ -1504,12 +1222,7 @@ def reshape_non_standard_test():
shape=[4, 3, 2]
)
return helper.make_graph(
[trans, res],
'reshape-ns',
[x],
[y],
)
return ([trans,res], [x], [y])
@onnx_test
def shape_test():
......@@ -1522,12 +1235,7 @@ def shape_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_shape',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def shape_gather_test():
......@@ -1563,12 +1271,7 @@ def shape_gather_test():
axis=0,
)
return helper.make_graph(
[node_const, node_shape, node_gather],
'shape_gather',
[x],
[z],
)
return ([node_const,node_shape,node_gather], [x], [z])
@onnx_test
def sign_test():
......@@ -1581,12 +1284,7 @@ def sign_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_sign',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def sin_test():
......@@ -1599,12 +1297,7 @@ def sin_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_sin',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def sinh_test():
......@@ -1617,12 +1310,7 @@ def sinh_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_sinh',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def slice_test():
......@@ -1638,12 +1326,7 @@ def slice_test():
outputs=['1']
)
return helper.make_graph(
[node],
'test-slice',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def softmax_test():
......@@ -1656,12 +1339,7 @@ def softmax_test():
outputs=['1']
)
return helper.make_graph(
[node],
'test-softmax',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def sqrt_test():
......@@ -1674,12 +1352,7 @@ def sqrt_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_sqrt',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def squeeze_unsqueeze_test():
......@@ -1701,12 +1374,7 @@ def squeeze_unsqueeze_test():
outputs=['2']
)
return helper.make_graph(
[node,node2],
'test-squeeze-unsqueeze',
[x],
[z]
)
return ([node,node2], [x], [z])
@onnx_test
def sub_bcast_test():
......@@ -1722,13 +1390,7 @@ def sub_bcast_test():
axis = 1,
)
return helper.make_graph(
[node],
'subtraction2',
[arg0, arg1],
[arg_out],
)
return ([node], [arg0,arg1], [arg_out])
@onnx_test
def sub_scalar_test():
......@@ -1757,12 +1419,7 @@ def sub_scalar_test():
outputs=['out'],
)
return helper.make_graph(
[arg_const, node],
'subtraction1',
[arg_node],
[arg_out],
)
return ([arg_const,node], [arg_node], [arg_out])
@onnx_test
def sum_test():
......@@ -1778,12 +1435,7 @@ def sum_test():
outputs=['3'],
)
return helper.make_graph(
[node],
'test-sum',
[a, b, c],
[y]
)
return ([node], [a,b,c], [y])
@onnx_test
def sum_test():
......@@ -1798,12 +1450,7 @@ def sum_test():
outputs=['3'],
)
return helper.make_graph(
[node],
'test-sum',
[a, b, c],
[y]
)
return ([node], [a,b,c], [y])
@onnx_test
def tan_test():
......@@ -1816,12 +1463,7 @@ def tan_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_tan',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def tanh_test():
......@@ -1834,12 +1476,7 @@ def tanh_test():
outputs=['y'],
)
return helper.make_graph(
[node],
'test_tanh',
[x],
[y],
)
return ([node], [x], [y])
@onnx_test
def transpose_test():
......@@ -1853,12 +1490,7 @@ def transpose_test():
outputs=['1'],
)
return helper.make_graph(
[node],
'test-transpose',
[x],
[y]
)
return ([node], [x], [y])
@onnx_test
def transpose_gather_test():
......@@ -1888,12 +1520,7 @@ def transpose_gather_test():
)
return helper.make_graph(
[td, ti, node],
'test_gather',
[x, i],
[y],
)
return ([td, ti, node], [x, i], [y])
@onnx_test
def unknown_test():
......@@ -1914,12 +1541,4 @@ def unknown_test():
outputs=['3']
)
return helper.make_graph(
[node,node2],
'test-unknown',
[x,y],
[a]
)
model_def = helper.make_model(graph_def, producer_name='unknown-example')
onnx.save(model_def, 'unknown_test.onnx')
\ No newline at end of file
return ([node,node2], [x,y], [a])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment