Commit c5c7ef90 authored by Ligeng Zhu's avatar Ligeng Zhu Committed by Kai Chen
Browse files

[DeformConv] Fix zero outputs when not running on cuda:0 (#1326)

* Update deform_conv_cuda.cpp

* Update deform_pool_cuda.cpp
parent df1b9043
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/DeviceGuard.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
...@@ -162,7 +163,8 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, ...@@ -162,7 +163,8 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group); dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous(); input = input.contiguous();
offset = offset.contiguous(); offset = offset.contiguous();
weight = weight.contiguous(); weight = weight.contiguous();
...@@ -266,6 +268,7 @@ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, ...@@ -266,6 +268,7 @@ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
int deformable_group, int im2col_step) { int deformable_group, int im2col_step) {
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group); dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous(); input = input.contiguous();
offset = offset.contiguous(); offset = offset.contiguous();
...@@ -382,7 +385,8 @@ int deform_conv_backward_parameters_cuda( ...@@ -382,7 +385,8 @@ int deform_conv_backward_parameters_cuda(
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, group, deformable_group); padW, dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous(); input = input.contiguous();
offset = offset.contiguous(); offset = offset.contiguous();
gradOutput = gradOutput.contiguous(); gradOutput = gradOutput.contiguous();
...@@ -492,7 +496,8 @@ void modulated_deform_conv_cuda_forward( ...@@ -492,7 +496,8 @@ void modulated_deform_conv_cuda_forward(
const bool with_bias) { const bool with_bias) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0); const int batch = input.size(0);
const int channels = input.size(1); const int channels = input.size(1);
const int height = input.size(2); const int height = input.size(2);
...@@ -573,6 +578,7 @@ void modulated_deform_conv_cuda_backward( ...@@ -573,6 +578,7 @@ void modulated_deform_conv_cuda_backward(
const bool with_bias) { const bool with_bias) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0); const int batch = input.size(0);
const int channels = input.size(1); const int channels = input.size(1);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/DeviceGuard.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
...@@ -33,6 +34,7 @@ void deform_psroi_pooling_cuda_forward( ...@@ -33,6 +34,7 @@ void deform_psroi_pooling_cuda_forward(
const int output_dim, const int group_size, const int pooled_size, const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std) { const int part_size, const int sample_per_part, const float trans_std) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0); const int batch = input.size(0);
const int channels = input.size(1); const int channels = input.size(1);
...@@ -59,6 +61,7 @@ void deform_psroi_pooling_cuda_backward( ...@@ -59,6 +61,7 @@ void deform_psroi_pooling_cuda_backward(
const int sample_per_part, const float trans_std) { const int sample_per_part, const float trans_std) {
AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0); const int batch = input.size(0);
const int channels = input.size(1); const int channels = input.size(1);
...@@ -84,4 +87,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -84,4 +87,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deform_psroi_pooling_cuda_backward", m.def("deform_psroi_pooling_cuda_backward",
&deform_psroi_pooling_cuda_backward, &deform_psroi_pooling_cuda_backward,
"deform psroi pooling backward(CUDA)"); "deform psroi pooling backward(CUDA)");
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment