Commit 661e6287 authored by Paul's avatar Paul
Browse files

Formatting

parent 839a69f1
...@@ -98,7 +98,8 @@ struct convolution ...@@ -98,7 +98,8 @@ struct convolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
auto t = input.type(); auto t = input.type();
if (padding_mode == default_) { if(padding_mode == default_)
{
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
...@@ -116,21 +117,30 @@ struct convolution ...@@ -116,21 +117,30 @@ struct convolution
stride[1] + stride[1] +
1)), 1)),
}}; }};
} else if(padding_mode == same) { }
return {t, { else if(padding_mode == same)
input.lens()[0], {
weights.lens()[0], return {t,
static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[2]) / stride[0])), {input.lens()[0],
static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[3]) / stride[1])) weights.lens()[0],
}}; static_cast<std::size_t>(
} else if(padding_mode == valid) { std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
return {t, { static_cast<std::size_t>(
input.lens()[0], std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
weights.lens()[0], }
static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])), else if(padding_mode == valid)
static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1])) {
}}; return {
} else { t,
{input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
}
else
{
RTG_THROW("Invalid padding mode"); RTG_THROW("Invalid padding mode");
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment