Unverified Commit c031e287 authored by Yuxin Wu's avatar Yuxin Wu Committed by GitHub
Browse files

fix the use of contiguous() in kernels (#2131)



* fix the use of contiguous() in kernels

* clang-format

* add a contiguous in nms
Co-authored-by: default avatarYuxin Wu <ppwwyyxx@users.noreply.github.com>
parent d6ee8757
...@@ -336,11 +336,12 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu( ...@@ -336,11 +336,12 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
return std::make_tuple(output, channel_mapping); return std::make_tuple(output, channel_mapping);
} }
auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIAlign_forward", [&] { input.scalar_type(), "PSROIAlign_forward", [&] {
PSROIAlignForwardCPU<scalar_t>( PSROIAlignForwardCPU<scalar_t>(
output_size, output_size,
input.contiguous().data_ptr<scalar_t>(), input_.data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
...@@ -348,7 +349,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu( ...@@ -348,7 +349,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
pooled_height, pooled_height,
pooled_width, pooled_width,
sampling_ratio, sampling_ratio,
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
channels_out, channels_out,
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>()); channel_mapping.data_ptr<int>());
...@@ -392,11 +393,12 @@ at::Tensor PSROIAlign_backward_cpu( ...@@ -392,11 +393,12 @@ at::Tensor PSROIAlign_backward_cpu(
int channels_out = channels / (pooled_height * pooled_width); int channels_out = channels / (pooled_height * pooled_width);
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIAlign_backward", [&] { grad.scalar_type(), "PSROIAlign_backward", [&] {
PSROIAlignBackwardCPU<scalar_t>( PSROIAlignBackwardCPU<scalar_t>(
grad.numel(), grad.numel(),
grad.contiguous().data_ptr<scalar_t>(), grad_.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>(), channel_mapping.data_ptr<int>(),
num_rois, num_rois,
spatial_scale, spatial_scale,
...@@ -408,7 +410,7 @@ at::Tensor PSROIAlign_backward_cpu( ...@@ -408,7 +410,7 @@ at::Tensor PSROIAlign_backward_cpu(
sampling_ratio, sampling_ratio,
channels_out, channels_out,
grad_input.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>()); rois_.data_ptr<scalar_t>());
}); });
return grad_input; return grad_input;
} }
...@@ -178,17 +178,18 @@ std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu( ...@@ -178,17 +178,18 @@ std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
return std::make_tuple(output, channel_mapping); return std::make_tuple(output, channel_mapping);
} }
auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIPool_forward", [&] { input.scalar_type(), "PSROIPool_forward", [&] {
PSROIPoolForward<scalar_t>( PSROIPoolForward<scalar_t>(
input.contiguous().data_ptr<scalar_t>(), input_.data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
width, width,
pooled_height, pooled_height,
pooled_width, pooled_width,
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
channels_out, channels_out,
num_rois, num_rois,
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
...@@ -232,10 +233,11 @@ at::Tensor PSROIPool_backward_cpu( ...@@ -232,10 +233,11 @@ at::Tensor PSROIPool_backward_cpu(
int channels_out = channels / (pooled_height * pooled_width); int channels_out = channels / (pooled_height * pooled_width);
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIPool_backward", [&] { grad.scalar_type(), "PSROIPool_backward", [&] {
PSROIPoolBackward<scalar_t>( PSROIPoolBackward<scalar_t>(
grad.contiguous().data_ptr<scalar_t>(), grad_.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>(), channel_mapping.data_ptr<int>(),
num_rois, num_rois,
spatial_scale, spatial_scale,
...@@ -246,7 +248,7 @@ at::Tensor PSROIPool_backward_cpu( ...@@ -246,7 +248,7 @@ at::Tensor PSROIPool_backward_cpu(
pooled_width, pooled_width,
channels_out, channels_out,
grad_input.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>()); rois_.data_ptr<scalar_t>());
}); });
return grad_input; return grad_input;
} }
...@@ -407,11 +407,12 @@ at::Tensor ROIAlign_forward_cpu( ...@@ -407,11 +407,12 @@ at::Tensor ROIAlign_forward_cpu(
if (output.numel() == 0) if (output.numel() == 0)
return output; return output;
auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ROIAlign_forward", [&] { input.scalar_type(), "ROIAlign_forward", [&] {
ROIAlignForward<scalar_t>( ROIAlignForward<scalar_t>(
output_size, output_size,
input.contiguous().data_ptr<scalar_t>(), input_.data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
...@@ -420,7 +421,7 @@ at::Tensor ROIAlign_forward_cpu( ...@@ -420,7 +421,7 @@ at::Tensor ROIAlign_forward_cpu(
pooled_width, pooled_width,
sampling_ratio, sampling_ratio,
aligned, aligned,
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>()); output.data_ptr<scalar_t>());
}); });
return output; return output;
...@@ -460,6 +461,7 @@ at::Tensor ROIAlign_backward_cpu( ...@@ -460,6 +461,7 @@ at::Tensor ROIAlign_backward_cpu(
int h_stride = grad.stride(2); int h_stride = grad.stride(2);
int w_stride = grad.stride(3); int w_stride = grad.stride(3);
auto rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "ROIAlign_forward", [&] { grad.scalar_type(), "ROIAlign_forward", [&] {
ROIAlignBackward<scalar_t>( ROIAlignBackward<scalar_t>(
...@@ -474,7 +476,7 @@ at::Tensor ROIAlign_backward_cpu( ...@@ -474,7 +476,7 @@ at::Tensor ROIAlign_backward_cpu(
sampling_ratio, sampling_ratio,
aligned, aligned,
grad_input.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
n_stride, n_stride,
c_stride, c_stride,
h_stride, h_stride,
......
...@@ -149,17 +149,18 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu( ...@@ -149,17 +149,18 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
return std::make_tuple(output, argmax); return std::make_tuple(output, argmax);
} }
auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ROIPool_forward", [&] { input.scalar_type(), "ROIPool_forward", [&] {
RoIPoolForward<scalar_t>( RoIPoolForward<scalar_t>(
input.contiguous().data_ptr<scalar_t>(), input_.data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
width, width,
pooled_height, pooled_height,
pooled_width, pooled_width,
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
num_rois, num_rois,
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
argmax.data_ptr<int>()); argmax.data_ptr<int>());
...@@ -204,6 +205,7 @@ at::Tensor ROIPool_backward_cpu( ...@@ -204,6 +205,7 @@ at::Tensor ROIPool_backward_cpu(
int h_stride = grad.stride(2); int h_stride = grad.stride(2);
int w_stride = grad.stride(3); int w_stride = grad.stride(3);
auto rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "ROIPool_backward", [&] { grad.scalar_type(), "ROIPool_backward", [&] {
RoIPoolBackward<scalar_t>( RoIPoolBackward<scalar_t>(
...@@ -216,7 +218,7 @@ at::Tensor ROIPool_backward_cpu( ...@@ -216,7 +218,7 @@ at::Tensor ROIPool_backward_cpu(
pooled_height, pooled_height,
pooled_width, pooled_width,
grad_input.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
n_stride, n_stride,
c_stride, c_stride,
h_stride, h_stride,
......
...@@ -342,11 +342,13 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda( ...@@ -342,11 +342,13 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
static_cast<int64_t>(4096))); static_cast<int64_t>(4096)));
dim3 block(512); dim3 block(512);
auto input_ = input.contiguous(),
rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIAlign_forward", [&] { input.scalar_type(), "PSROIAlign_forward", [&] {
PSROIAlignForwardCUDA<scalar_t><<<grid, block, 0, stream>>>( PSROIAlignForwardCUDA<scalar_t><<<grid, block, 0, stream>>>(
output_size, output_size,
input.contiguous().data_ptr<scalar_t>(), input_.data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
...@@ -354,7 +356,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda( ...@@ -354,7 +356,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
pooled_height, pooled_height,
pooled_width, pooled_width,
sampling_ratio, sampling_ratio,
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
channels_out, channels_out,
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>()); channel_mapping.data_ptr<int>());
...@@ -412,11 +414,13 @@ at::Tensor PSROIAlign_backward_cuda( ...@@ -412,11 +414,13 @@ at::Tensor PSROIAlign_backward_cuda(
int channels_out = channels / (pooled_height * pooled_width); int channels_out = channels / (pooled_height * pooled_width);
auto grad_ = grad.contiguous(),
rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIAlign_backward", [&] { grad.scalar_type(), "PSROIAlign_backward", [&] {
PSROIAlignBackwardCUDA<scalar_t><<<grid, block, 0, stream>>>( PSROIAlignBackwardCUDA<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(), grad.numel(),
grad.contiguous().data_ptr<scalar_t>(), grad_.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>(), channel_mapping.data_ptr<int>(),
num_rois, num_rois,
spatial_scale, spatial_scale,
...@@ -428,7 +432,7 @@ at::Tensor PSROIAlign_backward_cuda( ...@@ -428,7 +432,7 @@ at::Tensor PSROIAlign_backward_cuda(
sampling_ratio, sampling_ratio,
channels_out, channels_out,
grad_input.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>()); rois_.data_ptr<scalar_t>());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return grad_input; return grad_input;
......
...@@ -179,18 +179,20 @@ std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda( ...@@ -179,18 +179,20 @@ std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda(
static_cast<int64_t>(4096))); static_cast<int64_t>(4096)));
dim3 block(512); dim3 block(512);
auto input_ = input.contiguous(),
rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIPool_forward", [&] { input.scalar_type(), "PSROIPool_forward", [&] {
PSROIPoolForward<scalar_t><<<grid, block, 0, stream>>>( PSROIPoolForward<scalar_t><<<grid, block, 0, stream>>>(
output_size, output_size,
input.contiguous().data_ptr<scalar_t>(), input_.data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
width, width,
pooled_height, pooled_height,
pooled_width, pooled_width,
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
channels_out, channels_out,
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>()); channel_mapping.data_ptr<int>());
...@@ -246,11 +248,13 @@ at::Tensor PSROIPool_backward_cuda( ...@@ -246,11 +248,13 @@ at::Tensor PSROIPool_backward_cuda(
int channels_out = channels / (pooled_height * pooled_width); int channels_out = channels / (pooled_height * pooled_width);
auto grad_ = grad.contiguous(),
rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIPool_backward", [&] { grad.scalar_type(), "PSROIPool_backward", [&] {
PSROIPoolBackward<scalar_t><<<grid, block, 0, stream>>>( PSROIPoolBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(), grad.numel(),
grad.contiguous().data_ptr<scalar_t>(), grad_.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>(), channel_mapping.data_ptr<int>(),
num_rois, num_rois,
spatial_scale, spatial_scale,
...@@ -261,7 +265,7 @@ at::Tensor PSROIPool_backward_cuda( ...@@ -261,7 +265,7 @@ at::Tensor PSROIPool_backward_cuda(
pooled_width, pooled_width,
channels_out, channels_out,
grad_input.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>()); rois_.data_ptr<scalar_t>());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return grad_input; return grad_input;
......
...@@ -345,10 +345,12 @@ at::Tensor ROIAlign_forward_cuda( ...@@ -345,10 +345,12 @@ at::Tensor ROIAlign_forward_cuda(
return output; return output;
} }
auto input_ = input.contiguous(),
rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "ROIAlign_forward", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "ROIAlign_forward", [&] {
RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>( RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
output_size, output_size,
input.contiguous().data_ptr<scalar_t>(), input_.data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
...@@ -357,7 +359,7 @@ at::Tensor ROIAlign_forward_cuda( ...@@ -357,7 +359,7 @@ at::Tensor ROIAlign_forward_cuda(
pooled_width, pooled_width,
sampling_ratio, sampling_ratio,
aligned, aligned,
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>()); output.data_ptr<scalar_t>());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
...@@ -409,6 +411,7 @@ at::Tensor ROIAlign_backward_cuda( ...@@ -409,6 +411,7 @@ at::Tensor ROIAlign_backward_cuda(
int h_stride = grad.stride(2); int h_stride = grad.stride(2);
int w_stride = grad.stride(3); int w_stride = grad.stride(3);
auto rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "ROIAlign_backward", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "ROIAlign_backward", [&] {
RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>( RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(), grad.numel(),
...@@ -422,7 +425,7 @@ at::Tensor ROIAlign_backward_cuda( ...@@ -422,7 +425,7 @@ at::Tensor ROIAlign_backward_cuda(
sampling_ratio, sampling_ratio,
aligned, aligned,
grad_input.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
n_stride, n_stride,
c_stride, c_stride,
h_stride, h_stride,
......
...@@ -157,17 +157,19 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda( ...@@ -157,17 +157,19 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
return std::make_tuple(output, argmax); return std::make_tuple(output, argmax);
} }
auto input_ = input.contiguous(),
rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "ROIPool_forward", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "ROIPool_forward", [&] {
RoIPoolForward<scalar_t><<<grid, block, 0, stream>>>( RoIPoolForward<scalar_t><<<grid, block, 0, stream>>>(
output_size, output_size,
input.contiguous().data_ptr<scalar_t>(), input_.data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
width, width,
pooled_height, pooled_height,
pooled_width, pooled_width,
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
argmax.data_ptr<int>()); argmax.data_ptr<int>());
}); });
...@@ -224,11 +226,13 @@ at::Tensor ROIPool_backward_cuda( ...@@ -224,11 +226,13 @@ at::Tensor ROIPool_backward_cuda(
int h_stride = grad.stride(2); int h_stride = grad.stride(2);
int w_stride = grad.stride(3); int w_stride = grad.stride(3);
auto argmax_ = argmax.contiguous(),
rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "ROIPool_backward", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "ROIPool_backward", [&] {
RoIPoolBackward<scalar_t><<<grid, block, 0, stream>>>( RoIPoolBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(), grad.numel(),
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
argmax.contiguous().data_ptr<int>(), argmax_.data_ptr<int>(),
num_rois, num_rois,
spatial_scale, spatial_scale,
channels, channels,
...@@ -237,7 +241,7 @@ at::Tensor ROIPool_backward_cuda( ...@@ -237,7 +241,7 @@ at::Tensor ROIPool_backward_cuda(
pooled_height, pooled_height,
pooled_width, pooled_width,
grad_input.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>(), rois_.data_ptr<scalar_t>(),
n_stride, n_stride,
c_stride, c_stride,
h_stride, h_stride,
......
...@@ -77,7 +77,7 @@ at::Tensor nms_cuda(const at::Tensor& dets, ...@@ -77,7 +77,7 @@ at::Tensor nms_cuda(const at::Tensor& dets,
at::cuda::CUDAGuard device_guard(dets.device()); at::cuda::CUDAGuard device_guard(dets.device());
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t); auto dets_sorted = dets.index_select(0, order_t).contiguous();
int dets_num = dets.size(0); int dets_num = dets.size(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment