Commit b04a0157 authored by Kai Chen's avatar Kai Chen
Browse files

add fp16 support for forwarding

parent 7f9d2eb5
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <cuda.h> using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)
#include <cuda_runtime.h>
#include <math.h> #define CUDA_1D_KERNEL_LOOP(i, n) \
#include <stdio.h> for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
#include <vector>
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
#define THREADS_PER_BLOCK 1024 #define THREADS_PER_BLOCK 1024
...@@ -28,10 +24,8 @@ __device__ scalar_t bilinear_interpolate(const scalar_t *bottom_data, ...@@ -28,10 +24,8 @@ __device__ scalar_t bilinear_interpolate(const scalar_t *bottom_data,
return 0; return 0;
} }
if (y <= 0) if (y <= 0) y = 0;
y = 0; if (x <= 0) x = 0;
if (x <= 0)
x = 0;
int y_low = (int)y; int y_low = (int)y;
int x_low = (int)x; int x_low = (int)x;
...@@ -69,12 +63,13 @@ __device__ scalar_t bilinear_interpolate(const scalar_t *bottom_data, ...@@ -69,12 +63,13 @@ __device__ scalar_t bilinear_interpolate(const scalar_t *bottom_data,
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void __global__ void ROIAlignForward(const int nthreads, const scalar_t *bottom_data,
ROIAlignForward(const int nthreads, const scalar_t *bottom_data, const scalar_t *bottom_rois,
const scalar_t *bottom_rois, const scalar_t spatial_scale, const scalar_t spatial_scale,
const int sample_num, const int channels, const int height, const int sample_num, const int channels,
const int width, const int pooled_height, const int height, const int width,
const int pooled_width, scalar_t *top_data) { const int pooled_height, const int pooled_width,
scalar_t *top_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the aligned output // (n, c, ph, pw) is an element in the aligned output
int pw = index % pooled_width; int pw = index % pooled_width;
...@@ -101,7 +96,7 @@ ROIAlignForward(const int nthreads, const scalar_t *bottom_data, ...@@ -101,7 +96,7 @@ ROIAlignForward(const int nthreads, const scalar_t *bottom_data,
int sample_num_h = (sample_num > 0) int sample_num_h = (sample_num > 0)
? sample_num ? sample_num
: ceil(roi_height / pooled_height); // e.g., = 2 : ceil(roi_height / pooled_height); // e.g., = 2
int sample_num_w = int sample_num_w =
(sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width);
...@@ -137,17 +132,17 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois, ...@@ -137,17 +132,17 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
const int pooled_height, const int pooled_width, const int pooled_height, const int pooled_width,
at::Tensor output) { at::Tensor output) {
const int output_size = num_rois * pooled_height * pooled_width * channels; const int output_size = num_rois * pooled_height * pooled_width * channels;
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.type(), "ROIAlignLaucherForward", ([&] { features.type(), "ROIAlignLaucherForward", ([&] {
const scalar_t *bottom_data = features.data<scalar_t>(); const scalar_t *bottom_data = features.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>(); const scalar_t *rois_data = rois.data<scalar_t>();
scalar_t *top_data = output.data<scalar_t>(); scalar_t *top_data = output.data<scalar_t>();
ROIAlignForward< ROIAlignForward<scalar_t>
scalar_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale), output_size, bottom_data, rois_data, scalar_t(spatial_scale),
sample_num, channels, height, width, pooled_height, pooled_width, sample_num, channels, height, width, pooled_height,
top_data); pooled_width, top_data);
})); }));
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) { if (cudaSuccess != err) {
...@@ -159,11 +154,12 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois, ...@@ -159,11 +154,12 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
} }
template <typename scalar_t> template <typename scalar_t>
__device__ void __device__ void bilinear_interpolate_gradient(const int height, const int width,
bilinear_interpolate_gradient(const int height, const int width, scalar_t y, scalar_t y, scalar_t x,
scalar_t x, scalar_t &w1, scalar_t &w2, scalar_t &w1, scalar_t &w2,
scalar_t &w3, scalar_t &w4, int &x_low, scalar_t &w3, scalar_t &w4,
int &x_high, int &y_low, int &y_high) { int &x_low, int &x_high,
int &y_low, int &y_high) {
// deal with cases that inverse elements are out of feature map boundary // deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) { if (y < -1.0 || y > height || x < -1.0 || x > width) {
w1 = w2 = w3 = w4 = 0.; w1 = w2 = w3 = w4 = 0.;
...@@ -171,10 +167,8 @@ bilinear_interpolate_gradient(const int height, const int width, scalar_t y, ...@@ -171,10 +167,8 @@ bilinear_interpolate_gradient(const int height, const int width, scalar_t y,
return; return;
} }
if (y <= 0) if (y <= 0) y = 0;
y = 0; if (x <= 0) x = 0;
if (x <= 0)
x = 0;
y_low = (int)y; y_low = (int)y;
x_low = (int)x; x_low = (int)x;
...@@ -204,12 +198,11 @@ bilinear_interpolate_gradient(const int height, const int width, scalar_t y, ...@@ -204,12 +198,11 @@ bilinear_interpolate_gradient(const int height, const int width, scalar_t y,
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void __global__ void ROIAlignBackward(
ROIAlignBackward(const int nthreads, const scalar_t *top_diff, const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
const scalar_t *bottom_rois, const scalar_t spatial_scale, const scalar_t spatial_scale, const int sample_num, const int channels,
const int sample_num, const int channels, const int height, const int height, const int width, const int pooled_height,
const int width, const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
const int pooled_width, scalar_t *bottom_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the aligned output // (n, c, ph, pw) is an element in the aligned output
int pw = index % pooled_width; int pw = index % pooled_width;
...@@ -239,7 +232,7 @@ ROIAlignBackward(const int nthreads, const scalar_t *top_diff, ...@@ -239,7 +232,7 @@ ROIAlignBackward(const int nthreads, const scalar_t *top_diff,
int sample_num_h = (sample_num > 0) int sample_num_h = (sample_num > 0)
? sample_num ? sample_num
: ceil(roi_height / pooled_height); // e.g., = 2 : ceil(roi_height / pooled_height); // e.g., = 2
int sample_num_w = int sample_num_w =
(sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width);
...@@ -279,13 +272,6 @@ ROIAlignBackward(const int nthreads, const scalar_t *top_diff, ...@@ -279,13 +272,6 @@ ROIAlignBackward(const int nthreads, const scalar_t *top_diff,
} }
} }
template <>
__global__ void ROIAlignBackward<double>(
const int nthreads, const double *top_diff, const double *bottom_rois,
const double spatial_scale, const int sample_num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, double *bottom_diff) {}
int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const float spatial_scale, const int sample_num, const float spatial_scale, const int sample_num,
const int channels, const int height, const int channels, const int height,
...@@ -294,6 +280,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -294,6 +280,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
at::Tensor bottom_grad) { at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels; const int output_size = num_rois * pooled_height * pooled_width * channels;
// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
top_grad.type(), "ROIAlignLaucherBackward", ([&] { top_grad.type(), "ROIAlignLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>(); const scalar_t *top_diff = top_grad.data<scalar_t>();
...@@ -304,10 +291,11 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -304,10 +291,11 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
exit(-1); exit(-1);
} }
ROIAlignBackward< ROIAlignBackward<scalar_t>
scalar_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, top_diff, rois_data, spatial_scale, sample_num, output_size, top_diff, rois_data, spatial_scale, sample_num,
channels, height, width, pooled_height, pooled_width, bottom_diff); channels, height, width, pooled_height, pooled_width,
bottom_diff);
})); }));
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) { if (cudaSuccess != err) {
......
...@@ -4,7 +4,7 @@ from torch.autograd import gradcheck ...@@ -4,7 +4,7 @@ from torch.autograd import gradcheck
import os.path as osp import os.path as osp
import sys import sys
sys.path.append(osp.abspath(osp.join(__file__, '../../'))) sys.path.append(osp.abspath(osp.join(__file__, '../../')))
from roi_pooling import RoIPool from roi_pool import RoIPool
feat = torch.randn(4, 16, 15, 15, requires_grad=True).cuda() feat = torch.randn(4, 16, 15, 15, requires_grad=True).cuda()
rois = torch.Tensor([[0, 0, 0, 50, 50], [0, 10, 30, 43, 55], rois = torch.Tensor([[0, 0, 0, 50, 50], [0, 10, 30, 43, 55],
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <cuda.h> using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)
#include <cuda_runtime.h>
#include <math.h> #define CUDA_1D_KERNEL_LOOP(i, n) \
#include <stdio.h> for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
#include <vector>
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
#define THREADS_PER_BLOCK 1024 #define THREADS_PER_BLOCK 1024
...@@ -44,8 +40,7 @@ __global__ void ROIPoolForward(const int nthreads, const scalar_t *bottom_data, ...@@ -44,8 +40,7 @@ __global__ void ROIPoolForward(const int nthreads, const scalar_t *bottom_data,
// force malformed rois to be 1x1 // force malformed rois to be 1x1
scalar_t roi_w = roi_x2 - roi_x1; scalar_t roi_w = roi_x2 - roi_x1;
scalar_t roi_h = roi_y2 - roi_y1; scalar_t roi_h = roi_y2 - roi_y1;
if (roi_w <= 0 || roi_h <= 0) if (roi_w <= 0 || roi_h <= 0) continue;
continue;
scalar_t bin_size_w = roi_w / static_cast<scalar_t>(pooled_w); scalar_t bin_size_w = roi_w / static_cast<scalar_t>(pooled_w);
scalar_t bin_size_h = roi_h / static_cast<scalar_t>(pooled_h); scalar_t bin_size_h = roi_h / static_cast<scalar_t>(pooled_h);
...@@ -68,7 +63,8 @@ __global__ void ROIPoolForward(const int nthreads, const scalar_t *bottom_data, ...@@ -68,7 +63,8 @@ __global__ void ROIPoolForward(const int nthreads, const scalar_t *bottom_data,
bottom_data += (roi_batch_ind * channels + c) * height * width; bottom_data += (roi_batch_ind * channels + c) * height * width;
// Define an empty pooling region to be zero // Define an empty pooling region to be zero
scalar_t max_val = is_empty ? 0 : bottom_data[bin_y1 * width + bin_x1] - 1; scalar_t max_val = is_empty ? static_cast<scalar_t>(0)
: bottom_data[bin_y1 * width + bin_x1] - 1;
for (int h = bin_y1; h < bin_y2; ++h) { for (int h = bin_y1; h < bin_y2; ++h) {
for (int w = bin_x1; w < bin_x2; ++w) { for (int w = bin_x1; w < bin_x2; ++w) {
...@@ -80,8 +76,7 @@ __global__ void ROIPoolForward(const int nthreads, const scalar_t *bottom_data, ...@@ -80,8 +76,7 @@ __global__ void ROIPoolForward(const int nthreads, const scalar_t *bottom_data,
} }
} }
top_data[index] = max_val; top_data[index] = max_val;
if (argmax_data != NULL) if (argmax_data != NULL) argmax_data[index] = max_idx;
argmax_data[index] = max_idx;
} }
} }
...@@ -92,17 +87,18 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois, ...@@ -92,17 +87,18 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
at::Tensor output, at::Tensor argmax) { at::Tensor output, at::Tensor argmax) {
const int output_size = num_rois * channels * pooled_h * pooled_w; const int output_size = num_rois * channels * pooled_h * pooled_w;
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.type(), "ROIPoolLaucherForward", ([&] { features.type(), "ROIPoolLaucherForward", ([&] {
const scalar_t *bottom_data = features.data<scalar_t>(); const scalar_t *bottom_data = features.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>(); const scalar_t *rois_data = rois.data<scalar_t>();
scalar_t *top_data = output.data<scalar_t>(); scalar_t *top_data = output.data<scalar_t>();
int *argmax_data = argmax.data<int>(); int *argmax_data = argmax.data<int>();
ROIPoolForward< ROIPoolForward<scalar_t>
scalar_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale), output_size, bottom_data, rois_data, scalar_t(spatial_scale),
channels, height, width, pooled_h, pooled_w, top_data, argmax_data); channels, height, width, pooled_h, pooled_w, top_data,
argmax_data);
})); }));
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) { if (cudaSuccess != err) {
...@@ -135,28 +131,6 @@ __global__ void ROIPoolBackward(const int nthreads, const scalar_t *top_diff, ...@@ -135,28 +131,6 @@ __global__ void ROIPoolBackward(const int nthreads, const scalar_t *top_diff,
} }
} }
template <>
__global__ void
ROIPoolBackward<double>(const int nthreads, const double *top_diff,
const double *rois, const int *argmax_data,
const double spatial_scale, const int channels,
const int height, const int width, const int pooled_h,
const int pooled_w, double *bottom_diff) {
// CUDA_1D_KERNEL_LOOP(index, nthreads) {
// int pw = index % pooled_w;
// int ph = (index / pooled_w) % pooled_h;
// int c = (index / pooled_w / pooled_h) % channels;
// int n = index / pooled_w / pooled_h / channels;
// int roi_batch_ind = rois[n * 5];
// int bottom_index = argmax_data[(n * channels + c) * pooled_h * pooled_w +
// ph * pooled_w + pw];
// *(bottom_diff + (roi_batch_ind * channels + c) * height * width +
// bottom_index) +=top_diff[index];
// }
}
int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const at::Tensor argmax, const float spatial_scale, const at::Tensor argmax, const float spatial_scale,
const int batch_size, const int channels, const int batch_size, const int channels,
...@@ -165,6 +139,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -165,6 +139,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int pooled_w, at::Tensor bottom_grad) { const int pooled_w, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_h * pooled_w * channels; const int output_size = num_rois * pooled_h * pooled_w * channels;
// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
top_grad.type(), "ROIPoolLaucherBackward", ([&] { top_grad.type(), "ROIPoolLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>(); const scalar_t *top_diff = top_grad.data<scalar_t>();
...@@ -177,11 +152,11 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -177,11 +152,11 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
exit(-1); exit(-1);
} }
ROIPoolBackward< ROIPoolBackward<scalar_t>
scalar_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, top_diff, rois_data, argmax_data, output_size, top_diff, rois_data, argmax_data,
scalar_t(spatial_scale), channels, height, width, pooled_h, scalar_t(spatial_scale), channels, height, width, pooled_h,
pooled_w, bottom_diff); pooled_w, bottom_diff);
})); }));
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) { if (cudaSuccess != err) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment