Commit 828017fb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 94d13003
...@@ -78,13 +78,12 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -78,13 +78,12 @@ void migemm_impl(tensor_view<T> cmat,
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
double s = cmat(c_idx.begin(), c_idx.end()) * beta; double s = cmat(c_idx.begin(), c_idx.end()) * beta;
auto a_idx = c_idx; auto a_idx = c_idx;
auto b_idx = c_idx; auto b_idx = c_idx;
dfor(k)([&](auto kk) { dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk; a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
bmat(b_idx.begin(), b_idx.end());
}); });
cmat(c_idx.begin(), c_idx.end()) = alpha * s; cmat(c_idx.begin(), c_idx.end()) = alpha * s;
}); });
......
...@@ -836,7 +836,6 @@ struct test_gemm_transposea_ex ...@@ -836,7 +836,6 @@ struct test_gemm_transposea_ex
} }
}; };
struct test_gemm_transposeab struct test_gemm_transposeab
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -575,17 +575,17 @@ TEST_CASE(gemm_test) ...@@ -575,17 +575,17 @@ TEST_CASE(gemm_test)
TEST_CASE(gemm_ex) TEST_CASE(gemm_ex)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}}); auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}}); auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}}); auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0); auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f; auto alpha = 0.5f;
auto res_ab = p.add_instruction(migraphx::op::dot{alpha}, t0, l1); auto res_ab = p.add_instruction(migraphx::op::dot{alpha}, t0, l1);
auto beta = 0.8f; auto beta = 0.8f;
auto l_beta = p.add_literal(beta); auto l_beta = p.add_literal(beta);
auto brcst_beta = p.add_instruction(migraphx::op::scalar{l2->get_shape()}, l_beta); auto brcst_beta = p.add_instruction(migraphx::op::scalar{l2->get_shape()}, l_beta);
auto res_c = p.add_instruction(migraphx::op::mul{}, l2, brcst_beta); auto res_c = p.add_instruction(migraphx::op::mul{}, l2, brcst_beta);
p.add_instruction(migraphx::op::add{}, res_ab, res_c); p.add_instruction(migraphx::op::add{}, res_ab, res_c);
auto prog = migraphx::parse_onnx("gemm_test_ex.onnx"); auto prog = migraphx::parse_onnx("gemm_test_ex.onnx");
......
...@@ -321,10 +321,8 @@ TEST_CASE(dot) ...@@ -321,10 +321,8 @@ TEST_CASE(dot)
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}}, expect_shape(
migraphx::op::dot{}, migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::op::dot{}, s_m1, s_m2);
s_m1,
s_m2);
} }
{ {
...@@ -336,10 +334,8 @@ TEST_CASE(dot) ...@@ -336,10 +334,8 @@ TEST_CASE(dot)
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}}, expect_shape(
migraphx::op::dot{}, migraphx::shape{migraphx::shape::float_type, {1, 1}}, migraphx::op::dot{}, s_m1, s_m2);
s_m1,
s_m2);
} }
{ {
...@@ -394,7 +390,6 @@ TEST_CASE(dot) ...@@ -394,7 +390,6 @@ TEST_CASE(dot)
} }
} }
TEST_CASE(rnn) TEST_CASE(rnn)
{ {
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment