Commit 18eed505 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add gpu and shape tests for quant dot

parent 05affc68
...@@ -14,20 +14,24 @@ namespace device { ...@@ -14,20 +14,24 @@ namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg) void pack_a(hipStream_t stream, const argument& result, const argument& arg)
{ {
auto output_shape = result.get_shape(); auto output_shape = result.get_shape();
auto dim_0 = output_shape.lens().size() - 2; auto out_lens = output_shape.lens();
auto dim_0 = out_lens.size() - 2;
auto dim_1 = out_lens.size() - 1;
std::size_t lda = output_shape.strides()[dim_0]; std::size_t lda = output_shape.strides()[dim_0];
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) { visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
auto* out_ptr = device_cast(output.data()); auto* out_ptr = device_cast(output.data());
auto* in_ptr = device_cast(input.data()); auto* in_ptr = device_cast(input.data());
visit_tensor_size(output_shape.lens().size(), [&](auto out_dim) { visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(output_shape); hip_tensor_descriptor<out_dim> desc(output_shape);
gs_launch(stream, nelements)([=](auto ii) { gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4; const size_t nb = 4;
auto idx = desc.multi(ii); auto idx = desc.multi(ii);
std::size_t i_m = idx[1]; std::size_t i_m = idx[dim_1];
std::size_t i_k = idx[0]; std::size_t i_k = idx[dim_0];
out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb] = in_ptr[i_m + i_k * lda]; std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] = in_ptr[i_m + i_k * lda + offset];
}); });
}); });
}); });
...@@ -36,20 +40,24 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg) ...@@ -36,20 +40,24 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
void pack_b(hipStream_t stream, const argument& result, const argument& arg) void pack_b(hipStream_t stream, const argument& result, const argument& arg)
{ {
auto output_shape = result.get_shape(); auto output_shape = result.get_shape();
auto out_lens = output_shape.lens();
auto dim_0 = output_shape.lens().size() - 2;
auto dim_1 = output_shape.lens().size() - 1; auto dim_1 = output_shape.lens().size() - 1;
std::size_t ldb = output_shape.strides()[dim_1]; std::size_t ldb = output_shape.strides()[dim_1];
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) { visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
auto* out_ptr = device_cast(output.data()); auto* out_ptr = device_cast(output.data());
auto* in_ptr = device_cast(input.data()); auto* in_ptr = device_cast(input.data());
visit_tensor_size(output_shape.lens().size(), [&](auto out_dim) { visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(output_shape); hip_tensor_descriptor<out_dim> desc(output_shape);
gs_launch(stream, nelements)([=](auto ii) { gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4; const size_t nb = 4;
auto idx = desc.multi(ii); auto idx = desc.multi(ii);
std::size_t i_n = idx[1]; std::size_t i_n = idx[1];
std::size_t i_k = idx[0]; std::size_t i_k = idx[0];
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb] = in_ptr[i_n + i_k * ldb]; std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] = in_ptr[i_n + i_k * ldb + offset];
}); });
}); });
}); });
......
...@@ -1415,4 +1415,83 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1415,4 +1415,83 @@ TEST_CASE(quant_dot_3args_general)
} }
} }
TEST_CASE(quant_dot_3args_batch)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 2, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 4, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 2, 7}};
std::vector<int8_t> data1(4 * 2 * 4);
std::vector<int8_t> data2(4 * 4 * 7);
std::vector<int> data3(4 * 2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3);
std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150,
284, 308, 332, 356, 380, 404, 428,
1530, 1570, 1610, 1650, 1690, 1730, 1770,
2160, 2216, 2272, 2328, 2384, 2440, 2496,
4750, 4822, 4894, 4966, 5038, 5110, 5182,
5828, 5916, 6004, 6092, 6180, 6268, 6356,
9762, 9866, 9970, 10074, 10178, 10282, 10386,
11288, 11408, 11528, 11648, 11768, 11888, 12008
};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 6, 4}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 3, 6}};
std::vector<int8_t> data1(48);
std::vector<int8_t> data2(96);
std::vector<int> data3(72);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3);
std::vector<int> gold = {
90, 237, 384, 531, 678, 825,
120, 299, 478, 657, 836, 1015,
150, 361, 572, 783, 994, 1205,
3456, 3987, 4518, 5049, 5580, 6111,
3678, 4241, 4804, 5367, 5930, 6493,
3900, 4495, 5090, 5685, 6280, 6875,
11430, 12345, 13260, 14175, 15090, 16005,
11844, 12791, 13738, 14685, 15632, 16579,
12258, 13237, 14216, 15195, 16174, 17153,
24012, 25311, 26610, 27909, 29208, 30507,
24618, 25949, 27280, 28611, 29942, 31273,
25224, 26587, 27950, 29313, 30676, 32039
};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1309,6 +1309,42 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4> ...@@ -1309,6 +1309,42 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
} }
}; };
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p;
}
};
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3);
return p;
}
};
struct test_contiguous : verify_program<test_contiguous> struct test_contiguous : verify_program<test_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -584,6 +584,68 @@ TEST_CASE(gemm) ...@@ -584,6 +584,68 @@ TEST_CASE(gemm)
} }
} }
// quant_dot
TEST_CASE(quant_dot_2args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::op::quant_dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
migraphx::op::quant_dot{1, 0},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}};
migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}};
throws_shape(migraphx::op::quant_dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
throws_shape(migraphx::op::quant_dot{},
s_m1,
s_m2);
}
}
TEST_CASE(quant_dot_3args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::op::quant_dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
throws_shape(migraphx::op::quant_dot{1, 2},
s_m1,
s_m2,
s_m3);
}
}
TEST_CASE(rnn) TEST_CASE(rnn)
{ {
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment