Commit 28a644f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 18eed505
......@@ -13,11 +13,11 @@ namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg)
{
auto output_shape = result.get_shape();
auto out_lens = output_shape.lens();
auto dim_0 = out_lens.size() - 2;
auto dim_1 = out_lens.size() - 1;
std::size_t lda = output_shape.strides()[dim_0];
auto output_shape = result.get_shape();
auto out_lens = output_shape.lens();
auto dim_0 = out_lens.size() - 2;
auto dim_1 = out_lens.size() - 1;
std::size_t lda = output_shape.strides()[dim_0];
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
......@@ -26,12 +26,13 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(output_shape);
gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_m = idx[dim_1];
std::size_t i_k = idx[dim_0];
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_m = idx[dim_1];
std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] = in_ptr[i_m + i_k * lda + offset];
out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] =
in_ptr[i_m + i_k * lda + offset];
});
});
});
......@@ -39,11 +40,11 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
void pack_b(hipStream_t stream, const argument& result, const argument& arg)
{
auto output_shape = result.get_shape();
auto out_lens = output_shape.lens();
auto dim_0 = output_shape.lens().size() - 2;
auto dim_1 = output_shape.lens().size() - 1;
std::size_t ldb = output_shape.strides()[dim_1];
auto output_shape = result.get_shape();
auto out_lens = output_shape.lens();
auto dim_0 = output_shape.lens().size() - 2;
auto dim_1 = output_shape.lens().size() - 1;
std::size_t ldb = output_shape.strides()[dim_1];
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
......@@ -52,12 +53,13 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg)
visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(output_shape);
gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_n = idx[1];
std::size_t i_k = idx[0];
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_n = idx[1];
std::size_t i_k = idx[0];
std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] = in_ptr[i_n + i_k * ldb + offset];
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] =
in_ptr[i_n + i_k * ldb + offset];
});
});
});
......
......@@ -1435,15 +1435,11 @@ TEST_CASE(quant_dot_3args_batch)
p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3);
std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150,
284, 308, 332, 356, 380, 404, 428,
1530, 1570, 1610, 1650, 1690, 1730, 1770,
2160, 2216, 2272, 2328, 2384, 2440, 2496,
4750, 4822, 4894, 4966, 5038, 5110, 5182,
5828, 5916, 6004, 6092, 6180, 6268, 6356,
9762, 9866, 9970, 10074, 10178, 10282, 10386,
11288, 11408, 11528, 11648, 11768, 11888, 12008
};
102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
404, 428, 1530, 1570, 1610, 1650, 1690, 1730, 1770, 2160, 2216, 2272,
2328, 2384, 2440, 2496, 4750, 4822, 4894, 4966, 5038, 5110, 5182, 5828,
5916, 6004, 6092, 6180, 6268, 6356, 9762, 9866, 9970, 10074, 10178, 10282,
10386, 11288, 11408, 11528, 11648, 11768, 11888, 12008};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......@@ -1464,27 +1460,20 @@ TEST_CASE(quant_dot_3args_batch)
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3);
std::vector<int> gold = {
90, 237, 384, 531, 678, 825,
120, 299, 478, 657, 836, 1015,
150, 361, 572, 783, 994, 1205,
3456, 3987, 4518, 5049, 5580, 6111,
3678, 4241, 4804, 5367, 5930, 6493,
3900, 4495, 5090, 5685, 6280, 6875,
11430, 12345, 13260, 14175, 15090, 16005,
11844, 12791, 13738, 14685, 15632, 16579,
12258, 13237, 14216, 15195, 16174, 17153,
24012, 25311, 26610, 27909, 29208, 30507,
24618, 25949, 27280, 28611, 29942, 31273,
25224, 26587, 27950, 29313, 30676, 32039
};
90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
150, 361, 572, 783, 994, 1205, 3456, 3987, 4518, 5049, 5580, 6111,
3678, 4241, 4804, 5367, 5930, 6493, 3900, 4495, 5090, 5685, 6280, 6875,
11430, 12345, 13260, 14175, 15090, 16005, 11844, 12791, 13738, 14685, 15632, 16579,
12258, 13237, 14216, 15195, 16174, 17153, 24012, 25311, 26610, 27909, 29208, 30507,
24618, 25949, 27280, 28611, 29942, 31273, 25224, 26587, 27950, 29313, 30676, 32039};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......
......@@ -1337,9 +1337,9 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3);
return p;
}
......
......@@ -608,17 +608,13 @@ TEST_CASE(quant_dot_2args)
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}};
migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}};
throws_shape(migraphx::op::quant_dot{},
s_m1,
s_m2);
throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
throws_shape(migraphx::op::quant_dot{},
s_m1,
s_m2);
throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
}
}
......@@ -639,10 +635,7 @@ TEST_CASE(quant_dot_3args)
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
throws_shape(migraphx::op::quant_dot{1, 2},
s_m1,
s_m2,
s_m3);
throws_shape(migraphx::op::quant_dot{1, 2}, s_m1, s_m2, s_m3);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment