Commit 80a23cfb authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix Format

parent fad4da8a
...@@ -2320,7 +2320,8 @@ def if_then_else_diff_shape_test(): ...@@ -2320,7 +2320,8 @@ def if_then_else_diff_shape_test():
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 1]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 1])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 5]) onnx.TensorProto.FLOAT,
[2, 5])
else_out = onnx.helper.make_tensor_value_info('else_out', else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 1]) [2, 1])
...@@ -2369,11 +2370,13 @@ def if_then_else_diff_shape_test(): ...@@ -2369,11 +2370,13 @@ def if_then_else_diff_shape_test():
@onnx_test @onnx_test
def if_then_else_incompatible_shape_test(): def if_then_else_incompatible_shape_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 4, 5]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 4, 5])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 4, 5]) onnx.TensorProto.FLOAT,
[2, 3, 4, 5])
else_out = onnx.helper.make_tensor_value_info('else_out', else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 1]) [2, 1])
...@@ -2422,11 +2425,13 @@ def if_then_else_incompatible_shape_test(): ...@@ -2422,11 +2425,13 @@ def if_then_else_incompatible_shape_test():
@onnx_test @onnx_test
def if_then_else_incompatible_shape_test2(): def if_then_else_incompatible_shape_test2():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 1]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 1]) onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out', else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 2]) [2, 2])
...@@ -2475,11 +2480,13 @@ def if_then_else_incompatible_shape_test2(): ...@@ -2475,11 +2480,13 @@ def if_then_else_incompatible_shape_test2():
@onnx_test @onnx_test
def if_then_else_incompatible_output_shapes_test(): def if_then_else_incompatible_output_shapes_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 1]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 1]) onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out', else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 2]) [2, 2])
...@@ -2512,12 +2519,11 @@ def if_then_else_incompatible_output_shapes_test(): ...@@ -2512,12 +2519,11 @@ def if_then_else_incompatible_output_shapes_test():
inputs=['y', 'yt'], inputs=['y', 'yt'],
outputs=['else_out2']) outputs=['else_out2'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [], then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out]) [then_out])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node], 'else_body', [], else_body = onnx.helper.make_graph([else_mul_node, else_sub_node],
[else_out, else_out2]) 'else_body', [], [else_out, else_out2])
cond = np.array([1]).astype(np.bool) cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond", cond_tensor = helper.make_tensor(name="cond",
...@@ -2615,10 +2621,10 @@ def if_then_else_both_empty_test(): ...@@ -2615,10 +2621,10 @@ def if_then_else_both_empty_test():
value=empty_ts_y, value=empty_ts_y,
) )
else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, []) else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
then_out = onnx.helper.make_tensor_value_info('shape_const_y',
onnx.TensorProto.INT64,
[]) [])
then_out = onnx.helper.make_tensor_value_info('shape_const_y',
onnx.TensorProto.INT64, [])
else_body = onnx.helper.make_graph([shape_const], 'else_body', [], else_body = onnx.helper.make_graph([shape_const], 'else_body', [],
[else_out]) [else_out])
...@@ -2659,7 +2665,8 @@ def if_else_empty_constant_test(): ...@@ -2659,7 +2665,8 @@ def if_else_empty_constant_test():
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1])
else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, []) else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
[])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.INT64, onnx.TensorProto.INT64,
[2, 1]) [2, 1])
...@@ -2713,7 +2720,8 @@ def if_then_empty_constant_test(): ...@@ -2713,7 +2720,8 @@ def if_then_empty_constant_test():
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1])
then_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, []) then_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
[])
else_out = onnx.helper.make_tensor_value_info('else_out', else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.INT64, onnx.TensorProto.INT64,
[2, 1]) [2, 1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment