Commit 839a69f1 authored by Paul's avatar Paul
Browse files

Add padding mode to conv

parent 32b1fda9
......@@ -83,6 +83,13 @@ struct convolution
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
enum padding_mode_t
{
default_, // NOLINT
same,
valid
};
padding_mode_t padding_mode = default_;
std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -91,23 +98,41 @@ struct convolution
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
auto t = input.type();
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
if (padding_mode == default_) {
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
} else if(padding_mode == same) {
return {t, {
input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))
}};
} else if(padding_mode == valid) {
return {t, {
input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))
}};
} else {
RTG_THROW("Invalid padding mode");
}
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment