Unverified Commit b04ad69d authored by pc's avatar pc Committed by GitHub
Browse files

[Feature] Add 6 rotated detection ops in parrots (#1665)

* add active_rotated_filter, convex_iou, min_area_polygons, points_in_polygons, riroi_align_roteted, rotated_feature_align in parrots

* fix lint

* fix lint
parent 86ed509a
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/ActiveRotatingFilter.h
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void active_rotated_filter_forward_impl(const Tensor input,
const Tensor indices, Tensor output) {
DISPATCH_DEVICE_IMPL(active_rotated_filter_forward_impl, input, indices,
output);
}
void active_rotated_filter_backward_impl(const Tensor grad_out,
const Tensor indices, Tensor grad_in) {
DISPATCH_DEVICE_IMPL(active_rotated_filter_backward_impl, grad_out, indices,
grad_in);
}
void active_rotated_filter_forward(const Tensor input, const Tensor indices,
Tensor output) {
active_rotated_filter_forward_impl(input, indices, output);
}
void active_rotated_filter_backward(const Tensor grad_out, const Tensor indices,
Tensor grad_in) {
active_rotated_filter_backward_impl(grad_out, indices, grad_in);
}
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "active_rotated_filter_pytorch.h"
using namespace parrots;
#ifdef MMCV_WITH_CUDA
void active_rotated_filter_forward_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto input = buildATensor(ctx, ins[0]);
auto indices = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
active_rotated_filter_forward(input, indices, output);
}
void active_rotated_filter_backward_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto grad_out = buildATensor(ctx, ins[0]);
auto indices = buildATensor(ctx, ins[1]);
auto grad_in = buildATensor(ctx, outs[0]);
active_rotated_filter_backward(grad_out, indices, grad_in);
}
#endif
void active_rotated_filter_forward_cpu_parrots(
HostContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto input = buildATensor(ctx, ins[0]);
auto indices = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
active_rotated_filter_forward(input, indices, output);
}
void active_rotated_filter_backward_cpu_parrots(
HostContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto grad_out = buildATensor(ctx, ins[0]);
auto indices = buildATensor(ctx, ins[1]);
auto grad_in = buildATensor(ctx, outs[0]);
active_rotated_filter_backward(grad_out, indices, grad_in);
}
PARROTS_EXTENSION_REGISTER(active_rotated_filter_forward)
.input(2)
.output(1)
.apply(active_rotated_filter_forward_cpu_parrots)
#ifdef MMCV_WITH_CUDA
.apply(active_rotated_filter_forward_cuda_parrots)
#endif
.done();
PARROTS_EXTENSION_REGISTER(active_rotated_filter_backward)
.input(2)
.output(1)
.apply(active_rotated_filter_backward_cpu_parrots)
#ifdef MMCV_WITH_CUDA
.apply(active_rotated_filter_backward_cuda_parrots)
#endif
.done();
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ACTIVE_ROTATED_FILTER_PYTORCH_H
#define ACTIVE_ROTATED_FILTER_PYTORCH_H
#include <torch/extension.h>
using namespace at;
void active_rotated_filter_forward(const Tensor input, const Tensor indices,
Tensor output);
void active_rotated_filter_backward(const Tensor grad_out, const Tensor indices,
Tensor grad_in);
#endif // ACTIVE_ROTATED_FILTER_PYTORCH_H
// Copyright (c) OpenMMLab. All rights reserved
// modified from
// https://github.com/SDL-GuoZonghao/BeyondBoundingBox/tree/main/mmdet/ops/iou/src
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void convex_iou_impl(const Tensor pointsets, const Tensor polygons,
Tensor ious) {
DISPATCH_DEVICE_IMPL(convex_iou_impl, pointsets, polygons, ious);
}
void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious) {
convex_iou_impl(pointsets, polygons, ious);
}
void convex_giou_impl(const Tensor pointsets, const Tensor polygons,
Tensor output) {
DISPATCH_DEVICE_IMPL(convex_giou_impl, pointsets, polygons, output);
}
void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output) {
convex_giou_impl(pointsets, polygons, output);
}
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "convex_iou_pytorch.h"
using namespace parrots;
#ifdef MMCV_WITH_CUDA
void convex_iou_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto pointsets = buildATensor(ctx, ins[0]);
auto polygons = buildATensor(ctx, ins[1]);
auto ious = buildATensor(ctx, outs[0]);
convex_iou(pointsets, polygons, ious);
}
void convex_giou_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto pointsets = buildATensor(ctx, ins[0]);
auto polygons = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
convex_giou(pointsets, polygons, output);
}
PARROTS_EXTENSION_REGISTER(convex_iou)
.input(2)
.output(1)
.apply(convex_iou_forward_cuda_parrots)
.done();
PARROTS_EXTENSION_REGISTER(convex_giou)
.input(2)
.output(1)
.apply(convex_giou_forward_cuda_parrots)
.done();
#endif
// Copyright (c) OpenMMLab. All rights reserved
#ifndef CONVEX_IOU_PYTORCH_H
#define CONVEX_IOU_PYTORCH_H
#include <torch/extension.h>
using namespace at;
void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious);
void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output);
#endif // RIROI_ALIGN_ROTATED_PYTORCH_H
......@@ -992,6 +992,81 @@ REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA,
REGISTER_DEVICE_IMPL(roi_align_rotated_backward_impl, CUDA,
roi_align_rotated_backward_cuda);
void RiROIAlignRotatedForwardCUDAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
const int num_samples, const bool clockwise, const int channels,
const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const int num_orientations,
at::Tensor output);
void RiROIAlignRotatedBackwardCUDAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int num_samples, const bool clockwise, const int channels,
const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const int num_orientations,
at::Tensor bottom_grad);
void riroi_align_rotated_forward_cuda(Tensor features, Tensor rois,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
int size_rois = rois.size(1);
if (size_rois != 6) {
AT_ERROR("wrong roi size");
}
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(rois);
int num_channels = features.size(1) / num_orientations;
int data_height = features.size(2);
int data_width = features.size(3);
RiROIAlignRotatedForwardCUDAKernelLauncher(
features, rois, spatial_scale, num_samples, clockwise, num_channels,
data_height, data_width, num_rois, pooled_height, pooled_width,
num_orientations, output);
}
void riroi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
int size_rois = rois.size(1);
if (size_rois != 6) {
AT_ERROR("wrong roi size");
}
CHECK_CONTIGUOUS(top_grad);
CHECK_CONTIGUOUS(rois);
int num_channels = bottom_grad.size(1) / num_orientations;
int data_height = bottom_grad.size(2);
int data_width = bottom_grad.size(3);
RiROIAlignRotatedBackwardCUDAKernelLauncher(
top_grad, rois, spatial_scale, num_samples, clockwise, num_channels,
data_height, data_width, num_rois, pooled_height, pooled_width,
num_orientations, bottom_grad);
}
void riroi_align_rotated_forward_impl(Tensor features, Tensor rois,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise);
void riroi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise);
REGISTER_DEVICE_IMPL(riroi_align_rotated_forward_impl, CUDA,
riroi_align_rotated_forward_cuda);
REGISTER_DEVICE_IMPL(riroi_align_rotated_backward_impl, CUDA,
riroi_align_rotated_backward_cuda);
void RoiawarePool3dForwardCUDAKernelLauncher(
int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
int out_y, int out_z, const Tensor rois, const Tensor pts,
......@@ -1358,7 +1433,134 @@ void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim);
REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, CUDA,
hard_voxelize_forward_cuda);
REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, CUDA,
dynamic_voxelize_forward_cuda);
void RotatedFeatureAlignForwardCUDAKernelLauncher(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points,
Tensor output);
void RotatedFeatureAlignBackwardCUDAKernelLauncher(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points,
Tensor bottom_grad);
void rotated_feature_align_forward_cuda(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor output) {
RotatedFeatureAlignForwardCUDAKernelLauncher(features, best_bboxes,
spatial_scale, points, output);
};
void rotated_feature_align_backward_cuda(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor bottom_grad) {
RotatedFeatureAlignBackwardCUDAKernelLauncher(
top_grad, best_bboxes, spatial_scale, points, bottom_grad);
};
void rotated_feature_align_forward_impl(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor output);
void rotated_feature_align_backward_impl(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor bottom_grad);
REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, CUDA,
rotated_feature_align_forward_cuda);
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, CUDA,
rotated_feature_align_backward_cuda);
void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
const at::Tensor polygons,
const int rows, const int cols,
at::Tensor output);
void points_in_polygons_forward_cuda(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols) {
PointsInPolygonsForwardCUDAKernelLauncher(points, polygons, rows, cols,
output);
};
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols);
REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, CUDA,
points_in_polygons_forward_cuda);
void MinAreaPolygonsCUDAKernelLauncher(const Tensor pointsets, Tensor polygons);
void min_area_polygons_cuda(const Tensor pointsets, Tensor polygons) {
MinAreaPolygonsCUDAKernelLauncher(pointsets, polygons);
}
void min_area_polygons_impl(const Tensor pointsets, Tensor polygons);
REGISTER_DEVICE_IMPL(min_area_polygons_impl, CUDA, min_area_polygons_cuda);
void ActiveRotatedFilterForwardCUDAKernelLauncher(const Tensor input,
const Tensor indices,
Tensor output);
void ActiveRotatedFilterBackwardCUDAKernelLauncher(const Tensor grad_out,
const Tensor indices,
Tensor grad_in);
void active_rotated_filter_forward_cuda(const Tensor input,
const Tensor indices, Tensor output) {
ActiveRotatedFilterForwardCUDAKernelLauncher(input, indices, output);
};
void active_rotated_filter_backward_cuda(const Tensor grad_out,
const Tensor indices, Tensor grad_in) {
ActiveRotatedFilterBackwardCUDAKernelLauncher(grad_out, indices, grad_in);
};
void active_rotated_filter_forward_impl(const Tensor input,
const Tensor indices, Tensor output);
void active_rotated_filter_backward_impl(const Tensor grad_out,
const Tensor indices, Tensor grad_in);
REGISTER_DEVICE_IMPL(active_rotated_filter_forward_impl, CUDA,
active_rotated_filter_forward_cuda);
REGISTER_DEVICE_IMPL(active_rotated_filter_backward_impl, CUDA,
active_rotated_filter_backward_cuda);
void ConvexIoUCUDAKernelLauncher(const Tensor pointsets, const Tensor polygons,
Tensor ious);
void ConvexGIoUCUDAKernelLauncher(const Tensor pointsets, const Tensor polygons,
Tensor output);
void convex_iou_cuda(const Tensor pointsets, const Tensor polygons,
Tensor ious) {
ConvexIoUCUDAKernelLauncher(pointsets, polygons, ious);
}
void convex_giou_cuda(const Tensor pointsets, const Tensor polygons,
Tensor output) {
ConvexGIoUCUDAKernelLauncher(pointsets, polygons, output);
}
void convex_iou_impl(const Tensor pointsets, const Tensor polygons,
Tensor ious);
void convex_giou_impl(const Tensor pointsets, const Tensor polygons,
Tensor output);
REGISTER_DEVICE_IMPL(convex_iou_impl, CUDA, convex_iou_cuda);
REGISTER_DEVICE_IMPL(convex_giou_impl, CUDA, convex_giou_cuda);
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void min_area_polygons_impl(const Tensor pointsets, Tensor polygons) {
DISPATCH_DEVICE_IMPL(min_area_polygons_impl, pointsets, polygons);
}
void min_area_polygons(const Tensor pointsets, Tensor polygons) {
min_area_polygons_impl(pointsets, polygons);
}
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "min_area_polygons_pytorch.h"
using namespace parrots;
#ifdef MMCV_WITH_CUDA
void min_area_polygons_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto pointsets = buildATensor(ctx, ins[0]);
auto polygons = buildATensor(ctx, outs[0]);
min_area_polygons(pointsets, polygons);
}
PARROTS_EXTENSION_REGISTER(min_area_polygons)
.input(1)
.output(1)
.apply(min_area_polygons_cuda_parrots)
.done();
#endif
// Copyright (c) OpenMMLab. All rights reserved
#ifndef MIN_AREA_POLYGONS_PYTORCH_H
#define MIN_AREA_POLYGONS_PYTORCH_H
#include <torch/extension.h>
using namespace at;
void min_area_polygons(const Tensor pointsets, Tensor polygons);
#endif // MIN_AREA_POLYGONS_PYTORCH_H
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols) {
DISPATCH_DEVICE_IMPL(points_in_polygons_forward_impl, points, polygons,
output, rows, cols);
}
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output) {
int rows = points.size(0);
int cols = polygons.size(0);
points_in_polygons_forward_impl(points, polygons, output, rows, cols);
}
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "points_in_polygons_pytorch.h"
using namespace parrots;
#ifdef MMCV_WITH_CUDA
void points_in_polygons_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto points = buildATensor(ctx, ins[0]);
auto polygons = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
points_in_polygons_forward(points, polygons, output);
}
PARROTS_EXTENSION_REGISTER(points_in_polygons_forward)
.input(2)
.output(1)
.apply(points_in_polygons_cuda_parrots)
.done();
#endif
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINTS_IN_POLYGONS_PYTORCH_H
#define POINTS_IN_POLYGONS_PYTORCH_H
#include <torch/extension.h>
using namespace at;
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output);
#endif // POINTS_IN_POLYGONS_PYTORCH_H
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void riroi_align_rotated_forward_impl(Tensor features, Tensor rois,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise) {
DISPATCH_DEVICE_IMPL(riroi_align_rotated_forward_impl, features, rois, output,
pooled_height, pooled_width, spatial_scale, num_samples,
num_orientations, clockwise);
}
void riroi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise) {
DISPATCH_DEVICE_IMPL(riroi_align_rotated_backward_impl, top_grad, rois,
bottom_grad, pooled_height, pooled_width, spatial_scale,
num_samples, num_orientations, clockwise);
}
void riroi_align_rotated_forward(Tensor features, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale, int num_samples,
int num_orientations, bool clockwise) {
riroi_align_rotated_forward_impl(features, rois, output, pooled_height,
pooled_width, spatial_scale, num_samples,
num_orientations, clockwise);
}
void riroi_align_rotated_backward(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise) {
riroi_align_rotated_backward_impl(top_grad, rois, bottom_grad, pooled_height,
pooled_width, spatial_scale, num_samples,
num_orientations, clockwise);
}
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "riroi_align_rotated_pytorch.h"
using namespace parrots;
#ifdef MMCV_WITH_CUDA
void riroi_align_rotated_forward_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int pooled_height;
int pooled_width;
float spatial_scale;
int sample_num;
int num_orientations;
bool clockwise;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.get<int>("num_samples", sample_num)
.get<int>("num_orientations", num_orientations)
.get<bool>("clockwise", clockwise)
.done();
auto input = buildATensor(ctx, ins[0]);
auto rois = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
riroi_align_rotated_forward(input, rois, output, pooled_height, pooled_width,
spatial_scale, sample_num, num_orientations,
clockwise);
}
void riroi_align_rotated_backward_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int pooled_height;
int pooled_width;
float spatial_scale;
int sample_num;
int num_orientations;
bool clockwise;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.get<int>("num_samples", sample_num)
.get<int>("num_orientations", num_orientations)
.get<bool>("clockwise", clockwise)
.done();
auto grad_output = buildATensor(ctx, ins[0]);
auto rois = buildATensor(ctx, ins[1]);
auto grad_input = buildATensor(ctx, outs[0]);
riroi_align_rotated_backward(grad_output, rois, grad_input, pooled_height,
pooled_width, spatial_scale, sample_num,
num_orientations, clockwise);
}
PARROTS_EXTENSION_REGISTER(riroi_align_rotated_forward)
.attr("pooled_height")
.attr("pooled_width")
.attr("spatial_scale")
.attr("num_samples")
.attr("num_orientations")
.attr("clockwise")
.input(2)
.output(1)
.apply(riroi_align_rotated_forward_cuda_parrots)
.done();
PARROTS_EXTENSION_REGISTER(riroi_align_rotated_backward)
.attr("pooled_height")
.attr("pooled_width")
.attr("spatial_scale")
.attr("num_samples")
.attr("num_orientations")
.attr("clockwise")
.input(2)
.output(1)
.apply(riroi_align_rotated_backward_cuda_parrots)
.done();
#endif
// Copyright (c) OpenMMLab. All rights reserved
#ifndef RIROI_ALIGN_ROTATED_PYTORCH_H
#define RIROI_ALIGN_ROTATED_PYTORCH_H
#include <torch/extension.h>
using namespace at;
void riroi_align_rotated_forward(Tensor features, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale, int num_samples,
int num_orientations, bool clockwise);
void riroi_align_rotated_backward(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise);
#endif // RIROI_ALIGN_ROTATED_PYTORCH_H
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_cuda.cpp
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void rotated_feature_align_forward_impl(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor output) {
DISPATCH_DEVICE_IMPL(rotated_feature_align_forward_impl, features,
best_bboxes, spatial_scale, points, output);
}
void rotated_feature_align_backward_impl(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor bottom_grad) {
DISPATCH_DEVICE_IMPL(rotated_feature_align_backward_impl, top_grad,
best_bboxes, spatial_scale, points, bottom_grad);
}
void rotated_feature_align_forward(const Tensor features,
const Tensor best_bboxes, Tensor output,
const float spatial_scale,
const int points) {
rotated_feature_align_forward_impl(features, best_bboxes, spatial_scale,
points, output);
}
void rotated_feature_align_backward(const Tensor top_grad,
const Tensor best_bboxes,
Tensor bottom_grad,
const float spatial_scale,
const int points) {
rotated_feature_align_backward_impl(top_grad, best_bboxes, spatial_scale,
points, bottom_grad);
}
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "rotated_feature_align_pytorch.h"
using namespace parrots;
#ifdef MMCV_WITH_CUDA
void rotated_feature_align_forward_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
float spatial_scale;
int points;
SSAttrs(attr)
.get<float>("spatial_scale", spatial_scale)
.get<int>("points", points)
.done();
auto features = buildATensor(ctx, ins[0]);
auto best_bboxes = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
rotated_feature_align_forward(features, best_bboxes, output, spatial_scale,
points);
}
void rotated_feature_align_backward_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
float spatial_scale;
int points;
SSAttrs(attr)
.get<float>("spatial_scale", spatial_scale)
.get<int>("points", points)
.done();
auto grad_output = buildATensor(ctx, ins[0]);
auto best_bboxes = buildATensor(ctx, ins[1]);
auto grad_input = buildATensor(ctx, outs[0]);
rotated_feature_align_backward(grad_output, best_bboxes, grad_input,
spatial_scale, points);
}
PARROTS_EXTENSION_REGISTER(rotated_feature_align_forward)
.attr("spatial_scale")
.attr("points")
.input(2)
.output(1)
.apply(rotated_feature_align_forward_cuda_parrots)
.done();
PARROTS_EXTENSION_REGISTER(rotated_feature_align_backward)
.attr("spatial_scale")
.attr("points")
.input(2)
.output(1)
.apply(rotated_feature_align_backward_cuda_parrots)
.done();
#endif
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ROTATED_FEATURE_ALIGN_PYTORCH_H
#define ROTATED_FEATURE_ALIGN_PYTORCH_H
#include <torch/extension.h>
using namespace at;
void rotated_feature_align_forward(const Tensor features,
const Tensor best_bboxes, Tensor output,
const float spatial_scale, const int points);
void rotated_feature_align_backward(const Tensor top_grad,
const Tensor best_bboxes,
Tensor bottom_grad,
const float spatial_scale,
const int points);
#endif // ROTATED_FEATURE_ALIGN_PYTORCH_H
......@@ -15,7 +15,7 @@ void ConvexIoUCUDAKernelLauncher(const Tensor pointsets, const Tensor polygons,
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pointsets.scalar_type(), "convex_iou_cuda_kernel", ([&] {
convex_iou_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK / 2, 0, stream>>>(
num_pointsets, num_polygons, pointsets.data_ptr<scalar_t>(),
polygons.data_ptr<scalar_t>(), ious.data_ptr<scalar_t>());
}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment