Unverified Commit ba8aa764 authored by tudejiang79's avatar tudejiang79 Committed by GitHub
Browse files

[Refactor] Repalce the implementation of roi_align_rotated with mlu-ops (#2808)

parent d2aecbe4
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* OR IMPLIED, INCLUDING BUvoid NOKType LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENvoid SHALL THE AUTHORS OR COPYRIGHKType HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORvoid OR OTHERWISE, ARISING FROM, OUKType OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include "roi_align_rotated_utils.hpp"
#define ROI_OFFSET 6
#define SAMPLING_NUM 4
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void swap(T &a, T &b) {
T tmp = a;
a = b;
b = tmp;
}
template <typename T>
__mlu_func__ void bilinearInterpolate(const int input_height,
const int input_width, T x, T y, T *w1,
T *w2, T *w3, T *w4, int *x_low,
int *x_high, int *y_low, int *y_high,
bool *empty) {
// deal with case that the point is out of feature map boundary
if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) {
*empty = true;
return;
}
if (y <= 0) y = (T)0;
if (x <= 0) x = (T)0;
*y_low = int(y);
*x_low = int(x);
if (*y_low >= input_height - 1) {
*y_high = *y_low = input_height - 1;
y = (T)(*y_low);
} else {
*y_high = *y_low + 1;
}
if (*x_low >= input_width - 1) {
*x_high = *x_low = input_width - 1;
x = T(*x_low);
} else {
*x_high = *x_low + 1;
}
T ly = y - *y_low;
T lx = x - *x_low;
T hy = 1.0 - ly;
T hx = 1.0 - lx;
*w1 = hy * hx;
*w2 = hy * lx;
*w3 = ly * hx;
*w4 = ly * lx;
return;
}
template <typename T>
__mlu_func__ void getRoiBinInfo(const T *rois_dram, const int bin_i,
const RoiAlignRotatedParams &params,
int *batch_idx, int *roi_n, int *pw, int *ph,
T *roi_center_x, T *roi_center_y, T *roi_width,
T *roi_height, T *theta) {
T offset = params.aligned ? (T)0.5 : (T)0.0;
*pw = bin_i % params.pooled_width;
*ph = (bin_i / params.pooled_width) % params.pooled_height;
*roi_n = bin_i / params.pooled_width / params.pooled_height;
const T *roi_info = rois_dram + (*roi_n) * ROI_OFFSET;
*batch_idx = (int)roi_info[0];
*roi_center_x = roi_info[1] * (T)params.spatial_scale - offset;
*roi_center_y = roi_info[2] * (T)params.spatial_scale - offset;
*roi_width = roi_info[3] * (T)params.spatial_scale;
*roi_height = roi_info[4] * (T)params.spatial_scale;
*theta = roi_info[5];
if (params.clockwise) {
*theta = -(*theta);
}
if (!params.aligned) {
*roi_width = *roi_width > (T)1.0 ? *roi_width : (T)1.0;
*roi_height = *roi_height > (T)1.0 ? *roi_height : (T)1.0;
}
}
template <typename T>
__mlu_func__ void roiAlignRotatedForward(const T *input_dram,
const T *rois_dram, const int batch,
const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams &params,
T *output_dram) {
int align_base_128 = NFU_ALIGN_SIZE / sizeof(T);
int channel_max_cap = MAX_NRAM_SIZE / sizeof(T) / (2 * SAMPLING_NUM + 1);
channel_max_cap = channel_max_cap / align_base_128 * align_base_128;
int channel_align = channel < channel_max_cap ? channel : channel_max_cap;
channel_align = CEIL_ALIGN(channel_align, align_base_128);
T *nram_out = (T *)nram_buffer;
T *nram_ping = nram_out + channel_align;
T *nram_pong = nram_ping + channel_align * SAMPLING_NUM;
int bin_first = taskId;
int bin_end = rois_num * params.pooled_height * params.pooled_width;
for (int bin_i = bin_first; bin_i < bin_end; bin_i += taskDim) {
T roi_center_x, roi_center_y, roi_width, roi_height, theta;
int batch_idx, roi_n, pw, ph;
getRoiBinInfo(rois_dram, bin_i, params, &batch_idx, &roi_n, &pw, &ph,
&roi_center_x, &roi_center_y, &roi_width, &roi_height,
&theta);
T bin_size_h = roi_height / params.pooled_height;
T bin_size_w = roi_width / params.pooled_width;
int roi_bin_grid_h =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_height / params.pooled_height);
int roi_bin_grid_w =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_width / params.pooled_width);
T roi_start_y = -roi_height / 2;
T roi_start_x = -roi_width / 2;
const int bin_dim = roi_bin_grid_h * roi_bin_grid_w > 1
? roi_bin_grid_h * roi_bin_grid_w
: 1;
T cos_theta = std::cos(theta);
T sin_theta = std::sin(theta);
T zero_sign = 1.0f / bin_dim;
bool is_first_sample = true;
int src_offset = 0;
int dst_offset = 0;
int c_rem, c_slice, c_slice_align, pongc_slice, pongc_slice_align;
for (int c_offset = 0; c_offset < channel; c_offset += channel_align) {
__bang_write_value(nram_out, channel_align, (T)0);
c_rem = channel - c_offset;
c_slice = channel_align > c_rem ? c_rem : channel_align;
c_slice_align = CEIL_ALIGN(c_slice, align_base_128);
is_first_sample = true;
for (int iy = 0; iy < roi_bin_grid_h; ++iy) {
const T yy = roi_start_y + ph * bin_size_h +
T(iy + 0.5) * bin_size_h / roi_bin_grid_h;
for (int ix = 0; ix < roi_bin_grid_w; ++ix) {
const T xx = roi_start_x + pw * bin_size_w +
T(ix + 0.5) * bin_size_w / roi_bin_grid_w;
int sample_i = iy * roi_bin_grid_w + ix;
T y = yy * cos_theta - xx * sin_theta + roi_center_y;
T x = yy * sin_theta + xx * cos_theta + roi_center_x;
T w1, w2, w3, w4;
bool empty = false;
int x_low, x_high, y_low, y_high;
bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low,
&x_high, &y_low, &y_high, &empty);
/*******************************************************
| ping | pong |
|------|-----|-----|-----|-----|-----|-----|-----|-----|
|output| p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 |
|------|-----|-----|-----|-----|-----|-----|-----|-----|
********************************************************/
if (is_first_sample && !empty) {
// load input data from dram to nram
__bang_write_value(nram_ping, SAMPLING_NUM * c_slice_align, (T)0);
src_offset =
(batch_idx * height * width + y_low * width + x_low) * channel +
c_offset;
dst_offset = 0;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset = (batch_idx * height * width + y_low * width + x_high) *
channel +
c_offset;
dst_offset = c_slice_align;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset = (batch_idx * height * width + y_high * width + x_low) *
channel +
c_offset;
dst_offset = c_slice_align * 2;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + y_high * width + x_high) *
channel +
c_offset;
dst_offset = c_slice_align * 3;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
}
// load next input data to nram
if (sample_i + 1 < bin_dim) {
int p_iy = (sample_i + 1) / roi_bin_grid_w;
int p_ix = (sample_i + 1) % roi_bin_grid_w;
const T p_yy = roi_start_y + ph * bin_size_h +
T(p_iy + 0.5) * bin_size_h / roi_bin_grid_h;
const T p_xx = roi_start_x + pw * bin_size_w +
T(p_ix + 0.5) * bin_size_w / roi_bin_grid_w;
T p_y = p_yy * cos_theta - p_xx * sin_theta + roi_center_y;
T p_x = p_yy * sin_theta + p_xx * cos_theta + roi_center_x;
T p_w1, p_w2, p_w3, p_w4;
bool p_empty = false;
int p_x_low, p_x_high, p_y_low, p_y_high;
bilinearInterpolate(height, width, p_x, p_y, &p_w1, &p_w2, &p_w3,
&p_w4, &p_x_low, &p_x_high, &p_y_low, &p_y_high,
&p_empty);
pongc_slice = c_slice;
pongc_slice_align = c_slice_align;
if (!p_empty) {
__bang_write_value(nram_pong, SAMPLING_NUM * pongc_slice_align,
(T)0);
src_offset =
(batch_idx * height * width + p_y_low * width + p_x_low) *
channel +
c_offset;
dst_offset = 0;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + p_y_low * width + p_x_high) *
channel +
c_offset;
dst_offset = pongc_slice_align;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + p_y_high * width + p_x_low) *
channel +
c_offset;
dst_offset = pongc_slice_align * 2;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + p_y_high * width + p_x_high) *
channel +
c_offset;
dst_offset = pongc_slice_align * 3;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
}
}
T *tmp_sum = nram_ping + 3 * c_slice_align;
if (empty) {
__bang_write_value(tmp_sum, c_slice_align, T(0));
} else {
__bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align);
__bang_mul_scalar(nram_ping + c_slice_align,
nram_ping + c_slice_align, w2, c_slice_align);
__bang_mul_scalar(nram_ping + 2 * c_slice_align,
nram_ping + 2 * c_slice_align, w3, c_slice_align);
__bang_mul_scalar(nram_ping + 3 * c_slice_align,
nram_ping + 3 * c_slice_align, w4, c_slice_align);
__bang_sumpool(tmp_sum, nram_ping, c_slice_align, 1, SAMPLING_NUM,
1, SAMPLING_NUM, 1, 1);
}
__bang_add(nram_out, nram_out, tmp_sum, c_slice_align);
swap(nram_ping, nram_pong);
__asm__ volatile("sync;");
is_first_sample = false;
}
}
__bang_mul_scalar(nram_out, nram_out, zero_sign, c_slice_align);
// store the result to dram
int output_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
channel +
c_offset;
__memcpy(output_dram + output_offset, nram_out, c_slice * sizeof(T),
NRAM2GDRAM);
}
}
}
template <typename T>
__mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram,
const T *rois_dram, const int batch,
const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams &params,
T *bottom_grad_dram) {
int align_base_128 = NFU_ALIGN_SIZE / sizeof(T);
int channel_align = CEIL_ALIGN(channel, align_base_128);
unsigned int max_element = MAX_NRAM_SIZE / sizeof(T);
int c_limit = max_element >> 2;
c_limit = c_limit > channel_align ? channel_align : c_limit;
T *nram_ping = (T *)nram_buffer;
T *nram_pong = nram_ping + 2 * c_limit;
T *nram_output = nullptr;
int bin_first = taskId;
int bin_end = rois_num * params.pooled_height * params.pooled_width;
bool is_first_bin = true;
T roi_center_x, roi_center_y, roi_width, roi_height, theta;
int batch_idx, roi_n, pw, ph;
T pong_roi_center_x, pong_roi_center_y, pong_roi_width, pong_roi_height,
pong_theta;
int pong_batch_idx, pong_roi_n, pong_pw, pong_ph;
for (int bin_i = bin_first; bin_i < bin_end; bin_i += taskDim) {
getRoiBinInfo(rois_dram, bin_i, params, &batch_idx, &roi_n, &pw, &ph,
&roi_center_x, &roi_center_y, &roi_width, &roi_height,
&theta);
T bin_size_h = roi_height / params.pooled_height;
T bin_size_w = roi_width / params.pooled_width;
int roi_bin_grid_h =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_height / params.pooled_height);
int roi_bin_grid_w =
(params.sample_ratio > 0)
? params.sample_ratio
: __float2int_up((float)roi_width / params.pooled_width);
T roi_start_y = -roi_height / 2;
T roi_start_x = -roi_width / 2;
const int bin_dim = roi_bin_grid_h * roi_bin_grid_w > 1
? roi_bin_grid_h * roi_bin_grid_w
: 1;
T cos_theta = std::cos(theta);
T sin_theta = std::sin(theta);
T zero_sign = 1.0f / bin_dim;
int c_rem, c_slice, pongc_slice, c_offset;
c_rem = channel;
c_offset = 0;
/****************************************
| ping | pong |
|---------|---------|---------|---------|
| input | output | input | output |
|---------|---------|---------|---------|
*****************************************/
if (is_first_bin) {
// load the first top_grad to nram
c_slice = c_limit < c_rem ? c_limit : c_rem;
int top_grad_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
channel;
__memcpy(nram_ping, top_grad_dram + top_grad_offset, c_slice * sizeof(T),
GDRAM2NRAM);
}
nram_output = nram_ping + c_limit;
while (c_rem > 0) {
c_slice = c_slice < c_rem ? c_slice : c_rem;
// load the next top_grad to nram
if (c_rem - c_slice > 0) {
// load the rest channels to nram
pongc_slice = (c_rem - c_slice > c_slice) ? c_slice : c_rem - c_slice;
int top_grad_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
channel +
c_offset + c_slice;
__memcpy_async(nram_pong, top_grad_dram + top_grad_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
} else if (bin_i + taskDim < bin_end) {
// load next bin's data to nram
getRoiBinInfo(rois_dram, bin_i + taskDim, params, &pong_batch_idx,
&pong_roi_n, &pong_pw, &pong_ph, &pong_roi_center_x,
&pong_roi_center_y, &pong_roi_width, &pong_roi_height,
&pong_theta);
pongc_slice = c_limit < channel ? c_limit : channel;
int top_grad_offset = ((pong_roi_n * params.pooled_height + pong_ph) *
params.pooled_width +
pong_pw) *
channel;
__memcpy_async(nram_pong, top_grad_dram + top_grad_offset,
c_slice * sizeof(T), GDRAM2NRAM);
}
// comput the output in a single bin
for (int iy = 0; iy < roi_bin_grid_h; ++iy) {
const T yy = roi_start_y + ph * bin_size_h +
T(iy + 0.5) * bin_size_h / roi_bin_grid_h;
for (int ix = 0; ix < roi_bin_grid_w; ++ix) {
const T xx = roi_start_x + pw * bin_size_w +
T(ix + 0.5) * bin_size_w / roi_bin_grid_w;
T y = yy * cos_theta - xx * sin_theta + roi_center_y;
T x = yy * sin_theta + xx * cos_theta + roi_center_x;
T w1, w2, w3, w4;
bool empty = false;
int x_low, x_high, y_low, y_high;
bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low,
&x_high, &y_low, &y_high, &empty);
if (empty) {
continue;
} else {
__bang_mul_scalar(nram_output, nram_ping, w1 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_low * width * channel + x_low * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_scalar(nram_output, nram_ping, w2 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_low * width * channel + x_high * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_scalar(nram_output, nram_ping, w3 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_high * width * channel + x_low * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_scalar(nram_output, nram_ping, w4 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_high * width * channel + x_high * channel + c_offset,
(T *)nram_output, c_slice);
}
}
}
swap(nram_ping, nram_pong);
c_rem -= c_slice;
c_offset += c_slice;
__asm__ volatile("sync;");
}
is_first_bin = false;
}
}
__mlu_global__ void MLUUnion1KernelRoiAlignRotatedForward(
const void *features, const void *rois, void *output, const int batch,
const int height, const int width, const int channel, const int rois_num,
const RoiAlignRotatedParams rroiAlignParams,
const cnrtDataType_t data_type) {
if (0x80 == coreId) {
return;
}
if (data_type == CNRT_FLOAT32) {
roiAlignRotatedForward((float *)features, (float *)rois, batch, height,
width, channel, rois_num, rroiAlignParams,
(float *)output);
} else {
roiAlignRotatedForward((half *)features, (half *)rois, batch, height, width,
channel, rois_num, rroiAlignParams, (half *)output);
}
}
__mlu_global__ void MLUUnion1KernelRoiAlignRotatedBackward(
const void *top_grad, const void *rois, void *bottom_grad, const int batch,
const int height, const int width, const int channel, const int rois_num,
const RoiAlignRotatedParams rroiAlignParams,
const cnrtDataType_t data_type) {
if (0x80 == coreId) {
return;
}
if (data_type == CNRT_FLOAT32) {
roiAlignRotatedBackward((float *)top_grad, (float *)rois, batch, height,
width, channel, rois_num, rroiAlignParams,
(float *)bottom_grad);
} else {
roiAlignRotatedBackward((half *)top_grad, (half *)rois, batch, height,
width, channel, rois_num, rroiAlignParams,
(half *)bottom_grad);
}
}
void KernelRoiAlignRotatedForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const void *features, const void *rois,
void *output, const int batch, const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams roiAlignRotatedParams) {
MLUUnion1KernelRoiAlignRotatedForward<<<k_dim, k_type, queue>>>(
features, rois, output, batch, height, width, channel, rois_num,
roiAlignRotatedParams, d_type);
}
void KernelRoiAlignRotatedBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const void *top_grad, const void *rois,
void *bottom_grad, const int batch, const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams roiAlignRotatedParams) {
MLUUnion1KernelRoiAlignRotatedBackward<<<k_dim, k_type, queue>>>(
top_grad, rois, bottom_grad, batch, height, width, channel, rois_num,
roiAlignRotatedParams, d_type);
}
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef ROI_ALIGN_ROTATED_UTILS_HPP_
#define ROI_ALIGN_ROTATED_UTILS_HPP_
struct RoiAlignRotatedParams {
int pooled_height;
int pooled_width;
int sample_ratio;
float spatial_scale;
bool aligned;
bool clockwise;
};
#endif // ROI_ALIGN_ROTATED_UTILS_HPP_
...@@ -9,37 +9,7 @@ ...@@ -9,37 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/ *************************************************************************/
#include "pytorch_device_registry.hpp" #include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
#include "roi_align_rotated_utils.hpp"
namespace {
void policyFunc(int bin_num, cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = core_num;
unsigned int use_cluster = (bin_num + core_num - 1) / core_num;
k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster;
k_dim->z = 1;
}
} // namespace
void KernelRoiAlignRotatedForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const void *features, const void *rois,
void *output, const int batch, const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams roiAlignRotatedParams);
void KernelRoiAlignRotatedBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const void *top_grad, const void *rois,
void *bottom_grad, const int batch, const int height, const int width,
const int channel, const int rois_num,
const RoiAlignRotatedParams roiAlignRotatedParams);
void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois, void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois,
Tensor output, int pooled_height, Tensor output, int pooled_height,
...@@ -47,153 +17,70 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois, ...@@ -47,153 +17,70 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois,
float spatial_scale, float spatial_scale,
int sampling_ratio, bool aligned, int sampling_ratio, bool aligned,
bool clockwise) { bool clockwise) {
TORCH_CHECK(((input.scalar_type() == output.scalar_type()) &&
(output.scalar_type() == rois.scalar_type())),
"data types of input, rois and output should be the same, ",
"but now input type is ", input.scalar_type(), ", rois type is ",
rois.scalar_type(), ", output type is ", output.scalar_type(),
".");
TORCH_CHECK(
(input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf),
"input type should be Float or Half, got ", input.scalar_type(), ".");
TORCH_CHECK(input.dim() == 4, "input should be a 4d tensor, got ",
input.dim(), "D.");
TORCH_CHECK(rois.dim() == 2, "rois should be a 2d tensor, got ", rois.dim(),
"D.");
TORCH_CHECK(output.dim() == 4, "output should be a 4d tensor, got ",
output.dim(), "D.");
TORCH_CHECK((rois.size(0) == output.size(0)),
"the 1st dimensions of rois and output should be the same, ",
"but now the 1st dimension of rois is ", rois.size(0),
", and output is ", output.size(0), ".");
TORCH_CHECK((input.size(1) == output.size(1)),
"the 2nd dimensions of input and output should be the same, ",
"but now the 2nd dimension of input is ", input.size(1),
", and output is ", output.size(1), ".");
int channel = input.size(1);
int width = input.size(3);
int height = input.size(2);
int batch = input.size(0);
int rois_nums = rois.size(0);
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype());
// return if zero-elements
if (input.numel() == 0) {
CNLOG(INFO) << "Skip the zero-elements case.";
return;
}
RoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width,
sampling_ratio, spatial_scale,
aligned, clockwise};
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(rois_nums * pooled_height * pooled_width, &k_dim, &k_type);
auto memory_format = auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
auto input_tensor = auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); auto rois_contiguous =
at::Tensor output_tmp = torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format());
at::empty({rois_nums, channel, pooled_height, pooled_width}, auto output_contiguous =
input.options(), memory_format); torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format);
// get compute queue MluOpTensorDescriptor input_desc, rois_desc, output_desc;
auto queue = torch_mlu::getCurQueue(); input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC);
rois_desc.set(rois_contiguous);
output_desc.set_with_layout(output_contiguous, MLUOP_LAYOUT_NHWC);
// get ptr of tensors // get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(input_tensor); auto input_impl = torch_mlu::getMluTensorImpl(input_);
auto input_ptr = input_impl->cnnlMalloc(); auto input_ptr = input_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois); auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous);
auto rois_ptr = rois_impl->cnnlMalloc(); auto rois_ptr = rois_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output_tmp); auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous);
auto output_ptr = output_impl->cnnlMalloc(); auto output_ptr = output_impl->cnnlMalloc();
KernelRoiAlignRotatedForward(k_dim, k_type, queue, d_type, input_ptr, // get compute handle
rois_ptr, output_ptr, batch, height, width, auto handle = mluOpGetCurrentHandle();
channel, rois_nums, roiAlignRotatedParams); mluOpRoiAlignRotatedForward(
output.copy_(output_tmp); handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr,
pooled_height, pooled_width, sampling_ratio, spatial_scale, aligned,
clockwise, output_desc.desc(), output_ptr);
output.copy_(output_contiguous);
} }
void ROIAlignRotatedBackwardMLUKernelLauncher( void ROIAlignRotatedBackwardMLUKernelLauncher(
Tensor top_grad, Tensor rois, Tensor bottom_grad, int pooled_height, Tensor top_grad, Tensor rois, Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale, int sampling_ratio, bool aligned, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned,
bool clockwise) { bool clockwise) {
TORCH_CHECK(((top_grad.scalar_type() == bottom_grad.scalar_type()) &&
(bottom_grad.scalar_type() == rois.scalar_type())),
"data types of top_grad, rois and bottom_grad should be ",
"the same, but now top_grad type is ", top_grad.scalar_type(),
", rois type is ", rois.scalar_type(), ", bottom_grad type is ",
bottom_grad.scalar_type(), ".");
TORCH_CHECK((bottom_grad.scalar_type() == at::kFloat ||
bottom_grad.scalar_type() == at::kHalf),
"Data type of bottom_grad should be Float ro Half, got ",
bottom_grad.scalar_type(), ".");
TORCH_CHECK(bottom_grad.dim() == 4, "bottom_grad should be a 4d tensor, got ",
top_grad.dim(), "D.");
TORCH_CHECK(rois.dim() == 2, "rois should be a 2d tensor, got ", rois.dim(),
"D.");
TORCH_CHECK(top_grad.dim() == 4, "top_grad should be a 4d tensor, got ",
bottom_grad.dim(), "D.");
TORCH_CHECK((rois.size(0) == top_grad.size(0)),
"the 1st dimensions of rois and top_grad should be the same, ",
"but now the 1st dimension of rois is ", rois.size(0),
", and top_grad is ", top_grad.size(0), ".");
TORCH_CHECK((bottom_grad.size(1) == top_grad.size(1)),
"the 2nd dimensions of bottom_grad and top_grad should be ",
"the same, but now the 2nd dimension of bottom_grad is ",
bottom_grad.size(1), ", and top_grad is ", top_grad.size(1), ".");
int channel = bottom_grad.size(1);
int width = bottom_grad.size(3);
int height = bottom_grad.size(2);
int batch = bottom_grad.size(0);
int rois_nums = rois.size(0);
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(bottom_grad.dtype());
// return if zero-elements
if (bottom_grad.numel() == 0) {
CNLOG(INFO) << "Skip the zero-elements case.";
return;
}
RoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width,
sampling_ratio, spatial_scale,
aligned, clockwise};
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(rois_nums * pooled_height * pooled_width, &k_dim, &k_type);
auto memory_format = auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(top_grad.dim()); torch_mlu::cnnl::ops::get_channels_last_memory_format(top_grad.dim());
auto top_grad_tensor = auto top_grad_ =
torch_mlu::cnnl::ops::cnnl_contiguous(top_grad, memory_format); torch_mlu::cnnl::ops::cnnl_contiguous(top_grad, memory_format);
at::Tensor bottom_grad_tmp = at::empty({batch, channel, height, width}, auto rois_contiguous =
top_grad.options(), memory_format) torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format());
.zero_(); auto bottom_grad_ =
torch_mlu::cnnl::ops::cnnl_contiguous(bottom_grad, memory_format);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors // get ptr of tensors
auto bottom_grad_impl = torch_mlu::getMluTensorImpl(bottom_grad_tmp); auto top_grad_impl = torch_mlu::getMluTensorImpl(top_grad_);
auto bottom_grad_ptr = bottom_grad_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
auto rois_ptr = rois_impl->cnnlMalloc();
auto top_grad_impl = torch_mlu::getMluTensorImpl(top_grad_tensor);
auto top_grad_ptr = top_grad_impl->cnnlMalloc(); auto top_grad_ptr = top_grad_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous);
auto rois_ptr = rois_impl->cnnlMalloc();
auto bottom_grad_impl = torch_mlu::getMluTensorImpl(bottom_grad_);
auto bottom_grad_ptr = bottom_grad_impl->cnnlMalloc();
KernelRoiAlignRotatedBackward(k_dim, k_type, queue, d_type, top_grad_ptr, MluOpTensorDescriptor top_grad_desc, rois_desc, bottom_grad_desc;
rois_ptr, bottom_grad_ptr, batch, height, width, top_grad_desc.set_with_layout(top_grad_, MLUOP_LAYOUT_NHWC);
channel, rois_nums, roiAlignRotatedParams); rois_desc.set(rois_contiguous);
bottom_grad.copy_(bottom_grad_tmp); bottom_grad_desc.set_with_layout(bottom_grad_, MLUOP_LAYOUT_NHWC);
// get compute handle
auto handle = mluOpGetCurrentHandle();
mluOpRoiAlignRotatedBackward(
handle, top_grad_desc.desc(), top_grad_ptr, rois_desc.desc(), rois_ptr,
pooled_height, pooled_width, sampling_ratio, spatial_scale, aligned,
clockwise, bottom_grad_desc.desc(), bottom_grad_ptr);
bottom_grad.copy_(bottom_grad_);
} }
void roi_align_rotated_forward_mlu(Tensor input, Tensor rois, Tensor output, void roi_align_rotated_forward_mlu(Tensor input, Tensor rois, Tensor output,
......
...@@ -11,7 +11,6 @@ try: ...@@ -11,7 +11,6 @@ try:
except ImportError: except ImportError:
from torch.autograd import gradcheck from torch.autograd import gradcheck
_USING_PARROTS = False _USING_PARROTS = False
# yapf:disable # yapf:disable
inputs = [([[[[1., 2.], [3., 4.]]]], inputs = [([[[[1., 2.], [3., 4.]]]],
[[0., 0.5, 0.5, 1., 1., 0]]), [[0., 0.5, 0.5, 1., 1., 0]]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment