Commit ae4da3a3 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change quant_convolution following changes from convolution.

parent 681560bb
......@@ -52,8 +52,6 @@ struct quant_convolution
}
t = shape::int32_type;
if(padding_mode == default_)
{
return {t,
{
input.lens()[0],
......@@ -72,32 +70,6 @@ struct quant_convolution
1)),
}};
}
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
}
else if(padding_mode == valid)
{
return {
t,
{input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
}
else
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: invalid padding mode");
}
}
};
} // namespace op
......
......@@ -941,9 +941,6 @@ TEST_CASE(softmax_simple_test)
auto result = p.eval({});
std::vector<float> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for(auto v : results_vector)
std::cout << v << "\t";
std::cout << std::endl;
EXPECT(migraphx::verify_range(results_vector, s));
}
......@@ -1357,76 +1354,6 @@ TEST_CASE(quant_conv2d_test)
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_test_default_mode)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int32_t> s = {
10197, 10548, 6939, 3420, 11601, 11952, 7839, 3852, 7383, 7590, 4953, 2421, 3480,
3570, 2316, 1125, 25506, 26586, 17874, 9009, 29826, 30906, 20718, 10413, 20505, 21198,
14187, 7119, 10527, 10860, 7257, 3636, 27045, 27396, 17739, 8604, 28449, 28800, 18639,
9036, 17319, 17526, 11289, 5445, 7800, 7890, 5052, 2421, 77346, 78426, 52002, 25857,
81666, 82746, 54846, 27261, 53769, 54462, 36075, 17919, 26511, 26844, 17769, 8820};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_test_valid_mode)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int32_t> s = {10197,
10548,
11601,
11952,
25506,
26586,
29826,
30906,
27045,
27396,
28449,
28800,
77346,
78426,
81666,
82746};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_test)
{
migraphx::program p;
......
......@@ -94,16 +94,6 @@ TEST_CASE(quant_convolution_shape)
throws_shape(migraphx::op::quant_convolution{}, input3, weights);
throws_shape(migraphx::op::quant_convolution{}, input, weight3);
throws_shape(migraphx::op::quant_convolution{}, input3, weight3);
migraphx::shape output_same_mode{migraphx::shape::int32_type, {4, 4, 3, 3}};
expect_shape(output_same_mode,
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
input,
weights);
expect_shape(output,
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid},
input,
weights);
}
TEST_CASE(transpose_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment