Commit d013de49 authored by Paul's avatar Paul
Browse files

Merge branch 'pooling'

parents 95549529 4aff18c1
......@@ -179,6 +179,10 @@ struct pooling
const shape& input = inputs.at(0);
auto t = input.type();
assert(lengths[0] < (input.lens()[3] + 2 * padding[0]));
assert(lengths[1] < (input.lens()[4] + 2 * padding[1]));
return {t,
{
input.lens()[0],
......
......@@ -54,6 +54,79 @@ struct cpu_convolution
}
};
struct max_pool
{
static std::string name() { return "max"; }
static double start() { return std::numeric_limits<double>::lowest(); }
static double apply(double x, double y) { return x + y; }
static double final(double x, double) { return (x); }
};
struct avg_pool
{
static std::string name() { return "average"; }
static double start() { return 0.0; }
static double apply(double x, double y)
{
double m = std::max(x, y);
return (m);
}
static double final(double x, double y) { return x / y; }
};
template <class Op>
struct cpu_pooling
{
pooling op;
std::string name() const { return "cpu::pooling_" + Op::name(); }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
auto in_h = input.get_shape().lens()[2];
auto in_w = input.get_shape().lens()[3];
dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x0 = i * op.stride[0] - op.padding[0];
const int start_y0 = j * op.stride[1] - op.padding[1];
const int hend = std::min(start_x0 + op.lengths[0], in_h);
const int wend = std::min(start_y0 + op.lengths[1], in_w);
const int start_x = std::max(start_x0, 0);
const int start_y = std::max(start_y0, 0);
const int w_h = (hend - start_x);
const int w_w = (wend - start_y);
const int pool_size = std::max(w_h * w_w, 1);
double acc = Op::start();
dfor(w_h, w_w)([&](int x, int y) {
const int in_x = start_x + x;
const int in_y = start_y + y;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{
acc = Op::apply(acc, input(o, w, in_x, in_y));
}
});
output(o, w, i, j) = type(Op::final(acc, pool_size));
});
});
return result;
}
};
struct cpu_transpose
{
transpose op;
......@@ -485,6 +558,15 @@ struct cpu_apply
if(op.mode == "relu")
prog->replace_instruction(ins, cpu_unary<relu_op>{}, ins->arguments);
}
void apply_pooling(instruction_ref ins)
{
auto&& op = any_cast<pooling>(ins->op);
if(op.mode == "max")
prog->replace_instruction(ins, cpu_pooling<max_pool>{op}, ins->arguments);
else if(op.mode == "average")
prog->replace_instruction(ins, cpu_pooling<avg_pool>{op}, ins->arguments);
}
};
std::string cpu_target::name() const { return "cpu"; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment