Commit 695d873e authored by Scott Thornton's avatar Scott Thornton
Browse files

Added im2col to cpu operators

parent e9b33f76
......@@ -131,6 +131,52 @@ struct convolution
}
};
struct im2col {
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
enum padding_mode_t
{
default_, // NOLINT
same,
valid
};
std::string name() const { return "im2col"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto input = inputs[0];
auto weights = inputs[1];
auto batch_size = input.lens()[0];
auto input_channels = weights.lens()[1];
auto kernel_height = weights.lens()[2];
auto kernel_width = weights.lens()[3];
check_shapes{inputs, *this}.has(2);
if (batch_size != 1) MIGRAPH_THROW("im2col only support batch_size 1");
auto output_height = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) +
2 * padding[0]) /
stride[0] +
1));
auto output_width = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) +
2 * padding[1]) /
stride[1] +
1));
auto channels_col = kernel_height*kernel_width*input_channels;
return {input.type(), {output_height*output_width, channels_col}};
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct pooling
{
std::string mode = "average";
......
......@@ -134,6 +134,67 @@ struct cpu_convolution
}
};
struct cpu_im2col {
im2col op;
static std::string name() {return "cpu::im2col"; }
argument compute(context&, shape output_shape, std::vector<argument> args) const {
argument result{output_shape};
auto input_shape = args[0].get_shape();
auto weights_shape = args[1].get_shape();
visit_all(result, args[0])([&](auto col, auto input) {
const std::size_t& height = input_shape.lens()[2];
const std::size_t& width = input_shape.lens()[3];
const std::size_t& channels = weights_shape.lens()[1];
const std::size_t& kernel_h = weights_shape.lens()[2];
const std::size_t& kernel_w = weights_shape.lens()[3];
const std::size_t& pad_h = op.padding[0];
const std::size_t& pad_w = op.padding[1];
const std::size_t& stride_h = op.stride[0];
const std::size_t& stride_w = op.stride[1];
int ksize = kernel_h * kernel_w * channels;
int kdiv2_h, kdiv2_w;
kdiv2_h = kernel_h / 2; kdiv2_w = kernel_w / 2;
// calculate output sizes
const std::size_t col_height = (height - kernel_h + 2 * pad_h)/stride_h + 1;
const std::size_t col_width = (width - kernel_w + 2 * pad_w)/stride_w + 1;
// calculate number of pixels in frame
const std::size_t npixels = height*width;
// starting pixel positions
std::size_t iinput = kdiv2_h - pad_h;
// loop over output pixels
for (std::size_t ioutput = 0; ioutput < col_height; ioutput++, iinput+=stride_h) {
std::size_t jinput = kdiv2_w - pad_w;
for (std::size_t joutput = 0; joutput < col_width; joutput++, jinput+=stride_w) {
// compute linear index for output
std::size_t ldx = ioutput * col_width + joutput;
std::size_t p = 0;
for (std::size_t c = 0; c < channels; c++) {
for (int koffset = -kdiv2_h; koffset <= kdiv2_h; koffset++) {
for (int loffset = -kdiv2_w; loffset <= kdiv2_w; loffset++) {
int idx = iinput + koffset;
int jdx = jinput + loffset;
if ((idx >= 0) && (idx < height) &&
(jdx >= 0) && (jdx < width)) {
col[ldx * ksize + p] = input[c * npixels + idx * width + jdx];
}
else {
col[ldx * ksize + p] = 0;
}
p++;
}
}
}
}
}
});
return result;
}
};
struct max_pool
{
static std::string name() { return "max"; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment