Commit cbd244d1 authored by Khalique's avatar Khalique
Browse files

formatting

parent 5e59bf45
......@@ -17,11 +17,11 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace op {
enum padding_mode_t
{
default_, // NOLINT
same,
valid
};
{
default_, // NOLINT
same,
valid
};
struct not_computable
{
......@@ -65,7 +65,7 @@ struct convolution
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
padding_mode_t padding_mode = default_;
int group = 1;
......@@ -191,7 +191,7 @@ struct pooling
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}};
padding_mode_t padding_mode = default_;
padding_mode_t padding_mode = default_;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -217,7 +217,8 @@ struct pooling
if(padding_mode == default_)
{
return {t,
return {
t,
{
input.lens()[0],
input.lens()[1],
......@@ -246,20 +247,20 @@ struct pooling
else if(padding_mode == valid)
{
return {t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) /
static_cast<float>(stride[0]))) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) /
static_cast<float>(stride[1]))) +
1)),
}};
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) /
static_cast<float>(stride[0]))) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) /
static_cast<float>(stride[1]))) +
1)),
}};
}
else
{
......
......@@ -206,14 +206,12 @@ struct tf_parser
{
op.stride[0] = stride[1];
op.stride[1] = stride[2];
}
else
{
op.stride[0] = stride[2];
op.stride[1] = stride[3];
op.stride[1] = stride[3];
}
}
if(contains(attributes, "dilations"))
{
......@@ -233,7 +231,6 @@ struct tf_parser
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
}
auto l0 = prog.add_instruction(op::transpose{{2, 3, 0, 1}}, args[1]);
......@@ -276,7 +273,6 @@ struct tf_parser
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
}
if(contains(attributes, "ksize"))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment