Commit e762116b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'int8_miopen_call' into int8_quantize

parents 5656cc4e 5fe89b69
......@@ -1093,4 +1093,394 @@ TEST_CASE(matmul_mm2)
}
}
TEST_CASE(quant_dot_2args_multi4)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 8}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2);
std::vector<int> gold = {112, 118, 124, 130, 136, 142, 148, 154, 304, 326, 348,
370, 392, 414, 436, 458, 496, 534, 572, 610, 648, 686,
724, 762, 688, 742, 796, 850, 904, 958, 1012, 1066};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 8}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, tl1, l2);
std::vector<int> gold = {448, 472, 496, 520, 544, 568, 592, 616, 496, 524, 552,
580, 608, 636, 664, 692, 544, 576, 608, 640, 672, 704,
736, 768, 592, 628, 664, 700, 736, 772, 808, 844};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 4}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{}, l1, tl2);
std::vector<int> gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214,
302, 390, 478, 566, 654, 62, 214, 366, 518, 670, 822,
974, 1126, 86, 302, 518, 734, 950, 1166, 1382, 1598};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 4}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{}, tl1, tl2);
std::vector<int> gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286,
398, 510, 622, 734, 846, 68, 196, 324, 452, 580, 708,
836, 964, 74, 218, 362, 506, 650, 794, 938, 1082};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_2args_general)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}};
std::vector<int8_t> data1(3 * 4);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2);
std::vector<int> gold = {
70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}};
std::vector<int8_t> data1(4 * 3);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, tl1, l2);
std::vector<int> gold = {
210, 228, 246, 264, 282, 240, 262, 284, 306, 328, 270, 296, 322, 348, 374};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {5, 4}};
std::vector<int8_t> data1(3 * 4);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(
migraphx::op::quant_dot{
2,
},
l1,
tl2);
std::vector<int> gold = {
28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {5, 4}};
std::vector<int8_t> data1(4 * 3);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2);
std::vector<int> gold = {
126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_3args_general)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3);
std::vector<int> gold = {
982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3);
std::vector<int> gold = {
1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3);
std::vector<int> gold = {
286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
std::vector<int> gold = {
844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_3args_batch)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 2, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 4, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 2, 7}};
std::vector<int8_t> data1(4 * 2 * 4);
std::vector<int8_t> data2(4 * 4 * 7);
std::vector<int> data3(4 * 2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3);
std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
404, 428, 1530, 1570, 1610, 1650, 1690, 1730, 1770, 2160, 2216, 2272,
2328, 2384, 2440, 2496, 4750, 4822, 4894, 4966, 5038, 5110, 5182, 5828,
5916, 6004, 6092, 6180, 6268, 6356, 9762, 9866, 9970, 10074, 10178, 10282,
10386, 11288, 11408, 11528, 11648, 11768, 11888, 12008};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 6, 4}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 3, 6}};
std::vector<int8_t> data1(48);
std::vector<int8_t> data2(96);
std::vector<int> data3(72);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3);
std::vector<int> gold = {
90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
150, 361, 572, 783, 994, 1205, 3456, 3987, 4518, 5049, 5580, 6111,
3678, 4241, 4804, 5367, 5930, 6493, 3900, 4495, 5090, 5685, 6280, 6875,
11430, 12345, 13260, 14175, 15090, 16005, 11844, 12791, 13738, 14685, 15632, 16579,
12258, 13237, 14216, 15195, 16174, 17153, 24012, 25311, 26610, 27909, 29208, 30507,
24618, 25949, 27280, 28611, 29942, 31273, 25224, 26587, 27950, 29313, 30676, 32039};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -1338,6 +1338,177 @@ TEST_CASE(conv2d_padding_stride_test)
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {10197,
10548,
11601,
11952,
25506,
26586,
29826,
30906,
27045,
27396,
28449,
28800,
77346,
78426,
81666,
82746};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_test_default_mode)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {
10197, 10548, 6939, 3420, 11601, 11952, 7839, 3852, 7383, 7590, 4953, 2421, 3480,
3570, 2316, 1125, 25506, 26586, 17874, 9009, 29826, 30906, 20718, 10413, 20505, 21198,
14187, 7119, 10527, 10860, 7257, 3636, 27045, 27396, 17739, 8604, 28449, 28800, 18639,
9036, 17319, 17526, 11289, 5445, 7800, 7890, 5052, 2421, 77346, 78426, 52002, 25857,
81666, 82746, 54846, 27261, 53769, 54462, 36075, 17919, 26511, 26844, 17769, 8820};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_test_valid_mode)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {10197,
10548,
11601,
11952,
25506,
26586,
29826,
30906,
27045,
27396,
28449,
28800,
77346,
78426,
81666,
82746};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {
4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007,
7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826,
30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396,
17739, 19494, 28449, 28800, 18639, 11919, 17319, 17526, 11289, 34707, 51843, 52590, 34893,
51813, 77346, 78426, 52002, 54729, 81666, 82746, 54846, 36057, 53769, 54462, 36075};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_stride_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {4521,
7014,
7830,
11952,
10515,
16734,
19737,
30906,
13161,
19542,
19494,
28800,
34707,
52590,
54729,
82746};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(transpose_test)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 2, 3}};
......
......@@ -1238,6 +1238,114 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
}
};
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3);
return p;
}
};
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3);
return p;
}
};
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3);
return p;
}
};
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p;
}
};
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p;
}
};
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3);
return p;
}
};
struct test_contiguous : verify_program<test_contiguous>
{
migraphx::program create_program() const
......@@ -1383,6 +1491,83 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
}
};
struct quant_conv : verify_program<quant_conv>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{}, pa, pc);
return p;
}
};
struct quant_conv_default_mode : verify_program<quant_conv_default_mode>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
pa,
pc);
return p;
}
};
struct quant_conv_valid_mode : verify_program<quant_conv_valid_mode>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid},
pa,
pc);
return p;
}
};
struct quant_conv_padding : verify_program<quant_conv_padding>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, pa, pc);
return p;
}
};
struct test_concat : verify_program<test_concat>
{
migraphx::program create_program() const
......
......@@ -76,6 +76,41 @@ TEST_CASE(convolution_shape)
throws_shape(migraphx::op::convolution{}, input2, weights);
}
TEST_CASE(quant_convolution_shape)
{
migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::op::quant_convolution{}, input, weights);
throws_shape(migraphx::op::quant_convolution{}, input);
migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::op::quant_convolution{}, input2, weights2);
throws_shape(migraphx::op::quant_convolution{}, input2, weights);
migraphx::shape input3{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::op::quant_convolution{}, input3, weights);
throws_shape(migraphx::op::quant_convolution{}, input, weight3);
throws_shape(migraphx::op::quant_convolution{}, input3, weight3);
migraphx::shape output_same_mode{migraphx::shape::float_type, {4, 4, 3, 3}};
expect_shape(output_same_mode,
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
input,
weights);
expect_shape(output,
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid},
input,
weights);
throws_shape(
migraphx::op::quant_convolution{
{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::padding_mode_t(9999)},
input,
weights);
}
TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......@@ -584,6 +619,61 @@ TEST_CASE(gemm)
}
}
// quant_dot
TEST_CASE(quant_dot_2args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::op::quant_dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
migraphx::op::quant_dot{1, 0},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}};
migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}};
throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
}
}
TEST_CASE(quant_dot_3args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::op::quant_dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
throws_shape(migraphx::op::quant_dot{1, 2}, s_m1, s_m2, s_m3);
}
}
TEST_CASE(rnn)
{
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment