Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
0c23eb02
Unverified
Commit
0c23eb02
authored
May 19, 2023
by
bdf
Committed by
GitHub
May 19, 2023
Browse files
Sync main with mmcv1.x branch (#2800)
parent
59c1418e
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
408 additions
and
7652 deletions
+408
-7652
docs/en/understand_mmcv/ops.md
docs/en/understand_mmcv/ops.md
+1
-1
docs/zh_cn/understand_mmcv/ops.md
docs/zh_cn/understand_mmcv/ops.md
+1
-1
mmcv/ops/box_iou_rotated.py
mmcv/ops/box_iou_rotated.py
+4
-1
mmcv/ops/csrc/common/mlu/iou3d_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/iou3d_mlu_kernel.mlu
+0
-431
mmcv/ops/csrc/common/mlu/iou3d_utils.hpp
mmcv/ops/csrc/common/mlu/iou3d_utils.hpp
+0
-695
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
+0
-2094
mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu
+0
-483
mmcv/ops/csrc/common/mlu/nms_utils.hpp
mmcv/ops/csrc/common/mlu/nms_utils.hpp
+0
-553
mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu
+0
-493
mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu
+0
-747
mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu
+0
-466
mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu
+0
-532
mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp
mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp
+54
-0
mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp
+35
-101
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
+2
-2
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
+86
-463
mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp
+37
-108
mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp
+68
-89
mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp
+87
-322
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
+33
-70
No files found.
docs/en/understand_mmcv/ops.md
View file @
0c23eb02
...
...
@@ -9,7 +9,7 @@ We implement common ops used in detection, segmentation, etc.
| BallQuery | | √ | √ | | |
| BBoxOverlaps | | √ | √ | √ | √ |
| BorderAlign | | √ | | | |
| BoxIouRotated | √ | √ |
| | |
| BoxIouRotated | √ | √ |
√
| | |
| BoxIouQuadri | √ | √ | | | |
| CARAFE | | √ | √ | | |
| ChamferDistance | | √ | | | |
...
...
docs/zh_cn/understand_mmcv/ops.md
View file @
0c23eb02
...
...
@@ -9,7 +9,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| BallQuery | | √ | √ | | |
| BBoxOverlaps | | √ | √ | √ | √ |
| BorderAlign | | √ | | | |
| BoxIouRotated | √ | √ |
| | |
| BoxIouRotated | √ | √ |
√
| | |
| BoxIouQuadri | √ | √ | | | |
| CARAFE | | √ | √ | | |
| ChamferDistance | | √ | | | |
...
...
mmcv/ops/box_iou_rotated.py
View file @
0c23eb02
...
...
@@ -132,6 +132,9 @@ def box_iou_rotated(bboxes1: torch.Tensor,
cols
=
bboxes2
.
size
(
0
)
if
aligned
:
ious
=
bboxes1
.
new_zeros
(
rows
)
else
:
if
bboxes1
.
device
.
type
==
'mlu'
:
ious
=
bboxes1
.
new_zeros
([
rows
,
cols
])
else
:
ious
=
bboxes1
.
new_zeros
(
rows
*
cols
)
if
not
clockwise
:
...
...
mmcv/ops/csrc/common/mlu/iou3d_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include "iou3d_utils.hpp"
#define SIZE_SRAM_BUF (MAX_SRAM_SIZE)
/* NRAM buffer
* Suppose deal N boxes once time.
----------------------------------------------------------------
| Basic |score (1N)+ |intersect_pts(48N)| |
| |valid_box(1N) |+ ordered_pts(48N)| temp_long(72N) |
| |+ temp_buffer(10N)| | |
|--------------------------|------------------|----------------|
| Reuse | null | null |rotated_pts(16N)|
|-------|------------------|------------------|----------------|
---------------------------------------------------------------------------
| Basic | dist_ram(24N) | valid_pts(24N) |box1(5N) |box1_buffer(5KB) |
| | |+ nums_in_ram(1N)|+ box2(5N)|+nram_save(5KB) |
|--------------------------|-----------------|----------|-----------------|
| Reuse | vec_buffer(5N) | null | null | null |
|-------|------------------|-----------------|----------|-----------------|
Total Basic Memory Size = 239N * sizeof(float) + 10KB
*/
__nram__ char nram_buffer[MAX_NRAM_SIZE];
__mlu_shared__ char sram_buffer[SIZE_SRAM_BUF];
template <typename T>
__mlu_func__ void iou3D_detection(int32_t &result_box_num, int32_t *output_data,
const T *boxes_data, float *scores_data,
const int core_limit, const int input_box_num,
const float iou_threshold,
mluMemcpyDirection_t scores_load_dir,
mluMemcpyDirection_t scores_store_dir,
mluMemcpyDirection_t boxes_load_dir) {
// NRAM divide by (2+4*COMPUTE_COUNT_ALIGN) copies of NRAM, counted by bytes
const int nram_save_limit_count = 256;
int box_read_limit_count = 256;
float div_thresh_iou = 1.0 / iou_threshold;
// every box require 239 * sizeof(float) space in nram;
const int32_t copies_of_nram = 239 * sizeof(float);
const int32_t limit = (MAX_NRAM_SIZE - 5 * box_read_limit_count * sizeof(T) -
nram_save_limit_count * sizeof(int32_t)) /
copies_of_nram;
// x,y,z,dx,dy,dz,angle
const T *input_x_ptr = boxes_data;
const T *input_y_ptr = input_x_ptr + input_box_num;
const T *input_dx_ptr = input_y_ptr + 2 * input_box_num;
const T *input_dy_ptr = input_dx_ptr + input_box_num;
const T *input_angle_ptr = input_dy_ptr + 2 * input_box_num;
float *input_score_ptr = scores_data;
// data split
int avg_cluster = 0;
int rem_cluster = 0;
int len_cluster = 0;
int cluster_offset = 0;
if (clusterDim > 0) {
// union
avg_cluster = input_box_num / clusterDim;
rem_cluster = input_box_num % clusterDim;
len_cluster = avg_cluster + (clusterId < rem_cluster ? 1 : 0);
cluster_offset = avg_cluster * clusterId +
(clusterId <= rem_cluster ? clusterId : rem_cluster);
} else {
// block
len_cluster = input_box_num;
cluster_offset = 0;
}
int len_core = input_box_num;
int input_offset = 0;
if (core_limit > 1) {
int avg_core = len_cluster / coreDim;
int rem_core = len_cluster % coreDim;
len_core = avg_core + (coreId < rem_core ? 1 : 0);
int core_offset =
avg_core * coreId + (coreId <= rem_core ? coreId : rem_core);
input_offset = cluster_offset + core_offset;
}
int32_t max_seg_pad = IOU3D_DOWN(limit, IOU3D_SIZE);
int repeat_iou_compute = len_core / max_seg_pad;
int remain_iou_compute = len_core % max_seg_pad;
// basic consistent memory layout
void *score = ((char *)nram_buffer);
void *valid_box = ((char *)score) + 1 * max_seg_pad * sizeof(float);
void *temp_buffer = ((char *)valid_box) + 1 * max_seg_pad * sizeof(float);
void *intersect_pts_x =
((char *)temp_buffer) + 10 * max_seg_pad * sizeof(float);
void *intersect_pts_y =
((char *)intersect_pts_x) + 24 * max_seg_pad * sizeof(float);
void *ordered_pts_x =
((char *)intersect_pts_y) + 24 * max_seg_pad * sizeof(float);
void *ordered_pts_y =
((char *)ordered_pts_x) + 24 * max_seg_pad * sizeof(float);
void *temp_long_1 =
((char *)ordered_pts_y) + 24 * max_seg_pad * sizeof(float);
void *temp_long_2 = ((char *)temp_long_1) + 24 * max_seg_pad * sizeof(float);
void *temp_long_3 = ((char *)temp_long_2) + 24 * max_seg_pad * sizeof(float);
void *dist_ram = ((char *)temp_long_3) + 24 * max_seg_pad * sizeof(float);
void *valid_pts = ((char *)dist_ram) + 24 * max_seg_pad * sizeof(float);
void *nums_in_ram = ((char *)valid_pts) + 24 * max_seg_pad * sizeof(float);
T *box1 = (T *)(((char *)nums_in_ram) + 1 * max_seg_pad * sizeof(float));
T *box2 = (T *)(((char *)box1) + 5 * max_seg_pad * sizeof(float));
void *box1_buffer = ((char *)box2) + 5 * max_seg_pad * sizeof(float);
int32_t *nram_save =
(int32_t *)(((char *)box1_buffer) + 5 * box_read_limit_count * sizeof(T));
// nram_save ~ nram_save_limit_count * sizeof(int32_t)
int nram_save_count = 0;
// reuse memory
void *rotated_pts1_x = ((char *)dist_ram);
void *rotated_pts1_y =
((char *)rotated_pts1_x) + 4 * max_seg_pad * sizeof(float);
void *rotated_pts2_x =
((char *)rotated_pts1_y) + 4 * max_seg_pad * sizeof(float);
void *rotated_pts2_y =
((char *)rotated_pts2_x) + 4 * max_seg_pad * sizeof(float);
void *vec_buffer = ((char *)temp_long_1) + 5 * max_seg_pad * sizeof(float);
// vec_buffer ~ 16 * max_seg_pad * sizeof(float)
// First, initialize ram with all 0, or could cause nan/inf unexcepted results
__bang_write_zero((unsigned char *)nram_buffer, copies_of_nram * max_seg_pad);
// number 8 and 0xff relay on box_read_limit_count initial as 256
const int max_box_seg_id = (input_box_num - 1) >> 8;
const int last_rem_box_number = ((input_box_num - 1) & 0xff) + 1;
for (int32_t cur_box = 0; cur_box < input_box_num; ++cur_box) {
__sync_all();
int box_seg_id = cur_box >> 8, box_id = cur_box & 0xff;
box_read_limit_count = box_seg_id == max_box_seg_id ? last_rem_box_number
: box_read_limit_count;
if (box_id == 0) {
// x,y,z,dx,dy,dz,angle
int offset_num = box_seg_id << 8;
// x
__memcpy((char *)box1_buffer, input_x_ptr + offset_num,
box_read_limit_count * 1 * sizeof(T), boxes_load_dir,
box_read_limit_count * 1 * sizeof(T),
box_read_limit_count * 1 * sizeof(T), 0);
// y
__memcpy((char *)box1_buffer + box_read_limit_count * 1 * sizeof(T),
input_y_ptr + offset_num, box_read_limit_count * 1 * sizeof(T),
boxes_load_dir, box_read_limit_count * 1 * sizeof(T),
box_read_limit_count * 1 * sizeof(T), 0);
// dx
__memcpy((char *)box1_buffer + box_read_limit_count * 2 * sizeof(T),
input_dx_ptr + offset_num, box_read_limit_count * 1 * sizeof(T),
boxes_load_dir, box_read_limit_count * 1 * sizeof(T),
box_read_limit_count * 1 * sizeof(T), 0);
// dy
__memcpy((char *)box1_buffer + box_read_limit_count * 3 * sizeof(T),
input_dy_ptr + offset_num, box_read_limit_count * 1 * sizeof(T),
boxes_load_dir, box_read_limit_count * 1 * sizeof(T),
box_read_limit_count * 1 * sizeof(T), 0);
// angle
__memcpy((char *)box1_buffer + box_read_limit_count * 4 * sizeof(T),
input_angle_ptr + offset_num,
box_read_limit_count * 1 * sizeof(T), boxes_load_dir,
box_read_limit_count * 1 * sizeof(T),
box_read_limit_count * 1 * sizeof(T), 0);
}
if (((float *)input_score_ptr)[cur_box] == 0) {
continue;
}
// save result
nram_save[nram_save_count] = cur_box;
result_box_num++;
nram_save_count++;
if (clusterId == 0 && coreId == 0 &&
nram_save_count == nram_save_limit_count) {
pvLock();
__memcpy(output_data, nram_save, nram_save_count * sizeof(int32_t),
NRAM2GDRAM);
pvUnlock();
output_data += nram_save_count;
nram_save_count = 0;
}
// prepare box1
// x
__bang_write_value((float *)box1, max_seg_pad,
float(((T *)box1_buffer)[box_id]));
// y
__bang_write_value(
(float *)box1 + max_seg_pad, max_seg_pad,
float(((T *)box1_buffer)[box_id + 1 * box_read_limit_count]));
// dx
__bang_write_value(
(float *)box1 + max_seg_pad * 2, max_seg_pad,
float(((T *)box1_buffer)[box_id + 2 * box_read_limit_count]));
// dy
__bang_write_value(
(float *)box1 + max_seg_pad * 3, max_seg_pad,
float(((T *)box1_buffer)[box_id + 3 * box_read_limit_count]));
// angle
__bang_write_value(
(float *)box1 + max_seg_pad * 4, max_seg_pad,
float(((T *)box1_buffer)[box_id + 4 * box_read_limit_count]));
float max_area = 1.0f *
((T *)box1_buffer)[box_id + 2 * box_read_limit_count] *
((T *)box1_buffer)[box_id + 3 * box_read_limit_count];
// update score
for (int i = 0; i <= repeat_iou_compute; i++) {
if (i == repeat_iou_compute && remain_iou_compute == 0) {
break;
}
int seg_len = max_seg_pad;
int cpy_len =
(i == repeat_iou_compute) ? remain_iou_compute : max_seg_pad;
// int half_offset = std::is_same<T, half>::value ? max_seg_pad * 5 : 0;
int half_offset = (sizeof(T) == sizeof(half)) ? max_seg_pad * 5 : 0;
// score
__memcpy(score, input_score_ptr + input_offset + i * max_seg_pad,
cpy_len * sizeof(float), scores_load_dir,
cpy_len * sizeof(float), cpy_len * sizeof(float), 0);
// x
__memcpy(box2 + half_offset, input_x_ptr + input_offset + i * max_seg_pad,
cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T),
cpy_len * 1 * sizeof(T), 0);
// y
__memcpy(box2 + half_offset + seg_len * 1,
input_y_ptr + input_offset + i * max_seg_pad,
cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T),
cpy_len * 1 * sizeof(T), 0);
// dx
__memcpy(box2 + half_offset + seg_len * 2,
input_dx_ptr + input_offset + i * max_seg_pad,
cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T),
cpy_len * 1 * sizeof(T), 0);
// dy
__memcpy(box2 + half_offset + seg_len * 3,
input_dy_ptr + input_offset + i * max_seg_pad,
cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T),
cpy_len * 1 * sizeof(T), 0);
// angle
__memcpy(box2 + half_offset + seg_len * 4,
input_angle_ptr + input_offset + i * max_seg_pad,
cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T),
cpy_len * 1 * sizeof(T), 0);
// if (std::is_same<T, half>::value) {
if (sizeof(T) == sizeof(half)) {
__bang_half2float((float *)box2, (half *)(box2 + half_offset),
seg_len * 5);
}
// Calculate rotated vertices
void *temp1_ram = ((char *)temp_buffer);
void *temp2_ram = ((char *)temp_buffer) + seg_len * sizeof(float);
void *temp3_ram = ((char *)temp_buffer) + 2 * seg_len * sizeof(float);
void *temp4_ram = ((char *)temp_buffer) + 3 * seg_len * sizeof(float);
getRotatedVertices((float *)rotated_pts1_x, (float *)rotated_pts1_y,
(float *)box1, (float *)temp1_ram, (float *)temp2_ram,
(float *)temp3_ram, (float *)temp4_ram, seg_len);
getRotatedVertices((float *)rotated_pts2_x, (float *)rotated_pts2_y,
(float *)box2, (float *)temp1_ram, (float *)temp2_ram,
(float *)temp3_ram, (float *)temp4_ram, seg_len);
__bang_write_zero((float *)valid_pts, 24 * seg_len);
__bang_write_zero((float *)nums_in_ram, seg_len);
__bang_write_value(((float *)valid_box), seg_len, 1.0f);
void *vec1_x = ((char *)vec_buffer);
void *vec1_y = ((char *)vec1_x) + 4 * seg_len * sizeof(float);
void *vec2_x = ((char *)vec1_y) + 4 * seg_len * sizeof(float);
void *vec2_y = ((char *)vec2_x) + 4 * seg_len * sizeof(float);
void *temp5_ram = ((char *)temp_buffer) + 4 * seg_len * sizeof(float);
void *temp6_ram = ((char *)temp_buffer) + 5 * seg_len * sizeof(float);
void *temp7_ram = ((char *)temp_buffer) + 6 * seg_len * sizeof(float);
void *temp8_ram = ((char *)temp_buffer) + 7 * seg_len * sizeof(float);
void *temp9_ram = ((char *)temp_buffer) + 8 * seg_len * sizeof(float);
void *temp10_ram = ((char *)temp_buffer) + 9 * seg_len * sizeof(float);
// Get all intersection points
getIntersectPts(
(float *)rotated_pts1_x, (float *)rotated_pts1_y,
(float *)rotated_pts2_x, (float *)rotated_pts2_y, (float *)vec1_x,
(float *)vec1_y, (float *)vec2_x, (float *)vec2_y,
(float *)intersect_pts_x, (float *)intersect_pts_y,
(float *)valid_pts, (float *)nums_in_ram, (float *)temp1_ram,
(float *)temp2_ram, (float *)temp3_ram, (float *)temp4_ram,
(float *)temp5_ram, (float *)temp6_ram, (float *)temp7_ram,
(float *)temp8_ram, (float *)temp9_ram, (float *)temp10_ram, seg_len);
// Where nums_in <= 2, set valid_box to false
__bang_write_value((float *)temp9_ram, COMPUTE_COUNT_ALIGN, (float)2);
__bang_cycle_gt((float *)temp1_ram, (float *)nums_in_ram,
(float *)temp9_ram, seg_len, COMPUTE_COUNT_ALIGN);
__bang_and((float *)valid_box, (float *)valid_box, (float *)temp1_ram,
seg_len);
__bang_cycle_and((float *)valid_pts, (float *)valid_pts,
(float *)valid_box, 24 * seg_len, seg_len);
// Convex-hull-graham to order the intersection points in clockwise order
// and find the contour area
convexHullGraham(
(float *)intersect_pts_x, (float *)intersect_pts_y,
(float *)ordered_pts_x, (float *)ordered_pts_y, (float *)dist_ram,
(float *)valid_box, (float *)valid_pts, (float *)nums_in_ram,
(float *)temp7_ram, (float *)temp8_ram, (float *)temp9_ram,
(float *)temp_long_1, (float *)temp_long_2, (float *)temp_long_3,
seg_len, seg_len);
// Calculate polygon area
// set temp1 = intersection part area
polygonArea((float *)ordered_pts_x, (float *)ordered_pts_y,
(float *)valid_box, (float *)valid_pts, (float *)nums_in_ram,
(float *)temp1_ram, (float *)temp2_ram, (float *)temp3_ram,
(float *)temp4_ram, (float *)temp5_ram, (float *)temp6_ram,
(float *)temp7_ram, (float *)temp8_ram, (float *)temp9_ram,
seg_len);
// area
__bang_mul((float *)temp2_ram, (float *)box2 + seg_len * 2,
(float *)box2 + seg_len * 3, seg_len);
// get the area_U: area + max_area - area_I
__bang_add_scalar((float *)temp2_ram, (float *)temp2_ram, float(max_area),
seg_len);
__bang_sub((float *)temp2_ram, (float *)temp2_ram, (float *)temp1_ram,
seg_len); // area_U
if (iou_threshold > 0.0) {
__bang_mul_scalar((float *)temp1_ram, (float *)temp1_ram,
div_thresh_iou, seg_len);
} else {
__bang_mul_scalar((float *)temp2_ram, (float *)temp2_ram, iou_threshold,
seg_len);
}
__bang_ge((float *)temp1_ram, (float *)temp2_ram, (float *)temp1_ram,
seg_len);
__bang_mul((float *)score, (float *)score, (float *)temp1_ram, seg_len);
pvLock();
__memcpy(input_score_ptr + input_offset + i * max_seg_pad, score,
cpy_len * sizeof(float), scores_store_dir,
cpy_len * sizeof(float), cpy_len * sizeof(float), 0);
pvUnlock();
}
}
if (clusterId == 0 && coreId == 0 && nram_save_count) {
pvLock();
__memcpy(output_data, nram_save, nram_save_count * sizeof(int32_t),
NRAM2GDRAM);
pvUnlock();
}
}
__mlu_global__ void MLUBlockorUnionIKernelOU3D(
const void *input_boxes, const int input_box_num, const float iou_threshold,
const cnrtDataType_t data_type_input, void *workspace, void *result_num,
void *output) {
int input_dwidth = (data_type_input == CNRT_FLOAT32) ? 4 : 2;
mluMemcpyDirection_t scores_load_dir = GDRAM2NRAM;
mluMemcpyDirection_t scores_store_dir = NRAM2GDRAM;
mluMemcpyDirection_t boxes_load_dir = GDRAM2NRAM;
float *scores_data = (float *)workspace;
float *boxes_data = (float *)input_boxes;
const int cluster_score_size = input_box_num * sizeof(float);
const int cluster_boxes_size = input_box_num * 7 * input_dwidth;
char *sram_score = (char *)sram_buffer;
char *sram_boxes = (char *)sram_buffer + cluster_score_size;
if (clusterDim == 1 && SIZE_SRAM_BUF > cluster_score_size) {
scores_data = (float *)sram_score;
scores_load_dir = SRAM2NRAM;
scores_store_dir = NRAM2SRAM;
if (coreId == 0x80) {
__sramset((void *)sram_buffer, input_box_num, 1.0f);
}
} else {
if (coreId == 0) {
__gdramset(scores_data, input_box_num, 1.0f);
}
}
if (clusterDim == 1 &&
SIZE_SRAM_BUF - cluster_score_size >= cluster_boxes_size) {
boxes_load_dir = SRAM2NRAM;
boxes_data = (float *)sram_boxes;
if (coreId == 0x80) {
__memcpy((char *)boxes_data, (char *)input_boxes, cluster_boxes_size,
GDRAM2SRAM);
}
}
__sync_cluster();
int32_t result_box_num = 0;
int32_t *out_data = (int32_t *)output;
switch (data_type_input) {
default: { return; }
case CNRT_FLOAT16: {
iou3D_detection(result_box_num, out_data, (half *)boxes_data, scores_data,
taskDim, input_box_num, iou_threshold, scores_load_dir,
scores_store_dir, boxes_load_dir);
}; break;
case CNRT_FLOAT32: {
iou3D_detection(result_box_num, out_data, boxes_data, scores_data,
taskDim, input_box_num, iou_threshold, scores_load_dir,
scores_store_dir, boxes_load_dir);
}; break;
}
((int32_t *)result_num)[0] = result_box_num;
}
void KernelIou3d(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t data_type_input, const void *boxes_dram,
const int input_box_num, const float iou_threshold,
void *workspace, void *output_size, void *output) {
switch (k_type) {
default: { return; }
case CNRT_FUNC_TYPE_BLOCK:
case CNRT_FUNC_TYPE_UNION1:
case CNRT_FUNC_TYPE_UNION2:
case CNRT_FUNC_TYPE_UNION4:
case CNRT_FUNC_TYPE_UNION8:
case CNRT_FUNC_TYPE_UNION16: {
MLUBlockorUnionIKernelOU3D<<<k_dim, k_type, queue>>>(
(void *)boxes_dram, input_box_num, iou_threshold, data_type_input,
workspace, output_size, output);
}; break;
}
}
mmcv/ops/csrc/common/mlu/iou3d_utils.hpp
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef IOU3D_UTILS_HPP_
#define IOU3D_UTILS_HPP_
#include "common_mlu_helper.hpp"
#define IOU3D_SIZE 64
#define IOU3D_UP(x, y) (x / y + (int)(x % y > 0)) * y
#define IOU3D_DOWN(x, y) (x / y) * y
#define SIZE_NRAM_BUF (MAX_NRAM_SIZE)
#define SIZE_SRAM_BUF (MAX_SRAM_SIZE)
#define COMPUTE_COUNT_ALIGN 64
#define INFO_NUM (5) // score, x1, y1, x2, y2
#define REDUCE_NUM \
(7) // score, x1, y1, x2, y2, max_index (reserve 2 num for half-type input)
#define SINGLE_BOX_DIM 5
#define MEMORY_CORE (0x80)
__mlu_func__
void
pvLock
()
{
#if __BANG_ARCH__ == 270
if
(
coreId
!=
MEMORY_CORE
)
{
__bang_lock
(
0
,
0
);
}
#endif
}
__mlu_func__
void
pvUnlock
()
{
#if __BANG_ARCH__ == 270
if
(
coreId
!=
MEMORY_CORE
)
{
__bang_unlock
(
0
,
0
);
}
#endif
}
// cross2d<T>(A, B) = A.x * B.y - A.y * B.x;
template
<
typename
T
>
inline
__mlu_func__
void
cross2d
(
T
*
result
,
const
T
*
p1_x
,
const
T
*
p1_y
,
const
T
*
p2_x
,
const
T
*
p2_y
,
const
int
&
length
,
T
*
temp_ram
)
{
__bang_mul
((
T
*
)
temp_ram
,
(
T
*
)
p1_x
,
(
T
*
)
p2_y
,
length
);
__bang_mul
((
T
*
)
result
,
(
T
*
)
p1_y
,
(
T
*
)
p2_x
,
length
);
__bang_sub
((
T
*
)
result
,
(
T
*
)
temp_ram
,
(
T
*
)
result
,
length
);
}
// dot2d<T>(A, B) = A.x * B.x + A.y * B.y
template
<
typename
T
>
inline
__mlu_func__
void
dot2d
(
T
*
result
,
const
T
*
p1_x
,
const
T
*
p1_y
,
const
T
*
p2_x
,
const
T
*
p2_y
,
const
int
&
length
,
T
*
temp_ram
)
{
__bang_mul
((
T
*
)
temp_ram
,
(
T
*
)
p1_x
,
(
T
*
)
p2_x
,
length
);
__bang_mul
((
T
*
)
result
,
(
T
*
)
p1_y
,
(
T
*
)
p2_y
,
length
);
__bang_add
((
T
*
)
result
,
(
T
*
)
temp_ram
,
(
T
*
)
result
,
length
);
}
template
<
typename
T
>
__mlu_func__
void
getRotatedVertices
(
T
*
pts_x
,
T
*
pts_y
,
T
*
box
,
T
*
temp1
,
T
*
temp2
,
T
*
temp3
,
T
*
temp4
,
const
uint32_t
&
actual_compute_box_num
)
{
// T cosTheta2 = (T)cos(theta) * 0.5f; -- temp1
// T sinTheta2 = (T)sin(theta) * 0.5f; -- temp2
// theta is the box's 5th data: a, rotated radian;
#if __BANG_ARCH__ >= 300
__bang_cos
((
float
*
)
temp1
,
((
float
*
)
box
)
+
4
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sin
((
float
*
)
temp2
,
((
float
*
)
box
)
+
4
*
actual_compute_box_num
,
actual_compute_box_num
);
#else
__bang_taylor4_cos
((
T
*
)
temp1
,
((
T
*
)
box
)
+
4
*
actual_compute_box_num
,
(
T
*
)
temp3
,
(
T
*
)
temp4
,
actual_compute_box_num
);
__bang_taylor4_sin
((
T
*
)
temp2
,
((
T
*
)
box
)
+
4
*
actual_compute_box_num
,
(
T
*
)
temp3
,
(
T
*
)
temp4
,
actual_compute_box_num
);
#endif
__bang_mul_scalar
((
T
*
)
temp1
,
(
T
*
)
temp1
,
(
T
)
0.5
,
actual_compute_box_num
);
__bang_mul_scalar
((
T
*
)
temp2
,
(
T
*
)
temp2
,
(
T
)
0.5
,
actual_compute_box_num
);
// Temp3 = sinTheta2 * box.h;
// Temp4 = cosTheta2 * box.w;
__bang_mul
((
T
*
)
temp3
,
(
T
*
)
temp2
,
((
T
*
)
box
)
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp4
,
(
T
*
)
temp1
,
((
T
*
)
box
)
+
2
*
actual_compute_box_num
,
actual_compute_box_num
);
// pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
// pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
__bang_sub
((
T
*
)
pts_x
,
(
T
*
)
box
,
(
T
*
)
temp3
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_x
,
(
T
*
)
pts_x
,
(
T
*
)
temp4
,
actual_compute_box_num
);
__bang_add
((
T
*
)
pts_x
+
1
*
actual_compute_box_num
,
(
T
*
)
box
,
(
T
*
)
temp3
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_x
+
1
*
actual_compute_box_num
,
(
T
*
)
pts_x
+
1
*
actual_compute_box_num
,
(
T
*
)
temp4
,
actual_compute_box_num
);
// Temp3 = cosTheta2 * box.h;
// Temp4 = sinTheta2 * box.w;
__bang_mul
((
T
*
)
temp3
,
(
T
*
)
temp1
,
box
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp4
,
(
T
*
)
temp2
,
box
+
2
*
actual_compute_box_num
,
actual_compute_box_num
);
// pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
// pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
__bang_add
((
T
*
)
pts_y
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
(
T
*
)
temp3
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_y
,
(
T
*
)
pts_y
,
(
T
*
)
temp4
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_y
+
1
*
actual_compute_box_num
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
(
T
*
)
temp3
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_y
+
1
*
actual_compute_box_num
,
(
T
*
)
pts_y
+
1
*
actual_compute_box_num
,
(
T
*
)
temp4
,
actual_compute_box_num
);
// pts[2].x = 2 * box.x_ctr - pts[0].x;
// pts[3].x = 2 * box.x_ctr - pts[1].x;
__bang_add
((
T
*
)
pts_x
+
2
*
actual_compute_box_num
,
(
T
*
)
box
,
(
T
*
)
box
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_x
+
2
*
actual_compute_box_num
,
(
T
*
)
pts_x
+
2
*
actual_compute_box_num
,
(
T
*
)
pts_x
,
actual_compute_box_num
);
__bang_add
((
T
*
)
pts_x
+
3
*
actual_compute_box_num
,
(
T
*
)
box
,
(
T
*
)
box
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_x
+
3
*
actual_compute_box_num
,
(
T
*
)
pts_x
+
3
*
actual_compute_box_num
,
(
T
*
)
pts_x
+
1
*
actual_compute_box_num
,
actual_compute_box_num
);
// pts[2].y = 2 * box.y_ctr - pts[0].y;
// pts[3].y = 2 * box.y_ctr - pts[1].y;
__bang_add
((
T
*
)
pts_y
+
2
*
actual_compute_box_num
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_y
+
2
*
actual_compute_box_num
,
(
T
*
)
pts_y
+
2
*
actual_compute_box_num
,
(
T
*
)
pts_y
,
actual_compute_box_num
);
__bang_add
((
T
*
)
pts_y
+
3
*
actual_compute_box_num
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_y
+
3
*
actual_compute_box_num
,
(
T
*
)
pts_y
+
3
*
actual_compute_box_num
,
(
T
*
)
pts_y
+
1
*
actual_compute_box_num
,
actual_compute_box_num
);
}
template
<
typename
T
>
__mlu_func__
void
getIntersectPts
(
T
*
rotated_pts1_x
,
T
*
rotated_pts1_y
,
T
*
rotated_pts2_x
,
T
*
rotated_pts2_y
,
T
*
vec1_x
,
T
*
vec1_y
,
T
*
vec2_x
,
T
*
vec2_y
,
T
*
intersect_pts_x
,
T
*
intersect_pts_y
,
T
*
valid_pts
,
T
*
nums_in_ram
,
T
*
temp1_ram
,
T
*
temp2_ram
,
T
*
temp3_ram
,
T
*
temp4_ram
,
T
*
temp5_ram
,
T
*
temp6_ram
,
T
*
temp7_ram
,
T
*
temp8_ram
,
T
*
temp9_ram
,
T
*
temp10_ram
,
const
uint32_t
&
actual_compute_box_num
)
{
// Initialize const data to ram
// temp3 = const 1e-14(@float), length = COMPUTE_COUNT_ALIGN
#if __BANG_ARCH__ >= 300
__bang_write_value
((
T
*
)
temp3_ram
,
COMPUTE_COUNT_ALIGN
,
(
T
)
1e-14
);
#else
// NOTE: Since active_reciphp function has strict value range,
// [2.2205e-16, 2e6]@float, [0.00391, 65504]@half
__bang_write_value
((
T
*
)
temp3_ram
,
COMPUTE_COUNT_ALIGN
,
(
float
)
1e-14
);
#endif
// temp4 = const T(0), length = COMPUTE_COUNT_ALIGN
__bang_write_value
((
T
*
)
temp4_ram
,
COMPUTE_COUNT_ALIGN
,
(
T
)
0
);
// temp5 = const T(1), length = COMPUTE_COUNT_ALIGN
__bang_write_value
((
T
*
)
temp5_ram
,
COMPUTE_COUNT_ALIGN
,
(
T
)
1
);
// Line vector, from p1 to p2 is: p1+(p2-p1)*t, t=[0,1]
// for i = 0~3, vec[i] = pts[(i+1)%4] - pts[i]
__bang_sub
((
T
*
)
vec1_x
,
(
T
*
)
rotated_pts1_x
+
actual_compute_box_num
,
(
T
*
)
rotated_pts1_x
,
3
*
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec1_x
+
3
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_x
,
(
T
*
)
rotated_pts1_x
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec1_y
,
(
T
*
)
rotated_pts1_y
+
actual_compute_box_num
,
(
T
*
)
rotated_pts1_y
,
3
*
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec1_y
+
3
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_y
,
(
T
*
)
rotated_pts1_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec2_x
,
(
T
*
)
rotated_pts2_x
+
actual_compute_box_num
,
(
T
*
)
rotated_pts2_x
,
3
*
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec2_x
+
3
*
actual_compute_box_num
,
(
T
*
)
rotated_pts2_x
,
(
T
*
)
rotated_pts2_x
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec2_y
,
(
T
*
)
rotated_pts2_y
+
actual_compute_box_num
,
(
T
*
)
rotated_pts2_y
,
3
*
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec2_y
+
3
*
actual_compute_box_num
,
(
T
*
)
rotated_pts2_y
,
(
T
*
)
rotated_pts2_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
// First, line test - test all line combos for intersection, 4x4 possible
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
// T det = cross2d<T>(vec2[j], vec1[i]) -- temp2
cross2d
<
T
>
((
T
*
)
temp2_ram
,
(
T
*
)
vec2_x
+
j
*
actual_compute_box_num
,
(
T
*
)
vec2_y
+
j
*
actual_compute_box_num
,
(
T
*
)
vec1_x
+
i
*
actual_compute_box_num
,
(
T
*
)
vec1_y
+
i
*
actual_compute_box_num
,
actual_compute_box_num
,
(
T
*
)
temp1_ram
);
// temp8 = sign(det), since active_reciphp only receive positive values
__bang_active_sign
((
T
*
)
temp8_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
// deal with parallel lines, temp2 = fabs(det), temp1 = temp2 > 1e-14
__bang_active_abs
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
__bang_cycle_gt
((
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
temp3_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
// Where temp1 = false, set recip input to 1, avoiding recip(0), cause inf
__bang_not
((
T
*
)
temp9_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
// temp2 = 1/temp2, use mult (1/temp2) instead of div temp2
#if __BANG_ARCH__ >= 300
__bang_recip
((
float
*
)
temp2_ram
,
(
float
*
)
temp2_ram
,
actual_compute_box_num
);
#else
// NOTE: active_reciphp function has strict value range:
// [2.2205e-16, 2e6]@float, [0.00391, 65504]@half
__bang_active_reciphp
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
#endif
// Restore temp2 invalid box value 1 and sign-bit
__bang_mul
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
temp8_ram
,
actual_compute_box_num
);
// auto vec12 = pts2[j] - pts1[i], (temp6, temp7) = (x, y)
__bang_sub
((
T
*
)
temp6_ram
,
(
T
*
)
rotated_pts2_x
+
j
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_x
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp7_ram
,
(
T
*
)
rotated_pts2_y
+
j
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_y
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
// T t1 = cross2d<T>(vec2[j], vec12) mult (1/det) -- temp8
cross2d
<
T
>
((
T
*
)
temp8_ram
,
(
T
*
)
vec2_x
+
j
*
actual_compute_box_num
,
(
T
*
)
vec2_y
+
j
*
actual_compute_box_num
,
(
T
*
)
temp6_ram
,
(
T
*
)
temp7_ram
,
actual_compute_box_num
,
(
T
*
)
temp9_ram
);
__bang_mul
((
T
*
)
temp8_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
// temp1 &= (t1 >= 0.0f && t1 <= 1.0f) -- temp9
__bang_cycle_ge
((
T
*
)
temp9_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
__bang_cycle_le
((
T
*
)
temp9_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
temp5_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
// T t2 = cross2d<T>(vec1[i], vec12) mult temp2 -- temp9
// NOTE: temp8(t1) is used after, reuse temp7(p2_y) as cross2d temp ram
cross2d
<
T
>
((
T
*
)
temp9_ram
,
(
T
*
)
vec1_x
+
i
*
actual_compute_box_num
,
(
T
*
)
vec1_y
+
i
*
actual_compute_box_num
,
(
T
*
)
temp6_ram
,
(
T
*
)
temp7_ram
,
actual_compute_box_num
,
(
T
*
)
temp7_ram
);
__bang_mul
((
T
*
)
temp9_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
// temp1 &= (t2 >= 0.0f && t2 <= 1.0f) -- temp9
__bang_cycle_ge
((
T
*
)
temp7_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp7_ram
,
actual_compute_box_num
);
__bang_cycle_le
((
T
*
)
temp7_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp5_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp7_ram
,
actual_compute_box_num
);
// intersections = (pts1[i] + vec1[i] * t1) * temp1
__bang_mul
((
T
*
)
temp9_ram
,
(
T
*
)
vec1_x
+
i
*
actual_compute_box_num
,
(
T
*
)
temp8_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
temp9_ram
,
(
T
*
)
rotated_pts1_x
+
i
*
actual_compute_box_num
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
intersect_pts_x
+
(
4
*
i
+
j
)
*
actual_compute_box_num
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp9_ram
,
(
T
*
)
vec1_y
+
i
*
actual_compute_box_num
,
(
T
*
)
temp8_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
temp9_ram
,
(
T
*
)
rotated_pts1_y
+
i
*
actual_compute_box_num
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
intersect_pts_y
+
(
4
*
i
+
j
)
*
actual_compute_box_num
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
// Assign `valid_pts` bit and accumulate `nums_in` of valid points of each
// box pair
__bang_or
((
T
*
)
valid_pts
+
(
4
*
i
+
j
)
*
actual_compute_box_num
,
(
T
*
)
valid_pts
+
(
4
*
i
+
j
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
nums_in_ram
,
(
T
*
)
nums_in_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
}
}
// Check for vertices of rect1 inside rect2
// temp5 = ABdotAB
dot2d
<
T
>
((
T
*
)
temp5_ram
,
(
T
*
)
vec2_x
,
(
T
*
)
vec2_y
,
(
T
*
)
vec2_x
,
(
T
*
)
vec2_y
,
actual_compute_box_num
,
(
T
*
)
temp9_ram
);
// temp6 = ADdotAD
dot2d
<
T
>
((
T
*
)
temp6_ram
,
(
T
*
)
vec2_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec2_y
+
3
*
actual_compute_box_num
,
(
T
*
)
vec2_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec2_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
,
(
T
*
)
temp9_ram
);
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lines within AB
// and P's projection on AD lies within AD
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
// AP = pts1[i] - pts2[0] = (temp7, temp8)
__bang_sub
((
T
*
)
temp7_ram
,
(
T
*
)
rotated_pts1_x
+
i
*
actual_compute_box_num
,
(
T
*
)
rotated_pts2_x
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp8_ram
,
(
T
*
)
rotated_pts1_y
+
i
*
actual_compute_box_num
,
(
T
*
)
rotated_pts2_y
,
actual_compute_box_num
);
// temp9 = APdotAB = dot2d<T>(AP, AB)
dot2d
<
T
>
((
T
*
)
temp9_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
vec2_x
,
(
T
*
)
vec2_y
,
actual_compute_box_num
,
(
T
*
)
temp2_ram
);
// temp10 = APdotAD = -dot2d<T>(AP, DA)
dot2d
<
T
>
((
T
*
)
temp10_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
vec2_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec2_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
,
(
T
*
)
temp2_ram
);
__bang_mul_scalar
((
T
*
)
temp10_ram
,
(
T
*
)
temp10_ram
,
(
T
)
-
1
,
actual_compute_box_num
);
// ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <=
// ADdotAD))
__bang_cycle_ge
((
T
*
)
temp1_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_cycle_ge
((
T
*
)
temp2_ram
,
(
T
*
)
temp10_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
__bang_le
((
T
*
)
temp2_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp5_ram
,
actual_compute_box_num
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
__bang_le
((
T
*
)
temp2_ram
,
(
T
*
)
temp10_ram
,
(
T
*
)
temp6_ram
,
actual_compute_box_num
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
// 16 means the 4x4 possible intersection points above
__bang_mul
((
T
*
)
intersect_pts_x
+
(
16
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
(
T
*
)
rotated_pts1_x
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
intersect_pts_y
+
(
16
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
(
T
*
)
rotated_pts1_y
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
// assign valid_pts bit and accumulate nums of valid points of each box pair
__bang_or
((
T
*
)
valid_pts
+
(
16
+
i
)
*
actual_compute_box_num
,
(
T
*
)
valid_pts
+
(
16
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
nums_in_ram
,
(
T
*
)
nums_in_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
}
// Reverse the check - check for vertices of rect2 inside rect1
// temp5 = ABdotAB
dot2d
<
T
>
((
T
*
)
temp5_ram
,
(
T
*
)
vec1_x
,
(
T
*
)
vec1_y
,
(
T
*
)
vec1_x
,
(
T
*
)
vec1_y
,
actual_compute_box_num
,
(
T
*
)
temp9_ram
);
// temp6 = ADdotAD
dot2d
<
T
>
((
T
*
)
temp6_ram
,
(
T
*
)
vec1_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec1_y
+
3
*
actual_compute_box_num
,
(
T
*
)
vec1_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec1_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
,
(
T
*
)
temp9_ram
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
// AP = pts2[i] - pts1[0] = (temp7, temp8)
__bang_sub
((
T
*
)
temp7_ram
,
(
T
*
)
rotated_pts2_x
+
i
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_x
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp8_ram
,
(
T
*
)
rotated_pts2_y
+
i
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_y
,
actual_compute_box_num
);
// temp9 = APdotAB = dot2d<T>(AP, AB)
dot2d
<
T
>
((
T
*
)
temp9_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
vec1_x
,
(
T
*
)
vec1_y
,
actual_compute_box_num
,
(
T
*
)
temp2_ram
);
// temp10 = APdotAD = -dot2d<T>(AP, DA)
dot2d
<
T
>
((
T
*
)
temp10_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
vec1_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec1_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
,
(
T
*
)
temp2_ram
);
__bang_mul_scalar
((
T
*
)
temp10_ram
,
(
T
*
)
temp10_ram
,
(
T
)
-
1
,
actual_compute_box_num
);
// ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <=
// ADdotAD))
__bang_cycle_ge
((
T
*
)
temp1_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_cycle_ge
((
T
*
)
temp2_ram
,
(
T
*
)
temp10_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
__bang_le
((
T
*
)
temp2_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp5_ram
,
actual_compute_box_num
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
__bang_le
((
T
*
)
temp2_ram
,
(
T
*
)
temp10_ram
,
(
T
*
)
temp6_ram
,
actual_compute_box_num
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
// 20 means the (4x4+4) possible intersection points above
__bang_mul
((
T
*
)
intersect_pts_x
+
(
20
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
(
T
*
)
rotated_pts2_x
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
intersect_pts_y
+
(
20
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
(
T
*
)
rotated_pts2_y
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
// assign valid_pts bit and accumulate nums of valid points of each box pair
__bang_or
((
T
*
)
valid_pts
+
(
20
+
i
)
*
actual_compute_box_num
,
(
T
*
)
valid_pts
+
(
20
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
nums_in_ram
,
(
T
*
)
nums_in_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
}
}
template
<
typename
T
>
__mlu_func__
void
convexHullGraham
(
T
*
intersect_pts_x
,
T
*
intersect_pts_y
,
T
*
ordered_pts_x
,
T
*
ordered_pts_y
,
T
*
dist_ram
,
T
*
valid_box
,
T
*
valid_pts
,
T
*
nums_in_ram
,
T
*
temp1_ram
,
T
*
temp2_ram
,
T
*
temp3_ram
,
T
*
temp_long_1
,
T
*
temp_long_2
,
T
*
temp_long_3
,
const
uint32_t
&
actual_box_num
,
const
uint32_t
&
actual_compute_box_num
)
{
// Step1. Find the point with minimum y, if more than 1 points have the same
// minimum y,
// pick the one with the minimum x.
// set p[i].y to max_y_value if not valid_pts, to avoid invalid result
// 24 means all possible intersection points
__bang_max
((
T
*
)
temp2_ram
,
(
T
*
)
intersect_pts_y
,
24
*
actual_compute_box_num
);
__bang_write_value
((
T
*
)
temp3_ram
,
COMPUTE_COUNT_ALIGN
,
((
T
*
)
temp2_ram
)[
0
]);
__bang_not
((
T
*
)
temp_long_1
,
(
T
*
)
valid_pts
,
24
*
actual_compute_box_num
);
__bang_cycle_mul
((
T
*
)
temp_long_1
,
(
T
*
)
temp_long_1
,
(
T
*
)
temp3_ram
,
24
*
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_mul
((
T
*
)
temp_long_2
,
(
T
*
)
intersect_pts_y
,
(
T
*
)
valid_pts
,
24
*
actual_compute_box_num
);
__bang_add
((
T
*
)
temp_long_2
,
(
T
*
)
temp_long_2
,
(
T
*
)
temp_long_1
,
24
*
actual_compute_box_num
);
// temp2 = min_y_value(temp_long_2), use min_pool, channel=box_num, h=1, w=24
__bang_minpool
((
T
*
)
temp2_ram
,
(
T
*
)
temp_long_2
,
actual_compute_box_num
,
1
,
24
,
1
,
24
,
1
,
24
);
__bang_mul
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
valid_box
,
actual_compute_box_num
);
// set p[i].x to max_x_value if not min_y point
__bang_max
((
T
*
)
temp1_ram
,
(
T
*
)
intersect_pts_x
,
24
*
actual_compute_box_num
);
__bang_write_value
((
T
*
)
temp3_ram
,
COMPUTE_COUNT_ALIGN
,
((
T
*
)
temp1_ram
)[
0
]);
__bang_cycle_eq
((
T
*
)
temp_long_1
,
(
T
*
)
temp_long_2
,
(
T
*
)
temp2_ram
,
24
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_and
((
T
*
)
temp_long_1
,
(
T
*
)
temp_long_1
,
(
T
*
)
valid_pts
,
24
*
actual_compute_box_num
);
__bang_not
((
T
*
)
temp_long_3
,
(
T
*
)
temp_long_1
,
24
*
actual_compute_box_num
);
__bang_cycle_mul
((
T
*
)
temp_long_3
,
(
T
*
)
temp_long_3
,
(
T
*
)
temp3_ram
,
24
*
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_mul
((
T
*
)
temp_long_1
,
(
T
*
)
intersect_pts_x
,
(
T
*
)
temp_long_1
,
24
*
actual_compute_box_num
);
__bang_add
((
T
*
)
temp_long_1
,
(
T
*
)
temp_long_1
,
(
T
*
)
temp_long_3
,
24
*
actual_compute_box_num
);
// temp3 = min_x_value(temp_long_1), use min_pool, channel=box_num, h=1, w=24
__bang_minpool
((
T
*
)
temp3_ram
,
(
T
*
)
temp_long_1
,
actual_compute_box_num
,
1
,
24
,
1
,
24
,
1
,
24
);
__bang_mul
((
T
*
)
temp3_ram
,
(
T
*
)
temp3_ram
,
(
T
*
)
valid_box
,
actual_compute_box_num
);
// Step2. All points subtract starting-point (for sorting in the next step)
__bang_cycle_sub
((
T
*
)
ordered_pts_x
,
(
T
*
)
intersect_pts_x
,
(
T
*
)
temp3_ram
,
24
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_cycle_sub
((
T
*
)
ordered_pts_y
,
(
T
*
)
intersect_pts_y
,
(
T
*
)
temp2_ram
,
24
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
ordered_pts_x
,
(
T
*
)
ordered_pts_x
,
(
T
*
)
valid_pts
,
24
*
actual_compute_box_num
);
__bang_mul
((
T
*
)
ordered_pts_y
,
(
T
*
)
ordered_pts_y
,
(
T
*
)
valid_pts
,
24
*
actual_compute_box_num
);
// Step3. Sort every intersection point according to their relative
// cross-product values (essentially sorting according to angles)
// If the angles are the same, sort according to distance to origin
dot2d
<
T
>
((
T
*
)
dist_ram
,
(
T
*
)
ordered_pts_x
,
(
T
*
)
ordered_pts_y
,
(
T
*
)
ordered_pts_x
,
(
T
*
)
ordered_pts_y
,
24
*
actual_compute_box_num
,
(
T
*
)
temp_long_3
);
T
temp
,
temp_nums_in
,
temp_dist_1
,
temp_dist_2
;
T
temp1_x
,
temp1_y
;
T
temp2_x
,
temp2_y
;
for
(
int
i
=
0
;
i
<
actual_box_num
;
i
++
)
{
if
(((
T
*
)
valid_box
)[
i
])
{
// make sure all nums_in[i] points are at the front
for
(
int
ii
=
0
;
ii
<
23
;
ii
++
)
{
for
(
int
jj
=
ii
+
1
;
jj
<
24
;
jj
++
)
{
int
ii_index
=
ii
*
actual_compute_box_num
+
i
;
int
jj_index
=
jj
*
actual_compute_box_num
+
i
;
// ii point is not valid and jj point is valid, swap jj for ii
if
((
!
((
T
*
)
valid_pts
)[
ii_index
])
&&
((
T
*
)
valid_pts
)[
jj_index
])
{
((
T
*
)
ordered_pts_x
)[
ii_index
]
=
((
T
*
)
ordered_pts_x
)[
jj_index
];
((
T
*
)
ordered_pts_y
)[
ii_index
]
=
((
T
*
)
ordered_pts_y
)[
jj_index
];
((
T
*
)
dist_ram
)[
ii_index
]
=
((
T
*
)
dist_ram
)[
jj_index
];
((
T
*
)
valid_pts
)[
ii_index
]
=
true
;
((
T
*
)
ordered_pts_x
)[
jj_index
]
=
0
;
((
T
*
)
ordered_pts_y
)[
jj_index
]
=
0
;
((
T
*
)
dist_ram
)[
jj_index
]
=
0
;
((
T
*
)
valid_pts
)[
jj_index
]
=
false
;
break
;
}
}
}
temp_nums_in
=
((
T
*
)
nums_in_ram
)[
i
];
// make original q[0] = min_x, min_y before sort
for
(
int
ii
=
1
;
ii
<
temp_nums_in
;
ii
++
)
{
int
ii_index
=
ii
*
actual_compute_box_num
+
i
;
if
(((
T
*
)
dist_ram
)[
ii_index
]
==
0
)
{
// swap q[ii_index] and q[0]
((
T
*
)
ordered_pts_x
)[
ii_index
]
=
((
T
*
)
ordered_pts_x
)[
i
];
((
T
*
)
ordered_pts_y
)[
ii_index
]
=
((
T
*
)
ordered_pts_y
)[
i
];
((
T
*
)
dist_ram
)[
ii_index
]
=
((
T
*
)
dist_ram
)[
i
];
((
T
*
)
ordered_pts_x
)[
i
]
=
0
;
((
T
*
)
ordered_pts_y
)[
i
]
=
0
;
((
T
*
)
dist_ram
)[
i
]
=
0
;
break
;
}
}
for
(
int
ii
=
1
;
ii
<
temp_nums_in
-
1
;
ii
++
)
{
for
(
int
jj
=
ii
+
1
;
jj
<
temp_nums_in
;
jj
++
)
{
int
ii_index
=
ii
*
actual_compute_box_num
+
i
;
int
jj_index
=
jj
*
actual_compute_box_num
+
i
;
temp1_x
=
((
T
*
)
ordered_pts_x
)[
ii_index
];
temp1_y
=
((
T
*
)
ordered_pts_y
)[
ii_index
];
temp2_x
=
((
T
*
)
ordered_pts_x
)[
jj_index
];
temp2_y
=
((
T
*
)
ordered_pts_y
)[
jj_index
];
// calculate cross product and sort q (ordered_pts)
temp
=
(
temp1_x
*
temp2_y
)
-
(
temp1_y
*
temp2_x
);
temp_dist_1
=
((
T
*
)
dist_ram
)[
ii_index
];
temp_dist_2
=
((
T
*
)
dist_ram
)[
jj_index
];
if
((
temp
<
(
T
)
-
1e-6
)
||
((
fabs
(
temp
)
<
(
T
)
1e-6
)
&&
(
temp_dist_1
>
temp_dist_2
)))
{
((
T
*
)
ordered_pts_x
)[
ii_index
]
=
temp2_x
;
((
T
*
)
ordered_pts_y
)[
ii_index
]
=
temp2_y
;
((
T
*
)
ordered_pts_x
)[
jj_index
]
=
temp1_x
;
((
T
*
)
ordered_pts_y
)[
jj_index
]
=
temp1_y
;
((
T
*
)
dist_ram
)[
ii_index
]
=
temp_dist_2
;
((
T
*
)
dist_ram
)[
jj_index
]
=
temp_dist_1
;
}
}
}
// Step4:
// Make sure there are at least 2 points(that don't overlap with each
// other) in the stack
int
k
;
// index of the non-overlapped second point
for
(
k
=
1
;
k
<
temp_nums_in
;
k
++
)
{
if
(((
T
*
)
dist_ram
)[
k
*
actual_compute_box_num
+
i
]
>
(
T
)
1e-8
)
{
break
;
}
}
if
(
k
==
temp_nums_in
)
{
// We reach the end, which means the convex hull is just one point
// set valid_box = 0, to get ious = 0
((
T
*
)
valid_box
)[
i
]
=
0
;
continue
;
}
// q[1] = q[k];
((
T
*
)
ordered_pts_x
)[
actual_compute_box_num
+
i
]
=
((
T
*
)
ordered_pts_x
)[
k
*
actual_compute_box_num
+
i
];
((
T
*
)
ordered_pts_y
)[
actual_compute_box_num
+
i
]
=
((
T
*
)
ordered_pts_y
)[
k
*
actual_compute_box_num
+
i
];
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
int
m
=
2
;
// 2 points in the stack
for
(
int
j
=
k
+
1
;
j
<
temp_nums_in
;
j
++
)
{
// while (m > 1 && cross2d<T>(q[j] - q[m - 2], q[m - 1] - q[m - 2]) >=
// 0) {
// m--;
// }
temp1_x
=
((
T
*
)
ordered_pts_x
)[
j
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_x
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp1_y
=
((
T
*
)
ordered_pts_y
)[
j
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_y
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp2_x
=
((
T
*
)
ordered_pts_x
)[(
m
-
1
)
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_x
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp2_y
=
((
T
*
)
ordered_pts_y
)[(
m
-
1
)
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_y
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp
=
(
temp1_x
*
temp2_y
)
-
(
temp1_y
*
temp2_x
);
while
((
m
>
1
)
&&
(
temp
>=
0
))
{
m
--
;
if
(
m
>
1
)
{
temp1_x
=
((
T
*
)
ordered_pts_x
)[
j
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_x
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp1_y
=
((
T
*
)
ordered_pts_y
)[
j
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_y
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp2_x
=
((
T
*
)
ordered_pts_x
)[(
m
-
1
)
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_x
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp2_y
=
((
T
*
)
ordered_pts_y
)[(
m
-
1
)
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_y
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp
=
(
temp1_x
*
temp2_y
)
-
(
temp1_y
*
temp2_x
);
}
}
// q[m++] = q[j];
((
T
*
)
ordered_pts_x
)[
m
*
actual_compute_box_num
+
i
]
=
((
T
*
)
ordered_pts_x
)[
j
*
actual_compute_box_num
+
i
];
((
T
*
)
ordered_pts_y
)[
m
*
actual_compute_box_num
+
i
]
=
((
T
*
)
ordered_pts_y
)[
j
*
actual_compute_box_num
+
i
];
m
++
;
}
// set last(24-m) valid_pts to false, to erase invalid q in polygon area
for
(
int
j
=
m
;
j
<
temp_nums_in
;
j
++
)
{
((
T
*
)
valid_pts
)[
j
*
actual_compute_box_num
+
i
]
=
0
;
}
((
T
*
)
nums_in_ram
)[
i
]
=
m
;
}
}
}
template
<
typename
T
>
__mlu_func__
void
polygonArea
(
T
*
ordered_pts_x
,
T
*
ordered_pts_y
,
T
*
valid_box
,
T
*
valid_pts
,
T
*
nums_in_ram
,
T
*
temp1_ram
,
T
*
temp2_ram
,
T
*
temp3_ram
,
T
*
temp4_ram
,
T
*
temp5_ram
,
T
*
temp6_ram
,
T
*
temp7_ram
,
T
*
temp8_ram
,
T
*
temp9_ram
,
const
uint32_t
&
actual_compute_box_num
)
{
// Set where nums_in <= 2, valid_box = false
__bang_write_value
((
T
*
)
temp9_ram
,
COMPUTE_COUNT_ALIGN
,
(
T
)
2
);
__bang_cycle_gt
((
T
*
)
temp1_ram
,
(
T
*
)
nums_in_ram
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
valid_box
,
(
T
*
)
valid_box
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
// temp1 = area, initialize with all 0
__bang_write_zero
((
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_max
((
T
*
)
temp7_ram
,
(
T
*
)
nums_in_ram
,
actual_compute_box_num
);
// temp_nums_in = max(nums_in)
T
temp_nums_in
=
((
T
*
)
temp7_ram
)[
0
];
for
(
int
i
=
1
;
i
<
temp_nums_in
-
1
;
i
++
)
{
// q[i] - q[0]: (temp6, temp7)
__bang_sub
((
T
*
)
temp6_ram
,
(
T
*
)
ordered_pts_x
+
i
*
actual_compute_box_num
,
(
T
*
)
ordered_pts_x
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp7_ram
,
(
T
*
)
ordered_pts_y
+
i
*
actual_compute_box_num
,
(
T
*
)
ordered_pts_y
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp6_ram
,
(
T
*
)
temp6_ram
,
(
T
*
)
valid_pts
+
(
i
+
1
)
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp7_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
valid_pts
+
(
i
+
1
)
*
actual_compute_box_num
,
actual_compute_box_num
);
// q[i + 1] - q[0]: (temp8, temp9)
__bang_sub
((
T
*
)
temp8_ram
,
(
T
*
)
ordered_pts_x
+
(
i
+
1
)
*
actual_compute_box_num
,
(
T
*
)
ordered_pts_x
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp9_ram
,
(
T
*
)
ordered_pts_y
+
(
i
+
1
)
*
actual_compute_box_num
,
(
T
*
)
ordered_pts_y
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp8_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
valid_pts
+
(
i
+
1
)
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp9_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
valid_pts
+
(
i
+
1
)
*
actual_compute_box_num
,
actual_compute_box_num
);
// area += fabs(cross2d<T>(q[i] - q[0], q[i + 1] - q[0]));
__bang_mul
((
T
*
)
temp4_ram
,
(
T
*
)
temp6_ram
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp5_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
temp8_ram
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp3_ram
,
(
T
*
)
temp4_ram
,
(
T
*
)
temp5_ram
,
actual_compute_box_num
);
__bang_active_abs
((
T
*
)
temp3_ram
,
(
T
*
)
temp3_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp3_ram
,
actual_compute_box_num
);
}
// Set where valid_box = false, intersection = 0
__bang_mul
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
valid_box
,
actual_compute_box_num
);
// area = area / 2.0
__bang_mul_scalar
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
)
0.5
,
actual_compute_box_num
);
}
#endif // IOU3D_UTILS_HPP_
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include <math.h>
/****************************************************************************************
*
* NRAM partition forward:
* | spatial_shapes | data_value_p1_ping | data_value_p2_ping |
* | data_value_p3_ping | data_value_p4_ping | data_col_ping |
* | data_value_p1_pong | data_value_p2_pong | data_value_p3_pong |
* | data_value_p4_pong | data_col_pong | auxiliary_a |
* | auxiliary_b |
* | 128bytes | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size |
*
****************************************************************************************/
/****************************************************************************************
*
* NRAM partition backward:
* default kernel
* | grad_output_nram | grad_output_nram_temp | grad_weight |
* | grad_h_weight | grad_w_weight | top_grad |
* | top_grad_temp | spatial_shapes_nram | sampling_loc_nram |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | 64bytes |
*
* small channel kernel
* | nram_grad_output_tl | nram_grad_output_tr | nram_grad_output_bl |
* | nram_grad_output_br | grad_temp1 | grad_temp2 |
* | grad_temp3 | grad_temp4 | nram_loc_w |
* | nram_loc_h | nram_h_low | nram_w_low |
* | nram_h_high | nram_w_high | nram_h_low_temp |
* | nram_h_high_temp | nram_hw | nram_hh |
* | nram_lw | nram_lh | nram_h_low_ptr_offset |
* | nram_h_high_ptr_offset | nram_w_low_ptr_offset | nram_w_high_ptr_offset |
* | nram_w1 | nram_w2 | nram_w3 |
* | nram_w4 | nram_grad_weight | nram_base_ptr |
* | nram_offset_temp | nram_offset1 | nram_offset2 |
* | nram_offset3 | nram_offset4 | nram_w_low_temp |
* | nram_spatial_shapes | nram_level_start_index | nram_h_stride |
****************************************************************************************/
#define TWELVE_SPLIT 12
#define ALIGN_NUM 32
#define ALIGN_NUM_FOR_REDUCE 32
#define ELE_COUNT 32
#define LEN_FLOAT sizeof(float)
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void loadNeighborPointsData(
const T *data_value_gdram, T *data_value_p1_nram, T *data_value_p2_nram,
T *data_value_p3_nram, T *data_value_p4_nram, const size_t &deal_num,
const int32_t &width, const int32_t &height, const int32_t &num_heads,
const int32_t &channels, const T &x, const T &y, const int32_t &head_idx) {
const int32_t w_low = floorf(x);
const int32_t h_low = floorf(y);
const int32_t w_high = w_low + 1;
const int32_t h_high = h_low + 1;
const int32_t w_stride = num_heads * channels;
const int32_t h_stride = width * w_stride;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * w_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int32_t base_ptr_offset = head_idx * channels;
// top-left point
if (h_low >= 0 && w_low >= 0) {
const int32_t v1_offset =
h_low_ptr_offset + w_low_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p1_nram, data_value_gdram + v1_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// top-right point
if (h_low >= 0 && w_high <= width - 1) {
const int32_t v2_offset =
h_low_ptr_offset + w_high_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p2_nram, data_value_gdram + v2_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// bottom-left point
if (h_high <= height - 1 && w_low >= 0) {
const int32_t v3_offset =
h_high_ptr_offset + w_low_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p3_nram, data_value_gdram + v3_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// bottom-right point
if (h_high <= height - 1 && w_high <= width - 1) {
const int32_t v4_offset =
h_high_ptr_offset + w_high_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p4_nram, data_value_gdram + v4_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
}
template <typename T>
__mlu_func__ void computeMsDeformAttn(
T *data_value_p1_nram, T *data_value_p2_nram, T *data_value_p3_nram,
T *data_value_p4_nram, T *sample_point_value, T *auxiliary_b,
T *data_col_nram, const T &weight, const size_t &deal_num,
const int32_t &width, const int32_t &height, const T &x, const T &y) {
const int32_t w_low = floorf(x);
const int32_t h_low = floorf(y);
const int32_t w_high = w_low + 1;
const int32_t h_high = h_low + 1;
const T lw = x - w_low;
const T lh = y - h_low;
const T hw = 1 - lw;
const T hh = 1 - lh;
const T w1 = hh * hw;
const T w2 = hh * lw;
const T w3 = lh * hw;
const T w4 = lh * lw;
__bang_write_value((T *)sample_point_value, deal_num, (T)0);
// top-left point
if (h_low >= 0 && w_low >= 0) {
// sample_point_value += v1 * w1
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p1_nram, (T)w1,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// top-right point
if (h_low >= 0 && w_high <= width - 1) {
// sample_point_value += v2 * w2
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p2_nram, (T)w2,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// bottom-left point
if (h_high <= height - 1 && w_low >= 0) {
// sample_point_value += v3 * w3
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p3_nram, (T)w3,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// bottom-right point
if (h_high <= height - 1 && w_high <= width - 1) {
// sample_point_value += v4 * w4
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p4_nram, (T)w4,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
__bang_mul_scalar((T *)sample_point_value, (T *)sample_point_value, (T)weight,
deal_num);
__bang_add((T *)data_col_nram, (T *)data_col_nram, (T *)sample_point_value,
deal_num);
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForwardDefault(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
if (coreId == 0x80) {
return;
}
const size_t spatial_size = PAD_UP(2 * sizeof(int32_t), NFU_ALIGN_SIZE);
const size_t span_num_deal =
PAD_DOWN((MAX_NRAM_SIZE - spatial_size) / TWELVE_SPLIT / sizeof(T),
NFU_ALIGN_SIZE);
const size_t align_num = NFU_ALIGN_SIZE;
const int32_t channels_seg_num = channels / span_num_deal;
const size_t channels_rem = channels % span_num_deal;
const size_t channels_align_rem = CEIL_ALIGN(channels_rem, align_num);
char *data_spatial_shapes_nram = nram_buffer;
char *ping_data_value_p1_nram = data_spatial_shapes_nram + spatial_size;
char *ping_data_value_p2_nram =
ping_data_value_p1_nram + span_num_deal * sizeof(T);
char *ping_data_value_p3_nram =
ping_data_value_p2_nram + span_num_deal * sizeof(T);
char *ping_data_value_p4_nram =
ping_data_value_p3_nram + span_num_deal * sizeof(T);
char *ping_data_col_nram =
ping_data_value_p4_nram + span_num_deal * sizeof(T);
char *pong_data_value_p1_nram =
ping_data_col_nram + span_num_deal * sizeof(T);
char *pong_data_value_p2_nram =
pong_data_value_p1_nram + span_num_deal * sizeof(T);
char *pong_data_value_p3_nram =
pong_data_value_p2_nram + span_num_deal * sizeof(T);
char *pong_data_value_p4_nram =
pong_data_value_p3_nram + span_num_deal * sizeof(T);
char *pong_data_col_nram =
pong_data_value_p4_nram + span_num_deal * sizeof(T);
char *auxiliary_a = pong_data_col_nram + span_num_deal * sizeof(T);
char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T);
const size_t ping_pong_gap = 5 * span_num_deal * sizeof(T);
size_t data_col_ping_pong_idx = 0;
int32_t block_num_per_core = (batch_size * num_queries * num_heads) / taskDim;
const int32_t block_num_rem =
(batch_size * num_queries * num_heads) % taskDim;
const int32_t idx_start = taskId < (block_num_rem + 1)
? taskId * (block_num_per_core + 1)
: taskId * block_num_per_core + block_num_rem;
block_num_per_core =
taskId < block_num_rem
? (batch_size * num_queries * num_heads) / taskDim + 1
: (batch_size * num_queries * num_heads) / taskDim;
for (int32_t cur_idx = idx_start; cur_idx < idx_start + block_num_per_core;
++cur_idx) {
// cur_idx = batch_idx * num_queries * num_heads + query_idx * num_heads +
// head_idx
const int32_t head_idx = cur_idx % num_heads;
const int32_t batch_idx = (cur_idx / num_heads) / num_queries;
const char *data_value_gdram_start =
data_value_gdram +
batch_idx * num_keys * num_heads * channels * sizeof(T);
const char *data_sampling_loc_gdram_start =
data_sampling_loc_gdram +
cur_idx * num_levels * num_points * 2 * sizeof(T);
const char *data_attn_weight_gdram_start =
data_attn_weight_gdram + cur_idx * num_levels * num_points * sizeof(T);
char *data_col_gdram_start =
data_col_gdram + cur_idx * channels * sizeof(T);
for (int32_t c_seg_idx = 0; c_seg_idx < channels_seg_num; ++c_seg_idx) {
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
span_num_deal, (T)0);
// load data
// level_idx = 0, point_idx = 0
__memcpy(data_spatial_shapes_nram, data_spatial_shapes_gdram,
2 * sizeof(int32_t), GDRAM2NRAM);
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + c_seg_idx * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_gdram_start)[0];
T loc_h = ((T *)data_sampling_loc_gdram_start)[1];
T weight = ((T *)data_attn_weight_gdram_start)[0];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, span_num_deal, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_gdram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
__memcpy(
data_spatial_shapes_nram,
data_spatial_shapes_gdram + spatial_h_ptr * sizeof(int32_t),
2 * sizeof(int32_t), GDRAM2NRAM);
spatial_h_next_point = ((int32_t *)data_spatial_shapes_nram)[0];
spatial_w_next_point = ((int32_t *)data_spatial_shapes_nram)[1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
c_seg_idx * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_h_next_point = spatial_h;
spatial_w_next_point = spatial_w;
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, span_num_deal, spatial_w, spatial_h, x, y);
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
}
// store
__memcpy_async(
data_col_gdram_start + c_seg_idx * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
span_num_deal * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
if (channels_rem > 0) {
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
channels_align_rem, (T)0);
// load data
// level_idx = 0, point_idx = 0
__memcpy(data_spatial_shapes_nram, data_spatial_shapes_gdram,
2 * sizeof(int32_t), GDRAM2NRAM);
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + channels_seg_num * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_gdram_start)[0];
T loc_h = ((T *)data_sampling_loc_gdram_start)[1];
T weight = ((T *)data_attn_weight_gdram_start)[0];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, channels_rem, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_gdram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
__memcpy(
data_spatial_shapes_nram,
data_spatial_shapes_gdram + spatial_h_ptr * sizeof(int32_t),
2 * sizeof(int32_t), GDRAM2NRAM);
spatial_h_next_point = ((int32_t *)data_spatial_shapes_nram)[0];
spatial_w_next_point = ((int32_t *)data_spatial_shapes_nram)[1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
channels_seg_num * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_w_next_point = spatial_w;
spatial_h_next_point = spatial_h;
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, channels_align_rem, spatial_w, spatial_h, x, y);
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
}
// store
__memcpy_async(
data_col_gdram_start + channels_seg_num * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
channels_rem * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
}
__asm__ volatile("sync;");
return;
}
__mlu_func__ void genMask0101(float *mask_ram, int32_t size) {
int32_t align_num = NFU_ALIGN_SIZE / sizeof(float);
for (int32_t i = 0; i < align_num; ++i) {
mask_ram[i] = i % 2;
}
__asm__ volatile("sync;");
__memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM,
NFU_ALIGN_SIZE, 0, size / align_num - 2);
__asm__ volatile("sync;");
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
#if __BANG_ARCH__ >= 300
if (coreId == 0x80) {
return;
}
size_t block_num_per_core, batch_start, deal_g, offset_g;
size_t block_num_rem = 0;
const size_t grid_total = num_queries * num_heads * num_levels * num_points;
if (batch_size >= taskDim) {
block_num_rem = batch_size % taskDim;
block_num_per_core = taskId < block_num_rem ? batch_size / taskDim + 1
: batch_size / taskDim;
batch_start = taskId < block_num_rem
? taskId * block_num_per_core
: taskId * block_num_per_core + block_num_rem;
deal_g = grid_total;
offset_g = 0;
} else {
size_t skip_n = taskDim / batch_size;
batch_start = taskId / skip_n;
block_num_per_core = batch_start >= batch_size ? 0 : 1;
deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points);
size_t id = taskId % skip_n;
offset_g = id * deal_g;
deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1);
}
const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float);
int32_t deal_num;
int32_t cut_channel_iter = 2;
const size_t spatial_size =
PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE);
const size_t level_start_index_size =
PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE);
int32_t channel = channels;
int32_t mult;
while (true) {
deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) /
(8 * channel + 7) / sizeof(T);
deal_num = PAD_DOWN(deal_num, float_align);
deal_num = PAD_DOWN(deal_num, num_levels * num_points);
if (deal_num > 0) {
break;
} else {
channel = channels / cut_channel_iter;
cut_channel_iter += 2;
}
}
mult = channel;
const int32_t c_rep = channels / channel;
const int32_t c_rem = channels % channel;
const int32_t g_rep = deal_g / deal_num;
const int32_t g_rem = deal_g % deal_num;
// nram buffer alloc
char *data_spatial_shapes_nram = nram_buffer;
char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size;
char *input_tl = data_level_start_index_nram + level_start_index_size;
char *input_tr = input_tl + deal_num * mult * sizeof(T);
char *input_bl = input_tr + deal_num * mult * sizeof(T);
char *input_br = input_bl + deal_num * mult * sizeof(T);
char *weight_tl = input_tl + 4 * deal_num * mult * sizeof(T);
char *weight_tr = weight_tl + deal_num * mult * sizeof(T);
char *weight_bl = weight_tr + deal_num * mult * sizeof(T);
char *weight_br = weight_bl + deal_num * mult * sizeof(T);
char *mask_tl = weight_br + deal_num * mult * sizeof(T);
char *mask_tr = mask_tl + deal_num * sizeof(T);
char *mask_bl = mask_tr + deal_num * sizeof(T);
char *mask_br = mask_bl + deal_num * sizeof(T);
char *point_ram = mask_br + deal_num * sizeof(T);
char *index_tl = point_ram + deal_num * sizeof(T);
char *index_bl = index_tl + deal_num * sizeof(T);
// nram space reuse
char *grid_ram = weight_tl;
char *mask_ram = weight_bl;
char *coord_x = input_bl;
char *coord_y = coord_x + deal_num * sizeof(T);
char *coord_x_low = input_tl;
char *coord_y_low = coord_x_low + deal_num * sizeof(T);
char *coord_x_low_int = weight_tl;
char *coord_y_low_int = weight_tr;
char *spatial_x = mask_tl;
char *spatial_y = mask_tr;
char *spatial_x_float = weight_bl;
char *spatial_y_float = weight_br;
char *spatial_x_temp = mask_bl;
char *spatial_y_temp = mask_br;
char *base_ptr_offset = weight_tl;
char *auxiliary_a = point_ram;
char *auxiliary_b = weight_bl;
__memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(data_level_start_index_nram, data_level_start_index_gdram,
num_levels * sizeof(int32_t), GDRAM2NRAM);
__asm__ volatile("sync;");
for (int32_t batch_idx = batch_start;
batch_idx < batch_start + block_num_per_core; ++batch_idx) {
for (int32_t grid_iter = 0; grid_iter <= g_rep; ++grid_iter) {
int32_t io_data_num = deal_num;
const int32_t grid_off_base =
batch_idx * grid_total + offset_g + grid_iter * deal_num;
if (grid_iter == g_rep) {
if (g_rem == 0) {
continue;
} else {
io_data_num = g_rem;
}
}
char *data_col_gdram_start =
data_col_gdram + (batch_idx * num_queries * num_heads * channels +
(offset_g + grid_iter * deal_num) /
(num_levels * num_points) * channels) *
sizeof(float);
// load data_sampling_loc
__memcpy_async(
grid_ram, data_sampling_loc_gdram + grid_off_base * 2 * sizeof(float),
io_data_num * 2 * sizeof(float), GDRAM2NRAM);
genMask0101((float *)mask_ram, deal_num * 2);
__asm__ volatile("sync;");
// generate x and y coordinate vector
// generate spatial_x and spatial_y spatial vector
__bang_collect((float *)coord_y, (float *)grid_ram, (float *)mask_ram,
deal_num * 2); // y
__bang_collect((float *)spatial_x_temp, (float *)data_spatial_shapes_nram,
(float *)mask_ram,
num_levels * 2); // spatial_x
__bang_not((float *)mask_ram, (float *)mask_ram, deal_num * 2);
__bang_collect((float *)coord_x, (float *)grid_ram, (float *)mask_ram,
deal_num * 2); // x
__bang_collect((float *)spatial_y_temp, (float *)data_spatial_shapes_nram,
(float *)mask_ram,
num_levels * 2); // spatial_y
for (int32_t i = 0; i < num_levels; i++) {
__bang_write_value((int32_t *)spatial_x + i * num_points, num_points,
((int32_t *)spatial_x_temp)[i]);
__bang_write_value((int32_t *)spatial_y + i * num_points, num_points,
((int32_t *)spatial_y_temp)[i]);
}
__bang_int322float_rd((float *)spatial_x_float, (int32_t *)spatial_x,
num_levels * num_points, 0);
__bang_int322float_rd((float *)spatial_y_float, (int32_t *)spatial_y,
num_levels * num_points, 0);
// map x from [0, 1] to [0, spatial_x]; map y from [0, 1] to [0,
// spatial_y]
__bang_cycle_mul((float *)coord_x, (float *)coord_x,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_sub_scalar((float *)coord_x, (float *)coord_x, (float)0.5,
deal_num);
__bang_cycle_mul((float *)coord_y, (float *)coord_y,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5,
deal_num);
__bang_floor((float *)coord_x_low, (float *)coord_x, deal_num);
__bang_floor((float *)coord_y_low, (float *)coord_y, deal_num);
// calc index_tl
const int32_t w_stride = num_heads * channels;
__bang_float2int32_rd((int32_t *)coord_x_low_int, (float *)coord_x_low,
deal_num, 0);
__bang_float2int32_rd((int32_t *)coord_y_low_int, (float *)coord_y_low,
deal_num, 0);
__bang_cycle_mul((int32_t *)index_tl, (int32_t *)coord_y_low_int,
(int32_t *)spatial_x, deal_num, num_levels * num_points);
__bang_add((int32_t *)index_tl, (int32_t *)index_tl,
(int32_t *)coord_x_low_int, deal_num);
__bang_mul_scalar((int32_t *)index_tl, (int32_t *)index_tl, w_stride,
deal_num);
const int32_t deal_lp_num = deal_num / (num_levels * num_points);
const int32_t h_rep = deal_lp_num / num_heads;
const int32_t h_rem = deal_lp_num % num_heads;
const int32_t head_start =
((offset_g + grid_iter * deal_num) / (num_levels * num_points)) %
num_heads;
for (int32_t iter = 0; iter < num_heads; ++iter) {
((int32_t *)base_ptr_offset)[iter] =
((head_start + iter) % num_heads) * channels;
}
if (h_rep > 0) {
__memcpy((int32_t *)base_ptr_offset + num_heads,
(int32_t *)base_ptr_offset, num_heads * sizeof(int32_t),
NRAM2NRAM, num_heads * sizeof(int32_t), 0, h_rep - 1);
}
if (h_rep > 0 && h_rem > 0) {
__memcpy((int32_t *)base_ptr_offset + h_rep * num_heads,
(int32_t *)base_ptr_offset, h_rem * sizeof(int32_t),
NRAM2NRAM);
}
__bang_transpose((int32_t *)auxiliary_a, (int32_t *)index_tl, deal_lp_num,
num_levels * num_points);
__bang_cycle_add((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)base_ptr_offset, deal_num, deal_lp_num);
__bang_transpose((int32_t *)index_tl, (int32_t *)auxiliary_a,
num_levels * num_points, deal_lp_num);
// calc index_bl
__bang_mul_scalar((int32_t *)auxiliary_a, (int32_t *)spatial_x, w_stride,
deal_num);
__bang_cycle_add((int32_t *)index_bl, (int32_t *)index_tl,
(int32_t *)auxiliary_a, deal_num,
num_levels * num_points);
// calc mask_tl, mask_tr, mask_bl, mask_br
__bang_sub_scalar((float *)spatial_x_float, (float *)spatial_x_float,
(float)1.0, deal_num);
__bang_sub_scalar((float *)spatial_y_float, (float *)spatial_y_float,
(float)1.0, deal_num);
// mask_tl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_low < spatial_y
__bang_ge_scalar((float *)mask_bl, (float *)coord_x_low, (float)0,
deal_num);
__bang_cycle_le((float *)mask_br, (float *)coord_x_low,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_bl, (float *)mask_bl, (float *)mask_br,
deal_num);
__bang_ge_scalar((float *)mask_tr, (float *)coord_y_low, (float)0,
deal_num);
__bang_cycle_le((float *)mask_br, (float *)coord_y_low,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br,
deal_num);
__bang_and((float *)mask_tl, (float *)mask_tr, (float *)mask_bl,
deal_num);
// mask_tr : 0 <= coord_x_high < spatial_x && 0 <= coord_y_low < spatial_y
__bang_ge_scalar((float *)mask_br, (float *)coord_x_low, (float)(-1.0),
deal_num);
__bang_cycle_lt((float *)auxiliary_a, (float *)coord_x_low,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
__bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br,
deal_num);
// mask_bl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_high < spatial_y
__bang_ge_scalar((float *)auxiliary_a, (float *)coord_y_low,
(float)(-1.0), deal_num);
__bang_cycle_lt((float *)auxiliary_b, (float *)coord_y_low,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_and((float *)auxiliary_a, (float *)auxiliary_a,
(float *)auxiliary_b, deal_num);
__bang_and((float *)mask_bl, (float *)mask_bl, (float *)auxiliary_a,
deal_num);
// mask_br : 0 <= coord_x_high < spatial_x && 0 <= coord_y_high <
// spatial_y
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
// calc inner point num
__bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0,
deal_num);
__bang_mul_scalar((float *)weight_tr, (float *)mask_tr, (float)5.0,
deal_num);
__bang_add((float *)weight_tl, (float *)weight_tl, (float *)weight_tr,
deal_num);
__bang_mul_scalar((float *)weight_tr, (float *)mask_bl, (float)3.0,
deal_num);
__bang_add((float *)point_ram, (float *)weight_tr, (float *)mask_br,
deal_num);
__bang_add((float *)point_ram, (float *)point_ram, (float *)weight_tl,
deal_num);
// calc interpolation weight
__bang_sub((float *)weight_bl, (float *)coord_x_low, (float *)coord_x,
deal_num);
__bang_sub((float *)weight_br, (float *)coord_y_low, (float *)coord_y,
deal_num);
__bang_add_scalar((float *)weight_bl, (float *)weight_bl, (float)1.0,
deal_num);
__bang_add_scalar((float *)weight_br, (float *)weight_br, (float)1.0,
deal_num);
__bang_sub((float *)weight_tl, (float *)coord_x, (float *)coord_x_low,
deal_num);
__bang_sub((float *)weight_tr, (float *)coord_y, (float *)coord_y_low,
deal_num);
__bang_mul((float *)input_tl, (float *)weight_bl, (float *)weight_br,
deal_num);
__bang_mul((float *)input_tl + deal_num, (float *)weight_br,
(float *)weight_tl, deal_num);
__bang_mul((float *)input_tl + 2 * deal_num, (float *)weight_bl,
(float *)weight_tr, deal_num);
__bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl,
(float *)weight_tr, deal_num);
__asm__ volatile("sync;");
// extend weight
const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT;
const int32_t w_rem = channel % ELE_COUNT;
if (w_rem != 0) {
const int32_t data_sz = 1 * sizeof(float);
const int32_t dst_str = channel * sizeof(float);
for (int32_t iter = w_rep; iter < channel; ++iter) {
__memcpy_async((float *)weight_tl + iter, (float *)input_tl, data_sz,
NRAM2NRAM, dst_str, data_sz, 4 * deal_num - 1);
}
}
if (w_rep != 0) {
for (int32_t i = 0; i < 4 * deal_num; i++) {
__bang_write_value((float *)weight_tl + i * channel, w_rep,
((float *)input_tl)[i]);
}
}
__asm__ volatile("sync;");
const char *data_value_gdram_start =
data_value_gdram +
batch_idx * num_keys * num_heads * channels * sizeof(float);
const int32_t c_str = deal_num * channel * sizeof(float);
const int32_t cs_str = num_heads * channels * sizeof(float);
for (int32_t c_iter = 0; c_iter <= c_rep; ++c_iter) {
int32_t c_real_num = channel;
if (c_iter == c_rep) {
if (c_rem == 0) {
continue;
} else {
c_real_num = c_rem;
}
}
__bang_write_zero((float *)input_tl, 4 * deal_num * channel);
__asm__ volatile("sync;");
// load data_value
for (int32_t p_idx = 0; p_idx < io_data_num; ++p_idx) {
const int32_t inner_point_num = (int32_t)((float *)point_ram)[p_idx];
const int32_t tl_offset = ((int32_t *)index_tl)[p_idx];
const int32_t bl_offset = ((int32_t *)index_bl)[p_idx];
const int32_t level_start_id =
((int32_t *)data_level_start_index_nram)[(p_idx / num_points) %
num_levels];
const char *data_value_ptr =
data_value_gdram_start +
(level_start_id * num_heads * channels + c_iter * channel) *
sizeof(float);
switch (inner_point_num) {
case 16: // 4 points are cached.
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 12: // 2 points are cached. (top_left, top_right)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 4: // 2 points are cached. (bottom_left, bottom_right)
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 10: // 2 points are cached. (top_left, bottom_left)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 6: // 2 points are cached. (top_right, bottom_right)
__memcpy_async(
(float *)input_tr + p_idx * channel,
(float *)data_value_ptr + tl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
__memcpy_async(
(float *)input_br + p_idx * channel,
(float *)data_value_ptr + bl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 7: // 1 point is cached. (top_left)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 5: // 1 point is cached. (top_right)
__memcpy_async(
(float *)input_tr + p_idx * channel,
(float *)data_value_ptr + tl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 3: // 1 point is cached. (bottom_left)
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 1: // 1 point is cached. (bottom_right)
__memcpy_async(
(float *)input_br + p_idx * channel,
(float *)data_value_ptr + bl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
default:
continue;
}
}
__asm__ volatile("sync;");
// interpolation
__bang_mul((float *)input_tl, (float *)input_tl, (float *)weight_tl,
4 * deal_num * channel);
__bang_add((float *)input_tl, (float *)input_tl, (float *)input_bl,
2 * deal_num * channel);
__bang_add((float *)input_tl, (float *)input_tl, (float *)input_tr,
deal_num * channel);
// load attention weight
void *attn_weight = mask_tl;
__memcpy((float *)attn_weight,
(float *)data_attn_weight_gdram + grid_off_base,
io_data_num * sizeof(float), GDRAM2NRAM);
// calc data_col, muladd attention weight
__bang_transpose((float *)input_tr, (float *)input_tl, deal_num,
channel);
__bang_cycle_mul((float *)input_tr, (float *)input_tr,
(float *)attn_weight, deal_num * channel, deal_num);
__bang_transpose((float *)input_tl, (float *)input_tr, channel,
deal_num);
__bang_sumpool((float *)input_bl, (float *)input_tl, channel, 1,
io_data_num, 1, num_levels * num_points,
num_levels * num_points, 1);
// store
__memcpy((float *)data_col_gdram_start + c_iter * channel,
(float *)input_bl, c_real_num * sizeof(float), NRAM2GDRAM,
channels * sizeof(float), channel * sizeof(float),
(io_data_num / (num_levels * num_points)) - 1);
}
}
}
__asm__ volatile("sync;");
#endif
return;
}
template __mlu_global__ void MLUKernelMsDeformAttnForwardDefault<float>(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram);
void KernelMsDeformAttnForwardDefault(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
MLUKernelMsDeformAttnForwardDefault<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
}
template __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel<float>(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram);
void KernelMsDeformAttnForwardSmallChannel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
MLUKernelMsDeformAttnForwardSmallChannel<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
}
template <typename T>
void __mlu_func__ msDeformAttnCol2imBilinear(
T *top_grad_temp, const int32_t &height, const int32_t &width, const T &w1,
const T &w2, const T &w3, const T &w4, const int32_t &h_low,
const int32_t &w_low, const int32_t &h_high, const int32_t &w_high,
const int32_t &base_ptr, const int32_t &h_low_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &h_high_ptr_offset,
const int32_t &w_high_ptr_offset, const T &hh, const T &hw, const T &lh,
const T &lw, T *top_grad, const T &data_attn_weight, T *grad_h_weight,
T *grad_w_weight, T *grad_value, T *grad_output_nram, T *grad_weight,
T *grad_sampling_loc, T *grad_attn_weight, T *grad_output_nram_temp,
const int32_t &deal_num, const int32_t &deal_num_real,
const T *data_value_ptr) {
if (h_low >= 0 && w_low >= 0) {
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram, data_value_ptr + offset1,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num_real);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num_real);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset1),
(T *)top_grad_temp, deal_num_real);
}
if (h_low >= 0 && w_high <= width - 1) {
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset2,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num_real);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num_real);
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w2,
deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset2),
(T *)top_grad_temp, deal_num_real);
}
if (h_high <= height - 1 && w_low >= 0) {
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset3,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num_real);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w3,
deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset3),
(T *)top_grad_temp, deal_num_real);
}
if (h_high <= height - 1 && w_high <= width - 1) {
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset4,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w4,
deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset4),
(T *)top_grad_temp, deal_num_real);
}
__bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num_real);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_output_nram, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
const int32_t align_num_on_200 = NFU_ALIGN_SIZE / LEN_FLOAT;
recursiveSumPool(grad_output_nram, align_num_on_200,
deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_output_nram, grad_output_nram,
NFU_ALIGN_SIZE / LEN_FLOAT);
#endif
__bang_atomic_add((T *)grad_output_nram, (T *)grad_attn_weight,
(T *)grad_output_nram, 1);
__bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
recursiveSumPool(grad_w_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_w_weight, grad_w_weight, NFU_ALIGN_SIZE / LEN_FLOAT);
#endif
__bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc),
(T *)grad_w_weight, 1);
__bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real);
__bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
recursiveSumPool(grad_h_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_h_weight, grad_h_weight, NFU_ALIGN_SIZE / LEN_FLOAT);
#endif
__bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1),
(T *)grad_h_weight, 1);
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight) {
if (coreId == 0x80) {
return;
}
const int32_t split_num = 8;
const int32_t spatial_shapes_size = 64;
int32_t deal_num = PAD_DOWN(
(MAX_NRAM_SIZE - spatial_shapes_size) / split_num / LEN_FLOAT, ALIGN_NUM);
float *grad_output_nram = (float *)nram_buffer;
float *grad_output_nram_temp = (float *)nram_buffer + deal_num;
float *grad_weight = (float *)nram_buffer + 2 * deal_num;
float *grad_h_weight = (float *)nram_buffer + 3 * deal_num;
float *grad_w_weight = (float *)nram_buffer + 4 * deal_num;
float *top_grad = (float *)nram_buffer + 5 * deal_num;
float *top_grad_temp = (float *)nram_buffer + 6 * deal_num;
int32_t *spatial_shapes_nram =
(int32_t *)((float *)nram_buffer + 7 * deal_num);
float *sampling_loc_nram =
(float *)nram_buffer + 7 * deal_num + 2 * sizeof(int32_t);
const int32_t total_num = batch * num_query * num_heads * num_levels;
int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem);
int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core)
: (num_rem + taskId * num_per_core);
int32_t end_per_core = start_per_core + num_per_core;
const int32_t C_repeat = channels / deal_num;
const int32_t C_tail = channels % deal_num;
const int32_t qid_stride = num_heads * channels;
int32_t base_ptr = 0;
for (int32_t num_loop = start_per_core; num_loop < end_per_core; ++num_loop) {
const int32_t l_col = num_loop % num_levels;
const int32_t m_col = num_loop / num_levels % num_heads;
const int32_t q_col = num_loop / num_levels / num_heads % num_query;
const int32_t b_col = num_loop / num_query / num_heads / num_levels;
int32_t data_weight_ptr = num_loop * num_points;
int32_t data_loc_w_ptr = data_weight_ptr << 1;
const int32_t value_offset = b_col * spatial_size * num_heads * channels;
const int32_t level_start_id = data_level_start_index[l_col];
int32_t spatial_h_ptr = l_col << 1;
int32_t grad_output_offset = b_col * num_query * num_heads * channels +
q_col * num_heads * channels +
m_col * channels;
__memcpy(spatial_shapes_nram, spatial_shapes + spatial_h_ptr,
2 * sizeof(int32_t), GDRAM2NRAM);
const int32_t spatial_h = spatial_shapes_nram[0];
const int32_t spatial_w = spatial_shapes_nram[1];
const int32_t value_ptr_offset = value_offset + level_start_id * qid_stride;
const float *data_value_ptr = data_value + value_ptr_offset;
float *grad_value_ptr = grad_value + value_ptr_offset;
const int32_t grad_attn_weight_out = num_loop * num_points;
const int32_t grad_sampling_loc_out = num_loop * num_points * 2;
for (int32_t p_col = 0; p_col < num_points; ++p_col) {
__memcpy(sampling_loc_nram, data_sampling_loc + data_loc_w_ptr,
2 * LEN_FLOAT, GDRAM2NRAM);
const float loc_w = sampling_loc_nram[0];
const float loc_h = sampling_loc_nram[1];
const float weight = data_attn_weight[data_weight_ptr];
const float h_im = loc_h * spatial_h - 0.5;
const float w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
const int32_t h_low = floorf(h_im);
const int32_t w_low = floorf(w_im);
const int32_t h_high = h_low + 1;
const int32_t w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1.0 - lh;
const float hw = 1.0 - lw;
const int32_t w_stride = num_heads * channels;
const int32_t h_stride = spatial_w * w_stride;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * w_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + w_stride;
float w1 = hh * hw;
float w2 = hh * lw;
float w3 = lh * hw;
float w4 = lh * lw;
for (int32_t C_loop = 0; C_loop < C_repeat; ++C_loop) {
base_ptr = m_col * channels + C_loop * deal_num;
__bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM));
__memcpy(top_grad,
grad_output + grad_output_offset + C_loop * deal_num,
deal_num * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram, grad_weight,
grad_sampling_loc + grad_sampling_loc_out + p_col * 2,
grad_attn_weight + grad_attn_weight_out + p_col,
grad_output_nram_temp, deal_num, deal_num, data_value_ptr);
}
if (C_tail != 0) {
base_ptr = m_col * channels + C_repeat * deal_num;
__bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM));
__memcpy(top_grad,
grad_output + grad_output_offset + C_repeat * deal_num,
C_tail * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram, grad_weight,
grad_sampling_loc + grad_sampling_loc_out + p_col * 2,
grad_attn_weight + grad_attn_weight_out + p_col,
grad_output_nram_temp, deal_num, C_tail, data_value_ptr);
}
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
}
void __mlu_func__ computeGridMaskAndOffset(
float *nram_grad_output_tl, float *nram_grad_output_tr, float *nram_loc_w,
float *nram_loc_h, float *nram_h_stride, int32_t *nram_spatial_shapes,
float *nram_w_low_temp, float *nram_h_high_temp, float *nram_w_low,
float *nram_h_low, float *nram_h_high, float *nram_w_high, float *nram_lh,
float *nram_lw, float *nram_hh, float *nram_hw,
float *nram_h_low_ptr_offset, float *nram_h_high_ptr_offset,
float *nram_w_low_ptr_offset, float *nram_w_high_ptr_offset, float *nram_w1,
float *nram_w2, float *nram_w3, float *nram_w4, float *nram_offset_temp,
float *nram_offset1, float *nram_offset2, float *nram_offset3,
float *nram_offset4, float *nram_base_ptr, float *nram_h_low_temp,
int32_t num_deal_grid, int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, const int32_t w_stride,
const int32_t qid_stride) {
#if __BANG_ARCH__ >= 322
// [num_levels, 2] --> [2, num_levels]
__bang_transpose(nram_grad_output_tl, nram_loc_w, num_deal_grid, 2);
__bang_transpose(nram_loc_w, nram_grad_output_tl,
num_per_time_real * num_heads * num_levels, num_points);
__bang_transpose(nram_loc_h, nram_grad_output_tl + num_deal_grid,
num_per_time_real * num_heads * num_levels, num_points);
__bang_int322float((float *)nram_spatial_shapes,
(int32_t *)nram_spatial_shapes, num_levels * 2, 0);
__bang_transpose(nram_grad_output_tr, (float *)nram_spatial_shapes,
num_levels, 2);
__bang_mul_scalar(nram_h_stride, nram_grad_output_tr + num_levels, w_stride,
num_levels);
__memcpy_async(nram_spatial_shapes, nram_grad_output_tr,
num_levels * 2 * sizeof(float), NRAM2NRAM);
__bang_cycle_mul(nram_loc_w, nram_loc_w,
(float *)nram_spatial_shapes + num_levels, num_deal_grid,
num_levels);
__bang_cycle_mul(nram_loc_h, nram_loc_h, (float *)(nram_spatial_shapes),
num_deal_grid, num_levels);
__bang_sub_scalar(nram_loc_w, nram_loc_w, 0.5, num_deal_grid);
__bang_sub_scalar(nram_loc_h, nram_loc_h, 0.5, num_deal_grid);
// get mask. (h_im > -1 && w_im > -1 &&
// h_im < spatial_h && w_im < spatial_w)
__bang_cycle_lt(nram_w_low_temp, nram_loc_w,
(float *)(nram_spatial_shapes + num_levels), num_deal_grid,
num_levels);
__bang_cycle_lt(nram_h_high_temp, nram_loc_h, (float *)(nram_spatial_shapes),
num_deal_grid, num_levels);
__bang_and(nram_w_low_temp, nram_w_low_temp, nram_h_high_temp, num_deal_grid);
__bang_gt_scalar(nram_h_high_temp, nram_loc_h, -1, num_deal_grid);
__bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp,
num_deal_grid);
__bang_gt_scalar(nram_w_low_temp, nram_loc_w, -1, num_deal_grid);
__bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp,
num_deal_grid);
__bang_transpose(nram_w_low_temp, nram_h_high_temp, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_high_temp, nram_w_low_temp,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, nram_loc_w, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_loc_w, nram_grad_output_tl, num_deal_grid * sizeof(float),
NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, nram_loc_h, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_loc_h, nram_grad_output_tl, num_deal_grid * sizeof(float),
NRAM2NRAM);
__bang_floor(nram_w_low, nram_loc_w, num_deal_grid);
__bang_floor(nram_h_low, nram_loc_h, num_deal_grid);
__bang_add_scalar(nram_h_high, nram_h_low, 1, num_deal_grid);
__bang_add_scalar(nram_w_high, nram_w_low, 1, num_deal_grid);
__bang_sub(nram_lh, nram_loc_h, nram_h_low, num_deal_grid);
__bang_sub(nram_lw, nram_loc_w, nram_w_low, num_deal_grid);
__bang_fusion(FUSION_FMA, nram_hh, nram_lh, (float)(-1), 1, num_deal_grid);
__bang_fusion(FUSION_FMA, nram_hw, nram_lw, (float)(-1), 1, num_deal_grid);
__bang_transpose(nram_h_low_ptr_offset, nram_h_low,
num_per_time_real * num_heads * num_levels, num_points);
__bang_cycle_mul(nram_h_low_ptr_offset, nram_h_low_ptr_offset, nram_h_stride,
num_deal_grid, num_levels);
__bang_cycle_add(nram_h_high_ptr_offset, nram_h_low_ptr_offset, nram_h_stride,
num_deal_grid, num_levels);
__bang_transpose(nram_w_low_ptr_offset, nram_h_low_ptr_offset, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_low_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_w_low_ptr_offset, nram_h_high_ptr_offset, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_high_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_mul_scalar(nram_w_low_ptr_offset, nram_w_low, qid_stride,
num_deal_grid);
__bang_add_scalar(nram_w_high_ptr_offset, nram_w_low_ptr_offset, qid_stride,
num_deal_grid);
__bang_mul(nram_w1, nram_hh, nram_hw, num_deal_grid);
__bang_mul(nram_w2, nram_hh, nram_lw, num_deal_grid);
__bang_mul(nram_w3, nram_lh, nram_hw, num_deal_grid);
__bang_mul(nram_w4, nram_lh, nram_lw, num_deal_grid);
__bang_add(nram_offset1, nram_h_low_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset1,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset1, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset2, nram_h_low_ptr_offset, nram_w_high_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset2,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset2, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset3, nram_h_high_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset3,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset3, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset4, nram_h_high_ptr_offset, nram_w_high_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset4,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset4, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
// h_low >= 0 && w_low >= 0 mask2
float *mask1 = nram_h_low_ptr_offset;
float *mask2 = nram_h_high_ptr_offset;
float *mask3 = nram_w_low_ptr_offset;
float *mask4 = nram_w_high_ptr_offset;
__bang_ge_scalar(mask1, nram_h_low, 0, num_deal_grid);
__bang_ge_scalar(mask2, nram_w_low, 0, num_deal_grid);
__bang_and(mask2, mask1, mask2, num_deal_grid);
__bang_and(mask2, nram_h_high_temp, mask2, num_deal_grid);
// h_low >= 0 && w_high <= width - 1 mask1
__bang_transpose(mask3, nram_w_high,
num_per_time_real * num_heads * num_levels, num_points);
__bang_sub_scalar(nram_spatial_shapes, nram_spatial_shapes, 1,
num_levels * 2);
__bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes + num_levels),
num_deal_grid, num_levels);
__bang_transpose(mask4, mask3, num_points,
num_per_time_real * num_heads * num_levels);
__bang_and(mask1, mask1, mask4, num_deal_grid);
__bang_and(mask1, nram_h_high_temp, mask1, num_deal_grid);
// h_high <= height - 1 && w_high <= width - 1 mask3
__bang_transpose(mask3, nram_h_high,
num_per_time_real * num_heads * num_levels, num_points);
__bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes), num_deal_grid,
num_levels);
__bang_transpose(nram_h_low_temp, mask3, num_points,
num_per_time_real * num_heads * num_levels);
__bang_and(mask4, mask4, nram_h_low_temp, num_deal_grid);
__bang_and(mask3, mask4, nram_h_high_temp, num_deal_grid);
// h_high <= height - 1 && w_low >= 0 mask4
__bang_ge_scalar(nram_w_low_temp, nram_w_low, 0, num_deal_grid);
__bang_and(mask4, nram_h_low_temp, nram_w_low_temp, num_deal_grid);
__bang_and(mask4, mask4, nram_h_high_temp, num_deal_grid);
#endif
}
void __mlu_func__ loadValue(
float *nram_grad_output_tl, float *nram_grad_output_tr,
float *nram_grad_output_bl, float *nram_grad_output_br,
const float *data_value, const float *grad_output, float *grad_temp1,
float *grad_temp2, float *mask1, float *mask2, float *mask3, float *mask4,
float *nram_offset1, float *nram_offset2, float *nram_offset3,
float *nram_offset4, float *nram_grad_weight,
int32_t *nram_level_start_index, int32_t offset_nram,
int32_t start_per_core, int32_t grid_loop, int32_t num_per_time_theory,
int32_t num_heads, int32_t deal_num_real, int32_t num_per_time_real,
int32_t num_deal_grid, const int32_t num_query, const int32_t num_levels,
const int32_t num_points, int32_t grid_offset, const int32_t spatial_size,
const int32_t qid_stride) {
#if __BANG_ARCH__ >= 322
int32_t value_offset_temp = 0;
__bang_write_zero(nram_grad_output_tl, 4 * offset_nram);
__sync_io_move_compute();
__memcpy_async(
grad_temp2,
grad_output + (start_per_core + grid_loop * num_per_time_theory) *
num_heads * deal_num_real,
num_per_time_real * num_heads * deal_num_real * sizeof(float),
GDRAM2NRAM);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
const int32_t b_col =
(grid_offset + loop) / num_query / num_heads / num_levels / num_points;
const int32_t l_col = (grid_offset + loop) / num_points % num_levels;
const int32_t level_start_id = nram_level_start_index[l_col];
value_offset_temp =
b_col * spatial_size * qid_stride + level_start_id * qid_stride;
if (mask2[loop]) {
__memcpy_async(
nram_grad_output_tl + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset1[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask1[loop]) {
__memcpy_async(
nram_grad_output_tr + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset2[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask4[loop]) {
__memcpy_async(
nram_grad_output_bl + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset3[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask3[loop]) {
__memcpy_async(
nram_grad_output_br + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset4[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
}
for (int32_t m = 0; m < deal_num_real; ++m) {
__memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
}
__sync_io_move_compute();
#endif
}
void __mlu_func__ computeGradValue(
float *grad_temp1, float *grad_temp2, float *grad_temp3, float *grad_temp4,
float *mask1, float *mask2, float *mask3, float *mask4, float *nram_offset1,
float *nram_offset2, float *nram_offset3, float *nram_offset4,
int32_t *nram_level_start_index, int32_t deal_num_real,
const float *grad_value, float *nram_w1, float *nram_w2, float *nram_w3,
float *nram_w4, int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, const int32_t num_query,
int32_t num_deal_grid, int32_t grid_offset, const int32_t spatial_size,
const int32_t qid_stride, float *nram_grid_offset1,
float *nram_grid_offset2) {
#if __BANG_ARCH__ >= 322
__bang_transpose(grad_temp3, grad_temp1,
deal_num_real * num_per_time_real * num_heads,
num_levels * num_points);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(grad_temp3, grad_temp3, grad_temp1,
num_deal_grid * deal_num_real,
deal_num_real * num_per_time_real * num_heads);
__bang_transpose(grad_temp4, grad_temp3, num_levels * num_points,
deal_num_real * num_per_time_real * num_heads);
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w1,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
nram_grid_offset1[loop] = ((loop + grid_offset) / num_query / num_heads /
num_levels / num_points) *
spatial_size * qid_stride;
}
__bang_transpose(nram_grid_offset2, nram_grid_offset1,
num_per_time_real * num_heads * num_levels, num_points);
__bang_int322float((float *)nram_level_start_index, nram_level_start_index,
num_levels, 0);
__bang_mul_scalar(nram_grid_offset1, (float *)nram_level_start_index,
qid_stride, num_levels);
__bang_cycle_add(nram_grid_offset2, nram_grid_offset2, nram_grid_offset1,
num_deal_grid, num_levels);
__bang_transpose(nram_grid_offset1, nram_grid_offset2, num_points,
num_per_time_real * num_heads * num_levels);
__bang_add(nram_offset1, nram_offset1, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset2, nram_offset2, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset3, nram_offset3, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset4, nram_offset4, nram_grid_offset1, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask2[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset1[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w2,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask1[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset2[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w3,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask4[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset3[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w4,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask3[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset4[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
#endif
}
void __mlu_func__ computeGradAttnWeight(
float *grad_w_weight, float *grad_weight, float *nram_grad_output_tl,
float *nram_grad_output_tr, float *nram_grad_output_bl,
float *nram_grad_output_br, float *grad_temp1, float *grad_temp2,
const float *grad_attn_weight, float *nram_hw, float *nram_hh,
float *nram_lw, float *nram_lh, float *grad_h_weight, float *nram_w1,
float *nram_w2, float *nram_w3, float *nram_w4, int32_t offset_nram,
int32_t num_deal_grid, int32_t deal_num_real, int32_t num_per_time_real,
const int32_t num_heads, const int32_t num_levels, const int32_t num_points,
int32_t grid_offset, float *nram_h_high_temp) {
#if __BANG_ARCH__ >= 322
__bang_write_zero(grad_w_weight, 2 * offset_nram);
// grad_output_nram_tl
__bang_transpose(grad_weight, nram_grad_output_tl, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_w1,
num_deal_grid * deal_num_real, num_deal_grid);
// nram_grad_output_tr
__bang_transpose(grad_weight, nram_grad_output_tr, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_lw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tr,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_hh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_w_weight, grad_w_weight, nram_grad_output_tr,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_w2,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_tr,
num_deal_grid * deal_num_real);
// nram_grad_output_tl
__bang_transpose(grad_weight, nram_grad_output_bl, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_hw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_h_weight, grad_h_weight, nram_grad_output_bl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_lh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_bl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_w3,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_bl,
num_deal_grid * deal_num_real);
// nram_grad_output_br
__bang_transpose(grad_weight, nram_grad_output_br, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_h_weight, grad_h_weight, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_w_weight, grad_w_weight, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_w4,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_br, nram_grad_output_tl, deal_num_real,
num_deal_grid);
__bang_transpose(nram_grad_output_tr, nram_grad_output_br,
num_per_time_real * num_heads,
num_points * num_levels * deal_num_real);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1,
num_deal_grid * deal_num_real,
num_per_time_real * num_heads * deal_num_real);
__bang_transpose(nram_grad_output_br, nram_grad_output_tr,
num_points * num_levels * deal_num_real,
num_per_time_real * num_heads);
__bang_transpose((float *)nram_grad_output_tr, (float *)nram_grad_output_br,
num_deal_grid, deal_num_real);
recursiveSumPool(nram_grad_output_tr, num_deal_grid, deal_num_real,
ALIGN_NUM);
__bang_float2int32((int *)nram_h_high_temp, nram_h_high_temp, num_deal_grid,
0);
__nram__ int table[2] = {0, (int)0xffffffff};
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)nram_grad_output_tr, (char *)nram_grad_output_tr,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__bang_atomic_add((float *)nram_grad_output_tr,
(float *)grad_attn_weight + grid_offset,
(float *)nram_grad_output_tr, num_deal_grid);
#endif
}
void __mlu_func__ computeGradSampingLoc(
const float *grad_sampling_loc, float *nram_grad_output_tl,
float *nram_grad_output_tr, float *grad_h_weight, float *grad_w_weight,
int32_t *nram_spatial_shapes, float *grad_temp1, float *grad_temp2,
float *nram_grad_weight, int32_t num_deal_grid, int32_t deal_num_real,
int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, int32_t grid_offset,
float *nram_h_high_temp) {
#if __BANG_ARCH__ >= 322
__bang_transpose(nram_grad_output_tl, grad_h_weight,
num_per_time_real * num_heads * num_levels * deal_num_real,
num_points);
__bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl,
(float *)nram_spatial_shapes, num_deal_grid * deal_num_real,
num_levels);
__bang_transpose(grad_h_weight, nram_grad_output_tl,
num_points * deal_num_real,
num_per_time_real * num_heads * num_levels);
for (int32_t m = 0; m < deal_num_real; ++m) {
__memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
}
__sync_io_move_compute();
__bang_transpose(nram_grad_output_tr, grad_temp1,
deal_num_real * num_per_time_real * num_heads,
num_levels * num_points);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1,
num_deal_grid * deal_num_real,
deal_num_real * num_per_time_real * num_heads);
__bang_transpose(grad_temp1, nram_grad_output_tr,
num_levels * num_points * deal_num_real,
num_per_time_real * num_heads);
__bang_mul(grad_h_weight, grad_h_weight, grad_temp1,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_tl, grad_h_weight, num_deal_grid,
deal_num_real);
__memcpy_async(grad_h_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM);
recursiveSumPool(grad_h_weight, num_deal_grid, deal_num_real, ALIGN_NUM);
__nram__ int table[2] = {0, (int)0xffffffff};
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)grad_h_weight, (char *)grad_h_weight,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__bang_transpose(nram_grad_output_tl, grad_w_weight,
num_per_time_real * num_heads * num_levels * deal_num_real,
num_points);
__bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl,
(float *)(nram_spatial_shapes + num_levels),
num_deal_grid * deal_num_real, num_levels);
__bang_transpose(grad_w_weight, nram_grad_output_tl,
num_points * deal_num_real,
num_per_time_real * num_heads * num_levels);
__bang_mul(grad_w_weight, grad_w_weight, grad_temp1,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_tl, grad_w_weight, num_deal_grid,
deal_num_real);
__memcpy(grad_w_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM);
recursiveSumPool(grad_w_weight, num_deal_grid, deal_num_real, ALIGN_NUM);
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)grad_w_weight, (char *)grad_w_weight,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__memcpy(grad_w_weight + num_deal_grid, grad_h_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, grad_w_weight, 2, num_deal_grid);
__bang_atomic_add((float *)nram_grad_output_tl,
(float *)grad_sampling_loc + grid_offset * 2,
(float *)nram_grad_output_tl, 2 * num_deal_grid);
#endif
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight) {
#if __BANG_ARCH__ > 322
const int32_t split_grid_num = 28;
const int32_t split_num_c = 8;
const int32_t C_align = PAD_UP(channels, ALIGN_NUM);
const int32_t num_hlp = num_heads * num_levels * num_points;
int32_t num_per_time_theory = (MAX_NRAM_SIZE - num_levels * sizeof(float) -
3 * num_levels * sizeof(int32_t)) /
sizeof(float) /
(split_num_c * C_align + split_grid_num) /
PAD_UP((num_hlp), ALIGN_NUM);
int32_t deal_grid_num_theory = num_per_time_theory * num_hlp;
const int32_t offset_nram = num_per_time_theory * C_align * num_hlp;
const int32_t offset_nram_calc = PAD_UP(deal_grid_num_theory, ALIGN_NUM);
float *nram_grad_output_tl = (float *)nram_buffer;
float *nram_grad_output_tr = (float *)nram_buffer + offset_nram;
float *nram_grad_output_bl = (float *)nram_buffer + 2 * offset_nram;
float *nram_grad_output_br = (float *)nram_buffer + 3 * offset_nram;
float *grad_temp1 = (float *)nram_buffer + 4 * offset_nram;
float *grad_temp2 = (float *)nram_buffer + 5 * offset_nram;
float *grad_temp3 = (float *)nram_buffer + 6 * offset_nram;
float *grad_temp4 = (float *)nram_buffer + 7 * offset_nram;
float *nram_loc_w = (float *)nram_buffer + split_num_c * offset_nram;
float *nram_loc_h =
(float *)nram_buffer + split_num_c * offset_nram + offset_nram_calc;
float *nram_h_low =
(float *)nram_buffer + split_num_c * offset_nram + 2 * offset_nram_calc;
float *nram_w_low =
(float *)nram_buffer + split_num_c * offset_nram + 3 * offset_nram_calc;
float *nram_h_high =
(float *)nram_buffer + split_num_c * offset_nram + 4 * offset_nram_calc;
float *nram_w_high =
(float *)nram_buffer + split_num_c * offset_nram + 5 * offset_nram_calc;
float *nram_h_low_temp =
(float *)nram_buffer + split_num_c * offset_nram + 6 * offset_nram_calc;
float *nram_h_high_temp =
(float *)nram_buffer + split_num_c * offset_nram + 7 * offset_nram_calc;
float *nram_hw =
(float *)nram_buffer + split_num_c * offset_nram + 8 * offset_nram_calc;
float *nram_hh =
(float *)nram_buffer + split_num_c * offset_nram + 9 * offset_nram_calc;
float *nram_lw =
(float *)nram_buffer + split_num_c * offset_nram + 10 * offset_nram_calc;
float *nram_lh =
(float *)nram_buffer + split_num_c * offset_nram + 11 * offset_nram_calc;
float *nram_h_low_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 12 * offset_nram_calc;
float *nram_h_high_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 13 * offset_nram_calc;
float *nram_w_low_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 14 * offset_nram_calc;
float *nram_w_high_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 15 * offset_nram_calc;
float *nram_w1 =
(float *)nram_buffer + split_num_c * offset_nram + 16 * offset_nram_calc;
float *nram_w2 =
(float *)nram_buffer + split_num_c * offset_nram + 17 * offset_nram_calc;
float *nram_w3 =
(float *)nram_buffer + split_num_c * offset_nram + 18 * offset_nram_calc;
float *nram_w4 =
(float *)nram_buffer + split_num_c * offset_nram + 19 * offset_nram_calc;
float *nram_grad_weight =
(float *)nram_buffer + split_num_c * offset_nram + 20 * offset_nram_calc;
float *nram_base_ptr =
(float *)nram_buffer + split_num_c * offset_nram + 21 * offset_nram_calc;
float *nram_offset_temp =
(float *)nram_buffer + split_num_c * offset_nram + 22 * offset_nram_calc;
float *nram_offset1 =
(float *)nram_buffer + split_num_c * offset_nram + 23 * offset_nram_calc;
float *nram_offset2 =
(float *)nram_buffer + split_num_c * offset_nram + 24 * offset_nram_calc;
float *nram_offset3 =
(float *)nram_buffer + split_num_c * offset_nram + 25 * offset_nram_calc;
float *nram_offset4 =
(float *)nram_buffer + split_num_c * offset_nram + 26 * offset_nram_calc;
float *nram_w_low_temp =
(float *)nram_buffer + split_num_c * offset_nram + 27 * offset_nram_calc;
int32_t *nram_spatial_shapes =
(int32_t *)((float *)nram_buffer + split_num_c * offset_nram +
28 * offset_nram_calc);
int32_t *nram_level_start_index =
(int32_t *)(nram_spatial_shapes + 2 * num_levels);
float *nram_h_stride = (float *)(nram_level_start_index + 3 * num_levels);
const int32_t total_num = batch * num_query;
int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem);
num_per_time_theory =
num_per_core > num_per_time_theory ? num_per_time_theory : num_per_core;
int32_t num_deal_grid = num_per_time_theory * num_hlp;
if (num_per_core == 0) return;
int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core)
: (num_rem + taskId * num_per_core);
const int32_t qid_stride = num_heads * channels;
int32_t deal_num_real = channels;
const int32_t repeat_times = num_per_core / num_per_time_theory;
const int32_t tail_num = num_per_core % num_per_time_theory;
int32_t num_per_time_real = num_per_time_theory;
for (int32_t loop = 0; loop < num_heads; ++loop) {
nram_base_ptr[loop] = loop * channels;
}
const int32_t w_stride = num_heads * channels;
for (int32_t grid_loop = 0; grid_loop < repeat_times + 1; grid_loop += 1) {
int32_t grid_offset =
(start_per_core + grid_loop * num_per_time_theory) * num_hlp;
if (grid_loop == repeat_times) {
if (tail_num == 0) {
continue;
} else {
grid_offset =
(start_per_core + repeat_times * num_per_time_theory) * num_hlp;
num_per_time_real = tail_num;
num_deal_grid = tail_num * num_hlp;
}
}
__memcpy_async(nram_spatial_shapes, spatial_shapes,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(nram_level_start_index, data_level_start_index,
num_levels * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(nram_loc_w, data_sampling_loc + grid_offset * 2,
num_deal_grid * 2 * sizeof(float), GDRAM2NRAM);
__memcpy(nram_grad_weight, data_attn_weight + grid_offset,
num_deal_grid * sizeof(float), GDRAM2NRAM);
computeGridMaskAndOffset(
nram_grad_output_tl, nram_grad_output_tr, nram_loc_w, nram_loc_h,
nram_h_stride, nram_spatial_shapes, nram_w_low_temp, nram_h_high_temp,
nram_w_low, nram_h_low, nram_h_high, nram_w_high, nram_lh, nram_lw,
nram_hh, nram_hw, nram_h_low_ptr_offset, nram_h_high_ptr_offset,
nram_w_low_ptr_offset, nram_w_high_ptr_offset, nram_w1, nram_w2,
nram_w3, nram_w4, nram_offset_temp, nram_offset1, nram_offset2,
nram_offset3, nram_offset4, nram_base_ptr, nram_h_low_temp,
num_deal_grid, num_per_time_real, num_heads, num_levels, num_points,
w_stride, qid_stride);
float *mask1 = nram_h_low_ptr_offset;
float *mask2 = nram_h_high_ptr_offset;
float *mask3 = nram_w_low_ptr_offset;
float *mask4 = nram_w_high_ptr_offset;
loadValue(nram_grad_output_tl, nram_grad_output_tr, nram_grad_output_bl,
nram_grad_output_br, data_value, grad_output, grad_temp1,
grad_temp2, mask1, mask2, mask3, mask4, nram_offset1,
nram_offset2, nram_offset3, nram_offset4, nram_grad_weight,
nram_level_start_index, offset_nram, start_per_core, grid_loop,
num_per_time_theory, num_heads, deal_num_real, num_per_time_real,
num_deal_grid, num_query, num_levels, num_points, grid_offset,
spatial_size, qid_stride);
float *nram_grid_offset1 = nram_loc_h;
float *nram_grid_offset2 = nram_loc_w;
computeGradValue(
grad_temp1, grad_temp2, grad_temp3, grad_temp4, mask1, mask2, mask3,
mask4, nram_offset1, nram_offset2, nram_offset3, nram_offset4,
nram_level_start_index, deal_num_real, grad_value, nram_w1, nram_w2,
nram_w3, nram_w4, num_per_time_real, num_heads, num_levels, num_points,
num_query, num_deal_grid, grid_offset, spatial_size, qid_stride,
nram_grid_offset1, nram_grid_offset2);
// compute grad_weight
float *grad_weight = grad_temp1;
float *grad_h_weight = grad_temp4;
float *grad_w_weight = grad_temp3;
computeGradAttnWeight(
grad_w_weight, grad_weight, nram_grad_output_tl, nram_grad_output_tr,
nram_grad_output_bl, nram_grad_output_br, grad_temp1, grad_temp2,
grad_attn_weight, nram_hw, nram_hh, nram_lw, nram_lh, grad_h_weight,
nram_w1, nram_w2, nram_w3, nram_w4, offset_nram, num_deal_grid,
deal_num_real, num_per_time_real, num_heads, num_levels, num_points,
grid_offset, nram_h_high_temp);
// compute grad_sampling_loc
computeGradSampingLoc(grad_sampling_loc, nram_grad_output_tl,
nram_grad_output_tr, grad_h_weight, grad_w_weight,
nram_spatial_shapes, grad_temp1, grad_temp2,
nram_grad_weight, num_deal_grid, deal_num_real,
num_per_time_real, num_heads, num_levels, num_points,
grid_offset, nram_h_high_temp);
}
#endif
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight);
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight);
void KernelMsDeformAttnBackwardDefaultKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float *data_value,
const int32_t *spatial_shapes, const int32_t *data_level_start_index,
const float *data_sampling_loc, const float *data_attn_weight,
const float *grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float *grad_value,
float *grad_sampling_loc, float *grad_attn_weight) {
MLUUnion1KernelMsDeformAttnBackwarDefaultKernel<<<k_dim, k_type, queue>>>(
data_value, spatial_shapes, data_level_start_index, data_sampling_loc,
data_attn_weight, grad_output, batch, spatial_size, num_heads, channels,
num_levels, num_query, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
}
void KernelMsDeformAttnBackwardSmallChannelsKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float *data_value,
const int32_t *spatial_shapes, const int32_t *data_level_start_index,
const float *data_sampling_loc, const float *data_attn_weight,
const float *grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float *grad_value,
float *grad_sampling_loc, float *grad_attn_weight) {
MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel<<<k_dim, k_type,
queue>>>(
data_value, spatial_shapes, data_level_start_index, data_sampling_loc,
data_attn_weight, grad_output, batch, spatial_size, num_heads, channels,
num_levels, num_query, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
}
mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "nms_utils.hpp"
#define COORD_DIM (4)
#define SIZE_NRAM_BUF (MAX_NRAM_SIZE + REM_FOR_STACK - 62 * 1024)
#define SIZE_SRAM_BUF (MAX_SRAM_SIZE)
__nram__ int8_t nram_buffer[SIZE_NRAM_BUF];
__mlu_shared__ int8_t sram_buffer[SIZE_SRAM_BUF];
enum Addr { SRAM, GDRAM };
template <typename IN_DT, typename OUT_DT>
__mlu_func__ void nms_detection(
uint32_t &output_box_num, const int output_mode, OUT_DT *output_dram,
IN_DT *input_data_score, const IN_DT *input_data_box, const Addr input_ram,
IN_DT *sram, const int core_limit, const int input_num_boxes,
const int max_output_size, const float thresh_iou, const float thresh_score,
const float offset, const int algo) {
// global value
int32_t *exit_flag = (int32_t *)(sram + 28);
exit_flag[0] = 0;
// score, x1, y1, x2, y2, inter_x1, inter_y1, inter_x2, inter_y2
int nms_buffer_count1 = 9;
// temp nram buffer to store selected target.
int nram_save_limit_count = 256;
float div_thresh_iou = 1.0 / thresh_iou;
// input data ptr
const IN_DT *input_x1_ptr = input_data_box;
const IN_DT *input_y1_ptr = input_x1_ptr + input_num_boxes;
const IN_DT *input_x2_ptr = input_y1_ptr + input_num_boxes;
const IN_DT *input_y2_ptr = input_x2_ptr + input_num_boxes;
int limit = 0; // find limit when GDRAM or SRAM
int max_seg_pad = 0; // the max length every repeat
int repeat = 0;
int remain = 0;
int remain_pad = 0;
int input_offset = 0; // offset of input_data for current core
int nram_save_count = 0;
if (output_mode == 0) {
limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) -
nram_save_limit_count * sizeof(OUT_DT)) /
(nms_buffer_count1 * sizeof(IN_DT));
} else {
// 5 maens: score, x1, y1, x2, y2
limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) -
nram_save_limit_count * 5 * sizeof(OUT_DT)) /
(nms_buffer_count1 * sizeof(IN_DT));
}
int max_seg_iou_compute = 0;
int repeat_iou_compute = 0;
int remain_iou_compute = 0;
int remain_pad_iou_compute = 0;
getComputeParamsBlockOrU1(sizeof(IN_DT), input_num_boxes, limit, core_limit,
input_offset, max_seg_pad, repeat, remain,
remain_pad, max_seg_iou_compute, repeat_iou_compute,
remain_iou_compute, remain_pad_iou_compute);
// init the data ptr
IN_DT *score = (IN_DT *)nram_buffer;
IN_DT *x1 = score + max_seg_pad;
IN_DT *y1 = x1 + max_seg_pad;
IN_DT *x2 = y1 + max_seg_pad;
IN_DT *y2 = x2 + max_seg_pad;
IN_DT *inter_x1 = y2 + max_seg_pad;
IN_DT *inter_y1 = inter_x1 + max_seg_pad;
IN_DT *inter_x2 = inter_y1 + max_seg_pad;
IN_DT *inter_y2 = inter_x2 + max_seg_pad;
IN_DT *max_box = inter_y2 + max_seg_pad; // the max score, x1, y1, x2, y2
OUT_DT *nram_save =
(OUT_DT *)((char *)max_box +
NFU_ALIGN_SIZE); // offset two line from max_box
#if __BANG_ARCH__ >= 300
float max_box_x1 = 0;
float max_box_y1 = 0;
float max_box_x2 = 0;
float max_box_y2 = 0;
#endif
mluMemcpyDirection_t load_dir = SRAM2NRAM;
mluMemcpyDirection_t store_dir = NRAM2SRAM;
load_dir = (input_ram == SRAM) ? SRAM2NRAM : GDRAM2NRAM;
store_dir = (input_ram == SRAM) ? NRAM2SRAM : NRAM2GDRAM;
for (int keep = 0; keep < max_output_size;
keep++) { // loop until the max_score <= 0
if (core_limit != 1) {
__sync_cluster(); // sync before current loop
}
/******FIND MAX START******/
int max_index = 0; // the max score index
int global_max_index = 0; // for U1
float max_area = 0; // the max socre area
max_box[0] = 0; // init 0
findCoreMaxBox(input_data_score, score, inter_x1, max_box, input_x1_ptr,
input_y1_ptr, input_x2_ptr, input_y2_ptr, load_dir,
input_offset, repeat, remain, remain_pad, max_seg_pad,
max_index);
if (core_limit == 1) {
#if __BANG_ARCH__ >= 300
calMaxArea(max_box, algo, offset, max_area, max_box_x1, max_box_y1,
max_box_x2, max_box_y2);
#else
calMaxArea(max_box, algo, offset, max_area);
#endif
input_data_score[max_index] = 0;
global_max_index = max_index;
} else if (core_limit == 4) {
__sync_cluster();
findClusterMaxBox(sram, max_box, inter_x1, input_data_score, core_limit);
#if __BANG_ARCH__ >= 300
calMaxArea(max_box, algo, offset, max_area, max_box_x1, max_box_y1,
max_box_x2, max_box_y2);
#else
calMaxArea(max_box, algo, offset, max_area);
#endif
global_max_index = ((uint32_t *)(max_box + 5))[0];
input_data_score[global_max_index] = 0;
}
// by now, we get: max_score|max_index|max_box|max_area
/******FIND MAX END******/
storeResult(max_box, nram_save, output_dram, keep, nram_save_limit_count,
max_output_size, thresh_score, output_mode, nram_save_count,
output_box_num);
// if the max score <= 0, end
if (core_limit == 1) {
if (float(max_box[0]) <= thresh_score) {
break;
}
} else {
if (float(max_box[0]) <= thresh_score) {
if (coreId == 0) {
exit_flag[0] = 1;
}
}
__sync_cluster();
if (exit_flag[0] == 1) {
break;
}
}
/******NMS STORE END******/
#if __BANG_ARCH__ >= 300
scoreUpdate(input_data_score, load_dir, store_dir, input_x1_ptr,
input_y1_ptr, input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score,
inter_x1, inter_y1, inter_x2, inter_y2, max_box, max_box_x1,
max_box_y1, max_box_x2, max_box_y2, nram_save,
repeat_iou_compute, remain_iou_compute, remain_pad_iou_compute,
max_seg_iou_compute, max_seg_pad, thresh_iou, div_thresh_iou,
input_offset, offset, max_area, input_num_boxes, algo);
#else
scoreUpdate(input_data_score, load_dir, store_dir, input_x1_ptr,
input_y1_ptr, input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score,
inter_x1, inter_y1, inter_x2, inter_y2, max_box, max_box[1],
max_box[2], max_box[3], max_box[4], nram_save,
repeat_iou_compute, remain_iou_compute, remain_pad_iou_compute,
max_seg_iou_compute, max_seg_pad, thresh_iou, div_thresh_iou,
input_offset, offset, max_area, input_num_boxes, algo);
#endif
} // for max_output_size
}
__mlu_global__ void MLUUnion1KernelNMS(
const void *input_boxes, const void *input_confidence,
const int input_num_boxes, const int max_output_size,
const float iou_threshold, const float confidence_threshold,
const int output_mode, void *workspace, void *result_num, void *output,
const cnrtDataType_t data_type_input, const float offset, const int algo) {
if (data_type_input == CNRT_FLOAT16) {
__memcpy(workspace, input_confidence, input_num_boxes * sizeof(half),
GDRAM2GDRAM);
} else if (data_type_input == CNRT_FLOAT32) {
__memcpy(workspace, input_confidence, input_num_boxes * sizeof(float),
GDRAM2GDRAM);
} else {
}
uint32_t output_box_num = 0;
float *score_data = (float *)workspace;
float *boxes_data = (float *)input_boxes;
float *sram = (float *)sram_buffer;
if (output_mode == 0) {
if (data_type_input == CNRT_FLOAT32) {
nms_detection(output_box_num, output_mode, (uint32_t *)output, score_data,
boxes_data, GDRAM, sram, taskDim, input_num_boxes,
max_output_size, iou_threshold, confidence_threshold,
offset, algo);
} else {
nms_detection(output_box_num, output_mode, (uint32_t *)output,
(half *)score_data, (half *)boxes_data, GDRAM, (half *)sram,
taskDim, input_num_boxes, max_output_size, iou_threshold,
confidence_threshold, offset, algo);
}
} else {
if (data_type_input == CNRT_FLOAT32) {
nms_detection(output_box_num, output_mode, (float *)output, score_data,
boxes_data, GDRAM, sram, taskDim, input_num_boxes,
max_output_size, iou_threshold, confidence_threshold,
offset, algo);
} else {
nms_detection(output_box_num, output_mode, (half *)output,
(half *)score_data, (half *)boxes_data, GDRAM, (half *)sram,
taskDim, input_num_boxes, max_output_size, iou_threshold,
confidence_threshold, offset, algo);
}
}
((uint32_t *)result_num)[0] = output_box_num;
}
template <typename IN_DT, typename OUT_DT>
__mlu_func__ void nms_detection_ux(
int32_t *exit_flag, uint32_t &output_box_num, OUT_DT *output_dram,
IN_DT *score_data, const IN_DT *boxes_data, const Addr input_ram,
const int input_num_boxes, const int max_output_size,
const float thresh_iou, const float thresh_score, const float offset,
const int output_mode, const int algo, char *cdma_gdram) {
exit_flag[0] = 0;
IN_DT *sram = (IN_DT *)sram_buffer;
// score, x1, y1, x2, y2, inter_x1, inter_y1, inter_x2, inter_y2
int nms_buffer_count1 = 9;
// temp nram buffer to store selected target.
int nram_save_limit_count = 256;
float div_thresh_iou = 1.0 / thresh_iou;
// input data ptr
const IN_DT *input_x1_ptr = boxes_data;
const IN_DT *input_y1_ptr = input_x1_ptr + input_num_boxes;
const IN_DT *input_x2_ptr = input_y1_ptr + input_num_boxes;
const IN_DT *input_y2_ptr = input_x2_ptr + input_num_boxes;
int limit = 0; // find limit when GDRAM or SRAM
int max_seg_pad = 0; // the max length every repeat
int repeat = 0;
int remain = 0;
int remain_pad = 0;
int nram_save_count = 0;
if (output_mode == 0) {
limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) -
nram_save_limit_count * sizeof(OUT_DT)) /
(nms_buffer_count1 * sizeof(IN_DT));
} else {
limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) -
nram_save_limit_count * INFO_NUM * sizeof(OUT_DT)) /
(nms_buffer_count1 * sizeof(IN_DT));
}
int input_offset = 0;
int max_seg_iou_compute = 0;
int repeat_iou_compute = 0;
int remain_iou_compute = 0;
int remain_pad_iou_compute = 0;
getComputeParamsUx(sizeof(IN_DT), input_num_boxes, limit, input_offset,
max_seg_pad, repeat, remain, remain_pad,
max_seg_iou_compute, repeat_iou_compute,
remain_iou_compute, remain_pad_iou_compute);
// init the nram ptr
IN_DT *score = (IN_DT *)nram_buffer;
IN_DT *x1 = score + max_seg_pad;
IN_DT *y1 = x1 + max_seg_pad;
IN_DT *x2 = y1 + max_seg_pad;
IN_DT *y2 = x2 + max_seg_pad;
IN_DT *inter_x1 = y2 + max_seg_pad;
IN_DT *inter_y1 = inter_x1 + max_seg_pad;
IN_DT *inter_x2 = inter_y1 + max_seg_pad;
IN_DT *inter_y2 = inter_x2 + max_seg_pad;
IN_DT *max_box = inter_y2 + max_seg_pad; // the max score, x1, y1, x2, y2
OUT_DT *nram_save =
(OUT_DT *)((char *)max_box +
NFU_ALIGN_SIZE); // offset two line from max_box
#if __BANG_ARCH__ >= 300
float max_box_x1 = 0;
float max_box_y1 = 0;
float max_box_x2 = 0;
float max_box_y2 = 0;
#endif
mluMemcpyDirection_t load_dir = SRAM2NRAM;
mluMemcpyDirection_t store_dir = NRAM2SRAM;
load_dir = (input_ram == SRAM) ? SRAM2NRAM : GDRAM2NRAM;
store_dir = (input_ram == SRAM) ? NRAM2SRAM : NRAM2GDRAM;
for (int keep = 0; keep < max_output_size;
keep++) { // loop until the max_score <= 0
__sync_all();
int max_index = 0;
int global_max_index = 0; // for Ux
float max_area = 0; // the max socre area
max_box[0] = 0; // init 0
if (coreId == 0) {
findCoreMaxBox(score_data, score, inter_x1, max_box, input_x1_ptr,
input_y1_ptr, input_x2_ptr, input_y2_ptr, load_dir,
input_offset, repeat, remain, remain_pad, max_seg_pad,
max_index);
// copy max box info to sram
__memcpy(sram, max_box, REDUCE_NUM * sizeof(IN_DT), NRAM2SRAM);
}
__sync_all();
#if __BANG_ARCH__ >= 590
__memcpy((char *)cdma_gdram + REDUCE_NUM * clusterId * sizeof(IN_DT), sram,
REDUCE_NUM * sizeof(IN_DT), SRAM2GDRAM);
__sync_all();
if (clusterId == 0 && coreId == 0) {
__bang_write_zero(inter_x1, NMS_SIZE);
__memcpy((char *)inter_x1, (char *)cdma_gdram, sizeof(IN_DT), GDRAM2NRAM,
sizeof(IN_DT), REDUCE_NUM * sizeof(IN_DT), clusterDim - 1);
__bang_max(max_box, inter_x1, NMS_SIZE);
int max_cluster = (sizeof(IN_DT) == sizeof(half))
? ((uint16_t *)max_box)[1]
: ((uint32_t *)max_box)[1];
__memcpy((char *)cdma_gdram,
(char *)cdma_gdram + max_cluster * REDUCE_NUM * sizeof(IN_DT),
REDUCE_NUM * sizeof(IN_DT), GDRAM2GDRAM);
}
__sync_all();
__memcpy(max_box, cdma_gdram, REDUCE_NUM * sizeof(IN_DT), GDRAM2NRAM);
#else
findGlobalMaxBox(max_box, sram, inter_x1);
#endif
#if __BANG_ARCH__ >= 300
calMaxArea(max_box, algo, offset, max_area, max_box_x1, max_box_y1,
max_box_x2, max_box_y2);
#else
calMaxArea(max_box, algo, offset, max_area);
#endif
global_max_index = ((uint32_t *)(max_box + 5))[0];
if (coreId != MEMORY_CORE) {
score_data[global_max_index] = 0;
}
storeResult(max_box, nram_save, output_dram, keep, nram_save_limit_count,
max_output_size, thresh_score, output_mode, nram_save_count,
output_box_num);
if (float(max_box[0]) <= thresh_score) {
if (clusterId == 0 && coreId == 0) {
exit_flag[0] = 1; // dram
}
}
__sync_all();
if (exit_flag[0] == 1) {
break;
}
/******NMS STORE END******/
#if __BANG_ARCH__ >= 300
scoreUpdate(score_data, load_dir, store_dir, input_x1_ptr, input_y1_ptr,
input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score, inter_x1,
inter_y1, inter_x2, inter_y2, max_box, max_box_x1, max_box_y1,
max_box_x2, max_box_y2, nram_save, repeat_iou_compute,
remain_iou_compute, remain_pad_iou_compute, max_seg_iou_compute,
max_seg_pad, thresh_iou, div_thresh_iou, input_offset, offset,
max_area, input_num_boxes, algo);
#else
scoreUpdate(score_data, load_dir, store_dir, input_x1_ptr, input_y1_ptr,
input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score, inter_x1,
inter_y1, inter_x2, inter_y2, max_box, max_box[1], max_box[2],
max_box[3], max_box[4], nram_save, repeat_iou_compute,
remain_iou_compute, remain_pad_iou_compute, max_seg_iou_compute,
max_seg_pad, thresh_iou, div_thresh_iou, input_offset, offset,
max_area, input_num_boxes, algo);
#endif
} // for max_output_size
}
__mlu_global__ void MLUUionXKernelNMS(
const void *input_boxes, const void *input_confidence,
const int input_num_boxes, const int max_output_size,
const float iou_threshold, const float confidence_threshold,
const float offset, const cnrtDataType_t data_type_input,
const int output_mode, const int algo, void *workspace, void *result_num,
void *output) {
int input_dwidth = (data_type_input == CNRT_FLOAT32) ? 4 : 2;
int32_t *exit_flag = (int32_t *)((char *)workspace +
INFO_NUM * input_num_boxes * input_dwidth);
char *cdma_addr = (char *)exit_flag + sizeof(int32_t);
int reduce_sram_size = NFU_ALIGN_SIZE * REDUCE_NUM * input_dwidth;
int availbale_sram_size = SIZE_SRAM_BUF - reduce_sram_size;
int cluster_score_size = input_num_boxes * input_dwidth;
int cluster_boxes_size = input_num_boxes * 4 * input_dwidth;
char *sram_score = (char *)sram_buffer + reduce_sram_size;
char *sram_boxes =
(char *)sram_buffer + reduce_sram_size + cluster_score_size;
Addr input_ram = GDRAM;
if ((cluster_score_size + cluster_boxes_size) < availbale_sram_size) {
input_ram = SRAM;
__memcpy(sram_score, input_confidence, cluster_score_size, GDRAM2SRAM);
__memcpy(sram_boxes, input_boxes, cluster_boxes_size, GDRAM2SRAM);
} else {
__memcpy(workspace, input_confidence, cluster_score_size, GDRAM2GDRAM);
}
__sync_cluster();
uint32_t output_box_num = 0;
float *score_data;
float *boxes_data;
score_data = (input_ram == SRAM) ? (float *)sram_score : (float *)workspace;
boxes_data = (input_ram == SRAM) ? (float *)sram_boxes : (float *)input_boxes;
if (output_mode == 0) {
if (data_type_input == CNRT_FLOAT32) {
nms_detection_ux(exit_flag, output_box_num, (uint32_t *)output,
score_data, boxes_data, input_ram, input_num_boxes,
max_output_size, iou_threshold, confidence_threshold,
offset, output_mode, algo, cdma_addr);
} else {
nms_detection_ux(exit_flag, output_box_num, (uint32_t *)output,
(half *)score_data, (half *)boxes_data, input_ram,
input_num_boxes, max_output_size, iou_threshold,
confidence_threshold, offset, output_mode, algo,
cdma_addr);
}
} else {
if (data_type_input == CNRT_FLOAT32) {
nms_detection_ux(exit_flag, output_box_num, (float *)output, score_data,
boxes_data, input_ram, input_num_boxes, max_output_size,
iou_threshold, confidence_threshold, offset, output_mode,
algo, cdma_addr);
} else {
nms_detection_ux(exit_flag, output_box_num, (half *)output,
(half *)score_data, (half *)boxes_data, input_ram,
input_num_boxes, max_output_size, iou_threshold,
confidence_threshold, offset, output_mode, algo,
cdma_addr);
}
}
((uint32_t *)result_num)[0] = output_box_num;
}
void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t data_type_input, const void *boxes_ptr,
const void *scores_ptr, const int input_num_boxes,
const int max_output_boxes, const float iou_threshold,
const float offset, void *workspace_ptr, void *output_size_ptr,
void *output_ptr) {
switch (k_type) {
default: { return; }
case CNRT_FUNC_TYPE_BLOCK:
case CNRT_FUNC_TYPE_UNION1: {
MLUUnion1KernelNMS<<<k_dim, k_type, queue>>>(
(void *)boxes_ptr, (void *)scores_ptr, input_num_boxes,
max_output_boxes, iou_threshold, /*confidence_threshold=*/0.0,
/*output_mode=*/0, workspace_ptr, output_size_ptr, output_ptr,
data_type_input, offset, /*algo=*/1);
}; break;
case CNRT_FUNC_TYPE_UNION2:
case CNRT_FUNC_TYPE_UNION4:
case CNRT_FUNC_TYPE_UNION8:
case CNRT_FUNC_TYPE_UNION16: {
MLUUionXKernelNMS<<<k_dim, k_type, queue>>>(
(void *)boxes_ptr, (void *)scores_ptr, input_num_boxes,
max_output_boxes, iou_threshold, /*confidence_threshold=*/0.0, offset,
data_type_input, /*output_mode=*/0, /*algo=*/1, workspace_ptr,
output_size_ptr, output_ptr);
}; break;
}
}
mmcv/ops/csrc/common/mlu/nms_utils.hpp
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) [2019-2022] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef NMS_UTILS_HPP_
#define NMS_UTILS_HPP_
#include "common_mlu_helper.hpp"
#define NMS_SIZE (64)
#define NMS_UP(x, y) (x / y + (int)(x % y > 0)) * y
#define NMS_DOWN(x, y) (x / y) * y
#define INFO_NUM (5) // 5 means x1, x2, y1, y2 and score
#define MEMORY_CORE (0x80)
#define REDUCE_NUM \
(7) // score, x1, y1, x2, y2, max_index (reserve 2 num for half-type input)
__mlu_func__
void
pvLock
()
{
#if __BANG_ARCH__ == 270
if
(
coreId
!=
MEMORY_CORE
)
{
__bang_lock
(
0
,
0
);
}
#endif
}
__mlu_func__
void
pvUnlock
()
{
#if __BANG_ARCH__ == 270
if
(
coreId
!=
MEMORY_CORE
)
{
__bang_unlock
(
0
,
0
);
}
#endif
}
template
<
typename
T
>
static
__mlu_func__
void
computeReluN
(
T
*
nram_dst
,
T
*
nram_src
,
void
*
nram_tmp
,
const
int
deal_num
,
const
T
threshold
=
0
)
{
if
(
threshold
<
0
)
{
return
;
}
if
(
threshold
)
{
#if __BANG_ARCH__ >= 300
__bang_relun
(
nram_dst
,
nram_src
,
deal_num
,
threshold
);
#else
int
align_num
=
NFU_ALIGN_SIZE
/
sizeof
(
T
);
T
*
nram_aux_a
=
(
T
*
)
nram_tmp
;
T
*
nram_aux_b
=
nram_aux_a
+
deal_num
;
T
*
nram_zero
=
nram_aux_b
+
align_num
;
__bang_write_value
(
nram_aux_b
,
align_num
,
threshold
);
__bang_write_zero
(
nram_zero
,
align_num
);
__bang_cycle_lt
((
T
*
)
nram_aux_a
,
nram_src
,
(
T
*
)
nram_aux_b
,
deal_num
,
align_num
);
__bang_mul
(
nram_dst
,
nram_src
,
(
T
*
)
nram_aux_a
,
deal_num
);
__bang_cycle_eq
((
T
*
)
nram_aux_a
,
(
T
*
)
nram_aux_a
,
(
T
*
)
nram_zero
,
deal_num
,
align_num
);
__bang_cycle_mul
((
T
*
)
nram_aux_a
,
(
T
*
)
nram_aux_a
,
(
T
*
)
nram_aux_b
,
deal_num
,
align_num
);
__bang_add
(
nram_dst
,
nram_dst
,
(
T
*
)
nram_aux_a
,
deal_num
);
__bang_cycle_gt
((
T
*
)
nram_aux_a
,
nram_dst
,
(
T
*
)
nram_zero
,
deal_num
,
align_num
);
__bang_mul
(
nram_dst
,
nram_dst
,
(
T
*
)
nram_aux_a
,
deal_num
);
#endif
}
else
{
#if __BANG_ARCH__ >= 300
__bang_relu
(
nram_dst
,
nram_src
,
deal_num
);
#else
__bang_active_relu
(
nram_dst
,
nram_src
,
deal_num
);
#endif
}
}
__mlu_func__
void
getComputeParamsBlockOrU1
(
const
int
input_dwidth
,
const
int
input_box_num
,
const
int
limit
,
const
int
core_limit
,
int
&
input_offset
,
int
&
max_seg_pad
,
int
&
repeat
,
int
&
remain
,
int
&
remain_pad
,
int
&
max_seg_iou_compute
,
int
&
repeat_iou_compute
,
int
&
remain_iou_compute
,
int
&
remain_pad_iou_compute
)
{
int
avg_core
=
input_box_num
/
core_limit
;
int
rem
=
input_box_num
%
core_limit
;
int
len_core
=
avg_core
+
(
coreId
<
rem
?
1
:
0
);
input_offset
=
avg_core
*
coreId
+
(
coreId
<=
rem
?
coreId
:
rem
);
max_seg_pad
=
NMS_DOWN
(
limit
,
NMS_SIZE
);
repeat
=
len_core
/
max_seg_pad
;
remain
=
len_core
%
max_seg_pad
;
remain_pad
=
NMS_UP
(
remain
,
NMS_SIZE
);
// if datatype is fp16, we should cvt to fp32 when compute iou
max_seg_iou_compute
=
NMS_DOWN
(
max_seg_pad
/
(
4
/
input_dwidth
),
NMS_SIZE
);
repeat_iou_compute
=
len_core
/
max_seg_iou_compute
;
remain_iou_compute
=
len_core
%
max_seg_iou_compute
;
remain_pad_iou_compute
=
NMS_UP
(
remain_iou_compute
,
NMS_SIZE
);
}
__mlu_func__
void
getComputeParamsUx
(
const
int
input_dwidth
,
const
int
input_num_boxes
,
const
int
limit
,
int
&
input_offset
,
int
&
max_seg_pad
,
int
&
repeat
,
int
&
remain
,
int
&
remain_pad
,
int
&
max_seg_iou_compute
,
int
&
repeat_iou_compute
,
int
&
remain_iou_compute
,
int
&
remain_pad_iou_compute
)
{
// data split
int
avg_cluster
=
input_num_boxes
/
clusterDim
;
int
rem_cluster
=
input_num_boxes
%
clusterDim
;
int
len_cluster
=
avg_cluster
+
(
clusterId
<
rem_cluster
);
int
cluster_offset
=
avg_cluster
*
clusterId
+
(
clusterId
<=
rem_cluster
?
clusterId
:
rem_cluster
);
int
avg_core
=
len_cluster
/
coreDim
;
int
rem_core
=
len_cluster
%
coreDim
;
int
len_core
=
avg_core
+
(
coreId
<
rem_core
);
int
core_offset
=
avg_core
*
coreId
+
(
coreId
<=
rem_core
?
coreId
:
rem_core
);
input_offset
=
cluster_offset
+
core_offset
;
max_seg_pad
=
NMS_DOWN
(
limit
,
NMS_SIZE
);
// core 0 of each cluster calculate the max score index
int
max_index_len_core
=
avg_cluster
+
(
clusterId
<
rem_cluster
);
repeat
=
max_index_len_core
/
max_seg_pad
;
remain
=
max_index_len_core
%
max_seg_pad
;
remain_pad
=
NMS_UP
(
remain
,
NMS_SIZE
);
// if datatype is fp16, we should cvt to fp32 when compute iou
max_seg_iou_compute
=
NMS_DOWN
(
max_seg_pad
/
(
sizeof
(
float
)
/
input_dwidth
),
NMS_SIZE
);
repeat_iou_compute
=
len_core
/
max_seg_iou_compute
;
remain_iou_compute
=
len_core
%
max_seg_iou_compute
;
remain_pad_iou_compute
=
NMS_UP
(
remain_iou_compute
,
NMS_SIZE
);
}
template
<
typename
IN_DT
>
__mlu_func__
void
findGlobalMaxBox
(
IN_DT
*
max_box
,
IN_DT
*
sram
,
IN_DT
*
inter_x1
)
{
// copy all partial max to the sram of cluster 0
if
(
clusterId
!=
0
)
{
__memcpy
(
sram
+
REDUCE_NUM
*
clusterId
,
sram
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
SRAM2SRAM
,
0
);
}
__sync_all
();
// reduce between clusters to get the global max box
if
(
clusterId
==
0
)
{
if
(
coreId
==
0
)
{
__bang_write_zero
(
inter_x1
,
NMS_SIZE
);
__memcpy
(
inter_x1
,
sram
,
sizeof
(
IN_DT
),
SRAM2NRAM
,
sizeof
(
IN_DT
),
REDUCE_NUM
*
sizeof
(
IN_DT
),
clusterDim
-
1
);
__bang_max
(
max_box
,
inter_x1
,
NMS_SIZE
);
int
max_cluster
=
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
?
((
uint16_t
*
)
max_box
)[
1
]
:
((
uint32_t
*
)
max_box
)[
1
];
__memcpy
(
max_box
,
sram
+
max_cluster
*
REDUCE_NUM
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
SRAM2NRAM
);
__memcpy
(
sram
,
max_box
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
NRAM2SRAM
);
}
__sync_cluster
();
if
(
coreId
==
0x80
&&
clusterDim
>
1
)
{
// broadcast global max box to each cluster's sram
for
(
int
cluster_idx
=
1
;
cluster_idx
<
clusterDim
;
++
cluster_idx
)
{
__memcpy
(
sram
,
sram
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
SRAM2SRAM
,
cluster_idx
);
}
}
__sync_cluster
();
}
__sync_all
();
// copy the global max box to max_box
__memcpy
(
max_box
,
sram
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
SRAM2NRAM
);
}
template
<
typename
IN_DT
>
__mlu_func__
void
findCoreMaxBox
(
IN_DT
*
input_score_ptr
,
IN_DT
*
score
,
IN_DT
*
inter_x1
,
IN_DT
*
max_box
,
const
IN_DT
*
input_x1_ptr
,
const
IN_DT
*
input_y1_ptr
,
const
IN_DT
*
input_x2_ptr
,
const
IN_DT
*
input_y2_ptr
,
const
mluMemcpyDirection_t
load_dir
,
const
int
input_offset
,
const
int
repeat
,
const
int
remain
,
const
int
remain_pad
,
const
int
max_seg_pad
,
int
&
max_index
)
{
if
(
coreId
!=
0x80
)
{
for
(
int
i
=
0
;
i
<=
repeat
;
i
++
)
{
if
(
i
==
repeat
&&
remain
==
0
)
{
break
;
}
int
seg_len
=
0
;
// the length every nms compute
int
cpy_len
=
0
;
// the length every nms memcpy
i
==
repeat
?
seg_len
=
remain_pad
:
seg_len
=
max_seg_pad
;
i
==
repeat
?
cpy_len
=
remain
:
cpy_len
=
max_seg_pad
;
/******NMS LOAD START******/
__bang_write_zero
(
score
,
seg_len
);
__memcpy
(
score
,
input_score_ptr
+
input_offset
+
i
*
max_seg_pad
,
cpy_len
*
sizeof
(
IN_DT
),
load_dir
,
cpy_len
*
sizeof
(
IN_DT
),
cpy_len
*
sizeof
(
IN_DT
),
0
);
/******NMS LOAD END******/
__bang_max
(
inter_x1
,
score
,
seg_len
);
if
(
inter_x1
[
0
]
>
max_box
[
0
])
{
max_box
[
0
]
=
inter_x1
[
0
];
if
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
{
max_index
=
((
uint16_t
*
)
inter_x1
)[
1
]
+
input_offset
+
i
*
max_seg_pad
;
// offset start from head of input_data
}
else
if
(
sizeof
(
IN_DT
)
==
sizeof
(
float
))
{
max_index
=
((
uint32_t
*
)
inter_x1
)[
1
]
+
input_offset
+
i
*
max_seg_pad
;
// offset start from head of input_data
}
}
}
// for repeat
// the max box's x1, y1, x2, y2 on every core
max_box
[
1
]
=
input_x1_ptr
[
max_index
];
max_box
[
2
]
=
input_y1_ptr
[
max_index
];
max_box
[
3
]
=
input_x2_ptr
[
max_index
];
max_box
[
4
]
=
input_y2_ptr
[
max_index
];
((
uint32_t
*
)(
max_box
+
5
))[
0
]
=
max_index
;
}
}
template
<
typename
IN_DT
>
__mlu_func__
void
findClusterMaxBox
(
IN_DT
*
sram
,
IN_DT
*
max_box
,
IN_DT
*
inter_x1
,
IN_DT
*
input_data_score
,
const
int
core_limit
)
{
// find the max with sram
// copy every core's box info to sram, form: score---x1---y1---x2---y2---
__memcpy
(
sram
+
REDUCE_NUM
*
coreId
,
max_box
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
NRAM2SRAM
);
// int32_t datatype
__sync_cluster
();
// copy score from sram to nram and find the max
__bang_write_zero
(
inter_x1
,
64
);
__memcpy
(
inter_x1
,
sram
,
sizeof
(
IN_DT
),
SRAM2NRAM
,
sizeof
(
IN_DT
),
REDUCE_NUM
*
sizeof
(
IN_DT
),
coreDim
-
1
);
__bang_max
(
max_box
,
inter_x1
,
64
);
int
max_core
=
sizeof
(
IN_DT
)
==
sizeof
(
half
)
?
((
uint16_t
*
)
max_box
)[
1
]
:
((
uint32_t
*
)
max_box
)[
1
];
// copy the max box to max_box
__memcpy
(
max_box
,
sram
+
max_core
*
REDUCE_NUM
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
SRAM2NRAM
);
}
/*****************************************************************************/
/*******************************CALCULATE MAX AREA****************************/
/*****************************************************************************/
template
<
typename
IN_DT
>
__mlu_func__
void
calMaxArea
(
IN_DT
*
max_box
,
const
int
algo
,
float
offset
,
float
&
max_area
)
{
if
(
algo
==
0
||
offset
==
0.0
)
{
max_area
=
((
float
)
max_box
[
3
]
-
(
float
)
max_box
[
1
])
*
((
float
)
max_box
[
4
]
-
(
float
)
max_box
[
2
]);
}
else
{
max_area
=
((
float
)
max_box
[
3
]
-
(
float
)
max_box
[
1
]
+
offset
)
*
((
float
)
max_box
[
4
]
-
(
float
)
max_box
[
2
]
+
offset
);
}
}
template
<
typename
IN_DT
>
__mlu_func__
void
calMaxArea
(
IN_DT
*
max_box
,
const
int
algo
,
float
offset
,
float
&
max_area
,
float
&
max_box_x1
,
float
&
max_box_y1
,
float
&
max_box_x2
,
float
&
max_box_y2
)
{
// the case of random inf will break the requirement of x1<=x2, y1<=y2
// so exchange it if it happens.
max_box_x1
=
float
(
max_box
[
1
]);
max_box_x2
=
float
(
max_box
[
3
]);
if
(
max_box
[
1
]
>
max_box
[
3
])
{
max_box_x1
=
float
(
max_box
[
3
]);
max_box_x2
=
float
(
max_box
[
1
]);
}
max_box_y1
=
float
(
max_box
[
2
]);
max_box_y2
=
float
(
max_box
[
4
]);
if
(
max_box
[
2
]
>
max_box
[
4
])
{
max_box_y1
=
float
(
max_box
[
4
]);
max_box_y2
=
float
(
max_box
[
2
]);
}
if
(
algo
==
0
||
offset
==
0.0
)
{
max_area
=
(
max_box_x2
-
max_box_x1
)
*
(
max_box_y2
-
max_box_y1
);
}
else
{
max_area
=
(
max_box_x2
-
max_box_x1
+
offset
)
*
(
max_box_y2
-
max_box_y1
+
offset
);
}
}
/***********************************************************************/
/*******************************STORE RESULT****************************/
/***********************************************************************/
template
<
typename
IN_DT
,
typename
OUT_DT
>
__mlu_func__
void
storeResult
(
IN_DT
*
max_box
,
OUT_DT
*
nram_save
,
OUT_DT
*&
output_dram
,
const
int
keep
,
const
int
nram_save_limit_count
,
const
int
max_output_size
,
const
float
thresh_score
,
const
int
output_mode
,
int
&
nram_save_count
,
uint32_t
&
output_box_num
)
{
/******NMS STORE START******/
// store to nram
if
(
float
(
max_box
[
0
])
>
thresh_score
)
{
OUT_DT
*
save_ptr
;
int
save_offset
=
0
;
int
save_str_num
=
0
;
save_ptr
=
nram_save
;
save_offset
=
nram_save_count
;
save_str_num
=
nram_save_limit_count
;
if
(
clusterId
==
0
&&
coreId
==
0
)
{
if
(
output_mode
==
0
)
{
// index1, index2, ...
save_ptr
[
save_offset
]
=
((
uint32_t
*
)(
max_box
+
INFO_NUM
))[
0
];
}
else
if
(
output_mode
==
1
)
{
// score, x1, y1, x2, y2
__memcpy
(
save_ptr
+
save_offset
*
INFO_NUM
,
max_box
,
INFO_NUM
*
sizeof
(
IN_DT
),
NRAM2NRAM
,
INFO_NUM
*
sizeof
(
IN_DT
),
INFO_NUM
*
sizeof
(
IN_DT
),
0
);
}
else
if
(
output_mode
==
2
)
{
// score---, x1---, y1---, x2---, y2---
__memcpy
(
save_ptr
+
save_offset
,
max_box
,
1
*
sizeof
(
IN_DT
),
NRAM2NRAM
,
save_str_num
*
sizeof
(
IN_DT
),
1
*
sizeof
(
IN_DT
),
4
);
}
}
nram_save_count
++
;
output_box_num
++
;
}
// store to sram/gdram
if
(
output_box_num
!=
0
)
{
if
((
nram_save_count
==
nram_save_limit_count
)
||
(
float
(
max_box
[
0
])
<=
thresh_score
)
||
keep
==
max_output_size
-
1
)
{
if
(
nram_save_count
!=
0
)
{
if
(
clusterId
==
0
&&
coreId
==
0
)
{
if
(
output_mode
==
0
)
{
// index1, index2, ...
pvLock
();
__memcpy
(
output_dram
,
nram_save
,
nram_save_count
*
sizeof
(
uint32_t
),
NRAM2GDRAM
);
pvUnlock
();
output_dram
+=
nram_save_count
;
}
else
if
(
output_mode
==
1
)
{
// score, x1, y1, x2, y2
pvLock
();
__memcpy
(
output_dram
,
nram_save
,
nram_save_count
*
INFO_NUM
*
sizeof
(
IN_DT
),
NRAM2GDRAM
);
pvUnlock
();
output_dram
+=
nram_save_count
*
INFO_NUM
;
}
else
if
(
output_mode
==
2
)
{
// score---, x1---, y1---, x2---, y2---
pvLock
();
__memcpy
(
output_dram
,
nram_save
,
nram_save_count
*
sizeof
(
IN_DT
),
NRAM2GDRAM
,
max_output_size
*
sizeof
(
IN_DT
),
nram_save_limit_count
*
sizeof
(
IN_DT
),
4
);
pvUnlock
();
output_dram
+=
nram_save_count
;
}
nram_save_count
=
0
;
}
}
}
// if move data nram->sram/gdram
}
// if dst
}
template
<
typename
IN_DT
,
typename
OUT_DT
>
__mlu_func__
void
scoreUpdate
(
IN_DT
*
input_score_ptr
,
const
mluMemcpyDirection_t
load_dir
,
const
mluMemcpyDirection_t
store_dir
,
const
IN_DT
*
input_x1_ptr
,
const
IN_DT
*
input_y1_ptr
,
const
IN_DT
*
input_x2_ptr
,
const
IN_DT
*
input_y2_ptr
,
IN_DT
*
x1
,
IN_DT
*
y1
,
IN_DT
*
x2
,
IN_DT
*
y2
,
IN_DT
*
score
,
IN_DT
*
inter_x1
,
IN_DT
*
inter_y1
,
IN_DT
*
inter_x2
,
IN_DT
*
inter_y2
,
IN_DT
*
max_box
,
const
float
max_box_x1
,
const
float
max_box_y1
,
const
float
max_box_x2
,
const
float
max_box_y2
,
OUT_DT
*
nram_save
,
int
repeat_iou_compute
,
int
remain_iou_compute
,
int
remain_pad_iou_compute
,
int
max_seg_iou_compute
,
int
max_seg_pad
,
const
float
thresh_iou
,
const
float
div_thresh_iou
,
const
int
input_offset
,
const
float
offset
,
const
float
max_area
,
const
int
input_num_boxes
,
const
int
algo
)
{
for
(
int
i
=
0
;
i
<=
repeat_iou_compute
;
i
++
)
{
if
(
i
==
repeat_iou_compute
&&
remain_iou_compute
==
0
)
{
break
;
}
int
seg_len
=
(
i
==
repeat_iou_compute
)
?
remain_pad_iou_compute
:
max_seg_iou_compute
;
int
cpy_len
=
(
i
==
repeat_iou_compute
)
?
remain_iou_compute
:
max_seg_iou_compute
;
/******NMS LOAD START******/
int
dt_offset
=
0
;
if
(
sizeof
(
IN_DT
)
==
sizeof
(
float
))
{
__memcpy
(
score
,
input_score_ptr
+
input_offset
+
i
*
max_seg_pad
,
cpy_len
*
sizeof
(
IN_DT
),
load_dir
,
cpy_len
*
sizeof
(
IN_DT
),
cpy_len
*
sizeof
(
IN_DT
),
0
);
dt_offset
=
0
;
}
else
if
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
{
__memcpy
(
x1
,
input_score_ptr
+
input_offset
+
i
*
max_seg_iou_compute
,
cpy_len
*
sizeof
(
IN_DT
),
load_dir
,
cpy_len
*
sizeof
(
IN_DT
),
cpy_len
*
sizeof
(
IN_DT
),
0
);
__bang_half2float
((
float
*
)
score
,
(
half
*
)
x1
,
seg_len
);
dt_offset
=
max_seg_iou_compute
;
}
#if __BANG_ARCH__ >= 300
__memcpy
(
inter_x1
+
dt_offset
,
input_x1_ptr
+
input_offset
+
i
*
max_seg_iou_compute
,
cpy_len
*
sizeof
(
IN_DT
),
load_dir
,
max_seg_pad
*
sizeof
(
IN_DT
),
input_num_boxes
*
sizeof
(
IN_DT
),
3
);
if
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
{
__bang_half2float
((
float
*
)
inter_x1
,
(
half
*
)
inter_x1
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
inter_y1
,
(
half
*
)
inter_y1
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
inter_x2
,
(
half
*
)
inter_x2
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
inter_y2
,
(
half
*
)
inter_y2
+
max_seg_iou_compute
,
seg_len
);
}
// box transfer
__bang_minequal
((
float
*
)
x1
,
(
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
seg_len
);
__bang_maxequal
((
float
*
)
x2
,
(
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
seg_len
);
__bang_minequal
((
float
*
)
y1
,
(
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
seg_len
);
__bang_maxequal
((
float
*
)
y2
,
(
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
seg_len
);
// 1、 compute IOU
// get the area_I
__bang_maxeq_scalar
((
float
*
)
inter_x1
,
(
float
*
)
x1
,
max_box_x1
,
seg_len
);
// inter_x1
__bang_mineq_scalar
((
float
*
)
inter_x2
,
(
float
*
)
x2
,
max_box_x2
,
seg_len
);
// inter_x2
__bang_sub
((
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x1
,
seg_len
);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_add_scalar
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
offset
,
seg_len
);
}
computeReluN
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
NULL
,
seg_len
);
// inter_w
__bang_maxeq_scalar
((
float
*
)
inter_y1
,
(
float
*
)
y1
,
float
(
max_box_y1
),
seg_len
);
// inter_y1
__bang_mineq_scalar
((
float
*
)
inter_y2
,
(
float
*
)
y2
,
float
(
max_box_y2
),
seg_len
);
// inter_y2
__bang_sub
((
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
(
float
*
)
inter_y1
,
seg_len
);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_add_scalar
((
float
*
)
inter_y1
,
(
float
*
)
inter_y1
,
offset
,
seg_len
);
}
computeReluN
((
float
*
)
inter_y1
,
(
float
*
)
inter_y1
,
NULL
,
seg_len
);
// inter_h
__bang_mul
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
(
float
*
)
inter_y1
,
seg_len
);
// area_I
// get the area of input_box: area = (x2 - x1) * (y2 - y1);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_fusion
(
FUSION_FSA
,
(
float
*
)
inter_y1
,
(
float
*
)
x2
,
(
float
*
)
x1
,
offset
,
seg_len
,
seg_len
);
__bang_fusion
(
FUSION_FSA
,
(
float
*
)
inter_y2
,
(
float
*
)
y2
,
(
float
*
)
y1
,
offset
,
seg_len
,
seg_len
);
__bang_mul
((
float
*
)
inter_x2
,
(
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
seg_len
);
// area
}
else
{
__bang_sub
((
float
*
)
inter_y1
,
(
float
*
)
x2
,
(
float
*
)
x1
,
seg_len
);
__bang_fusion
(
FUSION_FSM
,
(
float
*
)
inter_x2
,
(
float
*
)
y2
,
(
float
*
)
y1
,
(
float
*
)
inter_y1
,
seg_len
,
seg_len
);
}
// get the area_U: area + max_area - area_I
__bang_fusion
(
FUSION_FAS
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x2
,
max_area
,
(
float
*
)
inter_x1
,
seg_len
,
seg_len
);
// 2、 select the box
// if IOU greater than thres, set the score to zero, abort it: area_U >
// area_I * (1 / thresh)?
if
(
thresh_iou
>
0.0
)
{
__bang_mul_scalar
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
div_thresh_iou
,
seg_len
);
}
else
{
__bang_mul_scalar
((
float
*
)
inter_x2
,
(
float
*
)
inter_x2
,
thresh_iou
,
seg_len
);
}
// process for nan
__bang_lt
((
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x1
,
seg_len
);
__bang_not
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
seg_len
);
__bang_mul
((
float
*
)
score
,
(
float
*
)
score
,
(
float
*
)
inter_x1
,
seg_len
);
/******NMS COMPUTE END******/
#else
__memcpy
(
x1
+
dt_offset
,
input_x1_ptr
+
input_offset
+
i
*
max_seg_iou_compute
,
cpy_len
*
sizeof
(
IN_DT
),
load_dir
,
max_seg_pad
*
sizeof
(
IN_DT
),
input_num_boxes
*
sizeof
(
IN_DT
),
3
);
if
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
{
__bang_half2float
((
float
*
)
x1
,
(
half
*
)
x1
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
y1
,
(
half
*
)
y1
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
x2
,
(
half
*
)
x2
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
y2
,
(
half
*
)
y2
+
max_seg_iou_compute
,
seg_len
);
}
// 1、 compute IOU
// get the area_I
__bang_write_value
((
float
*
)
inter_y1
,
seg_len
,
float
(
max_box
[
1
]));
// max_x1
__bang_maxequal
((
float
*
)
inter_x1
,
(
float
*
)
x1
,
(
float
*
)
inter_y1
,
seg_len
);
// inter_x1
__bang_write_value
((
float
*
)
inter_y2
,
seg_len
,
float
(
max_box
[
3
]));
// max_x2
__bang_minequal
((
float
*
)
inter_x2
,
(
float
*
)
x2
,
(
float
*
)
inter_y2
,
seg_len
);
// inter_x2
__bang_sub
((
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x1
,
seg_len
);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_add_scalar
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
offset
,
seg_len
);
}
computeReluN
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
NULL
,
seg_len
);
// inter_w
__bang_write_value
((
float
*
)
inter_x2
,
seg_len
,
float
(
max_box
[
2
]));
// max_y1
__bang_maxequal
((
float
*
)
inter_y1
,
(
float
*
)
y1
,
(
float
*
)
inter_x2
,
seg_len
);
// inter_y1
__bang_write_value
((
float
*
)
inter_x2
,
seg_len
,
float
(
max_box
[
4
]));
// max_y2
__bang_minequal
((
float
*
)
inter_y2
,
(
float
*
)
y2
,
(
float
*
)
inter_x2
,
seg_len
);
// inter_y2
__bang_sub
((
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
(
float
*
)
inter_y1
,
seg_len
);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_add_scalar
((
float
*
)
inter_y1
,
(
float
*
)
inter_y1
,
offset
,
seg_len
);
}
computeReluN
((
float
*
)
inter_y1
,
(
float
*
)
inter_y1
,
NULL
,
seg_len
);
// inter_h
__bang_mul
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
(
float
*
)
inter_y1
,
seg_len
);
// area_I
// get the area of input_box: area = (x2 - x1) * (y2 - y1);
__bang_sub
((
float
*
)
inter_y1
,
(
float
*
)
x2
,
(
float
*
)
x1
,
seg_len
);
__bang_sub
((
float
*
)
inter_y2
,
(
float
*
)
y2
,
(
float
*
)
y1
,
seg_len
);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_add_scalar
((
float
*
)
inter_y1
,
(
float
*
)
inter_y1
,
offset
,
seg_len
);
__bang_add_scalar
((
float
*
)
inter_y2
,
(
float
*
)
inter_y2
,
offset
,
seg_len
);
}
__bang_mul
((
float
*
)
inter_x2
,
(
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
seg_len
);
// area
// get the area_U: area + max_area - area_I
__bang_add_scalar
((
float
*
)
inter_x2
,
(
float
*
)
inter_x2
,
float
(
max_area
),
seg_len
);
__bang_sub
((
float
*
)
inter_x2
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x1
,
seg_len
);
// area_U
// 2、 select the box
// if IOU greater than thresh, set the score to zero, abort it: area_U >
// area_I * (1 / thresh)?
if
(
thresh_iou
>
0.0
)
{
__bang_mul_scalar
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
div_thresh_iou
,
seg_len
);
}
else
{
__bang_mul_scalar
((
float
*
)
inter_x2
,
(
float
*
)
inter_x2
,
thresh_iou
,
seg_len
);
}
__bang_ge
((
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x1
,
seg_len
);
__bang_mul
((
float
*
)
score
,
(
float
*
)
score
,
(
float
*
)
inter_x1
,
seg_len
);
/******NMS COMPUTE END******/
#endif
// update the score
if
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
{
convertFloat2half
((
half
*
)
score
,
(
float
*
)
score
,
seg_len
);
}
pvLock
();
__memcpy
(
input_score_ptr
+
input_offset
+
i
*
max_seg_iou_compute
,
score
,
cpy_len
*
sizeof
(
IN_DT
),
store_dir
,
cpy_len
*
sizeof
(
IN_DT
),
cpy_len
*
sizeof
(
IN_DT
),
0
);
pvUnlock
();
}
}
#endif // NMS_UTILS_HPP_
mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#define ROI_OFFSET 5
__nram__ char buffer[MAX_NRAM_SIZE];
namespace forward {
template <typename T>
__mlu_func__ void bilinearInterpolate(const int input_height,
const int input_width, T y, T x, T *w1,
T *w2, T *w3, T *w4, int *x_low,
int *x_high, int *y_low, int *y_high,
bool *empty) {
// deal with cases that inverse elements are of feature map boundary
if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) {
*empty = true;
return;
}
if (y <= 0) y = 0;
if (x <= 0) x = 0;
int y_low_ = int(y);
int x_low_ = int(x);
if (y_low_ >= input_height - 1) {
*y_high = y_low_ = input_height - 1;
y = (T)y_low_;
} else {
*y_high = y_low_ + 1;
}
if (x_low_ >= input_width - 1) {
*x_high = x_low_ = input_width - 1;
x = T(x_low_);
} else {
*x_high = x_low_ + 1;
}
*y_low = y_low_;
*x_low = x_low_;
T ly = y - y_low_;
T lx = x - x_low_;
T hy = 1.0 - ly;
T hx = 1.0 - lx;
*w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
return;
}
template <typename T>
__mlu_func__ void computeChannel(T *input_core, T *nram_in, T *output_core,
T *nram_out, const int roi_bin_grid_h,
const int roi_bin_grid_w, const T roi_start_h,
const T roi_start_w, const int ph,
const int pw, const T bin_size_h,
const T bin_size_w, const float count,
const int input_height, const int input_width,
const int channels, const int cyc_num,
const int max_elements) {
int cyc_channel = max_elements;
for (int i = 0; i < cyc_num; i++) {
int real_channel =
(i == cyc_num - 1) ? channels - i * cyc_channel : cyc_channel;
int align_channel = PAD_UP(real_channel, NFU_ALIGN_SIZE / sizeof(T));
__bang_write_zero(nram_out, align_channel);
uint32_t real_size = real_channel * sizeof(T);
int iy, ix;
for (iy = 0; iy < roi_bin_grid_h; iy++) {
// 1. compute the coordinates of the y axis in the current roi_bin_grid_h
T y = roi_start_h + ph * bin_size_h +
(T)(iy + 0.5) * bin_size_h / (T)(roi_bin_grid_h);
for (ix = 0; ix < roi_bin_grid_w; ix++) {
// 2. compute the coordinates of the x axis in the current
// roi_bin_grid_w
T x = roi_start_w + pw * bin_size_w +
(T)(ix + 0.5) * bin_size_w / (T)(roi_bin_grid_w);
// 3. compute the four weights (w1, w2, w3 and w4), the height (y_low
// and y_high) and weight (x_low and x_high) of input feature map in
// the current roi bin grid, and the flag (empty) which shows if x, y
// are out of input feature map ranges
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bool empty = false;
bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, &w3, &w4,
&x_low, &x_high, &y_low, &y_high, &empty);
// 4. compute interpolation of the current roi bin grid
// tmp_cyc1, temp_cyc2, tmp_cyc3 and tmp_cyc4 store the input values
// to compute the interpolation, and then reused to compute
// the argmax_x and argmax_y.
T *tmp_cyc1 = nram_in + cyc_channel;
T *tmp_cyc2 = nram_in + cyc_channel * 2;
T *tmp_cyc3 = nram_in + cyc_channel * 3;
T *tmp_cyc4 = nram_in + cyc_channel * 4;
if (empty) { // exits abnormal values
__bang_write_zero(nram_in, align_channel);
} else {
__bang_write_zero(nram_in, align_channel);
uint32_t offset1 = (y_low * input_width + x_low) * channels;
uint32_t offset2 = (y_low * input_width + x_high) * channels;
uint32_t offset3 = (y_high * input_width + x_low) * channels;
uint32_t offset4 = (y_high * input_width + x_high) * channels;
T *input1 = (T *)input_core + offset1 + i * cyc_channel;
T *input2 = (T *)input_core + offset2 + i * cyc_channel;
T *input3 = (T *)input_core + offset3 + i * cyc_channel;
T *input4 = (T *)input_core + offset4 + i * cyc_channel;
// load the four pixels (p1, p2, p3 and p4) of input feature map to
// compute interpolation
__memcpy(tmp_cyc1, input1, real_size, GDRAM2NRAM);
__memcpy(tmp_cyc2, input2, real_size, GDRAM2NRAM);
__memcpy(tmp_cyc3, input3, real_size, GDRAM2NRAM);
__memcpy(tmp_cyc4, input4, real_size, GDRAM2NRAM);
// interpolation value = w1 * p1 + w2 * p2 + w3 * p3 + w4 * p4
__bang_mul_scalar(tmp_cyc1, tmp_cyc1, w1, align_channel);
__bang_mul_scalar(tmp_cyc2, tmp_cyc2, w2, align_channel);
__bang_mul_scalar(tmp_cyc3, tmp_cyc3, w3, align_channel);
__bang_mul_scalar(tmp_cyc4, tmp_cyc4, w4, align_channel);
__bang_add(nram_in, tmp_cyc1, nram_in, align_channel);
__bang_add(nram_in, tmp_cyc2, nram_in, align_channel);
__bang_add(nram_in, tmp_cyc3, nram_in, align_channel);
__bang_add(nram_in, tmp_cyc4, nram_in, align_channel);
}
// 5. compute sum value and corresponding coordinates of x axis and y
// axis. Update the sum value.
__bang_add(nram_out, nram_in, nram_out, align_channel);
} // loop_roi_grid_w
} // loop_roi_grid_h
T count_value = (T)(1.0 / count);
__bang_mul_scalar(nram_out, nram_out, count_value, align_channel);
__memcpy(output_core + i * cyc_channel, nram_out, real_size, NRAM2GDRAM);
} // loop_cyc_num
}
template <typename T>
__mlu_func__ void roialignForwardAvg(
T *input, T *rois, T *output, const bool aligned, const int channels,
const int pooled_height, const int pooled_width, const int input_height,
const int input_width, const int sampling_ratio, const T spatial_scale,
const int num_rois) {
// find limit for channel, the nram space is divided to 6 parts that are
// input, 4 weights to compute the interpolation (w1, w2, w3, w4), output
// max_elements : 300 : float datatype : 27296, half datatype : 54592
// max_elements : 200 : float datatype : 16384, half datatype : 32768
int max_elements = (PAD_DOWN(MAX_NRAM_SIZE / 6, NFU_ALIGN_SIZE)) / sizeof(T);
int cyc_num = channels / max_elements + (int)(channels % max_elements != 0);
T offset = aligned ? (T)0.5 : (T)0.0;
int task_num = num_rois * pooled_height * pooled_width;
T *nram_out = (T *)buffer;
T *nram_in = nram_out + max_elements;
if (task_num < taskDim) {
if (taskId >= task_num) {
return;
}
}
for (int bin_idx = taskId; bin_idx < task_num; bin_idx = bin_idx + taskDim) {
if (bin_idx >= task_num) {
return;
}
// (n,ph.pw) is a c in the pooled output
int pw = bin_idx % pooled_width;
int ph = (bin_idx / pooled_width) % pooled_height;
int n = bin_idx / pooled_width / pooled_height;
T *roi_id_tmp = rois + n * ROI_OFFSET;
// 1. compute width and height of roi region.
int batch_idx = (int)roi_id_tmp[0];
T roi_x1 = roi_id_tmp[1];
T roi_y1 = roi_id_tmp[2];
T roi_x2 = roi_id_tmp[3];
T roi_y2 = roi_id_tmp[4];
T roi_start_w = roi_x1 * spatial_scale - offset;
T roi_start_h = roi_y1 * spatial_scale - offset;
T roi_end_w = roi_x2 * spatial_scale - offset;
T roi_end_h = roi_y2 * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
roi_width = roi_width > (T)(1.0) ? roi_width : (T)(1.0);
roi_height = roi_height > (T)(1.0) ? roi_height : (T)(1.0);
}
// 2. compute float-type width and height of roi bin region.
T bin_size_w = (T)roi_width / (T)pooled_width;
T bin_size_h = (T)roi_height / (T)pooled_height;
// 3. compute int-type width and height of roi bin region.
int roi_bin_grid_h, roi_bin_grid_w;
roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: int(ceilf(roi_height / pooled_height));
roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: int(ceilf(roi_width / pooled_width));
float count = (float)((roi_bin_grid_h * roi_bin_grid_w) > 1
? roi_bin_grid_h * roi_bin_grid_w
: 1.0);
T *input_core = input + batch_idx * channels * input_width * input_height;
T *output_core = output + bin_idx * channels;
// 4. compute avg value and corresponding coordinates of x axis and y axis.
computeChannel(input_core, nram_in, output_core, nram_out, roi_bin_grid_h,
roi_bin_grid_w, roi_start_h, roi_start_w, ph, pw, bin_size_h,
bin_size_w, count, input_height, input_width, channels,
cyc_num, max_elements);
}
}
__mlu_global__ void MLUUnion1KernelRoiAlignAvg(
const void *input, const void *rois, const int channels, const bool aligned,
const int pooled_height, const int pooled_width, const int input_height,
const int input_width, const int sampling_ratio, const float spatial_scale,
const int num_rois, const cnrtDataType_t data_type, void *output) {
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
switch (data_type) {
case CNRT_FLOAT16: {
roialignForwardAvg((half *)input, (half *)rois, (half *)output, aligned,
channels, pooled_height, pooled_width, input_height,
input_width, sampling_ratio, (half)spatial_scale,
num_rois);
}; break;
case CNRT_FLOAT32: {
roialignForwardAvg((float *)input, (float *)rois, (float *)output,
aligned, channels, pooled_height, pooled_width,
input_height, input_width, sampling_ratio,
(float)spatial_scale, num_rois);
}; break;
default:
break;
}
return;
}
} // namespace forward
namespace backward {
__mlu_func__ void bilinearInterpolateGradient(int height, int width, float y,
float x, float *w1, float *w2,
float *w3, float *w4, int *x_low,
int *x_high, int *y_low,
int *y_high) {
if (y < -1.0 || y > height || x < -1.0 || x > width) {
*w1 = 0.0, *w2 = 0.0, *w3 = 0.0, *w4 = 0.0;
*x_low = -1, *x_high = -1, *y_low = -1, *y_high = -1;
return;
}
if (y <= 0) {
y = 0;
}
if (x <= 0) {
x = 0;
}
*y_low = (int)y;
*x_low = (int)x;
if (*y_low >= height - 1) {
*y_high = height - 1, *y_low = height - 1;
y = (float)(*y_low);
} else {
*y_high = *y_low + 1;
}
if (*x_low >= width - 1) {
*x_high = width - 1, *x_low = width - 1;
x = (float)(*x_low);
} else {
*x_high = *x_low + 1;
}
float ly = y - *y_low, lx = x - *x_low;
float hy = 1.0 - ly, hx = 1.0 - lx;
*w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
return;
}
template <typename T>
__mlu_func__ void unionRoiAlignBp(
T *grads, T *boxes, T *grads_image, const int boxes_num, const int hi,
const int wi, const int c, const int no, const int ho, const int wo,
const float spatial_scale, const int sampling_ratio, const bool aligned) {
int c_align = PAD_UP(c, NFU_ALIGN_SIZE / sizeof(T));
int deal_all = boxes_num * hi * wi;
int deal_this_core = deal_all / taskDim + (int)(taskId < deal_all % taskDim);
for (int i = 0; i < deal_this_core; ++i) {
int bhw_id = i * taskDim + taskId;
int box_id = bhw_id / (hi * wi);
int ih = (bhw_id / wi) % hi;
int iw = bhw_id % wi;
T *box = boxes + box_id * 5;
int image_id = (int)box[0];
T *image_offset = grads_image + image_id * ho * wo * c;
T *grads_ = grads + box_id * hi * wi * c + ih * wi * c + iw * c;
float offset = aligned ? 0.5 : 0.0;
float x1 = box[1] * spatial_scale - offset;
float y1 = box[2] * spatial_scale - offset;
float x2 = box[3] * spatial_scale - offset;
float y2 = box[4] * spatial_scale - offset;
float roi_width = x2 - x1;
float roi_height = y2 - y1;
if (!aligned) {
roi_width = (roi_width > 1.0) ? roi_width : 1.0;
roi_height = (roi_height > 1.0) ? roi_height : 1.0;
}
float bin_size_h = roi_height / hi;
float bin_size_w = roi_width / wi;
int roi_grid_h =
(sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_height / hi);
int roi_grid_w =
(sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / wi);
const T count = roi_grid_h * roi_grid_w;
if (c_align * sizeof(T) * 2 <= MAX_NRAM_SIZE) {
for (int iy = 0; iy < roi_grid_h; ++iy) {
const float y =
y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h;
for (int ix = 0; ix < roi_grid_w; ++ix) {
const float x =
x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w;
float w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, &x_low,
&x_high, &y_low, &y_high);
if (x_low >= 0 && y_low >= 0) {
__memcpy(buffer, grads_, c * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w1,
c_align);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align,
1 / count, c_align);
__bang_atomic_add((T *)buffer + c_align,
image_offset + y_low * wo * c + x_low * c,
(T *)buffer + c_align, c);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w2,
c_align);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align,
1 / count, c_align);
__bang_atomic_add((T *)buffer + c_align,
image_offset + y_low * wo * c + x_high * c,
(T *)buffer + c_align, c);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w3,
c_align);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align,
1 / count, c_align);
__bang_atomic_add((T *)buffer + c_align,
image_offset + y_high * wo * c + x_low * c,
(T *)buffer + c_align, c);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w4,
c_align);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align,
1 / count, c_align);
__bang_atomic_add((T *)buffer + c_align,
image_offset + y_high * wo * c + x_high * c,
(T *)buffer + c_align, c);
} // x_low && y_low
} // ix
} // iy
} else {
for (int iy = 0; iy < roi_grid_h; ++iy) {
const float y =
y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h;
for (int ix = 0; ix < roi_grid_w; ++ix) {
const float x =
x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w;
float w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, &x_low,
&x_high, &y_low, &y_high);
if (x_low >= 0 && y_low >= 0) {
int deal_once =
PAD_DOWN(MAX_NRAM_SIZE / 2, NFU_ALIGN_SIZE) / sizeof(T);
int c_repeat = c / deal_once + (int)(c % deal_once != 0);
for (int i = 0; i < c_repeat; ++i) {
int deal_c = deal_once;
int align_c = deal_once;
if (i == c_repeat - 1) {
deal_c = c - i * deal_once;
align_c = c_align - i * deal_once;
}
__memcpy(buffer, grads_ + i * deal_once, deal_c * sizeof(T),
GDRAM2NRAM);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w1,
align_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c,
1 / count, align_c);
__bang_atomic_add(
(T *)buffer + align_c,
image_offset + y_low * wo * c + x_low * c + i * deal_once,
(T *)buffer + align_c, deal_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w2,
align_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c,
1 / count, align_c);
__bang_atomic_add(
(T *)buffer + align_c,
image_offset + y_low * wo * c + x_high * c + i * deal_once,
(T *)buffer + align_c, deal_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w3,
align_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c,
1 / count, align_c);
__bang_atomic_add(
(T *)buffer + align_c,
image_offset + y_high * wo * c + x_low * c + i * deal_once,
(T *)buffer + align_c, deal_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w4,
align_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c,
1 / count, align_c);
__bang_atomic_add(
(T *)buffer + align_c,
image_offset + y_high * wo * c + x_high * c + i * deal_once,
(T *)buffer + align_c, deal_c);
} // for c_repeat
} // x_low >= 0 && y_low >= 0
} // ix
} // iy
} // if c
} // i
}
__mlu_global__ void MLUUnion1KernelRoiAlignBackward(
const void *grads, const void *boxes, void *grads_image,
const cnrtDataType_t dtype, const int boxes_num, const int hi, const int wi,
const int c, const int no, const int ho, const int wo,
const float spatial_scale, const int sampling_ratio, const bool aligned) {
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
switch (dtype) {
case CNRT_FLOAT16: {
unionRoiAlignBp((half *)grads, (half *)boxes, (half *)grads_image,
boxes_num, hi, wi, c, no, ho, wo, spatial_scale,
sampling_ratio, aligned);
}; break;
case CNRT_FLOAT32: {
unionRoiAlignBp((float *)grads, (float *)boxes, (float *)grads_image,
boxes_num, hi, wi, c, no, ho, wo, spatial_scale,
sampling_ratio, aligned);
}; break;
default: { return; }
}
}
} // namespace backward
void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const void *input, const void *rois, const int channels,
const bool aligned, const int pooled_height,
const int pooled_width, const int input_height,
const int input_width, const int sampling_ratio,
const float spatial_scale, const int num_rois,
void *output) {
forward::MLUUnion1KernelRoiAlignAvg<<<k_dim, k_type, queue>>>(
input, rois, channels, aligned, pooled_height, pooled_width, input_height,
input_width, sampling_ratio, spatial_scale, num_rois, d_type, output);
}
void KernelRoiAlignBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t dtype,
const void *grads, const void *boxes,
void *grads_image, const int boxes_num,
const int hi, const int wi, const int c,
const int no, const int ho, const int wo,
const float spatial_scale, const int sampling_ratio,
const bool aligned) {
backward::MLUUnion1KernelRoiAlignBackward<<<k_dim, k_type, queue>>>(
grads, boxes, grads_image, dtype, boxes_num, hi, wi, c, no, ho, wo,
spatial_scale, sampling_ratio, aligned);
}
mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#define ROI_OFFSET 7
#define FLOAT_NRAM_BUFFER_NUM 14
#define HALF_NRAM_BUFFER_NUM 25
#define ALIGN_NUM 64
__nram__ char data_nram[MAX_NRAM_SIZE];
template <typename T>
__mlu_global__ void MLUUnion1KernelPtsIdxOfVoxels(
const int pool_method, const int boxes_num, const int pts_num,
const int max_pts_each_voxel, const int out_x, const int out_y,
const int out_z, const T *rois, const T *pts, int *pts_idx_of_voxels) {
// params (T)rois: (boxes_num, 7)
// params (T)pts: (3, pts_num)
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
// max_pts_each_voxel)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int nram_pts_num = 0;
if (sizeof(T) == sizeof(float)) {
nram_pts_num = PAD_DOWN(
(MAX_NRAM_SIZE / sizeof(float) / FLOAT_NRAM_BUFFER_NUM), ALIGN_NUM);
} else {
nram_pts_num = PAD_DOWN(
(MAX_NRAM_SIZE / sizeof(half) / HALF_NRAM_BUFFER_NUM), ALIGN_NUM);
}
char *X = NULL;
char *Y = NULL;
char *Z = NULL;
char *local_X = NULL;
char *local_Y = NULL;
char *local_Z = NULL;
char *nram_pts_in_flag = NULL;
float *temp_buffer1 = NULL;
float *temp_buffer2 = NULL;
float *temp_buffer3 = NULL;
float *temp_buffer4 = NULL;
float *temp_buffer5 = NULL;
float *nram_voxel_offset = NULL;
int *nram_pts_idx_seq = NULL;
float *fp_local_X = NULL;
float *fp_local_Y = NULL;
float *fp_local_Z = NULL;
float *fp_nram_pts_in_flag = NULL;
if (sizeof(T) == sizeof(float)) {
X = (char *)((float *)data_nram);
Y = (char *)((float *)data_nram + nram_pts_num);
Z = (char *)((float *)data_nram + nram_pts_num * 2);
local_X = (char *)((float *)data_nram + nram_pts_num * 3);
local_Y = (char *)((float *)data_nram + nram_pts_num * 4);
local_Z = (char *)((float *)data_nram + nram_pts_num * 5);
nram_pts_in_flag = (char *)((float *)data_nram + nram_pts_num * 6);
temp_buffer1 = (float *)data_nram + nram_pts_num * 7;
temp_buffer2 = (float *)data_nram + nram_pts_num * 8;
temp_buffer3 = (float *)data_nram + nram_pts_num * 9;
temp_buffer4 = (float *)data_nram + nram_pts_num * 10;
temp_buffer5 = (float *)data_nram + nram_pts_num * 11;
nram_voxel_offset = (float *)data_nram + nram_pts_num * 12;
nram_pts_idx_seq = (int *)((float *)data_nram + nram_pts_num * 13);
fp_local_X = (float *)local_X;
fp_local_Y = (float *)local_Y;
fp_local_Z = (float *)local_Z;
fp_nram_pts_in_flag = (float *)nram_pts_in_flag;
} else {
X = (char *)((half *)data_nram);
Y = (char *)((half *)data_nram + nram_pts_num);
Z = (char *)((half *)data_nram + nram_pts_num * 2);
local_X = (char *)((half *)data_nram + nram_pts_num * 4);
local_Y = (char *)((half *)data_nram + nram_pts_num * 6);
local_Z = (char *)((half *)data_nram + nram_pts_num * 8);
nram_pts_in_flag = (char *)((half *)data_nram + nram_pts_num * 10);
temp_buffer1 = (float *)((half *)data_nram + nram_pts_num * 11);
temp_buffer2 = (float *)((half *)data_nram + nram_pts_num * 13);
temp_buffer3 = (float *)((half *)data_nram + nram_pts_num * 15);
temp_buffer4 = (float *)((half *)data_nram + nram_pts_num * 17);
temp_buffer5 = (float *)((half *)data_nram + nram_pts_num * 19);
nram_voxel_offset = (float *)((half *)data_nram + nram_pts_num * 21);
nram_pts_idx_seq = (int *)((half *)data_nram + nram_pts_num * 23);
fp_local_X = (float *)((half *)local_X - nram_pts_num);
fp_local_Y = (float *)((half *)local_Y - nram_pts_num);
fp_local_Z = (float *)((half *)local_Z - nram_pts_num);
fp_nram_pts_in_flag = (float *)((half *)nram_pts_in_flag - nram_pts_num);
}
for (int i = 0; i < nram_pts_num; i++) {
nram_pts_idx_seq[i] = i;
}
int nram_pts_loop_times = pts_num / nram_pts_num;
int rem_nram_num = pts_num % nram_pts_num;
for (int roi_index = taskId; roi_index < boxes_num; roi_index += taskDim) {
const T *cur_roi = rois + roi_index * ROI_OFFSET;
T cx = cur_roi[0];
T cy = cur_roi[1];
T cz = cur_roi[2];
T dx = cur_roi[3];
T dy = cur_roi[4];
T dz = cur_roi[5];
T rz = cur_roi[6];
T dx_2 = dx / 2.0;
T dy_2 = dy / 2.0;
T dz_2 = dz / 2.0;
for (int loop_idx = 0; loop_idx <= nram_pts_loop_times; loop_idx++) {
int load_pts_num =
(loop_idx == nram_pts_loop_times) ? rem_nram_num : nram_pts_num;
if (load_pts_num == 0) {
break;
}
int pts_offset_cur_loop = nram_pts_num * loop_idx;
int compute_pts_num = (loop_idx == nram_pts_loop_times)
? PAD_UP(rem_nram_num, ALIGN_NUM)
: nram_pts_num;
// load pts
__memcpy((void *)X, (T *)pts + pts_offset_cur_loop,
load_pts_num * sizeof(T), GDRAM2NRAM);
__memcpy((void *)Y, (T *)pts + pts_num + pts_offset_cur_loop,
load_pts_num * sizeof(T), GDRAM2NRAM);
__memcpy((void *)Z, (T *)pts + pts_num * 2 + pts_offset_cur_loop,
load_pts_num * sizeof(T), GDRAM2NRAM);
// fabs(local_z)
__bang_sub_scalar((T *)local_Z, (T *)Z, (T)cz, compute_pts_num);
__bang_sub_scalar((T *)temp_buffer1, (T *)Z, (T)(cz + dz_2),
compute_pts_num);
__bang_active_abs((T *)temp_buffer1, (T *)temp_buffer1, compute_pts_num);
#if __BANG_ARCH__ >= 322
__bang_le_scalar((T *)nram_pts_in_flag, (T *)temp_buffer1, (T)(dz_2),
compute_pts_num);
#else
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dz_2));
__bang_le((T *)nram_pts_in_flag, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
#endif
T cosa = std::cos(-rz);
T sina = std::sin(-rz);
__bang_sub_scalar((T *)temp_buffer3, (T *)X, (T)cx, compute_pts_num);
__bang_sub_scalar((T *)temp_buffer4, (T *)Y, (T)cy, compute_pts_num);
__bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)cosa,
compute_pts_num);
__bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)sina,
compute_pts_num);
// local_x
__bang_sub((T *)local_X, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
// fabs(local_x)
__bang_active_abs((T *)temp_buffer1, (T *)local_X, compute_pts_num);
// fabs(local_x) < dx/2 ? 1 : 0
#if __BANG_ARCH__ >= 322
__bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dx_2),
compute_pts_num);
#else
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dx_2));
__bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
#endif
__bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag,
(T *)temp_buffer1,
compute_pts_num); // flush res
__bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)sina,
compute_pts_num);
__bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)cosa,
compute_pts_num);
// local_y
__bang_add((T *)local_Y, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
// fabs(local_y)
__bang_active_abs((T *)temp_buffer1, (T *)local_Y, compute_pts_num);
// fabs(local_y) < dy/2 ? 1 : 0
#if __BANG_ARCH__ >= 322
__bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dy_2),
compute_pts_num);
#else
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dy_2));
__bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
#endif
__bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag,
(T *)temp_buffer1,
compute_pts_num); // flush res
T x_res = dx / out_x;
T y_res = dy / out_y;
T z_res = dz / out_z;
__bang_add_scalar((T *)local_X, (T *)local_X, (T)(dx_2), compute_pts_num);
__bang_add_scalar((T *)local_Y, (T *)local_Y, (T)(dy_2), compute_pts_num);
// local_Z do not need to add dz/2.0
#if (__BANG_ARCH__ >= 322) && (__BANG_ARCH__ != 372)
__bang_div((T *)local_X, (T *)local_X, (T)x_res, compute_pts_num);
__bang_div((T *)local_Y, (T *)local_Y, (T)y_res, compute_pts_num);
__bang_div((T *)local_Z, (T *)local_Z, (T)z_res, compute_pts_num);
#else
__bang_mul_scalar((T *)local_X, (T *)local_X, (T)(1 / x_res),
compute_pts_num);
__bang_mul_scalar((T *)local_Y, (T *)local_Y, (T)(1 / y_res),
compute_pts_num);
__bang_mul_scalar((T *)local_Z, (T *)local_Z, (T)(1 / z_res),
compute_pts_num);
#endif
// float = float2int + int2float, half = half2int + int2float
if (sizeof(T) == sizeof(float)) {
#if __BANG_ARCH__ >= 322
__bang_float2int32_tz((int *)temp_buffer1, (float *)local_X,
compute_pts_num, 0);
__bang_float2int32_tz((int *)temp_buffer2, (float *)local_Y,
compute_pts_num, 0);
__bang_float2int32_tz((int *)temp_buffer3, (float *)local_Z,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3,
compute_pts_num, 0);
#else
convertFloat2Int((int *)temp_buffer1, (float *)temp_buffer2,
(float *)fp_local_X, (float *)temp_buffer3,
compute_pts_num);
convertFloat2Int((int *)temp_buffer2, (float *)temp_buffer3,
(float *)fp_local_Y, (float *)temp_buffer4,
compute_pts_num);
convertFloat2Int((int *)temp_buffer3, (float *)temp_buffer4,
(float *)fp_local_Z, (float *)temp_buffer5,
compute_pts_num);
convertInt2Float((float *)fp_local_X, (float *)temp_buffer4,
(int *)temp_buffer1, (float *)temp_buffer5,
compute_pts_num);
convertInt2Float((float *)fp_local_Y, (float *)temp_buffer4,
(int *)temp_buffer2, (float *)temp_buffer5,
compute_pts_num);
convertInt2Float((float *)fp_local_Z, (float *)temp_buffer4,
(int *)temp_buffer3, (float *)temp_buffer5,
compute_pts_num);
#endif
} else {
__bang_half2float((float *)temp_buffer4, (half *)nram_pts_in_flag,
compute_pts_num);
__bang_move((void *)fp_nram_pts_in_flag, (void *)temp_buffer4,
compute_pts_num * sizeof(float));
#if __BANG_ARCH__ >= 322
__bang_half2int32_tz((int *)temp_buffer1, (half *)local_X,
compute_pts_num, 0);
__bang_half2int32_tz((int *)temp_buffer2, (half *)local_Y,
compute_pts_num, 0);
__bang_half2int32_tz((int *)temp_buffer3, (half *)local_Z,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3,
compute_pts_num, 0);
#else
__bang_half2int16_tz((int16_t *)temp_buffer1, (half *)local_X,
compute_pts_num, 0);
__bang_half2int16_tz((int16_t *)temp_buffer2, (half *)local_Y,
compute_pts_num, 0);
__bang_half2int16_tz((int16_t *)temp_buffer3, (half *)local_Z,
compute_pts_num, 0);
__bang_int162float((float *)fp_local_X, (int16_t *)temp_buffer1,
compute_pts_num, 0);
__bang_int162float((float *)fp_local_Y, (int16_t *)temp_buffer2,
compute_pts_num, 0);
__bang_int162float((float *)fp_local_Z, (int16_t *)temp_buffer3,
compute_pts_num, 0);
#endif
}
// process index >= 0
__bang_write_value((float *)temp_buffer4, compute_pts_num, (float)0.0f);
__bang_maxequal((float *)fp_local_X, (float *)fp_local_X,
(float *)temp_buffer4, compute_pts_num);
__bang_maxequal((float *)fp_local_Y, (float *)fp_local_Y,
(float *)temp_buffer4, compute_pts_num);
__bang_maxequal((float *)fp_local_Z, (float *)fp_local_Z,
(float *)temp_buffer4, compute_pts_num);
// process index <= (out_x - 1)
__bang_write_value((float *)temp_buffer5, compute_pts_num,
(float)(out_x - 1));
__bang_minequal((float *)fp_local_X, (float *)fp_local_X,
(float *)temp_buffer5, compute_pts_num);
__bang_write_value((float *)temp_buffer5, compute_pts_num,
(float)(out_y - 1));
__bang_minequal((float *)fp_local_Y, (float *)fp_local_Y,
(float *)temp_buffer5, compute_pts_num);
__bang_write_value((float *)temp_buffer5, compute_pts_num,
(float)(out_z - 1));
__bang_minequal((float *)fp_local_Z, (float *)fp_local_Z,
(float *)temp_buffer5, compute_pts_num);
__bang_mul_scalar((float *)temp_buffer1, (float *)fp_local_X,
(float)(out_y * out_z), compute_pts_num);
__bang_mul_scalar((float *)temp_buffer2, (float *)fp_local_Y,
(float)out_z, compute_pts_num);
__bang_mul_scalar((float *)temp_buffer3, (float *)fp_local_Z, (float)1.0,
compute_pts_num);
__bang_add((float *)nram_voxel_offset, (float *)temp_buffer1,
(float *)temp_buffer2, compute_pts_num);
__bang_add((float *)nram_voxel_offset, (float *)nram_voxel_offset,
(float *)temp_buffer3, compute_pts_num);
__bang_mul_scalar((float *)nram_voxel_offset, (float *)nram_voxel_offset,
(float)max_pts_each_voxel, compute_pts_num);
if (compute_pts_num != load_pts_num) {
__memset_nram((float *)fp_nram_pts_in_flag + load_pts_num,
compute_pts_num - load_pts_num, (float)0.0);
}
__bang_collect((float *)temp_buffer4, (float *)nram_pts_idx_seq,
(float *)fp_nram_pts_in_flag, compute_pts_num);
int pts_num_in_cur_roi =
(int)__bang_count((float *)fp_nram_pts_in_flag, compute_pts_num);
int *pts_idx_cur_voxels =
(int *)pts_idx_of_voxels +
roi_index * out_x * out_y * out_z * max_pts_each_voxel;
for (int idx = 0; idx < pts_num_in_cur_roi; idx++) {
int cur_pts_idx = *((int *)temp_buffer4 + idx);
int offset = (int)(*((float *)nram_voxel_offset + cur_pts_idx));
int cnt = pts_idx_cur_voxels[offset];
if (cnt < max_pts_each_voxel - 1) {
pts_idx_cur_voxels[offset + cnt + 1] =
cur_pts_idx + loop_idx * nram_pts_num;
pts_idx_cur_voxels[offset]++;
}
}
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelRoiawarePool3dForward(
const int pool_method, const int boxes_num, const int pts_num,
const int channels, const int max_pts_each_voxel, const int out_x,
const int out_y, const int out_z, const T *pts_feature,
const int *pts_idx_of_voxels, T *pooled_features, int *argmax) {
// params (T)pts_feature: (channels, pts_num)
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
// max_pts_each_voxel) params (int)argmax: (boxes_num, out_x, out_y, out_z,
// channels) params (T)pooled_features: (boxes_num, out_x, out_y, out_z,
// channels)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int align_num = NFU_ALIGN_SIZE / sizeof(T);
int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num);
int nram_channels_limit =
PAD_DOWN((MAX_NRAM_SIZE - 128 -
align_max_pts_each_voxel * (sizeof(int) + sizeof(T))) /
((align_max_pts_each_voxel + 1) * sizeof(T) + sizeof(int)),
align_num);
int *nram_pts_idx_cur_voxel = (int *)data_nram;
// nram_pts_idx_cur_voxel [align_max_pts_each_voxel]
T *nram_max_pts_feature_tmp =
(T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel);
// nram_max_pts_feature_tmp [align_max_pts_each_voxel]
T *nram_pts_feature_in_voxel =
((T *)nram_max_pts_feature_tmp + align_max_pts_each_voxel);
// nram_pts_feature_in_voxel [nram_channels_limit, align_max_pts_each_voxel]
T *nram_pooled_features_cur_voxel =
((T *)nram_pts_feature_in_voxel +
nram_channels_limit * align_max_pts_each_voxel);
// nram_pooled_features_cur_voxel [nram_channels_limit]
int *nram_argmax_cur_voxel =
(int *)((T *)nram_pooled_features_cur_voxel + nram_channels_limit);
// nram_argmax_cur_voxel [nram_channels_limit]
char *one_pooled_feature =
(char *)((int *)nram_argmax_cur_voxel + nram_channels_limit);
// one_pooled_feature [128]
int channels_loop_times = channels / nram_channels_limit;
int rem_channels = channels % nram_channels_limit;
for (int voxel_index = taskId;
voxel_index < boxes_num * out_x * out_y * out_z;
voxel_index += taskDim) {
int *pts_idx_cur_voxels =
(int *)pts_idx_of_voxels + voxel_index * max_pts_each_voxel;
__memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxels,
max_pts_each_voxel * sizeof(int), GDRAM2NRAM);
int pts_num_cur_voxel = nram_pts_idx_cur_voxel[0];
if (pts_num_cur_voxel == 0) {
continue;
}
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
channels_loop_idx++) {
int actual_channels_num = (channels_loop_idx == channels_loop_times)
? rem_channels
: nram_channels_limit;
if (actual_channels_num == 0) {
break;
}
int channels_offset = nram_channels_limit * channels_loop_idx;
#if ((__BANG_ARCH__ >= 200) && (__BANG_ARCH__ < 300))
int compute_channels_num = (channels_loop_idx == channels_loop_times)
? PAD_UP(rem_channels, align_num)
: nram_channels_limit;
if (pool_method == 0) {
__bang_write_value((void *)nram_pts_feature_in_voxel,
compute_channels_num * align_max_pts_each_voxel,
(T)-INFINITY);
}
#endif
T *pts_feature_cur_loop = (T *)pts_feature + channels_offset * pts_num;
for (int idx = 0; idx < pts_num_cur_voxel; idx++) {
__memcpy((T *)nram_pts_feature_in_voxel + idx,
(T *)pts_feature_cur_loop + nram_pts_idx_cur_voxel[idx + 1],
sizeof(T), GDRAM2NRAM, align_max_pts_each_voxel * sizeof(T),
pts_num * sizeof(T), actual_channels_num - 1);
}
for (int channel_idx = 0; channel_idx < actual_channels_num;
channel_idx++) {
if (pool_method == 0) {
#if __BANG_ARCH__ >= 322
__bang_argmax((T *)one_pooled_feature,
(T *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel,
pts_num_cur_voxel);
T max_val = ((T *)one_pooled_feature)[0];
int max_idx = (int)(*(uint32_t *)((T *)one_pooled_feature + 1));
nram_pooled_features_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? 0 : max_val;
nram_argmax_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1];
#else
// __bang_max need align num on mlu200 series
if (sizeof(T) == sizeof(float)) {
__bang_max((float *)one_pooled_feature,
(float *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel,
align_max_pts_each_voxel);
float max_val = ((float *)one_pooled_feature)[0];
__bang_write_value((void *)nram_max_pts_feature_tmp,
align_max_pts_each_voxel, (float)max_val);
__bang_eq((float *)nram_max_pts_feature_tmp,
(float *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel,
(float *)nram_max_pts_feature_tmp,
align_max_pts_each_voxel);
int max_idx = (int)__bang_findfirst1(
(float *)nram_max_pts_feature_tmp, align_max_pts_each_voxel);
nram_pooled_features_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? 0 : max_val;
nram_argmax_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? -1
: nram_pts_idx_cur_voxel[max_idx + 1];
} else {
int max_idx = -1;
float max_val = -INFINITY;
for (int k = 0; k < pts_num_cur_voxel; k++) {
float pts_feature_cur_channel = __half2float_rd(
*((half *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel + k));
if (pts_feature_cur_channel > max_val) {
max_val = pts_feature_cur_channel;
max_idx = k;
}
}
nram_pooled_features_cur_voxel[channel_idx] =
(max_idx == -1) ? 0 : max_val;
nram_argmax_cur_voxel[channel_idx] =
(max_idx == -1) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1];
}
#endif
} else if (pool_method == 1) {
float sum_val_cur_channel = 0;
for (int k = 0; k < pts_num_cur_voxel; k++) {
sum_val_cur_channel += static_cast<float>(
((T *)nram_pts_feature_in_voxel)[channel_idx *
align_max_pts_each_voxel +
k]);
}
nram_pooled_features_cur_voxel[channel_idx] =
(T)(sum_val_cur_channel / pts_num_cur_voxel);
}
}
// store
__memcpy((T *)pooled_features + voxel_index * channels + channels_offset,
(void *)nram_pooled_features_cur_voxel,
actual_channels_num * sizeof(T), NRAM2GDRAM);
if (pool_method == 0) {
__memcpy((int *)argmax + voxel_index * channels + channels_offset,
(void *)nram_argmax_cur_voxel,
actual_channels_num * sizeof(int), NRAM2GDRAM);
}
}
}
}
void KernelPtsIdxOfVoxels(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const int pool_method, const int boxes_num,
const int pts_num, const int max_pts_each_voxel,
const int out_x, const int out_y, const int out_z,
const void *rois, const void *pts,
int *pts_idx_of_voxels) {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelPtsIdxOfVoxels<float><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y,
out_z, (float *)rois, (float *)pts, (int *)pts_idx_of_voxels);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelPtsIdxOfVoxels<half><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y,
out_z, (half *)rois, (half *)pts, (int *)pts_idx_of_voxels);
}; break;
default: {
break;
}
}
}
void KernelRoiawarePool3dForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
const int pts_num, const int channels, const int max_pts_each_voxel,
const int out_x, const int out_y, const int out_z, const void *pts_feature,
const int *pts_idx_of_voxels, void *pooled_features, int *argmax) {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelRoiawarePool3dForward<float><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x,
out_y, out_z, (float *)pts_feature, (int *)pts_idx_of_voxels,
(float *)pooled_features, (int *)argmax);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelRoiawarePool3dForward<half><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x,
out_y, out_z, (half *)pts_feature, (int *)pts_idx_of_voxels,
(half *)pooled_features, (int *)argmax);
}; break;
default: {
break;
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelRoiawareMaxPool3dBackward(
const int boxes_num, const int out_x, const int out_y, const int out_z,
const int channels, const int *argmax, const T *grad_out, T *grad_in) {
// params (int)argmax: (boxes_num, out_x, out_y, out_z, channels)
// params (T)grad_out: (boxes_num, out_x, out_y, out_z, channels)
// params (T)grad_in: (pts_num, channels)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int nram_channels_limit =
(MAX_NRAM_SIZE - sizeof(T) * 1) / (sizeof(T) + sizeof(int));
int *nram_argmax_cur_loop = (int *)data_nram;
// nram_argmax_cur_loop [nram_channels_limit]
T *nram_grad_out_cur_loop =
(T *)((int *)nram_argmax_cur_loop + nram_channels_limit);
// nram_grad_out_cur_loop [nram_channels_limit]
T *nram_grad_in_cur_channel =
(T *)nram_grad_out_cur_loop + nram_channels_limit;
// nram_grad_in_cur_channel [1]
int channels_loop_times = channels / nram_channels_limit;
int rem_channels = channels % nram_channels_limit;
int voxels_num = boxes_num * out_x * out_y * out_z;
for (int voxel_index = taskId; voxel_index < voxels_num;
voxel_index += taskDim) {
const int *argmax_cur_voxel = argmax + voxel_index * channels;
const T *grad_out_cur_voxel = grad_out + voxel_index * channels;
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
channels_loop_idx++) {
int actual_channels_num = (channels_loop_idx == channels_loop_times)
? rem_channels
: nram_channels_limit;
if (actual_channels_num == 0) {
break;
}
const int *argmax_cur_loop =
argmax_cur_voxel + nram_channels_limit * channels_loop_idx;
const T *grad_out_cur_loop =
grad_out_cur_voxel + nram_channels_limit * channels_loop_idx;
__memcpy((void *)nram_argmax_cur_loop, (void *)argmax_cur_loop,
actual_channels_num * sizeof(int), GDRAM2NRAM);
__memcpy((void *)nram_grad_out_cur_loop, (void *)grad_out_cur_loop,
actual_channels_num * sizeof(T), GDRAM2NRAM);
for (int channel_idx = 0; channel_idx < actual_channels_num;
channel_idx++) {
int *nram_argmax_cur_channel = nram_argmax_cur_loop + channel_idx;
T *nram_grad_out_cur_channel = nram_grad_out_cur_loop + channel_idx;
if (nram_argmax_cur_channel[0] == -1) {
continue;
}
T *grad_in_cur_channel =
grad_in + nram_argmax_cur_channel[0] * channels +
nram_channels_limit * channels_loop_idx + channel_idx;
__bang_atomic_add((T *)nram_grad_in_cur_channel,
(T *)grad_in_cur_channel,
(T *)(nram_grad_out_cur_channel), 1);
}
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelRoiawareAvgPool3dBackward(
const int boxes_num, const int out_x, const int out_y, const int out_z,
const int channels, const int max_pts_each_voxel,
const int *pts_idx_of_voxels, const T *grad_out, T *grad_in) {
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
// max_pts_each_voxel) params (T)grad_out: (boxes_num, out_x, out_y, out_z,
// channels) params (T)grad_in: (pts_num, channels)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int align_num = NFU_ALIGN_SIZE / sizeof(T);
int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num);
int nram_channels_limit = PAD_DOWN(
(MAX_NRAM_SIZE - align_max_pts_each_voxel * sizeof(int)) / 2 / sizeof(T),
align_num);
int *nram_pts_idx_cur_voxel = (int *)data_nram;
// nram_pts_idx_cur_voxel [align_max_pts_each_voxel]
T *nram_grad_out_cur_loop =
(T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel);
// nram_grad_out_cur_loop [nram_channels_limit]
T *nram_grad_in_cur_loop = (T *)nram_grad_out_cur_loop + nram_channels_limit;
// nram_grad_in_cur_loop [nram_channels_limit]
int channels_loop_times = channels / nram_channels_limit;
int rem_channels = channels % nram_channels_limit;
int voxels_num = boxes_num * out_x * out_y * out_z;
for (int voxel_index = taskId; voxel_index < voxels_num;
voxel_index += taskDim) {
const T *grad_out_cur_voxel = grad_out + voxel_index * channels;
const int *pts_idx_cur_voxel =
pts_idx_of_voxels + voxel_index * max_pts_each_voxel;
__memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxel,
max_pts_each_voxel * sizeof(int), GDRAM2NRAM);
int total_pts_of_voxel = nram_pts_idx_cur_voxel[0];
if (total_pts_of_voxel <= 0) {
continue;
}
float cur_grad = 1.0 / ((float)total_pts_of_voxel);
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
channels_loop_idx++) {
int actual_channels_num = (channels_loop_idx == channels_loop_times)
? rem_channels
: nram_channels_limit;
if (actual_channels_num == 0) {
break;
}
const T *grad_out_cur_loop =
grad_out_cur_voxel + nram_channels_limit * channels_loop_idx;
__memcpy((void *)nram_grad_in_cur_loop, (void *)grad_out_cur_loop,
actual_channels_num * sizeof(T), GDRAM2NRAM);
int align_actual_channels_num = PAD_UP(actual_channels_num, align_num);
if (sizeof(T) == sizeof(half)) {
__bang_half2float((float *)nram_grad_out_cur_loop,
(half *)nram_grad_in_cur_loop,
align_actual_channels_num);
__bang_mul_scalar((float *)nram_grad_out_cur_loop,
(float *)nram_grad_out_cur_loop, (float)cur_grad,
align_actual_channels_num);
convertFloat2half((half *)nram_grad_out_cur_loop,
(float *)nram_grad_out_cur_loop,
align_actual_channels_num);
} else {
__bang_mul_scalar((float *)nram_grad_out_cur_loop,
(float *)nram_grad_in_cur_loop, (float)cur_grad,
align_actual_channels_num);
}
for (int k = 1; k <= total_pts_of_voxel; k++) {
T *grad_in_cur_loop = grad_in + nram_pts_idx_cur_voxel[k] * channels +
nram_channels_limit * channels_loop_idx;
__bang_atomic_add((T *)nram_grad_in_cur_loop, (T *)grad_in_cur_loop,
(T *)nram_grad_out_cur_loop, actual_channels_num);
}
}
}
}
void KernelRoiawarePool3dBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
const int out_x, const int out_y, const int out_z, const int channels,
const int max_pts_each_voxel, const int *pts_idx_of_voxels,
const int *argmax, const void *grad_out, void *grad_in) {
if (pool_method == 0) {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelRoiawareMaxPool3dBackward<float>
<<<k_dim, k_type, queue>>>(boxes_num, out_x, out_y, out_z, channels,
(int *)argmax, (float *)grad_out,
(float *)grad_in);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelRoiawareMaxPool3dBackward<half>
<<<k_dim, k_type, queue>>>(boxes_num, out_x, out_y, out_z, channels,
(int *)argmax, (half *)grad_out,
(half *)grad_in);
}; break;
default: {
break;
}
}
} else {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelRoiawareAvgPool3dBackward<float>
<<<k_dim, k_type, queue>>>(
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
(int *)pts_idx_of_voxels, (float *)grad_out, (float *)grad_in);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelRoiawareAvgPool3dBackward<half>
<<<k_dim, k_type, queue>>>(
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
(int *)pts_idx_of_voxels, (half *)grad_out, (half *)grad_in);
}; break;
default: {
break;
}
}
}
}
mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include <algorithm>
__nram__ char nram_buffer[MAX_NRAM_SIZE];
#if __BANG_ARCH__ >= 322
/**
* returns the index of ret, which is stored at the 1st position of the `ret`,
* used after bang_min
*/
__mlu_func__ uint32_t getIndice(half *ret) {
uint32_t indice = *((uint32_t *)((uint16_t *)ret + 1));
return indice;
}
/**
* returns the index of ret, which is stored at the 1st position of the `ret`,
* used after bang_min
*/
__mlu_func__ uint32_t getIndice(float *ret) {
uint32_t indice = ((uint32_t *)ret)[1];
return indice;
}
#endif
template <typename T>
__mlu_func__ void auxArgmin(T *nram_dst, T *nram_src, const int num_deal,
T *value, int *index) {
__bang_min(nram_dst, nram_src, num_deal);
*value = nram_dst[0];
__bang_write_value(nram_dst, num_deal, *value);
__bang_eq(nram_dst, nram_src, nram_dst, num_deal);
__bang_findfirst1((uint32_t *)nram_dst, nram_dst, num_deal);
*index = *((int *)nram_dst);
}
template <typename T>
__mlu_func__ void auxFuncFind3Min(T *nram_aux_a, const int auxa_offset,
int *nram_aux_b, const int auxb_offset,
T *nram_dest, T *nram_aux_sort_a,
int *nram_aux_sort_b, const int deal_offset) {
__bang_write_value(nram_aux_sort_a, auxa_offset, (T)(INFINITY));
__bang_write_value(nram_aux_sort_b, auxb_offset, (int)0);
int index = 0;
for (int i = 0; i < 3; i++) {
#if __BANG_ARCH__ >= 322
__bang_argmin(nram_dest, nram_aux_a, auxa_offset);
nram_aux_sort_a[i] = nram_dest[0];
index = getIndice(nram_dest);
#else
T value = 0;
auxArgmin(nram_dest, nram_aux_a, auxa_offset, &value, &index);
nram_aux_sort_a[i] = value;
#endif
nram_aux_sort_b[i] = nram_aux_b[index];
__memset_nram(nram_aux_a + index, 1, (T)(INFINITY));
}
__memcpy((char *)nram_aux_a, (char *)nram_aux_sort_a, auxa_offset * sizeof(T),
NRAM2NRAM);
__memcpy((char *)nram_aux_b, (char *)nram_aux_sort_b,
auxb_offset * sizeof(int), NRAM2NRAM);
}
template <typename T>
__mlu_func__ void auxFuncSort(T *nram_aux_a, const int auxa_offset,
int *nram_aux_b, const int auxb_offset,
T *nram_dest, T *nram_help_value,
int *nram_help_idx, const int num_deal,
const int deal_offset) {
for (int k = 0; k < num_deal; ++k) {
auxFuncFind3Min(nram_aux_a + k * auxa_offset, auxa_offset,
nram_aux_b + k * auxb_offset, auxb_offset, nram_dest,
nram_help_value, nram_help_idx, deal_offset);
}
}
template <typename T>
__mlu_func__ void auxFuncNN(
size_t *output_aux_sort_a_gap, size_t *output_aux_sort_b_gap,
size_t *output_aux_dest_gap, size_t *output_unknown_gap,
size_t *output_known_gap, size_t *output_dist_gap, size_t *auxillary_a_gap,
size_t *auxillary_b_gap, size_t *known_num_deal, size_t *unknown_num_deal,
size_t *align_num, size_t *auxa_offset, size_t *auxb_offset) {
/*
* nram partition:
* |-NFU_ALIGN_SIZE-|-2*NFU_ALIGN_SIZE-|-X*3*sizeof(T)-|
* space: | aux_sort_a | aux_sort_b | nram_unknown |
*
* | ------ (Y * 7 *sizeof(T)) ---------------- |
* | nram_known | nram_dist | nram_dest |
*
* | -X * NFU_ALIGN_SIZE ---|---X * 2 * NFU_ALIGN_SIZE-|
* | output_dist(aux_a) | output_dist(aux_b) |
* 200 series
* X = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) * (2/3) / (3 * sizeof(T) + 3 *
* NFU_ALIGN_SIZE)
* Y = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) * (1/3) / (7 * sizeof(T))
* 300 series
* X = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) * (4/5) / (3 *
* sizeof(T) + 3 * NFU_ALIGN_SIZE)
* Y = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) *
* (1/5) / (7 * sizeof(T))
*
*/
*align_num = NFU_ALIGN_SIZE / sizeof(T);
*auxa_offset = NFU_ALIGN_SIZE / sizeof(T);
*auxb_offset = 2 * NFU_ALIGN_SIZE / sizeof(int);
#if __BANG_ARCH__ >= 322
*known_num_deal = PAD_DOWN(
(MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 5 / (7 * sizeof(T)), *align_num);
*unknown_num_deal = PAD_DOWN((MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 5 * 4 /
(3 * sizeof(T) + 3 * NFU_ALIGN_SIZE),
*align_num);
#else
*known_num_deal = PAD_DOWN(
(MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 3 / (7 * sizeof(T)), *align_num);
*unknown_num_deal = PAD_DOWN((MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 3 * 2 /
(3 * sizeof(T) + 3 * NFU_ALIGN_SIZE),
*align_num);
#endif
*output_aux_sort_a_gap = 0;
*output_aux_sort_b_gap = *output_aux_sort_a_gap + NFU_ALIGN_SIZE;
*output_aux_dest_gap = *output_aux_sort_b_gap + 2 * NFU_ALIGN_SIZE;
*output_unknown_gap = *output_aux_dest_gap + *known_num_deal * sizeof(T);
*output_known_gap = *output_unknown_gap + *unknown_num_deal * 3 * sizeof(T);
*output_dist_gap = *output_known_gap + *known_num_deal * 3 * sizeof(T);
*auxillary_a_gap = *output_dist_gap + *known_num_deal * 3 * sizeof(T);
*auxillary_b_gap = *auxillary_a_gap + *unknown_num_deal * NFU_ALIGN_SIZE;
}
#if __BANG_ARCH__ >= 322
template <typename T>
__mlu_func__ bool containNanInf(T *nram_unknown) {
if (std::isnan(nram_unknown[0]) || std::isnan(nram_unknown[1]) ||
std::isnan(nram_unknown[2]) || std::isinf(nram_unknown[0]) ||
std::isinf(nram_unknown[1]) || std::isinf(nram_unknown[2]))
return true;
else
return false;
}
#endif
template <typename T>
__mlu_func__ void computeThreeNN(T *nram_unknown, T *nram_known, T *nram_dist,
T *nram_dest, T *nram_aux_a,
T *nram_aux_sort_a, int *nram_aux_b,
int *nram_aux_sort_b, const int known_num_deal,
const int known_seg_num, const int deal_offset,
const int known_count,
const int known_count_align) {
__bang_write_value(nram_dist, 3 * known_num_deal, (T)(INFINITY));
#if __BANG_ARCH__ >= 322
if (!containNanInf(nram_unknown)) {
#endif
// x1 - x2
__bang_sub_scalar(nram_dist, nram_known, nram_unknown[0],
known_count_align);
// y1 - y2
__bang_sub_scalar(nram_dist + known_count_align,
nram_known + known_count_align, nram_unknown[1],
known_count_align);
// z1 - z2
__bang_sub_scalar(nram_dist + 2 * known_count_align,
nram_known + 2 * known_count_align, nram_unknown[2],
known_count_align);
__bang_square(nram_dist, nram_dist, 3 * known_count_align);
__bang_add(nram_dist, nram_dist, nram_dist + known_count_align,
known_count_align);
__bang_add(nram_dist, nram_dist, nram_dist + 2 * known_count_align,
known_count_align);
#if __BANG_ARCH__ >= 322
}
#endif
int index = 0;
for (int i = 0; i < 3; i++) {
#if __BANG_ARCH__ >= 322
__bang_argmin(nram_dest, nram_dist, known_count_align);
nram_aux_a[i + deal_offset] = nram_dest[0];
index = getIndice(nram_dest);
#else
T value = 0;
auxArgmin(nram_dest, nram_dist, known_count_align, &value, &index);
nram_aux_a[i + deal_offset] = value;
#endif
nram_aux_b[i + deal_offset] = index + known_seg_num * known_num_deal;
__memset_nram(nram_dist + index, 1, (T)(INFINITY));
}
}
template <typename T>
__mlu_func__ void loadTransposedKnownTensor(
char *nram_known, char *nram_dist, const char *known_gdram,
const int known_num_deal, const int batch_id, const int m,
const int known_seg_num, const int count, const int count_align_num) {
__bang_write_value(nram_known, 3 * known_num_deal, (T)(INFINITY));
#if __BANG_ARCH__ >= 322
__bang_write_value(nram_dist, 3 * known_num_deal, (T)(INFINITY));
__memcpy(nram_dist,
known_gdram +
(batch_id * m * 3 + known_seg_num * known_num_deal) * sizeof(T),
count * sizeof(T), GDRAM2NRAM, count_align_num * sizeof(T),
m * sizeof(T), 2);
__bang_minequal((T *)nram_known, (T *)nram_known, (T *)nram_dist,
3 * count_align_num);
#else
__memcpy(nram_known,
known_gdram +
(batch_id * m * 3 + known_seg_num * known_num_deal) * sizeof(T),
count * sizeof(T), GDRAM2NRAM, count_align_num * sizeof(T),
m * sizeof(T), 2);
#endif
}
template <typename T>
__mlu_func__ void loadUnknownTensor(char *nram_unknown,
const char *unknown_gdram,
const int unknown_num_deal,
const int unknown_seg_num, const int count,
const int count_align_num) {
__memcpy(nram_unknown,
unknown_gdram + unknown_seg_num * unknown_num_deal * 3 * sizeof(T),
count * 3 * sizeof(T), GDRAM2NRAM);
}
template <typename T>
__mlu_func__ void auxProcessSegment(
const int m, const int n, T *nram_unknown, T *nram_known, T *nram_dist,
T *nram_dest, T *known_gdram, T *nram_aux_a, const int auxa_offset,
int *nram_aux_b, const int auxb_offset, T *nram_aux_sort_a,
int *nram_aux_sort_b, const int unknown_num_deal, const int known_num_deal,
const int known_seg_num, const int unknown_seg_num, const int unknown_count,
const int known_count, const int known_count_align, const int start_idx,
int *deal_offset) {
int pre_batch_id = -1;
int cur_batch_id = -1;
pre_batch_id = start_idx / n;
// if aux_a space is not enough, get the first 3 min among aux_a and clear.
if (*deal_offset >= PAD_DOWN(auxa_offset, 3)) {
auxFuncSort(nram_aux_a, auxa_offset, nram_aux_b, auxb_offset, nram_dest,
nram_aux_sort_a, nram_aux_sort_b, unknown_count, *deal_offset);
*deal_offset = 3;
}
// load i'th segment of known batch data.
loadTransposedKnownTensor<T>((char *)nram_known, (char *)nram_dist,
(char *)known_gdram, known_num_deal,
pre_batch_id, m, known_seg_num, known_count,
known_count_align);
for (int k = 0; k < unknown_count; ++k) {
cur_batch_id = (start_idx + k) / n;
if (cur_batch_id != pre_batch_id) { // if batch id of unknown data changed,
// load corresponding known batch data
pre_batch_id = cur_batch_id;
loadTransposedKnownTensor<T>((char *)nram_known, (char *)nram_dist,
(char *)known_gdram, known_num_deal,
pre_batch_id, m, known_seg_num, known_count,
known_count_align);
}
computeThreeNN(nram_unknown + 3 * k, nram_known, nram_dist, nram_dest,
nram_aux_a + k * auxa_offset, nram_aux_sort_a,
nram_aux_b + k * auxb_offset, nram_aux_sort_b,
known_num_deal, known_seg_num, *deal_offset, known_count,
known_count_align);
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelThreeNN(const int b, const int n,
const int m, char *unknown_gdram,
char *known_gdram, char *dist2_gdram,
int *idx_gdram) {
if (coreId == 0x80) {
return;
}
size_t output_aux_sort_a_gap = 0, output_aux_sort_b_gap = 0,
output_dest_gap = 0, output_unknown_gap = 0, output_known_gap = 0,
output_dist_gap = 0, auxillary_a_gap = 0, auxillary_b_gap = 0,
known_num_deal = 0, unknown_num_deal = 0, align_num = 0,
auxa_offset = 0, auxb_offset = 0;
auxFuncNN<T>(&output_aux_sort_a_gap, &output_aux_sort_b_gap, &output_dest_gap,
&output_unknown_gap, &output_known_gap, &output_dist_gap,
&auxillary_a_gap, &auxillary_b_gap, &known_num_deal,
&unknown_num_deal, &align_num, &auxa_offset, &auxb_offset);
int num_per_core = b * n / taskDim;
const int core_offset = num_per_core;
char *unknown_gdram_start =
unknown_gdram + taskId * 3 * core_offset * sizeof(T);
char *known_gdram_start = known_gdram;
char *output_dist_start = dist2_gdram + taskId * 3 * core_offset * sizeof(T);
int *output_idx_start = idx_gdram + taskId * 3 * core_offset;
const int rem = (b * n) % taskDim;
if (taskId == taskDim - 1) {
num_per_core += rem;
}
const int unknown_repeat =
num_per_core / unknown_num_deal; // if unknown number is big, process it
// by unknown_repeat times.
const int unknown_rem = num_per_core % unknown_num_deal; // unknown reminder
const int unknown_rem_align = PAD_UP(unknown_rem, align_num);
const int known_repeat =
m / known_num_deal; // if known number is big, process it by
// unknown_repeat times.
const int known_rem = m % known_num_deal; // known reminder
const int known_rem_align = PAD_UP(known_rem, align_num);
char *nram_aux_sort_a = nram_buffer;
int *nram_aux_sort_b = (int *)(nram_buffer + output_aux_sort_b_gap);
char *nram_dest = nram_buffer + output_dest_gap;
char *nram_unknown = nram_buffer + output_unknown_gap;
char *nram_known = nram_buffer + output_known_gap;
char *nram_dist = nram_buffer + output_dist_gap;
char *nram_aux_a = nram_buffer + auxillary_a_gap;
int *nram_aux_b = (int *)(nram_buffer + auxillary_b_gap);
int deal_offset = 0;
int start_idx = -1;
for (int j = 0; j < unknown_repeat;
++j) { // process data within a unknown_repeat
// if unknown need to be process segmentally, use a aux_a and aux_b
// space to find first 3 minimum dist.
__bang_write_value(nram_aux_a, unknown_num_deal * auxa_offset,
(T)(INFINITY));
__bang_write_value(nram_aux_b, unknown_num_deal * auxb_offset, (int)0);
loadUnknownTensor<T>(nram_unknown, unknown_gdram_start, unknown_num_deal, j,
unknown_num_deal, unknown_num_deal);
deal_offset = 0;
start_idx = taskId * core_offset + j * unknown_num_deal;
for (int i = 0; i < known_repeat;
++i) { // process known data in segmentally.
auxProcessSegment<T>(
m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist,
(T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset,
nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_num_deal, known_num_deal, i, j, unknown_num_deal,
known_num_deal, known_num_deal, start_idx, &deal_offset);
deal_offset += 3;
}
if (known_rem > 0) { // process known rem
__bang_write_value(nram_known, 3 * known_num_deal, (T)(INFINITY));
auxProcessSegment<T>(
m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist,
(T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset,
nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_num_deal, known_num_deal, known_repeat, j, unknown_num_deal,
known_rem, known_rem_align, start_idx, &deal_offset);
}
deal_offset += 3;
if (deal_offset > 3) {
auxFuncSort((T *)nram_aux_a, auxa_offset, nram_aux_b, auxb_offset,
(T *)nram_dest, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_num_deal, deal_offset);
deal_offset = 0;
}
__memcpy((char *)output_dist_start + j * unknown_num_deal * 3 * sizeof(T),
(char *)nram_aux_a, 3 * sizeof(T), NRAM2GDRAM, 3 * sizeof(T),
auxa_offset * sizeof(T), unknown_num_deal - 1);
__memcpy((char *)output_idx_start + j * unknown_num_deal * 3 * sizeof(int),
(char *)nram_aux_b, 3 * sizeof(int), NRAM2GDRAM, 3 * sizeof(int),
auxb_offset * sizeof(int), unknown_num_deal - 1);
}
if (unknown_rem > 0) { // process unknown rem
deal_offset = 0;
__bang_write_value(nram_aux_a, unknown_num_deal * auxa_offset,
(T)(INFINITY));
__bang_write_value(nram_aux_b, unknown_num_deal * auxb_offset, (int)0);
loadUnknownTensor<T>(nram_unknown, unknown_gdram_start, unknown_num_deal,
unknown_repeat, unknown_rem, unknown_rem_align);
start_idx = taskId * core_offset + unknown_repeat * unknown_num_deal;
for (int i = 0; i < known_repeat; ++i) {
auxProcessSegment<T>(
m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist,
(T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset,
nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_num_deal, known_num_deal, i, unknown_repeat, unknown_rem,
known_num_deal, known_num_deal, start_idx, &deal_offset);
deal_offset += 3;
}
if (known_rem > 0) {
__bang_write_value(nram_known, 3 * known_num_deal, (T)(INFINITY));
start_idx = taskId * core_offset + unknown_repeat * unknown_num_deal;
auxProcessSegment<T>(
m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist,
(T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset,
nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_num_deal, known_num_deal, known_repeat, unknown_repeat,
unknown_rem, known_rem, known_rem_align, start_idx, &deal_offset);
deal_offset += 3;
}
if (deal_offset > 3) {
auxFuncSort((T *)nram_aux_a, auxa_offset, nram_aux_b, auxb_offset,
(T *)nram_dest, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_rem, deal_offset);
deal_offset = 0;
}
__memcpy((char *)output_dist_start +
unknown_repeat * unknown_num_deal * 3 * sizeof(T),
(char *)nram_aux_a, 3 * sizeof(T), NRAM2GDRAM, 3 * sizeof(T),
auxa_offset * sizeof(T), unknown_rem - 1);
__memcpy((char *)output_idx_start +
unknown_repeat * unknown_num_deal * 3 * sizeof(int),
(char *)nram_aux_b, 3 * sizeof(int), NRAM2GDRAM, 3 * sizeof(int),
auxb_offset * sizeof(int), unknown_rem - 1);
}
}
template __mlu_global__ void MLUUnion1KernelThreeNN<float>(
const int b, const int n, const int m, char *unknown_gdram,
char *known_gdram, char *dist2_gdram, int *idx_gdram);
template __mlu_global__ void MLUUnion1KernelThreeNN<half>(
const int b, const int n, const int m, char *unknown_gdram,
char *known_gdram, char *dist2_gdram, int *idx_gdram);
void KernelThreeNNForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *unknown, const void *known, void *dist2,
int *idx, const int b, const int n, const int m) {
switch (data_type) {
case CNRT_FLOAT16: {
MLUUnion1KernelThreeNN<half><<<k_dim, k_type, queue>>>(
b, n, m, (char *)unknown, (char *)known, (char *)dist2, idx);
}; break;
case CNRT_FLOAT32: {
MLUUnion1KernelThreeNN<float><<<k_dim, k_type, queue>>>(
b, n, m, (char *)unknown, (char *)known, (char *)dist2, idx);
}; break;
default: {
break;
}
}
}
mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
__nram__ char nram_buffer[MAX_NRAM_SIZE];
#if __BANG_ARCH__ >= 322
__mlu_func__ void computeDynamicVoxelize(
char *points_x, char *points_y, char *points_z, char *auxiliary_a,
char *auxiliary_b, char *auxiliary_c, const float coors_x_min,
const float coors_y_min, const float coors_z_min, const float voxel_x,
const float voxel_y, const float voxel_z, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z, const int32_t deal_num) {
// x - coors_x_min
__bang_sub_scalar((float *)points_x, (float *)points_x, coors_x_min,
deal_num);
// y - coors_y_min
__bang_sub_scalar((float *)points_y, (float *)points_y, coors_y_min,
deal_num);
// z - coors_z_min
__bang_sub_scalar((float *)points_z, (float *)points_z, coors_z_min,
deal_num);
// (x - coors_x_min) / voxel_x
__bang_mul_scalar((float *)points_x, (float *)points_x, 1.0 / voxel_x,
deal_num);
// (y - coors_y_min) / voxel_y
__bang_mul_scalar((float *)points_y, (float *)points_y, 1.0 / voxel_y,
deal_num);
// (z - coors_z_min) / voxel_z
__bang_mul_scalar((float *)points_z, (float *)points_z, 1.0 / voxel_z,
deal_num);
// c_x = floor((x - coors_x_min) / voxel_x)
__bang_floor((float *)auxiliary_a, (float *)points_x, deal_num);
__bang_float2int32((int32_t *)points_x, (float *)auxiliary_a, deal_num, 0);
// c_y = floor((y - coors_y_min) / voxel_y)
__bang_floor((float *)auxiliary_a, (float *)points_y, deal_num);
__bang_float2int32((int32_t *)points_y, (float *)auxiliary_a, deal_num, 0);
// c_z = floor((z - coors_z_min) / voxel_z)
__bang_floor((float *)auxiliary_a, (float *)points_z, deal_num);
__bang_float2int32((int32_t *)points_z, (float *)auxiliary_a, deal_num, 0);
// c_x >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_x, (int32_t)0,
deal_num);
// c_x < grid_x
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_x, grid_x,
deal_num);
// 0 <= c_x < grid_x
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// c_y >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_y, (int32_t)0,
deal_num);
// c_y < grid_y
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_y, grid_y,
deal_num);
// 0 <= c_y < grid_y
__bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// c_x >= 0 && c_x < grid_x && c_y >= 0 && c_y < grid_y
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)auxiliary_b, deal_num);
// c_z >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_z, (int32_t)0,
deal_num);
// c_z < grid_z
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_z, grid_z,
deal_num);
// 0 <= c_z < grid_z
__bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// 0 <= c_x < grid_x && 0 <= c_y < grid_y && 0 <= c_z < grid_z
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)auxiliary_b, deal_num);
__bang_not((int32_t *)auxiliary_c, (int32_t *)auxiliary_a, deal_num);
__bang_mul((int32_t *)points_x, (int32_t *)points_x, (int32_t *)auxiliary_a,
deal_num);
__bang_mul_scalar((int32_t *)auxiliary_b, (int32_t *)auxiliary_c,
(int32_t)(-1), deal_num);
__bang_add((int32_t *)points_x, (int32_t *)points_x, (int32_t *)auxiliary_b,
deal_num);
__bang_mul((int32_t *)points_y, (int32_t *)points_y, (int32_t *)auxiliary_a,
deal_num);
__bang_add((int32_t *)points_y, (int32_t *)points_y, (int32_t *)auxiliary_b,
deal_num);
__bang_mul((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_a,
deal_num);
__bang_add((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_b,
deal_num);
}
__mlu_func__ void computePoint2Voxel(char *coors_x, char *coors_y,
char *coors_z, const int32_t c_x,
const int32_t c_y, const int32_t c_z,
const int32_t max_points, int32_t *num,
int32_t *first_point,
const int32_t deal_idx,
const int32_t deal_num) {
__bang_eq_scalar((int32_t *)coors_x, (int32_t *)coors_x, c_x, deal_num);
__bang_eq_scalar((int32_t *)coors_y, (int32_t *)coors_y, c_y, deal_num);
__bang_eq_scalar((int32_t *)coors_z, (int32_t *)coors_z, c_z, deal_num);
__bang_mul((int32_t *)coors_x, (int32_t *)coors_x, (int32_t *)coors_y,
deal_num);
__bang_mul((int32_t *)coors_x, (int32_t *)coors_x, (int32_t *)coors_z,
deal_num);
if (*num == 0) {
*num = (int32_t)__bang_count((float *)coors_x, deal_num);
if (*num > 0) {
*first_point =
(int32_t)__bang_findfirst1((float *)coors_x, deal_num) + deal_idx;
}
} else {
*num += (int32_t)__bang_count((float *)coors_x, deal_num);
}
}
#endif
__mlu_global__ void MLUUnion1KernelDynamicVoxelize(
const float *points, int32_t *coors, const float voxel_x,
const float voxel_y, const float voxel_z, const float coors_x_min,
const float coors_y_min, const float coors_z_min, const float coors_x_max,
const float coors_y_max, const float coors_z_max, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z, const int32_t num_points,
const int32_t num_features) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
const int32_t points_rem = num_points % taskDim;
const int32_t points_per_core =
taskId < points_rem ? num_points / taskDim + 1 : num_points / taskDim;
const int32_t points_start = taskId < points_rem
? taskId * points_per_core
: taskId * points_per_core + points_rem;
const int32_t split_num = 9;
const int32_t deal_num =
PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(float), NFU_ALIGN_SIZE);
const int32_t repeat = points_per_core / deal_num;
const int32_t rem = points_per_core % deal_num;
const int32_t ping_pong_gap = 3 * deal_num * sizeof(float);
char *points_x = nram_buffer;
char *points_y = points_x + deal_num * sizeof(float);
char *points_z = points_y + deal_num * sizeof(float);
char *auxiliary_a = points_x + 2 * ping_pong_gap;
char *auxiliary_b = auxiliary_a + deal_num * sizeof(float);
char *auxiliary_c = auxiliary_b + deal_num * sizeof(float);
int32_t *coors_z_start = coors + points_start;
int32_t *coors_y_start = coors + num_points + points_start;
int32_t *coors_x_start = coors + num_points * 2 + points_start;
if (repeat > 0) {
__memcpy_async(points_x, points + points_start * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_y, points + points_start * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_z, points + points_start * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__asm__ volatile("sync;");
}
if (repeat > 1) {
__memcpy_async(points_x + ping_pong_gap,
points + (points_start + deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_y + ping_pong_gap,
points + (points_start + deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_z + ping_pong_gap,
points + (points_start + deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
computeDynamicVoxelize(points_x, points_y, points_z, auxiliary_a,
auxiliary_b, auxiliary_c, coors_x_min, coors_y_min,
coors_z_min, voxel_x, voxel_y, voxel_z, grid_x,
grid_y, grid_z, deal_num);
__asm__ volatile("sync;");
}
for (int32_t i = 0; i < repeat - 2; ++i) {
__memcpy_async(coors_x_start + i * deal_num,
points_x + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + i * deal_num,
points_y + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + i * deal_num,
points_z + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(points_x + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(
points_y + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
deal_num - 1);
__memcpy_async(
points_z + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
deal_num - 1);
computeDynamicVoxelize(points_x + ((i + 1) % 2) * ping_pong_gap,
points_y + ((i + 1) % 2) * ping_pong_gap,
points_z + ((i + 1) % 2) * ping_pong_gap,
auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min,
coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z,
grid_x, grid_y, grid_z, deal_num);
__asm__ volatile("sync;");
}
if (repeat >= 2) {
__memcpy_async(coors_x_start + (repeat - 2) * deal_num,
points_x + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + (repeat - 2) * deal_num,
points_y + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + (repeat - 2) * deal_num,
points_z + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
}
if (rem > 0) {
__memcpy_async(points_x + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), rem - 1);
__memcpy_async(
points_y + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
rem - 1);
__memcpy_async(
points_z + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
rem - 1);
}
if (repeat > 0) {
computeDynamicVoxelize(points_x + ((repeat - 1) % 2) * ping_pong_gap,
points_y + ((repeat - 1) % 2) * ping_pong_gap,
points_z + ((repeat - 1) % 2) * ping_pong_gap,
auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min,
coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z,
grid_x, grid_y, grid_z, deal_num);
}
__asm__ volatile("sync;");
if (repeat > 0) {
__memcpy_async(coors_x_start + (repeat - 1) * deal_num,
points_x + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + (repeat - 1) * deal_num,
points_y + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + (repeat - 1) * deal_num,
points_z + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
}
if (rem > 0) {
computeDynamicVoxelize(points_x + (repeat % 2) * ping_pong_gap,
points_y + (repeat % 2) * ping_pong_gap,
points_z + (repeat % 2) * ping_pong_gap, auxiliary_a,
auxiliary_b, auxiliary_c, coors_x_min, coors_y_min,
coors_z_min, voxel_x, voxel_y, voxel_z, grid_x,
grid_y, grid_z, rem);
__asm__ volatile("sync;");
__memcpy_async(coors_x_start + repeat * deal_num,
points_x + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + repeat * deal_num,
points_y + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + repeat * deal_num,
points_z + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
}
#endif
}
__mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors,
int32_t *point_to_pointidx,
int32_t *point_to_voxelidx,
const int32_t num_points,
const int32_t max_points) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
const int32_t split_num = 6;
const int32_t deal_num =
PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(int32_t), NFU_ALIGN_SIZE);
const int32_t ping_pong_gap = 3 * deal_num * sizeof(int32_t);
char *coors_x = nram_buffer;
char *coors_y = coors_x + deal_num * sizeof(int32_t);
char *coors_z = coors_y + deal_num * sizeof(int32_t);
int32_t *coors_z_start = coors;
int32_t *coors_y_start = coors + num_points;
int32_t *coors_x_start = coors + num_points * 2;
for (int32_t point_idx = taskId; point_idx < num_points;
point_idx += taskDim) {
if (coors_x_start[point_idx] == -1) {
point_to_pointidx[point_idx] = -1;
point_to_voxelidx[point_idx] = -1;
continue;
}
int32_t c_x = coors_x_start[point_idx];
int32_t c_y = coors_y_start[point_idx];
int32_t c_z = coors_z_start[point_idx];
int32_t deal_total_num = point_idx;
int32_t repeat = deal_total_num / deal_num;
int32_t rem = deal_total_num % deal_num;
int32_t num = 0;
int32_t first_point = -1;
if (repeat > 0) {
__memcpy_async(coors_x, coors_x_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_y, coors_y_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_z, coors_z_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__asm__ volatile("sync;");
}
for (int32_t i = 0; i < repeat - 1; ++i) {
__memcpy_async(coors_x + ((i + 1) % 2) * ping_pong_gap,
coors_x_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(coors_y + ((i + 1) % 2) * ping_pong_gap,
coors_y_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(coors_z + ((i + 1) % 2) * ping_pong_gap,
coors_z_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
computePoint2Voxel(
coors_x + (i % 2) * ping_pong_gap, coors_y + (i % 2) * ping_pong_gap,
coors_z + (i % 2) * ping_pong_gap, c_x, c_y, c_z, max_points, &num,
&first_point, i * deal_num, deal_num);
__asm__ volatile("sync;");
}
if (rem > 0) {
__memcpy_async(coors_x + (repeat % 2) * ping_pong_gap,
coors_x_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_y + (repeat % 2) * ping_pong_gap,
coors_y_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_z + (repeat % 2) * ping_pong_gap,
coors_z_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
}
if (repeat > 0) {
computePoint2Voxel(coors_x + ((repeat - 1) % 2) * ping_pong_gap,
coors_y + ((repeat - 1) % 2) * ping_pong_gap,
coors_z + ((repeat - 1) % 2) * ping_pong_gap, c_x, c_y,
c_z, max_points, &num, &first_point,
(repeat - 1) * deal_num, deal_num);
}
__asm__ volatile("sync;");
if (rem > 0) {
computePoint2Voxel(coors_x + (repeat % 2) * ping_pong_gap,
coors_y + (repeat % 2) * ping_pong_gap,
coors_z + (repeat % 2) * ping_pong_gap, c_x, c_y, c_z,
max_points, &num, &first_point, repeat * deal_num,
rem);
__asm__ volatile("sync;");
}
if (num == 0) {
point_to_pointidx[point_idx] = point_idx;
} else if (num > 0) {
point_to_pointidx[point_idx] = first_point;
}
if (num < max_points) {
point_to_voxelidx[point_idx] = num;
} else {
point_to_voxelidx[point_idx] = -1;
}
}
#endif
}
__mlu_global__ void MLUUnion1KernelCalcPointsPerVoxel(
int32_t *point_to_pointidx, int32_t *point_to_voxelidx,
int32_t *coor_to_voxelidx, int32_t *num_points_per_voxel,
int32_t *voxel_num, const int32_t max_voxels, const int32_t num_points) {
#if __BANG_ARCH__ >= 322
if (coreId == 0) {
int32_t voxel_num_temp = 0;
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
int32_t point_pos_in_voxel = point_to_voxelidx[point_idx];
coor_to_voxelidx[point_idx] = -1;
if (point_pos_in_voxel == -1) {
continue;
} else if (point_pos_in_voxel == 0) {
int32_t voxel_idx = voxel_num_temp;
if (voxel_num_temp >= max_voxels) {
continue;
}
voxel_num_temp += 1;
coor_to_voxelidx[point_idx] = voxel_idx;
num_points_per_voxel[voxel_idx] = 1;
} else {
int32_t point_idx_temp = point_to_pointidx[point_idx];
int32_t voxel_idx = coor_to_voxelidx[point_idx_temp];
if (voxel_idx != -1) {
coor_to_voxelidx[point_idx] = voxel_idx;
num_points_per_voxel[voxel_idx] += 1;
}
}
}
*voxel_num = voxel_num_temp;
}
#endif
}
__mlu_global__ void MLUUnion1KernelAssignVoxelsCoors(
const float *points, int32_t *temp_coors, int32_t *point_to_voxelidx,
int32_t *coor_to_voxelidx, float *voxels, int32_t *coors,
const int32_t max_points, const int32_t num_points,
const int32_t num_features) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
int32_t points_per_core = num_points / taskDim;
int32_t points_rem = num_points % taskDim;
int32_t points_start = taskId < points_rem
? taskId * (points_per_core + 1)
: taskId * points_per_core + points_rem;
int32_t points_end = taskId < points_rem ? points_start + points_per_core + 1
: points_start + points_per_core;
for (int32_t point_idx = points_start; point_idx < points_end; ++point_idx) {
int32_t num = point_to_voxelidx[point_idx];
int32_t voxel_idx = coor_to_voxelidx[point_idx];
if (num > -1 && voxel_idx > -1) {
float *voxels_offset =
voxels + voxel_idx * max_points * num_features + num * num_features;
const float *points_offset = points + point_idx * num_features;
__memcpy_async(voxels_offset, points_offset, num_features * sizeof(float),
GDRAM2GDRAM);
if (num == 0) {
int32_t *coors_offset = coors + voxel_idx * 3;
__memcpy_async(coors_offset, temp_coors + point_idx, sizeof(int32_t),
GDRAM2GDRAM, sizeof(int32_t),
num_points * sizeof(int32_t), 2);
}
}
}
__asm__ volatile("sync;");
#endif
}
void KernelDynamicVoxelize(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const void *points, void *coors,
const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min,
const float coors_y_min, const float coors_z_min,
const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z,
const int32_t num_points,
const int32_t num_features) {
MLUUnion1KernelDynamicVoxelize<<<k_dim, k_type, queue>>>(
(float *)points, (int32_t *)coors, voxel_x, voxel_y, voxel_z, coors_x_min,
coors_y_min, coors_z_min, coors_x_max, coors_y_max, coors_z_max, grid_x,
grid_y, grid_z, num_points, num_features);
}
void KernelPoint2Voxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *coors, void *point_to_pointidx,
void *point_to_voxelidx, const int32_t num_points,
const int32_t max_points) {
MLUUnion1KernelPoint2Voxel<<<k_dim, k_type, queue>>>(
(int32_t *)coors, (int32_t *)point_to_pointidx,
(int32_t *)point_to_voxelidx, num_points, max_points);
}
void KernelCalcPointsPerVoxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *point_to_pointidx,
void *point_to_voxelidx, void *coor_to_voxelidx,
void *num_points_per_voxel, void *voxel_num,
const int32_t max_voxels,
const int32_t num_points) {
MLUUnion1KernelCalcPointsPerVoxel<<<k_dim, k_type, queue>>>(
(int32_t *)point_to_pointidx, (int32_t *)point_to_voxelidx,
(int32_t *)coor_to_voxelidx, (int32_t *)num_points_per_voxel,
(int32_t *)voxel_num, max_voxels, num_points);
}
void KernelAssignVoxelsCoors(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const void *points,
void *temp_coors, void *point_to_voxelidx,
void *coor_to_voxelidx, void *voxels, void *coors,
const int32_t max_points, const int32_t num_points,
const int32_t num_features) {
MLUUnion1KernelAssignVoxelsCoors<<<k_dim, k_type, queue>>>(
(float *)points, (int32_t *)temp_coors, (int32_t *)point_to_voxelidx,
(int32_t *)coor_to_voxelidx, (float *)voxels, (int32_t *)coors,
max_points, num_points, num_features);
}
mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp
0 → 100644
View file @
0c23eb02
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
void
BoxIouRotatedMLUKernelLauncher
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
)
{
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
auto
boxes1_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes1
,
boxes1
.
suggest_memory_format
());
auto
boxes2_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes2
,
boxes2
.
suggest_memory_format
());
auto
ious_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
ious
,
ious
.
suggest_memory_format
());
MluOpTensorDescriptor
boxes1_desc
,
boxes2_desc
,
ious_desc
;
boxes1_desc
.
set
(
boxes1_contiguous
);
boxes2_desc
.
set
(
boxes2_contiguous
);
ious_desc
.
set
(
ious_contiguous
);
auto
boxes1_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes1_contiguous
);
auto
boxes2_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes2_contiguous
);
auto
ious_impl
=
torch_mlu
::
getMluTensorImpl
(
ious_contiguous
);
auto
boxes1_ptr
=
boxes1_impl
->
cnnlMalloc
();
auto
boxes2_ptr
=
boxes2_impl
->
cnnlMalloc
();
auto
ious_ptr
=
ious_impl
->
cnnlMalloc
();
CNLOG
(
INFO
)
<<
"Call mluOpBoxIouRotated()."
;
mluOpBoxIouRotated
(
handle
,
mode_flag
,
aligned
,
boxes1_desc
.
desc
(),
boxes1_ptr
,
boxes2_desc
.
desc
(),
boxes2_ptr
,
ious_desc
.
desc
(),
ious_ptr
);
}
void
box_iou_rotated_mlu
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
)
{
BoxIouRotatedMLUKernelLauncher
(
boxes1
,
boxes2
,
ious
,
mode_flag
,
aligned
);
}
void
box_iou_rotated_impl
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
);
REGISTER_DEVICE_IMPL
(
box_iou_rotated_impl
,
MLU
,
box_iou_rotated_mlu
);
mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp
View file @
0c23eb02
...
...
@@ -10,114 +10,30 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelIou3d
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
data_type_input
,
const
void
*
boxes_dram
,
const
int
input_box_num
,
const
float
iou_threshold
,
void
*
workspace
,
void
*
output_size
,
void
*
output
);
int
selectType
(
uint32_t
use_job
,
int
box_num_per_core
)
{
// the box_num_per_core should be at least 256, otherwise the real IO
// bandwidth would be very low
while
(
box_num_per_core
<
256
&&
use_job
>=
4
)
{
box_num_per_core
*=
2
;
use_job
/=
2
;
}
return
use_job
;
}
static
cnnlStatus_t
policyFunc
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
int
&
core_num_per_class
,
const
int
input_box_num
)
{
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
uint32_t
job_limit
=
getJobLimitCapability
();
uint32_t
core_number
=
job_limit
;
int
box_num_per_core
=
(
input_box_num
+
core_number
-
1
)
/
core_number
;
int
use_job
=
selectType
(
job_limit
,
box_num_per_core
);
// initiate k_type as Union1
k_dim
->
x
=
core_dim
;
k_dim
->
y
=
1
;
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
switch
(
job_limit
)
{
case
CN_KERNEL_CLASS_BLOCK
:
case
CN_KERNEL_CLASS_UNION
:
case
CN_KERNEL_CLASS_UNION2
:
case
CN_KERNEL_CLASS_UNION4
:
case
CN_KERNEL_CLASS_UNION8
:
case
CN_KERNEL_CLASS_UNION16
:
{
if
(
use_job
<
4
)
{
k_dim
->
x
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
}
else
if
(
use_job
==
4
)
{
k_dim
->
x
=
core_dim
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
else
{
k_dim
->
x
=
use_job
;
*
k_type
=
(
cnrtFunctionType_t
)
use_job
;
}
};
break
;
default:
LOG
(
WARNING
)
<<
"[cnnlNms_v2]: got unsupported job limit number."
<<
" Use default CN_KERNEL_CLASS_UNION1 with UNION1 task."
;
}
return
CNNL_STATUS_SUCCESS
;
}
#include "mlu_common_helper.h"
void
IoU3DNMS3DMLUKernelLauncher
(
Tensor
boxes
,
Tensor
&
keep
,
Tensor
&
keep_num
,
float
iou_threshold
)
{
// dimension parameters check
TORCH_CHECK
(
boxes
.
dim
()
==
2
,
"boxes should be a 2d tensor, got "
,
boxes
.
dim
(),
"D"
);
TORCH_CHECK
(
boxes
.
size
(
1
)
==
7
,
"boxes should have 7 elements in dimension 1, got "
,
boxes
.
size
(
1
));
// data type check
TORCH_CHECK
(
boxes
.
scalar_type
()
==
at
::
kFloat
||
boxes
.
scalar_type
()
==
at
::
kHalf
,
"data type of boxes should be Float or Half, got "
,
boxes
.
scalar_type
());
if
(
boxes
.
numel
()
==
0
)
{
return
;
}
const
size_t
max_input_num
=
2147483648
;
// 2^31, 2G num
TORCH_CHECK
(
boxes
.
numel
()
<
max_input_num
,
"boxes.numel() should be less than 2147483648, got "
,
boxes
.
numel
());
int
input_box_num
=
boxes
.
size
(
0
);
cnrtDataType_t
data_type_input
=
torch_mlu
::
toCnrtDtype
(
boxes
.
dtype
());
cnrtDim3_t
k_dim
;
cnrtJobType_t
k_type
;
int
core_num_per_class
;
policyFunc
(
&
k_dim
,
&
k_type
,
core_num_per_class
,
input_box_num
);
// transpose boxes (n, 7) to (7, n) for better performance
auto
boxes_t
=
boxes
.
transpose
(
0
,
1
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes_t
);
auto
output
=
at
::
empty
({
input_box_num
},
boxes
.
options
().
dtype
(
at
::
kLong
));
int
input_box_num
=
boxes
.
size
(
0
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes
);
auto
output
=
keep
.
to
(
boxes
.
options
().
dtype
(
at
::
kInt
));
auto
output_size
=
at
::
empty
({
1
},
boxes
.
options
().
dtype
(
at
::
kInt
));
// workspace
const
int
info_num
=
7
;
// x, y,z, dx, dy, dz,angle
size_t
space_size
=
0
;
if
(
boxes
.
scalar_type
()
==
at
::
kHalf
)
{
space_size
=
input_box_num
*
sizeof
(
int16_t
)
*
info_num
+
input_box_num
*
sizeof
(
float
)
+
sizeof
(
float
);
}
else
{
space_size
=
input_box_num
*
sizeof
(
float
)
*
(
info_num
+
1
)
+
sizeof
(
float
);
}
MluOpTensorDescriptor
boxes_desc
,
output_desc
;
boxes_desc
.
set
(
boxes_
);
output_desc
.
set
(
output
);
auto
workspace
=
at
::
empty
(
space_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
// workspace
size_t
workspace_size
=
0
;
auto
handle
=
mluOpGetCurrentHandle
();
mluOpGetNmsWorkspaceSize
(
handle
,
boxes_desc
.
desc
(),
NULL
,
&
workspace_size
);
auto
workspace
=
at
::
empty
(
workspace_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
boxes_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes_
);
auto
boxes_ptr
=
boxes_impl
->
cnnlMalloc
();
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
...
...
@@ -127,11 +43,29 @@ void IoU3DNMS3DMLUKernelLauncher(Tensor boxes, Tensor &keep, Tensor &keep_num,
auto
output_size_impl
=
torch_mlu
::
getMluTensorImpl
(
keep_num
);
auto
output_size_ptr
=
output_size_impl
->
cnnlMalloc
();
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
CNLOG
(
INFO
)
<<
"Launch Kernel KernelIou3d<<<Union"
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelIou3d
(
k_dim
,
k_type
,
queue
,
data_type_input
,
boxes_ptr
,
input_box_num
,
iou_threshold
,
workspace_ptr
,
output_size_ptr
,
output_ptr
);
// nms desc
mluOpNmsDescriptor_t
nms_desc
;
const
mluOpNmsBoxPointMode_t
box_mode
=
(
mluOpNmsBoxPointMode_t
)
0
;
const
mluOpNmsOutputMode_t
output_mode
=
(
mluOpNmsOutputMode_t
)
0
;
const
mluOpNmsAlgo_t
algo
=
(
mluOpNmsAlgo_t
)
0
;
const
mluOpNmsMethodMode_t
method_mode
=
(
mluOpNmsMethodMode_t
)
0
;
const
float
soft_nms_sigma
=
0.0
;
const
float
confidence_threshold
=
0.0
;
const
int
input_layout
=
0
;
const
bool
pad_to_max_output_size
=
false
;
const
int
max_output_size
=
input_box_num
;
const
float
offset
=
0.0
;
mluOpCreateNmsDescriptor
(
&
nms_desc
);
mluOpSetNmsDescriptor
(
nms_desc
,
box_mode
,
output_mode
,
algo
,
method_mode
,
iou_threshold
,
soft_nms_sigma
,
max_output_size
,
confidence_threshold
,
offset
,
input_layout
,
pad_to_max_output_size
);
mluOpNms
(
handle
,
nms_desc
,
boxes_desc
.
desc
(),
boxes_ptr
,
NULL
,
NULL
,
workspace_ptr
,
workspace_size
,
output_desc
.
desc
(),
output_ptr
,
output_size_ptr
);
mluOpDestroyNmsDescriptor
(
nms_desc
);
}
void
iou3d_nms3d_forward_mlu
(
const
Tensor
boxes
,
Tensor
&
keep
,
Tensor
&
keep_num
,
...
...
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
View file @
0c23eb02
...
...
@@ -18,8 +18,8 @@
#include "pytorch_device_registry.hpp"
#define MLUOP_MAJOR 0
#define MLUOP_MINOR
5
#define MLUOP_PATCHLEVEL
302
#define MLUOP_MINOR
6
#define MLUOP_PATCHLEVEL
0
mluOpDataType_t
getMluOpDataType
(
const
caffe2
::
TypeMeta
&
data_type
);
mluOpTensorLayout_t
getMluOpSuggestLayout
(
const
at
::
Tensor
&
input
);
...
...
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
View file @
0c23eb02
...
...
@@ -9,495 +9,117 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
typedef
enum
{
MS_DEFORM_ATTN_FORWARD_INVALID
=
0
,
/*!< Index is invalid. */
MS_DEFORM_ATTN_FORWARD_DEFAULT
=
1
,
/*!< MLUKernelMsDeformAttnForwardDefault */
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL
=
2
,
/*!< MLUKernelMsDeformAttnForwardSmallChannel */
}
MsDeformAttnForwardPolicy
;
void
KernelMsDeformAttnForwardDefault
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
char
*
data_value_gdram
,
const
char
*
data_spatial_shapes_gdram
,
const
char
*
data_level_start_index_gdram
,
const
char
*
data_sampling_loc_gdram
,
const
char
*
data_attn_weight_gdram
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
,
char
*
data_col_gdram
);
void
KernelMsDeformAttnForwardSmallChannel
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
char
*
data_value_gdram
,
const
char
*
data_spatial_shapes_gdram
,
const
char
*
data_level_start_index_gdram
,
const
char
*
data_sampling_loc_gdram
,
const
char
*
data_attn_weight_gdram
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
,
char
*
data_col_gdram
);
typedef
enum
{
MS_DEFORM_ATTN_BACKWARD_DEFAULT
=
0
,
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL
=
1
,
}
MsDeformAttnBackwardKernelPolicy
;
MsDeformAttnBackwardKernelPolicy
msDeformAttnBackwardPolicyFunc
(
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_points
,
const
int32_t
num_heads
)
{
const
int32_t
nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
const
int
num_hlp
=
num_heads
*
num_levels
*
num_points
;
int
num_per_time_theory
=
(
nram_size
-
num_levels
*
sizeof
(
float
)
-
3
*
num_levels
*
sizeof
(
int32_t
))
/
sizeof
(
float
)
/
(
8
*
PAD_UP
(
channels
,
32
)
+
28
)
/
PAD_UP
((
num_hlp
),
32
);
if
(
num_per_time_theory
>=
1
)
{
return
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL
;
}
return
MS_DEFORM_ATTN_BACKWARD_DEFAULT
;
}
void
KernelMsDeformAttnBackwardDefaultKernel
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
float
*
data_value
,
const
int32_t
*
spatial_shapes
,
const
int32_t
*
data_level_start_index
,
const
float
*
data_sampling_loc
,
const
float
*
data_attn_weight
,
const
float
*
grad_output
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
,
float
*
grad_value
,
float
*
grad_sampling_loc
,
float
*
grad_attn_weight
);
void
KernelMsDeformAttnBackwardSmallChannelsKernel
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
float
*
data_value
,
const
int32_t
*
spatial_shapes
,
const
int32_t
*
data_level_start_index
,
const
float
*
data_sampling_loc
,
const
float
*
data_attn_weight
,
const
float
*
grad_output
,
const
int32_t
batch
,
const
int32_t
spatial_size
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_query
,
const
int32_t
num_points
,
float
*
grad_value
,
float
*
grad_sampling_loc
,
float
*
grad_attn_weight
);
// policy function
MsDeformAttnForwardPolicy
msDeformAttnForwardPolicyFunc
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
)
{
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
MIN
((
batch_size
*
num_queries
*
num_heads
+
k_dim
->
x
-
1
)
/
k_dim
->
x
,
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
));
k_dim
->
z
=
1
;
#if __BANG_ARCH__ == 520
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
#else
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
#endif
int32_t
nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
if
(
num_levels
*
num_points
*
3
*
sizeof
(
int32_t
)
>
nram_size
)
{
return
MS_DEFORM_ATTN_FORWARD_DEFAULT
;
}
else
if
(
channels
>
nram_size
/
12
/
sizeof
(
float
)
||
channels
>
96
||
channels
<
16
)
{
return
MS_DEFORM_ATTN_FORWARD_DEFAULT
;
}
else
{
return
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL
;
}
}
// policy function for backward
static
void
policyFuncBackward
(
const
int32_t
batch_size
,
const
int32_t
num_queries
,
const
int32_t
num_heads
,
const
int32_t
num_levels
,
cnrtFunctionType_t
*
k_type
,
cnrtDim3_t
*
k_dim
)
{
size_t
cluster_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
size_t
core_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
x
=
core_limit
;
int32_t
total_num
=
batch_size
*
num_queries
*
num_heads
*
num_levels
;
size_t
total_num_align
=
CEIL_ALIGN
(
total_num
,
core_limit
);
k_dim
->
y
=
(
total_num_align
/
core_limit
)
>
cluster_limit
?
cluster_limit
:
(
total_num_align
/
core_limit
);
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
Tensor
ms_deform_attn_mlu_forward
(
const
Tensor
&
value
,
/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();
Tensor
MsDeformAttnForwardLauncher
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
// check contiguous
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
// check datatype
TORCH_CHECK
((
value
.
scalar_type
()
==
at
::
kFloat
),
"value type should be Float, got "
,
value
.
scalar_type
(),
"."
);
TORCH_CHECK
((
spatial_shapes
.
scalar_type
()
==
at
::
kInt
||
spatial_shapes
.
scalar_type
()
==
at
::
kLong
),
"spatial_shapes type should be Int, got "
,
spatial_shapes
.
scalar_type
(),
"."
);
TORCH_CHECK
((
level_start_index
.
scalar_type
()
==
at
::
kInt
||
level_start_index
.
scalar_type
()
==
at
::
kLong
),
"level_start_index type should be Int, got "
,
level_start_index
.
scalar_type
(),
"."
);
TORCH_CHECK
((
sampling_loc
.
scalar_type
()
==
at
::
kFloat
),
"sampling_loc type should be Float, got "
,
sampling_loc
.
scalar_type
(),
"."
);
TORCH_CHECK
((
attn_weight
.
scalar_type
()
==
at
::
kFloat
),
"attn_weight type should be Float, got "
,
attn_weight
.
scalar_type
(),
"."
);
// check shape
TORCH_CHECK
(
value
.
dim
()
==
4
,
"value should be a 4d tensor, got "
,
value
.
dim
(),
"D."
);
TORCH_CHECK
(
spatial_shapes
.
dim
()
==
2
,
"spatial_shapes should be a 2d tensor, got "
,
spatial_shapes
.
dim
(),
"D."
);
TORCH_CHECK
(
level_start_index
.
dim
()
==
1
,
"level_start_index should be a 1d tensor, got "
,
level_start_index
.
dim
(),
"D."
);
TORCH_CHECK
(
sampling_loc
.
dim
()
==
6
,
"sampling_loc should be a 6d tensor, got "
,
sampling_loc
.
dim
(),
"D."
);
TORCH_CHECK
(
attn_weight
.
dim
()
==
5
,
"attn_weight should be a 5d tensor, got "
,
attn_weight
.
dim
(),
"D."
);
auto
handle
=
mluOpGetCurrentHandle
();
const
int
batch_size
=
value
.
size
(
0
);
const
int
num_keys
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_queries
=
sampling_loc
.
size
(
1
);
const
int
num_points
=
sampling_loc
.
size
(
4
);
TORCH_CHECK
(
spatial_shapes
.
size
(
1
)
==
2
,
"the 2nd dimensions of spatial_shapes should be 2, got "
,
spatial_shapes
.
size
(
1
),
"."
);
TORCH_CHECK
(
sampling_loc
.
size
(
5
)
==
2
,
"the 6th dimensions of sampling_loc should be 2, got "
,
sampling_loc
.
size
(
5
),
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of sampling_loc should be batch_size, "
,
"but now the 1st dimension of sampling_loc is "
,
sampling_loc
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of attn_weight should be batch_size, "
,
"but now the 1st dimension of attn_weight is "
,
attn_weight
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of sampling_loc should be num_heads, "
,
"but now the 3rd dimension of sampling_loc is "
,
sampling_loc
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of attn_weight should be num_heads, "
,
"but now the 3rd dimension of attn_weight is "
,
attn_weight
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
level_start_index
.
size
(
0
)
==
num_levels
),
"the 1st dimensions of level_start_index should be num_levels, "
,
"but now the 1st dimension of level_start_index is "
,
level_start_index
.
size
(
0
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of sampling_loc should be num_levels, "
,
"but now the 4th dimension of sampling_loc is "
,
sampling_loc
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of attn_weight should be num_levels, "
,
"but now the 4th dimension of attn_weight is "
,
attn_weight
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of attn_weight should be num_queries, "
,
"but now the 2nd dimension of attn_weight is "
,
attn_weight
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
4
)
==
num_points
),
"the 5th dimensions of attn_weight should be num_points, "
,
"but now the 5th dimension of attn_weight is "
,
attn_weight
.
size
(
4
),
", and num_points is "
,
num_points
,
"."
);
auto
output
=
at
::
zeros
({
batch_size
,
num_queries
,
num_heads
,
channels
},
value
.
options
());
// large tensor check
const
size_t
max_input_size
=
2147483648
;
TORCH_CHECK
(
value
.
numel
()
<
max_input_size
,
"value element num should be less than 2^31, got "
,
value
.
numel
(),
"."
);
TORCH_CHECK
(
sampling_loc
.
numel
()
<
max_input_size
,
"sampling_loc element num should be less than 2^31, got "
,
sampling_loc
.
numel
(),
"."
);
TORCH_CHECK
(
output
.
numel
()
<
max_input_size
,
"output element num should be less than 2^31, got "
,
output
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
batch_size
!=
0
,
"batch_size should not be zero"
);
TORCH_CHECK
(
num_heads
!=
0
,
"num_heads should not be zero"
);
TORCH_CHECK
(
channels
!=
0
,
"channels should not be zero"
);
TORCH_CHECK
(
num_queries
!=
0
,
"num_queries should not be zero"
);
if
(
num_keys
==
0
||
num_levels
==
0
||
num_points
==
0
)
{
return
output
;
}
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
MsDeformAttnForwardPolicy
policy
=
msDeformAttnForwardPolicyFunc
(
&
k_dim
,
&
k_type
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
spatial_shapes_
=
spatial_shapes
.
to
(
at
::
kInt
);
auto
level_start_index_
=
level_start_index
.
to
(
at
::
kInt
);
// get ptr of tensors
auto
value_impl
=
torch_mlu
::
getMluTensorImpl
(
value
);
auto
value_ptr
=
value_impl
->
cnnlMalloc
();
auto
spatial_shapes_impl
=
torch_mlu
::
getMluTensorImpl
(
spatial_shapes_
);
auto
spatial_shapes_ptr
=
spatial_shapes_impl
->
cnnlMalloc
();
auto
level_start_index_impl
=
torch_mlu
::
getMluTensorImpl
(
level_start_index_
);
auto
level_start_index_ptr
=
level_start_index_impl
->
cnnlMalloc
();
auto
sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
sampling_loc
);
auto
sampling_loc_ptr
=
sampling_loc_impl
->
cnnlMalloc
();
auto
attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
attn_weight
);
auto
attn_weight_ptr
=
attn_weight_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
// get compute dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
value
.
dtype
());
// launch kernel
switch
(
policy
)
{
default:
{
VLOG
(
5
)
<<
"MsDeformAttnForward Policy not supported"
;
};
break
;
case
MS_DEFORM_ATTN_FORWARD_DEFAULT
:
{
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelMsDeformAttnForwardDefault
(
k_dim
,
k_type
,
queue
,
data_type
,
(
char
*
)
value_ptr
,
(
char
*
)
spatial_shapes_ptr
,
(
char
*
)
level_start_index_ptr
,
(
char
*
)
sampling_loc_ptr
,
(
char
*
)
attn_weight_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
char
*
)
output_ptr
);
break
;
}
case
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL
:
{
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelMsDeformAttnForwardSmallChannel
(
k_dim
,
k_type
,
queue
,
data_type
,
(
char
*
)
value_ptr
,
(
char
*
)
spatial_shapes_ptr
,
(
char
*
)
level_start_index_ptr
,
(
char
*
)
sampling_loc_ptr
,
(
char
*
)
attn_weight_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
char
*
)
output_ptr
);
break
;
}
}
auto
spatial_shapes_int
=
spatial_shapes
.
to
(
at
::
kInt
);
auto
level_start_index_int
=
level_start_index
.
to
(
at
::
kInt
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
output
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
value
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
spatial_shapes_int
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
level_start_index_int
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
sampling_loc
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
attn_weight
);
mluOpMsDeformAttnForward
(
handle
,
value_desc
.
desc
(),
value_ptr
,
spatial_shapes_int_desc
.
desc
(),
spatial_shapes_int_ptr
,
level_start_index_int_desc
.
desc
(),
level_start_index_int_ptr
,
sampling_loc_desc
.
desc
(),
sampling_loc_ptr
,
attn_weight_desc
.
desc
(),
attn_weight_ptr
,
im2col_step
,
output_desc
.
desc
(),
output_ptr
);
output
=
output
.
view
({
batch_size
,
num_queries
,
num_heads
*
channels
});
return
output
;
}
void
ms_d
eform
_a
ttn
_mlu_b
ackward
(
void
MsD
eform
A
ttn
B
ackward
Launcher
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
)
{
// check contiguous
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
AT_ASSERTM
(
grad_output
.
is_contiguous
(),
"grad_output tensor has to be contiguous"
);
// check datatype
TORCH_CHECK
((
value
.
scalar_type
()
==
at
::
kFloat
),
"value type should be Float, got "
,
value
.
scalar_type
(),
"."
);
TORCH_CHECK
((
spatial_shapes
.
scalar_type
()
==
at
::
kInt
||
spatial_shapes
.
scalar_type
()
==
at
::
kLong
),
"spatial_shapes type should be Int, got "
,
spatial_shapes
.
scalar_type
(),
"."
);
TORCH_CHECK
((
level_start_index
.
scalar_type
()
==
at
::
kInt
||
level_start_index
.
scalar_type
()
==
at
::
kLong
),
"level_start_index type should be Int, got "
,
level_start_index
.
scalar_type
(),
"."
);
TORCH_CHECK
((
sampling_loc
.
scalar_type
()
==
at
::
kFloat
),
"sampling_loc type should be Float, got "
,
sampling_loc
.
scalar_type
(),
"."
);
TORCH_CHECK
((
attn_weight
.
scalar_type
()
==
at
::
kFloat
),
"attn_weight type should be Float, got "
,
attn_weight
.
scalar_type
(),
"."
);
TORCH_CHECK
((
grad_output
.
scalar_type
()
==
at
::
kFloat
),
"grad_output type should be Float, got "
,
grad_output
.
scalar_type
(),
"."
);
auto
handle
=
mluOpGetCurrentHandle
();
auto
spatial_shapes_int
=
spatial_shapes
.
to
(
at
::
kInt
);
auto
level_start_index_int
=
level_start_index
.
to
(
at
::
kInt
);
const
int
batch_size
=
value
.
size
(
0
);
const
int
num_keys
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_queries
=
sampling_loc
.
size
(
1
);
const
int
num_points
=
sampling_loc
.
size
(
4
);
// Check shape.
TORCH_CHECK
(
spatial_shapes
.
size
(
1
)
==
2
,
"the 2nd dimensions of spatial_shapes should be 2, got "
,
spatial_shapes
.
size
(
1
),
"."
);
TORCH_CHECK
((
level_start_index
.
size
(
0
)
==
num_levels
),
"the 1st dimensions of level_start_index should be num_levels, "
,
"but now the 1st dimension of level_start_index is "
,
level_start_index
.
size
(
0
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of sampling_loc should be batch_size, "
,
"but now the 1st dimension of sampling_loc is "
,
sampling_loc
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of sampling_loc should be num_heads, "
,
"but now the 3rd dimension of sampling_loc is "
,
sampling_loc
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of sampling_loc should be num_levels, "
,
"but now the 4th dimension of sampling_loc is "
,
sampling_loc
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
(
sampling_loc
.
size
(
5
)
==
2
,
"the 6th dimensions of sampling_loc should be 2, got "
,
sampling_loc
.
size
(
5
),
"."
);
auto
grad_output_dim4
=
grad_output
.
view
({
batch_size
,
num_queries
,
num_heads
,
channels
});
// auto grad_output_dim4 = grad_output.view({batch_size, num_queries,
// num_heads, channels}).detach();
INITIAL_MLU_PARAM_WITH_TENSOR
(
value
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
spatial_shapes_int
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
level_start_index_int
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
sampling_loc
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
attn_weight
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_output_dim4
);
// INITIAL_MLU_PARAM_WITH_TENSOR(grad_output);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_value
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_sampling_loc
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_attn_weight
);
mluOpMsDeformAttnBackward
(
handle
,
value_desc
.
desc
(),
value_ptr
,
spatial_shapes_int_desc
.
desc
(),
spatial_shapes_int_ptr
,
level_start_index_int_desc
.
desc
(),
level_start_index_int_ptr
,
sampling_loc_desc
.
desc
(),
sampling_loc_ptr
,
attn_weight_desc
.
desc
(),
attn_weight_ptr
,
grad_output_dim4_desc
.
desc
(),
grad_output_dim4_ptr
,
im2col_step
,
grad_value_desc
.
desc
(),
grad_value_ptr
,
grad_sampling_loc_desc
.
desc
(),
grad_sampling_loc_ptr
,
grad_attn_weight_desc
.
desc
(),
grad_attn_weight_ptr
);
TORCH_CHECK
((
attn_weight
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of attn_weight should be batch_size, "
,
"but now the 1st dimension of attn_weight is "
,
attn_weight
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of attn_weight should be num_queries, "
,
"but now the 2nd dimension of attn_weight is "
,
attn_weight
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of attn_weight should be num_heads, "
,
"but now the 3rd dimension of attn_weight is "
,
attn_weight
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of attn_weight should be num_levels, "
,
"but now the 4th dimension of attn_weight is "
,
attn_weight
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
4
)
==
num_points
),
"the 5th dimensions of attn_weight should be num_points, "
,
"but now the 5th dimension of attn_weight is "
,
attn_weight
.
size
(
4
),
", and num_points is "
,
num_points
,
"."
);
TORCH_CHECK
((
grad_output
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of grad_output should be batch_size, "
,
"but now the 1st dimension of grad_output is "
,
grad_output
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
grad_output
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of grad_output should be num_queries, "
,
"but now the 2nd dimension of grad_output is "
,
grad_output
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
(
(
grad_output
.
size
(
2
)
==
num_heads
*
channels
),
"the 3rd dimensions of grad_output should be num_heads * channels, "
,
"but now the 3rd dimension of grad_output is "
,
grad_output
.
size
(
2
),
", and num_heads * channels is "
,
num_heads
*
channels
,
"."
);
// check zero element
TORCH_CHECK
(
batch_size
!=
0
,
"The batch_size is zero."
);
TORCH_CHECK
(
channels
!=
0
,
"The channels is zero."
);
TORCH_CHECK
(
num_keys
!=
0
,
"The num_keys is zero."
);
TORCH_CHECK
(
num_heads
!=
0
,
"The num_heads is zero."
);
TORCH_CHECK
(
num_queries
!=
0
,
"The num_queries is zero."
);
if
(
num_levels
==
0
||
num_points
==
0
)
{
return
;
}
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncBackward
(
batch_size
,
num_queries
,
num_heads
,
num_levels
,
&
k_type
,
&
k_dim
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
value_impl
=
torch_mlu
::
getMluTensorImpl
(
value
);
auto
value_ptr
=
value_impl
->
cnnlMalloc
();
auto
spatial_shapes_impl
=
torch_mlu
::
getMluTensorImpl
(
spatial_shapes
);
auto
spatial_shapes_ptr
=
spatial_shapes_impl
->
cnnlMalloc
();
auto
level_start_index_impl
=
torch_mlu
::
getMluTensorImpl
(
level_start_index
);
auto
level_start_index_ptr
=
level_start_index_impl
->
cnnlMalloc
();
auto
sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
sampling_loc
);
auto
sampling_loc_ptr
=
sampling_loc_impl
->
cnnlMalloc
();
auto
attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
attn_weight
);
auto
attn_weight_ptr
=
attn_weight_impl
->
cnnlMalloc
();
auto
grad_output_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_output
);
auto
grad_output_ptr
=
grad_output_impl
->
cnnlMalloc
();
auto
grad_value_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_value
);
auto
grad_value_ptr
=
grad_value_impl
->
cnnlMalloc
();
auto
grad_sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_sampling_loc
);
auto
grad_sampling_loc_ptr
=
grad_sampling_loc_impl
->
cnnlMalloc
();
auto
grad_attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_attn_weight
);
auto
grad_attn_weight_ptr
=
grad_attn_weight_impl
->
cnnlMalloc
();
}
// get comput dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
value
.
dtype
());
Tensor
ms_deform_attn_mlu_forward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
return
MsDeformAttnForwardLauncher
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
im2col_step
);
}
// launch kernel
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelMsDeformAttnBackward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
MsDeformAttnBackwardKernelPolicy
kernelPolicy
=
msDeformAttnBackwardPolicyFunc
(
channels
,
num_levels
,
num_points
,
num_heads
);
switch
(
kernelPolicy
)
{
default:
{
VLOG
(
5
)
<<
"NotImplemented."
;
}
break
;
case
MS_DEFORM_ATTN_BACKWARD_DEFAULT
:
{
KernelMsDeformAttnBackwardDefaultKernel
(
k_dim
,
k_type
,
queue
,
data_type
,
(
float
*
)
value_ptr
,
(
int32_t
*
)
spatial_shapes_ptr
,
(
int32_t
*
)
level_start_index_ptr
,
(
float
*
)
sampling_loc_ptr
,
(
float
*
)
attn_weight_ptr
,
(
float
*
)
grad_output_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
float
*
)
grad_value_ptr
,
(
float
*
)
grad_sampling_loc_ptr
,
(
float
*
)
grad_attn_weight_ptr
);
}
break
;
case
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL
:
{
KernelMsDeformAttnBackwardSmallChannelsKernel
(
k_dim
,
k_type
,
queue
,
data_type
,
(
float
*
)
value_ptr
,
(
int32_t
*
)
spatial_shapes_ptr
,
(
int32_t
*
)
level_start_index_ptr
,
(
float
*
)
sampling_loc_ptr
,
(
float
*
)
attn_weight_ptr
,
(
float
*
)
grad_output_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
float
*
)
grad_value_ptr
,
(
float
*
)
grad_sampling_loc_ptr
,
(
float
*
)
grad_attn_weight_ptr
);
}
break
;
}
void
ms_deform_attn_mlu_backward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
)
{
return
MsDeformAttnBackwardLauncher
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
grad_output
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
);
}
Tensor
ms_deform_attn_impl_forward
(
const
Tensor
&
value
,
...
...
@@ -515,5 +137,6 @@ void ms_deform_attn_impl_backward(
REGISTER_DEVICE_IMPL
(
ms_deform_attn_impl_forward
,
MLU
,
ms_deform_attn_mlu_forward
);
REGISTER_DEVICE_IMPL
(
ms_deform_attn_impl_backward
,
MLU
,
ms_deform_attn_mlu_backward
);
mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp
View file @
0c23eb02
...
...
@@ -10,123 +10,35 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelNms
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
data_type_input
,
const
void
*
boxes_ptr
,
const
void
*
scores_ptr
,
const
int
input_num_boxes
,
const
int
max_output_boxes
,
const
float
iou_threshold
,
const
float
offset
,
void
*
workspace_ptr
,
void
*
output_size_ptr
,
void
*
output_ptr
);
int
selectUnionType
(
uint32_t
use_job
,
int
box_num_per_core
)
{
// the box_num_per_core should be at least 256, otherwise the real IO
// bandwidth would be very low
while
(
box_num_per_core
<
256
&&
use_job
>=
4
)
{
box_num_per_core
*=
2
;
use_job
/=
2
;
}
return
use_job
;
}
static
cnnlStatus_t
policyFunc
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
int
&
core_num_per_class
,
const
int
input_box_num
)
{
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
uint32_t
cluster_number
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
uint32_t
job_limit
=
getJobLimitCapability
();
uint32_t
core_number
=
job_limit
;
int
box_num_per_core
=
(
input_box_num
+
core_number
-
1
)
/
core_number
;
int
use_job
=
selectUnionType
(
job_limit
,
box_num_per_core
);
// initiate k_type as Union1
k_dim
->
x
=
core_dim
;
k_dim
->
y
=
1
;
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
switch
(
job_limit
)
{
case
CN_KERNEL_CLASS_BLOCK
:
case
CN_KERNEL_CLASS_UNION
:
case
CN_KERNEL_CLASS_UNION2
:
case
CN_KERNEL_CLASS_UNION4
:
case
CN_KERNEL_CLASS_UNION8
:
case
CN_KERNEL_CLASS_UNION16
:
{
if
(
use_job
<
4
)
{
k_dim
->
x
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
}
else
if
(
use_job
==
4
)
{
k_dim
->
x
=
core_dim
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
else
{
k_dim
->
x
=
use_job
;
*
k_type
=
(
cnrtFunctionType_t
)
use_job
;
}
};
break
;
default:
LOG
(
WARNING
)
<<
"[cnnlNms_v2]: got unsupported job limit number."
<<
" Use default CN_KERNEL_CLASS_UNION1 with UNION1 task."
;
}
return
CNNL_STATUS_SUCCESS
;
}
#include "mlu_common_helper.h"
Tensor
NMSMLUKernelLauncher
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
// dimension parameters check
TORCH_CHECK
(
boxes
.
dim
()
==
2
,
"boxes should be a 2d tensor, got "
,
boxes
.
dim
(),
"D"
);
TORCH_CHECK
(
boxes
.
size
(
1
)
==
4
,
"boxes should have 4 elements in dimension 1, got "
,
boxes
.
size
(
1
));
TORCH_CHECK
(
scores
.
dim
()
==
1
,
"scores should be a 1d tensor, got "
,
scores
.
dim
(),
"D"
);
// data type check
TORCH_CHECK
(
boxes
.
scalar_type
()
==
scores
.
scalar_type
(),
"boxes should have the same type as scores"
);
TORCH_CHECK
(
boxes
.
scalar_type
()
==
at
::
kFloat
||
boxes
.
scalar_type
()
==
at
::
kHalf
,
"data type of boxes should be Float or Half, got "
,
boxes
.
scalar_type
());
if
(
boxes
.
numel
()
==
0
)
{
return
at
::
empty
({
0
},
boxes
.
options
().
dtype
(
at
::
kLong
));
}
int
input_num_boxes
=
boxes
.
size
(
0
);
int
max_output_boxes
=
boxes
.
size
(
0
);
cnrtDataType_t
data_type_input
=
torch_mlu
::
toCnrtDtype
(
boxes
.
dtype
());
cnrtDim3_t
k_dim
;
cnrtJobType_t
k_type
;
int
core_num_per_class
;
policyFunc
(
&
k_dim
,
&
k_type
,
core_num_per_class
,
input_num_boxes
);
// transpose boxes (n, 4) to (4, n) for better performance
auto
boxes_t
=
boxes
.
transpose
(
0
,
1
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes_t
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes
);
auto
scores_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
scores
);
auto
output
=
at
::
empty
({
max_output_boxes
},
boxes
.
options
().
dtype
(
at
::
k
Long
));
auto
output
=
at
::
empty
({
max_output_boxes
},
boxes
.
options
().
dtype
(
at
::
k
Int
));
auto
output_size
=
at
::
empty
({
1
},
scores
.
options
().
dtype
(
at
::
kInt
));
MluOpTensorDescriptor
boxes_desc
,
scores_desc
,
output_desc
;
boxes_desc
.
set
(
boxes_
);
scores_desc
.
set
(
scores_
);
output_desc
.
set
(
output
);
// workspace
const
int
info_num
=
5
;
// x1, x2, y1, y2 and score
size_t
space_size
=
0
;
if
(
boxes
.
scalar_type
()
==
at
::
kHalf
)
{
space_size
=
input_num_boxes
*
sizeof
(
int16_t
)
*
info_num
+
sizeof
(
float
);
}
else
{
space_size
=
input_num_boxes
*
sizeof
(
float
)
*
info_num
+
sizeof
(
float
);
}
#if __BANG_ARCH__ > 370
int
cluster_num
=
getCoreNumOfJobLimitCapability
()
/
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
space_size
+=
cluster_number
*
sizeof
(
float
)
*
7
;
#endif
auto
workspace
=
at
::
empty
(
space_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
size_t
workspace_size
=
0
;
auto
handle
=
mluOpGetCurrentHandle
();
mluOpGetNmsWorkspaceSize
(
handle
,
boxes_desc
.
desc
(),
scores_desc
.
desc
(),
&
workspace_size
);
auto
workspace
=
at
::
empty
(
workspace_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
boxes_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes_
);
auto
boxes_ptr
=
boxes_impl
->
cnnlMalloc
();
auto
scores_impl
=
torch_mlu
::
getMluTensorImpl
(
scores_
);
...
...
@@ -138,14 +50,31 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
auto
output_size_impl
=
torch_mlu
::
getMluTensorImpl
(
output_size
);
auto
output_size_ptr
=
output_size_impl
->
cnnlMalloc
();
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
CNLOG
(
INFO
)
<<
"Launch Kernel MLUUnionX NMS<<<Union"
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelNms
(
k_dim
,
k_type
,
queue
,
data_type_input
,
boxes_ptr
,
scores_ptr
,
input_num_boxes
,
max_output_boxes
,
iou_threshold
,
offset
,
workspace_ptr
,
output_size_ptr
,
output_ptr
);
// nms desc
mluOpNmsDescriptor_t
nms_desc
;
const
mluOpNmsBoxPointMode_t
box_mode
=
(
mluOpNmsBoxPointMode_t
)
0
;
const
mluOpNmsOutputMode_t
output_mode
=
(
mluOpNmsOutputMode_t
)
0
;
const
mluOpNmsAlgo_t
algo
=
(
mluOpNmsAlgo_t
)
0
;
const
mluOpNmsMethodMode_t
method_mode
=
(
mluOpNmsMethodMode_t
)
0
;
const
float
soft_nms_sigma
=
0.0
;
const
float
confidence_threshold
=
0.0
;
const
int
input_layout
=
0
;
const
bool
pad_to_max_output_size
=
false
;
const
int
max_output_size
=
max_output_boxes
;
mluOpCreateNmsDescriptor
(
&
nms_desc
);
mluOpSetNmsDescriptor
(
nms_desc
,
box_mode
,
output_mode
,
algo
,
method_mode
,
iou_threshold
,
soft_nms_sigma
,
max_output_size
,
confidence_threshold
,
(
float
)
offset
,
input_layout
,
pad_to_max_output_size
);
mluOpNms
(
handle
,
nms_desc
,
boxes_desc
.
desc
(),
boxes_ptr
,
scores_desc
.
desc
(),
scores_ptr
,
workspace_ptr
,
workspace_size
,
output_desc
.
desc
(),
output_ptr
,
output_size_ptr
);
mluOpDestroyNmsDescriptor
(
nms_desc
);
int
output_num
=
*
static_cast
<
int
*>
(
output_size
.
cpu
().
data_ptr
());
return
output
.
slice
(
0
,
0
,
output_num
);
auto
ret
=
output
.
to
(
boxes
.
options
().
dtype
(
at
::
kLong
));
return
ret
.
slice
(
0
,
0
,
output_num
);
}
Tensor
nms_mlu
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
...
...
mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp
View file @
0c23eb02
...
...
@@ -9,26 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelRoiAlign
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
input
,
const
void
*
rois
,
const
int
channels
,
const
bool
aligned
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
input_height
,
const
int
input_width
,
const
int
sampling_ratio
,
const
float
spatial_scale
,
const
int
num_rois
,
void
*
output
);
void
KernelRoiAlignBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
dtype
,
const
void
*
grads
,
const
void
*
boxes
,
void
*
grads_image
,
const
int
boxes_num
,
const
int
hi
,
const
int
wi
,
const
int
c
,
const
int
no
,
const
int
ho
,
const
int
wo
,
const
float
spatial_scale
,
const
int
sampling_ratio
,
const
bool
aligned
);
#include "mlu_common_helper.h"
void
ROIAlignForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
rois
,
Tensor
output
,
Tensor
argmax_y
,
Tensor
argmax_x
,
...
...
@@ -36,17 +17,7 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
float
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
)
{
// params check
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
,
"input type should be Float or Half, got "
,
input
.
scalar_type
());
TORCH_CHECK
(
rois
.
scalar_type
()
==
input
.
scalar_type
(),
"rois should have the same type as input"
);
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input should be a 4d tensor, got "
,
input
.
dim
(),
"D"
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2d tensor, got "
,
rois
.
dim
(),
"D"
);
TORCH_CHECK
(
pool_mode
==
1
,
"pool_mode only supports 'avg' currently"
);
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
auto
input_tensor
=
...
...
@@ -57,52 +28,56 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
int
height
=
input
.
size
(
2
);
int
width
=
input
.
size
(
3
);
if
(
output
.
numel
()
==
0
)
{
output
=
at
::
zeros
({
num_rois
,
channels
,
aligned_height
,
aligned_width
},
input
.
options
());
return
;
}
at
::
Tensor
output_tmp
=
auto
output_contiguous
=
at
::
empty
({
num_rois
,
channels
,
aligned_height
,
aligned_width
},
input
.
options
(),
memory_format
);
// get tensor impl
auto
self_impl
=
torch_mlu
::
getMluTensorImpl
(
input_tensor
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
tmp
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
contiguous
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
MluOpTensorDescriptor
input_desc
,
rois_desc
,
argmax_y_desc
,
argmax_x_desc
,
output_desc
;
input_desc
.
set_with_layout
(
input_tensor
,
MLUOP_LAYOUT_NHWC
);
rois_desc
.
set_with_layout
(
rois
,
MLUOP_LAYOUT_ARRAY
);
output_desc
.
set_with_layout
(
output_contiguous
,
MLUOP_LAYOUT_NHWC
);
// get the mlu ptr
auto
self_ptr
=
self_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
cnrtJobType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
k_dim
.
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
.
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
.
z
=
1
;
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
KernelRoiAlign
(
k_dim
,
k_type
,
queue
,
data_type
,
self_ptr
,
rois_ptr
,
channels
,
aligned
,
aligned_height
,
aligned_width
,
height
,
width
,
sampling_ratio
,
spatial_scale
,
num_rois
,
output_ptr
);
output
.
copy_
(
output_tmp
);
}
static
int
nearestPower2
(
int
x
)
{
x
--
;
x
|=
x
>>
1
;
x
|=
x
>>
2
;
x
|=
x
>>
4
;
x
|=
x
>>
8
;
x
|=
x
>>
16
;
x
++
;
return
x
;
mluOpRoiAlignForwardDescriptor_t
roialign_desc
;
mluOpCreateRoiAlignForwardDescriptor
(
&
roialign_desc
);
mluOpSetRoiAlignForwardDescriptor_v2
(
roialign_desc
,
aligned_height
,
aligned_width
,
sampling_ratio
,
spatial_scale
,
pool_mode
,
aligned
);
auto
handle
=
mluOpGetCurrentHandle
();
if
(
pool_mode
==
0
)
{
auto
argmax_y_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_y
,
memory_format
);
auto
argmax_x_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_x
,
memory_format
);
auto
argmax_x_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_x_contiguous
);
auto
argmax_y_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_y_contiguous
);
auto
argmax_x_ptr
=
argmax_x_impl
->
cnnlMalloc
();
auto
argmax_y_ptr
=
argmax_y_impl
->
cnnlMalloc
();
argmax_y_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
argmax_x_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
mluOpRoiAlignForward_v2
(
handle
,
roialign_desc
,
input_desc
.
desc
(),
self_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
output_desc
.
desc
(),
output_ptr
,
argmax_x_desc
.
desc
(),
argmax_x_ptr
,
argmax_y_desc
.
desc
(),
argmax_y_ptr
);
argmax_x
.
copy_
(
argmax_x_contiguous
);
argmax_y
.
copy_
(
argmax_y_contiguous
);
}
else
{
mluOpRoiAlignForward_v2
(
handle
,
roialign_desc
,
input_desc
.
desc
(),
self_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
output_desc
.
desc
(),
output_ptr
,
NULL
,
NULL
,
NULL
,
NULL
);
}
mluOpDestroyRoiAlignForwardDescriptor
(
roialign_desc
);
output
.
copy_
(
output_contiguous
);
}
void
ROIAlignBackwardMLUKernelLauncher
(
Tensor
grad
,
Tensor
rois
,
...
...
@@ -112,17 +87,7 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
)
{
// params check
TORCH_CHECK
(
grad
.
scalar_type
()
==
at
::
kFloat
||
grad
.
scalar_type
()
==
at
::
kHalf
,
"grad type should be Float or Half, got "
,
grad
.
scalar_type
());
TORCH_CHECK
(
rois
.
scalar_type
()
==
grad
.
scalar_type
(),
"rois should have the same type as grad"
);
TORCH_CHECK
(
grad
.
dim
()
==
4
,
"grad should be a 4d tensor, got "
,
grad
.
dim
(),
"D"
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2d tensor, got "
,
rois
.
dim
(),
"D"
);
TORCH_CHECK
(
pool_mode
==
1
,
"pool_mode only supports 'avg' currently"
);
int
batch_size
=
grad_input
.
size
(
0
);
int
channels
=
grad_input
.
size
(
1
);
int
height
=
grad_input
.
size
(
2
);
...
...
@@ -148,26 +113,40 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois,
auto
grad_input_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_input_
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get the mlu ptr
auto
grad_ptr
=
grad_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
grad_input_ptr
=
grad_input_impl
->
cnnlMalloc
();
cnrtJobType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
int
need_core
=
nearestPower2
(
boxes_num
);
int
union_number
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
uint32_t
dim_x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
uint32_t
dim_y
=
(
need_core
-
1
)
/
dim_x
+
1
;
dim_y
=
(
dim_y
>
union_number
)
?
union_number
:
dim_y
;
cnrtDim3_t
k_dim
=
{
dim_x
,
dim_y
,
1
};
cnrtDataType_t
k_dtype
=
torch_mlu
::
toCnrtDtype
(
grad
.
dtype
());
KernelRoiAlignBackward
(
k_dim
,
k_type
,
queue
,
k_dtype
,
grad_ptr
,
rois_ptr
,
grad_input_ptr
,
boxes_num
,
hi
,
wi
,
c
,
no
,
ho
,
wo
,
spatial_scale
,
sampling_ratio
,
aligned
);
MluOpTensorDescriptor
grads_desc
,
rois_desc
,
argmax_y_desc
,
argmax_x_desc
,
grad_input_desc
;
grads_desc
.
set_with_layout
(
grad_
,
MLUOP_LAYOUT_NHWC
);
rois_desc
.
set_with_layout
(
rois
,
MLUOP_LAYOUT_ARRAY
);
grad_input_desc
.
set_with_layout
(
grad_input_
,
MLUOP_LAYOUT_NHWC
);
auto
handle
=
mluOpGetCurrentHandle
();
if
(
pool_mode
==
0
)
{
auto
argmax_y_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_y
,
memory_format
);
auto
argmax_x_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_x
,
memory_format
);
auto
argmax_x_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_x_contiguous
);
auto
argmax_y_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_y_contiguous
);
auto
argmax_x_ptr
=
argmax_x_impl
->
cnnlMalloc
();
auto
argmax_y_ptr
=
argmax_y_impl
->
cnnlMalloc
();
argmax_y_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
argmax_x_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
mluOpRoiAlignBackward_v2
(
handle
,
grads_desc
.
desc
(),
grad_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
argmax_y_desc
.
desc
(),
argmax_x_ptr
,
argmax_y_desc
.
desc
(),
argmax_y_ptr
,
spatial_scale
,
sampling_ratio
,
aligned
,
pool_mode
,
grad_input_desc
.
desc
(),
grad_input_ptr
);
}
else
{
mluOpRoiAlignBackward_v2
(
handle
,
grads_desc
.
desc
(),
grad_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
NULL
,
NULL
,
NULL
,
NULL
,
spatial_scale
,
sampling_ratio
,
aligned
,
pool_mode
,
grad_input_desc
.
desc
(),
grad_input_ptr
);
}
grad_input
.
copy_
(
grad_input_
);
}
...
...
mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp
View file @
0c23eb02
...
...
@@ -9,49 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelPtsIdxOfVoxels
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
pool_method
,
const
int
boxes_num
,
const
int
pts_num
,
const
int
max_pts_each_voxel
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
const
void
*
rois
,
const
void
*
pts
,
int
*
pts_idx_of_voxels
);
void
KernelRoiawarePool3dForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
pool_method
,
const
int
boxes_num
,
const
int
pts_num
,
const
int
channels
,
const
int
max_pts_each_voxel
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
const
void
*
pts_feature
,
const
int
*
pts_idx_of_voxels
,
void
*
pooled_features
,
int
*
argmax
);
// policy function
static
void
kernelPtsIdxOfVoxelsPolicyFunc
(
const
int
boxes_num
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
unsigned
int
use_cluster
=
(
boxes_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
static
void
kernelRoiawarePool3dForwardPolicyFunc
(
const
int
boxes_num
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
const
int
voxels_num
=
boxes_num
*
out_x
*
out_y
*
out_z
;
unsigned
int
use_cluster
=
(
voxels_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
#include "mlu_common_helper.h"
void
RoiawarePool3dForwardMLUKernelLauncher
(
const
int
pool_method
,
const
int
boxes_num
,
const
int
pts_num
,
...
...
@@ -59,168 +17,65 @@ void RoiawarePool3dForwardMLUKernelLauncher(
const
int
out_y
,
const
int
out_z
,
const
Tensor
rois
,
const
Tensor
pts
,
const
Tensor
pts_feature
,
Tensor
pts_idx_of_voxels
,
Tensor
pooled_features
,
Tensor
argmax
)
{
// check datatype
TORCH_CHECK
(((
pts
.
scalar_type
()
==
rois
.
scalar_type
())
&&
(
pts_feature
.
scalar_type
()
==
rois
.
scalar_type
())
&&
(
pooled_features
.
scalar_type
()
==
rois
.
scalar_type
())),
"data types of rois, rois, pts_feature and pooled_features "
"should be the same, "
,
"but now rois type is "
,
rois
.
scalar_type
(),
", pts type is "
,
pts
.
scalar_type
(),
", pts_feature type is "
,
pts_feature
.
scalar_type
(),
", pooled_features type is "
,
pooled_features
.
scalar_type
(),
"."
);
TORCH_CHECK
(
(
rois
.
scalar_type
()
==
at
::
kFloat
||
rois
.
scalar_type
()
==
at
::
kHalf
),
"rois type should be Float or Half, got "
,
rois
.
scalar_type
(),
"."
);
TORCH_CHECK
((
pts_idx_of_voxels
.
scalar_type
()
==
at
::
kInt
),
"pts_idx_of_voxels type should be Int, got "
,
pts_idx_of_voxels
.
scalar_type
(),
"."
);
// check dim
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2D tensor, got "
,
rois
.
dim
(),
"D."
);
TORCH_CHECK
(
pts
.
dim
()
==
2
,
"pts should be a 2D tensor, got "
,
pts
.
dim
(),
"D."
);
TORCH_CHECK
(
pts_feature
.
dim
()
==
2
,
"pts_feature should be a 2D tensor, got "
,
pts_feature
.
dim
(),
"D."
);
TORCH_CHECK
(
pts_idx_of_voxels
.
dim
()
==
5
,
"pts_idx_of_voxels should be a 5D tensor, got "
,
pts_idx_of_voxels
.
dim
(),
"D."
);
TORCH_CHECK
(
pooled_features
.
dim
()
==
5
,
"pooled_features should be a 5D tensor, got "
,
pooled_features
.
dim
(),
"D."
);
// check shape
TORCH_CHECK
(((
rois
.
size
(
0
)
==
boxes_num
)
&&
(
rois
.
size
(
1
)
==
7
)),
"the dimensions of rois should be (boxes_num, 7), "
,
"but got ("
,
rois
.
size
(
0
),
", "
,
rois
.
size
(
1
),
") ."
);
TORCH_CHECK
(((
pts
.
size
(
0
)
==
pts_num
)
&&
(
pts
.
size
(
1
)
==
3
)),
"the dimensions of pts should be (pts_num, 3), "
,
"but got ("
,
pts
.
size
(
0
),
","
,
pts
.
size
(
1
),
")."
);
TORCH_CHECK
(
((
pts_feature
.
size
(
0
)
==
pts_num
)
&&
(
pts_feature
.
size
(
1
)
==
channels
)),
"the dimensions of pts_feature should be (pts_num, channels), "
,
"but got ("
,
pts_feature
.
size
(
0
),
","
,
pts_feature
.
size
(
1
),
")."
);
TORCH_CHECK
(((
pts_idx_of_voxels
.
size
(
0
)
==
boxes_num
)
&&
(
pts_idx_of_voxels
.
size
(
1
)
==
out_x
)
&&
(
pts_idx_of_voxels
.
size
(
2
)
==
out_y
)
&&
(
pts_idx_of_voxels
.
size
(
3
)
==
out_z
)
&&
(
pts_idx_of_voxels
.
size
(
4
)
==
max_pts_each_voxel
)),
"the dimensions of pts_idx_of_voxels should be (boxes_num, "
"out_x, out_y, out_z, max_pts_each_voxel), "
,
"but got ("
,
pts_idx_of_voxels
.
size
(
0
),
","
,
pts_idx_of_voxels
.
size
(
1
),
","
,
pts_idx_of_voxels
.
size
(
2
),
","
,
pts_idx_of_voxels
.
size
(
3
),
","
,
pts_idx_of_voxels
.
size
(
4
),
")."
);
TORCH_CHECK
(((
pooled_features
.
size
(
0
)
==
boxes_num
)
&&
(
pooled_features
.
size
(
1
)
==
out_x
)
&&
(
pooled_features
.
size
(
2
)
==
out_y
)
&&
(
pooled_features
.
size
(
3
)
==
out_z
)
&&
(
pooled_features
.
size
(
4
)
==
channels
)),
"the dimensions of pooled_features should be (boxes_num, out_x, "
"out_y, out_z, channels), "
,
"but got ("
,
pooled_features
.
size
(
0
),
","
,
pooled_features
.
size
(
1
),
","
,
pooled_features
.
size
(
2
),
","
,
pooled_features
.
size
(
3
),
","
,
pooled_features
.
size
(
4
),
")."
);
// check other params : pool_mothod
TORCH_CHECK
(((
pool_method
==
0
)
||
(
pool_method
==
1
)),
"the num of pool_method should be 0(max) or 1(avg), "
,
"but got "
,
pool_method
,
"."
);
// check large tensor
const
size_t
max_input_size
=
2147483648
;
TORCH_CHECK
(
rois
.
numel
()
<
max_input_size
,
"rois element num should be less than 2^31, got "
,
rois
.
numel
(),
"."
);
TORCH_CHECK
(
pts
.
numel
()
<
max_input_size
,
"pts element num should be less than 2^31, got "
,
pts
.
numel
(),
"."
);
TORCH_CHECK
(
pts_feature
.
numel
()
<
max_input_size
,
"pts_feature element num should be less than 2^31, got "
,
pts_feature
.
numel
(),
"."
);
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
<
max_input_size
,
"pts_idx_of_voxels element num should be less than 2^31, got "
,
pts_idx_of_voxels
.
numel
(),
"."
);
TORCH_CHECK
(
pooled_features
.
numel
()
<
max_input_size
,
"pooled_features element num should be less than 2^31, got "
,
pooled_features
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
rois
.
numel
()
!=
0
,
"rois.numel() should not be zero, got "
,
rois
.
numel
());
TORCH_CHECK
(
pts
.
numel
()
!=
0
,
"pts.numel() should not be zero, got "
,
pts
.
numel
());
TORCH_CHECK
(
pts_feature
.
numel
()
!=
0
,
"pts_feature.numel() should not be zero, got "
,
pts_feature
.
numel
());
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
!=
0
,
"pts_idx_of_voxels.numel() should not be zero, got "
,
pts_idx_of_voxels
.
numel
());
TORCH_CHECK
(
pooled_features
.
numel
()
!=
0
,
"pooled_features.numel() should not be zero, got "
,
pooled_features
.
numel
());
if
(
pool_method
==
0
)
{
// check datatype
TORCH_CHECK
((
argmax
.
scalar_type
()
==
at
::
kInt
),
"argmax type should be Int, got "
,
argmax
.
scalar_type
(),
"."
);
// check dim
TORCH_CHECK
(
argmax
.
dim
()
==
5
,
"argmax should be a 5D tensor, got "
,
argmax
.
dim
(),
"D."
);
// check shape
TORCH_CHECK
(((
argmax
.
size
(
0
)
==
boxes_num
)
&&
(
argmax
.
size
(
1
)
==
out_x
)
&&
(
argmax
.
size
(
2
)
==
out_y
)
&&
(
argmax
.
size
(
3
)
==
out_z
)
&&
(
argmax
.
size
(
4
)
==
channels
)),
"the dimensions of argmax should be (boxes_num, out_x, out_y, "
"out_z, channels), "
,
"but got ("
,
argmax
.
size
(
0
),
","
,
argmax
.
size
(
1
),
","
,
argmax
.
size
(
2
),
","
,
argmax
.
size
(
3
),
","
,
argmax
.
size
(
4
),
")."
);
// check large tensor
TORCH_CHECK
(
argmax
.
numel
()
<
max_input_size
,
"argmax element num should be less than 2^31, got "
,
argmax
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
argmax
.
numel
()
!=
0
,
"argmax.numel() should not be zero, got "
,
argmax
.
numel
());
// when pool_method is 0, which is max pool, init argmax data value to -1
argmax
.
fill_
(
static_cast
<
int
>
(
-
1
));
}
// calculate task one dimension
cnrtDim3_t
k1_dim
;
cnrtFunctionType_t
k1_type
;
kernelPtsIdxOfVoxelsPolicyFunc
(
boxes_num
,
&
k1_dim
,
&
k1_type
);
cnrtDim3_t
k2_dim
;
cnrtFunctionType_t
k2_type
;
kernelRoiawarePool3dForwardPolicyFunc
(
boxes_num
,
out_x
,
out_y
,
out_z
,
&
k2_dim
,
&
k2_type
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
auto
rois_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
rois
,
rois
.
suggest_memory_format
());
auto
pts_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts
,
pts
.
suggest_memory_format
());
auto
pts_feature_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_feature
,
pts_feature
.
suggest_memory_format
());
auto
argmax_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax
,
argmax
.
suggest_memory_format
());
auto
pts_idx_of_voxels_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_idx_of_voxels
,
pts_idx_of_voxels
.
suggest_memory_format
());
auto
pooled_features_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pooled_features
,
pooled_features
.
suggest_memory_format
());
MluOpTensorDescriptor
rois_desc
,
pts_desc
,
pts_feature_desc
,
argmax_desc
,
pts_idx_of_voxels_desc
,
pooled_features_desc
;
rois_desc
.
set
(
rois_contiguous
);
pts_desc
.
set
(
pts_contiguous
);
pts_feature_desc
.
set
(
pts_feature_contiguous
);
argmax_desc
.
set
(
argmax_contiguous
);
pts_idx_of_voxels_desc
.
set
(
pts_idx_of_voxels_contiguous
);
pooled_features_desc
.
set
(
pooled_features_contiguous
);
// allocate extra space for workspace
size_t
workspace_size
=
0
;
mluOpGetRoiawarePool3dForwardWorkspaceSize
(
handle
,
rois_desc
.
desc
(),
pts_desc
.
desc
(),
pts_feature_desc
.
desc
(),
&
workspace_size
);
auto
workspace
=
at
::
empty
(
workspace_size
,
rois
.
options
().
dtype
(
at
::
kByte
));
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois_contiguous
);
auto
pts_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_contiguous
);
auto
pts_feature_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_feature_contiguous
);
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_contiguous
);
auto
pts_idx_of_voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels_contiguous
);
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features_contiguous
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
// transpose points [pts_num, 3] -> [3, pts_num]
auto
pts_
=
pts
.
permute
({
1
,
0
}).
contiguous
();
auto
pts_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_
);
auto
pts_ptr
=
pts_impl
->
cnnlMalloc
();
// transpose points_features [pts_num, channels] -> [channels, pts_num]
auto
pts_feature_
=
pts_feature
.
permute
({
1
,
0
}).
contiguous
();
auto
pts_feature_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_feature_
);
auto
pts_feature_ptr
=
pts_feature_impl
->
cnnlMalloc
();
auto
pts_idx_of_voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels
);
auto
argmax_ptr
=
argmax_impl
->
cnnlMalloc
(
);
auto
pts_idx_of_voxels_ptr
=
pts_idx_of_voxels_impl
->
cnnlMalloc
();
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features
);
auto
pooled_features_ptr
=
pooled_features_impl
->
cnnlMalloc
();
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax
);
auto
argmax_ptr
=
argmax_impl
->
cnnlMalloc
();
// get compute dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
rois
.
dtype
());
// launch kernel PtsIdxOfVoxels
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernel PtsIdxOfVoxels<<<"
<<
k1_dim
.
x
<<
", "
<<
k1_dim
.
y
<<
", "
<<
k1_dim
.
z
<<
">>>"
;
KernelPtsIdxOfVoxels
(
k1_dim
,
k1_type
,
queue
,
data_type
,
pool_method
,
boxes_num
,
pts_num
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
rois_ptr
,
pts_ptr
,
(
int
*
)
pts_idx_of_voxels_ptr
);
// launch kernel RoiawarePool3dForward
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernel RoiawarePool3dForward<<<"
<<
k2_dim
.
x
<<
", "
<<
k2_dim
.
y
<<
", "
<<
k2_dim
.
z
<<
">>>"
;
KernelRoiawarePool3dForward
(
k2_dim
,
k2_type
,
queue
,
data_type
,
pool_method
,
boxes_num
,
pts_num
,
channels
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
pts_feature_ptr
,
(
int
*
)
pts_idx_of_voxels_ptr
,
pooled_features_ptr
,
(
int
*
)
argmax_ptr
);
CNLOG
(
INFO
)
<<
"Call mluOpRoiawarePool3dForward()."
;
mluOpRoiawarePool3dForward
(
handle
,
pool_method
,
boxes_num
,
pts_num
,
channels
,
rois_desc
.
desc
(),
rois_ptr
,
pts_desc
.
desc
(),
pts_ptr
,
pts_feature_desc
.
desc
(),
pts_feature_ptr
,
workspace_ptr
,
workspace_size
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
argmax_desc
.
desc
(),
argmax_ptr
,
pts_idx_of_voxels_desc
.
desc
(),
pts_idx_of_voxels_ptr
,
pooled_features_desc
.
desc
(),
pooled_features_ptr
);
}
void
roiaware_pool3d_forward_mlu
(
int
boxes_num
,
int
pts_num
,
int
channels
,
...
...
@@ -245,136 +100,46 @@ void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels,
REGISTER_DEVICE_IMPL
(
roiaware_pool3d_forward_impl
,
MLU
,
roiaware_pool3d_forward_mlu
);
void
KernelRoiawarePool3dBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
pool_method
,
const
int
boxes_num
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
const
int
channels
,
const
int
max_pts_each_voxel
,
const
int
*
pts_idx_of_voxels
,
const
int
*
argmax
,
const
void
*
grad_out
,
void
*
grad_in
);
static
void
kernelRoiawarePool3dBackwardPolicyFunc
(
const
int
boxes_num
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
const
int
voxels_num
=
boxes_num
*
out_x
*
out_y
*
out_z
;
unsigned
int
use_cluster
=
(
voxels_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
void
RoiawarePool3dBackwardMLUKernelLauncher
(
int
pool_method
,
int
boxes_num
,
int
out_x
,
int
out_y
,
int
out_z
,
int
channels
,
int
max_pts_each_voxel
,
const
Tensor
pts_idx_of_voxels
,
const
Tensor
argmax
,
const
Tensor
grad_out
,
Tensor
grad_in
)
{
// check datatype
TORCH_CHECK
((
pts_idx_of_voxels
.
scalar_type
()
==
at
::
kInt
),
"pts_idx_of_voxels type should be Int, got "
,
pts_idx_of_voxels
.
scalar_type
(),
"."
);
TORCH_CHECK
((
argmax
.
scalar_type
()
==
at
::
kInt
),
"argmax type should be Int, got "
,
argmax
.
scalar_type
(),
"."
);
TORCH_CHECK
((
grad_out
.
scalar_type
()
==
at
::
kFloat
||
grad_out
.
scalar_type
()
==
at
::
kHalf
),
"grad_out type should be Float or Half, got "
,
grad_out
.
scalar_type
(),
"."
);
TORCH_CHECK
((
grad_out
.
scalar_type
()
==
grad_in
.
scalar_type
()),
"data types of grad_out, grad_in, should be the same, "
,
"but now grad_out type is "
,
grad_out
.
scalar_type
(),
", grad_in type is "
,
grad_in
.
scalar_type
(),
"."
);
// check dim
TORCH_CHECK
(
pts_idx_of_voxels
.
dim
()
==
5
,
"pts_idx_of_voxels should be a 5D tensor, got "
,
pts_idx_of_voxels
.
dim
(),
"D."
);
TORCH_CHECK
(
argmax
.
dim
()
==
5
,
"argmax should be a 5D tensor, got "
,
argmax
.
dim
(),
"D."
);
TORCH_CHECK
(
grad_out
.
dim
()
==
5
,
"grad_out should be a 5D tensor, got "
,
grad_out
.
dim
(),
"D."
);
TORCH_CHECK
(
grad_in
.
dim
()
==
2
,
"grad_in should be a 2D tensor, got "
,
grad_in
.
dim
(),
"D."
);
// check shape
TORCH_CHECK
(((
pts_idx_of_voxels
.
size
(
0
)
==
boxes_num
)
&&
(
pts_idx_of_voxels
.
size
(
1
)
==
out_x
)
&&
(
pts_idx_of_voxels
.
size
(
2
)
==
out_y
)
&&
(
pts_idx_of_voxels
.
size
(
3
)
==
out_z
)
&&
(
pts_idx_of_voxels
.
size
(
4
)
==
max_pts_each_voxel
)),
"the dimensions of pts_idx_of_voxels should be (boxes_num, "
"out_x, out_y, out_z, max_pts_each_voxel), "
,
"but got ("
,
pts_idx_of_voxels
.
size
(
0
),
","
,
pts_idx_of_voxels
.
size
(
1
),
","
,
pts_idx_of_voxels
.
size
(
2
),
","
,
pts_idx_of_voxels
.
size
(
3
),
","
,
pts_idx_of_voxels
.
size
(
4
),
")."
);
TORCH_CHECK
(((
argmax
.
size
(
0
)
==
boxes_num
)
&&
(
argmax
.
size
(
1
)
==
out_x
)
&&
(
argmax
.
size
(
2
)
==
out_y
)
&&
(
argmax
.
size
(
3
)
==
out_z
)
&&
(
argmax
.
size
(
4
)
==
channels
)),
"the dimensions of argmax should be (boxes_num, out_x, out_y, "
"out_z, channels), "
,
"but got ("
,
argmax
.
size
(
0
),
","
,
argmax
.
size
(
1
),
","
,
argmax
.
size
(
2
),
","
,
argmax
.
size
(
3
),
","
,
argmax
.
size
(
4
),
")."
);
TORCH_CHECK
(((
grad_out
.
size
(
0
)
==
boxes_num
)
&&
(
grad_out
.
size
(
1
)
==
out_x
)
&&
(
grad_out
.
size
(
2
)
==
out_y
)
&&
(
grad_out
.
size
(
3
)
==
out_z
)
&&
(
grad_out
.
size
(
4
)
==
channels
)),
"the dimensions of grad_out should be (boxes_num, out_x, "
"out_y, out_z, channels), "
,
"but got ("
,
grad_out
.
size
(
0
),
","
,
grad_out
.
size
(
1
),
","
,
grad_out
.
size
(
2
),
","
,
grad_out
.
size
(
3
),
","
,
grad_out
.
size
(
4
),
")."
);
TORCH_CHECK
((
grad_in
.
size
(
1
)
==
channels
),
"the 1st dimensions of grad_in should be channels, "
,
"but got "
,
grad_in
.
size
(
1
),
"."
);
// check other params : pool_mothod
TORCH_CHECK
(((
pool_method
==
0
)
||
(
pool_method
==
1
)),
"the num of pool_method should be 0(max) or 1(avg), "
,
"but got "
,
pool_method
,
"."
);
// check large tensor
const
size_t
max_input_size
=
2147483648
;
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
<
max_input_size
,
"pts_idx_of_voxels element num should be less than 2^31, got "
,
pts_idx_of_voxels
.
numel
(),
"."
);
TORCH_CHECK
(
argmax
.
numel
()
<
max_input_size
,
"argmax element num should be less than 2^31, got "
,
argmax
.
numel
(),
"."
);
TORCH_CHECK
(
grad_out
.
numel
()
<
max_input_size
,
"grad_out element num should be less than 2^31, got "
,
grad_out
.
numel
(),
"."
);
TORCH_CHECK
(
grad_in
.
numel
()
<
max_input_size
,
"grad_in element num should be less than 2^31, got "
,
grad_in
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
!=
0
,
"pts_idx_of_voxels.numel() should not be zero, got "
,
pts_idx_of_voxels
.
numel
());
TORCH_CHECK
(
argmax
.
numel
()
!=
0
,
"argmax.numel() should not be zero, got "
,
argmax
.
numel
());
TORCH_CHECK
(
grad_out
.
numel
()
!=
0
,
"grad_out.numel() should not be zero, got "
,
grad_out
.
numel
());
TORCH_CHECK
(
grad_in
.
numel
()
!=
0
,
"grad_in.numel() should not be zero, got "
,
grad_in
.
numel
());
// calculate task one dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
kernelRoiawarePool3dBackwardPolicyFunc
(
boxes_num
,
out_x
,
out_y
,
out_z
,
&
k_dim
,
&
k_type
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// transpose points_features [pts_num, channels] -> [channels, pts_num]
auto
pts_idx_of_voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
auto
pts_idx_of_voxels_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_idx_of_voxels
,
pts_idx_of_voxels
.
suggest_memory_format
());
auto
argmax_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax
,
argmax
.
suggest_memory_format
());
auto
grad_out_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_out
,
grad_out
.
suggest_memory_format
());
auto
grad_in_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_in
,
grad_in
.
suggest_memory_format
());
MluOpTensorDescriptor
pts_idx_of_voxels_desc
,
argmax_desc
,
grad_out_desc
,
grad_in_desc
;
pts_idx_of_voxels_desc
.
set
(
pts_idx_of_voxels_contiguous
);
argmax_desc
.
set
(
argmax_contiguous
);
grad_out_desc
.
set
(
grad_out_contiguous
);
grad_in_desc
.
set
(
grad_in_contiguous
);
auto
pts_idx_of_voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels_contiguous
);
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_contiguous
);
auto
grad_out_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_out_contiguous
);
auto
grad_in_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_in_contiguous
);
auto
pts_idx_of_voxels_ptr
=
pts_idx_of_voxels_impl
->
cnnlMalloc
();
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax
);
auto
argmax_ptr
=
argmax_impl
->
cnnlMalloc
();
auto
grad_out_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_out
);
auto
grad_out_ptr
=
grad_out_impl
->
cnnlMalloc
();
auto
grad_in_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_in
);
auto
grad_in_ptr
=
grad_in_impl
->
cnnlMalloc
();
// get compute dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
grad_out
.
dtype
());
// launch kernel RoiawarePool3dForward
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernel RoiawarePool3dBackward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelRoiawarePool3dBackward
(
k_dim
,
k_type
,
queue
,
data_type
,
pool_method
,
boxes_num
,
out_x
,
out_y
,
out_z
,
channels
,
max_pts_each_voxel
,
(
int
*
)
pts_idx_of_voxels_ptr
,
(
int
*
)
argmax_ptr
,
grad_out_ptr
,
grad_in_ptr
);
CNLOG
(
INFO
)
<<
"Call mluOpRoiawarePool3dBackward()."
;
mluOpRoiawarePool3dBackward
(
handle
,
pool_method
,
boxes_num
,
out_x
,
out_y
,
out_z
,
channels
,
max_pts_each_voxel
,
pts_idx_of_voxels_desc
.
desc
(),
pts_idx_of_voxels_ptr
,
argmax_desc
.
desc
(),
argmax_ptr
,
grad_out_desc
.
desc
(),
grad_out_ptr
,
grad_in_desc
.
desc
(),
grad_in_ptr
);
}
void
roiaware_pool3d_backward_mlu
(
int
boxes_num
,
int
out_x
,
int
out_y
,
...
...
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
View file @
0c23eb02
...
...
@@ -9,84 +9,47 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelThreeNNForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
cnrtDataType_t
data_type
,
const
void
*
unknown
,
const
void
*
known
,
void
*
dist2
,
int
*
idx
,
const
int
b
,
const
int
n
,
const
int
m
);
#include "mlu_common_helper.h"
void
ThreeNNMLUKernelLauncher
(
int
b
,
int
n
,
int
m
,
const
Tensor
unknown
,
const
Tensor
known
,
Tensor
dist2
,
Tensor
idx
)
{
// Check dtype.
TORCH_CHECK
(
unknown
.
scalar_type
()
==
at
::
kFloat
||
unknown
.
scalar_type
()
==
at
::
kHalf
,
"unknown type should be Float or Half, got "
,
unknown
.
scalar_type
(),
"."
);
TORCH_CHECK
(
unknown
.
scalar_type
()
==
known
.
scalar_type
(),
"known should have the same type as unknown."
);
TORCH_CHECK
(
unknown
.
scalar_type
()
==
dist2
.
scalar_type
(),
"dist2 should have the same type as unknown."
);
TORCH_CHECK
(
idx
.
scalar_type
()
==
at
::
kInt
,
"idx type should be Int."
);
// Check shape.
TORCH_CHECK
(
unknown
.
dim
()
==
3
,
"unknown should be 3d tensor, got "
,
unknown
.
dim
(),
"D."
);
TORCH_CHECK
(
known
.
dim
()
==
3
,
"known should be 3d tensor, got "
,
known
.
dim
(),
"D."
);
TORCH_CHECK
(
unknown
.
size
(
0
)
==
known
.
size
(
0
),
"known.dim0 should be equal to unknown.dim0, got "
,
known
.
size
(
0
),
"."
);
TORCH_CHECK
(
unknown
.
size
(
2
)
==
3
,
"unknown dim2 should be 3, got "
,
unknown
.
size
(
2
),
"."
);
TORCH_CHECK
(
known
.
size
(
2
)
==
3
,
"known dim2 should be 3, got "
,
known
.
size
(
2
),
"."
);
// zero element check
TORCH_CHECK
(
unknown
.
numel
()
>
0
,
"unknown.numel should greater than zero, got "
,
unknown
.
numel
(),
"."
);
if
(
known
.
numel
()
==
0
)
{
// return if known zero element
return
;
}
// large tensor check
const
size_t
max_input_num
=
2147483648
;
// 2^31, 2G num
TORCH_CHECK
(
unknown
.
numel
()
<
max_input_num
,
"unknown.numel() should be less than 2147483648, got "
,
unknown
.
numel
(),
"."
);
TORCH_CHECK
(
known
.
numel
()
<
max_input_num
,
"known.numel() should be less than 2147483648, got "
,
known
.
numel
(),
"."
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
unknown_impl
=
torch_mlu
::
getMluTensorImpl
(
unknown
);
auto
unknown_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
unknown
,
unknown
.
suggest_memory_format
());
auto
known_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
known
,
known
.
suggest_memory_format
());
auto
dist2_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
dist2
,
dist2
.
suggest_memory_format
());
auto
idx_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
idx
,
idx
.
suggest_memory_format
());
MluOpTensorDescriptor
unknown_desc
,
known_desc
,
dist2_desc
,
idx_desc
;
unknown_desc
.
set
(
unknown_contiguous
);
known_desc
.
set
(
known_contiguous
);
dist2_desc
.
set
(
dist2_contiguous
);
idx_desc
.
set
(
idx_contiguous
);
auto
handle
=
mluOpGetCurrentHandle
();
size_t
workspace_size
=
0
;
mluOpGetThreeNNForwardWorkspaceSize
(
handle
,
known_desc
.
desc
(),
&
workspace_size
);
auto
known_workspace
=
at
::
empty
(
workspace_size
,
known
.
options
().
dtype
(
at
::
kByte
));
auto
unknown_impl
=
torch_mlu
::
getMluTensorImpl
(
unknown_contiguous
);
auto
known_impl
=
torch_mlu
::
getMluTensorImpl
(
known_contiguous
);
auto
dist2_impl
=
torch_mlu
::
getMluTensorImpl
(
dist2_contiguous
);
auto
idx_impl
=
torch_mlu
::
getMluTensorImpl
(
idx_contiguous
);
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
known_workspace
);
auto
unknown_ptr
=
unknown_impl
->
cnnlMalloc
();
auto
known_t
=
known
.
permute
({
0
,
2
,
1
}).
contiguous
();
auto
known_impl
=
torch_mlu
::
getMluTensorImpl
(
known_t
);
auto
known_ptr
=
known_impl
->
cnnlMalloc
();
auto
dist2_impl
=
torch_mlu
::
getMluTensorImpl
(
dist2
);
auto
dist2_ptr
=
dist2_impl
->
cnnlMalloc
();
auto
idx_impl
=
torch_mlu
::
getMluTensorImpl
(
idx
);
auto
idx_ptr
=
idx_impl
->
cnnlMalloc
();
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
cnrtJobType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
k_dim
.
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
.
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
.
z
=
1
;
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
unknown
.
dtype
());
// launch kernel
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelThreeNNForward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>."
;
KernelThreeNNForward
(
k_dim
,
k_type
,
queue
,
data_type
,
unknown_ptr
,
known_ptr
,
dist2_ptr
,
(
int
*
)
idx_ptr
,
b
,
n
,
m
);
mluOpThreeNNForward
(
handle
,
unknown_desc
.
desc
(),
unknown_ptr
,
known_desc
.
desc
(),
known_ptr
,
workspace_ptr
,
workspace_size
,
dist2_desc
.
desc
(),
dist2_ptr
,
idx_desc
.
desc
(),
idx_ptr
);
}
void
three_nn_forward_mlu
(
int
b
,
int
n
,
int
m
,
const
Tensor
unknown
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment