Commit 1959394f authored by Khalique's avatar Khalique
Browse files

formatting

parent 6d2e8e92
...@@ -65,7 +65,7 @@ struct convolution ...@@ -65,7 +65,7 @@ struct convolution
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}}; std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}}; std::array<std::size_t, 2> dilation = {{1, 1}};
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
int group = 1; int group = 1;
...@@ -191,7 +191,7 @@ struct pooling ...@@ -191,7 +191,7 @@ struct pooling
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}}; std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}}; std::array<std::size_t, 2> lengths = {{1, 1}};
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -215,10 +215,10 @@ struct pooling ...@@ -215,10 +215,10 @@ struct pooling
assert(lengths[0] <= (input.lens()[2] + 2 * padding[0])); assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
assert(lengths[1] <= (input.lens()[3] + 2 * padding[1])); assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
if(padding_mode == default_) if(padding_mode == default_)
{ {
return {t, return {
t,
{ {
input.lens()[0], input.lens()[0],
input.lens()[1], input.lens()[1],
...@@ -247,26 +247,25 @@ struct pooling ...@@ -247,26 +247,25 @@ struct pooling
else if(padding_mode == valid) else if(padding_mode == valid)
{ {
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
input.lens()[1], input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) / std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) /
static_cast<float>(stride[0]))) + static_cast<float>(stride[0]))) +
1)), 1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) / std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) /
static_cast<float>(stride[1]))) + static_cast<float>(stride[1]))) +
1)), 1)),
}}; }};
} }
else else
{ {
MIGRAPHX_THROW("Invalid padding mode"); MIGRAPHX_THROW("Invalid padding mode");
} }
} }
}; };
......
...@@ -232,13 +232,13 @@ struct onnx_parser ...@@ -232,13 +232,13 @@ struct onnx_parser
{ {
// insert zeros for pad op (args[0] has 4 dims) // insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}; padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0); l0 = prog.add_instruction(op::pad{padding}, l0);
} }
else else
{ {
op.padding[0] = padding[0]; op.padding[0] = padding[0];
op.padding[1] = padding[1]; op.padding[1] = padding[1];
} }
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
...@@ -298,7 +298,7 @@ struct onnx_parser ...@@ -298,7 +298,7 @@ struct onnx_parser
{ {
// insert zeros for pad op (args[0] has 4 dims) // insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}; padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0); l0 = prog.add_instruction(op::pad{padding}, l0);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment