Commit 80a23cfb authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix Format

parent fad4da8a
...@@ -2320,7 +2320,8 @@ def if_then_else_diff_shape_test(): ...@@ -2320,7 +2320,8 @@ def if_then_else_diff_shape_test():
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 1]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 1])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 5]) onnx.TensorProto.FLOAT,
[2, 5])
else_out = onnx.helper.make_tensor_value_info('else_out', else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 1]) [2, 1])
...@@ -2369,11 +2370,13 @@ def if_then_else_diff_shape_test(): ...@@ -2369,11 +2370,13 @@ def if_then_else_diff_shape_test():
@onnx_test @onnx_test
def if_then_else_incompatible_shape_test(): def if_then_else_incompatible_shape_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 4, 5]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 4, 5])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 4, 5]) onnx.TensorProto.FLOAT,
[2, 3, 4, 5])
else_out = onnx.helper.make_tensor_value_info('else_out', else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 1]) [2, 1])
...@@ -2422,11 +2425,13 @@ def if_then_else_incompatible_shape_test(): ...@@ -2422,11 +2425,13 @@ def if_then_else_incompatible_shape_test():
@onnx_test @onnx_test
def if_then_else_incompatible_shape_test2(): def if_then_else_incompatible_shape_test2():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 1]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 1]) onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out', else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 2]) [2, 2])
...@@ -2475,18 +2480,20 @@ def if_then_else_incompatible_shape_test2(): ...@@ -2475,18 +2480,20 @@ def if_then_else_incompatible_shape_test2():
@onnx_test @onnx_test
def if_then_else_incompatible_output_shapes_test(): def if_then_else_incompatible_output_shapes_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 1]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 1]) onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out', else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 2]) [2, 2])
else_out2 = onnx.helper.make_tensor_value_info('else_out2', else_out2 = onnx.helper.make_tensor_value_info('else_out2',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 2]) [2, 2])
xt = np.ones((2, 3, 1)).astype(np.float) xt = np.ones((2, 3, 1)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt', xt_tensor = helper.make_tensor(name='xt',
...@@ -2512,12 +2519,11 @@ def if_then_else_incompatible_output_shapes_test(): ...@@ -2512,12 +2519,11 @@ def if_then_else_incompatible_output_shapes_test():
inputs=['y', 'yt'], inputs=['y', 'yt'],
outputs=['else_out2']) outputs=['else_out2'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [], then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out]) [then_out])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node], 'else_body', [], else_body = onnx.helper.make_graph([else_mul_node, else_sub_node],
[else_out, else_out2]) 'else_body', [], [else_out, else_out2])
cond = np.array([1]).astype(np.bool) cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond", cond_tensor = helper.make_tensor(name="cond",
...@@ -2605,9 +2611,9 @@ def if_then_else_both_empty_test(): ...@@ -2605,9 +2611,9 @@ def if_then_else_both_empty_test():
empty_val_y = np.array([]).astype(np.int64) empty_val_y = np.array([]).astype(np.int64)
empty_ts_y = helper.make_tensor(name='empty_tensor_y', empty_ts_y = helper.make_tensor(name='empty_tensor_y',
data_type=TensorProto.INT64, data_type=TensorProto.INT64,
dims=empty_val_y.shape, dims=empty_val_y.shape,
vals=empty_val_y.flatten().astype(int)) vals=empty_val_y.flatten().astype(int))
shape_const_y = helper.make_node( shape_const_y = helper.make_node(
'Constant_y', 'Constant_y',
inputs=[], inputs=[],
...@@ -2615,13 +2621,13 @@ def if_then_else_both_empty_test(): ...@@ -2615,13 +2621,13 @@ def if_then_else_both_empty_test():
value=empty_ts_y, value=empty_ts_y,
) )
else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, []) else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
[])
then_out = onnx.helper.make_tensor_value_info('shape_const_y', then_out = onnx.helper.make_tensor_value_info('shape_const_y',
onnx.TensorProto.INT64, onnx.TensorProto.INT64, [])
[])
else_body = onnx.helper.make_graph([shape_const], 'else_body', [], else_body = onnx.helper.make_graph([shape_const], 'else_body', [],
[else_out]) [else_out])
then_body = onnx.helper.make_graph([shape_const_y], 'then_body', [], then_body = onnx.helper.make_graph([shape_const_y], 'then_body', [],
[then_out]) [then_out])
...@@ -2659,7 +2665,8 @@ def if_else_empty_constant_test(): ...@@ -2659,7 +2665,8 @@ def if_else_empty_constant_test():
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1])
else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, []) else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
[])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.INT64, onnx.TensorProto.INT64,
[2, 1]) [2, 1])
...@@ -2675,7 +2682,7 @@ def if_else_empty_constant_test(): ...@@ -2675,7 +2682,7 @@ def if_else_empty_constant_test():
outputs=['then_out']) outputs=['then_out'])
else_body = onnx.helper.make_graph([shape_const], 'else_body', [], else_body = onnx.helper.make_graph([shape_const], 'else_body', [],
[else_out]) [else_out])
then_body = onnx.helper.make_graph([then_mul_node], 'then_body', [], then_body = onnx.helper.make_graph([then_mul_node], 'then_body', [],
[then_out]) [then_out])
...@@ -2713,7 +2720,8 @@ def if_then_empty_constant_test(): ...@@ -2713,7 +2720,8 @@ def if_then_empty_constant_test():
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1])
then_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, []) then_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
[])
else_out = onnx.helper.make_tensor_value_info('else_out', else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.INT64, onnx.TensorProto.INT64,
[2, 1]) [2, 1])
...@@ -2729,7 +2737,7 @@ def if_then_empty_constant_test(): ...@@ -2729,7 +2737,7 @@ def if_then_empty_constant_test():
outputs=['else_out']) outputs=['else_out'])
then_body = onnx.helper.make_graph([shape_const], 'then_body', [], then_body = onnx.helper.make_graph([shape_const], 'then_body', [],
[then_out]) [then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [], else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out]) [else_out])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment