Commit 80a23cfb authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix Format

parent fad4da8a
......@@ -2320,7 +2320,8 @@ def if_then_else_diff_shape_test():
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 1])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 5])
onnx.TensorProto.FLOAT,
[2, 5])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 1])
......@@ -2369,11 +2370,13 @@ def if_then_else_diff_shape_test():
@onnx_test
def if_then_else_incompatible_shape_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 4, 5])
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 4, 5])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 4, 5])
onnx.TensorProto.FLOAT,
[2, 3, 4, 5])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 1])
......@@ -2422,11 +2425,13 @@ def if_then_else_incompatible_shape_test():
@onnx_test
def if_then_else_incompatible_shape_test2():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 1])
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 1])
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 2])
......@@ -2475,18 +2480,20 @@ def if_then_else_incompatible_shape_test2():
@onnx_test
def if_then_else_incompatible_output_shapes_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 1])
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 1])
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 2])
else_out2 = onnx.helper.make_tensor_value_info('else_out2',
onnx.TensorProto.FLOAT,
[2, 2])
onnx.TensorProto.FLOAT,
[2, 2])
xt = np.ones((2, 3, 1)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
......@@ -2512,12 +2519,11 @@ def if_then_else_incompatible_output_shapes_test():
inputs=['y', 'yt'],
outputs=['else_out2'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node], 'else_body', [],
[else_out, else_out2])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node],
'else_body', [], [else_out, else_out2])
cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
......@@ -2605,9 +2611,9 @@ def if_then_else_both_empty_test():
empty_val_y = np.array([]).astype(np.int64)
empty_ts_y = helper.make_tensor(name='empty_tensor_y',
data_type=TensorProto.INT64,
dims=empty_val_y.shape,
vals=empty_val_y.flatten().astype(int))
data_type=TensorProto.INT64,
dims=empty_val_y.shape,
vals=empty_val_y.flatten().astype(int))
shape_const_y = helper.make_node(
'Constant_y',
inputs=[],
......@@ -2615,13 +2621,13 @@ def if_then_else_both_empty_test():
value=empty_ts_y,
)
else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, [])
else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
[])
then_out = onnx.helper.make_tensor_value_info('shape_const_y',
onnx.TensorProto.INT64,
[])
onnx.TensorProto.INT64, [])
else_body = onnx.helper.make_graph([shape_const], 'else_body', [],
[else_out])
[else_out])
then_body = onnx.helper.make_graph([shape_const_y], 'then_body', [],
[then_out])
......@@ -2659,7 +2665,8 @@ def if_else_empty_constant_test():
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1])
else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, [])
else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
[])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.INT64,
[2, 1])
......@@ -2675,7 +2682,7 @@ def if_else_empty_constant_test():
outputs=['then_out'])
else_body = onnx.helper.make_graph([shape_const], 'else_body', [],
[else_out])
[else_out])
then_body = onnx.helper.make_graph([then_mul_node], 'then_body', [],
[then_out])
......@@ -2713,7 +2720,8 @@ def if_then_empty_constant_test():
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1])
then_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, [])
then_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
[])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.INT64,
[2, 1])
......@@ -2729,7 +2737,7 @@ def if_then_empty_constant_test():
outputs=['else_out'])
then_body = onnx.helper.make_graph([shape_const], 'then_body', [],
[then_out])
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment