Commit 90665ffd authored by wsttiger's avatar wsttiger
Browse files

fixed up for PR

parent 0aa847f9
...@@ -141,7 +141,7 @@ struct cpu_im2col ...@@ -141,7 +141,7 @@ struct cpu_im2col
static std::string name() { return "cpu::im2col"; } static std::string name() { return "cpu::im2col"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto input_shape = args[0].get_shape(); auto input_shape = args[0].get_shape();
...@@ -157,18 +157,15 @@ struct cpu_im2col ...@@ -157,18 +157,15 @@ struct cpu_im2col
const std::size_t& stride_h = op.stride[0]; const std::size_t& stride_h = op.stride[0];
const std::size_t& stride_w = op.stride[1]; const std::size_t& stride_w = op.stride[1];
int ksize = kernel_h * kernel_w * channels;
int kdiv2_h, kdiv2_w; int kdiv2_h, kdiv2_w;
kdiv2_h = kernel_h / 2; kdiv2_h = kernel_h / 2;
kdiv2_w = kernel_w / 2; kdiv2_w = kernel_w / 2;
// calculate output sizes // calculate output sizes
const std::size_t col_height = (height - kernel_h + 2 * pad_h) / stride_h + 1; const std::size_t col_height = (height - kernel_h + 2 * pad_h) / stride_h + 1;
const std::size_t col_width = (width - kernel_w + 2 * pad_w) / stride_w + 1; const std::size_t col_width = (width - kernel_w + 2 * pad_w) / stride_w + 1;
// calculate number of pixels in frame // account for padding for the starting position of the input pixels
const std::size_t npixels = height * width;
// starting pixel positions
std::size_t iinput = kdiv2_h - pad_h; std::size_t iinput = kdiv2_h - pad_h;
// loop over output pixels // loop over output pixels (ioutput, joutput)
for(std::size_t ioutput = 0; ioutput < col_height; ioutput++, iinput += stride_h) for(std::size_t ioutput = 0; ioutput < col_height; ioutput++, iinput += stride_h)
{ {
std::size_t jinput = kdiv2_w - pad_w; std::size_t jinput = kdiv2_w - pad_w;
...@@ -181,14 +178,8 @@ struct cpu_im2col ...@@ -181,14 +178,8 @@ struct cpu_im2col
[&](std::size_t c, std::size_t koffset, std::size_t loffset) { [&](std::size_t c, std::size_t koffset, std::size_t loffset) {
int idx = iinput + koffset - kdiv2_h; int idx = iinput + koffset - kdiv2_h;
int jdx = jinput + loffset - kdiv2_w; int jdx = jinput + loffset - kdiv2_w;
if((idx >= 0) && (idx < height) && (jdx >= 0) && (jdx < width)) col(ldx, p) = ((idx >= 0) && (idx < height) && (jdx >= 0) && (jdx < width)) ?
{ input(0, c, idx, jdx) : 0;
col[ldx * ksize + p] = input[c * npixels + idx * width + jdx];
}
else
{
col[ldx * ksize + p] = 0;
}
p++; p++;
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment