Commit b091e4d2 authored by tudejiang79's avatar tudejiang79 Committed by Zaida Zhou
Browse files

[Fix] Fix roi_align_rotated op of MLU (#2210)

* [Fix] roi_align_rotated codes

* [Fix]: fix code style

* [Fix]: fix code style

* [Fix]: fix code style
parent 625e82ce
...@@ -25,10 +25,10 @@ __mlu_func__ void swap(T &a, T &b) { ...@@ -25,10 +25,10 @@ __mlu_func__ void swap(T &a, T &b) {
template <typename T> template <typename T>
__mlu_func__ void bilinearInterpolate(const int input_height, __mlu_func__ void bilinearInterpolate(const int input_height,
const int input_width, T x, T y, const int input_width, T x, T y, T *w1,
const T zero_sign, T *w1, T *w2, T *w3, T *w2, T *w3, T *w4, int *x_low,
T *w4, int *x_low, int *x_high, int *x_high, int *y_low, int *y_high,
int *y_low, int *y_high, bool *empty) { bool *empty) {
// deal with case that the point is out of feature map boundary // deal with case that the point is out of feature map boundary
if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) {
*empty = true; *empty = true;
...@@ -58,10 +58,11 @@ __mlu_func__ void bilinearInterpolate(const int input_height, ...@@ -58,10 +58,11 @@ __mlu_func__ void bilinearInterpolate(const int input_height,
T lx = x - *x_low; T lx = x - *x_low;
T hy = 1.0 - ly; T hy = 1.0 - ly;
T hx = 1.0 - lx; T hx = 1.0 - lx;
*w1 = hy * hx * zero_sign; *w1 = hy * hx;
*w2 = hy * lx * zero_sign; *w2 = hy * lx;
*w3 = ly * hx * zero_sign; *w3 = ly * hx;
*w4 = ly * lx * zero_sign; *w4 = ly * lx;
return;
} }
template <typename T> template <typename T>
...@@ -141,7 +142,7 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram, ...@@ -141,7 +142,7 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
int dst_offset = 0; int dst_offset = 0;
int c_rem, c_slice, c_slice_align, pongc_slice, pongc_slice_align; int c_rem, c_slice, c_slice_align, pongc_slice, pongc_slice_align;
for (int c_offset = 0; c_offset < channel; c_offset += channel_align) { for (int c_offset = 0; c_offset < channel; c_offset += channel_align) {
__nramset(nram_out, channel_align, (T)0); __bang_write_value(nram_out, channel_align, (T)0);
c_rem = channel - c_offset; c_rem = channel - c_offset;
c_slice = channel_align > c_rem ? c_rem : channel_align; c_slice = channel_align > c_rem ? c_rem : channel_align;
c_slice_align = CEIL_ALIGN(c_slice, align_base_128); c_slice_align = CEIL_ALIGN(c_slice, align_base_128);
...@@ -159,9 +160,8 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram, ...@@ -159,9 +160,8 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
T w1, w2, w3, w4; T w1, w2, w3, w4;
bool empty = false; bool empty = false;
int x_low, x_high, y_low, y_high; int x_low, x_high, y_low, y_high;
bilinearInterpolate(height, width, x, y, zero_sign, &w1, &w2, &w3, bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low,
&w4, &x_low, &x_high, &y_low, &y_high, &empty); &x_high, &y_low, &y_high, &empty);
int sample_wdim = x_high - x_low + 1;
/******************************************************* /*******************************************************
| ping | pong | | ping | pong |
|------|-----|-----|-----|-----|-----|-----|-----|-----| |------|-----|-----|-----|-----|-----|-----|-----|-----|
...@@ -170,22 +170,32 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram, ...@@ -170,22 +170,32 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
********************************************************/ ********************************************************/
if (is_first_sample && !empty) { if (is_first_sample && !empty) {
// load input data from dram to nram // load input data from dram to nram
__nramset(nram_ping, SAMPLING_NUM * c_slice_align, (T)0); __bang_write_value(nram_ping, SAMPLING_NUM * c_slice_align, (T)0);
for (int h = y_low; h <= y_high; ++h) { src_offset =
src_offset = (batch_idx * height * width + y_low * width + x_low) * channel +
(batch_idx * height * width + h * width + x_low) * channel + c_offset;
c_offset; dst_offset = 0;
dst_offset = (h - y_low) * SAMPLING_NUM * c_slice_align / 2; __memcpy(nram_ping + dst_offset, input_dram + src_offset,
if (c_slice_align == channel) { c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping + dst_offset, input_dram + src_offset, src_offset = (batch_idx * height * width + y_low * width + x_high) *
sample_wdim * channel * sizeof(T), GDRAM2NRAM); channel +
} else { c_offset;
__memcpy(nram_ping + dst_offset, input_dram + src_offset, dst_offset = c_slice_align;
c_slice * sizeof(T), GDRAM2NRAM, __memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice_align * sizeof(T), channel * sizeof(T), c_slice * sizeof(T), GDRAM2NRAM);
sample_wdim - 1); src_offset = (batch_idx * height * width + y_high * width + x_low) *
} channel +
} c_offset;
dst_offset = c_slice_align * 2;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + y_high * width + x_high) *
channel +
c_offset;
dst_offset = c_slice_align * 3;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
} }
// load next input data to nram // load next input data to nram
if (sample_i + 1 < bin_dim) { if (sample_i + 1 < bin_dim) {
...@@ -200,56 +210,65 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram, ...@@ -200,56 +210,65 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
T p_w1, p_w2, p_w3, p_w4; T p_w1, p_w2, p_w3, p_w4;
bool p_empty = false; bool p_empty = false;
int p_x_low, p_x_high, p_y_low, p_y_high; int p_x_low, p_x_high, p_y_low, p_y_high;
bilinearInterpolate(height, width, p_x, p_y, zero_sign, &p_w1, bilinearInterpolate(height, width, p_x, p_y, &p_w1, &p_w2, &p_w3,
&p_w2, &p_w3, &p_w4, &p_x_low, &p_x_high, &p_w4, &p_x_low, &p_x_high, &p_y_low, &p_y_high,
&p_y_low, &p_y_high, &p_empty); &p_empty);
int p_sample_wdim = p_x_high - p_x_low + 1;
pongc_slice = c_slice; pongc_slice = c_slice;
pongc_slice_align = c_slice_align; pongc_slice_align = c_slice_align;
if (!p_empty) { if (!p_empty) {
__nramset(nram_pong, SAMPLING_NUM * pongc_slice_align, (T)0); __bang_write_value(nram_pong, SAMPLING_NUM * pongc_slice_align,
for (int h = p_y_low; h <= p_y_high; ++h) { (T)0);
src_offset = src_offset =
(batch_idx * height * width + h * width + p_x_low) * (batch_idx * height * width + p_y_low * width + p_x_low) *
channel + channel +
c_offset; c_offset;
dst_offset = dst_offset = 0;
(h - p_y_low) * SAMPLING_NUM * pongc_slice_align / 2; __memcpy(nram_pong + dst_offset, input_dram + src_offset,
if (pongc_slice_align == channel) { c_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async( src_offset =
nram_pong + dst_offset, input_dram + src_offset, (batch_idx * height * width + p_y_low * width + p_x_high) *
p_sample_wdim * channel * sizeof(T), GDRAM2NRAM); channel +
} else { c_offset;
__memcpy_async(nram_pong + dst_offset, dst_offset = pongc_slice_align;
input_dram + src_offset, __memcpy(nram_pong + dst_offset, input_dram + src_offset,
pongc_slice * sizeof(T), GDRAM2NRAM, c_slice * sizeof(T), GDRAM2NRAM);
pongc_slice_align * sizeof(T), src_offset =
channel * sizeof(T), p_sample_wdim - 1); (batch_idx * height * width + p_y_high * width + p_x_low) *
} channel +
} c_offset;
dst_offset = pongc_slice_align * 2;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + p_y_high * width + p_x_high) *
channel +
c_offset;
dst_offset = pongc_slice_align * 3;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
} }
} }
T *tmp_sum = nram_ping + 3 * c_slice_align; T *tmp_sum = nram_ping + 3 * c_slice_align;
if (empty) { if (empty) {
__nramset(tmp_sum, c_slice_align, T(0)); __bang_write_value(tmp_sum, c_slice_align, T(0));
} else { } else {
__bang_mul_const(nram_ping, nram_ping, w1, c_slice_align); __bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align);
__bang_mul_const(nram_ping + c_slice_align, __bang_mul_scalar(nram_ping + c_slice_align,
nram_ping + c_slice_align, w2, c_slice_align); nram_ping + c_slice_align, w2, c_slice_align);
__bang_mul_const(nram_ping + 2 * c_slice_align, __bang_mul_scalar(nram_ping + 2 * c_slice_align,
nram_ping + 2 * c_slice_align, w3, c_slice_align); nram_ping + 2 * c_slice_align, w3, c_slice_align);
__bang_mul_const(nram_ping + 3 * c_slice_align, __bang_mul_scalar(nram_ping + 3 * c_slice_align,
nram_ping + 3 * c_slice_align, w4, c_slice_align); nram_ping + 3 * c_slice_align, w4, c_slice_align);
__bang_sumpool(tmp_sum, nram_ping, c_slice_align, 1, SAMPLING_NUM, __bang_sumpool(tmp_sum, nram_ping, c_slice_align, 1, SAMPLING_NUM,
1, SAMPLING_NUM, 1, 1); 1, SAMPLING_NUM, 1, 1);
} }
__bang_add(nram_out, nram_out, tmp_sum, c_slice_align); __bang_add(nram_out, nram_out, tmp_sum, c_slice_align);
swap(nram_ping, nram_pong); swap(nram_ping, nram_pong);
__asm__ volatile("sync;"); __asm__ volatile("sync;");
is_first_sample = false; is_first_sample = false;
} }
} }
__bang_mul_scalar(nram_out, nram_out, zero_sign, c_slice_align);
// store the result to dram // store the result to dram
int output_offset = int output_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) * ((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
...@@ -310,7 +329,6 @@ __mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram, ...@@ -310,7 +329,6 @@ __mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram,
T cos_theta = std::cos(theta); T cos_theta = std::cos(theta);
T sin_theta = std::sin(theta); T sin_theta = std::sin(theta);
T zero_sign = 1.0f / bin_dim; T zero_sign = 1.0f / bin_dim;
int c_rem, c_slice, pongc_slice, c_offset; int c_rem, c_slice, pongc_slice, c_offset;
c_rem = channel; c_rem = channel;
c_offset = 0; c_offset = 0;
...@@ -369,30 +387,30 @@ __mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram, ...@@ -369,30 +387,30 @@ __mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram,
T w1, w2, w3, w4; T w1, w2, w3, w4;
bool empty = false; bool empty = false;
int x_low, x_high, y_low, y_high; int x_low, x_high, y_low, y_high;
bilinearInterpolate(height, width, x, y, zero_sign, &w1, &w2, &w3, bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low,
&w4, &x_low, &x_high, &y_low, &y_high, &empty); &x_high, &y_low, &y_high, &empty);
if (empty) { if (empty) {
continue; continue;
} else { } else {
__bang_mul_const(nram_output, nram_ping, w1, c_limit); __bang_mul_scalar(nram_output, nram_ping, w1 * zero_sign, c_limit);
__bang_atomic_add( __bang_atomic_add(
(T *)nram_output, (T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel + bottom_grad_dram + batch_idx * height * width * channel +
y_low * width * channel + x_low * channel + c_offset, y_low * width * channel + x_low * channel + c_offset,
(T *)nram_output, c_slice); (T *)nram_output, c_slice);
__bang_mul_const(nram_output, nram_ping, w2, c_limit); __bang_mul_scalar(nram_output, nram_ping, w2 * zero_sign, c_limit);
__bang_atomic_add( __bang_atomic_add(
(T *)nram_output, (T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel + bottom_grad_dram + batch_idx * height * width * channel +
y_low * width * channel + x_high * channel + c_offset, y_low * width * channel + x_high * channel + c_offset,
(T *)nram_output, c_slice); (T *)nram_output, c_slice);
__bang_mul_const(nram_output, nram_ping, w3, c_limit); __bang_mul_scalar(nram_output, nram_ping, w3 * zero_sign, c_limit);
__bang_atomic_add( __bang_atomic_add(
(T *)nram_output, (T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel + bottom_grad_dram + batch_idx * height * width * channel +
y_high * width * channel + x_low * channel + c_offset, y_high * width * channel + x_low * channel + c_offset,
(T *)nram_output, c_slice); (T *)nram_output, c_slice);
__bang_mul_const(nram_output, nram_ping, w4, c_limit); __bang_mul_scalar(nram_output, nram_ping, w4 * zero_sign, c_limit);
__bang_atomic_add( __bang_atomic_add(
(T *)nram_output, (T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel + bottom_grad_dram + batch_idx * height * width * channel +
......
...@@ -99,8 +99,8 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois, ...@@ -99,8 +99,8 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois,
auto input_tensor = auto input_tensor =
torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
at::Tensor output_tmp = at::Tensor output_tmp =
at::empty({batch, channel, pooled_height, pooled_width}, input.options(), at::empty({rois_nums, channel, pooled_height, pooled_width},
memory_format); input.options(), memory_format);
// get compute queue // get compute queue
auto queue = torch_mlu::getCurQueue(); auto queue = torch_mlu::getCurQueue();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment