Commit 8a18175b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add tests for quant_convoluation.

parent 040bbf04
......@@ -1371,7 +1371,125 @@ TEST_CASE(quant_conv2d_test)
81666,
82746};
std::vector<float> results_vector(16);
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_test_default_mode)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {
10197, 10548, 6939, 3420, 11601, 11952,
7839, 3852, 7383, 7590, 4953, 2421,
3480, 3570, 2316, 1125, 25506, 26586,
17874, 9009, 29826, 30906, 20718, 10413,
20505, 21198, 14187, 7119, 10527, 10860,
7257, 3636, 27045, 27396, 17739, 8604,
28449, 28800, 18639, 9036, 17319, 17526,
11289, 5445, 7800, 7890, 5052, 2421,
77346, 78426, 52002, 25857, 81666, 82746,
54846, 27261, 53769, 54462, 36075, 17919,
26511, 26844, 17769, 8820};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_test_valid_mode)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {
10197, 10548, 11601, 11952, 25506, 26586,
29826, 30906, 27045, 27396, 28449, 28800,
77346, 78426, 81666, 82746};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {
4521, 6753, 7014, 4635, 6858, 10197,
10548, 6939, 7830, 11601, 11952, 7839,
5007, 7383, 7590, 4953, 10515, 15987,
16734, 11277, 16821, 25506, 26586, 17874,
19737, 29826, 30906, 20718, 13593, 20505,
21198, 14187, 13161, 19281, 19542, 12699,
18522, 27045, 27396, 17739, 19494, 28449,
28800, 18639, 11919, 17319, 17526, 11289,
34707, 51843, 52590, 34893, 51813, 77346,
78426, 52002, 54729, 81666, 82746, 54846,
36057, 53769, 54462, 36075};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_stride_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {
4521, 7014, 7830, 11952, 10515, 16734,
19737, 30906, 13161, 19542, 19494, 28800,
34707, 52590, 54729, 82746};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
......
......@@ -1490,6 +1490,72 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
}
};
struct quant_conv : verify_program<quant_conv>
{
migraphx::program create_program() {
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{}, pa, pc);
return p;
}
};
struct quant_conv_default_mode : verify_program<quant_conv_default_mode>
{
migraphx::program create_program() {
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same}, pa, pc);
return p;
}
};
struct quant_conv_valid_mode : verify_program<quant_conv_valid_mode>
{
migraphx::program create_program() {
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid}, pa, pc);
return p;
}
};
struct quant_conv_padding : verify_program<quant_conv_padding>
{
migraphx::program create_program() {
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
{
migraphx::program create_program() {
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, pa, pc);
return p;
}
};
struct test_concat : verify_program<test_concat>
{
migraphx::program create_program() const
......
......@@ -76,6 +76,34 @@ TEST_CASE(convolution_shape)
throws_shape(migraphx::op::convolution{}, input2, weights);
}
TEST_CASE(quant_convolution_shape)
{
migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::op::quant_convolution{}, input, weights);
throws_shape(migraphx::op::quant_convolution{}, input);
migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::op::quant_convolution{}, input2, weights2);
throws_shape(migraphx::op::quant_convolution{}, input2, weights);
migraphx::shape input3{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::op::quant_convolution{}, input3, weights);
throws_shape(migraphx::op::quant_convolution{}, input, weight3);
throws_shape(migraphx::op::quant_convolution{}, input3, weight3);
migraphx::shape output_same_mode{migraphx::shape::float_type, {4, 4, 3, 3}};
expect_shape(output_same_mode, migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}},
migraphx::op::same}, input, weights);
expect_shape(output, migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}},
migraphx::op::valid}, input, weights);
throws_shape(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}},
migraphx::op::padding_mode_t(9999)}, input, weights);
}
TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment