Commit 3a848f0d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into doc2

parents 64e8e30a d1e945da
deconv_input_pads_test:®
=
x
wy" ConvTranspose*
pads@@@@ *
strides@@ deconv_input_pads_testZ
x




Z
w




b
y




B
deconv_output_padding_test:
C
x
wy" ConvTranspose*
output_padding@@*
strides@@deconv_output_padding_testZ
x




Z
w




b
y




B
deconv_output_shape_test:
A
x
wy" ConvTranspose*
output_shape@
@*
strides@@deconv_output_shape_testZ
x




Z
w




b
y




B
 deconv_test:…

x
wyconv1" ConvTranspose deconv_testZ
x




Z
w




b
y




B
...@@ -38,6 +38,20 @@ def acos_test(): ...@@ -38,6 +38,20 @@ def acos_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def acosh_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10])
node = onnx.helper.make_node(
'Acosh',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])
@onnx_test @onnx_test
def add_bcast_test(): def add_bcast_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
...@@ -130,6 +144,20 @@ def asin_test(): ...@@ -130,6 +144,20 @@ def asin_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def asinh_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10])
node = onnx.helper.make_node(
'Asinh',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])
@onnx_test @onnx_test
def atan_test(): def atan_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
...@@ -144,6 +172,64 @@ def atan_test(): ...@@ -144,6 +172,64 @@ def atan_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def atanh_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10])
node = onnx.helper.make_node(
'Atanh',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])
@onnx_test
def averagepool_notset_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 1, 1])
node = onnx.helper.make_node('AveragePool',
inputs=['x'],
outputs=['y'],
kernel_shape=[6, 6],
strides=[2, 2],
pads=[0, 0, 1, 1],
auto_pad='NOTSET')
return ([node], [x], [y])
@onnx_test
def averagepool_same_lower_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5])
node = onnx.helper.make_node('AveragePool',
inputs=['x'],
outputs=['y'],
kernel_shape=[2, 2],
auto_pad='SAME_LOWER')
return ([node], [x], [y])
@onnx_test
def averagepool_same_upper_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5])
node = onnx.helper.make_node('AveragePool',
inputs=['x'],
outputs=['y'],
kernel_shape=[2, 2],
auto_pad='SAME_UPPER')
return ([node], [x], [y])
@onnx_test @onnx_test
def cast_test(): def cast_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [10]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [10])
...@@ -406,6 +492,22 @@ def conv_autopad_fail_test(): ...@@ -406,6 +492,22 @@ def conv_autopad_fail_test():
return ([node], [x, y], [out]) return ([node], [x, y], [out])
@onnx_test
def conv_autopad_same_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 3, 3])
out = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 32, 32])
node = onnx.helper.make_node('Conv',
inputs=['0', '1'],
outputs=['2'],
dilations=[1, 1],
strides=[1, 1],
auto_pad='SAME')
return ([node], [x, y], [out])
@onnx_test @onnx_test
def conv_bias_test(): def conv_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32])
...@@ -528,6 +630,22 @@ def conv_relu_maxpool_x2_test(): ...@@ -528,6 +630,22 @@ def conv_relu_maxpool_x2_test():
return ([node1, node2, node3, node4, node5, node6], [x, y, z, m, n], [out]) return ([node1, node2, node3, node4, node5, node6], [x, y, z, m, n], [out])
@onnx_test
def convinteger_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 5, 5])
z = helper.make_tensor_value_info('2', TensorProto.INT32, [1])
out = helper.make_tensor_value_info('3', TensorProto.INT32, [1, 2, 28, 28])
node = onnx.helper.make_node('ConvInteger',
inputs=['0', '1', '2'],
outputs=['3'],
dilations=[1, 1],
strides=[1, 1])
return ([node], [x, y, z], [out])
@onnx_test @onnx_test
def cos_test(): def cos_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
...@@ -556,6 +674,109 @@ def cosh_test(): ...@@ -556,6 +674,109 @@ def cosh_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def deconv_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 1, 3, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5])
node = onnx.helper.make_node('ConvTranspose',
name='conv1',
inputs=['x', 'w'],
outputs=['y'])
return ([node], [x, w], [y])
@onnx_test
def deconv_bias_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 1, 3, 3])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5])
node = onnx.helper.make_node('ConvTranspose',
name='conv1',
inputs=['x', 'w', 'b'],
outputs=['y'])
return ([node], [x, w, b], [y])
@onnx_test
def deconv_input_pads_strides_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 7, 5])
node = onnx.helper.make_node('ConvTranspose',
inputs=['x', 'w'],
outputs=['y'],
strides=[3, 2],
pads=[1, 1, 1, 1])
return ([node], [x, w], [y])
@onnx_test
def deconv_input_pads_asymm_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 8, 6])
node = onnx.helper.make_node('ConvTranspose',
inputs=['x', 'w'],
outputs=['y'],
strides=[3, 2],
pads=[0, 0, 1, 1])
return ([node], [x, w], [y])
@onnx_test
def deconv_output_shape_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 10, 8])
node = onnx.helper.make_node('ConvTranspose',
inputs=['x', 'w'],
outputs=['y'],
strides=[3, 2],
output_shape=[10, 8])
return ([node], [x, w], [y])
@onnx_test
def deconv_output_padding_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 10, 8])
node = onnx.helper.make_node('ConvTranspose',
inputs=['x', 'w'],
outputs=['y'],
strides=[3, 2],
output_padding=[1, 1])
return ([node], [x, w], [y])
@onnx_test
def deconv_stride_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 7, 3])
node = onnx.helper.make_node('ConvTranspose',
inputs=['x', 'w'],
outputs=['y'],
strides=[3, 2])
return ([node], [x, w], [y])
@onnx_test @onnx_test
def dropout_test(): def dropout_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 2, 2]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 2, 2])
...@@ -791,6 +1012,20 @@ def imagescaler_test(): ...@@ -791,6 +1012,20 @@ def imagescaler_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def imagescaler_half_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT16, [1, 3, 16, 16])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 3, 16, 16])
node = onnx.helper.make_node('ImageScaler',
inputs=['0'],
outputs=['1'],
bias=[0.01, 0.02, 0.03],
scale=0.5)
return ([node], [x], [y])
@onnx_test @onnx_test
def implicit_add_bcast_test(): def implicit_add_bcast_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
...@@ -858,6 +1093,50 @@ def initializer_not_an_input(): ...@@ -858,6 +1093,50 @@ def initializer_not_an_input():
return ([node], [x], [y], [w]) return ([node], [x], [y], [w])
@onnx_test
def instance_norm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 3, 3])
scale = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2])
bias = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 2, 3, 3])
node = onnx.helper.make_node('InstanceNormalization',
inputs=['0', '1', '2'],
outputs=['3'])
return ([node], [x, scale, bias], [y])
@onnx_test
def instance_norm_val_test():
x = np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]])
scale = np.array([1, 2])
bias = np.array([0, 1])
x_tensor = helper.make_tensor(name='x_tensor',
data_type=TensorProto.FLOAT,
dims=x.shape,
vals=x.flatten().astype(np.float))
scale_tensor = helper.make_tensor(name='scale_tensor',
data_type=TensorProto.FLOAT,
dims=scale.shape,
vals=scale.flatten().astype(np.float))
bias_tensor = helper.make_tensor(name='bias_tensor',
data_type=TensorProto.FLOAT,
dims=bias.shape,
vals=bias.flatten().astype(np.float))
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 3, 3])
node = onnx.helper.make_node(
'InstanceNormalization',
inputs=['x_tensor', 'scale_tensor', 'bias_tensor'],
outputs=['y'])
return ([node], [], [y], [x_tensor, scale_tensor, bias_tensor])
@onnx_test @onnx_test
def leaky_relu_test(): def leaky_relu_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
...@@ -1004,6 +1283,21 @@ def matmul_vv_test(): ...@@ -1004,6 +1283,21 @@ def matmul_vv_test():
return ([node], [m1, m2], [y]) return ([node], [m1, m2], [y])
@onnx_test
def matmulinteger_test():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [3, 6, 16])
m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [3, 16, 8])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [3, 6, 8])
node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test @onnx_test
def max_test(): def max_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
...@@ -1020,6 +1314,36 @@ def max_test(): ...@@ -1020,6 +1314,36 @@ def max_test():
return ([node], [a, b, c], [y]) return ([node], [a, b, c], [y])
@onnx_test
def maxpool_notset_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 1, 1])
node = onnx.helper.make_node('MaxPool',
inputs=['x'],
outputs=['y'],
kernel_shape=[6, 6],
strides=[2, 2],
pads=[0, 0, 1, 1],
auto_pad='NOTSET')
return ([node], [x], [y])
@onnx_test
def maxpool_same_upper_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5])
node = onnx.helper.make_node('MaxPool',
inputs=['x'],
outputs=['y'],
kernel_shape=[2, 2],
auto_pad='SAME_UPPER')
return ([node], [x], [y])
@onnx_test @onnx_test
def min_test(): def min_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
...@@ -1078,10 +1402,86 @@ def pow_test(): ...@@ -1078,10 +1402,86 @@ def pow_test():
return ([node], [arg0, arg1], [arg_out]) return ([node], [arg0, arg1], [arg_out])
@onnx_test
def prelu_brcst_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 5])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT,
[2, 3, 4, 5])
node = onnx.helper.make_node(
'PRelu',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def reducel1_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 6])
axes = [-2]
node = onnx.helper.make_node('ReduceL1',
inputs=['x'],
outputs=['y'],
axes=axes,
keepdims=0)
return ([node], [x], [y])
@onnx_test
def reducel2_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5])
axes = [-1]
node = onnx.helper.make_node('ReduceL2',
inputs=['x'],
outputs=['y'],
axes=axes,
keepdims=0)
return ([node], [x], [y])
@onnx_test
def reduce_log_sum_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 1, 5, 6])
axes = [-3]
node = onnx.helper.make_node('ReduceLogSum',
inputs=['x'],
outputs=['y'],
axes=axes,
keepdims=1)
return ([node], [x], [y])
@onnx_test
def reduce_log_sum_exp_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 5, 6])
axes = [-4]
node = onnx.helper.make_node('ReduceLogSumExp',
inputs=['x'],
outputs=['y'],
axes=axes,
keepdims=1)
return ([node], [x], [y])
@onnx_test @onnx_test
def reducemax_test(): def reducemax_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 6])
axes = [2] axes = [2]
node = onnx.helper.make_node('ReduceMax', node = onnx.helper.make_node('ReduceMax',
...@@ -1139,12 +1539,12 @@ def reducemin_test(): ...@@ -1139,12 +1539,12 @@ def reducemin_test():
@onnx_test @onnx_test
def reducesum_test(): def reduceprod_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 1]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6])
axes = [2] axes = [2]
node = onnx.helper.make_node('ReduceSum', node = onnx.helper.make_node('ReduceProd',
inputs=['x'], inputs=['x'],
outputs=['y'], outputs=['y'],
axes=axes, axes=axes,
...@@ -1154,10 +1554,10 @@ def reducesum_test(): ...@@ -1154,10 +1554,10 @@ def reducesum_test():
@onnx_test @onnx_test
def reducesum_multiaxis_test(): def reducesum_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 1]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6])
axes = [2, 3] axes = [2]
node = onnx.helper.make_node('ReduceSum', node = onnx.helper.make_node('ReduceSum',
inputs=['x'], inputs=['x'],
...@@ -1183,6 +1583,36 @@ def reducesum_keepdims_test(): ...@@ -1183,6 +1583,36 @@ def reducesum_keepdims_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def reducesum_multiaxis_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 1])
axes = [2, 3]
node = onnx.helper.make_node('ReduceSum',
inputs=['x'],
outputs=['y'],
axes=axes,
keepdims=0)
return ([node], [x], [y])
@onnx_test
def reducesum_square_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 6])
axes = [-2]
node = onnx.helper.make_node('ReduceSumSquare',
inputs=['x'],
outputs=['y'],
axes=axes,
keepdims=0)
return ([node], [x], [y])
@onnx_test @onnx_test
def reshape_test(): def reshape_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [4, 2, 3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [4, 2, 3])
...@@ -1241,7 +1671,7 @@ def shape_test(): ...@@ -1241,7 +1671,7 @@ def shape_test():
@onnx_test @onnx_test
def shape_gather_test(): def shape_gather_test():
values = np.array([1]) values = np.array([1])
value = helper.make_tensor_value_info('value', TensorProto.INT32, [1]) # value = helper.make_tensor_value_info('value', TensorProto.INT32, [1])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [7, 3, 10]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [7, 3, 10])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [3]) y = helper.make_tensor_value_info('y', TensorProto.INT64, [3])
z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [1]) z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [1])
...@@ -1423,23 +1853,6 @@ def sub_scalar_test(): ...@@ -1423,23 +1853,6 @@ def sub_scalar_test():
return ([arg_const, node], [arg_node], [arg_out]) return ([arg_const, node], [arg_node], [arg_out])
@onnx_test
def sum_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3])
node = onnx.helper.make_node(
'Sum',
inputs=['0', '1', '2'],
outputs=['3'],
)
return ([node], [a, b, c], [y])
@onnx_test @onnx_test
def sum_test(): def sum_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
...@@ -1533,7 +1946,9 @@ def transpose_gather_test(): ...@@ -1533,7 +1946,9 @@ def transpose_gather_test():
def unknown_test(): def unknown_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5])
helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5])
a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2, 3, 4, 5]) a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2, 3, 4, 5])
node = onnx.helper.make_node('Unknown', inputs=['0', '1'], outputs=['2']) node = onnx.helper.make_node('Unknown', inputs=['0', '1'], outputs=['2'])
...@@ -1541,3 +1956,26 @@ def unknown_test(): ...@@ -1541,3 +1956,26 @@ def unknown_test():
node2 = onnx.helper.make_node('Unknown', inputs=['2'], outputs=['3']) node2 = onnx.helper.make_node('Unknown', inputs=['2'], outputs=['3'])
return ([node, node2], [x, y], [a]) return ([node, node2], [x, y], [a])
@onnx_test
def variable_batch_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
[None, 3, 16, 16])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT,
[None, 3, 16, 16])
node = onnx.helper.make_node('Identity', inputs=['0'], outputs=['1'])
return ([node], [x], [y])
@onnx_test
def variable_batch_leq_zero_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [0, 3, 16, 16])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [-1, 3, 16, 16])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [-1, 3, 16, 16])
node = onnx.helper.make_node('Add', inputs=['0', '1'], outputs=['2'])
return ([node], [x, y], [z])
instance_norm_test:
#
0
1
23"InstanceNormalizationinstance_norm_testZ
0




Z
1

Z
2

b
3




B
matmulinteger_test:y

1
2y" MatMulIntegermatmulinteger_testZ
1



Z
2



b
y



B
\ No newline at end of file
maxpool_same_upper_test:
A
xy"MaxPool*
auto_pad"
SAME_UPPER*
kernel_shape@@maxpool_same_upper_testZ
x




b
y




B
\ No newline at end of file
...@@ -4,9 +4,28 @@ ...@@ -4,9 +4,28 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = true)
{
auto prog = migraphx::parse_onnx(name);
if(eliminate_deadcode)
migraphx::run_passes(prog, {migraphx::dead_code_elimination{}});
// remove the last identity instruction
auto last_ins = std::prev(prog.end());
if(last_ins->name() == "@return")
{
prog.remove_instruction(last_ins);
}
return prog;
}
TEST_CASE(rnn_test_bidirectional) TEST_CASE(rnn_test_bidirectional)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
...@@ -43,7 +62,7 @@ TEST_CASE(rnn_test_bidirectional) ...@@ -43,7 +62,7 @@ TEST_CASE(rnn_test_bidirectional)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx"); auto prog = optimize_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -85,7 +104,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -85,7 +104,7 @@ TEST_CASE(rnn_test_one_direction)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx"); auto prog = optimize_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -111,7 +130,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -111,7 +130,7 @@ TEST_CASE(rnn_test_one_direction)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx"); auto prog = optimize_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -135,7 +154,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -135,7 +154,7 @@ TEST_CASE(rnn_test_one_direction)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx"); auto prog = optimize_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -163,7 +182,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -163,7 +182,7 @@ TEST_CASE(rnn_test_one_direction)
seq_len, seq_len,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx"); auto prog = optimize_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -207,7 +226,7 @@ TEST_CASE(gru_test) ...@@ -207,7 +226,7 @@ TEST_CASE(gru_test)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx"); auto prog = optimize_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -241,7 +260,7 @@ TEST_CASE(gru_test) ...@@ -241,7 +260,7 @@ TEST_CASE(gru_test)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx"); auto prog = optimize_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -278,7 +297,7 @@ TEST_CASE(gru_test) ...@@ -278,7 +297,7 @@ TEST_CASE(gru_test)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx"); auto prog = optimize_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -317,7 +336,7 @@ TEST_CASE(gru_test_args) ...@@ -317,7 +336,7 @@ TEST_CASE(gru_test_args)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx"); auto prog = optimize_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -349,7 +368,7 @@ TEST_CASE(gru_test_args) ...@@ -349,7 +368,7 @@ TEST_CASE(gru_test_args)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx"); auto prog = optimize_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -386,7 +405,7 @@ TEST_CASE(gru_test_args) ...@@ -386,7 +405,7 @@ TEST_CASE(gru_test_args)
seq_len, seq_len,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx"); auto prog = optimize_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -432,7 +451,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -432,7 +451,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); auto prog = optimize_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -469,7 +488,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -469,7 +488,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); auto prog = optimize_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -506,7 +525,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -506,7 +525,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx"); auto prog = optimize_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -543,7 +562,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -543,7 +562,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); auto prog = optimize_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -577,7 +596,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -577,7 +596,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx"); auto prog = optimize_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -611,7 +630,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -611,7 +630,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); auto prog = optimize_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -660,8 +679,7 @@ TEST_CASE(lstm_forward) ...@@ -660,8 +679,7 @@ TEST_CASE(lstm_forward)
ic, ic,
pph); pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_forward.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -690,8 +708,93 @@ TEST_CASE(lstm_forward) ...@@ -690,8 +708,93 @@ TEST_CASE(lstm_forward)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f3args.onnx");
EXPECT(p == prog);
}
// 3 args, hs output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
auto prog = optimize_onnx("onnx_lstm_hs.onnx");
EXPECT(p == prog);
}
// 3 args, last output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_last.onnx");
EXPECT(p == prog);
}
// 3 args, cell output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f3args.onnx"); auto prog = optimize_onnx("onnx_lstm_cell.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -721,8 +824,7 @@ TEST_CASE(lstm_forward) ...@@ -721,8 +824,7 @@ TEST_CASE(lstm_forward)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_f4args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_f4args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -754,7 +856,7 @@ TEST_CASE(lstm_forward) ...@@ -754,7 +856,7 @@ TEST_CASE(lstm_forward)
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f5args.onnx"); auto prog = optimize_onnx("onnx_lstm_f5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -787,7 +889,7 @@ TEST_CASE(lstm_forward) ...@@ -787,7 +889,7 @@ TEST_CASE(lstm_forward)
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f6args.onnx"); auto prog = optimize_onnx("onnx_lstm_f6args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -821,7 +923,7 @@ TEST_CASE(lstm_forward) ...@@ -821,7 +923,7 @@ TEST_CASE(lstm_forward)
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f7args.onnx"); auto prog = optimize_onnx("onnx_lstm_f7args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -866,8 +968,7 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -866,8 +968,7 @@ TEST_CASE(lstm_forward_actv_func)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_f0af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_f0af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -897,8 +998,7 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -897,8 +998,7 @@ TEST_CASE(lstm_forward_actv_func)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_f1af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -930,7 +1030,7 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -930,7 +1030,7 @@ TEST_CASE(lstm_forward_actv_func)
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx"); auto prog = optimize_onnx("onnx_lstm_f2af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -979,8 +1079,7 @@ TEST_CASE(lstm_reverse) ...@@ -979,8 +1079,7 @@ TEST_CASE(lstm_reverse)
ic, ic,
pph); pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_reverse.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1012,7 +1111,7 @@ TEST_CASE(lstm_reverse) ...@@ -1012,7 +1111,7 @@ TEST_CASE(lstm_reverse)
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_r5args.onnx"); auto prog = optimize_onnx("onnx_lstm_r5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1041,8 +1140,7 @@ TEST_CASE(lstm_reverse) ...@@ -1041,8 +1140,7 @@ TEST_CASE(lstm_reverse)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_r0af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_r0af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1095,8 +1193,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1095,8 +1193,7 @@ TEST_CASE(lstm_bidirectional)
ic, ic,
pph); pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1129,8 +1226,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1129,8 +1226,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi3args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1164,8 +1260,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1164,8 +1260,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi4args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1200,8 +1295,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1200,8 +1295,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi5args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1237,8 +1331,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1237,8 +1331,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi6args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1275,8 +1368,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1275,8 +1368,7 @@ TEST_CASE(lstm_bidirectional)
ic, ic,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi7args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1326,8 +1418,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1326,8 +1418,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi0af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1361,8 +1452,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1361,8 +1452,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi1af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi1af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1397,8 +1487,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1397,8 +1487,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi2af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi2af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1434,8 +1523,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1434,8 +1523,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi4af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi4af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1472,8 +1560,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1472,8 +1560,7 @@ TEST_CASE(lstm_bi_actv_funcs)
ic, ic,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi5af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi5af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1506,8 +1593,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1506,8 +1593,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi6af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi6af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
...@@ -5,16 +5,46 @@ ...@@ -5,16 +5,46 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = false)
{
auto prog = migraphx::parse_onnx(name);
if(eliminate_deadcode)
migraphx::run_passes(prog, {migraphx::dead_code_elimination{}});
// remove the last identity instruction
auto last_ins = std::prev(prog.end());
if(last_ins->name() == "@return")
{
prog.remove_instruction(last_ins);
}
return prog;
}
TEST_CASE(acos_test) TEST_CASE(acos_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::acos{}, input); p.add_instruction(migraphx::op::acos{}, input);
auto prog = migraphx::parse_onnx("acos_test.onnx"); auto prog = optimize_onnx("acos_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(acosh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::acosh{}, input);
auto prog = optimize_onnx("acosh_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -27,7 +57,7 @@ TEST_CASE(add_bcast_test) ...@@ -27,7 +57,7 @@ TEST_CASE(add_bcast_test)
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_onnx("add_bcast_test.onnx"); auto prog = optimize_onnx("add_bcast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -40,7 +70,7 @@ TEST_CASE(add_fp16_test) ...@@ -40,7 +70,7 @@ TEST_CASE(add_fp16_test)
auto l1 = auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {2.5}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {2.5}});
p.add_instruction(migraphx::op::add{}, l0, l1); p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_onnx("add_fp16_test.onnx"); auto prog = optimize_onnx("add_fp16_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -50,10 +80,9 @@ TEST_CASE(add_scalar_test) ...@@ -50,10 +80,9 @@ TEST_CASE(add_scalar_test)
migraphx::program p; migraphx::program p;
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}}); auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1); p.add_instruction(migraphx::op::add{}, l0, m1);
auto prog = migraphx::parse_onnx("add_scalar_test.onnx"); auto prog = optimize_onnx("add_scalar_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -64,7 +93,7 @@ TEST_CASE(argmax_test) ...@@ -64,7 +93,7 @@ TEST_CASE(argmax_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = p.add_instruction(migraphx::op::argmax{2}, l0); auto ins = p.add_instruction(migraphx::op::argmax{2}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, ins); p.add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = migraphx::parse_onnx("argmax_test.onnx"); auto prog = optimize_onnx("argmax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -75,7 +104,7 @@ TEST_CASE(argmin_test) ...@@ -75,7 +104,7 @@ TEST_CASE(argmin_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = p.add_instruction(migraphx::op::argmin{3}, l0); auto ins = p.add_instruction(migraphx::op::argmin{3}, l0);
p.add_instruction(migraphx::op::squeeze{{3}}, ins); p.add_instruction(migraphx::op::squeeze{{3}}, ins);
auto prog = migraphx::parse_onnx("argmin_test.onnx"); auto prog = optimize_onnx("argmin_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -86,7 +115,18 @@ TEST_CASE(asin_test) ...@@ -86,7 +115,18 @@ TEST_CASE(asin_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::asin{}, input); p.add_instruction(migraphx::op::asin{}, input);
auto prog = migraphx::parse_onnx("asin_test.onnx"); auto prog = optimize_onnx("asin_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(asinh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::asinh{}, input);
auto prog = optimize_onnx("asinh_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -97,7 +137,63 @@ TEST_CASE(atan_test) ...@@ -97,7 +137,63 @@ TEST_CASE(atan_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::atan{}, input); p.add_instruction(migraphx::op::atan{}, input);
auto prog = migraphx::parse_onnx("atan_test.onnx"); auto prog = optimize_onnx("atan_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(atanh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::atanh{}, input);
auto prog = optimize_onnx("atanh_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(averagepool_notset_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
auto ins_pad = p.add_instruction(migraphx::op::pad{pads}, input);
p.add_instruction(migraphx::op::pooling{"average", {0, 0}, {2, 2}, {6, 6}}, ins_pad);
auto prog = optimize_onnx("averagepool_notset_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(averagepool_same_lower_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0};
auto ins_pad = p.add_instruction(migraphx::op::pad{pads}, input);
p.add_instruction(
migraphx::op::pooling{
"average", {0, 0}, {1, 1}, {2, 2}, migraphx::op::padding_mode_t::same},
ins_pad);
auto prog = optimize_onnx("averagepool_same_lower_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(averagepool_same_upper_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
auto ins_pad = p.add_instruction(migraphx::op::pad{pads}, input);
p.add_instruction(
migraphx::op::pooling{
"average", {0, 0}, {1, 1}, {2, 2}, migraphx::op::padding_mode_t::same},
ins_pad);
auto prog = optimize_onnx("averagepool_same_upper_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -108,7 +204,7 @@ TEST_CASE(cast_test) ...@@ -108,7 +204,7 @@ TEST_CASE(cast_test)
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}}); auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}});
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l); p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l);
auto prog = migraphx::parse_onnx("cast_test.onnx"); auto prog = optimize_onnx("cast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -118,7 +214,7 @@ TEST_CASE(ceil_test) ...@@ -118,7 +214,7 @@ TEST_CASE(ceil_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::ceil{}, input); p.add_instruction(migraphx::op::ceil{}, input);
auto prog = migraphx::parse_onnx("ceil_test.onnx"); auto prog = optimize_onnx("ceil_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -128,7 +224,7 @@ TEST_CASE(clip_test) ...@@ -128,7 +224,7 @@ TEST_CASE(clip_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0); p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
auto prog = migraphx::parse_onnx("clip_test.onnx"); auto prog = optimize_onnx("clip_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -139,7 +235,7 @@ TEST_CASE(concat_test) ...@@ -139,7 +235,7 @@ TEST_CASE(concat_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7, 4, 3}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7, 4, 3}});
p.add_instruction(migraphx::op::concat{0}, l0, l1); p.add_instruction(migraphx::op::concat{0}, l0, l1);
auto prog = migraphx::parse_onnx("concat_test.onnx"); auto prog = optimize_onnx("concat_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -148,7 +244,7 @@ TEST_CASE(constant_test) ...@@ -148,7 +244,7 @@ TEST_CASE(constant_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}});
auto prog = migraphx::parse_onnx("constant_test.onnx"); auto prog = optimize_onnx("constant_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -160,7 +256,7 @@ TEST_CASE(constant_fill_test) ...@@ -160,7 +256,7 @@ TEST_CASE(constant_fill_test)
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> value(s.elements(), 1.0); std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value}); p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("constant_fill_test.onnx"); auto prog = optimize_onnx("constant_fill_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -175,7 +271,7 @@ TEST_CASE(constant_fill_input_as_shape_test) ...@@ -175,7 +271,7 @@ TEST_CASE(constant_fill_input_as_shape_test)
migraphx::shape s{migraphx::shape::float_type, dims}; migraphx::shape s{migraphx::shape::float_type, dims};
std::vector<float> value(s.elements(), 1.0); std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value}); p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("constant_fill_input_as_shape_test.onnx"); auto prog = optimize_onnx("constant_fill_input_as_shape_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -184,7 +280,7 @@ TEST_CASE(constant_scalar_test) ...@@ -184,7 +280,7 @@ TEST_CASE(constant_scalar_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}});
auto prog = migraphx::parse_onnx("constant_scalar_test.onnx"); auto prog = optimize_onnx("constant_scalar_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -197,7 +293,7 @@ TEST_CASE(const_of_shape_empty_input_test) ...@@ -197,7 +293,7 @@ TEST_CASE(const_of_shape_empty_input_test)
std::vector<int64_t> vec(s.elements(), 10); std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec)); p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape_empty_input_test.onnx"); auto prog = optimize_onnx("const_of_shape_empty_input_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -210,7 +306,7 @@ TEST_CASE(const_of_shape_float_test) ...@@ -210,7 +306,7 @@ TEST_CASE(const_of_shape_float_test)
std::vector<float> vec(s.elements(), 10.0f); std::vector<float> vec(s.elements(), 10.0f);
p.add_literal(migraphx::literal(s, vec)); p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape_float_test.onnx"); auto prog = optimize_onnx("const_of_shape_float_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -223,7 +319,7 @@ TEST_CASE(const_of_shape_int64_test) ...@@ -223,7 +319,7 @@ TEST_CASE(const_of_shape_int64_test)
std::vector<int64_t> vec(s.elements(), 10); std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec)); p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape_int64_test.onnx"); auto prog = optimize_onnx("const_of_shape_int64_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -236,13 +332,27 @@ TEST_CASE(const_of_shape_no_value_attr_test) ...@@ -236,13 +332,27 @@ TEST_CASE(const_of_shape_no_value_attr_test)
std::vector<float> vec(s.elements(), 0.0f); std::vector<float> vec(s.elements(), 0.0f);
p.add_literal(migraphx::literal(s, vec)); p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape_no_value_attr_test.onnx"); auto prog = optimize_onnx("const_of_shape_no_value_attr_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(conv_autopad_fail_test) TEST_CASE(conv_autopad_fail_test)
{ {
EXPECT(test::throws([&] { migraphx::parse_onnx("conv_autopad_fail_test.onnx"); })); EXPECT(test::throws([&] { optimize_onnx("conv_autopad_fail_test.onnx"); }));
}
TEST_CASE(conv_autopad_same_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
migraphx::op::convolution op;
op.padding = {1, 1};
op.padding_mode = migraphx::op::padding_mode_t::same;
p.add_instruction(op, l0, l1);
auto prog = optimize_onnx("conv_autopad_same_test.onnx");
EXPECT(p == prog);
} }
TEST_CASE(conv_bias_test) TEST_CASE(conv_bias_test)
...@@ -256,7 +366,7 @@ TEST_CASE(conv_bias_test) ...@@ -256,7 +366,7 @@ TEST_CASE(conv_bias_test)
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4); p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = migraphx::parse_onnx("conv_bias_test.onnx"); auto prog = optimize_onnx("conv_bias_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -279,7 +389,7 @@ TEST_CASE(conv_bn_relu_maxpool_test) ...@@ -279,7 +389,7 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto l7 = p.add_instruction(migraphx::op::relu{}, l6); auto l7 = p.add_instruction(migraphx::op::relu{}, l6);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto prog = migraphx::parse_onnx("conv_bn_relu_maxpool_test.onnx"); auto prog = optimize_onnx("conv_bn_relu_maxpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -296,7 +406,7 @@ TEST_CASE(conv_relu_maxpool_test) ...@@ -296,7 +406,7 @@ TEST_CASE(conv_relu_maxpool_test)
auto l6 = p.add_instruction(migraphx::op::relu{}, l5); auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto prog = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto prog = optimize_onnx("conv_relu_maxpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -321,18 +431,33 @@ TEST_CASE(conv_relu_maxpool_x2_test) ...@@ -321,18 +431,33 @@ TEST_CASE(conv_relu_maxpool_x2_test)
auto l13 = p.add_instruction(migraphx::op::relu{}, l12); auto l13 = p.add_instruction(migraphx::op::relu{}, l12);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto prog = migraphx::parse_onnx("conv_relu_maxpool_x2_test.onnx"); auto prog = optimize_onnx("conv_relu_maxpool_x2_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(convinteger_bias_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::int8_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::int32_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::quant_convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = optimize_onnx("convinteger_bias_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(cos_test) TEST_CASE(cos_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::cos{}, input); p.add_instruction(migraphx::op::cos{}, input);
auto prog = migraphx::parse_onnx("cos_test.onnx"); auto prog = optimize_onnx("cos_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -342,18 +467,91 @@ TEST_CASE(cosh_test) ...@@ -342,18 +467,91 @@ TEST_CASE(cosh_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
p.add_instruction(migraphx::op::cosh{}, input); p.add_instruction(migraphx::op::cosh{}, input);
auto prog = migraphx::parse_onnx("cosh_test.onnx"); auto prog = optimize_onnx("cosh_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(deconv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 1, 3, 3}});
p.add_instruction(migraphx::op::deconvolution{}, l0, l1);
auto prog = optimize_onnx("deconv_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(deconv_bias_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l2 = p.add_parameter("b", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::deconvolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = optimize_onnx("deconv_bias_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(deconv_input_pads_strides_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
p.add_instruction(migraphx::op::deconvolution{{1, 1}, {3, 2}}, l0, l1);
auto prog = optimize_onnx("deconv_input_pads_strides_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(deconv_input_pads_asymm_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
auto l2 = p.add_instruction(migraphx::op::deconvolution{{0, 0}, {3, 2}}, l0, l1);
p.add_instruction(migraphx::op::slice{{0, 1, 2, 3}, {0, 0, 0, 0}, {1, 2, 8, 6}}, l2);
auto prog = optimize_onnx("deconv_input_pads_asymm_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(deconv_output_shape_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
auto l2 = p.add_instruction(migraphx::op::deconvolution{{0, 0}, {3, 2}}, l0, l1);
p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 1, 1}}, l2);
auto prog = optimize_onnx("deconv_output_shape_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(deconv_output_padding_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
auto l2 = p.add_instruction(migraphx::op::deconvolution{{0, 0}, {3, 2}}, l0, l1);
p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 1, 1}}, l2);
auto prog = optimize_onnx("deconv_output_padding_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(dropout_test) TEST_CASE(dropout_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}});
p.add_instruction(migraphx::op::identity{}, input); p.add_instruction(migraphx::op::identity{}, input);
auto prog = migraphx::parse_onnx("dropout_test.onnx"); auto prog = optimize_onnx("dropout_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -364,7 +562,7 @@ TEST_CASE(elu_test) ...@@ -364,7 +562,7 @@ TEST_CASE(elu_test)
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::elu{0.01}, input); p.add_instruction(migraphx::op::elu{0.01}, input);
auto prog = migraphx::parse_onnx("elu_test.onnx"); auto prog = optimize_onnx("elu_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -375,7 +573,7 @@ TEST_CASE(erf_test) ...@@ -375,7 +573,7 @@ TEST_CASE(erf_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
p.add_instruction(migraphx::op::erf{}, input); p.add_instruction(migraphx::op::erf{}, input);
auto prog = migraphx::parse_onnx("erf_test.onnx"); auto prog = optimize_onnx("erf_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -385,7 +583,7 @@ TEST_CASE(exp_test) ...@@ -385,7 +583,7 @@ TEST_CASE(exp_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::exp{}, input); p.add_instruction(migraphx::op::exp{}, input);
auto prog = migraphx::parse_onnx("exp_test.onnx"); auto prog = optimize_onnx("exp_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -398,7 +596,7 @@ TEST_CASE(expand_test) ...@@ -398,7 +596,7 @@ TEST_CASE(expand_test)
p.add_literal(migraphx::literal(ss, {2, 3, 4, 5})); p.add_literal(migraphx::literal(ss, {2, 3, 4, 5}));
p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param); p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param);
auto prog = migraphx::parse_onnx("expand_test.onnx"); auto prog = optimize_onnx("expand_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -408,7 +606,7 @@ TEST_CASE(flatten_test) ...@@ -408,7 +606,7 @@ TEST_CASE(flatten_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::flatten{2}, l0); p.add_instruction(migraphx::op::flatten{2}, l0);
p.add_instruction(migraphx::op::flatten{1}, l0); p.add_instruction(migraphx::op::flatten{1}, l0);
auto prog = migraphx::parse_onnx("flatten_test.onnx"); auto prog = optimize_onnx("flatten_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -419,7 +617,7 @@ TEST_CASE(floor_test) ...@@ -419,7 +617,7 @@ TEST_CASE(floor_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::floor{}, input); p.add_instruction(migraphx::op::floor{}, input);
auto prog = migraphx::parse_onnx("floor_test.onnx"); auto prog = optimize_onnx("floor_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -431,7 +629,7 @@ TEST_CASE(gather_test) ...@@ -431,7 +629,7 @@ TEST_CASE(gather_test)
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}}); auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
int axis = 1; int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1); p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx"); auto prog = optimize_onnx("gather_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -439,15 +637,16 @@ TEST_CASE(gather_test) ...@@ -439,15 +637,16 @@ TEST_CASE(gather_test)
TEST_CASE(gemm_test) TEST_CASE(gemm_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}}); auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type});
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0); auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{7, 11}}, l2);
auto alpha = 2.f; auto alpha = 2.f;
auto beta = 2.0f; auto beta = 2.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1); p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1, bl2);
auto prog = migraphx::parse_onnx("gemm_test.onnx"); auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -462,7 +661,7 @@ TEST_CASE(gemm_ex_test) ...@@ -462,7 +661,7 @@ TEST_CASE(gemm_ex_test)
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2); p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2);
auto prog = migraphx::parse_onnx("gemm_ex_test.onnx"); auto prog = optimize_onnx("gemm_ex_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -479,7 +678,7 @@ TEST_CASE(gemm_ex_brcst_test) ...@@ -479,7 +678,7 @@ TEST_CASE(gemm_ex_brcst_test)
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, t2); p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, t2);
auto prog = migraphx::parse_onnx("gemm_ex_brcst_test.onnx"); auto prog = optimize_onnx("gemm_ex_brcst_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -493,7 +692,7 @@ TEST_CASE(globalavgpool_test) ...@@ -493,7 +692,7 @@ TEST_CASE(globalavgpool_test)
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input); p.add_instruction(op, input);
auto prog = migraphx::parse_onnx("globalavgpool_test.onnx"); auto prog = optimize_onnx("globalavgpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -507,7 +706,7 @@ TEST_CASE(globalmaxpool_test) ...@@ -507,7 +706,7 @@ TEST_CASE(globalmaxpool_test)
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input); p.add_instruction(op, input);
auto prog = migraphx::parse_onnx("globalmaxpool_test.onnx"); auto prog = optimize_onnx("globalmaxpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -520,7 +719,7 @@ TEST_CASE(group_conv_test) ...@@ -520,7 +719,7 @@ TEST_CASE(group_conv_test)
migraphx::op::convolution op; migraphx::op::convolution op;
op.group = 4; op.group = 4;
p.add_instruction(op, l0, l1); p.add_instruction(op, l0, l1);
auto prog = migraphx::parse_onnx("group_conv_test.onnx"); auto prog = optimize_onnx("group_conv_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -538,7 +737,26 @@ TEST_CASE(imagescaler_test) ...@@ -538,7 +737,26 @@ TEST_CASE(imagescaler_test)
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = migraphx::parse_onnx("imagescaler_test.onnx"); auto prog = optimize_onnx("imagescaler_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(imagescaler_half_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {1, 3, 16, 16}};
auto l0 = p.add_parameter("0", s);
auto scale_val =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5f}});
auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::half_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = optimize_onnx("imagescaler_half_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -548,11 +766,10 @@ TEST_CASE(implicit_add_bcast_test) ...@@ -548,11 +766,10 @@ TEST_CASE(implicit_add_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3); p.add_instruction(migraphx::op::add{}, l0, l3);
auto prog = migraphx::parse_onnx("implicit_add_bcast_test.onnx"); auto prog = optimize_onnx("implicit_add_bcast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -562,11 +779,10 @@ TEST_CASE(implicit_pow_bcast_test) ...@@ -562,11 +779,10 @@ TEST_CASE(implicit_pow_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::pow{}, l2, l3); p.add_instruction(migraphx::op::pow{}, l0, l3);
auto prog = migraphx::parse_onnx("implicit_pow_bcast_test.onnx"); auto prog = optimize_onnx("implicit_pow_bcast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -576,11 +792,10 @@ TEST_CASE(implicit_sub_bcast_test) ...@@ -576,11 +792,10 @@ TEST_CASE(implicit_sub_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l2, l3); p.add_instruction(migraphx::op::sub{}, l0, l3);
auto prog = migraphx::parse_onnx("implicit_sub_bcast_test.onnx"); auto prog = optimize_onnx("implicit_sub_bcast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -593,7 +808,39 @@ TEST_CASE(initializer_not_an_input) ...@@ -593,7 +808,39 @@ TEST_CASE(initializer_not_an_input)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
p.add_instruction(migraphx::op::dot{}, l0, l1); p.add_instruction(migraphx::op::dot{}, l0, l1);
auto prog = migraphx::parse_onnx("initializer_not_an_input.onnx"); auto prog = optimize_onnx("initializer_not_an_input.onnx");
EXPECT(p == prog);
}
TEST_CASE(instance_norm_test)
{
std::vector<size_t> dims{1, 2, 3, 3};
migraphx::shape s1{migraphx::shape::float_type, dims};
migraphx::shape s2{migraphx::shape::float_type, {2}};
migraphx::program p;
auto x = p.add_parameter("0", s1);
auto scale = p.add_parameter("1", s2);
auto bias = p.add_parameter("2", s2);
auto mean = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, x);
auto mean_bcast = p.add_instruction(migraphx::op::multibroadcast{dims}, mean);
auto l0 = p.add_instruction(migraphx::op::sqdiff{}, x, mean_bcast);
auto variance = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
auto l1 = p.add_instruction(migraphx::op::sub{}, x, mean_bcast);
auto epsilon_literal = p.add_literal(1e-5f);
auto epsilon_bcast = p.add_instruction(migraphx::op::multibroadcast{dims}, epsilon_literal);
auto variance_bcast = p.add_instruction(migraphx::op::multibroadcast{dims}, variance);
auto l2 = p.add_instruction(migraphx::op::add{}, variance_bcast, epsilon_bcast);
auto l3 = p.add_instruction(migraphx::op::rsqrt{}, l2);
auto l4 = p.add_instruction(migraphx::op::mul{}, l1, l3);
auto scale_bcast = p.add_instruction(migraphx::op::broadcast{1, dims}, scale);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, dims}, bias);
auto l5 = p.add_instruction(migraphx::op::mul{}, l4, scale_bcast);
p.add_instruction(migraphx::op::add{}, l5, bias_bcast);
auto prog = optimize_onnx("instance_norm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -605,7 +852,7 @@ TEST_CASE(leaky_relu_test) ...@@ -605,7 +852,7 @@ TEST_CASE(leaky_relu_test)
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {3}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::leaky_relu{alpha}, l0); p.add_instruction(migraphx::op::leaky_relu{alpha}, l0);
auto prog = migraphx::parse_onnx("leaky_relu_test.onnx"); auto prog = optimize_onnx("leaky_relu_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -616,7 +863,7 @@ TEST_CASE(log_test) ...@@ -616,7 +863,7 @@ TEST_CASE(log_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::log{}, input); p.add_instruction(migraphx::op::log{}, input);
auto prog = migraphx::parse_onnx("log_test.onnx"); auto prog = optimize_onnx("log_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -626,7 +873,7 @@ TEST_CASE(logsoftmax_test) ...@@ -626,7 +873,7 @@ TEST_CASE(logsoftmax_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
int axis = 1; int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, l0); p.add_instruction(migraphx::op::logsoftmax{axis}, l0);
auto prog = migraphx::parse_onnx("logsoftmax_test.onnx"); auto prog = optimize_onnx("logsoftmax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -641,7 +888,7 @@ TEST_CASE(lrn_test) ...@@ -641,7 +888,7 @@ TEST_CASE(lrn_test)
op.beta = 0.75; op.beta = 0.75;
op.bias = 1.0; op.bias = 1.0;
p.add_instruction(op, l0); p.add_instruction(op, l0);
auto prog = migraphx::parse_onnx("lrn_test.onnx"); auto prog = optimize_onnx("lrn_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -655,7 +902,7 @@ TEST_CASE(matmul_bmbm_test) ...@@ -655,7 +902,7 @@ TEST_CASE(matmul_bmbm_test)
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1);
p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1); p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1);
auto prog = migraphx::parse_onnx("matmul_bmbm_test.onnx"); auto prog = optimize_onnx("matmul_bmbm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -670,7 +917,7 @@ TEST_CASE(matmul_bmv_test) ...@@ -670,7 +917,7 @@ TEST_CASE(matmul_bmv_test)
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1);
p.add_instruction(migraphx::op::squeeze{{2}}, res); p.add_instruction(migraphx::op::squeeze{{2}}, res);
auto prog = migraphx::parse_onnx("matmul_bmv_test.onnx"); auto prog = optimize_onnx("matmul_bmv_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -684,7 +931,7 @@ TEST_CASE(matmul_mv_test) ...@@ -684,7 +931,7 @@ TEST_CASE(matmul_mv_test)
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1);
p.add_instruction(migraphx::op::squeeze{{1}}, res); p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = migraphx::parse_onnx("matmul_mv_test.onnx"); auto prog = optimize_onnx("matmul_mv_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -696,12 +943,10 @@ TEST_CASE(matmul_vbm_test) ...@@ -696,12 +943,10 @@ TEST_CASE(matmul_vbm_test)
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}}); auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0); auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0); auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0);
std::cout << "ONNX_TEST" << std::endl; auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1);
std::cout << "After Dot" << std::endl;
p.add_instruction(migraphx::op::squeeze{{1}}, res); p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = migraphx::parse_onnx("matmul_vbm_test.onnx"); auto prog = optimize_onnx("matmul_vbm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -715,7 +960,7 @@ TEST_CASE(matmul_vm_test) ...@@ -715,7 +960,7 @@ TEST_CASE(matmul_vm_test)
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1);
p.add_instruction(migraphx::op::squeeze{{0}}, res); p.add_instruction(migraphx::op::squeeze{{0}}, res);
auto prog = migraphx::parse_onnx("matmul_vm_test.onnx"); auto prog = optimize_onnx("matmul_vm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -731,7 +976,19 @@ TEST_CASE(matmul_vv_test) ...@@ -731,7 +976,19 @@ TEST_CASE(matmul_vv_test)
auto sr0 = p.add_instruction(migraphx::op::squeeze{{0}}, res); auto sr0 = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sr0); p.add_instruction(migraphx::op::squeeze{{0}}, sr0);
auto prog = migraphx::parse_onnx("matmul_vv_test.onnx"); auto prog = optimize_onnx("matmul_vv_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmulinteger_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {3, 6, 16}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 16, 8}});
p.add_instruction(migraphx::op::quant_dot{1, 0}, l0, l1);
auto prog = optimize_onnx("matmulinteger_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -745,7 +1002,37 @@ TEST_CASE(max_test) ...@@ -745,7 +1002,37 @@ TEST_CASE(max_test)
auto l0 = p.add_instruction(migraphx::op::max{}, input0, input1); auto l0 = p.add_instruction(migraphx::op::max{}, input0, input1);
p.add_instruction(migraphx::op::max{}, l0, input2); p.add_instruction(migraphx::op::max{}, l0, input2);
migraphx::parse_onnx("max_test.onnx"); optimize_onnx("max_test.onnx");
}
TEST_CASE(maxpool_notset_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
float val = std::numeric_limits<float>::lowest();
auto ins_pad = p.add_instruction(migraphx::op::pad{pads, val}, input);
p.add_instruction(migraphx::op::pooling{"max", {0, 0}, {2, 2}, {6, 6}}, ins_pad);
auto prog = optimize_onnx("maxpool_notset_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(maxpool_same_upper_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
float val = std::numeric_limits<float>::lowest();
auto ins_pad = p.add_instruction(migraphx::op::pad{pads, val}, input);
p.add_instruction(
migraphx::op::pooling{"max", {0, 0}, {1, 1}, {2, 2}, migraphx::op::padding_mode_t::same},
ins_pad);
auto prog = optimize_onnx("maxpool_same_upper_test.onnx");
EXPECT(p == prog);
} }
TEST_CASE(min_test) TEST_CASE(min_test)
...@@ -757,7 +1044,7 @@ TEST_CASE(min_test) ...@@ -757,7 +1044,7 @@ TEST_CASE(min_test)
auto l0 = p.add_instruction(migraphx::op::min{}, input0, input1); auto l0 = p.add_instruction(migraphx::op::min{}, input0, input1);
p.add_instruction(migraphx::op::min{}, l0, input2); p.add_instruction(migraphx::op::min{}, l0, input2);
migraphx::parse_onnx("min_test.onnx"); optimize_onnx("min_test.onnx");
} }
TEST_CASE(no_pad_test) TEST_CASE(no_pad_test)
...@@ -765,7 +1052,7 @@ TEST_CASE(no_pad_test) ...@@ -765,7 +1052,7 @@ TEST_CASE(no_pad_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::identity{}, l0); p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_onnx("no_pad_test.onnx"); auto prog = optimize_onnx("no_pad_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -775,7 +1062,7 @@ TEST_CASE(pad_test) ...@@ -775,7 +1062,7 @@ TEST_CASE(pad_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0); p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0);
auto prog = migraphx::parse_onnx("pad_test.onnx"); auto prog = optimize_onnx("pad_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -787,7 +1074,69 @@ TEST_CASE(pow_test) ...@@ -787,7 +1074,69 @@ TEST_CASE(pow_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::pow{}, l0, l1); p.add_instruction(migraphx::op::pow{}, l0, l1);
auto prog = migraphx::parse_onnx("pow_test.onnx"); auto prog = optimize_onnx("pow_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(prelu_brcst_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{l0->get_shape().lens()}, l1);
auto ret = p.add_instruction(migraphx::op::prelu{}, l0, bl1);
p.add_return({ret});
auto prog = migraphx::parse_onnx("prelu_brcst_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reducel1_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto abs_l0 = p.add_instruction(migraphx::op::abs{}, l0);
auto sum_l0 = p.add_instruction(migraphx::op::reduce_sum{{-2}}, abs_l0);
p.add_instruction(migraphx::op::squeeze{{-2}}, sum_l0);
auto prog = optimize_onnx("reducel1_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reducel2_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto square_l0 = p.add_instruction(migraphx::op::mul{}, l0, l0);
auto sum_l0 = p.add_instruction(migraphx::op::reduce_sum{{-1}}, square_l0);
auto squ_l0 = p.add_instruction(migraphx::op::squeeze{{-1}}, sum_l0);
p.add_instruction(migraphx::op::sqrt{}, squ_l0);
auto prog = optimize_onnx("reducel2_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reduce_log_sum_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto sum_l0 = p.add_instruction(migraphx::op::reduce_sum{{-3}}, l0);
p.add_instruction(migraphx::op::log{}, sum_l0);
auto prog = optimize_onnx("reduce_log_sum_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reduce_log_sum_exp_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto exp_l0 = p.add_instruction(migraphx::op::exp{}, l0);
auto sum_l0 = p.add_instruction(migraphx::op::reduce_sum{{-4}}, exp_l0);
p.add_instruction(migraphx::op::log{}, sum_l0);
auto prog = optimize_onnx("reduce_log_sum_exp_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -797,7 +1146,7 @@ TEST_CASE(reducemax_test) ...@@ -797,7 +1146,7 @@ TEST_CASE(reducemax_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_max{{2}}, l0); p.add_instruction(migraphx::op::reduce_max{{2}}, l0);
auto prog = migraphx::parse_onnx("reducemax_test.onnx"); auto prog = optimize_onnx("reducemax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -808,7 +1157,7 @@ TEST_CASE(reducemean_test) ...@@ -808,7 +1157,7 @@ TEST_CASE(reducemean_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0); auto l1 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducemean_test.onnx"); auto prog = optimize_onnx("reducemean_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -818,7 +1167,7 @@ TEST_CASE(reducemean_keepdims_test) ...@@ -818,7 +1167,7 @@ TEST_CASE(reducemean_keepdims_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0); p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
auto prog = migraphx::parse_onnx("reducemean_keepdims_test.onnx"); auto prog = optimize_onnx("reducemean_keepdims_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -829,7 +1178,17 @@ TEST_CASE(reducemin_test) ...@@ -829,7 +1178,17 @@ TEST_CASE(reducemin_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_min{{2, 3}}, l0); auto l1 = p.add_instruction(migraphx::op::reduce_min{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducemin_test.onnx"); auto prog = optimize_onnx("reducemin_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reduceprod_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_prod{{2}}, l0);
auto prog = optimize_onnx("reduceprod_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -840,7 +1199,7 @@ TEST_CASE(reducesum_test) ...@@ -840,7 +1199,7 @@ TEST_CASE(reducesum_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2}}, l0); auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, l1); p.add_instruction(migraphx::op::squeeze{{2}}, l1);
auto prog = migraphx::parse_onnx("reducesum_test.onnx"); auto prog = optimize_onnx("reducesum_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -851,7 +1210,7 @@ TEST_CASE(reducesum_multiaxis_test) ...@@ -851,7 +1210,7 @@ TEST_CASE(reducesum_multiaxis_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0); auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducesum_multiaxis_test.onnx"); auto prog = optimize_onnx("reducesum_multiaxis_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -861,7 +1220,19 @@ TEST_CASE(reducesum_keepdims_test) ...@@ -861,7 +1220,19 @@ TEST_CASE(reducesum_keepdims_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0); p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
auto prog = migraphx::parse_onnx("reducesum_keepdims_test.onnx"); auto prog = optimize_onnx("reducesum_keepdims_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reducesum_square_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto squ_l0 = p.add_instruction(migraphx::op::mul{}, l0, l0);
auto sum_l0 = p.add_instruction(migraphx::op::reduce_sum{{-2}}, squ_l0);
p.add_instruction(migraphx::op::squeeze{{-2}}, sum_l0);
auto prog = optimize_onnx("reducesum_square_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -877,7 +1248,7 @@ TEST_CASE(reshape_test) ...@@ -877,7 +1248,7 @@ TEST_CASE(reshape_test)
op.dims = reshape_dims; op.dims = reshape_dims;
p.add_instruction(op, l0); p.add_instruction(op, l0);
p.add_instruction(op, l0); p.add_instruction(op, l0);
auto prog = migraphx::parse_onnx("reshape_test.onnx"); auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -892,7 +1263,7 @@ TEST_CASE(reshape_non_standard_test) ...@@ -892,7 +1263,7 @@ TEST_CASE(reshape_non_standard_test)
auto tran_x = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, x); auto tran_x = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, x);
auto cont_x = p.add_instruction(migraphx::op::contiguous{}, tran_x); auto cont_x = p.add_instruction(migraphx::op::contiguous{}, tran_x);
p.add_instruction(migraphx::op::reshape{{4, 3, 2}}, cont_x); p.add_instruction(migraphx::op::reshape{{4, 3, 2}}, cont_x);
auto prog = migraphx::parse_onnx("reshape_non_standard_test.onnx"); auto prog = optimize_onnx("reshape_non_standard_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -903,7 +1274,7 @@ TEST_CASE(round_test) ...@@ -903,7 +1274,7 @@ TEST_CASE(round_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::round{}, input); p.add_instruction(migraphx::op::round{}, input);
auto prog = migraphx::parse_onnx("round_test.onnx"); auto prog = optimize_onnx("round_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -914,7 +1285,7 @@ TEST_CASE(shape_test) ...@@ -914,7 +1285,7 @@ TEST_CASE(shape_test)
auto l0 = p.add_parameter("x", s); auto l0 = p.add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}}; migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
p.add_literal(s_shape, l0->get_shape().lens()); p.add_literal(s_shape, l0->get_shape().lens());
auto prog = migraphx::parse_onnx("shape_test.onnx"); auto prog = optimize_onnx("shape_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -929,7 +1300,7 @@ TEST_CASE(shape_gather_test) ...@@ -929,7 +1300,7 @@ TEST_CASE(shape_gather_test)
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}}); auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
int axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2); p.add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = migraphx::parse_onnx("shape_gather_test.onnx"); auto prog = optimize_onnx("shape_gather_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -940,7 +1311,7 @@ TEST_CASE(sign_test) ...@@ -940,7 +1311,7 @@ TEST_CASE(sign_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::sign{}, input); p.add_instruction(migraphx::op::sign{}, input);
auto prog = migraphx::parse_onnx("sign_test.onnx"); auto prog = optimize_onnx("sign_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -950,7 +1321,7 @@ TEST_CASE(sin_test) ...@@ -950,7 +1321,7 @@ TEST_CASE(sin_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::sin{}, input); p.add_instruction(migraphx::op::sin{}, input);
auto prog = migraphx::parse_onnx("sin_test.onnx"); auto prog = optimize_onnx("sin_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -960,7 +1331,7 @@ TEST_CASE(sinh_test) ...@@ -960,7 +1331,7 @@ TEST_CASE(sinh_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::sinh{}, input); p.add_instruction(migraphx::op::sinh{}, input);
auto prog = migraphx::parse_onnx("sinh_test.onnx"); auto prog = optimize_onnx("sinh_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -970,7 +1341,7 @@ TEST_CASE(slice_test) ...@@ -970,7 +1341,7 @@ TEST_CASE(slice_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}});
p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 2}}, l0); p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 2}}, l0);
auto prog = migraphx::parse_onnx("slice_test.onnx"); auto prog = optimize_onnx("slice_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -980,7 +1351,7 @@ TEST_CASE(softmax_test) ...@@ -980,7 +1351,7 @@ TEST_CASE(softmax_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
p.add_instruction(migraphx::op::softmax{1}, l0); p.add_instruction(migraphx::op::softmax{1}, l0);
auto prog = migraphx::parse_onnx("softmax_test.onnx"); auto prog = optimize_onnx("softmax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -991,7 +1362,7 @@ TEST_CASE(sqrt_test) ...@@ -991,7 +1362,7 @@ TEST_CASE(sqrt_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
p.add_instruction(migraphx::op::sqrt{}, input); p.add_instruction(migraphx::op::sqrt{}, input);
auto prog = migraphx::parse_onnx("sqrt_test.onnx"); auto prog = optimize_onnx("sqrt_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1004,7 +1375,7 @@ TEST_CASE(squeeze_unsqueeze_test) ...@@ -1004,7 +1375,7 @@ TEST_CASE(squeeze_unsqueeze_test)
p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}}); p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
auto l1 = p.add_instruction(migraphx::op::squeeze{squeeze_axes}, l0); auto l1 = p.add_instruction(migraphx::op::squeeze{squeeze_axes}, l0);
p.add_instruction(migraphx::op::unsqueeze{unsqueeze_axes}, l1); p.add_instruction(migraphx::op::unsqueeze{unsqueeze_axes}, l1);
auto prog = migraphx::parse_onnx("squeeze_unsqueeze_test.onnx"); auto prog = optimize_onnx("squeeze_unsqueeze_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1017,7 +1388,7 @@ TEST_CASE(sub_bcast_test) ...@@ -1017,7 +1388,7 @@ TEST_CASE(sub_bcast_test)
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l2); p.add_instruction(migraphx::op::sub{}, l0, l2);
auto prog = migraphx::parse_onnx("sub_bcast_test.onnx"); auto prog = optimize_onnx("sub_bcast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1028,10 +1399,9 @@ TEST_CASE(sub_scalar_test) ...@@ -1028,10 +1399,9 @@ TEST_CASE(sub_scalar_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, m0, m1); p.add_instruction(migraphx::op::sub{}, l0, m1);
auto prog = migraphx::parse_onnx("sub_scalar_test.onnx"); auto prog = optimize_onnx("sub_scalar_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1045,7 +1415,7 @@ TEST_CASE(sum_test) ...@@ -1045,7 +1415,7 @@ TEST_CASE(sum_test)
auto l0 = p.add_instruction(migraphx::op::add{}, input0, input1); auto l0 = p.add_instruction(migraphx::op::add{}, input0, input1);
p.add_instruction(migraphx::op::add{}, l0, input2); p.add_instruction(migraphx::op::add{}, l0, input2);
auto prog = migraphx::parse_onnx("sum_test.onnx"); auto prog = optimize_onnx("sum_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1055,7 +1425,7 @@ TEST_CASE(tan_test) ...@@ -1055,7 +1425,7 @@ TEST_CASE(tan_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::tan{}, input); p.add_instruction(migraphx::op::tan{}, input);
auto prog = migraphx::parse_onnx("tan_test.onnx"); auto prog = optimize_onnx("tan_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1065,7 +1435,7 @@ TEST_CASE(tanh_test) ...@@ -1065,7 +1435,7 @@ TEST_CASE(tanh_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
p.add_instruction(migraphx::op::tanh{}, input); p.add_instruction(migraphx::op::tanh{}, input);
auto prog = migraphx::parse_onnx("tanh_test.onnx"); auto prog = optimize_onnx("tanh_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1077,7 +1447,7 @@ TEST_CASE(transpose_test) ...@@ -1077,7 +1447,7 @@ TEST_CASE(transpose_test)
std::vector<int64_t> perm{0, 3, 1, 2}; std::vector<int64_t> perm{0, 3, 1, 2};
p.add_instruction(migraphx::op::transpose{perm}, input); p.add_instruction(migraphx::op::transpose{perm}, input);
auto prog = migraphx::parse_onnx("transpose_test.onnx"); auto prog = optimize_onnx("transpose_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1103,7 +1473,7 @@ TEST_CASE(transpose_gather_test) ...@@ -1103,7 +1473,7 @@ TEST_CASE(transpose_gather_test)
p.add_instruction( p.add_instruction(
migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind)); migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind));
auto prog = migraphx::parse_onnx("transpose_gather_test.onnx"); auto prog = optimize_onnx("transpose_gather_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1115,7 +1485,28 @@ TEST_CASE(unknown_test) ...@@ -1115,7 +1485,28 @@ TEST_CASE(unknown_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::unknown{"Unknown"}, l0, l1); auto l2 = p.add_instruction(migraphx::op::unknown{"Unknown"}, l0, l1);
p.add_instruction(migraphx::op::unknown{"Unknown"}, l2); p.add_instruction(migraphx::op::unknown{"Unknown"}, l2);
auto prog = migraphx::parse_onnx("unknown_test.onnx"); auto prog = optimize_onnx("unknown_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(variable_batch_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0);
auto prog = optimize_onnx("variable_batch_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(variable_batch_leq_zero_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = optimize_onnx("variable_batch_leq_zero_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
prelu_brcst_test:w

0
1out"PReluprelu_brcst_testZ
0




Z
1


b
out




B
\ No newline at end of file
reduce_log_sum_exp_test:
>
xy"ReduceLogSumExp*
axes@*
keepdimsreduce_log_sum_exp_testZ
x




b
y



B
\ No newline at end of file
reduce_log_sum_test:
;
xy" ReduceLogSum*
axes@*
keepdimsreduce_log_sum_testZ
x




b
y




B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment