"...git@developer.sourcefind.cn:chenzk/alphafold2_jax.git" did not exist on "1109480e6f38d71b3b265a4a25039e51e2343368"
Unverified Commit a023ec19 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Add additional simple operators (MatMulInteger, ConvInteger, Asinh, Acosh, and Atanh (#431)



* Add initial api

* Formatting

* Add more api

* Formatting

* add more operators (asinh, acosh, atanh, MatMulInteger, ConvInteger)

* clang format

* add unit tests for new operators

* clang format

* Add auto api generation

* Formatting

* Fix some compilation errors

* Change handle struct

* Formatting

* Fix reamining compilation errors

* Formatting

* Simplify using ctype

* Formatting

* Initial c++ generation

* Formatting

* Add C++header

* Formatting

* Add test

* Formatting

* Add initial tests

* Formatting

* Try to fix formatting

* Cleanup formatting

* Formatting

* Fix constructors on the same line

* Fix tests

* Formatting

* Fix tidy issues

* Fix tidy issues

* Fix naming issue

* Add onnx API to parse buffer

* Formatting

* Add arguments api

* Formatting

* Fix verify parameters

* Fix cppcheck issues

* Formatting

* Add method to get output shapes and bytes

* Formatting

* Try formatting

* Formatting

* Improve the test coverage

* Formatting

* Add print method

* Formatting

* Fix cppcheck issue

* Fix package dependency

* Add nolint

* Try fix formatting

* Formatting

* formatting

* formatting

* Fix formatting
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>
parent b949da7f

atanh_test:=
xy"Atanh
atanh_testZ
x


b
y


B
\ No newline at end of file
convinteger_bias_test:À
?
0
1
23" ConvInteger*
dilations@@ *
strides@@ convinteger_bias_testZ
0




 Z
1




Z
2

b
3




B
\ No newline at end of file
...@@ -38,6 +38,20 @@ def acos_test(): ...@@ -38,6 +38,20 @@ def acos_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def acosh_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10])
node = onnx.helper.make_node(
'Acosh',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])
@onnx_test @onnx_test
def add_bcast_test(): def add_bcast_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
...@@ -130,6 +144,20 @@ def asin_test(): ...@@ -130,6 +144,20 @@ def asin_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def asinh_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10])
node = onnx.helper.make_node(
'Asinh',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])
@onnx_test @onnx_test
def atan_test(): def atan_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
...@@ -144,6 +172,20 @@ def atan_test(): ...@@ -144,6 +172,20 @@ def atan_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def atanh_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10])
node = onnx.helper.make_node(
'Atanh',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])
@onnx_test @onnx_test
def averagepool_notset_test(): def averagepool_notset_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5])
...@@ -572,6 +614,22 @@ def conv_relu_maxpool_x2_test(): ...@@ -572,6 +614,22 @@ def conv_relu_maxpool_x2_test():
return ([node1, node2, node3, node4, node5, node6], [x, y, z, m, n], [out]) return ([node1, node2, node3, node4, node5, node6], [x, y, z, m, n], [out])
@onnx_test
def convinteger_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 5, 5])
z = helper.make_tensor_value_info('2', TensorProto.INT32, [1])
out = helper.make_tensor_value_info('3', TensorProto.INT32, [1, 2, 28, 28])
node = onnx.helper.make_node('ConvInteger',
inputs=['0', '1', '2'],
outputs=['3'],
dilations=[1, 1],
strides=[1, 1])
return ([node], [x, y, z], [out])
@onnx_test @onnx_test
def cos_test(): def cos_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
...@@ -1195,6 +1253,21 @@ def matmul_vv_test(): ...@@ -1195,6 +1253,21 @@ def matmul_vv_test():
return ([node], [m1, m2], [y]) return ([node], [m1, m2], [y])
@onnx_test
def matmulinteger_test():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [3, 6, 16])
m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [3, 16, 8])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [3, 6, 8])
node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test @onnx_test
def max_test(): def max_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
matmulinteger_test:y

1
2y" MatMulIntegermatmulinteger_testZ
1



Z
2



b
y



B
\ No newline at end of file
...@@ -38,6 +38,17 @@ TEST_CASE(acos_test) ...@@ -38,6 +38,17 @@ TEST_CASE(acos_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(acosh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::acosh{}, input);
auto prog = optimize_onnx("acosh_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(add_bcast_test) TEST_CASE(add_bcast_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -109,6 +120,17 @@ TEST_CASE(asin_test) ...@@ -109,6 +120,17 @@ TEST_CASE(asin_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(asinh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::asinh{}, input);
auto prog = optimize_onnx("asinh_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(atan_test) TEST_CASE(atan_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -120,6 +142,17 @@ TEST_CASE(atan_test) ...@@ -120,6 +142,17 @@ TEST_CASE(atan_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(atanh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::atanh{}, input);
auto prog = optimize_onnx("atanh_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(averagepool_notset_test) TEST_CASE(averagepool_notset_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -389,6 +422,21 @@ TEST_CASE(conv_relu_maxpool_x2_test) ...@@ -389,6 +422,21 @@ TEST_CASE(conv_relu_maxpool_x2_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(convinteger_bias_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::int8_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::int32_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::quant_convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = optimize_onnx("convinteger_bias_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(cos_test) TEST_CASE(cos_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -862,9 +910,7 @@ TEST_CASE(matmul_vbm_test) ...@@ -862,9 +910,7 @@ TEST_CASE(matmul_vbm_test)
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}}); auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0); auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0); auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0);
std::cout << "ONNX_TEST" << std::endl;
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1);
std::cout << "After Dot" << std::endl;
p.add_instruction(migraphx::op::squeeze{{1}}, res); p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = optimize_onnx("matmul_vbm_test.onnx"); auto prog = optimize_onnx("matmul_vbm_test.onnx");
...@@ -902,6 +948,18 @@ TEST_CASE(matmul_vv_test) ...@@ -902,6 +948,18 @@ TEST_CASE(matmul_vv_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(matmulinteger_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {3, 6, 16}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 16, 8}});
p.add_instruction(migraphx::op::quant_dot{1, 0}, l0, l1);
auto prog = optimize_onnx("matmulinteger_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(max_test) TEST_CASE(max_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment