Commit 28a644f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 18eed505
...@@ -31,7 +31,8 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg) ...@@ -31,7 +31,8 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
std::size_t i_m = idx[dim_1]; std::size_t i_m = idx[dim_1];
std::size_t i_k = idx[dim_0]; std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size; std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] = in_ptr[i_m + i_k * lda + offset]; out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] =
in_ptr[i_m + i_k * lda + offset];
}); });
}); });
}); });
...@@ -57,7 +58,8 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg) ...@@ -57,7 +58,8 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg)
std::size_t i_n = idx[1]; std::size_t i_n = idx[1];
std::size_t i_k = idx[0]; std::size_t i_k = idx[0];
std::size_t offset = ii / m_size * m_size; std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] = in_ptr[i_n + i_k * ldb + offset]; out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] =
in_ptr[i_n + i_k * ldb + offset];
}); });
}); });
}); });
......
...@@ -1435,15 +1435,11 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1435,15 +1435,11 @@ TEST_CASE(quant_dot_3args_batch)
p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3); p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3);
std::vector<int> gold = { std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150, 102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
284, 308, 332, 356, 380, 404, 428, 404, 428, 1530, 1570, 1610, 1650, 1690, 1730, 1770, 2160, 2216, 2272,
1530, 1570, 1610, 1650, 1690, 1730, 1770, 2328, 2384, 2440, 2496, 4750, 4822, 4894, 4966, 5038, 5110, 5182, 5828,
2160, 2216, 2272, 2328, 2384, 2440, 2496, 5916, 6004, 6092, 6180, 6268, 6356, 9762, 9866, 9970, 10074, 10178, 10282,
4750, 4822, 4894, 4966, 5038, 5110, 5182, 10386, 11288, 11408, 11528, 11648, 11768, 11888, 12008};
5828, 5916, 6004, 6092, 6180, 6268, 6356,
9762, 9866, 9970, 10074, 10178, 10282, 10386,
11288, 11408, 11528, 11648, 11768, 11888, 12008
};
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -1472,19 +1468,12 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1472,19 +1468,12 @@ TEST_CASE(quant_dot_3args_batch)
p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3); p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3);
std::vector<int> gold = { std::vector<int> gold = {
90, 237, 384, 531, 678, 825, 90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
120, 299, 478, 657, 836, 1015, 150, 361, 572, 783, 994, 1205, 3456, 3987, 4518, 5049, 5580, 6111,
150, 361, 572, 783, 994, 1205, 3678, 4241, 4804, 5367, 5930, 6493, 3900, 4495, 5090, 5685, 6280, 6875,
3456, 3987, 4518, 5049, 5580, 6111, 11430, 12345, 13260, 14175, 15090, 16005, 11844, 12791, 13738, 14685, 15632, 16579,
3678, 4241, 4804, 5367, 5930, 6493, 12258, 13237, 14216, 15195, 16174, 17153, 24012, 25311, 26610, 27909, 29208, 30507,
3900, 4495, 5090, 5685, 6280, 6875, 24618, 25949, 27280, 28611, 29942, 31273, 25224, 26587, 27950, 29313, 30676, 32039};
11430, 12345, 13260, 14175, 15090, 16005,
11844, 12791, 13738, 14685, 15632, 16579,
12258, 13237, 14216, 15195, 16174, 17153,
24012, 25311, 26610, 27909, 29208, 30507,
24618, 25949, 27280, 28611, 29942, 31273,
25224, 26587, 27950, 29313, 30676, 32039
};
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
......
...@@ -608,17 +608,13 @@ TEST_CASE(quant_dot_2args) ...@@ -608,17 +608,13 @@ TEST_CASE(quant_dot_2args)
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}}; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}};
migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}}; migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}};
throws_shape(migraphx::op::quant_dot{}, throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
s_m1,
s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}}; migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
throws_shape(migraphx::op::quant_dot{}, throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
s_m1,
s_m2);
} }
} }
...@@ -639,10 +635,7 @@ TEST_CASE(quant_dot_3args) ...@@ -639,10 +635,7 @@ TEST_CASE(quant_dot_3args)
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}}; migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
throws_shape(migraphx::op::quant_dot{1, 2}, throws_shape(migraphx::op::quant_dot{1, 2}, s_m1, s_m2, s_m3);
s_m1,
s_m2,
s_m3);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment