Commit 837304f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a bug in cpu_gemm and add tests to cover that.

parent 1c3b16d2
...@@ -55,10 +55,7 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -55,10 +55,7 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(amat, [&](const auto& a) { visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) { visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat); auto c = make_mat(cmat);
if(beta != 0.0) c = beta * c;
{
c = beta * c;
}
if(alpha != 0.0) if(alpha != 0.0)
{ {
......
...@@ -1025,6 +1025,123 @@ TEST_CASE(gemm_mutli_dim_2) ...@@ -1025,6 +1025,123 @@ TEST_CASE(gemm_mutli_dim_2)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_dim_2_beta0)
{
migraphx::program p;
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140,
-0.55893585,
0.02625652,
0.75171776,
0.23112578,
0.25624787,
-1.50442161};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691, 0.43326023, 0.78305235, -0.53506601,
-0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469,
0.41435494, 1.52136843, -0.40699791, -1.59839430};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
std::vector<float> m3 = {0.18208394, -0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 2, 4}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f;
float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_beta_0)
{
migraphx::program p;
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140};
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
std::vector<float> m3 = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f;
float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim_2_3) TEST_CASE(gemm_mutli_dim_2_3)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -902,6 +902,46 @@ struct gemm_mutli_3args ...@@ -902,6 +902,46 @@ struct gemm_mutli_3args
} }
}; };
struct gemm_mutli_3args_beta0
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 1.0f;
float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_mutli_3args_alpha0
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.0f;
float beta = 1.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct test_contiguous struct test_contiguous
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -3065,6 +3105,8 @@ int main() ...@@ -3065,6 +3105,8 @@ int main()
verify_program<gemm_mutli_dim_2>(); verify_program<gemm_mutli_dim_2>();
verify_program<gemm_mutli_dim_2_3>(); verify_program<gemm_mutli_dim_2_3>();
verify_program<gemm_mutli_3args>(); verify_program<gemm_mutli_3args>();
verify_program<gemm_mutli_3args_beta0>();
verify_program<gemm_mutli_3args_alpha0>();
verify_program<test_contiguous>(); verify_program<test_contiguous>();
verify_program<test_eliminate_contiguous>(); verify_program<test_eliminate_contiguous>();
verify_program<test_transpose>(); verify_program<test_transpose>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment