Unverified Commit fb5062ca authored by zhouchenyang's avatar zhouchenyang Committed by GitHub
Browse files

[Feature] Support RoiAlignRotated with cambricon MLU backend (#2033)

* [Feature] Support RoiAlignRotated with cambricon MLU backend

* [Fix] Remove std lib in mlu files

* [Fix] replace std::min/max with conditional operators
parent 27d1d7fe
......@@ -33,4 +33,6 @@
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
#endif
#define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y))
#endif // UTILS_H_
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* OR IMPLIED, INCLUDING BUvoid NOKType LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENvoid SHALL THE AUTHORS OR COPYRIGHKType HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORvoid OR OTHERWISE, ARISING FROM, OUKType OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include "roi_align_rotated_utils.hpp"
#define ROI_OFFSET 6
#define SAMPLING_NUM 4
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void swap(T &a, T &b) {
T tmp = a;
a = b;
b = tmp;
}
template <typename T>
__mlu_func__ void bilinearInterpolate(const int input_height,
const int input_width, T x, T y,
const T zero_sign, T *w1, T *w2, T *w3,
T *w4, int *x_low, int *x_high,
int *y_low, int *y_high, bool *empty) {
// deal with case that the point is out of feature map boundary
if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) {
*empty = true;
return;
}
if (y <= 0) y = (T)0;
if (x <= 0) x = (T)0;
*y_low = int(y);
*x_low = int(x);
if (*y_low >= input_height - 1) {
*y_high = *y_low = input_height - 1;
y = (T)(*y_low);
} else {
*y_high = *y_low + 1;
}
if (*x_low >= input_width - 1) {
*x_high = *x_low = input_width - 1;
x = T(*x_low);
} else {
*x_high = *x_low + 1;
}
T ly = y - *y_low;
T lx = x - *x_low;
T hy = 1.0 - ly;
T hx = 1.0 - lx;
*w1 = hy * hx * zero_sign;
*w2 = hy * lx * zero_sign;
*w3 = ly * hx * zero_sign;
*w4 = ly * lx * zero_sign;
}
template <typename T>
__mlu_func__ void getRoiBinInfo(const T *rois_dram, const int bin_i,
const RoiAlignRotatedParams &params,
int *batch_idx, int *roi_n, int *pw, int *ph,
T *roi_center_x, T *roi_center_y, T *roi_width,
T *roi_height, T *theta) {
T offset = params.aligned ? (T)0.5 : (T)0.0;
*pw = bin_i % params.pooled_width;
*ph = (bin_i / params.pooled_width) % params.pooled_height;
*roi_n = bin_i / params.pooled_width / params.pooled_height;
const T *roi_info = rois_dram + (*roi_n) * ROI_OFFSET;
*batch_idx = (int)roi_info[0];
*roi_center_x = roi_info[1] * (T)params.spatial_scale - offset;
*roi_center_y = roi_info[2] * (T)params.spatial_scale - offset;
*roi_width = roi_info[3] * (T)params.spatial_scale;
*roi_height = roi_info[4] * (T)params.spatial_scale;
*theta = roi_info[5];
if (params.clockwise) {
*theta = -(*theta);
}
if (!params.aligned) {
*roi_width = *roi_width > (T)1.0 ? *roi_width : (T)1.0;
*roi_height = *roi_height > (T)1.0 ? *roi_height : (T)1.0;
}
}
template <typename T>
__mlu_func__ void roiAlignRotatedForward(const T *input_dram,
const T *rois_dram, const int batch,
const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams &params,
T *output_dram) {
int align_base_128 = NFU_ALIGN_SIZE / sizeof(T);
int channel_max_cap = MAX_NRAM_SIZE / sizeof(T) / (2 * SAMPLING_NUM + 1);
channel_max_cap = channel_max_cap / align_base_128 * align_base_128;
int channel_align = channel < channel_max_cap ? channel : channel_max_cap;
channel_align = CEIL_ALIGN(channel_align, align_base_128);
T *nram_out = (T *)nram_buffer;
T *nram_ping = nram_out + channel_align;
T *nram_pong = nram_ping + channel_align * SAMPLING_NUM;
int bin_first = taskId;
int bin_end = rois_num * params.pooled_height * params.pooled_width;
for (int bin_i = bin_first; bin_i < bin_end; bin_i += taskDim) {
T roi_center_x, roi_center_y, roi_width, roi_height, theta;
int batch_idx, roi_n, pw, ph;
getRoiBinInfo(rois_dram, bin_i, params, &batch_idx, &roi_n, &pw, &ph,
&roi_center_x, &roi_center_y, &roi_width, &roi_height,
&theta);
T bin_size_h = roi_height / params.pooled_height;
T bin_size_w = roi_width / params.pooled_width;
int roi_bin_grid_h =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_height / params.pooled_height);
int roi_bin_grid_w =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_width / params.pooled_width);
T roi_start_y = -roi_height / 2;
T roi_start_x = -roi_width / 2;
const int bin_dim = roi_bin_grid_h * roi_bin_grid_w > 1
? roi_bin_grid_h * roi_bin_grid_w
: 1;
T cos_theta = std::cos(theta);
T sin_theta = std::sin(theta);
T zero_sign = 1.0f / bin_dim;
bool is_first_sample = true;
int src_offset = 0;
int dst_offset = 0;
int c_rem, c_slice, c_slice_align, pongc_slice, pongc_slice_align;
for (int c_offset = 0; c_offset < channel; c_offset += channel_align) {
__nramset(nram_out, channel_align, (T)0);
c_rem = channel - c_offset;
c_slice = channel_align > c_rem ? c_rem : channel_align;
c_slice_align = CEIL_ALIGN(c_slice, align_base_128);
is_first_sample = true;
for (int iy = 0; iy < roi_bin_grid_h; ++iy) {
const T yy = roi_start_y + ph * bin_size_h +
T(iy + 0.5) * bin_size_h / roi_bin_grid_h;
for (int ix = 0; ix < roi_bin_grid_w; ++ix) {
const T xx = roi_start_x + pw * bin_size_w +
T(ix + 0.5) * bin_size_w / roi_bin_grid_w;
int sample_i = iy * roi_bin_grid_w + ix;
T y = yy * cos_theta - xx * sin_theta + roi_center_y;
T x = yy * sin_theta + xx * cos_theta + roi_center_x;
T w1, w2, w3, w4;
bool empty = false;
int x_low, x_high, y_low, y_high;
bilinearInterpolate(height, width, x, y, zero_sign, &w1, &w2, &w3,
&w4, &x_low, &x_high, &y_low, &y_high, &empty);
int sample_wdim = x_high - x_low + 1;
/*******************************************************
| ping | pong |
|------|-----|-----|-----|-----|-----|-----|-----|-----|
|output| p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 |
|------|-----|-----|-----|-----|-----|-----|-----|-----|
********************************************************/
if (is_first_sample && !empty) {
// load input data from dram to nram
__nramset(nram_ping, SAMPLING_NUM * c_slice_align, (T)0);
for (int h = y_low; h <= y_high; ++h) {
src_offset =
(batch_idx * height * width + h * width + x_low) * channel +
c_offset;
dst_offset = (h - y_low) * SAMPLING_NUM * c_slice_align / 2;
if (c_slice_align == channel) {
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
sample_wdim * channel * sizeof(T), GDRAM2NRAM);
} else {
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM,
c_slice_align * sizeof(T), channel * sizeof(T),
sample_wdim - 1);
}
}
}
// load next input data to nram
if (sample_i + 1 < bin_dim) {
int p_iy = (sample_i + 1) / roi_bin_grid_w;
int p_ix = (sample_i + 1) % roi_bin_grid_w;
const T p_yy = roi_start_y + ph * bin_size_h +
T(p_iy + 0.5) * bin_size_h / roi_bin_grid_h;
const T p_xx = roi_start_x + pw * bin_size_w +
T(p_ix + 0.5) * bin_size_w / roi_bin_grid_w;
T p_y = p_yy * cos_theta - p_xx * sin_theta + roi_center_y;
T p_x = p_yy * sin_theta + p_xx * cos_theta + roi_center_x;
T p_w1, p_w2, p_w3, p_w4;
bool p_empty = false;
int p_x_low, p_x_high, p_y_low, p_y_high;
bilinearInterpolate(height, width, p_x, p_y, zero_sign, &p_w1,
&p_w2, &p_w3, &p_w4, &p_x_low, &p_x_high,
&p_y_low, &p_y_high, &p_empty);
int p_sample_wdim = p_x_high - p_x_low + 1;
pongc_slice = c_slice;
pongc_slice_align = c_slice_align;
if (!p_empty) {
__nramset(nram_pong, SAMPLING_NUM * pongc_slice_align, (T)0);
for (int h = p_y_low; h <= p_y_high; ++h) {
src_offset =
(batch_idx * height * width + h * width + p_x_low) *
channel +
c_offset;
dst_offset =
(h - p_y_low) * SAMPLING_NUM * pongc_slice_align / 2;
if (pongc_slice_align == channel) {
__memcpy_async(
nram_pong + dst_offset, input_dram + src_offset,
p_sample_wdim * channel * sizeof(T), GDRAM2NRAM);
} else {
__memcpy_async(nram_pong + dst_offset,
input_dram + src_offset,
pongc_slice * sizeof(T), GDRAM2NRAM,
pongc_slice_align * sizeof(T),
channel * sizeof(T), p_sample_wdim - 1);
}
}
}
}
T *tmp_sum = nram_ping + 3 * c_slice_align;
if (empty) {
__nramset(tmp_sum, c_slice_align, T(0));
} else {
__bang_mul_const(nram_ping, nram_ping, w1, c_slice_align);
__bang_mul_const(nram_ping + c_slice_align,
nram_ping + c_slice_align, w2, c_slice_align);
__bang_mul_const(nram_ping + 2 * c_slice_align,
nram_ping + 2 * c_slice_align, w3, c_slice_align);
__bang_mul_const(nram_ping + 3 * c_slice_align,
nram_ping + 3 * c_slice_align, w4, c_slice_align);
__bang_sumpool(tmp_sum, nram_ping, c_slice_align, 1, SAMPLING_NUM,
1, SAMPLING_NUM, 1, 1);
}
__bang_add(nram_out, nram_out, tmp_sum, c_slice_align);
swap(nram_ping, nram_pong);
__asm__ volatile("sync;");
is_first_sample = false;
}
}
// store the result to dram
int output_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
channel +
c_offset;
__memcpy(output_dram + output_offset, nram_out, c_slice * sizeof(T),
NRAM2GDRAM);
}
}
}
template <typename T>
__mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram,
const T *rois_dram, const int batch,
const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams &params,
T *bottom_grad_dram) {
int align_base_128 = NFU_ALIGN_SIZE / sizeof(T);
int channel_align = CEIL_ALIGN(channel, align_base_128);
unsigned int max_element = MAX_NRAM_SIZE / sizeof(T);
int c_limit = max_element >> 2;
c_limit = c_limit > channel_align ? channel_align : c_limit;
T *nram_ping = (T *)nram_buffer;
T *nram_pong = nram_ping + 2 * c_limit;
T *nram_output = nullptr;
int bin_first = taskId;
int bin_end = rois_num * params.pooled_height * params.pooled_width;
bool is_first_bin = true;
T roi_center_x, roi_center_y, roi_width, roi_height, theta;
int batch_idx, roi_n, pw, ph;
T pong_roi_center_x, pong_roi_center_y, pong_roi_width, pong_roi_height,
pong_theta;
int pong_batch_idx, pong_roi_n, pong_pw, pong_ph;
for (int bin_i = bin_first; bin_i < bin_end; bin_i += taskDim) {
getRoiBinInfo(rois_dram, bin_i, params, &batch_idx, &roi_n, &pw, &ph,
&roi_center_x, &roi_center_y, &roi_width, &roi_height,
&theta);
T bin_size_h = roi_height / params.pooled_height;
T bin_size_w = roi_width / params.pooled_width;
int roi_bin_grid_h =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_height / params.pooled_height);
int roi_bin_grid_w =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_width / params.pooled_width);
T roi_start_y = -roi_height / 2;
T roi_start_x = -roi_width / 2;
const int bin_dim = roi_bin_grid_h * roi_bin_grid_w > 1
? roi_bin_grid_h * roi_bin_grid_w
: 1;
T cos_theta = std::cos(theta);
T sin_theta = std::sin(theta);
T zero_sign = 1.0f / bin_dim;
int c_rem, c_slice, pongc_slice, c_offset;
c_rem = channel;
c_offset = 0;
/****************************************
| ping | pong |
|---------|---------|---------|---------|
| input | output | input | output |
|---------|---------|---------|---------|
*****************************************/
if (is_first_bin) {
// load the first top_grad to nram
c_slice = c_limit < c_rem ? c_limit : c_rem;
int top_grad_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
channel;
__memcpy(nram_ping, top_grad_dram + top_grad_offset, c_slice * sizeof(T),
GDRAM2NRAM);
}
nram_output = nram_ping + c_limit;
while (c_rem > 0) {
c_slice = c_slice < c_rem ? c_slice : c_rem;
// load the next top_grad to nram
if (c_rem - c_slice > 0) {
// load the rest channels to nram
pongc_slice = (c_rem - c_slice > c_slice) ? c_slice : c_rem - c_slice;
int top_grad_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
channel +
c_offset + c_slice;
__memcpy_async(nram_pong, top_grad_dram + top_grad_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
} else if (bin_i + taskDim < bin_end) {
// load next bin's data to nram
getRoiBinInfo(rois_dram, bin_i + taskDim, params, &pong_batch_idx,
&pong_roi_n, &pong_pw, &pong_ph, &pong_roi_center_x,
&pong_roi_center_y, &pong_roi_width, &pong_roi_height,
&pong_theta);
pongc_slice = c_limit < channel ? c_limit : channel;
int top_grad_offset = ((pong_roi_n * params.pooled_height + pong_ph) *
params.pooled_width +
pong_pw) *
channel;
__memcpy_async(nram_pong, top_grad_dram + top_grad_offset,
c_slice * sizeof(T), GDRAM2NRAM);
}
// comput the output in a single bin
for (int iy = 0; iy < roi_bin_grid_h; ++iy) {
const T yy = roi_start_y + ph * bin_size_h +
T(iy + 0.5) * bin_size_h / roi_bin_grid_h;
for (int ix = 0; ix < roi_bin_grid_w; ++ix) {
const T xx = roi_start_x + pw * bin_size_w +
T(ix + 0.5) * bin_size_w / roi_bin_grid_w;
T y = yy * cos_theta - xx * sin_theta + roi_center_y;
T x = yy * sin_theta + xx * cos_theta + roi_center_x;
T w1, w2, w3, w4;
bool empty = false;
int x_low, x_high, y_low, y_high;
bilinearInterpolate(height, width, x, y, zero_sign, &w1, &w2, &w3,
&w4, &x_low, &x_high, &y_low, &y_high, &empty);
if (empty) {
continue;
} else {
__bang_mul_const(nram_output, nram_ping, w1, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_low * width * channel + x_low * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_const(nram_output, nram_ping, w2, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_low * width * channel + x_high * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_const(nram_output, nram_ping, w3, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_high * width * channel + x_low * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_const(nram_output, nram_ping, w4, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_high * width * channel + x_high * channel + c_offset,
(T *)nram_output, c_slice);
}
}
}
swap(nram_ping, nram_pong);
c_rem -= c_slice;
c_offset += c_slice;
__asm__ volatile("sync;");
}
is_first_bin = false;
}
}
__mlu_global__ void MLUUnion1KernelRoiAlignRotatedForward(
const void *features, const void *rois, void *output, const int batch,
const int height, const int width, const int channel, const int rois_num,
const RoiAlignRotatedParams rroiAlignParams,
const cnrtDataType_t data_type) {
if (0x80 == coreId) {
return;
}
if (data_type == CNRT_FLOAT32) {
roiAlignRotatedForward((float *)features, (float *)rois, batch, height,
width, channel, rois_num, rroiAlignParams,
(float *)output);
} else {
roiAlignRotatedForward((half *)features, (half *)rois, batch, height, width,
channel, rois_num, rroiAlignParams, (half *)output);
}
}
__mlu_global__ void MLUUnion1KernelRoiAlignRotatedBackward(
const void *top_grad, const void *rois, void *bottom_grad, const int batch,
const int height, const int width, const int channel, const int rois_num,
const RoiAlignRotatedParams rroiAlignParams,
const cnrtDataType_t data_type) {
if (0x80 == coreId) {
return;
}
if (data_type == CNRT_FLOAT32) {
roiAlignRotatedBackward((float *)top_grad, (float *)rois, batch, height,
width, channel, rois_num, rroiAlignParams,
(float *)bottom_grad);
} else {
roiAlignRotatedBackward((half *)top_grad, (half *)rois, batch, height,
width, channel, rois_num, rroiAlignParams,
(half *)bottom_grad);
}
}
void KernelRoiAlignRotatedForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const void *features, const void *rois,
void *output, const int batch, const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams roiAlignRotatedParams) {
MLUUnion1KernelRoiAlignRotatedForward<<<k_dim, k_type, queue>>>(
features, rois, output, batch, height, width, channel, rois_num,
roiAlignRotatedParams, d_type);
}
void KernelRoiAlignRotatedBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const void *top_grad, const void *rois,
void *bottom_grad, const int batch, const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams roiAlignRotatedParams) {
MLUUnion1KernelRoiAlignRotatedBackward<<<k_dim, k_type, queue>>>(
top_grad, rois, bottom_grad, batch, height, width, channel, rois_num,
roiAlignRotatedParams, d_type);
}
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef ROI_ALIGN_ROTATED_UTILS_HPP_
#define ROI_ALIGN_ROTATED_UTILS_HPP_
struct RoiAlignRotatedParams {
int pooled_height;
int pooled_width;
int sample_ratio;
float spatial_scale;
bool aligned;
bool clockwise;
};
#endif // ROI_ALIGN_ROTATED_UTILS_HPP_
......@@ -21,6 +21,8 @@
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
#define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y))
#endif
#endif // PYTORCH_MLU_HELPER_HPP_
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#include "roi_align_rotated_utils.hpp"
namespace {
void policyFunc(int bin_num, cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = core_num;
unsigned int use_cluster = (bin_num + core_num - 1) / core_num;
k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster;
k_dim->z = 1;
}
} // namespace
void KernelRoiAlignRotatedForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const void *features, const void *rois,
void *output, const int batch, const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams roiAlignRotatedParams);
void KernelRoiAlignRotatedBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const void *top_grad, const void *rois,
void *bottom_grad, const int batch, const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams roiAlignRotatedParams);
void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois,
Tensor output, int pooled_height,
int pooled_width,
float spatial_scale,
int sampling_ratio, bool aligned,
bool clockwise) {
TORCH_CHECK(((input.scalar_type() == output.scalar_type()) &&
(output.scalar_type() == rois.scalar_type())),
"data types of input, rois and output should be the same, ",
"but now input type is ", input.scalar_type(), ", rois type is ",
rois.scalar_type(), ", output type is ", output.scalar_type(),
".");
TORCH_CHECK(
(input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf),
"input type should be Float or Half, got ", input.scalar_type(), ".");
TORCH_CHECK(input.dim() == 4, "input should be a 4d tensor, got ",
input.dim(), "D.");
TORCH_CHECK(rois.dim() == 2, "rois should be a 2d tensor, got ", rois.dim(),
"D.");
TORCH_CHECK(output.dim() == 4, "output should be a 4d tensor, got ",
output.dim(), "D.");
TORCH_CHECK((rois.size(0) == output.size(0)),
"the 1st dimensions of rois and output should be the same, ",
"but now the 1st dimension of rois is ", rois.size(0),
", and output is ", output.size(0), ".");
TORCH_CHECK((input.size(1) == output.size(1)),
"the 2nd dimensions of input and output should be the same, ",
"but now the 2nd dimension of input is ", input.size(1),
", and output is ", output.size(1), ".");
int channel = input.size(1);
int width = input.size(3);
int height = input.size(2);
int batch = input.size(0);
int rois_nums = rois.size(0);
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype());
// return if zero-elements
if (input.numel() == 0) {
CNLOG(INFO) << "Skip the zero-elements case.";
return;
}
RoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width,
sampling_ratio, spatial_scale,
aligned, clockwise};
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(rois_nums * pooled_height * pooled_width, &k_dim, &k_type);
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
auto input_tensor =
torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
at::Tensor output_tmp =
at::empty({batch, channel, pooled_height, pooled_width}, input.options(),
memory_format);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(input_tensor);
auto input_ptr = input_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
auto rois_ptr = rois_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output_tmp);
auto output_ptr = output_impl->cnnlMalloc();
KernelRoiAlignRotatedForward(k_dim, k_type, queue, d_type, input_ptr,
rois_ptr, output_ptr, batch, height, width,
channel, rois_nums, roiAlignRotatedParams);
output.copy_(output_tmp);
}
void ROIAlignRotatedBackwardMLUKernelLauncher(
Tensor top_grad, Tensor rois, Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale, int sampling_ratio, bool aligned,
bool clockwise) {
TORCH_CHECK(((top_grad.scalar_type() == bottom_grad.scalar_type()) &&
(bottom_grad.scalar_type() == rois.scalar_type())),
"data types of top_grad, rois and bottom_grad should be ",
"the same, but now top_grad type is ", top_grad.scalar_type(),
", rois type is ", rois.scalar_type(), ", bottom_grad type is ",
bottom_grad.scalar_type(), ".");
TORCH_CHECK((bottom_grad.scalar_type() == at::kFloat ||
bottom_grad.scalar_type() == at::kHalf),
"Data type of bottom_grad should be Float ro Half, got ",
bottom_grad.scalar_type(), ".");
TORCH_CHECK(bottom_grad.dim() == 4, "bottom_grad should be a 4d tensor, got ",
top_grad.dim(), "D.");
TORCH_CHECK(rois.dim() == 2, "rois should be a 2d tensor, got ", rois.dim(),
"D.");
TORCH_CHECK(top_grad.dim() == 4, "top_grad should be a 4d tensor, got ",
bottom_grad.dim(), "D.");
TORCH_CHECK((rois.size(0) == top_grad.size(0)),
"the 1st dimensions of rois and top_grad should be the same, ",
"but now the 1st dimension of rois is ", rois.size(0),
", and top_grad is ", top_grad.size(0), ".");
TORCH_CHECK((bottom_grad.size(1) == top_grad.size(1)),
"the 2nd dimensions of bottom_grad and top_grad should be ",
"the same, but now the 2nd dimension of bottom_grad is ",
bottom_grad.size(1), ", and top_grad is ", top_grad.size(1), ".");
int channel = bottom_grad.size(1);
int width = bottom_grad.size(3);
int height = bottom_grad.size(2);
int batch = bottom_grad.size(0);
int rois_nums = rois.size(0);
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(bottom_grad.dtype());
// return if zero-elements
if (bottom_grad.numel() == 0) {
CNLOG(INFO) << "Skip the zero-elements case.";
return;
}
RoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width,
sampling_ratio, spatial_scale,
aligned, clockwise};
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(rois_nums * pooled_height * pooled_width, &k_dim, &k_type);
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(top_grad.dim());
auto top_grad_tensor =
torch_mlu::cnnl::ops::cnnl_contiguous(top_grad, memory_format);
at::Tensor bottom_grad_tmp = at::empty({batch, channel, height, width},
top_grad.options(), memory_format)
.zero_();
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto bottom_grad_impl = torch_mlu::getMluTensorImpl(bottom_grad_tmp);
auto bottom_grad_ptr = bottom_grad_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
auto rois_ptr = rois_impl->cnnlMalloc();
auto top_grad_impl = torch_mlu::getMluTensorImpl(top_grad_tensor);
auto top_grad_ptr = top_grad_impl->cnnlMalloc();
KernelRoiAlignRotatedBackward(k_dim, k_type, queue, d_type, top_grad_ptr,
rois_ptr, bottom_grad_ptr, batch, height, width,
channel, rois_nums, roiAlignRotatedParams);
bottom_grad.copy_(bottom_grad_tmp);
}
void roi_align_rotated_forward_mlu(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise) {
ROIAlignRotatedForwardMLUKernelLauncher(input, rois, output, aligned_height,
aligned_width, spatial_scale,
sampling_ratio, aligned, clockwise);
}
void roi_align_rotated_backward_mlu(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, bool aligned,
bool clockwise) {
ROIAlignRotatedBackwardMLUKernelLauncher(
top_grad, rois, bottom_grad, aligned_height, aligned_width, spatial_scale,
sampling_ratio, aligned, clockwise);
}
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);
void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, bool aligned,
bool clockwise);
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, MLU,
roi_align_rotated_forward_mlu);
REGISTER_DEVICE_IMPL(roi_align_rotated_backward_impl, MLU,
roi_align_rotated_backward_mlu);
......@@ -3,6 +3,8 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
......@@ -51,8 +53,6 @@ sampling_ratio = 2
def _test_roialign_rotated_gradcheck(device, dtype):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('unittest does not support GPU yet.')
try:
from mmcv.ops import RoIAlignRotated
except ModuleNotFoundError:
......@@ -69,7 +69,6 @@ def _test_roialign_rotated_gradcheck(device, dtype):
froipool = RoIAlignRotated((pool_h, pool_w), spatial_scale,
sampling_ratio)
if torch.__version__ == 'parrots':
gradcheck(
froipool, (x, rois), no_grads=[rois], delta=1e-5, pt_atol=1e-5)
......@@ -78,8 +77,6 @@ def _test_roialign_rotated_gradcheck(device, dtype):
def _test_roialign_rotated_allclose(device, dtype):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('unittest does not support GPU yet.')
try:
from mmcv.ops import RoIAlignRotated, roi_align_rotated
except ModuleNotFoundError:
......@@ -127,10 +124,28 @@ def _test_roialign_rotated_allclose(device, dtype):
output_2.data.type(torch.float).cpu().numpy())
@pytest.mark.parametrize('device', ['cuda', 'cpu'])
@pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half])
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
torch.half
])
def test_roialign_rotated(device, dtype):
# check double only
if (dtype is torch.double):
if dtype is torch.double:
_test_roialign_rotated_gradcheck(device=device, dtype=dtype)
_test_roialign_rotated_allclose(device=device, dtype=dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment