Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
ba8aa764
Unverified
Commit
ba8aa764
authored
Jun 01, 2023
by
tudejiang79
Committed by
GitHub
Jun 01, 2023
Browse files
[Refactor] Repalce the implementation of roi_align_rotated with mlu-ops (#2808)
parent
d2aecbe4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
671 deletions
+43
-671
mmcv/ops/csrc/common/mlu/roi_align_rotated_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/roi_align_rotated_mlu_kernel.mlu
+0
-490
mmcv/ops/csrc/common/mlu/roi_align_rotated_utils.hpp
mmcv/ops/csrc/common/mlu/roi_align_rotated_utils.hpp
+0
-24
mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp
+43
-156
tests/test_ops/test_roi_align_rotated.py
tests/test_ops/test_roi_align_rotated.py
+0
-1
No files found.
mmcv/ops/csrc/common/mlu/roi_align_rotated_mlu_kernel.mlu
deleted
100644 → 0
View file @
d2aecbe4
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* OR IMPLIED, INCLUDING BUvoid NOKType LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENvoid SHALL THE AUTHORS OR COPYRIGHKType HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORvoid OR OTHERWISE, ARISING FROM, OUKType OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include "roi_align_rotated_utils.hpp"
#define ROI_OFFSET 6
#define SAMPLING_NUM 4
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void swap(T &a, T &b) {
T tmp = a;
a = b;
b = tmp;
}
template <typename T>
__mlu_func__ void bilinearInterpolate(const int input_height,
const int input_width, T x, T y, T *w1,
T *w2, T *w3, T *w4, int *x_low,
int *x_high, int *y_low, int *y_high,
bool *empty) {
// deal with case that the point is out of feature map boundary
if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) {
*empty = true;
return;
}
if (y <= 0) y = (T)0;
if (x <= 0) x = (T)0;
*y_low = int(y);
*x_low = int(x);
if (*y_low >= input_height - 1) {
*y_high = *y_low = input_height - 1;
y = (T)(*y_low);
} else {
*y_high = *y_low + 1;
}
if (*x_low >= input_width - 1) {
*x_high = *x_low = input_width - 1;
x = T(*x_low);
} else {
*x_high = *x_low + 1;
}
T ly = y - *y_low;
T lx = x - *x_low;
T hy = 1.0 - ly;
T hx = 1.0 - lx;
*w1 = hy * hx;
*w2 = hy * lx;
*w3 = ly * hx;
*w4 = ly * lx;
return;
}
template <typename T>
__mlu_func__ void getRoiBinInfo(const T *rois_dram, const int bin_i,
const RoiAlignRotatedParams ¶ms,
int *batch_idx, int *roi_n, int *pw, int *ph,
T *roi_center_x, T *roi_center_y, T *roi_width,
T *roi_height, T *theta) {
T offset = params.aligned ? (T)0.5 : (T)0.0;
*pw = bin_i % params.pooled_width;
*ph = (bin_i / params.pooled_width) % params.pooled_height;
*roi_n = bin_i / params.pooled_width / params.pooled_height;
const T *roi_info = rois_dram + (*roi_n) * ROI_OFFSET;
*batch_idx = (int)roi_info[0];
*roi_center_x = roi_info[1] * (T)params.spatial_scale - offset;
*roi_center_y = roi_info[2] * (T)params.spatial_scale - offset;
*roi_width = roi_info[3] * (T)params.spatial_scale;
*roi_height = roi_info[4] * (T)params.spatial_scale;
*theta = roi_info[5];
if (params.clockwise) {
*theta = -(*theta);
}
if (!params.aligned) {
*roi_width = *roi_width > (T)1.0 ? *roi_width : (T)1.0;
*roi_height = *roi_height > (T)1.0 ? *roi_height : (T)1.0;
}
}
template <typename T>
__mlu_func__ void roiAlignRotatedForward(const T *input_dram,
const T *rois_dram, const int batch,
const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams ¶ms,
T *output_dram) {
int align_base_128 = NFU_ALIGN_SIZE / sizeof(T);
int channel_max_cap = MAX_NRAM_SIZE / sizeof(T) / (2 * SAMPLING_NUM + 1);
channel_max_cap = channel_max_cap / align_base_128 * align_base_128;
int channel_align = channel < channel_max_cap ? channel : channel_max_cap;
channel_align = CEIL_ALIGN(channel_align, align_base_128);
T *nram_out = (T *)nram_buffer;
T *nram_ping = nram_out + channel_align;
T *nram_pong = nram_ping + channel_align * SAMPLING_NUM;
int bin_first = taskId;
int bin_end = rois_num * params.pooled_height * params.pooled_width;
for (int bin_i = bin_first; bin_i < bin_end; bin_i += taskDim) {
T roi_center_x, roi_center_y, roi_width, roi_height, theta;
int batch_idx, roi_n, pw, ph;
getRoiBinInfo(rois_dram, bin_i, params, &batch_idx, &roi_n, &pw, &ph,
&roi_center_x, &roi_center_y, &roi_width, &roi_height,
&theta);
T bin_size_h = roi_height / params.pooled_height;
T bin_size_w = roi_width / params.pooled_width;
int roi_bin_grid_h =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_height / params.pooled_height);
int roi_bin_grid_w =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_width / params.pooled_width);
T roi_start_y = -roi_height / 2;
T roi_start_x = -roi_width / 2;
const int bin_dim = roi_bin_grid_h * roi_bin_grid_w > 1
? roi_bin_grid_h * roi_bin_grid_w
: 1;
T cos_theta = std::cos(theta);
T sin_theta = std::sin(theta);
T zero_sign = 1.0f / bin_dim;
bool is_first_sample = true;
int src_offset = 0;
int dst_offset = 0;
int c_rem, c_slice, c_slice_align, pongc_slice, pongc_slice_align;
for (int c_offset = 0; c_offset < channel; c_offset += channel_align) {
__bang_write_value(nram_out, channel_align, (T)0);
c_rem = channel - c_offset;
c_slice = channel_align > c_rem ? c_rem : channel_align;
c_slice_align = CEIL_ALIGN(c_slice, align_base_128);
is_first_sample = true;
for (int iy = 0; iy < roi_bin_grid_h; ++iy) {
const T yy = roi_start_y + ph * bin_size_h +
T(iy + 0.5) * bin_size_h / roi_bin_grid_h;
for (int ix = 0; ix < roi_bin_grid_w; ++ix) {
const T xx = roi_start_x + pw * bin_size_w +
T(ix + 0.5) * bin_size_w / roi_bin_grid_w;
int sample_i = iy * roi_bin_grid_w + ix;
T y = yy * cos_theta - xx * sin_theta + roi_center_y;
T x = yy * sin_theta + xx * cos_theta + roi_center_x;
T w1, w2, w3, w4;
bool empty = false;
int x_low, x_high, y_low, y_high;
bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low,
&x_high, &y_low, &y_high, &empty);
/*******************************************************
| ping | pong |
|------|-----|-----|-----|-----|-----|-----|-----|-----|
|output| p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 |
|------|-----|-----|-----|-----|-----|-----|-----|-----|
********************************************************/
if (is_first_sample && !empty) {
// load input data from dram to nram
__bang_write_value(nram_ping, SAMPLING_NUM * c_slice_align, (T)0);
src_offset =
(batch_idx * height * width + y_low * width + x_low) * channel +
c_offset;
dst_offset = 0;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset = (batch_idx * height * width + y_low * width + x_high) *
channel +
c_offset;
dst_offset = c_slice_align;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset = (batch_idx * height * width + y_high * width + x_low) *
channel +
c_offset;
dst_offset = c_slice_align * 2;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + y_high * width + x_high) *
channel +
c_offset;
dst_offset = c_slice_align * 3;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
}
// load next input data to nram
if (sample_i + 1 < bin_dim) {
int p_iy = (sample_i + 1) / roi_bin_grid_w;
int p_ix = (sample_i + 1) % roi_bin_grid_w;
const T p_yy = roi_start_y + ph * bin_size_h +
T(p_iy + 0.5) * bin_size_h / roi_bin_grid_h;
const T p_xx = roi_start_x + pw * bin_size_w +
T(p_ix + 0.5) * bin_size_w / roi_bin_grid_w;
T p_y = p_yy * cos_theta - p_xx * sin_theta + roi_center_y;
T p_x = p_yy * sin_theta + p_xx * cos_theta + roi_center_x;
T p_w1, p_w2, p_w3, p_w4;
bool p_empty = false;
int p_x_low, p_x_high, p_y_low, p_y_high;
bilinearInterpolate(height, width, p_x, p_y, &p_w1, &p_w2, &p_w3,
&p_w4, &p_x_low, &p_x_high, &p_y_low, &p_y_high,
&p_empty);
pongc_slice = c_slice;
pongc_slice_align = c_slice_align;
if (!p_empty) {
__bang_write_value(nram_pong, SAMPLING_NUM * pongc_slice_align,
(T)0);
src_offset =
(batch_idx * height * width + p_y_low * width + p_x_low) *
channel +
c_offset;
dst_offset = 0;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + p_y_low * width + p_x_high) *
channel +
c_offset;
dst_offset = pongc_slice_align;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + p_y_high * width + p_x_low) *
channel +
c_offset;
dst_offset = pongc_slice_align * 2;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + p_y_high * width + p_x_high) *
channel +
c_offset;
dst_offset = pongc_slice_align * 3;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
}
}
T *tmp_sum = nram_ping + 3 * c_slice_align;
if (empty) {
__bang_write_value(tmp_sum, c_slice_align, T(0));
} else {
__bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align);
__bang_mul_scalar(nram_ping + c_slice_align,
nram_ping + c_slice_align, w2, c_slice_align);
__bang_mul_scalar(nram_ping + 2 * c_slice_align,
nram_ping + 2 * c_slice_align, w3, c_slice_align);
__bang_mul_scalar(nram_ping + 3 * c_slice_align,
nram_ping + 3 * c_slice_align, w4, c_slice_align);
__bang_sumpool(tmp_sum, nram_ping, c_slice_align, 1, SAMPLING_NUM,
1, SAMPLING_NUM, 1, 1);
}
__bang_add(nram_out, nram_out, tmp_sum, c_slice_align);
swap(nram_ping, nram_pong);
__asm__ volatile("sync;");
is_first_sample = false;
}
}
__bang_mul_scalar(nram_out, nram_out, zero_sign, c_slice_align);
// store the result to dram
int output_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
channel +
c_offset;
__memcpy(output_dram + output_offset, nram_out, c_slice * sizeof(T),
NRAM2GDRAM);
}
}
}
template <typename T>
__mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram,
const T *rois_dram, const int batch,
const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams ¶ms,
T *bottom_grad_dram) {
int align_base_128 = NFU_ALIGN_SIZE / sizeof(T);
int channel_align = CEIL_ALIGN(channel, align_base_128);
unsigned int max_element = MAX_NRAM_SIZE / sizeof(T);
int c_limit = max_element >> 2;
c_limit = c_limit > channel_align ? channel_align : c_limit;
T *nram_ping = (T *)nram_buffer;
T *nram_pong = nram_ping + 2 * c_limit;
T *nram_output = nullptr;
int bin_first = taskId;
int bin_end = rois_num * params.pooled_height * params.pooled_width;
bool is_first_bin = true;
T roi_center_x, roi_center_y, roi_width, roi_height, theta;
int batch_idx, roi_n, pw, ph;
T pong_roi_center_x, pong_roi_center_y, pong_roi_width, pong_roi_height,
pong_theta;
int pong_batch_idx, pong_roi_n, pong_pw, pong_ph;
for (int bin_i = bin_first; bin_i < bin_end; bin_i += taskDim) {
getRoiBinInfo(rois_dram, bin_i, params, &batch_idx, &roi_n, &pw, &ph,
&roi_center_x, &roi_center_y, &roi_width, &roi_height,
&theta);
T bin_size_h = roi_height / params.pooled_height;
T bin_size_w = roi_width / params.pooled_width;
int roi_bin_grid_h =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_height / params.pooled_height);
int roi_bin_grid_w =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_width / params.pooled_width);
T roi_start_y = -roi_height / 2;
T roi_start_x = -roi_width / 2;
const int bin_dim = roi_bin_grid_h * roi_bin_grid_w > 1
? roi_bin_grid_h * roi_bin_grid_w
: 1;
T cos_theta = std::cos(theta);
T sin_theta = std::sin(theta);
T zero_sign = 1.0f / bin_dim;
int c_rem, c_slice, pongc_slice, c_offset;
c_rem = channel;
c_offset = 0;
/****************************************
| ping | pong |
|---------|---------|---------|---------|
| input | output | input | output |
|---------|---------|---------|---------|
*****************************************/
if (is_first_bin) {
// load the first top_grad to nram
c_slice = c_limit < c_rem ? c_limit : c_rem;
int top_grad_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
channel;
__memcpy(nram_ping, top_grad_dram + top_grad_offset, c_slice * sizeof(T),
GDRAM2NRAM);
}
nram_output = nram_ping + c_limit;
while (c_rem > 0) {
c_slice = c_slice < c_rem ? c_slice : c_rem;
// load the next top_grad to nram
if (c_rem - c_slice > 0) {
// load the rest channels to nram
pongc_slice = (c_rem - c_slice > c_slice) ? c_slice : c_rem - c_slice;
int top_grad_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
channel +
c_offset + c_slice;
__memcpy_async(nram_pong, top_grad_dram + top_grad_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
} else if (bin_i + taskDim < bin_end) {
// load next bin's data to nram
getRoiBinInfo(rois_dram, bin_i + taskDim, params, &pong_batch_idx,
&pong_roi_n, &pong_pw, &pong_ph, &pong_roi_center_x,
&pong_roi_center_y, &pong_roi_width, &pong_roi_height,
&pong_theta);
pongc_slice = c_limit < channel ? c_limit : channel;
int top_grad_offset = ((pong_roi_n * params.pooled_height + pong_ph) *
params.pooled_width +
pong_pw) *
channel;
__memcpy_async(nram_pong, top_grad_dram + top_grad_offset,
c_slice * sizeof(T), GDRAM2NRAM);
}
// comput the output in a single bin
for (int iy = 0; iy < roi_bin_grid_h; ++iy) {
const T yy = roi_start_y + ph * bin_size_h +
T(iy + 0.5) * bin_size_h / roi_bin_grid_h;
for (int ix = 0; ix < roi_bin_grid_w; ++ix) {
const T xx = roi_start_x + pw * bin_size_w +
T(ix + 0.5) * bin_size_w / roi_bin_grid_w;
T y = yy * cos_theta - xx * sin_theta + roi_center_y;
T x = yy * sin_theta + xx * cos_theta + roi_center_x;
T w1, w2, w3, w4;
bool empty = false;
int x_low, x_high, y_low, y_high;
bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low,
&x_high, &y_low, &y_high, &empty);
if (empty) {
continue;
} else {
__bang_mul_scalar(nram_output, nram_ping, w1 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_low * width * channel + x_low * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_scalar(nram_output, nram_ping, w2 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_low * width * channel + x_high * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_scalar(nram_output, nram_ping, w3 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_high * width * channel + x_low * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_scalar(nram_output, nram_ping, w4 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_high * width * channel + x_high * channel + c_offset,
(T *)nram_output, c_slice);
}
}
}
swap(nram_ping, nram_pong);
c_rem -= c_slice;
c_offset += c_slice;
__asm__ volatile("sync;");
}
is_first_bin = false;
}
}
__mlu_global__ void MLUUnion1KernelRoiAlignRotatedForward(
const void *features, const void *rois, void *output, const int batch,
const int height, const int width, const int channel, const int rois_num,
const RoiAlignRotatedParams rroiAlignParams,
const cnrtDataType_t data_type) {
if (0x80 == coreId) {
return;
}
if (data_type == CNRT_FLOAT32) {
roiAlignRotatedForward((float *)features, (float *)rois, batch, height,
width, channel, rois_num, rroiAlignParams,
(float *)output);
} else {
roiAlignRotatedForward((half *)features, (half *)rois, batch, height, width,
channel, rois_num, rroiAlignParams, (half *)output);
}
}
__mlu_global__ void MLUUnion1KernelRoiAlignRotatedBackward(
const void *top_grad, const void *rois, void *bottom_grad, const int batch,
const int height, const int width, const int channel, const int rois_num,
const RoiAlignRotatedParams rroiAlignParams,
const cnrtDataType_t data_type) {
if (0x80 == coreId) {
return;
}
if (data_type == CNRT_FLOAT32) {
roiAlignRotatedBackward((float *)top_grad, (float *)rois, batch, height,
width, channel, rois_num, rroiAlignParams,
(float *)bottom_grad);
} else {
roiAlignRotatedBackward((half *)top_grad, (half *)rois, batch, height,
width, channel, rois_num, rroiAlignParams,
(half *)bottom_grad);
}
}
void KernelRoiAlignRotatedForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const void *features, const void *rois,
void *output, const int batch, const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams roiAlignRotatedParams) {
MLUUnion1KernelRoiAlignRotatedForward<<<k_dim, k_type, queue>>>(
features, rois, output, batch, height, width, channel, rois_num,
roiAlignRotatedParams, d_type);
}
void KernelRoiAlignRotatedBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const void *top_grad, const void *rois,
void *bottom_grad, const int batch, const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams roiAlignRotatedParams) {
MLUUnion1KernelRoiAlignRotatedBackward<<<k_dim, k_type, queue>>>(
top_grad, rois, bottom_grad, batch, height, width, channel, rois_num,
roiAlignRotatedParams, d_type);
}
mmcv/ops/csrc/common/mlu/roi_align_rotated_utils.hpp
deleted
100644 → 0
View file @
d2aecbe4
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef ROI_ALIGN_ROTATED_UTILS_HPP_
#define ROI_ALIGN_ROTATED_UTILS_HPP_
struct
RoiAlignRotatedParams
{
int
pooled_height
;
int
pooled_width
;
int
sample_ratio
;
float
spatial_scale
;
bool
aligned
;
bool
clockwise
;
};
#endif // ROI_ALIGN_ROTATED_UTILS_HPP_
mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp
100755 → 100644
View file @
ba8aa764
...
...
@@ -9,37 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#include "roi_align_rotated_utils.hpp"
namespace
{
void
policyFunc
(
int
bin_num
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
unsigned
int
use_cluster
=
(
bin_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
}
// namespace
void
KernelRoiAlignRotatedForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
features
,
const
void
*
rois
,
void
*
output
,
const
int
batch
,
const
int
height
,
const
int
width
,
const
int
channel
,
const
int
rois_num
,
const
RoiAlignRotatedParams
roiAlignRotatedParams
);
void
KernelRoiAlignRotatedBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
top_grad
,
const
void
*
rois
,
void
*
bottom_grad
,
const
int
batch
,
const
int
height
,
const
int
width
,
const
int
channel
,
const
int
rois_num
,
const
RoiAlignRotatedParams
roiAlignRotatedParams
);
#include "mlu_common_helper.h"
void
ROIAlignRotatedForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
rois
,
Tensor
output
,
int
pooled_height
,
...
...
@@ -47,153 +17,70 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois,
float
spatial_scale
,
int
sampling_ratio
,
bool
aligned
,
bool
clockwise
)
{
TORCH_CHECK
(((
input
.
scalar_type
()
==
output
.
scalar_type
())
&&
(
output
.
scalar_type
()
==
rois
.
scalar_type
())),
"data types of input, rois and output should be the same, "
,
"but now input type is "
,
input
.
scalar_type
(),
", rois type is "
,
rois
.
scalar_type
(),
", output type is "
,
output
.
scalar_type
(),
"."
);
TORCH_CHECK
(
(
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
),
"input type should be Float or Half, got "
,
input
.
scalar_type
(),
"."
);
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input should be a 4d tensor, got "
,
input
.
dim
(),
"D."
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2d tensor, got "
,
rois
.
dim
(),
"D."
);
TORCH_CHECK
(
output
.
dim
()
==
4
,
"output should be a 4d tensor, got "
,
output
.
dim
(),
"D."
);
TORCH_CHECK
((
rois
.
size
(
0
)
==
output
.
size
(
0
)),
"the 1st dimensions of rois and output should be the same, "
,
"but now the 1st dimension of rois is "
,
rois
.
size
(
0
),
", and output is "
,
output
.
size
(
0
),
"."
);
TORCH_CHECK
((
input
.
size
(
1
)
==
output
.
size
(
1
)),
"the 2nd dimensions of input and output should be the same, "
,
"but now the 2nd dimension of input is "
,
input
.
size
(
1
),
", and output is "
,
output
.
size
(
1
),
"."
);
int
channel
=
input
.
size
(
1
);
int
width
=
input
.
size
(
3
);
int
height
=
input
.
size
(
2
);
int
batch
=
input
.
size
(
0
);
int
rois_nums
=
rois
.
size
(
0
);
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
// return if zero-elements
if
(
input
.
numel
()
==
0
)
{
CNLOG
(
INFO
)
<<
"Skip the zero-elements case."
;
return
;
}
RoiAlignRotatedParams
roiAlignRotatedParams
{
pooled_height
,
pooled_width
,
sampling_ratio
,
spatial_scale
,
aligned
,
clockwise
};
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFunc
(
rois_nums
*
pooled_height
*
pooled_width
,
&
k_dim
,
&
k_type
);
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
auto
input_
tensor
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
memory_format
);
at
::
Tensor
output_tmp
=
at
::
empty
({
rois_nums
,
channel
,
pooled_height
,
pooled_width
},
input
.
options
()
,
memory_format
);
auto
input_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
memory_format
);
auto
rois_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
rois
,
rois
.
suggest_memory_format
());
auto
output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
memory_format
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
MluOpTensorDescriptor
input_desc
,
rois_desc
,
output_desc
;
input_desc
.
set_with_layout
(
input_
,
MLUOP_LAYOUT_NHWC
);
rois_desc
.
set
(
rois_contiguous
);
output_desc
.
set_with_layout
(
output_contiguous
,
MLUOP_LAYOUT_NHWC
);
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input_
tensor
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input_
);
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
_contiguous
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
tmp
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
contiguous
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
KernelRoiAlignRotatedForward
(
k_dim
,
k_type
,
queue
,
d_type
,
input_ptr
,
rois_ptr
,
output_ptr
,
batch
,
height
,
width
,
channel
,
rois_nums
,
roiAlignRotatedParams
);
output
.
copy_
(
output_tmp
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
mluOpRoiAlignRotatedForward
(
handle
,
input_desc
.
desc
(),
input_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
pooled_height
,
pooled_width
,
sampling_ratio
,
spatial_scale
,
aligned
,
clockwise
,
output_desc
.
desc
(),
output_ptr
);
output
.
copy_
(
output_contiguous
);
}
void
ROIAlignRotatedBackwardMLUKernelLauncher
(
Tensor
top_grad
,
Tensor
rois
,
Tensor
bottom_grad
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
sampling_ratio
,
bool
aligned
,
bool
clockwise
)
{
TORCH_CHECK
(((
top_grad
.
scalar_type
()
==
bottom_grad
.
scalar_type
())
&&
(
bottom_grad
.
scalar_type
()
==
rois
.
scalar_type
())),
"data types of top_grad, rois and bottom_grad should be "
,
"the same, but now top_grad type is "
,
top_grad
.
scalar_type
(),
", rois type is "
,
rois
.
scalar_type
(),
", bottom_grad type is "
,
bottom_grad
.
scalar_type
(),
"."
);
TORCH_CHECK
((
bottom_grad
.
scalar_type
()
==
at
::
kFloat
||
bottom_grad
.
scalar_type
()
==
at
::
kHalf
),
"Data type of bottom_grad should be Float ro Half, got "
,
bottom_grad
.
scalar_type
(),
"."
);
TORCH_CHECK
(
bottom_grad
.
dim
()
==
4
,
"bottom_grad should be a 4d tensor, got "
,
top_grad
.
dim
(),
"D."
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2d tensor, got "
,
rois
.
dim
(),
"D."
);
TORCH_CHECK
(
top_grad
.
dim
()
==
4
,
"top_grad should be a 4d tensor, got "
,
bottom_grad
.
dim
(),
"D."
);
TORCH_CHECK
((
rois
.
size
(
0
)
==
top_grad
.
size
(
0
)),
"the 1st dimensions of rois and top_grad should be the same, "
,
"but now the 1st dimension of rois is "
,
rois
.
size
(
0
),
", and top_grad is "
,
top_grad
.
size
(
0
),
"."
);
TORCH_CHECK
((
bottom_grad
.
size
(
1
)
==
top_grad
.
size
(
1
)),
"the 2nd dimensions of bottom_grad and top_grad should be "
,
"the same, but now the 2nd dimension of bottom_grad is "
,
bottom_grad
.
size
(
1
),
", and top_grad is "
,
top_grad
.
size
(
1
),
"."
);
int
channel
=
bottom_grad
.
size
(
1
);
int
width
=
bottom_grad
.
size
(
3
);
int
height
=
bottom_grad
.
size
(
2
);
int
batch
=
bottom_grad
.
size
(
0
);
int
rois_nums
=
rois
.
size
(
0
);
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
bottom_grad
.
dtype
());
// return if zero-elements
if
(
bottom_grad
.
numel
()
==
0
)
{
CNLOG
(
INFO
)
<<
"Skip the zero-elements case."
;
return
;
}
RoiAlignRotatedParams
roiAlignRotatedParams
{
pooled_height
,
pooled_width
,
sampling_ratio
,
spatial_scale
,
aligned
,
clockwise
};
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFunc
(
rois_nums
*
pooled_height
*
pooled_width
,
&
k_dim
,
&
k_type
);
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
top_grad
.
dim
());
auto
top_grad_
tensor
=
auto
top_grad_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
top_grad
,
memory_format
);
at
::
Tensor
bottom_grad_tmp
=
at
::
empty
({
batch
,
channel
,
height
,
width
},
top_grad
.
options
(),
memory_format
)
.
zero_
();
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
rois_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
rois
,
rois
.
suggest_memory_format
());
auto
bottom_grad_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
bottom_grad
,
memory_format
);
// get ptr of tensors
auto
bottom_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
bottom_grad_tmp
);
auto
bottom_grad_ptr
=
bottom_grad_impl
->
cnnlMalloc
();
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
top_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
top_grad_tensor
);
auto
top_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
top_grad_
);
auto
top_grad_ptr
=
top_grad_impl
->
cnnlMalloc
();
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois_contiguous
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
bottom_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
bottom_grad_
);
auto
bottom_grad_ptr
=
bottom_grad_impl
->
cnnlMalloc
();
KernelRoiAlignRotatedBackward
(
k_dim
,
k_type
,
queue
,
d_type
,
top_grad_ptr
,
rois_ptr
,
bottom_grad_ptr
,
batch
,
height
,
width
,
channel
,
rois_nums
,
roiAlignRotatedParams
);
bottom_grad
.
copy_
(
bottom_grad_tmp
);
MluOpTensorDescriptor
top_grad_desc
,
rois_desc
,
bottom_grad_desc
;
top_grad_desc
.
set_with_layout
(
top_grad_
,
MLUOP_LAYOUT_NHWC
);
rois_desc
.
set
(
rois_contiguous
);
bottom_grad_desc
.
set_with_layout
(
bottom_grad_
,
MLUOP_LAYOUT_NHWC
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
mluOpRoiAlignRotatedBackward
(
handle
,
top_grad_desc
.
desc
(),
top_grad_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
pooled_height
,
pooled_width
,
sampling_ratio
,
spatial_scale
,
aligned
,
clockwise
,
bottom_grad_desc
.
desc
(),
bottom_grad_ptr
);
bottom_grad
.
copy_
(
bottom_grad_
);
}
void
roi_align_rotated_forward_mlu
(
Tensor
input
,
Tensor
rois
,
Tensor
output
,
...
...
tests/test_ops/test_roi_align_rotated.py
View file @
ba8aa764
...
...
@@ -11,7 +11,6 @@ try:
except
ImportError
:
from
torch.autograd
import
gradcheck
_USING_PARROTS
=
False
# yapf:disable
inputs
=
[([[[[
1.
,
2.
],
[
3.
,
4.
]]]],
[[
0.
,
0.5
,
0.5
,
1.
,
1.
,
0
]]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment