"vscode:/vscode.git/clone" did not exist on "81b6fbf19d115de02b53dc34afe23b4833fa45da"
Unverified Commit 77cb5786 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Fix] Remove dumplicate files in csrc (#1284)

parent 357b0dfb
// Copyright (c) OpenMMLab. All rights reserved
#include "deform_conv_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void deformable_im2col(Tensor data_im, Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
Tensor data_col) {
// num_axes should be smaller than block size
// todo: check parallel_imgs is correctly passed in
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels),
THREADS_PER_BLOCK, 0,
at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, height, width, ksize_h,
ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels,
deformable_group, height_col, width_col, data_col_);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void deformable_col2im(Tensor data_col, Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
Tensor grad_im) {
// todo: make sure parallel_imgs is passed in correctly
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels =
channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels),
THREADS_PER_BLOCK, 0,
at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_offset_, channels, height, width,
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
dilation_w, channel_per_deformable_group, parallel_imgs,
deformable_group, height_col, width_col, grad_im_);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void deformable_col2im_coord(
Tensor data_col, Tensor data_im, Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, Tensor grad_offset) {
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w *
deformable_group * parallel_imgs;
int channel_per_deformable_group =
channels * ksize_h * ksize_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
deformable_col2im_coord_gpu_kernel<<<
GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0,
at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, channels, height,
width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs,
2 * ksize_h * ksize_w * deformable_group, deformable_group,
height_col, width_col, grad_offset_);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void deform_conv_shape_check(Tensor input, Tensor offset, Tensor *gradOutput,
Tensor weight, int kH, int kW, int dH, int dW,
int padH, int padW, int dilationH, int dilationW,
int group, int deformable_group) {
TORCH_CHECK(
weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, but got: %s",
weight.ndimension());
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
TORCH_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d",
kH, kW);
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d",
kH, kW, weight.size(2), weight.size(3));
TORCH_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d", dH,
dW);
TORCH_CHECK(
dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input.ndimension();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4) {
dimf++;
dimh++;
dimw++;
}
TORCH_CHECK(ndim == 3 || ndim == 4,
"3D or 4D input tensor expected but got: %s", ndim);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
TORCH_CHECK(nInputPlane % deformable_group == 0,
"input channels must divide deformable group size");
if (outputWidth < 1 || outputHeight < 1)
AT_ERROR(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
outputWidth);
TORCH_CHECK(input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane, input.size(1));
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
"input image is smaller than kernel");
TORCH_CHECK(
(offset.size(2) == outputHeight && offset.size(3) == outputWidth),
"invalid spatial size of offset, expected height: %d width: %d, but "
"got height: %d width: %d",
outputHeight, outputWidth, offset.size(2), offset.size(3));
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
"invalid number of channels of offset");
if (gradOutput != NULL) {
TORCH_CHECK(
gradOutput->size(dimf) == nOutputPlane,
"invalid number of gradOutput planes, expected: %d, but got: %d",
nOutputPlane, gradOutput->size(dimf));
TORCH_CHECK(
(gradOutput->size(dimh) == outputHeight &&
gradOutput->size(dimw) == outputWidth),
"invalid size of gradOutput, expected height: %d width: %d , but "
"got height: %d width: %d",
outputHeight, outputWidth, gradOutput->size(dimh),
gradOutput->size(dimw));
}
}
void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight,
Tensor offset, Tensor output,
Tensor columns, Tensor ones, int kW,
int kH, int dW, int dH, int padW,
int padH, int dilationW, int dilationH,
int group, int deformable_group,
int im2col_step) {
// todo: resize columns to include im2col: done
// todo: add im2col_step as input
// todo: add new output buffer and transpose it to output (or directly
// transpose output) todo: possibly change data indexing because of
// parallel_imgs
deform_conv_shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input.unsqueeze_(0);
offset.unsqueeze_(0);
}
// todo: assert batchsize dividable by im2col_step
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
outputHeight, outputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
ones = at::ones({outputHeight, outputWidth}, input.options());
}
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth},
output.options());
output_buffer = output_buffer.view(
{output_buffer.size(0), group, output_buffer.size(1) / group,
output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
output_buffer[elt][g] = output_buffer[elt][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output_buffer[elt][g]);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
}
output_buffer = output_buffer.view(
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
output = output.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
}
}
void DeformConvBackwardInputCUDAKernelLauncher(
Tensor input, Tensor offset, Tensor gradOutput, Tensor gradInput,
Tensor gradOffset, Tensor weight, Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) {
deform_conv_shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW,
padH, padW, dilationH, dilationW, group,
deformable_group);
at::DeviceGuard guard(input.device());
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view({1, input.size(0), input.size(1), input.size(2)});
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
// change order of grad output
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight,
outputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
gradOutput = gradOutput.view(
{gradOutput.size(0), group, gradOutput.size(1) / group,
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
for (int g = 0; g < group; g++) {
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view(
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW, im2col_step, deformable_group,
gradOffset[elt]);
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, gradInput[elt]);
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
}
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
gradOffset =
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
}
}
void DeformConvBackwardParametersCUDAKernelLauncher(
Tensor input, Tensor offset, Tensor gradOutput, Tensor gradWeight,
Tensor columns, Tensor ones, int kW, int kH, int dW, int dH, int padW,
int padH, int dilationW, int dilationH, int group, int deformable_group,
float scale, int im2col_step) {
// todo: transpose and reshape outGrad
// todo: reshape columns
// todo: add im2col_step as input
deform_conv_shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH,
dW, padH, padW, dilationH, dilationW, group,
deformable_group);
at::DeviceGuard guard(input.device());
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view(
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = gradWeight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
outputHeight, outputWidth});
gradOutputBuffer = gradOutputBuffer.contiguous();
gradOutputBuffer.copy_(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
// divide into group
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight =
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3)});
for (int g = 0; g < group; g++) {
gradWeight[g] = gradWeight[g]
.flatten(1)
.addmm_(gradOutputBuffer[elt][g].flatten(1),
columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0),
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3),
gradWeight.size(4)});
}
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
}
}
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cuda_helper.hpp"
#include "roi_align_rotated_cuda_kernel.cuh"
void ROIAlignRotatedForwardCUDAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output) {
const int output_size = num_rois * pooled_height * pooled_width * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.type(), "ROIAlignRotatedLaucherForward", ([&] {
const scalar_t *bottom_data = features.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>();
scalar_t *top_data = output.data<scalar_t>();
roi_align_rotated_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
sample_num, aligned, clockwise, channels, height, width,
pooled_height, pooled_width, top_data);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void ROIAlignRotatedBackwardCUDAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.type(), "ROIAlignLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>();
scalar_t *bottom_diff = bottom_grad.data<scalar_t>();
roi_align_rotated_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, top_diff, rois_data, spatial_scale, sample_num,
aligned, clockwise, channels, height, width, pooled_height,
pooled_width, bottom_diff);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "bbox_overlaps_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Tensor ious, const int mode,
const bool aligned, const int offset) {
int output_size = ious.numel();
int num_bbox1 = bboxes1.size(0);
int num_bbox2 = bboxes2.size(0);
at::cuda::CUDAGuard device_guard(bboxes1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
bboxes1.scalar_type(), "bbox_overlaps_cuda_kernel", ([&] {
bbox_overlaps_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
bboxes1.data_ptr<scalar_t>(), bboxes2.data_ptr<scalar_t>(),
ious.data_ptr<scalar_t>(), num_bbox1, num_bbox2, mode, aligned,
offset);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "border_align_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void BorderAlignForwardCUDAKernelLauncher(const Tensor &input,
const Tensor &boxes, Tensor output,
Tensor argmax_idx,
const int pool_size) {
// shape assertion
AT_ASSERTM(input.ndimension() == 4,
"non-empty 4D(batch mode) tensor expected for input feature");
AT_ASSERTM(boxes.ndimension() == 3,
"boxes must be 3D tensor with size of [B, H*W, 4]");
int batch_size = input.size(0);
int feat_channels = input.size(1);
int channels = feat_channels / 4;
int height = input.size(2);
int width = input.size(3);
// shape [N, box_size, 4] for boxes. (x1, y1, x2, y2) format
int box_size = boxes.size(1);
// shape [N, channels, box_size, 4] for output
int nthreads = batch_size * channels * box_size;
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block(128, 4);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "border_align_forward_cuda_kernel", [&] {
border_align_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(nthreads), block, 0, stream>>>(
nthreads, input.data_ptr<scalar_t>(),
boxes.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
argmax_idx.data_ptr<int>(), channels, box_size, height, width,
pool_size);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output,
const Tensor &boxes,
const Tensor &argmax_idx,
Tensor grad_input,
const int pool_size) {
int batch_size = grad_input.size(0);
int feat_channels = grad_input.size(1);
int channels = feat_channels / 4;
int height = grad_input.size(2);
int width = grad_input.size(3);
int box_size = boxes.size(1);
int nthreads = batch_size * channels * box_size;
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block(128, 4);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "border_align_backward_cuda_kernel", [&] {
border_align_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(nthreads), block, 0, stream>>>(
nthreads, grad_output.data_ptr<scalar_t>(),
boxes.data_ptr<scalar_t>(), argmax_idx.data_ptr<int>(),
grad_input.data_ptr<scalar_t>(), channels, box_size, height,
width, pool_size);
});
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "carafe_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void CARAFEForwardCUDAKernelLauncher(const Tensor features, const Tensor masks,
Tensor rfeatures, Tensor routput,
Tensor rmasks, Tensor output,
const int kernel_size,
const int group_size,
const int scale_factor) {
const int batch_size = output.size(0);
const int channels = output.size(1);
const int output_height = output.size(2);
const int output_width = output.size(3);
const int input_height = features.size(2);
const int input_width = features.size(3);
const int mask_channels = masks.size(1);
rfeatures.resize_({batch_size, input_height, input_width, channels});
routput.resize_({batch_size, output_height, output_width, channels});
rmasks.resize_({batch_size, output_height, output_width, mask_channels});
// one warp per pixel
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "NCHW2NHWC_Feature", ([&] {
const scalar_t *bottom_data = features.data_ptr<scalar_t>();
scalar_t *top_data = rfeatures.data_ptr<scalar_t>();
const int dh = divideUP(channels, kTileDim);
const int dw = divideUP(input_height * input_width, kTileDim);
BatchTranspose2DCUDAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, channels, input_height * input_width, dh, dw,
bottom_data, top_data);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "NCHW2NHWC_Masks", ([&] {
const scalar_t *bottom_data = masks.data_ptr<scalar_t>();
scalar_t *top_data = rmasks.data_ptr<scalar_t>();
const int dh = divideUP(mask_channels, kTileDim);
const int dw = divideUP(output_height * output_width, kTileDim);
BatchTranspose2DCUDAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, mask_channels, output_height * output_width, dh, dw,
bottom_data, top_data);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "CARAFELaucherForward", ([&] {
const int num_kernels =
batch_size * output_height * output_width * THREADS_PER_PIXEL;
const scalar_t *bottom_data = rfeatures.data_ptr<scalar_t>();
const scalar_t *bottom_masks = rmasks.data_ptr<scalar_t>();
scalar_t *top_data = routput.data_ptr<scalar_t>();
CARAFEForward<scalar_t><<<divideUP(num_kernels, THREADS_PER_BLOCK),
THREADS_PER_BLOCK, 0, stream>>>(
num_kernels, bottom_data, bottom_masks, kernel_size, group_size,
scale_factor, channels, input_height, input_width, output_height,
output_width, mask_channels, top_data);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "NHWC2NCHW", ([&] {
const scalar_t *bottom_data = routput.data_ptr<scalar_t>();
scalar_t *top_data = output.data_ptr<scalar_t>();
const int dh = divideUP(output_height * output_width, kTileDim);
const int dw = divideUP(channels, kTileDim);
BatchTranspose2DCUDAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, output_height * output_width, channels, dh, dw,
bottom_data, top_data);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void CARAFEBackwardCUDAKernelLauncher(
const Tensor top_grad, const Tensor rfeatures, const Tensor masks,
Tensor rtop_grad, Tensor rbottom_grad_hs, Tensor rbottom_grad,
Tensor rmask_grad, Tensor bottom_grad, Tensor mask_grad,
const int kernel_size, const int group_size, const int scale_factor) {
const int batch_size = top_grad.size(0);
const int channels = top_grad.size(1);
const int output_height = top_grad.size(2);
const int output_width = top_grad.size(3);
const int input_height = bottom_grad.size(2);
const int input_width = bottom_grad.size(3);
const int mask_channels = masks.size(1);
rtop_grad.resize_({batch_size, output_height, output_width, channels});
rbottom_grad.resize_({batch_size, input_height, input_width, channels});
rbottom_grad_hs.resize_({batch_size, output_height, output_width, channels});
rmask_grad.resize_({batch_size, output_height, output_width, mask_channels});
at::cuda::CUDAGuard device_guard(top_grad.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "NCHW2NHWC_Top_Grad", ([&] {
const scalar_t *bottom_data = top_grad.data_ptr<scalar_t>();
scalar_t *top_data = rtop_grad.data_ptr<scalar_t>();
const int dh = divideUP(channels, kTileDim);
const int dw = divideUP(output_height * output_width, kTileDim);
BatchTranspose2DCUDAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, channels, output_height * output_width, dh, dw,
bottom_data, top_data);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "CARAFELaucherBackward_Feature", ([&] {
const int num_kernels =
batch_size * output_height * output_width * THREADS_PER_PIXEL;
const scalar_t *top_diff = rtop_grad.data_ptr<scalar_t>();
const scalar_t *bottom_masks = masks.data_ptr<scalar_t>();
scalar_t *bottom_diff = rbottom_grad_hs.data_ptr<scalar_t>();
CARAFEBackward_Feature<scalar_t>
<<<divideUP(num_kernels, THREADS_PER_BLOCK), THREADS_PER_BLOCK, 0,
stream>>>(num_kernels, top_diff, bottom_masks, kernel_size,
group_size, scale_factor, channels, input_height,
input_width, output_height, output_width,
mask_channels, bottom_diff);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "FeatureSum", ([&] {
const int num_kernels =
batch_size * input_height * input_width * THREADS_PER_PIXEL;
const scalar_t *bottom_diff_hs = rbottom_grad_hs.data_ptr<scalar_t>();
scalar_t *bottom_diff = rbottom_grad.data_ptr<scalar_t>();
FeatureSum<scalar_t>
<<<divideUP(num_kernels, THREADS_PER_BLOCK), THREADS_PER_BLOCK, 0,
stream>>>(num_kernels, bottom_diff_hs, scale_factor, channels,
input_height, input_width, bottom_diff);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "NHWC2NCHW_Bottom_Grad", ([&] {
const scalar_t *bottom_data = rbottom_grad.data_ptr<scalar_t>();
scalar_t *top_data = bottom_grad.data_ptr<scalar_t>();
const int dh = divideUP(input_height * input_width, kTileDim);
const int dw = divideUP(channels, kTileDim);
BatchTranspose2DCUDAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, input_height * input_width, channels, dh, dw,
bottom_data, top_data);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "CARAFELaucherBackward_Mask", ([&] {
const int num_kernels = batch_size * output_height * output_width *
mask_channels * WARP_SIZE;
const scalar_t *top_diff = rtop_grad.data_ptr<scalar_t>();
const scalar_t *bottom_data = rfeatures.data_ptr<scalar_t>();
scalar_t *mask_diff = rmask_grad.data_ptr<scalar_t>();
CARAFEBackward_Mask<scalar_t>
<<<divideUP(num_kernels, THREADS_PER_BLOCK), THREADS_PER_BLOCK, 0,
stream>>>(num_kernels, top_diff, bottom_data, kernel_size,
group_size, scale_factor, channels, input_height,
input_width, output_height, output_width,
mask_channels, mask_diff);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "NHWC2NCHW_Mask_Grad", ([&] {
const scalar_t *bottom_data = rmask_grad.data_ptr<scalar_t>();
scalar_t *top_data = mask_grad.data_ptr<scalar_t>();
const int dh = divideUP(output_height * output_width, kTileDim);
const int dw = divideUP(mask_channels, kTileDim);
BatchTranspose2DCUDAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, output_height * output_width, mask_channels, dh, dw,
bottom_data, top_data);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "carafe_naive_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void CARAFENAIVEForwardCUDAKernelLauncher(const Tensor features,
const Tensor masks, Tensor output,
const int kernel_size,
const int group_size,
const int scale_factor) {
int output_size = output.numel();
int channels = output.size(1);
int height = output.size(2);
int width = output.size(3);
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "CARAFENAIVEForward", ([&] {
carafe_naive_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, features.data_ptr<scalar_t>(),
masks.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
kernel_size, group_size, scale_factor, channels, height, width);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void CARAFENAIVEBackwardCUDAKernelLauncher(
const Tensor top_grad, const Tensor features, const Tensor masks,
Tensor bottom_grad, Tensor mask_grad, const int kernel_size,
const int group_size, const int scale_factor) {
int output_size = top_grad.numel();
int channels = top_grad.size(1);
int height = top_grad.size(2);
int width = top_grad.size(3);
at::cuda::CUDAGuard device_guard(top_grad.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "CARAFENAIVEBackward", ([&] {
carafe_naive_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, top_grad.data_ptr<scalar_t>(),
features.data_ptr<scalar_t>(), masks.data_ptr<scalar_t>(),
bottom_grad.data_ptr<scalar_t>(),
mask_grad.data_ptr<scalar_t>(), kernel_size, group_size,
scale_factor, channels, height, width);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/LikeLy-Journey/SegmenTron/blob/master/segmentron/modules/csrc/criss_cross_attention/ca_cuda.cu
#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
#include "cc_attention_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
Tensor weight) {
AT_ASSERTM(t.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(f.device().is_cuda(), "input must be a CUDA tensor");
auto n = t.size(0);
auto c = t.size(1);
auto h = t.size(2);
auto w = t.size(3);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Run kernel
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w - 1;
dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] {
ca_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
t.contiguous().data_ptr<scalar_t>(),
f.contiguous().data_ptr<scalar_t>(),
weight.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
THCudaCheck(cudaGetLastError());
}
void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
const Tensor f, Tensor dt, Tensor df) {
AT_ASSERTM(dw.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(t.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(f.device().is_cuda(), "input must be a CUDA tensor");
auto n = t.size(0);
auto c = t.size(1);
auto h = t.size(2);
auto w = t.size(3);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Run kernel
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = c * n;
dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] {
ca_backward_kernel_t<scalar_t><<<blocks, threads, 0, stream>>>(
dw.contiguous().data_ptr<scalar_t>(),
t.contiguous().data_ptr<scalar_t>(),
f.contiguous().data_ptr<scalar_t>(),
dt.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
AT_DISPATCH_FLOATING_TYPES(f.scalar_type(), "ca_backward_kernel_f", [&] {
ca_backward_kernel_f<scalar_t><<<blocks, threads, 0, stream>>>(
dw.contiguous().data_ptr<scalar_t>(),
t.contiguous().data_ptr<scalar_t>(),
f.contiguous().data_ptr<scalar_t>(),
df.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
THCudaCheck(cudaGetLastError());
}
void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
Tensor out) {
AT_ASSERTM(weight.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(g.device().is_cuda(), "input must be a CUDA tensor");
auto n = g.size(0);
auto c = g.size(1);
auto h = g.size(2);
auto w = g.size(3);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Run kernel
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = c * n;
dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] {
ca_map_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
weight.contiguous().data_ptr<scalar_t>(),
g.contiguous().data_ptr<scalar_t>(),
out.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
THCudaCheck(cudaGetLastError());
}
void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
const Tensor g, Tensor dw, Tensor dg) {
AT_ASSERTM(dout.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(g.device().is_cuda(), "input must be a CUDA tensor");
auto n = dout.size(0);
auto c = dout.size(1);
auto h = dout.size(2);
auto w = dout.size(3);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Run kernel
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w - 1;
dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES(
weight.scalar_type(), "ca_map_backward_kernel_w", [&] {
ca_map_backward_kernel_w<scalar_t><<<blocks, threads, 0, stream>>>(
dout.contiguous().data_ptr<scalar_t>(),
weight.contiguous().data_ptr<scalar_t>(),
g.contiguous().data_ptr<scalar_t>(),
dw.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
d3 = c * n;
blocks = dim3(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] {
ca_map_backward_kernel_g<scalar_t><<<blocks, threads, 0, stream>>>(
dout.contiguous().data_ptr<scalar_t>(),
weight.contiguous().data_ptr<scalar_t>(),
g.contiguous().data_ptr<scalar_t>(),
dg.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
THCudaCheck(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "deform_roi_pool_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void DeformRoIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois,
Tensor offset, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale,
int sampling_ratio, float gamma) {
int output_size = output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deform_roi_pool_forward_cuda_kernel", [&] {
deform_roi_pool_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), offset.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), pooled_height, pooled_width,
static_cast<scalar_t>(spatial_scale), sampling_ratio,
static_cast<scalar_t>(gamma), channels, height, width);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void DeformRoIPoolBackwardCUDAKernelLauncher(
Tensor grad_output, Tensor input, Tensor rois, Tensor offset,
Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width,
float spatial_scale, int sampling_ratio, float gamma) {
int output_size = grad_output.numel();
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "deform_roi_pool_backward_cuda_kernel", [&] {
deform_roi_pool_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(), rois.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
grad_offset.data_ptr<scalar_t>(), pooled_height, pooled_width,
static_cast<scalar_t>(spatial_scale), sampling_ratio,
static_cast<scalar_t>(gamma), channels, height, width);
});
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cuda_helper.hpp"
#include "sigmoid_focal_loss_cuda_kernel.cuh"
#include "softmax_focal_loss_cuda_kernel.cuh"
void SigmoidFocalLossForwardCUDAKernelLauncher(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma,
const float alpha) {
int output_size = output.numel();
int num_classes = input.size(1);
AT_ASSERTM(target.max().item<int64_t>() <= (int64_t)num_classes,
"target label should smaller or equal than num classes");
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "sigmoid_focal_loss_forward_cuda_kernel", [&] {
sigmoid_focal_loss_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), gamma, alpha, num_classes);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target,
Tensor weight,
Tensor grad_input,
const float gamma,
const float alpha) {
int output_size = grad_input.numel();
int num_classes = input.size(1);
at::cuda::CUDAGuard device_guard(grad_input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "sigmoid_focal_loss_backward_cuda_kernel", [&] {
sigmoid_focal_loss_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(), gamma, alpha, num_classes);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor output,
const float gamma,
const float alpha) {
int output_size = output.numel();
int num_classes = softmax.size(1);
AT_ASSERTM(target.max().item<int64_t>() <= (int64_t)num_classes,
"target label should smaller or equal than num classes");
at::cuda::CUDAGuard device_guard(softmax.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
softmax.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] {
softmax_focal_loss_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), gamma, alpha, num_classes);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input,
const float gamma,
const float alpha) {
int num_classes = softmax.size(1);
int output_size = buff.numel();
at::cuda::CUDAGuard device_guard(grad_input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_input.scalar_type(),
"softmax_focal_loss_backward_cuda1_"
"kernel",
[&] {
softmax_focal_loss_backward_cuda1_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
buff.data_ptr<scalar_t>(), gamma, alpha, num_classes);
});
AT_CUDA_CHECK(cudaGetLastError());
output_size = grad_input.numel();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_input.scalar_type(),
"softmax_focal_loss_backward_cuda2_"
"kernel",
[&] {
softmax_focal_loss_backward_cuda2_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), buff.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(), num_classes);
});
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "masked_conv2d_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void MaskedIm2colForwardCUDAKernelLauncher(const Tensor bottom_data,
const Tensor mask_h_idx,
const Tensor mask_w_idx,
Tensor top_data, const int kernel_h,
const int kernel_w, const int pad_h,
const int pad_w) {
int channels = bottom_data.size(1);
int height = bottom_data.size(2);
int width = bottom_data.size(3);
int mask_cnt = mask_h_idx.size(0);
int output_size = mask_cnt * channels;
at::cuda::CUDAGuard device_guard(bottom_data.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
bottom_data.scalar_type(), "MaskedIm2colLaucherForward", ([&] {
const scalar_t *bottom_data_ = bottom_data.data_ptr<scalar_t>();
const int64_t *mask_h_idx_ = mask_h_idx.data_ptr<int64_t>();
const int64_t *mask_w_idx_ = mask_w_idx.data_ptr<int64_t>();
scalar_t *top_data_ = top_data.data_ptr<scalar_t>();
MaskedIm2colForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, bottom_data_, height, width, kernel_h, kernel_w,
pad_h, pad_w, mask_h_idx_, mask_w_idx_, mask_cnt, top_data_);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void MaskedCol2imForwardCUDAKernelLauncher(
const Tensor bottom_data, const Tensor mask_h_idx, const Tensor mask_w_idx,
Tensor top_data, const int height, const int width, const int channels) {
int mask_cnt = mask_h_idx.size(0);
int output_size = mask_cnt * channels;
at::cuda::CUDAGuard device_guard(bottom_data.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
bottom_data.scalar_type(), "MaskedCol2imLaucherForward", ([&] {
const scalar_t *bottom_data_ = bottom_data.data_ptr<scalar_t>();
const int64_t *mask_h_idx_ = mask_h_idx.data_ptr<int64_t>();
const int64_t *mask_w_idx_ = mask_w_idx.data_ptr<int64_t>();
scalar_t *top_data_ = top_data.data_ptr<scalar_t>();
MaskedCol2imForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, bottom_data_, height, width, channels, mask_h_idx_,
mask_w_idx_, mask_cnt, top_data_);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "modulated_deform_conv_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void modulated_deformable_im2col_cuda(
const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, Tensor data_col) {
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
modulated_deformable_im2col_gpu_kernel<<<
GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0,
at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, data_mask_, height_im,
width_im, kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group, batch_size,
channels, deformable_group, height_col, width_col, data_col_);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void modulated_deformable_col2im_cuda(
const Tensor data_col, const Tensor data_offset, const Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, Tensor grad_im) {
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels =
channels * kernel_h * kernel_w * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
modulated_deformable_col2im_gpu_kernel<<<
GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0,
at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_offset_, data_mask_, channels,
height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im_);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void modulated_deformable_col2im_coord_cuda(
const Tensor data_col, const Tensor data_im, const Tensor data_offset,
const Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
Tensor grad_offset, Tensor grad_mask) {
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h *
kernel_w * deformable_group;
const int channel_per_deformable_group =
channels * kernel_h * kernel_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
modulated_deformable_col2im_coord_gpu_kernel<<<
GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0,
at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, data_mask_,
channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w,
stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, batch_size,
2 * kernel_h * kernel_w * deformable_group, deformable_group,
height_col, width_col, grad_offset_, grad_mask_);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void ModulatedDeformConvForwardCUDAKernelLauncher(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, const int group,
const int deformable_group, const bool with_bias) {
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns =
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.options());
output = output.view({output.size(0), group, output.size(1) / group,
output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2),
output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
}
void ModulatedDeformConvBackwardCUDAKernelLauncher(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
input.options());
grad_output =
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(
columns, input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(
columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
grad_weight.size(1), grad_weight.size(2),
grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
.view_as(grad_weight[g]);
if (with_bias) {
grad_bias[g] =
grad_bias[g]
.view({-1, 1})
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
.view(-1);
}
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
}
// Copyright (c) OpenMMLab. All rights reserved
#include "nms_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
Tensor NMSCUDAKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
int offset) {
at::cuda::CUDAGuard device_guard(boxes.device());
if (boxes.numel() == 0) {
return at::empty({0}, boxes.options().dtype(at::kLong));
}
auto order_t = std::get<1>(scores.sort(0, /*descending=*/true));
auto boxes_sorted = boxes.index_select(0, order_t);
int boxes_num = boxes.size(0);
const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
nms_cuda<<<blocks, threads, 0, stream>>>(
boxes_num, iou_threshold, offset, boxes_sorted.data_ptr<float>(),
(unsigned long long*)mask.data_ptr<int64_t>());
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host =
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
at::Tensor keep_t =
at::zeros({boxes_num}, boxes.options().dtype(at::kBool).device(at::kCPU));
bool* keep = keep_t.data_ptr<bool>();
for (int i = 0; i < boxes_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep[i] = true;
// set every overlap box with bit 1 in remv
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
AT_CUDA_CHECK(cudaGetLastError());
return order_t.masked_select(keep_t.to(at::kCUDA));
}
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/hszhao/semseg/blob/master/lib/psa/src
#include <THC/THC.h>
#include <torch/serialize/tensor.h>
#include <THC/THCDeviceUtils.cuh>
#include "psamask_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void PSAMaskForwardCUDAKernelLauncher(const int psa_type, const Tensor input,
Tensor output, const int num_,
const int h_feature, const int w_feature,
const int h_mask, const int w_mask,
const int half_h_mask,
const int half_w_mask) {
int nthreads = num_ * h_feature * w_feature;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (psa_type == 0)
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "psamask_collect_forward_cuda", [&] {
psamask_collect_forward_cuda<scalar_t><<<nthreads, 512, 0, stream>>>(
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
half_w_mask, input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
});
else
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "psamask_distribute_forward_cuda", [&] {
psamask_distribute_forward_cuda<scalar_t>
<<<nthreads, 512, 0, stream>>>(
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
half_w_mask, input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
});
}
void PSAMaskBackwardCUDAKernelLauncher(
const int psa_type, const Tensor grad_output, Tensor grad_input,
const int num_, const int h_feature, const int w_feature, const int h_mask,
const int w_mask, const int half_h_mask, const int half_w_mask) {
int nthreads = num_ * h_feature * w_feature;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (psa_type == 0)
AT_DISPATCH_FLOATING_TYPES(
grad_input.scalar_type(), "psamask_collect_backward_cuda", [&] {
psamask_collect_backward_cuda<scalar_t><<<nthreads, 512, 0, stream>>>(
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
half_w_mask, grad_output.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>());
});
else
AT_DISPATCH_FLOATING_TYPES(
grad_input.scalar_type(), "psamask_distribute_backward_cuda", [&] {
psamask_distribute_backward_cuda<scalar_t>
<<<nthreads, 512, 0, stream>>>(
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
half_w_mask, grad_output.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>());
});
}
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cuda_helper.hpp"
#include "roi_align_cuda_kernel.cuh"
void ROIAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
int output_size = output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "roi_align_forward_cuda_kernel", [&] {
roi_align_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
argmax_y.data_ptr<scalar_t>(), argmax_x.data_ptr<scalar_t>(),
aligned_height, aligned_width,
static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
aligned, channels, height, width);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void ROIAlignBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor argmax_y, Tensor argmax_x,
Tensor grad_input, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode,
bool aligned) {
int output_size = grad_output.numel();
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "roi_align_backward_cuda_kernel", [&] {
roi_align_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), argmax_y.data_ptr<scalar_t>(),
argmax_x.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
aligned_height, aligned_width,
static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
aligned, channels, height, width);
});
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cuda_helper.hpp"
#include "roi_pool_cuda_kernel.cuh"
void ROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height,
int pooled_width, float spatial_scale) {
int output_size = output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "roi_pool_forward_cuda_kernel", [&] {
roi_pool_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
argmax.data_ptr<int>(), pooled_height, pooled_width,
static_cast<scalar_t>(spatial_scale), channels, height, width);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void ROIPoolBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor argmax, Tensor grad_input,
int pooled_height, int pooled_width,
float spatial_scale) {
int output_size = grad_output.numel();
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "roi_pool_backward_cuda_kernel", [&] {
roi_pool_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), argmax.data_ptr<int>(),
grad_input.data_ptr<scalar_t>(), pooled_height, pooled_width,
channels, height, width);
});
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cuda_helper.hpp"
#include "sync_bn_cuda_kernel.cuh"
void SyncBNForwardMeanCUDAKernelLauncher(const Tensor input, Tensor mean) {
int num = input.size(0);
int channels = input.size(1);
int spatial = input.size(2);
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "sync_bn_forward_mean_cuda_kernel", [&] {
sync_bn_forward_mean_cuda_kernel<scalar_t>
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
input.data_ptr<scalar_t>(), mean.data_ptr<float>(), num,
channels, spatial);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void SyncBNForwardVarCUDAKernelLauncher(const Tensor input, const Tensor mean,
Tensor var) {
int num = input.size(0);
int channels = input.size(1);
int spatial = input.size(2);
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "sync_bn_forward_mean_cuda_kernel", [&] {
sync_bn_forward_var_cuda_kernel<scalar_t>
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
input.data_ptr<scalar_t>(), mean.data_ptr<float>(),
var.data_ptr<float>(), num, channels, spatial);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void SyncBNForwardOutputCUDAKernelLauncher(
const Tensor input, const Tensor mean, const Tensor var,
Tensor running_mean, Tensor running_var, const Tensor weight,
const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps,
float momentum, int group_size) {
int num = input.size(0);
int channels = input.size(1);
int spatial = input.size(2);
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "sync_bn_forward_mean_cuda_kernel", [&] {
sync_bn_forward_output_cuda_kernel<scalar_t>
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
input.data_ptr<scalar_t>(), mean.data_ptr<float>(),
var.data_ptr<float>(), running_mean.data_ptr<float>(),
running_var.data_ptr<float>(), weight.data_ptr<float>(),
bias.data_ptr<float>(), norm.data_ptr<float>(),
std.data_ptr<float>(), output.data_ptr<scalar_t>(), num,
channels, spatial, eps, momentum, group_size);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void SyncBNBackwardParamCUDAKernelLauncher(const Tensor grad_output,
const Tensor norm,
Tensor grad_weight,
Tensor grad_bias) {
int num = grad_output.size(0);
int channels = grad_output.size(1);
int spatial = grad_output.size(2);
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "sync_bn_backward_param_cuda_kernel", [&] {
sync_bn_backward_param_cuda_kernel<scalar_t>
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
grad_output.data_ptr<scalar_t>(), norm.data_ptr<float>(),
grad_weight.data_ptr<float>(), grad_bias.data_ptr<float>(), num,
channels, spatial);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void SyncBNBackwardDataCUDAKernelLauncher(const Tensor grad_output,
const Tensor weight,
const Tensor grad_weight,
const Tensor grad_bias,
const Tensor norm, const Tensor std,
Tensor grad_input) {
int output_size = grad_input.numel();
int num = grad_input.size(0);
int channels = grad_input.size(1);
int spatial = grad_input.size(2);
at::cuda::CUDAGuard device_guard(grad_input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "sync_bn_backward_data_cuda_kernel", [&] {
sync_bn_backward_data_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
weight.data_ptr<float>(), grad_weight.data_ptr<float>(),
grad_bias.data_ptr<float>(), norm.data_ptr<float>(),
std.data_ptr<float>(), grad_input.data_ptr<scalar_t>(), num,
channels, spatial);
});
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cuda_helper.hpp"
#include "tin_shift_cuda_kernel.cuh"
void TINShiftForwardCUDAKernelLauncher(Tensor input, Tensor shift,
Tensor output) {
int output_size = output.numel();
int batch_size = input.size(0);
int t_size = input.size(1);
int channels = input.size(2);
int hw_size = input.size(3);
int group_size = shift.size(1);
int group_channel = channels / group_size;
int num_kernels = batch_size * hw_size * channels;
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "tin_shift_forward_cuda_kernel", [&] {
tin_shift_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(), shift.data_ptr<int>(),
output.data_ptr<scalar_t>(), batch_size, channels, t_size,
hw_size, group_size, group_channel);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void TINShiftBackwardCUDAKernelLauncher(Tensor grad_output, Tensor shift,
Tensor grad_input) {
int output_size = grad_output.numel();
int batch_size = grad_output.size(0);
int t_size = grad_output.size(1);
int channels = grad_output.size(2);
int hw_size = grad_output.size(3);
int group_size = shift.size(1);
int group_channel = channels / group_size;
int num_kernels = batch_size * hw_size * channels;
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "tin_shift_backward_cuda_kernel", [&] {
tin_shift_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
shift.data_ptr<int>(), grad_input.data_ptr<scalar_t>(),
batch_size, channels, t_size, hw_size, group_size,
group_channel);
});
AT_CUDA_CHECK(cudaGetLastError());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment