Unverified Commit 3d7bcc8f authored by pc's avatar pc Committed by GitHub
Browse files

add border_align support in parrots (#1086)

parent e05fb560
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void BorderAlignForwardCUDAKernelLauncher(const Tensor &input,
const Tensor &boxes, Tensor output,
Tensor argmax_idx,
const int pool_size);
void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output,
const Tensor &boxes,
const Tensor &argmax_idx,
Tensor grad_input,
const int pool_size);
void border_align_forward_cuda(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size) {
BorderAlignForwardCUDAKernelLauncher(input, boxes, output, argmax_idx,
pool_size);
}
void border_align_backward_cuda(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size) {
BorderAlignBackwardCUDAKernelLauncher(grad_output, boxes, argmax_idx,
grad_input, pool_size);
}
#endif
void border_align_forward(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size) {
if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(boxes);
CHECK_CUDA_INPUT(output);
CHECK_CUDA_INPUT(argmax_idx);
border_align_forward_cuda(input, boxes, output, argmax_idx, pool_size);
#else
AT_ERROR("BorderAlign is not compiled with GPU support");
#endif
} else {
AT_ERROR("BorderAlign is not implemented on CPU");
}
}
void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size) {
if (grad_output.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_output);
CHECK_CUDA_INPUT(boxes);
CHECK_CUDA_INPUT(argmax_idx);
CHECK_CUDA_INPUT(grad_input);
border_align_backward_cuda(grad_output, boxes, argmax_idx, grad_input,
pool_size);
#else
AT_ERROR("BorderAlign is not compiled with GPU support");
#endif
} else {
AT_ERROR("BorderAlign is not implemented on CPU");
}
}
#include "border_align_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void BorderAlignForwardCUDAKernelLauncher(const Tensor &input,
const Tensor &boxes, Tensor output,
Tensor argmax_idx,
const int pool_size) {
// shape assertion
AT_ASSERTM(input.ndimension() == 4,
"non-empty 4D(batch mode) tensor expected for input feature");
AT_ASSERTM(boxes.ndimension() == 3,
"boxes must be 3D tensor with size of [B, H*W, 4]");
int batch_size = input.size(0);
int feat_channels = input.size(1);
int channels = feat_channels / 4;
int height = input.size(2);
int width = input.size(3);
// shape [N, box_size, 4] for boxes. (x1, y1, x2, y2) format
int box_size = boxes.size(1);
// shape [N, channels, box_size, 4] for output
int nthreads = batch_size * channels * box_size;
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block(128, 4);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "border_align_forward_cuda_kernel", [&] {
border_align_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(nthreads), block, 0, stream>>>(
nthreads, input.data_ptr<scalar_t>(),
boxes.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
argmax_idx.data_ptr<int>(), channels, box_size, height, width,
pool_size);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output,
const Tensor &boxes,
const Tensor &argmax_idx,
Tensor grad_input,
const int pool_size) {
int batch_size = grad_input.size(0);
int feat_channels = grad_input.size(1);
int channels = feat_channels / 4;
int height = grad_input.size(2);
int width = grad_input.size(3);
int box_size = boxes.size(1);
int nthreads = batch_size * channels * box_size;
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block(128, 4);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "border_align_backward_cuda_kernel", [&] {
border_align_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(nthreads), block, 0, stream>>>(
nthreads, grad_output.data_ptr<scalar_t>(),
boxes.data_ptr<scalar_t>(), argmax_idx.data_ptr<int>(),
grad_input.data_ptr<scalar_t>(), channels, box_size, height,
width, pool_size);
});
AT_CUDA_CHECK(cudaGetLastError());
}
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "border_align_pytorch.h"
using namespace parrots;
void border_align_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int pool_size;
SSAttrs(attr).get<int>("pool_size", pool_size).done();
const auto& input = buildATensor(ctx, ins[0]);
const auto& boxes = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
auto argmax_idx = buildATensor(ctx, outs[1]);
border_align_forward_cuda(input, boxes, output, argmax_idx, pool_size);
}
void border_align_backward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int pool_size;
SSAttrs(attr).get<int>("pool_size", pool_size).done();
const auto& top_grad = buildATensor(ctx, ins[0]);
const auto& boxes = buildATensor(ctx, ins[1]);
const auto& argmax_idx = buildATensor(ctx, ins[2]);
auto bottom_grad = buildATensor(ctx, outs[0]);
border_align_backward_cuda(top_grad, boxes, argmax_idx, bottom_grad,
pool_size);
}
PARROTS_EXTENSION_REGISTER(border_align_forward)
.attr("pool_size")
.input(2)
.output(2)
.apply(border_align_forward_cuda_parrots)
.done();
PARROTS_EXTENSION_REGISTER(border_align_backward)
.attr("pool_size")
.input(3)
.output(1)
.apply(border_align_backward_cuda_parrots)
.done();
#ifndef BORDER_ALIGN_PYTORCH_H
#define BORDER_ALIGN_PYTORCH_H
#include <torch/extension.h>
using namespace at;
#ifdef MMCV_WITH_CUDA
void border_align_forward_cuda(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size);
void border_align_backward_cuda(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size);
#endif
#endif // BORDER_ALIGN_PYTORCH_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment