Commit 9b53cf55 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'extend_gemm_op' into seq2seq_example

parents 0db15370 ad8f88f5
...@@ -524,7 +524,7 @@ TEST_CASE(constant_test) ...@@ -524,7 +524,7 @@ TEST_CASE(constant_test)
TEST_CASE(constant_test_scalar) TEST_CASE(constant_test_scalar)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}, {0}}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {1}});
auto prog = migraphx::parse_onnx("constant_scalar.onnx"); auto prog = migraphx::parse_onnx("constant_scalar.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -572,6 +572,27 @@ TEST_CASE(gemm_test) ...@@ -572,6 +572,27 @@ TEST_CASE(gemm_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_ex)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f;
auto res_ab = p.add_instruction(migraphx::op::dot{alpha}, t0, l1);
auto beta = 0.8f;
auto l_beta = p.add_literal(beta);
auto brcst_beta = p.add_instruction(migraphx::op::scalar{l2->get_shape()}, l_beta);
auto res_c = p.add_instruction(migraphx::op::mul{}, l2, brcst_beta);
p.add_instruction(migraphx::op::add{}, res_ab, res_c);
auto prog = migraphx::parse_onnx("gemm_test_ex.onnx");
EXPECT(p == prog);
}
TEST_CASE(add_scalar_test) TEST_CASE(add_scalar_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -263,7 +263,7 @@ TEST_CASE(gather) ...@@ -263,7 +263,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}}; migraphx::shape indices{migraphx::shape::int32_type};
int axis = -4; int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
...@@ -273,7 +273,7 @@ TEST_CASE(gather) ...@@ -273,7 +273,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}}; migraphx::shape indices{migraphx::shape::int32_type};
int axis = 3; int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
...@@ -283,9 +283,9 @@ TEST_CASE(gather) ...@@ -283,9 +283,9 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {3}}; migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}}; migraphx::shape indices{migraphx::shape::int32_type};
int axis = 0; int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1}, {0}}, expect_shape(migraphx::shape{migraphx::shape::float_type},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
input, input,
indices); indices);
...@@ -316,6 +316,80 @@ TEST_CASE(gather) ...@@ -316,6 +316,80 @@ TEST_CASE(gather)
} }
} }
TEST_CASE(dot)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1}}, migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
migraphx::op::dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 3, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 7}},
migraphx::op::dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 7}},
migraphx::op::dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {3, 1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {3, 1, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {2, 2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {3, 2, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 1, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
}
TEST_CASE(rnn) TEST_CASE(rnn)
{ {
{ {
......
...@@ -22,10 +22,8 @@ struct target ...@@ -22,10 +22,8 @@ struct target
{ {
/// A unique name used to identify the target /// A unique name used to identify the target
std::string name() const; std::string name() const;
/// The transformation passes to be run
/** /**
* @brief The transformation pass to be run during compilation. * @brief The transformation pass to be run during compilation.
* @details [long description]
* *
* @param ctx This is the target-dependent context that is created by `get_context` * @param ctx This is the target-dependent context that is created by `get_context`
* @return The passes to be ran * @return The passes to be ran
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment