Unverified Commit ba73bcc5 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #257 from open-mmlab/pytorch-1.0

Support Pytorch 1.0
parents b6561a1a e83e5d0f
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
- Linux (tested on Ubuntu 16.04 and CentOS 7.2) - Linux (tested on Ubuntu 16.04 and CentOS 7.2)
- Python 3.4+ - Python 3.4+
- PyTorch 0.4.1 - PyTorch 1.0
- Cython - Cython
- [mmcv](https://github.com/open-mmlab/mmcv) - [mmcv](https://github.com/open-mmlab/mmcv) >= 0.2.2
### Install mmdetection ### Install mmdetection
a. Install PyTorch 0.4.1 and torchvision following the [official instructions](https://pytorch.org/). a. Install PyTorch 1.0 and torchvision following the [official instructions](https://pytorch.org/).
b. Clone the mmdetection repository. b. Clone the mmdetection repository.
......
...@@ -10,11 +10,17 @@ ...@@ -10,11 +10,17 @@
### Software environment ### Software environment
- Python 3.6 / 3.7 - Python 3.6 / 3.7
- PyTorch 0.4.1 - PyTorch 1.0
- CUDA 9.0.176 - CUDA 9.0.176
- CUDNN 7.0.4 - CUDNN 7.0.4
- NCCL 2.1.15 - NCCL 2.1.15
Note: The train time was measured with PyTorch 0.4.1. We will update it later, which should be about 0.02s ~ 0.05s faster.
## Mirror sites
We use AWS as the main site to host our model zoo, and maintain a mirror on aliyun.
You can replace `https://s3.ap-northeast-2.amazonaws.com` with `https://open-mmlab.oss-cn-beijing.aliyuncs.com` in model urls.
## Common settings ## Common settings
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
## Introduction ## Introduction
The master branch works with **PyTorch 1.0**. If you would like to use PyTorch 0.4.1,
please checkout to the [pytorch-0.4.1](https://github.com/open-mmlab/mmdetection/tree/pytorch-0.4.1) branch.
mmdetection is an open source object detection toolbox based on PyTorch. It is mmdetection is an open source object detection toolbox based on PyTorch. It is
a part of the open-mmlab project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk/). a part of the open-mmlab project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk/).
...@@ -36,6 +39,9 @@ This project is released under the [Apache 2.0 license](LICENSE). ...@@ -36,6 +39,9 @@ This project is released under the [Apache 2.0 license](LICENSE).
## Updates ## Updates
v0.6rc0(06/02/2019)
- Migrate to PyTorch 1.0.
v0.5.7 (06/02/2019) v0.5.7 (06/02/2019)
- Add support for Deformable ConvNet v2. (Many thanks to the authors and [@chengdazhi](https://github.com/chengdazhi)) - Add support for Deformable ConvNet v2. (Many thanks to the authors and [@chengdazhi](https://github.com/chengdazhi))
- This is the last release based on PyTorch 0.4.1. - This is the last release based on PyTorch 0.4.1.
......
...@@ -34,13 +34,21 @@ def sigmoid_focal_loss(pred, ...@@ -34,13 +34,21 @@ def sigmoid_focal_loss(pred,
weight, weight,
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
reduction='elementwise_mean'): reduction='mean'):
pred_sigmoid = pred.sigmoid() pred_sigmoid = pred.sigmoid()
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma) weight = weight * pt.pow(gamma)
return F.binary_cross_entropy_with_logits( loss = F.binary_cross_entropy_with_logits(
pred, target, weight, reduction=reduction) pred, target, reduction='none') * weight
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weighted_sigmoid_focal_loss(pred, def weighted_sigmoid_focal_loss(pred,
...@@ -62,22 +70,22 @@ def mask_cross_entropy(pred, target, label): ...@@ -62,22 +70,22 @@ def mask_cross_entropy(pred, target, label):
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1) pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits( return F.binary_cross_entropy_with_logits(
pred_slice, target, reduction='elementwise_mean')[None] pred_slice, target, reduction='mean')[None]
def smooth_l1_loss(pred, target, beta=1.0, reduction='elementwise_mean'): def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'):
assert beta > 0 assert beta > 0
assert pred.size() == target.size() and target.numel() > 0 assert pred.size() == target.size() and target.numel() > 0
diff = torch.abs(pred - target) diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta, loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta) diff - 0.5 * beta)
reduction = F._Reduction.get_enum(reduction) reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2 # none: 0, mean:1, sum: 2
if reduction == 0: if reduction_enum == 0:
return loss return loss
elif reduction == 1: elif reduction_enum == 1:
return loss.sum() / pred.numel() return loss.sum() / pred.numel()
elif reduction == 2: elif reduction_enum == 2:
return loss.sum() return loss.sum()
......
// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c // modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include <torch/torch.h> #include <torch/extension.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
void deformable_im2col(const at::Tensor data_im, void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_offset, const int channels, const int channels, const int height, const int width,
const int height, const int width, const int ksize_h, const int ksize_h, const int ksize_w, const int pad_h,
const int ksize_w, const int pad_h, const int pad_w, const int pad_w, const int stride_h, const int stride_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int dilation_h, const int dilation_w,
const int parallel_imgs, const int parallel_imgs, const int deformable_group,
const int deformable_group, at::Tensor data_col); at::Tensor data_col);
void deformable_col2im(const at::Tensor data_col,
const at::Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs,
const int deformable_group, at::Tensor grad_im);
void deformable_col2im_coord(const at::Tensor data_col,
const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h,
const int pad_w, const int stride_h,
const int stride_w, const int dilation_h,
const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor grad_offset);
void modulated_deformable_im2col_cuda(const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor data_col);
void modulated_deformable_col2im_cuda(const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void shape_check(at::Tensor input, at::Tensor offset,
at::Tensor *gradOutput, at::Tensor weight, int kH, int kW,
int dH, int dW, int padH, int padW, int dilationH,
int dilationW, int group, int deformable_group)
{
AT_CHECK(weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s",
weight.ndimension());
AT_CHECK(weight.is_contiguous(),
"weight tensor has to be contiguous");
AT_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d",
kH, kW);
AT_CHECK((weight.size(2) == kH &&
weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
kW, weight.size(2), weight.size(3));
AT_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
AT_CHECK(dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input.ndimension();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4)
{
dimf++;
dimh++;
dimw++;
}
AT_CHECK(ndim == 3 || ndim == 4, void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
"3D or 4D input tensor expected but got: %s", ndim); const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
long nInputPlane = weight.size(1) * group; const int pad_w, const int stride_h, const int stride_w,
long inputHeight = input.size(dimh); const int dilation_h, const int dilation_w,
long inputWidth = input.size(dimw); const int parallel_imgs, const int deformable_group,
long nOutputPlane = weight.size(0); at::Tensor grad_im);
long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im,
AT_CHECK(nInputPlane % deformable_group == 0, const at::Tensor data_offset, const int channels, const int height,
"input channels must divide deformable group size"); const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
if (outputWidth < 1 || outputHeight < 1) const int dilation_h, const int dilation_w, const int parallel_imgs,
AT_ERROR( const int deformable_group, at::Tensor grad_offset);
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small", void modulated_deformable_im2col_cuda(
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, const at::Tensor data_im, const at::Tensor data_offset,
outputWidth); const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
AT_CHECK(input.size(1) == nInputPlane, const int width_col, const int kernel_h, const int kenerl_w,
"invalid number of input planes, expected: %d, but got: %d", const int pad_h, const int pad_w, const int stride_h, const int stride_w,
nInputPlane, input.size(1)); const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor data_col);
AT_CHECK((inputHeight >= kH && inputWidth >= kW),
"input image is smaller than kernel"); void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset,
AT_CHECK( const at::Tensor data_mask, const int batch_size, const int channels,
(offset.size(2) == outputHeight && offset.size(3) == outputWidth), const int height_im, const int width_im, const int height_col,
"invalid spatial size of offset, expected height: %d width: %d, but got height: %d width: %d", const int width_col, const int kernel_h, const int kenerl_w,
outputHeight, outputWidth, offset.size(2), offset.size(3)); const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), at::Tensor grad_im);
"invalid number of channels of offset");
void modulated_deformable_col2im_coord_cuda(
if (gradOutput != NULL) const at::Tensor data_col, const at::Tensor data_im,
{ const at::Tensor data_offset, const at::Tensor data_mask,
AT_CHECK(gradOutput->size(dimf) == nOutputPlane, const int batch_size, const int channels, const int height_im,
"invalid number of gradOutput planes, expected: %d, but got: %d", const int width_im, const int height_col, const int width_col,
nOutputPlane, gradOutput->size(dimf)); const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
AT_CHECK((gradOutput->size(dimh) == outputHeight && const int dilation_w, const int deformable_group, at::Tensor grad_offset,
gradOutput->size(dimw) == outputWidth), at::Tensor grad_mask);
"invalid size of gradOutput, expected height: %d width: %d , but got height: %d width: %d",
outputHeight, outputWidth, gradOutput->size(dimh), gradOutput->size(dimw)); void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
} at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
int padW, int dilationH, int dilationW, int group,
int deformable_group) {
AT_CHECK(weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s",
weight.ndimension());
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
AT_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
kW);
AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
kW, weight.size(2), weight.size(3));
AT_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
AT_CHECK(
dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input.ndimension();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4) {
dimf++;
dimh++;
dimw++;
}
AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
ndim);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
AT_CHECK(nInputPlane % deformable_group == 0,
"input channels must divide deformable group size");
if (outputWidth < 1 || outputHeight < 1)
AT_ERROR(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
outputWidth);
AT_CHECK(input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane, input.size(1));
AT_CHECK((inputHeight >= kH && inputWidth >= kW),
"input image is smaller than kernel");
AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
"invalid spatial size of offset, expected height: %d width: %d, but "
"got height: %d width: %d",
outputHeight, outputWidth, offset.size(2), offset.size(3));
AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
"invalid number of channels of offset");
if (gradOutput != NULL) {
AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
"invalid number of gradOutput planes, expected: %d, but got: %d",
nOutputPlane, gradOutput->size(dimf));
AT_CHECK((gradOutput->size(dimh) == outputHeight &&
gradOutput->size(dimw) == outputWidth),
"invalid size of gradOutput, expected height: %d width: %d , but "
"got height: %d width: %d",
outputHeight, outputWidth, gradOutput->size(dimh),
gradOutput->size(dimw));
}
} }
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
...@@ -155,480 +153,543 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, ...@@ -155,480 +153,543 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor columns, at::Tensor ones, int kW, at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH, int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group, int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) int deformable_group, int im2col_step) {
{ // todo: resize columns to include im2col: done
// todo: add im2col_step as input
// todo: resize columns to include im2col: done // todo: add new output buffer and transpose it to output (or directly
// todo: add im2col_step as input // transpose output) todo: possibly change data indexing because of
// todo: add new output buffer and transpose it to output (or directly transpose output) // parallel_imgs
// todo: possibly change data indexing because of parallel_imgs
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, dilationH, dilationW, group, deformable_group);
dilationH, dilationW, group, deformable_group);
input = input.contiguous();
input = input.contiguous(); offset = offset.contiguous();
offset = offset.contiguous(); weight = weight.contiguous();
weight = weight.contiguous();
int batch = 1;
int batch = 1; if (input.ndimension() == 3) {
if (input.ndimension() == 3) // Force batch
{ batch = 0;
// Force batch input.unsqueeze_(0);
batch = 0; offset.unsqueeze_(0);
input.unsqueeze_(0); }
offset.unsqueeze_(0);
} // todo: assert batchsize dividable by im2col_step
// todo: assert batchsize dividable by im2col_step long batchSize = input.size(0);
long nInputPlane = input.size(1);
long batchSize = input.size(0); long inputHeight = input.size(2);
long nInputPlane = input.size(1); long inputWidth = input.size(3);
long inputHeight = input.size(2);
long inputWidth = input.size(3); long nOutputPlane = weight.size(0);
long nOutputPlane = weight.size(0); long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; long outputHeight =
long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, outputHeight, outputWidth}); output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
columns = at::zeros({nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, input.type()); outputHeight, outputWidth});
columns = at::zeros(
if (ones.ndimension() != 2 || ones.size(0) * ones.size(1) < outputHeight * outputWidth) {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
{ input.type());
ones = at::ones({outputHeight, outputWidth}, input.type());
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
ones = at::ones({outputHeight, outputWidth}, input.type());
}
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
at::Tensor output_buffer =
at::zeros({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth},
output.type());
output_buffer = output_buffer.view(
{output_buffer.size(0), group, output_buffer.size(1) / group,
output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
output_buffer[elt][g] = output_buffer[elt][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output_buffer[elt][g]);
} }
}
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth}); output_buffer = output_buffer.view(
offset = offset.view({batchSize / im2col_step, im2col_step, {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
deformable_group * 2 * kH * kW, outputHeight, outputWidth}); output_buffer.size(3), output_buffer.size(4)});
at::Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth}, output.type());
output_buffer = output_buffer.view({output_buffer.size(0), group, output_buffer.size(1) / group, output_buffer.size(2), output_buffer.size(3)}); output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
{ offset = offset.view(
deformable_im2col( {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
im2col_step, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)}); if (batch == 0) {
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); output = output.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
}
for (int g = 0; g < group; g++){ return 1;
output_buffer[elt][g] =
output_buffer[elt][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output_buffer[elt][g]);
}
}
output_buffer = output_buffer.view({output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view(
{batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0)
{
output = output.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
} }
int deform_conv_backward_input_cuda( int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor input, at::Tensor offset, at::Tensor gradOutput, at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradInput, at::Tensor gradOffset, at::Tensor weight, at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW, int dH, int padW, int padH, at::Tensor columns, int kW, int kH, int dW,
int dilationW, int dilationH, int group, int deformable_group, int im2col_step) int dH, int padW, int padH, int dilationW,
{ int dilationH, int group,
int deformable_group, int im2col_step) {
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
padW, dilationH, dilationW, group, deformable_group); dilationH, dilationW, group, deformable_group);
input = input.contiguous(); input = input.contiguous();
offset = offset.contiguous(); offset = offset.contiguous();
gradOutput = gradOutput.contiguous(); gradOutput = gradOutput.contiguous();
weight = weight.contiguous(); weight = weight.contiguous();
int batch = 1; int batch = 1;
if (input.ndimension() == 3) if (input.ndimension() == 3) {
{ // Force batch
// Force batch batch = 0;
batch = 0; input = input.view({1, input.size(0), input.size(1), input.size(2)});
input = input.view({1, input.size(0), input.size(1), input.size(2)}); offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
gradOutput = gradOutput.view({1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
columns = at::zeros({nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, input.type());
// change order of grad output
gradOutput = gradOutput.view( gradOutput = gradOutput.view(
{batchSize / im2col_step, im2col_step, nOutputPlane, outputHeight, outputWidth}); {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
gradOutput.transpose_(1, 2); }
gradInput = gradInput.view( long batchSize = input.size(0);
{batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth}); long nInputPlane = input.size(1);
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth}); long inputHeight = input.size(2);
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, long inputWidth = input.size(3);
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
offset = offset.view({batchSize / im2col_step, im2col_step, long nOutputPlane = weight.size(0);
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
long outputWidth =
for (int elt = 0; elt < batchSize / im2col_step; elt++) (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
{ long outputHeight =
// divide into groups (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradOutput = gradOutput.view({gradOutput.size(0), group, gradOutput.size(1) / group, gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
columns = at::zeros(
for (int g = 0; g < group; g++){ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), gradOutput[elt][g].flatten(1), 0.0f, 1.0f); input.type());
}
// change order of grad output
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
gradOutput = gradOutput.view({gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
deformable_col2im_coord(
columns, input[elt], offset[elt], gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, inputHeight, inputWidth});
dilationH, dilationW, im2col_step, deformable_group, gradOffset[elt]); input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
deformable_col2im( gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
columns, offset[elt], nInputPlane, inputHeight, deformable_group * 2 * kH * kW, outputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, im2col_step, outputWidth});
deformable_group, gradInput[elt]); offset =
} offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
gradOutput = gradOutput.view(
{gradOutput.size(0), group, gradOutput.size(1) / group,
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
gradOutput.transpose_(1, 2); for (int g = 0; g < group; g++) {
gradOutput = gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
offset = offset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0)
{
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
gradOffset = gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
} }
return 1; columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view(
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW, im2col_step, deformable_group,
gradOffset[elt]);
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, gradInput[elt]);
}
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
gradOffset =
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
} }
int deform_conv_backward_parameters_cuda( int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput, at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias, at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group, int deformable_group, int padW, int padH, int dilationW, int dilationH, int group,
float scale, int im2col_step) int deformable_group, float scale, int im2col_step) {
{ // todo: transpose and reshape outGrad
// todo: reshape columns
// todo: transpose and reshape outGrad // todo: add im2col_step as input
// todo: reshape columns
// todo: add im2col_step as input shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, group, deformable_group);
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW,
padH, padW, dilationH, dilationW, group, deformable_group); input = input.contiguous();
offset = offset.contiguous();
input = input.contiguous(); gradOutput = gradOutput.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous(); int batch = 1;
int batch = 1; if (input.ndimension() == 3) {
// Force batch
if (input.ndimension() == 3) batch = 0;
{ input = input.view(
// Force batch at::IntList({1, input.size(0), input.size(1), input.size(2)}));
batch = 0;
input = input.view(at::IntList({1, input.size(0), input.size(1), input.size(2)}));
gradOutput = gradOutput.view({1, gradOutput.size(0),
gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = gradWeight.size(0);
long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
columns = at::zeros({nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, input.type());
gradOutput = gradOutput.view( gradOutput = gradOutput.view(
{batchSize / im2col_step, im2col_step, nOutputPlane, outputHeight, outputWidth}); {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
gradOutput.transpose_(1, 2); }
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = gradWeight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.type());
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
outputHeight, outputWidth});
gradOutputBuffer.copy_(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
// divide into group
gradOutputBuffer = gradOutputBuffer.view( gradOutputBuffer = gradOutputBuffer.view(
{batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth}); {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
gradOutputBuffer.copy_(gradOutput); gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
gradOutputBuffer = gradOutputBuffer.view( columns = columns.view({group, columns.size(0) / group, columns.size(1)});
{batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth}); gradWeight =
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
gradOutput.transpose_(1, 2); gradWeight.size(2), gradWeight.size(3)});
gradOutput = gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
for (int g = 0; g < group; g++) {
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth}); gradWeight[g] = gradWeight[g]
offset = offset.view({batchSize / im2col_step, im2col_step, .flatten(1)
deformable_group * 2 * kH * kW, .addmm_(gradOutputBuffer[elt][g].flatten(1),
outputHeight, outputWidth}); columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
for (int elt = 0; elt < batchSize / im2col_step; elt++)
{
deformable_im2col(
input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
im2col_step, deformable_group, columns);
// divide into group
gradOutputBuffer = gradOutputBuffer.view({gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight = gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), gradWeight.size(2), gradWeight.size(3)});
for (int g = 0; g < group; g++){
gradWeight[g] = gradWeight[g].flatten(1).addmm_(
gradOutputBuffer[elt][g].flatten(1), columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view({gradOutputBuffer.size(0), gradOutputBuffer.size(1) * gradOutputBuffer.size(2), gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), gradWeight.size(2), gradWeight.size(3), gradWeight.size(4)});
} }
gradOutputBuffer = gradOutputBuffer.view(
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); {gradOutputBuffer.size(0),
offset = offset.view({batchSize, deformable_group * 2 * kH * kW, gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
outputHeight, outputWidth}); gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns =
if (batch == 0) columns.view({columns.size(0) * columns.size(1), columns.size(2)});
{ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); gradWeight.size(2), gradWeight.size(3),
input = input.view({nInputPlane, inputHeight, inputWidth}); gradWeight.size(4)});
} }
return 1; input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
}
return 1;
} }
void modulated_deform_conv_cuda_forward(
void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight, at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor bias, at::Tensor ones, at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
at::Tensor offset, at::Tensor mask, int kernel_h, int kernel_w, const int stride_h, const int stride_w,
at::Tensor output, at::Tensor columns, const int pad_h, const int pad_w, const int dilation_h,
int kernel_h, int kernel_w, const int dilation_w, const int group, const int deformable_group,
const int stride_h, const int stride_w, const bool with_bias) {
const int pad_h, const int pad_w, AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int dilation_h, const int dilation_w, const int group, AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int deformable_group, const bool with_bias)
{ const int batch = input.size(0);
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); const int channels = input.size(1);
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); const int height = input.size(2);
const int width = input.size(3);
const int batch = input.size(0);
const int channels = input.size(1); const int channels_out = weight.size(0);
const int height = input.size(2); const int channels_kernel = weight.size(1);
const int width = input.size(3); const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1); if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
const int kernel_h_ = weight.size(2); AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
const int kernel_w_ = weight.size(3); kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", channels, channels_kernel * group);
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group) const int height_out =
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
channels, channels_kernel * group); const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
if (ones.ndimension() != 2 || // Resize plane and fill with ones...
ones.size(0) * ones.size(1) < height_out * width_out) ones = at::ones({height_out, width_out}, input.type());
{ }
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.type()); // resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns =
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.type());
output = output.view({output.size(0), group, output.size(1) / group,
output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
} }
// resize output weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
output = output.view({batch, channels_out, height_out, width_out}).zero_(); weight.size(3), weight.size(4)});
// resize temporary columns columns =
columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type()); columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), group, output.size(1) / group, output.size(2), output.size(3)});
for (int b = 0; b < batch; b++)
{
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group, columns);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++){ output = output.view({output.size(0), output.size(1) * output.size(2),
output[b][g] = output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]); output.size(3), output.size(4)});
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), weight.size(3), weight.size(4)}); if (with_bias) {
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); output += bias.view({1, bias.size(0), 1, 1});
} }
output = output.view({output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)});
if (with_bias){
output += bias.view({1, bias.size(0), 1, 1});
}
} }
void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight, void modulated_deform_conv_cuda_backward(
at::Tensor bias, at::Tensor ones, at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor columns, at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
at::Tensor grad_bias, at::Tensor grad_offset, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
at::Tensor grad_mask, at::Tensor grad_output, int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
int kernel_h, int kernel_w, const bool with_bias) {
int stride_h, int stride_w, AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
int pad_h, int pad_w, AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
int dilation_h, int dilation_w, int group,
int deformable_group, const bool with_bias) const int batch = input.size(0);
{ const int channels = input.size(1);
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); const int height = input.size(2);
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); const int width = input.size(3);
const int batch = input.size(0); const int channels_kernel = weight.size(1);
const int channels = input.size(1); const int kernel_h_ = weight.size(2);
const int height = input.size(2); const int kernel_w_ = weight.size(3);
const int width = input.size(3); if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
const int channels_kernel = weight.size(1); kernel_h_, kernel_w, kernel_h_, kernel_w_);
const int kernel_h_ = weight.size(2); if (channels != channels_kernel * group)
const int kernel_w_ = weight.size(3); AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) channels, channels_kernel * group);
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_); const int height_out =
if (channels != channels_kernel * group) (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", const int width_out =
channels, channels_kernel * group); (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; if (ones.ndimension() != 2 ||
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
if (ones.ndimension() != 2 || ones = at::ones({height_out, width_out}, input.type());
ones.size(0) * ones.size(1) < height_out * width_out) }
{
// Resize plane and fill with ones... grad_input = grad_input.view({batch, channels, height, width});
ones = at::ones({height_out, width_out}, input.type()); columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
input.type());
grad_output =
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1), 0.0f, 1.0f);
} }
grad_input = grad_input.view({batch, channels, height, width}); columns =
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, input.type()); columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
grad_output = grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, grad_output.size(2), grad_output.size(3)}); weight.size(3), weight.size(4)});
for (int b = 0; b < batch; b++) // gradient w.r.t. input coordinate data
{ modulated_deformable_col2im_coord_cuda(
// divide int group columns, input[b], offset[b], mask[b], 1, channels, height, width,
columns = columns.view({group, columns.size(0) / group, columns.size(1)}); height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
for (int g = 0; g < group; g++){ // gradient w.r.t. input data
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), grad_output[b][g].flatten(1), 0.0f, 1.0f); modulated_deformable_col2im_cuda(
} columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); dilation_h, dilation_w, deformable_group, grad_input[b]);
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), weight.size(3), weight.size(4)});
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// gradient w.r.t. input coordinate data // group
modulated_deformable_col2im_coord_cuda(columns, input[b], offset[b], mask[b], modulated_deformable_im2col_cuda(
1, channels, height, width, input[b], offset[b], mask[b], 1, channels, height, width, height_out,
height_out, width_out, kernel_h, kernel_w, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns);
dilation_h, dilation_w, deformable_group,
grad_offset[b], grad_mask[b]); columns = columns.view({group, columns.size(0) / group, columns.size(1)});
// gradient w.r.t. input data grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
modulated_deformable_col2im_cuda(columns, offset[b], mask[b], grad_weight.size(1), grad_weight.size(2),
1, channels, height, width, grad_weight.size(3)});
height_out, width_out, kernel_h, kernel_w, if (with_bias)
pad_h, pad_w, stride_h, stride_w, grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
dilation_h, dilation_w, deformable_group,
grad_input[b]); for (int g = 0; g < group; g++) {
grad_weight[g] =
// gradient w.r.t. weight, dWeight should accumulate across the batch and group grad_weight[g]
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b], .flatten(1)
1, channels, height, width, .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
height_out, width_out, kernel_h, kernel_w, .view_as(grad_weight[g]);
pad_h, pad_w, stride_h, stride_w, if (with_bias) {
dilation_h, dilation_w, deformable_group, grad_bias[g] =
columns); grad_bias[g]
.view({-1, 1})
columns = columns.view({group, columns.size(0) / group, columns.size(1)}); .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
grad_weight = grad_weight.view({group, grad_weight.size(0) / group, grad_weight.size(1), grad_weight.size(2), grad_weight.size(3)}); .view(-1);
if (with_bias) }
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++){
grad_weight[g] = grad_weight[g].flatten(1).addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)).view_as(grad_weight[g]);
if (with_bias){
grad_bias[g] = grad_bias[g].view({-1, 1}).addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})).view(-1);
}
}
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), grad_weight.size(2), grad_weight.size(3), grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
} }
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), grad_output.size(2), grad_output.size(3), grad_output.size(4)});
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
{ m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, "deform forward (CUDA)"); "deform forward (CUDA)");
m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda, m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
"deform_conv_backward_input (CUDA)"); "deform_conv_backward_input (CUDA)");
m.def("deform_conv_backward_parameters_cuda", &deform_conv_backward_parameters_cuda, m.def("deform_conv_backward_parameters_cuda",
"deform_conv_backward_parameters (CUDA)"); &deform_conv_backward_parameters_cuda,
m.def("modulated_deform_conv_cuda_forward", &modulated_deform_conv_cuda_forward, "deform_conv_backward_parameters (CUDA)");
"modulated deform conv forward (CUDA)"); m.def("modulated_deform_conv_cuda_forward",
m.def("modulated_deform_conv_cuda_backward", &modulated_deform_conv_cuda_backward, &modulated_deform_conv_cuda_forward,
"modulated deform conv backward (CUDA)"); "modulated deform conv forward (CUDA)");
m.def("modulated_deform_conv_cuda_backward",
&modulated_deform_conv_cuda_backward,
"modulated deform conv backward (CUDA)");
} }
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
// based on
// author: Charles Shang // author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob /mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c #include <torch/extension.h>
#include <torch/torch.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
void DeformablePSROIPoolForward(const at::Tensor data, void DeformablePSROIPoolForward(
const at::Tensor bbox, const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
const at::Tensor trans, at::Tensor out, at::Tensor top_count, const int batch, const int channels,
at::Tensor out, const int height, const int width, const int num_bbox,
at::Tensor top_count, const int channels_trans, const int no_trans, const float spatial_scale,
const int batch, const int output_dim, const int group_size, const int pooled_size,
const int channels, const int part_size, const int sample_per_part, const float trans_std);
const int height,
const int width,
const int num_bbox,
const int channels_trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std);
void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad, void DeformablePSROIPoolBackwardAcc(
const at::Tensor data, const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
const at::Tensor bbox, const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
const at::Tensor trans, at::Tensor trans_grad, const int batch, const int channels,
const at::Tensor top_count, const int height, const int width, const int num_bbox,
at::Tensor in_grad, const int channels_trans, const int no_trans, const float spatial_scale,
at::Tensor trans_grad, const int output_dim, const int group_size, const int pooled_size,
const int batch, const int part_size, const int sample_per_part, const float trans_std);
const int channels,
const int height,
const int width,
const int num_bbox,
const int channels_trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std);
void deform_psroi_pooling_cuda_forward(at::Tensor input, at::Tensor bbox, void deform_psroi_pooling_cuda_forward(
at::Tensor trans, at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
at::Tensor out, at::Tensor top_count, at::Tensor top_count, const int no_trans, const float spatial_scale,
const int no_trans, const int output_dim, const int group_size, const int pooled_size,
const float spatial_scale, const int part_size, const int sample_per_part, const float trans_std) {
const int output_dim, AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int batch = input.size(0); const int batch = input.size(0);
const int channels = input.size(1); const int channels = input.size(1);
const int height = input.size(2); const int height = input.size(2);
const int width = input.size(3); const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1); const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0); const int num_bbox = bbox.size(0);
if (num_bbox != out.size(0)) if (num_bbox != out.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out.size(0), num_bbox); out.size(0), num_bbox);
DeformablePSROIPoolForward(input, bbox, trans, out, top_count, DeformablePSROIPoolForward(
batch, channels, height, width, input, bbox, trans, out, top_count, batch, channels, height, width,
num_bbox, num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
channels_trans, pooled_size, part_size, sample_per_part, trans_std);
no_trans,
spatial_scale,
output_dim,
group_size,
pooled_size,
part_size,
sample_per_part,
trans_std);
} }
void deform_psroi_pooling_cuda_backward(at::Tensor out_grad, void deform_psroi_pooling_cuda_backward(
at::Tensor input, at::Tensor bbox, at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
at::Tensor trans, at::Tensor top_count, at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
at::Tensor input_grad, at::Tensor trans_grad, const int no_trans, const float spatial_scale, const int output_dim,
const int no_trans, const int group_size, const int pooled_size, const int part_size,
const float spatial_scale, const int sample_per_part, const float trans_std) {
const int output_dim, AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
const int group_size, AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int batch = input.size(0); const int batch = input.size(0);
const int channels = input.size(1); const int channels = input.size(1);
const int height = input.size(2); const int height = input.size(2);
const int width = input.size(3); const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1); const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0); const int num_bbox = bbox.size(0);
if (num_bbox != out_grad.size(0)) if (num_bbox != out_grad.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out_grad.size(0), num_bbox); out_grad.size(0), num_bbox);
DeformablePSROIPoolBackwardAcc(out_grad, DeformablePSROIPoolBackwardAcc(
input, out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
bbox, channels, height, width, num_bbox, channels_trans, no_trans,
trans, spatial_scale, output_dim, group_size, pooled_size, part_size,
top_count, sample_per_part, trans_std);
input_grad,
trans_grad,
batch, channels, height, width, num_bbox,
channels_trans,
no_trans,
spatial_scale,
output_dim,
group_size,
pooled_size,
part_size,
sample_per_part,
trans_std);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
{ m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward, "deform psroi pooling forward(CUDA)");
"deform psroi pooling forward(CUDA)"); m.def("deform_psroi_pooling_cuda_backward",
m.def("deform_psroi_pooling_cuda_backward", &deform_psroi_pooling_cuda_backward, &deform_psroi_pooling_cuda_backward,
"deform psroi pooling backward(CUDA)"); "deform psroi pooling backward(CUDA)");
} }
\ No newline at end of file
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
# Written by Ross Girshick # Written by Ross Girshick
# -------------------------------------------------------- # --------------------------------------------------------
# cython: language_level=3, boundscheck=False
import numpy as np import numpy as np
cimport numpy as np cimport numpy as np
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
# Modified by Kai Chen # Modified by Kai Chen
# ---------------------------------------------------------- # ----------------------------------------------------------
# cython: language_level=3, boundscheck=False
import numpy as np import numpy as np
cimport numpy as np cimport numpy as np
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
# Written by Ross Girshick # Written by Ross Girshick
# -------------------------------------------------------- # --------------------------------------------------------
# cython: language_level=3, boundscheck=False
import numpy as np import numpy as np
cimport numpy as np cimport numpy as np
......
from torch.autograd import Function, Variable from torch.autograd import Function
from .. import roi_align_cuda from .. import roi_align_cuda
...@@ -49,11 +49,11 @@ class RoIAlignFunction(Function): ...@@ -49,11 +49,11 @@ class RoIAlignFunction(Function):
grad_input = grad_rois = None grad_input = grad_rois = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_input = Variable( grad_input = rois.new_zeros(batch_size, num_channels, data_height,
rois.new(batch_size, num_channels, data_height, data_width) data_width)
.zero_()) roi_align_cuda.backward(grad_output.contiguous(), rois, out_h,
roi_align_cuda.backward(grad_output, rois, out_h, out_w, out_w, spatial_scale, sample_num,
spatial_scale, sample_num, grad_input) grad_input)
return grad_input, grad_rois, None, None, None return grad_input, grad_rois, None, None, None
......
#include <torch/torch.h> #include <torch/extension.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)
#define CUDA_1D_KERNEL_LOOP(i, n) \ #define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
...@@ -144,12 +142,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois, ...@@ -144,12 +142,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
sample_num, channels, height, width, pooled_height, sample_num, channels, height, width, pooled_height,
pooled_width, top_data); pooled_width, top_data);
})); }));
cudaError_t err = cudaGetLastError(); THCudaCheck(cudaGetLastError());
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 1; return 1;
} }
...@@ -280,8 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -280,8 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
at::Tensor bottom_grad) { at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels; const int output_size = num_rois * pooled_height * pooled_width * channels;
// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES(
top_grad.type(), "ROIAlignLaucherBackward", ([&] { top_grad.type(), "ROIAlignLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>(); const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>(); const scalar_t *rois_data = rois.data<scalar_t>();
...@@ -297,11 +289,6 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -297,11 +289,6 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
channels, height, width, pooled_height, pooled_width, channels, height, width, pooled_height, pooled_width,
bottom_diff); bottom_diff);
})); }));
cudaError_t err = cudaGetLastError(); THCudaCheck(cudaGetLastError());
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 1; return 1;
} }
...@@ -24,9 +24,8 @@ class RoIPoolFunction(Function): ...@@ -24,9 +24,8 @@ class RoIPoolFunction(Function):
num_channels = features.size(1) num_channels = features.size(1)
num_rois = rois.size(0) num_rois = rois.size(0)
out_size = (num_rois, num_channels, out_h, out_w) out_size = (num_rois, num_channels, out_h, out_w)
output = features.new_zeros(*out_size) output = features.new_zeros(out_size)
argmax = features.new_zeros(out_size, dtype=torch.int)
argmax = features.new_zeros(*out_size, dtype=torch.int)
roi_pool_cuda.forward(features, rois, out_h, out_w, spatial_scale, roi_pool_cuda.forward(features, rois, out_h, out_w, spatial_scale,
output, argmax) output, argmax)
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
...@@ -46,9 +45,9 @@ class RoIPoolFunction(Function): ...@@ -46,9 +45,9 @@ class RoIPoolFunction(Function):
grad_input = grad_rois = None grad_input = grad_rois = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_input = grad_output.new(feature_size).zero_() grad_input = grad_output.new_zeros(feature_size)
roi_pool_cuda.backward(grad_output, rois, argmax, spatial_scale, roi_pool_cuda.backward(grad_output.contiguous(), rois, argmax,
grad_input) spatial_scale, grad_input)
return grad_input, grad_rois, None, None return grad_input, grad_rois, None, None
......
#include <torch/torch.h> #include <torch/extension.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)
#define CUDA_1D_KERNEL_LOOP(i, n) \ #define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
...@@ -100,11 +98,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois, ...@@ -100,11 +98,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
channels, height, width, pooled_h, pooled_w, top_data, channels, height, width, pooled_h, pooled_w, top_data,
argmax_data); argmax_data);
})); }));
cudaError_t err = cudaGetLastError(); THCudaCheck(cudaGetLastError());
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 1; return 1;
} }
...@@ -139,8 +133,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -139,8 +133,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int pooled_w, at::Tensor bottom_grad) { const int pooled_w, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_h * pooled_w * channels; const int output_size = num_rois * pooled_h * pooled_w * channels;
// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES(
top_grad.type(), "ROIPoolLaucherBackward", ([&] { top_grad.type(), "ROIPoolLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>(); const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>(); const scalar_t *rois_data = rois.data<scalar_t>();
...@@ -158,11 +151,6 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -158,11 +151,6 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
scalar_t(spatial_scale), channels, height, width, pooled_h, scalar_t(spatial_scale), channels, height, width, pooled_h,
pooled_w, bottom_diff); pooled_w, bottom_diff);
})); }));
cudaError_t err = cudaGetLastError(); THCudaCheck(cudaGetLastError());
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 1; return 1;
} }
...@@ -11,8 +11,8 @@ def readme(): ...@@ -11,8 +11,8 @@ def readme():
MAJOR = 0 MAJOR = 0
MINOR = 5 MINOR = 6
PATCH = 7 PATCH = 'rc0'
SUFFIX = '' SUFFIX = ''
SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX) SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment