Commit ec48e189 authored by Paul's avatar Paul
Browse files

Use floor for pooling for now

parent 4886f3e8
...@@ -145,8 +145,8 @@ struct pooling ...@@ -145,8 +145,8 @@ struct pooling
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto t = input.type(); auto t = input.type();
assert(lengths[0] < (input.lens()[2] + 2 * padding[0])); assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
assert(lengths[1] < (input.lens()[3] + 2 * padding[1])); assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
return {t, return {t,
{ {
...@@ -154,12 +154,12 @@ struct pooling ...@@ -154,12 +154,12 @@ struct pooling
input.lens()[1], input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
std::ptrdiff_t(std::ceil((input.lens()[2] + 2 * padding[0] - lengths[0]) / std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
static_cast<float>(stride[0]))) + static_cast<float>(stride[0]))) +
1)), 1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
std::ptrdiff_t(std::ceil((input.lens()[3] + 2 * padding[1] - lengths[1]) / std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
static_cast<float>(stride[1]))) + static_cast<float>(stride[1]))) +
1)), 1)),
}}; }};
...@@ -236,6 +236,13 @@ struct transpose ...@@ -236,6 +236,13 @@ struct transpose
{ {
return {output_shape, std::move(args.front().data)}; return {output_shape, std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const transpose& op)
{
os << op.name() << "[";
os << "dims={" << stream_range(op.dims) << "}";
os << "]";
return os;
}
}; };
struct contiguous struct contiguous
...@@ -305,7 +312,7 @@ struct reshape ...@@ -305,7 +312,7 @@ struct reshape
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
os << op.name() << "["; os << op.name() << "[";
os << "dims={" << stream_range(op.dims) << "}, "; os << "dims={" << stream_range(op.dims) << "}";
os << "]"; os << "]";
return os; return os;
} }
...@@ -443,6 +450,13 @@ struct flatten ...@@ -443,6 +450,13 @@ struct flatten
{ {
return {output_shape, std::move(args.front().data)}; return {output_shape, std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const flatten& op)
{
os << op.name() << "[";
os << "axis=" << op.axis;
os << "]";
return os;
}
}; };
struct broadcast struct broadcast
{ {
...@@ -476,6 +490,13 @@ struct broadcast ...@@ -476,6 +490,13 @@ struct broadcast
{ {
return {output_shape, std::move(args.at(1).data)}; return {output_shape, std::move(args.at(1).data)};
} }
friend std::ostream& operator<<(std::ostream& os, const broadcast& op)
{
os << op.name() << "[";
os << "axis=" << op.axis;
os << "]";
return os;
}
}; };
struct binary struct binary
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment