Commit 121ded22 authored by Scott Thornton's avatar Scott Thornton
Browse files
parents eb6452fa 4148ca64
...@@ -91,6 +91,13 @@ struct convolution ...@@ -91,6 +91,13 @@ struct convolution
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}}; std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}}; std::array<std::size_t, 2> dilation = {{1, 1}};
enum padding_mode_t
{
default_, // NOLINT
same,
valid
};
padding_mode_t padding_mode = default_;
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -99,23 +106,51 @@ struct convolution ...@@ -99,23 +106,51 @@ struct convolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
auto t = input.type(); auto t = input.type();
return {t, if(padding_mode == default_)
{ {
input.lens()[0], return {t,
weights.lens()[0], {
std::size_t(std::max<std::ptrdiff_t>( input.lens()[0],
1, weights.lens()[0],
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) + std::size_t(std::max<std::ptrdiff_t>(
2 * padding[0]) / 1,
stride[0] + (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
1)), 2 * padding[0]) /
std::size_t(std::max<std::ptrdiff_t>( stride[0] +
1, 1)),
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) + std::size_t(std::max<std::ptrdiff_t>(
2 * padding[1]) / 1,
stride[1] + (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
1)), 2 * padding[1]) /
}}; stride[1] +
1)),
}};
}
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
}
else if(padding_mode == valid)
{
return {
t,
{input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
}
else
{
RTG_THROW("Invalid padding mode");
}
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
......
...@@ -15,6 +15,8 @@ namespace rtg { ...@@ -15,6 +15,8 @@ namespace rtg {
struct program_impl; struct program_impl;
const operation& get_operation(instruction_ref ins);
/** /**
* @brief Stores the instruction stream * @brief Stores the instruction stream
*/ */
......
...@@ -12,6 +12,8 @@ struct program_impl ...@@ -12,6 +12,8 @@ struct program_impl
std::list<instruction> instructions; std::list<instruction> instructions;
}; };
const operation& get_operation(instruction_ref ins) { return ins->op; }
program::program() : impl(std::make_unique<program_impl>()) {} program::program() : impl(std::make_unique<program_impl>()) {}
program::program(program&&) noexcept = default; program::program(program&&) noexcept = default;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment