Unverified Commit 9f5a03dc authored by pc's avatar pc Committed by GitHub
Browse files

[Feature] Add diff_iiou_rotated op in parrots (#1911)

parent 057c0323
...@@ -924,20 +924,20 @@ REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda); ...@@ -924,20 +924,20 @@ REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);
REGISTER_DEVICE_IMPL(roi_align_backward_impl, CUDA, roi_align_backward_cuda); REGISTER_DEVICE_IMPL(roi_align_backward_impl, CUDA, roi_align_backward_cuda);
void ROIAlignRotatedForwardCUDAKernelLauncher( void ROIAlignRotatedForwardCUDAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale, const at::Tensor input, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise, const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois, const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output); const int pooled_height, const int pooled_width, at::Tensor output);
void ROIAlignRotatedBackwardCUDAKernelLauncher( void ROIAlignRotatedBackwardCUDAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale, const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise, const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois, const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad); const int pooled_height, const int pooled_width, at::Tensor bottom_grad);
void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output, void roi_align_rotated_forward_cuda(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width, int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio, float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise) { bool aligned, bool clockwise) {
// Number of ROIs // Number of ROIs
int num_rois = rois.size(0); int num_rois = rois.size(0);
...@@ -947,11 +947,11 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output, ...@@ -947,11 +947,11 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
AT_ERROR("wrong roi size"); AT_ERROR("wrong roi size");
} }
int num_channels = features.size(1); int num_channels = input.size(1);
int data_height = features.size(2); int data_height = input.size(2);
int data_width = features.size(3); int data_width = input.size(3);
ROIAlignRotatedForwardCUDAKernelLauncher( ROIAlignRotatedForwardCUDAKernelLauncher(
features, rois, spatial_scale, sample_ratio, aligned, clockwise, input, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height, num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, output); aligned_width, output);
} }
...@@ -959,7 +959,7 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output, ...@@ -959,7 +959,7 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois, void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height, Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale, int aligned_width, float spatial_scale,
int sample_ratio, bool aligned, int sampling_ratio, bool aligned,
bool clockwise) { bool clockwise) {
// Number of ROIs // Number of ROIs
int num_rois = rois.size(0); int num_rois = rois.size(0);
...@@ -972,20 +972,20 @@ void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois, ...@@ -972,20 +972,20 @@ void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
int data_height = bottom_grad.size(2); int data_height = bottom_grad.size(2);
int data_width = bottom_grad.size(3); int data_width = bottom_grad.size(3);
ROIAlignRotatedBackwardCUDAKernelLauncher( ROIAlignRotatedBackwardCUDAKernelLauncher(
top_grad, rois, spatial_scale, sample_ratio, aligned, clockwise, top_grad, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height, num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, bottom_grad); aligned_width, bottom_grad);
} }
void roi_align_rotated_forward_impl(Tensor features, Tensor rois, Tensor output, void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width, int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio, float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise); bool aligned, bool clockwise);
void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois, void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height, Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale, int aligned_width, float spatial_scale,
int sample_ratio, bool aligned, int sampling_ratio, bool aligned,
bool clockwise); bool clockwise);
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA, REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA,
roi_align_rotated_forward_cuda); roi_align_rotated_forward_cuda);
...@@ -1564,3 +1564,19 @@ void convex_giou_impl(const Tensor pointsets, const Tensor polygons, ...@@ -1564,3 +1564,19 @@ void convex_giou_impl(const Tensor pointsets, const Tensor polygons,
REGISTER_DEVICE_IMPL(convex_iou_impl, CUDA, convex_iou_cuda); REGISTER_DEVICE_IMPL(convex_iou_impl, CUDA, convex_iou_cuda);
REGISTER_DEVICE_IMPL(convex_giou_impl, CUDA, convex_giou_cuda); REGISTER_DEVICE_IMPL(convex_giou_impl, CUDA, convex_giou_cuda);
Tensor DiffIoURotatedSortVerticesCUDAKernelLauncher(Tensor vertices,
Tensor mask,
Tensor num_valid);
Tensor diff_iou_rotated_sort_vertices_forward_cuda(Tensor vertices, Tensor mask,
Tensor num_valid) {
return DiffIoURotatedSortVerticesCUDAKernelLauncher(vertices, mask,
num_valid);
}
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid);
REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA,
diff_iou_rotated_sort_vertices_forward_cuda);
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid) {
return DISPATCH_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl,
vertices, mask, num_valid);
}
Tensor diff_iou_rotated_sort_vertices_forward(Tensor vertices, Tensor mask,
Tensor num_valid) {
return diff_iou_rotated_sort_vertices_forward_impl(vertices, mask, num_valid);
}
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "diff_iou_rotated_pytorch.h"
using namespace parrots;
#ifdef MMCV_WITH_CUDA
void diff_iou_rotated_sort_vertices_forward_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
at::Tensor boxes, scores, dets;
auto vertices = buildATensor(ctx, ins[0]);
auto mask = buildATensor(ctx, ins[1]);
auto num_valid = buildATensor(ctx, ins[2]);
auto out =
diff_iou_rotated_sort_vertices_forward_cuda(vertices, mask, num_valid);
updateDArray(ctx, out, outs[0]);
}
PARROTS_EXTENSION_REGISTER(diff_iou_rotated_sort_vertices_forward)
.input(3)
.output(1)
.apply(diff_iou_rotated_sort_vertices_forward_cuda_parrots)
.done();
#endif
// Copyright (c) OpenMMLab. All rights reserved
#ifndef DIFF_IOU_ROTATED_PYTORCH_H
#define DIFF_IOU_ROTATED_PYTORCH_H
#include <torch/extension.h>
using namespace at;
Tensor diff_iou_rotated_sort_vertices_forward_cuda(Tensor vertices, Tensor mask,
Tensor num_valid);
#endif // DIFF_IOU_ROTATED_PYTORCH_H
...@@ -17,7 +17,8 @@ class SortVertices(Function): ...@@ -17,7 +17,8 @@ class SortVertices(Function):
def forward(ctx, vertices, mask, num_valid): def forward(ctx, vertices, mask, num_valid):
idx = ext_module.diff_iou_rotated_sort_vertices_forward( idx = ext_module.diff_iou_rotated_sort_vertices_forward(
vertices, mask, num_valid) vertices, mask, num_valid)
ctx.mark_non_differentiable(idx) if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return idx return idx
@staticmethod @staticmethod
......
...@@ -36,6 +36,7 @@ else: ...@@ -36,6 +36,7 @@ else:
'ms_deform_attn_forward', 'ms_deform_attn_forward',
'pixel_group', 'pixel_group',
'contour_expand', 'contour_expand',
'diff_iou_rotated_sort_vertices_forward',
] ]
def get_fake_func(name, e): def get_fake_func(name, e):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment