Unverified Commit 561a014b authored by Mikhail Lobanov's avatar Mikhail Lobanov Committed by GitHub
Browse files

Fix Tensor::data<> deprecation. (#2028)

parent c6c3294d
...@@ -340,7 +340,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu( ...@@ -340,7 +340,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
input.scalar_type(), "PSROIAlign_forward", [&] { input.scalar_type(), "PSROIAlign_forward", [&] {
PSROIAlignForwardCPU<scalar_t>( PSROIAlignForwardCPU<scalar_t>(
output_size, output_size,
input.contiguous().data<scalar_t>(), input.contiguous().data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
...@@ -348,10 +348,10 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu( ...@@ -348,10 +348,10 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
pooled_height, pooled_height,
pooled_width, pooled_width,
sampling_ratio, sampling_ratio,
rois.contiguous().data<scalar_t>(), rois.contiguous().data_ptr<scalar_t>(),
channels_out, channels_out,
output.data<scalar_t>(), output.data_ptr<scalar_t>(),
channel_mapping.data<int>()); channel_mapping.data_ptr<int>());
}); });
return std::make_tuple(output, channel_mapping); return std::make_tuple(output, channel_mapping);
} }
...@@ -396,8 +396,8 @@ at::Tensor PSROIAlign_backward_cpu( ...@@ -396,8 +396,8 @@ at::Tensor PSROIAlign_backward_cpu(
grad.scalar_type(), "PSROIAlign_backward", [&] { grad.scalar_type(), "PSROIAlign_backward", [&] {
PSROIAlignBackwardCPU<scalar_t>( PSROIAlignBackwardCPU<scalar_t>(
grad.numel(), grad.numel(),
grad.contiguous().data<scalar_t>(), grad.contiguous().data_ptr<scalar_t>(),
channel_mapping.data<int>(), channel_mapping.data_ptr<int>(),
num_rois, num_rois,
spatial_scale, spatial_scale,
channels, channels,
...@@ -407,8 +407,8 @@ at::Tensor PSROIAlign_backward_cpu( ...@@ -407,8 +407,8 @@ at::Tensor PSROIAlign_backward_cpu(
pooled_width, pooled_width,
sampling_ratio, sampling_ratio,
channels_out, channels_out,
grad_input.data<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data<scalar_t>()); rois.contiguous().data_ptr<scalar_t>());
}); });
return grad_input; return grad_input;
} }
...@@ -181,18 +181,18 @@ std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu( ...@@ -181,18 +181,18 @@ std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIPool_forward", [&] { input.scalar_type(), "PSROIPool_forward", [&] {
PSROIPoolForward<scalar_t>( PSROIPoolForward<scalar_t>(
input.contiguous().data<scalar_t>(), input.contiguous().data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
width, width,
pooled_height, pooled_height,
pooled_width, pooled_width,
rois.contiguous().data<scalar_t>(), rois.contiguous().data_ptr<scalar_t>(),
channels_out, channels_out,
num_rois, num_rois,
output.data<scalar_t>(), output.data_ptr<scalar_t>(),
channel_mapping.data<int>()); channel_mapping.data_ptr<int>());
}); });
return std::make_tuple(output, channel_mapping); return std::make_tuple(output, channel_mapping);
} }
...@@ -235,8 +235,8 @@ at::Tensor PSROIPool_backward_cpu( ...@@ -235,8 +235,8 @@ at::Tensor PSROIPool_backward_cpu(
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIPool_backward", [&] { grad.scalar_type(), "PSROIPool_backward", [&] {
PSROIPoolBackward<scalar_t>( PSROIPoolBackward<scalar_t>(
grad.contiguous().data<scalar_t>(), grad.contiguous().data_ptr<scalar_t>(),
channel_mapping.data<int>(), channel_mapping.data_ptr<int>(),
num_rois, num_rois,
spatial_scale, spatial_scale,
channels, channels,
...@@ -245,8 +245,8 @@ at::Tensor PSROIPool_backward_cpu( ...@@ -245,8 +245,8 @@ at::Tensor PSROIPool_backward_cpu(
pooled_height, pooled_height,
pooled_width, pooled_width,
channels_out, channels_out,
grad_input.data<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data<scalar_t>()); rois.contiguous().data_ptr<scalar_t>());
}); });
return grad_input; return grad_input;
} }
...@@ -346,7 +346,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda( ...@@ -346,7 +346,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
input.scalar_type(), "PSROIAlign_forward", [&] { input.scalar_type(), "PSROIAlign_forward", [&] {
PSROIAlignForwardCUDA<scalar_t><<<grid, block, 0, stream>>>( PSROIAlignForwardCUDA<scalar_t><<<grid, block, 0, stream>>>(
output_size, output_size,
input.contiguous().data<scalar_t>(), input.contiguous().data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
...@@ -354,10 +354,10 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda( ...@@ -354,10 +354,10 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
pooled_height, pooled_height,
pooled_width, pooled_width,
sampling_ratio, sampling_ratio,
rois.contiguous().data<scalar_t>(), rois.contiguous().data_ptr<scalar_t>(),
channels_out, channels_out,
output.data<scalar_t>(), output.data_ptr<scalar_t>(),
channel_mapping.data<int>()); channel_mapping.data_ptr<int>());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
...@@ -416,8 +416,8 @@ at::Tensor PSROIAlign_backward_cuda( ...@@ -416,8 +416,8 @@ at::Tensor PSROIAlign_backward_cuda(
grad.scalar_type(), "PSROIAlign_backward", [&] { grad.scalar_type(), "PSROIAlign_backward", [&] {
PSROIAlignBackwardCUDA<scalar_t><<<grid, block, 0, stream>>>( PSROIAlignBackwardCUDA<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(), grad.numel(),
grad.contiguous().data<scalar_t>(), grad.contiguous().data_ptr<scalar_t>(),
channel_mapping.data<int>(), channel_mapping.data_ptr<int>(),
num_rois, num_rois,
spatial_scale, spatial_scale,
channels, channels,
...@@ -427,8 +427,8 @@ at::Tensor PSROIAlign_backward_cuda( ...@@ -427,8 +427,8 @@ at::Tensor PSROIAlign_backward_cuda(
pooled_width, pooled_width,
sampling_ratio, sampling_ratio,
channels_out, channels_out,
grad_input.data<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data<scalar_t>()); rois.contiguous().data_ptr<scalar_t>());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return grad_input; return grad_input;
......
...@@ -183,17 +183,17 @@ std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda( ...@@ -183,17 +183,17 @@ std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda(
input.scalar_type(), "PSROIPool_forward", [&] { input.scalar_type(), "PSROIPool_forward", [&] {
PSROIPoolForward<scalar_t><<<grid, block, 0, stream>>>( PSROIPoolForward<scalar_t><<<grid, block, 0, stream>>>(
output_size, output_size,
input.contiguous().data<scalar_t>(), input.contiguous().data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
height, height,
width, width,
pooled_height, pooled_height,
pooled_width, pooled_width,
rois.contiguous().data<scalar_t>(), rois.contiguous().data_ptr<scalar_t>(),
channels_out, channels_out,
output.data<scalar_t>(), output.data_ptr<scalar_t>(),
channel_mapping.data<int>()); channel_mapping.data_ptr<int>());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, channel_mapping); return std::make_tuple(output, channel_mapping);
...@@ -250,8 +250,8 @@ at::Tensor PSROIPool_backward_cuda( ...@@ -250,8 +250,8 @@ at::Tensor PSROIPool_backward_cuda(
grad.scalar_type(), "PSROIPool_backward", [&] { grad.scalar_type(), "PSROIPool_backward", [&] {
PSROIPoolBackward<scalar_t><<<grid, block, 0, stream>>>( PSROIPoolBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(), grad.numel(),
grad.contiguous().data<scalar_t>(), grad.contiguous().data_ptr<scalar_t>(),
channel_mapping.data<int>(), channel_mapping.data_ptr<int>(),
num_rois, num_rois,
spatial_scale, spatial_scale,
channels, channels,
...@@ -260,8 +260,8 @@ at::Tensor PSROIPool_backward_cuda( ...@@ -260,8 +260,8 @@ at::Tensor PSROIPool_backward_cuda(
pooled_height, pooled_height,
pooled_width, pooled_width,
channels_out, channels_out,
grad_input.data<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois.contiguous().data<scalar_t>()); rois.contiguous().data_ptr<scalar_t>());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return grad_input; return grad_input;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment