Commit 28a644f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 18eed505
...@@ -13,11 +13,11 @@ namespace device { ...@@ -13,11 +13,11 @@ namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg) void pack_a(hipStream_t stream, const argument& result, const argument& arg)
{ {
auto output_shape = result.get_shape(); auto output_shape = result.get_shape();
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
auto dim_0 = out_lens.size() - 2; auto dim_0 = out_lens.size() - 2;
auto dim_1 = out_lens.size() - 1; auto dim_1 = out_lens.size() - 1;
std::size_t lda = output_shape.strides()[dim_0]; std::size_t lda = output_shape.strides()[dim_0];
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1]; std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) { visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
...@@ -26,12 +26,13 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg) ...@@ -26,12 +26,13 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
visit_tensor_size(out_lens.size(), [&](auto out_dim) { visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(output_shape); hip_tensor_descriptor<out_dim> desc(output_shape);
gs_launch(stream, nelements)([=](auto ii) { gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4; const size_t nb = 4;
auto idx = desc.multi(ii); auto idx = desc.multi(ii);
std::size_t i_m = idx[dim_1]; std::size_t i_m = idx[dim_1];
std::size_t i_k = idx[dim_0]; std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size; std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] = in_ptr[i_m + i_k * lda + offset]; out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] =
in_ptr[i_m + i_k * lda + offset];
}); });
}); });
}); });
...@@ -39,11 +40,11 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg) ...@@ -39,11 +40,11 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
void pack_b(hipStream_t stream, const argument& result, const argument& arg) void pack_b(hipStream_t stream, const argument& result, const argument& arg)
{ {
auto output_shape = result.get_shape(); auto output_shape = result.get_shape();
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
auto dim_0 = output_shape.lens().size() - 2; auto dim_0 = output_shape.lens().size() - 2;
auto dim_1 = output_shape.lens().size() - 1; auto dim_1 = output_shape.lens().size() - 1;
std::size_t ldb = output_shape.strides()[dim_1]; std::size_t ldb = output_shape.strides()[dim_1];
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1]; std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) { visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
...@@ -52,12 +53,13 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg) ...@@ -52,12 +53,13 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg)
visit_tensor_size(out_lens.size(), [&](auto out_dim) { visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(output_shape); hip_tensor_descriptor<out_dim> desc(output_shape);
gs_launch(stream, nelements)([=](auto ii) { gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4; const size_t nb = 4;
auto idx = desc.multi(ii); auto idx = desc.multi(ii);
std::size_t i_n = idx[1]; std::size_t i_n = idx[1];
std::size_t i_k = idx[0]; std::size_t i_k = idx[0];
std::size_t offset = ii / m_size * m_size; std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] = in_ptr[i_n + i_k * ldb + offset]; out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] =
in_ptr[i_n + i_k * ldb + offset];
}); });
}); });
}); });
......
...@@ -1435,15 +1435,11 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1435,15 +1435,11 @@ TEST_CASE(quant_dot_3args_batch)
p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3); p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3);
std::vector<int> gold = { std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150, 102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
284, 308, 332, 356, 380, 404, 428, 404, 428, 1530, 1570, 1610, 1650, 1690, 1730, 1770, 2160, 2216, 2272,
1530, 1570, 1610, 1650, 1690, 1730, 1770, 2328, 2384, 2440, 2496, 4750, 4822, 4894, 4966, 5038, 5110, 5182, 5828,
2160, 2216, 2272, 2328, 2384, 2440, 2496, 5916, 6004, 6092, 6180, 6268, 6356, 9762, 9866, 9970, 10074, 10178, 10282,
4750, 4822, 4894, 4966, 5038, 5110, 5182, 10386, 11288, 11408, 11528, 11648, 11768, 11888, 12008};
5828, 5916, 6004, 6092, 6180, 6268, 6356,
9762, 9866, 9970, 10074, 10178, 10282, 10386,
11288, 11408, 11528, 11648, 11768, 11888, 12008
};
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -1464,27 +1460,20 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1464,27 +1460,20 @@ TEST_CASE(quant_dot_3args_batch)
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2); std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1); auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2); auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3}); auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3); p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3);
std::vector<int> gold = { std::vector<int> gold = {
90, 237, 384, 531, 678, 825, 90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
120, 299, 478, 657, 836, 1015, 150, 361, 572, 783, 994, 1205, 3456, 3987, 4518, 5049, 5580, 6111,
150, 361, 572, 783, 994, 1205, 3678, 4241, 4804, 5367, 5930, 6493, 3900, 4495, 5090, 5685, 6280, 6875,
3456, 3987, 4518, 5049, 5580, 6111, 11430, 12345, 13260, 14175, 15090, 16005, 11844, 12791, 13738, 14685, 15632, 16579,
3678, 4241, 4804, 5367, 5930, 6493, 12258, 13237, 14216, 15195, 16174, 17153, 24012, 25311, 26610, 27909, 29208, 30507,
3900, 4495, 5090, 5685, 6280, 6875, 24618, 25949, 27280, 28611, 29942, 31273, 25224, 26587, 27950, 29313, 30676, 32039};
11430, 12345, 13260, 14175, 15090, 16005,
11844, 12791, 13738, 14685, 15632, 16579,
12258, 13237, 14216, 15195, 16174, 17153,
24012, 25311, 26610, 27909, 29208, 30507,
24618, 25949, 27280, 28611, 29942, 31273,
25224, 26587, 27950, 29313, 30676, 32039
};
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
......
...@@ -1337,9 +1337,9 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2> ...@@ -1337,9 +1337,9 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape); auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape); auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape); auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3); p.add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3);
return p; return p;
} }
......
...@@ -608,17 +608,13 @@ TEST_CASE(quant_dot_2args) ...@@ -608,17 +608,13 @@ TEST_CASE(quant_dot_2args)
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}}; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}};
migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}}; migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}};
throws_shape(migraphx::op::quant_dot{}, throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
s_m1,
s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}}; migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
throws_shape(migraphx::op::quant_dot{}, throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
s_m1,
s_m2);
} }
} }
...@@ -639,10 +635,7 @@ TEST_CASE(quant_dot_3args) ...@@ -639,10 +635,7 @@ TEST_CASE(quant_dot_3args)
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}}; migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
throws_shape(migraphx::op::quant_dot{1, 2}, throws_shape(migraphx::op::quant_dot{1, 2}, s_m1, s_m2, s_m3);
s_m1,
s_m2,
s_m3);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment