Commit 3f412c39 authored by Kai Chen's avatar Kai Chen
Browse files

support half tensors

parent 826a5613
from torch.autograd import Function, Variable
from torch.autograd import Function
from .. import roi_align_cuda
......@@ -49,11 +49,11 @@ class RoIAlignFunction(Function):
grad_input = grad_rois = None
if ctx.needs_input_grad[0]:
grad_input = Variable(
rois.new(batch_size, num_channels, data_height, data_width)
.zero_())
roi_align_cuda.backward(grad_output, rois, out_h, out_w,
spatial_scale, sample_num, grad_input)
grad_input = rois.new_zeros(batch_size, num_channels, data_height,
data_width)
roi_align_cuda.backward(grad_output.contiguous(), rois, out_h,
out_w, spatial_scale, sample_num,
grad_input)
return grad_input, grad_rois, None, None, None
......
#include <torch/torch.h>
#include <torch/extension.h>
#include <cmath>
#include <vector>
......
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
......@@ -144,12 +142,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
sample_num, channels, height, width, pooled_height,
pooled_width, top_data);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
THCudaCheck(cudaGetLastError());
return 1;
}
......@@ -280,8 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels;
// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.type(), "ROIAlignLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>();
......@@ -297,11 +289,6 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
channels, height, width, pooled_height, pooled_width,
bottom_diff);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
THCudaCheck(cudaGetLastError());
return 1;
}
......@@ -24,9 +24,8 @@ class RoIPoolFunction(Function):
num_channels = features.size(1)
num_rois = rois.size(0)
out_size = (num_rois, num_channels, out_h, out_w)
output = features.new_zeros(*out_size)
argmax = features.new_zeros(*out_size, dtype=torch.int)
output = features.new_zeros(out_size)
argmax = features.new_zeros(out_size, dtype=torch.int)
roi_pool_cuda.forward(features, rois, out_h, out_w, spatial_scale,
output, argmax)
ctx.spatial_scale = spatial_scale
......@@ -46,9 +45,9 @@ class RoIPoolFunction(Function):
grad_input = grad_rois = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.new(feature_size).zero_()
roi_pool_cuda.backward(grad_output, rois, argmax, spatial_scale,
grad_input)
grad_input = grad_output.new_zeros(feature_size)
roi_pool_cuda.backward(grad_output.contiguous(), rois, argmax,
spatial_scale, grad_input)
return grad_input, grad_rois, None, None
......
#include <torch/torch.h>
#include <torch/extension.h>
#include <cmath>
#include <vector>
......
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
......@@ -100,11 +98,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
channels, height, width, pooled_h, pooled_w, top_data,
argmax_data);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
THCudaCheck(cudaGetLastError());
return 1;
}
......@@ -139,8 +133,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int pooled_w, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_h * pooled_w * channels;
// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.type(), "ROIPoolLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>();
......@@ -158,11 +151,6 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
scalar_t(spatial_scale), channels, height, width, pooled_h,
pooled_w, bottom_diff);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
THCudaCheck(cudaGetLastError());
return 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment