Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
0c23eb02
".github/vscode:/vscode.git/clone" did not exist on "62c22317760f5aa0ff181a6ad7b3f801fa8639b6"
Unverified
Commit
0c23eb02
authored
May 19, 2023
by
bdf
Committed by
GitHub
May 19, 2023
Browse files
Sync main with mmcv1.x branch (#2800)
parent
59c1418e
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
408 additions
and
7652 deletions
+408
-7652
docs/en/understand_mmcv/ops.md
docs/en/understand_mmcv/ops.md
+1
-1
docs/zh_cn/understand_mmcv/ops.md
docs/zh_cn/understand_mmcv/ops.md
+1
-1
mmcv/ops/box_iou_rotated.py
mmcv/ops/box_iou_rotated.py
+4
-1
mmcv/ops/csrc/common/mlu/iou3d_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/iou3d_mlu_kernel.mlu
+0
-431
mmcv/ops/csrc/common/mlu/iou3d_utils.hpp
mmcv/ops/csrc/common/mlu/iou3d_utils.hpp
+0
-695
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
+0
-2094
mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu
+0
-483
mmcv/ops/csrc/common/mlu/nms_utils.hpp
mmcv/ops/csrc/common/mlu/nms_utils.hpp
+0
-553
mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu
+0
-493
mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu
+0
-747
mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu
+0
-466
mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu
+0
-532
mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp
mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp
+54
-0
mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp
+35
-101
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
+2
-2
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
+86
-463
mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp
+37
-108
mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp
+68
-89
mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp
+87
-322
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
+33
-70
No files found.
docs/en/understand_mmcv/ops.md
View file @
0c23eb02
...
...
@@ -9,7 +9,7 @@ We implement common ops used in detection, segmentation, etc.
| BallQuery | | √ | √ | | |
| BBoxOverlaps | | √ | √ | √ | √ |
| BorderAlign | | √ | | | |
| BoxIouRotated | √ | √ |
| | |
| BoxIouRotated | √ | √ |
√
| | |
| BoxIouQuadri | √ | √ | | | |
| CARAFE | | √ | √ | | |
| ChamferDistance | | √ | | | |
...
...
docs/zh_cn/understand_mmcv/ops.md
View file @
0c23eb02
...
...
@@ -9,7 +9,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| BallQuery | | √ | √ | | |
| BBoxOverlaps | | √ | √ | √ | √ |
| BorderAlign | | √ | | | |
| BoxIouRotated | √ | √ |
| | |
| BoxIouRotated | √ | √ |
√
| | |
| BoxIouQuadri | √ | √ | | | |
| CARAFE | | √ | √ | | |
| ChamferDistance | | √ | | | |
...
...
mmcv/ops/box_iou_rotated.py
View file @
0c23eb02
...
...
@@ -133,7 +133,10 @@ def box_iou_rotated(bboxes1: torch.Tensor,
if
aligned
:
ious
=
bboxes1
.
new_zeros
(
rows
)
else
:
ious
=
bboxes1
.
new_zeros
(
rows
*
cols
)
if
bboxes1
.
device
.
type
==
'mlu'
:
ious
=
bboxes1
.
new_zeros
([
rows
,
cols
])
else
:
ious
=
bboxes1
.
new_zeros
(
rows
*
cols
)
if
not
clockwise
:
flip_mat
=
bboxes1
.
new_ones
(
bboxes1
.
shape
[
-
1
])
flip_mat
[
-
1
]
=
-
1
...
...
mmcv/ops/csrc/common/mlu/iou3d_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include "iou3d_utils.hpp"
#define SIZE_SRAM_BUF (MAX_SRAM_SIZE)
/* NRAM buffer
* Suppose deal N boxes once time.
----------------------------------------------------------------
| Basic |score (1N)+ |intersect_pts(48N)| |
| |valid_box(1N) |+ ordered_pts(48N)| temp_long(72N) |
| |+ temp_buffer(10N)| | |
|--------------------------|------------------|----------------|
| Reuse | null | null |rotated_pts(16N)|
|-------|------------------|------------------|----------------|
---------------------------------------------------------------------------
| Basic | dist_ram(24N) | valid_pts(24N) |box1(5N) |box1_buffer(5KB) |
| | |+ nums_in_ram(1N)|+ box2(5N)|+nram_save(5KB) |
|--------------------------|-----------------|----------|-----------------|
| Reuse | vec_buffer(5N) | null | null | null |
|-------|------------------|-----------------|----------|-----------------|
Total Basic Memory Size = 239N * sizeof(float) + 10KB
*/
__nram__ char nram_buffer[MAX_NRAM_SIZE];
__mlu_shared__ char sram_buffer[SIZE_SRAM_BUF];
template <typename T>
__mlu_func__ void iou3D_detection(int32_t &result_box_num, int32_t *output_data,
const T *boxes_data, float *scores_data,
const int core_limit, const int input_box_num,
const float iou_threshold,
mluMemcpyDirection_t scores_load_dir,
mluMemcpyDirection_t scores_store_dir,
mluMemcpyDirection_t boxes_load_dir) {
// NRAM divide by (2+4*COMPUTE_COUNT_ALIGN) copies of NRAM, counted by bytes
const int nram_save_limit_count = 256;
int box_read_limit_count = 256;
float div_thresh_iou = 1.0 / iou_threshold;
// every box require 239 * sizeof(float) space in nram;
const int32_t copies_of_nram = 239 * sizeof(float);
const int32_t limit = (MAX_NRAM_SIZE - 5 * box_read_limit_count * sizeof(T) -
nram_save_limit_count * sizeof(int32_t)) /
copies_of_nram;
// x,y,z,dx,dy,dz,angle
const T *input_x_ptr = boxes_data;
const T *input_y_ptr = input_x_ptr + input_box_num;
const T *input_dx_ptr = input_y_ptr + 2 * input_box_num;
const T *input_dy_ptr = input_dx_ptr + input_box_num;
const T *input_angle_ptr = input_dy_ptr + 2 * input_box_num;
float *input_score_ptr = scores_data;
// data split
int avg_cluster = 0;
int rem_cluster = 0;
int len_cluster = 0;
int cluster_offset = 0;
if (clusterDim > 0) {
// union
avg_cluster = input_box_num / clusterDim;
rem_cluster = input_box_num % clusterDim;
len_cluster = avg_cluster + (clusterId < rem_cluster ? 1 : 0);
cluster_offset = avg_cluster * clusterId +
(clusterId <= rem_cluster ? clusterId : rem_cluster);
} else {
// block
len_cluster = input_box_num;
cluster_offset = 0;
}
int len_core = input_box_num;
int input_offset = 0;
if (core_limit > 1) {
int avg_core = len_cluster / coreDim;
int rem_core = len_cluster % coreDim;
len_core = avg_core + (coreId < rem_core ? 1 : 0);
int core_offset =
avg_core * coreId + (coreId <= rem_core ? coreId : rem_core);
input_offset = cluster_offset + core_offset;
}
int32_t max_seg_pad = IOU3D_DOWN(limit, IOU3D_SIZE);
int repeat_iou_compute = len_core / max_seg_pad;
int remain_iou_compute = len_core % max_seg_pad;
// basic consistent memory layout
void *score = ((char *)nram_buffer);
void *valid_box = ((char *)score) + 1 * max_seg_pad * sizeof(float);
void *temp_buffer = ((char *)valid_box) + 1 * max_seg_pad * sizeof(float);
void *intersect_pts_x =
((char *)temp_buffer) + 10 * max_seg_pad * sizeof(float);
void *intersect_pts_y =
((char *)intersect_pts_x) + 24 * max_seg_pad * sizeof(float);
void *ordered_pts_x =
((char *)intersect_pts_y) + 24 * max_seg_pad * sizeof(float);
void *ordered_pts_y =
((char *)ordered_pts_x) + 24 * max_seg_pad * sizeof(float);
void *temp_long_1 =
((char *)ordered_pts_y) + 24 * max_seg_pad * sizeof(float);
void *temp_long_2 = ((char *)temp_long_1) + 24 * max_seg_pad * sizeof(float);
void *temp_long_3 = ((char *)temp_long_2) + 24 * max_seg_pad * sizeof(float);
void *dist_ram = ((char *)temp_long_3) + 24 * max_seg_pad * sizeof(float);
void *valid_pts = ((char *)dist_ram) + 24 * max_seg_pad * sizeof(float);
void *nums_in_ram = ((char *)valid_pts) + 24 * max_seg_pad * sizeof(float);
T *box1 = (T *)(((char *)nums_in_ram) + 1 * max_seg_pad * sizeof(float));
T *box2 = (T *)(((char *)box1) + 5 * max_seg_pad * sizeof(float));
void *box1_buffer = ((char *)box2) + 5 * max_seg_pad * sizeof(float);
int32_t *nram_save =
(int32_t *)(((char *)box1_buffer) + 5 * box_read_limit_count * sizeof(T));
// nram_save ~ nram_save_limit_count * sizeof(int32_t)
int nram_save_count = 0;
// reuse memory
void *rotated_pts1_x = ((char *)dist_ram);
void *rotated_pts1_y =
((char *)rotated_pts1_x) + 4 * max_seg_pad * sizeof(float);
void *rotated_pts2_x =
((char *)rotated_pts1_y) + 4 * max_seg_pad * sizeof(float);
void *rotated_pts2_y =
((char *)rotated_pts2_x) + 4 * max_seg_pad * sizeof(float);
void *vec_buffer = ((char *)temp_long_1) + 5 * max_seg_pad * sizeof(float);
// vec_buffer ~ 16 * max_seg_pad * sizeof(float)
// First, initialize ram with all 0, or could cause nan/inf unexcepted results
__bang_write_zero((unsigned char *)nram_buffer, copies_of_nram * max_seg_pad);
// number 8 and 0xff relay on box_read_limit_count initial as 256
const int max_box_seg_id = (input_box_num - 1) >> 8;
const int last_rem_box_number = ((input_box_num - 1) & 0xff) + 1;
for (int32_t cur_box = 0; cur_box < input_box_num; ++cur_box) {
__sync_all();
int box_seg_id = cur_box >> 8, box_id = cur_box & 0xff;
box_read_limit_count = box_seg_id == max_box_seg_id ? last_rem_box_number
: box_read_limit_count;
if (box_id == 0) {
// x,y,z,dx,dy,dz,angle
int offset_num = box_seg_id << 8;
// x
__memcpy((char *)box1_buffer, input_x_ptr + offset_num,
box_read_limit_count * 1 * sizeof(T), boxes_load_dir,
box_read_limit_count * 1 * sizeof(T),
box_read_limit_count * 1 * sizeof(T), 0);
// y
__memcpy((char *)box1_buffer + box_read_limit_count * 1 * sizeof(T),
input_y_ptr + offset_num, box_read_limit_count * 1 * sizeof(T),
boxes_load_dir, box_read_limit_count * 1 * sizeof(T),
box_read_limit_count * 1 * sizeof(T), 0);
// dx
__memcpy((char *)box1_buffer + box_read_limit_count * 2 * sizeof(T),
input_dx_ptr + offset_num, box_read_limit_count * 1 * sizeof(T),
boxes_load_dir, box_read_limit_count * 1 * sizeof(T),
box_read_limit_count * 1 * sizeof(T), 0);
// dy
__memcpy((char *)box1_buffer + box_read_limit_count * 3 * sizeof(T),
input_dy_ptr + offset_num, box_read_limit_count * 1 * sizeof(T),
boxes_load_dir, box_read_limit_count * 1 * sizeof(T),
box_read_limit_count * 1 * sizeof(T), 0);
// angle
__memcpy((char *)box1_buffer + box_read_limit_count * 4 * sizeof(T),
input_angle_ptr + offset_num,
box_read_limit_count * 1 * sizeof(T), boxes_load_dir,
box_read_limit_count * 1 * sizeof(T),
box_read_limit_count * 1 * sizeof(T), 0);
}
if (((float *)input_score_ptr)[cur_box] == 0) {
continue;
}
// save result
nram_save[nram_save_count] = cur_box;
result_box_num++;
nram_save_count++;
if (clusterId == 0 && coreId == 0 &&
nram_save_count == nram_save_limit_count) {
pvLock();
__memcpy(output_data, nram_save, nram_save_count * sizeof(int32_t),
NRAM2GDRAM);
pvUnlock();
output_data += nram_save_count;
nram_save_count = 0;
}
// prepare box1
// x
__bang_write_value((float *)box1, max_seg_pad,
float(((T *)box1_buffer)[box_id]));
// y
__bang_write_value(
(float *)box1 + max_seg_pad, max_seg_pad,
float(((T *)box1_buffer)[box_id + 1 * box_read_limit_count]));
// dx
__bang_write_value(
(float *)box1 + max_seg_pad * 2, max_seg_pad,
float(((T *)box1_buffer)[box_id + 2 * box_read_limit_count]));
// dy
__bang_write_value(
(float *)box1 + max_seg_pad * 3, max_seg_pad,
float(((T *)box1_buffer)[box_id + 3 * box_read_limit_count]));
// angle
__bang_write_value(
(float *)box1 + max_seg_pad * 4, max_seg_pad,
float(((T *)box1_buffer)[box_id + 4 * box_read_limit_count]));
float max_area = 1.0f *
((T *)box1_buffer)[box_id + 2 * box_read_limit_count] *
((T *)box1_buffer)[box_id + 3 * box_read_limit_count];
// update score
for (int i = 0; i <= repeat_iou_compute; i++) {
if (i == repeat_iou_compute && remain_iou_compute == 0) {
break;
}
int seg_len = max_seg_pad;
int cpy_len =
(i == repeat_iou_compute) ? remain_iou_compute : max_seg_pad;
// int half_offset = std::is_same<T, half>::value ? max_seg_pad * 5 : 0;
int half_offset = (sizeof(T) == sizeof(half)) ? max_seg_pad * 5 : 0;
// score
__memcpy(score, input_score_ptr + input_offset + i * max_seg_pad,
cpy_len * sizeof(float), scores_load_dir,
cpy_len * sizeof(float), cpy_len * sizeof(float), 0);
// x
__memcpy(box2 + half_offset, input_x_ptr + input_offset + i * max_seg_pad,
cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T),
cpy_len * 1 * sizeof(T), 0);
// y
__memcpy(box2 + half_offset + seg_len * 1,
input_y_ptr + input_offset + i * max_seg_pad,
cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T),
cpy_len * 1 * sizeof(T), 0);
// dx
__memcpy(box2 + half_offset + seg_len * 2,
input_dx_ptr + input_offset + i * max_seg_pad,
cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T),
cpy_len * 1 * sizeof(T), 0);
// dy
__memcpy(box2 + half_offset + seg_len * 3,
input_dy_ptr + input_offset + i * max_seg_pad,
cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T),
cpy_len * 1 * sizeof(T), 0);
// angle
__memcpy(box2 + half_offset + seg_len * 4,
input_angle_ptr + input_offset + i * max_seg_pad,
cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T),
cpy_len * 1 * sizeof(T), 0);
// if (std::is_same<T, half>::value) {
if (sizeof(T) == sizeof(half)) {
__bang_half2float((float *)box2, (half *)(box2 + half_offset),
seg_len * 5);
}
// Calculate rotated vertices
void *temp1_ram = ((char *)temp_buffer);
void *temp2_ram = ((char *)temp_buffer) + seg_len * sizeof(float);
void *temp3_ram = ((char *)temp_buffer) + 2 * seg_len * sizeof(float);
void *temp4_ram = ((char *)temp_buffer) + 3 * seg_len * sizeof(float);
getRotatedVertices((float *)rotated_pts1_x, (float *)rotated_pts1_y,
(float *)box1, (float *)temp1_ram, (float *)temp2_ram,
(float *)temp3_ram, (float *)temp4_ram, seg_len);
getRotatedVertices((float *)rotated_pts2_x, (float *)rotated_pts2_y,
(float *)box2, (float *)temp1_ram, (float *)temp2_ram,
(float *)temp3_ram, (float *)temp4_ram, seg_len);
__bang_write_zero((float *)valid_pts, 24 * seg_len);
__bang_write_zero((float *)nums_in_ram, seg_len);
__bang_write_value(((float *)valid_box), seg_len, 1.0f);
void *vec1_x = ((char *)vec_buffer);
void *vec1_y = ((char *)vec1_x) + 4 * seg_len * sizeof(float);
void *vec2_x = ((char *)vec1_y) + 4 * seg_len * sizeof(float);
void *vec2_y = ((char *)vec2_x) + 4 * seg_len * sizeof(float);
void *temp5_ram = ((char *)temp_buffer) + 4 * seg_len * sizeof(float);
void *temp6_ram = ((char *)temp_buffer) + 5 * seg_len * sizeof(float);
void *temp7_ram = ((char *)temp_buffer) + 6 * seg_len * sizeof(float);
void *temp8_ram = ((char *)temp_buffer) + 7 * seg_len * sizeof(float);
void *temp9_ram = ((char *)temp_buffer) + 8 * seg_len * sizeof(float);
void *temp10_ram = ((char *)temp_buffer) + 9 * seg_len * sizeof(float);
// Get all intersection points
getIntersectPts(
(float *)rotated_pts1_x, (float *)rotated_pts1_y,
(float *)rotated_pts2_x, (float *)rotated_pts2_y, (float *)vec1_x,
(float *)vec1_y, (float *)vec2_x, (float *)vec2_y,
(float *)intersect_pts_x, (float *)intersect_pts_y,
(float *)valid_pts, (float *)nums_in_ram, (float *)temp1_ram,
(float *)temp2_ram, (float *)temp3_ram, (float *)temp4_ram,
(float *)temp5_ram, (float *)temp6_ram, (float *)temp7_ram,
(float *)temp8_ram, (float *)temp9_ram, (float *)temp10_ram, seg_len);
// Where nums_in <= 2, set valid_box to false
__bang_write_value((float *)temp9_ram, COMPUTE_COUNT_ALIGN, (float)2);
__bang_cycle_gt((float *)temp1_ram, (float *)nums_in_ram,
(float *)temp9_ram, seg_len, COMPUTE_COUNT_ALIGN);
__bang_and((float *)valid_box, (float *)valid_box, (float *)temp1_ram,
seg_len);
__bang_cycle_and((float *)valid_pts, (float *)valid_pts,
(float *)valid_box, 24 * seg_len, seg_len);
// Convex-hull-graham to order the intersection points in clockwise order
// and find the contour area
convexHullGraham(
(float *)intersect_pts_x, (float *)intersect_pts_y,
(float *)ordered_pts_x, (float *)ordered_pts_y, (float *)dist_ram,
(float *)valid_box, (float *)valid_pts, (float *)nums_in_ram,
(float *)temp7_ram, (float *)temp8_ram, (float *)temp9_ram,
(float *)temp_long_1, (float *)temp_long_2, (float *)temp_long_3,
seg_len, seg_len);
// Calculate polygon area
// set temp1 = intersection part area
polygonArea((float *)ordered_pts_x, (float *)ordered_pts_y,
(float *)valid_box, (float *)valid_pts, (float *)nums_in_ram,
(float *)temp1_ram, (float *)temp2_ram, (float *)temp3_ram,
(float *)temp4_ram, (float *)temp5_ram, (float *)temp6_ram,
(float *)temp7_ram, (float *)temp8_ram, (float *)temp9_ram,
seg_len);
// area
__bang_mul((float *)temp2_ram, (float *)box2 + seg_len * 2,
(float *)box2 + seg_len * 3, seg_len);
// get the area_U: area + max_area - area_I
__bang_add_scalar((float *)temp2_ram, (float *)temp2_ram, float(max_area),
seg_len);
__bang_sub((float *)temp2_ram, (float *)temp2_ram, (float *)temp1_ram,
seg_len); // area_U
if (iou_threshold > 0.0) {
__bang_mul_scalar((float *)temp1_ram, (float *)temp1_ram,
div_thresh_iou, seg_len);
} else {
__bang_mul_scalar((float *)temp2_ram, (float *)temp2_ram, iou_threshold,
seg_len);
}
__bang_ge((float *)temp1_ram, (float *)temp2_ram, (float *)temp1_ram,
seg_len);
__bang_mul((float *)score, (float *)score, (float *)temp1_ram, seg_len);
pvLock();
__memcpy(input_score_ptr + input_offset + i * max_seg_pad, score,
cpy_len * sizeof(float), scores_store_dir,
cpy_len * sizeof(float), cpy_len * sizeof(float), 0);
pvUnlock();
}
}
if (clusterId == 0 && coreId == 0 && nram_save_count) {
pvLock();
__memcpy(output_data, nram_save, nram_save_count * sizeof(int32_t),
NRAM2GDRAM);
pvUnlock();
}
}
__mlu_global__ void MLUBlockorUnionIKernelOU3D(
const void *input_boxes, const int input_box_num, const float iou_threshold,
const cnrtDataType_t data_type_input, void *workspace, void *result_num,
void *output) {
int input_dwidth = (data_type_input == CNRT_FLOAT32) ? 4 : 2;
mluMemcpyDirection_t scores_load_dir = GDRAM2NRAM;
mluMemcpyDirection_t scores_store_dir = NRAM2GDRAM;
mluMemcpyDirection_t boxes_load_dir = GDRAM2NRAM;
float *scores_data = (float *)workspace;
float *boxes_data = (float *)input_boxes;
const int cluster_score_size = input_box_num * sizeof(float);
const int cluster_boxes_size = input_box_num * 7 * input_dwidth;
char *sram_score = (char *)sram_buffer;
char *sram_boxes = (char *)sram_buffer + cluster_score_size;
if (clusterDim == 1 && SIZE_SRAM_BUF > cluster_score_size) {
scores_data = (float *)sram_score;
scores_load_dir = SRAM2NRAM;
scores_store_dir = NRAM2SRAM;
if (coreId == 0x80) {
__sramset((void *)sram_buffer, input_box_num, 1.0f);
}
} else {
if (coreId == 0) {
__gdramset(scores_data, input_box_num, 1.0f);
}
}
if (clusterDim == 1 &&
SIZE_SRAM_BUF - cluster_score_size >= cluster_boxes_size) {
boxes_load_dir = SRAM2NRAM;
boxes_data = (float *)sram_boxes;
if (coreId == 0x80) {
__memcpy((char *)boxes_data, (char *)input_boxes, cluster_boxes_size,
GDRAM2SRAM);
}
}
__sync_cluster();
int32_t result_box_num = 0;
int32_t *out_data = (int32_t *)output;
switch (data_type_input) {
default: { return; }
case CNRT_FLOAT16: {
iou3D_detection(result_box_num, out_data, (half *)boxes_data, scores_data,
taskDim, input_box_num, iou_threshold, scores_load_dir,
scores_store_dir, boxes_load_dir);
}; break;
case CNRT_FLOAT32: {
iou3D_detection(result_box_num, out_data, boxes_data, scores_data,
taskDim, input_box_num, iou_threshold, scores_load_dir,
scores_store_dir, boxes_load_dir);
}; break;
}
((int32_t *)result_num)[0] = result_box_num;
}
void KernelIou3d(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t data_type_input, const void *boxes_dram,
const int input_box_num, const float iou_threshold,
void *workspace, void *output_size, void *output) {
switch (k_type) {
default: { return; }
case CNRT_FUNC_TYPE_BLOCK:
case CNRT_FUNC_TYPE_UNION1:
case CNRT_FUNC_TYPE_UNION2:
case CNRT_FUNC_TYPE_UNION4:
case CNRT_FUNC_TYPE_UNION8:
case CNRT_FUNC_TYPE_UNION16: {
MLUBlockorUnionIKernelOU3D<<<k_dim, k_type, queue>>>(
(void *)boxes_dram, input_box_num, iou_threshold, data_type_input,
workspace, output_size, output);
}; break;
}
}
mmcv/ops/csrc/common/mlu/iou3d_utils.hpp
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef IOU3D_UTILS_HPP_
#define IOU3D_UTILS_HPP_
#include "common_mlu_helper.hpp"
#define IOU3D_SIZE 64
#define IOU3D_UP(x, y) (x / y + (int)(x % y > 0)) * y
#define IOU3D_DOWN(x, y) (x / y) * y
#define SIZE_NRAM_BUF (MAX_NRAM_SIZE)
#define SIZE_SRAM_BUF (MAX_SRAM_SIZE)
#define COMPUTE_COUNT_ALIGN 64
#define INFO_NUM (5) // score, x1, y1, x2, y2
#define REDUCE_NUM \
(7) // score, x1, y1, x2, y2, max_index (reserve 2 num for half-type input)
#define SINGLE_BOX_DIM 5
#define MEMORY_CORE (0x80)
__mlu_func__
void
pvLock
()
{
#if __BANG_ARCH__ == 270
if
(
coreId
!=
MEMORY_CORE
)
{
__bang_lock
(
0
,
0
);
}
#endif
}
__mlu_func__
void
pvUnlock
()
{
#if __BANG_ARCH__ == 270
if
(
coreId
!=
MEMORY_CORE
)
{
__bang_unlock
(
0
,
0
);
}
#endif
}
// cross2d<T>(A, B) = A.x * B.y - A.y * B.x;
template
<
typename
T
>
inline
__mlu_func__
void
cross2d
(
T
*
result
,
const
T
*
p1_x
,
const
T
*
p1_y
,
const
T
*
p2_x
,
const
T
*
p2_y
,
const
int
&
length
,
T
*
temp_ram
)
{
__bang_mul
((
T
*
)
temp_ram
,
(
T
*
)
p1_x
,
(
T
*
)
p2_y
,
length
);
__bang_mul
((
T
*
)
result
,
(
T
*
)
p1_y
,
(
T
*
)
p2_x
,
length
);
__bang_sub
((
T
*
)
result
,
(
T
*
)
temp_ram
,
(
T
*
)
result
,
length
);
}
// dot2d<T>(A, B) = A.x * B.x + A.y * B.y
template
<
typename
T
>
inline
__mlu_func__
void
dot2d
(
T
*
result
,
const
T
*
p1_x
,
const
T
*
p1_y
,
const
T
*
p2_x
,
const
T
*
p2_y
,
const
int
&
length
,
T
*
temp_ram
)
{
__bang_mul
((
T
*
)
temp_ram
,
(
T
*
)
p1_x
,
(
T
*
)
p2_x
,
length
);
__bang_mul
((
T
*
)
result
,
(
T
*
)
p1_y
,
(
T
*
)
p2_y
,
length
);
__bang_add
((
T
*
)
result
,
(
T
*
)
temp_ram
,
(
T
*
)
result
,
length
);
}
template
<
typename
T
>
__mlu_func__
void
getRotatedVertices
(
T
*
pts_x
,
T
*
pts_y
,
T
*
box
,
T
*
temp1
,
T
*
temp2
,
T
*
temp3
,
T
*
temp4
,
const
uint32_t
&
actual_compute_box_num
)
{
// T cosTheta2 = (T)cos(theta) * 0.5f; -- temp1
// T sinTheta2 = (T)sin(theta) * 0.5f; -- temp2
// theta is the box's 5th data: a, rotated radian;
#if __BANG_ARCH__ >= 300
__bang_cos
((
float
*
)
temp1
,
((
float
*
)
box
)
+
4
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sin
((
float
*
)
temp2
,
((
float
*
)
box
)
+
4
*
actual_compute_box_num
,
actual_compute_box_num
);
#else
__bang_taylor4_cos
((
T
*
)
temp1
,
((
T
*
)
box
)
+
4
*
actual_compute_box_num
,
(
T
*
)
temp3
,
(
T
*
)
temp4
,
actual_compute_box_num
);
__bang_taylor4_sin
((
T
*
)
temp2
,
((
T
*
)
box
)
+
4
*
actual_compute_box_num
,
(
T
*
)
temp3
,
(
T
*
)
temp4
,
actual_compute_box_num
);
#endif
__bang_mul_scalar
((
T
*
)
temp1
,
(
T
*
)
temp1
,
(
T
)
0.5
,
actual_compute_box_num
);
__bang_mul_scalar
((
T
*
)
temp2
,
(
T
*
)
temp2
,
(
T
)
0.5
,
actual_compute_box_num
);
// Temp3 = sinTheta2 * box.h;
// Temp4 = cosTheta2 * box.w;
__bang_mul
((
T
*
)
temp3
,
(
T
*
)
temp2
,
((
T
*
)
box
)
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp4
,
(
T
*
)
temp1
,
((
T
*
)
box
)
+
2
*
actual_compute_box_num
,
actual_compute_box_num
);
// pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
// pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
__bang_sub
((
T
*
)
pts_x
,
(
T
*
)
box
,
(
T
*
)
temp3
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_x
,
(
T
*
)
pts_x
,
(
T
*
)
temp4
,
actual_compute_box_num
);
__bang_add
((
T
*
)
pts_x
+
1
*
actual_compute_box_num
,
(
T
*
)
box
,
(
T
*
)
temp3
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_x
+
1
*
actual_compute_box_num
,
(
T
*
)
pts_x
+
1
*
actual_compute_box_num
,
(
T
*
)
temp4
,
actual_compute_box_num
);
// Temp3 = cosTheta2 * box.h;
// Temp4 = sinTheta2 * box.w;
__bang_mul
((
T
*
)
temp3
,
(
T
*
)
temp1
,
box
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp4
,
(
T
*
)
temp2
,
box
+
2
*
actual_compute_box_num
,
actual_compute_box_num
);
// pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
// pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
__bang_add
((
T
*
)
pts_y
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
(
T
*
)
temp3
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_y
,
(
T
*
)
pts_y
,
(
T
*
)
temp4
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_y
+
1
*
actual_compute_box_num
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
(
T
*
)
temp3
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_y
+
1
*
actual_compute_box_num
,
(
T
*
)
pts_y
+
1
*
actual_compute_box_num
,
(
T
*
)
temp4
,
actual_compute_box_num
);
// pts[2].x = 2 * box.x_ctr - pts[0].x;
// pts[3].x = 2 * box.x_ctr - pts[1].x;
__bang_add
((
T
*
)
pts_x
+
2
*
actual_compute_box_num
,
(
T
*
)
box
,
(
T
*
)
box
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_x
+
2
*
actual_compute_box_num
,
(
T
*
)
pts_x
+
2
*
actual_compute_box_num
,
(
T
*
)
pts_x
,
actual_compute_box_num
);
__bang_add
((
T
*
)
pts_x
+
3
*
actual_compute_box_num
,
(
T
*
)
box
,
(
T
*
)
box
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_x
+
3
*
actual_compute_box_num
,
(
T
*
)
pts_x
+
3
*
actual_compute_box_num
,
(
T
*
)
pts_x
+
1
*
actual_compute_box_num
,
actual_compute_box_num
);
// pts[2].y = 2 * box.y_ctr - pts[0].y;
// pts[3].y = 2 * box.y_ctr - pts[1].y;
__bang_add
((
T
*
)
pts_y
+
2
*
actual_compute_box_num
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_y
+
2
*
actual_compute_box_num
,
(
T
*
)
pts_y
+
2
*
actual_compute_box_num
,
(
T
*
)
pts_y
,
actual_compute_box_num
);
__bang_add
((
T
*
)
pts_y
+
3
*
actual_compute_box_num
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
(
T
*
)
box
+
1
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
pts_y
+
3
*
actual_compute_box_num
,
(
T
*
)
pts_y
+
3
*
actual_compute_box_num
,
(
T
*
)
pts_y
+
1
*
actual_compute_box_num
,
actual_compute_box_num
);
}
template
<
typename
T
>
__mlu_func__
void
getIntersectPts
(
T
*
rotated_pts1_x
,
T
*
rotated_pts1_y
,
T
*
rotated_pts2_x
,
T
*
rotated_pts2_y
,
T
*
vec1_x
,
T
*
vec1_y
,
T
*
vec2_x
,
T
*
vec2_y
,
T
*
intersect_pts_x
,
T
*
intersect_pts_y
,
T
*
valid_pts
,
T
*
nums_in_ram
,
T
*
temp1_ram
,
T
*
temp2_ram
,
T
*
temp3_ram
,
T
*
temp4_ram
,
T
*
temp5_ram
,
T
*
temp6_ram
,
T
*
temp7_ram
,
T
*
temp8_ram
,
T
*
temp9_ram
,
T
*
temp10_ram
,
const
uint32_t
&
actual_compute_box_num
)
{
// Initialize const data to ram
// temp3 = const 1e-14(@float), length = COMPUTE_COUNT_ALIGN
#if __BANG_ARCH__ >= 300
__bang_write_value
((
T
*
)
temp3_ram
,
COMPUTE_COUNT_ALIGN
,
(
T
)
1e-14
);
#else
// NOTE: Since active_reciphp function has strict value range,
// [2.2205e-16, 2e6]@float, [0.00391, 65504]@half
__bang_write_value
((
T
*
)
temp3_ram
,
COMPUTE_COUNT_ALIGN
,
(
float
)
1e-14
);
#endif
// temp4 = const T(0), length = COMPUTE_COUNT_ALIGN
__bang_write_value
((
T
*
)
temp4_ram
,
COMPUTE_COUNT_ALIGN
,
(
T
)
0
);
// temp5 = const T(1), length = COMPUTE_COUNT_ALIGN
__bang_write_value
((
T
*
)
temp5_ram
,
COMPUTE_COUNT_ALIGN
,
(
T
)
1
);
// Line vector, from p1 to p2 is: p1+(p2-p1)*t, t=[0,1]
// for i = 0~3, vec[i] = pts[(i+1)%4] - pts[i]
__bang_sub
((
T
*
)
vec1_x
,
(
T
*
)
rotated_pts1_x
+
actual_compute_box_num
,
(
T
*
)
rotated_pts1_x
,
3
*
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec1_x
+
3
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_x
,
(
T
*
)
rotated_pts1_x
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec1_y
,
(
T
*
)
rotated_pts1_y
+
actual_compute_box_num
,
(
T
*
)
rotated_pts1_y
,
3
*
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec1_y
+
3
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_y
,
(
T
*
)
rotated_pts1_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec2_x
,
(
T
*
)
rotated_pts2_x
+
actual_compute_box_num
,
(
T
*
)
rotated_pts2_x
,
3
*
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec2_x
+
3
*
actual_compute_box_num
,
(
T
*
)
rotated_pts2_x
,
(
T
*
)
rotated_pts2_x
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec2_y
,
(
T
*
)
rotated_pts2_y
+
actual_compute_box_num
,
(
T
*
)
rotated_pts2_y
,
3
*
actual_compute_box_num
);
__bang_sub
((
T
*
)
vec2_y
+
3
*
actual_compute_box_num
,
(
T
*
)
rotated_pts2_y
,
(
T
*
)
rotated_pts2_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
);
// First, line test - test all line combos for intersection, 4x4 possible
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
// T det = cross2d<T>(vec2[j], vec1[i]) -- temp2
cross2d
<
T
>
((
T
*
)
temp2_ram
,
(
T
*
)
vec2_x
+
j
*
actual_compute_box_num
,
(
T
*
)
vec2_y
+
j
*
actual_compute_box_num
,
(
T
*
)
vec1_x
+
i
*
actual_compute_box_num
,
(
T
*
)
vec1_y
+
i
*
actual_compute_box_num
,
actual_compute_box_num
,
(
T
*
)
temp1_ram
);
// temp8 = sign(det), since active_reciphp only receive positive values
__bang_active_sign
((
T
*
)
temp8_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
// deal with parallel lines, temp2 = fabs(det), temp1 = temp2 > 1e-14
__bang_active_abs
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
__bang_cycle_gt
((
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
temp3_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
// Where temp1 = false, set recip input to 1, avoiding recip(0), cause inf
__bang_not
((
T
*
)
temp9_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
// temp2 = 1/temp2, use mult (1/temp2) instead of div temp2
#if __BANG_ARCH__ >= 300
__bang_recip
((
float
*
)
temp2_ram
,
(
float
*
)
temp2_ram
,
actual_compute_box_num
);
#else
// NOTE: active_reciphp function has strict value range:
// [2.2205e-16, 2e6]@float, [0.00391, 65504]@half
__bang_active_reciphp
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
#endif
// Restore temp2 invalid box value 1 and sign-bit
__bang_mul
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
temp8_ram
,
actual_compute_box_num
);
// auto vec12 = pts2[j] - pts1[i], (temp6, temp7) = (x, y)
__bang_sub
((
T
*
)
temp6_ram
,
(
T
*
)
rotated_pts2_x
+
j
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_x
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp7_ram
,
(
T
*
)
rotated_pts2_y
+
j
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_y
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
// T t1 = cross2d<T>(vec2[j], vec12) mult (1/det) -- temp8
cross2d
<
T
>
((
T
*
)
temp8_ram
,
(
T
*
)
vec2_x
+
j
*
actual_compute_box_num
,
(
T
*
)
vec2_y
+
j
*
actual_compute_box_num
,
(
T
*
)
temp6_ram
,
(
T
*
)
temp7_ram
,
actual_compute_box_num
,
(
T
*
)
temp9_ram
);
__bang_mul
((
T
*
)
temp8_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
// temp1 &= (t1 >= 0.0f && t1 <= 1.0f) -- temp9
__bang_cycle_ge
((
T
*
)
temp9_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
__bang_cycle_le
((
T
*
)
temp9_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
temp5_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
// T t2 = cross2d<T>(vec1[i], vec12) mult temp2 -- temp9
// NOTE: temp8(t1) is used after, reuse temp7(p2_y) as cross2d temp ram
cross2d
<
T
>
((
T
*
)
temp9_ram
,
(
T
*
)
vec1_x
+
i
*
actual_compute_box_num
,
(
T
*
)
vec1_y
+
i
*
actual_compute_box_num
,
(
T
*
)
temp6_ram
,
(
T
*
)
temp7_ram
,
actual_compute_box_num
,
(
T
*
)
temp7_ram
);
__bang_mul
((
T
*
)
temp9_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
// temp1 &= (t2 >= 0.0f && t2 <= 1.0f) -- temp9
__bang_cycle_ge
((
T
*
)
temp7_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp7_ram
,
actual_compute_box_num
);
__bang_cycle_le
((
T
*
)
temp7_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp5_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp7_ram
,
actual_compute_box_num
);
// intersections = (pts1[i] + vec1[i] * t1) * temp1
__bang_mul
((
T
*
)
temp9_ram
,
(
T
*
)
vec1_x
+
i
*
actual_compute_box_num
,
(
T
*
)
temp8_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
temp9_ram
,
(
T
*
)
rotated_pts1_x
+
i
*
actual_compute_box_num
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
intersect_pts_x
+
(
4
*
i
+
j
)
*
actual_compute_box_num
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp9_ram
,
(
T
*
)
vec1_y
+
i
*
actual_compute_box_num
,
(
T
*
)
temp8_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
temp9_ram
,
(
T
*
)
rotated_pts1_y
+
i
*
actual_compute_box_num
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
intersect_pts_y
+
(
4
*
i
+
j
)
*
actual_compute_box_num
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
// Assign `valid_pts` bit and accumulate `nums_in` of valid points of each
// box pair
__bang_or
((
T
*
)
valid_pts
+
(
4
*
i
+
j
)
*
actual_compute_box_num
,
(
T
*
)
valid_pts
+
(
4
*
i
+
j
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
nums_in_ram
,
(
T
*
)
nums_in_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
}
}
// Check for vertices of rect1 inside rect2
// temp5 = ABdotAB
dot2d
<
T
>
((
T
*
)
temp5_ram
,
(
T
*
)
vec2_x
,
(
T
*
)
vec2_y
,
(
T
*
)
vec2_x
,
(
T
*
)
vec2_y
,
actual_compute_box_num
,
(
T
*
)
temp9_ram
);
// temp6 = ADdotAD
dot2d
<
T
>
((
T
*
)
temp6_ram
,
(
T
*
)
vec2_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec2_y
+
3
*
actual_compute_box_num
,
(
T
*
)
vec2_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec2_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
,
(
T
*
)
temp9_ram
);
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lines within AB
// and P's projection on AD lies within AD
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
// AP = pts1[i] - pts2[0] = (temp7, temp8)
__bang_sub
((
T
*
)
temp7_ram
,
(
T
*
)
rotated_pts1_x
+
i
*
actual_compute_box_num
,
(
T
*
)
rotated_pts2_x
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp8_ram
,
(
T
*
)
rotated_pts1_y
+
i
*
actual_compute_box_num
,
(
T
*
)
rotated_pts2_y
,
actual_compute_box_num
);
// temp9 = APdotAB = dot2d<T>(AP, AB)
dot2d
<
T
>
((
T
*
)
temp9_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
vec2_x
,
(
T
*
)
vec2_y
,
actual_compute_box_num
,
(
T
*
)
temp2_ram
);
// temp10 = APdotAD = -dot2d<T>(AP, DA)
dot2d
<
T
>
((
T
*
)
temp10_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
vec2_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec2_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
,
(
T
*
)
temp2_ram
);
__bang_mul_scalar
((
T
*
)
temp10_ram
,
(
T
*
)
temp10_ram
,
(
T
)
-
1
,
actual_compute_box_num
);
// ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <=
// ADdotAD))
__bang_cycle_ge
((
T
*
)
temp1_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_cycle_ge
((
T
*
)
temp2_ram
,
(
T
*
)
temp10_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
__bang_le
((
T
*
)
temp2_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp5_ram
,
actual_compute_box_num
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
__bang_le
((
T
*
)
temp2_ram
,
(
T
*
)
temp10_ram
,
(
T
*
)
temp6_ram
,
actual_compute_box_num
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
// 16 means the 4x4 possible intersection points above
__bang_mul
((
T
*
)
intersect_pts_x
+
(
16
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
(
T
*
)
rotated_pts1_x
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
intersect_pts_y
+
(
16
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
(
T
*
)
rotated_pts1_y
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
// assign valid_pts bit and accumulate nums of valid points of each box pair
__bang_or
((
T
*
)
valid_pts
+
(
16
+
i
)
*
actual_compute_box_num
,
(
T
*
)
valid_pts
+
(
16
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
nums_in_ram
,
(
T
*
)
nums_in_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
}
// Reverse the check - check for vertices of rect2 inside rect1
// temp5 = ABdotAB
dot2d
<
T
>
((
T
*
)
temp5_ram
,
(
T
*
)
vec1_x
,
(
T
*
)
vec1_y
,
(
T
*
)
vec1_x
,
(
T
*
)
vec1_y
,
actual_compute_box_num
,
(
T
*
)
temp9_ram
);
// temp6 = ADdotAD
dot2d
<
T
>
((
T
*
)
temp6_ram
,
(
T
*
)
vec1_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec1_y
+
3
*
actual_compute_box_num
,
(
T
*
)
vec1_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec1_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
,
(
T
*
)
temp9_ram
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
// AP = pts2[i] - pts1[0] = (temp7, temp8)
__bang_sub
((
T
*
)
temp7_ram
,
(
T
*
)
rotated_pts2_x
+
i
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_x
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp8_ram
,
(
T
*
)
rotated_pts2_y
+
i
*
actual_compute_box_num
,
(
T
*
)
rotated_pts1_y
,
actual_compute_box_num
);
// temp9 = APdotAB = dot2d<T>(AP, AB)
dot2d
<
T
>
((
T
*
)
temp9_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
vec1_x
,
(
T
*
)
vec1_y
,
actual_compute_box_num
,
(
T
*
)
temp2_ram
);
// temp10 = APdotAD = -dot2d<T>(AP, DA)
dot2d
<
T
>
((
T
*
)
temp10_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
vec1_x
+
3
*
actual_compute_box_num
,
(
T
*
)
vec1_y
+
3
*
actual_compute_box_num
,
actual_compute_box_num
,
(
T
*
)
temp2_ram
);
__bang_mul_scalar
((
T
*
)
temp10_ram
,
(
T
*
)
temp10_ram
,
(
T
)
-
1
,
actual_compute_box_num
);
// ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <=
// ADdotAD))
__bang_cycle_ge
((
T
*
)
temp1_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_cycle_ge
((
T
*
)
temp2_ram
,
(
T
*
)
temp10_ram
,
(
T
*
)
temp4_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
__bang_le
((
T
*
)
temp2_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
temp5_ram
,
actual_compute_box_num
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
__bang_le
((
T
*
)
temp2_ram
,
(
T
*
)
temp10_ram
,
(
T
*
)
temp6_ram
,
actual_compute_box_num
);
__bang_and
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp2_ram
,
actual_compute_box_num
);
// 20 means the (4x4+4) possible intersection points above
__bang_mul
((
T
*
)
intersect_pts_x
+
(
20
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
(
T
*
)
rotated_pts2_x
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
intersect_pts_y
+
(
20
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
(
T
*
)
rotated_pts2_y
+
i
*
actual_compute_box_num
,
actual_compute_box_num
);
// assign valid_pts bit and accumulate nums of valid points of each box pair
__bang_or
((
T
*
)
valid_pts
+
(
20
+
i
)
*
actual_compute_box_num
,
(
T
*
)
valid_pts
+
(
20
+
i
)
*
actual_compute_box_num
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
nums_in_ram
,
(
T
*
)
nums_in_ram
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
}
}
template
<
typename
T
>
__mlu_func__
void
convexHullGraham
(
T
*
intersect_pts_x
,
T
*
intersect_pts_y
,
T
*
ordered_pts_x
,
T
*
ordered_pts_y
,
T
*
dist_ram
,
T
*
valid_box
,
T
*
valid_pts
,
T
*
nums_in_ram
,
T
*
temp1_ram
,
T
*
temp2_ram
,
T
*
temp3_ram
,
T
*
temp_long_1
,
T
*
temp_long_2
,
T
*
temp_long_3
,
const
uint32_t
&
actual_box_num
,
const
uint32_t
&
actual_compute_box_num
)
{
// Step1. Find the point with minimum y, if more than 1 points have the same
// minimum y,
// pick the one with the minimum x.
// set p[i].y to max_y_value if not valid_pts, to avoid invalid result
// 24 means all possible intersection points
__bang_max
((
T
*
)
temp2_ram
,
(
T
*
)
intersect_pts_y
,
24
*
actual_compute_box_num
);
__bang_write_value
((
T
*
)
temp3_ram
,
COMPUTE_COUNT_ALIGN
,
((
T
*
)
temp2_ram
)[
0
]);
__bang_not
((
T
*
)
temp_long_1
,
(
T
*
)
valid_pts
,
24
*
actual_compute_box_num
);
__bang_cycle_mul
((
T
*
)
temp_long_1
,
(
T
*
)
temp_long_1
,
(
T
*
)
temp3_ram
,
24
*
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_mul
((
T
*
)
temp_long_2
,
(
T
*
)
intersect_pts_y
,
(
T
*
)
valid_pts
,
24
*
actual_compute_box_num
);
__bang_add
((
T
*
)
temp_long_2
,
(
T
*
)
temp_long_2
,
(
T
*
)
temp_long_1
,
24
*
actual_compute_box_num
);
// temp2 = min_y_value(temp_long_2), use min_pool, channel=box_num, h=1, w=24
__bang_minpool
((
T
*
)
temp2_ram
,
(
T
*
)
temp_long_2
,
actual_compute_box_num
,
1
,
24
,
1
,
24
,
1
,
24
);
__bang_mul
((
T
*
)
temp2_ram
,
(
T
*
)
temp2_ram
,
(
T
*
)
valid_box
,
actual_compute_box_num
);
// set p[i].x to max_x_value if not min_y point
__bang_max
((
T
*
)
temp1_ram
,
(
T
*
)
intersect_pts_x
,
24
*
actual_compute_box_num
);
__bang_write_value
((
T
*
)
temp3_ram
,
COMPUTE_COUNT_ALIGN
,
((
T
*
)
temp1_ram
)[
0
]);
__bang_cycle_eq
((
T
*
)
temp_long_1
,
(
T
*
)
temp_long_2
,
(
T
*
)
temp2_ram
,
24
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_and
((
T
*
)
temp_long_1
,
(
T
*
)
temp_long_1
,
(
T
*
)
valid_pts
,
24
*
actual_compute_box_num
);
__bang_not
((
T
*
)
temp_long_3
,
(
T
*
)
temp_long_1
,
24
*
actual_compute_box_num
);
__bang_cycle_mul
((
T
*
)
temp_long_3
,
(
T
*
)
temp_long_3
,
(
T
*
)
temp3_ram
,
24
*
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_mul
((
T
*
)
temp_long_1
,
(
T
*
)
intersect_pts_x
,
(
T
*
)
temp_long_1
,
24
*
actual_compute_box_num
);
__bang_add
((
T
*
)
temp_long_1
,
(
T
*
)
temp_long_1
,
(
T
*
)
temp_long_3
,
24
*
actual_compute_box_num
);
// temp3 = min_x_value(temp_long_1), use min_pool, channel=box_num, h=1, w=24
__bang_minpool
((
T
*
)
temp3_ram
,
(
T
*
)
temp_long_1
,
actual_compute_box_num
,
1
,
24
,
1
,
24
,
1
,
24
);
__bang_mul
((
T
*
)
temp3_ram
,
(
T
*
)
temp3_ram
,
(
T
*
)
valid_box
,
actual_compute_box_num
);
// Step2. All points subtract starting-point (for sorting in the next step)
__bang_cycle_sub
((
T
*
)
ordered_pts_x
,
(
T
*
)
intersect_pts_x
,
(
T
*
)
temp3_ram
,
24
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_cycle_sub
((
T
*
)
ordered_pts_y
,
(
T
*
)
intersect_pts_y
,
(
T
*
)
temp2_ram
,
24
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
ordered_pts_x
,
(
T
*
)
ordered_pts_x
,
(
T
*
)
valid_pts
,
24
*
actual_compute_box_num
);
__bang_mul
((
T
*
)
ordered_pts_y
,
(
T
*
)
ordered_pts_y
,
(
T
*
)
valid_pts
,
24
*
actual_compute_box_num
);
// Step3. Sort every intersection point according to their relative
// cross-product values (essentially sorting according to angles)
// If the angles are the same, sort according to distance to origin
dot2d
<
T
>
((
T
*
)
dist_ram
,
(
T
*
)
ordered_pts_x
,
(
T
*
)
ordered_pts_y
,
(
T
*
)
ordered_pts_x
,
(
T
*
)
ordered_pts_y
,
24
*
actual_compute_box_num
,
(
T
*
)
temp_long_3
);
T
temp
,
temp_nums_in
,
temp_dist_1
,
temp_dist_2
;
T
temp1_x
,
temp1_y
;
T
temp2_x
,
temp2_y
;
for
(
int
i
=
0
;
i
<
actual_box_num
;
i
++
)
{
if
(((
T
*
)
valid_box
)[
i
])
{
// make sure all nums_in[i] points are at the front
for
(
int
ii
=
0
;
ii
<
23
;
ii
++
)
{
for
(
int
jj
=
ii
+
1
;
jj
<
24
;
jj
++
)
{
int
ii_index
=
ii
*
actual_compute_box_num
+
i
;
int
jj_index
=
jj
*
actual_compute_box_num
+
i
;
// ii point is not valid and jj point is valid, swap jj for ii
if
((
!
((
T
*
)
valid_pts
)[
ii_index
])
&&
((
T
*
)
valid_pts
)[
jj_index
])
{
((
T
*
)
ordered_pts_x
)[
ii_index
]
=
((
T
*
)
ordered_pts_x
)[
jj_index
];
((
T
*
)
ordered_pts_y
)[
ii_index
]
=
((
T
*
)
ordered_pts_y
)[
jj_index
];
((
T
*
)
dist_ram
)[
ii_index
]
=
((
T
*
)
dist_ram
)[
jj_index
];
((
T
*
)
valid_pts
)[
ii_index
]
=
true
;
((
T
*
)
ordered_pts_x
)[
jj_index
]
=
0
;
((
T
*
)
ordered_pts_y
)[
jj_index
]
=
0
;
((
T
*
)
dist_ram
)[
jj_index
]
=
0
;
((
T
*
)
valid_pts
)[
jj_index
]
=
false
;
break
;
}
}
}
temp_nums_in
=
((
T
*
)
nums_in_ram
)[
i
];
// make original q[0] = min_x, min_y before sort
for
(
int
ii
=
1
;
ii
<
temp_nums_in
;
ii
++
)
{
int
ii_index
=
ii
*
actual_compute_box_num
+
i
;
if
(((
T
*
)
dist_ram
)[
ii_index
]
==
0
)
{
// swap q[ii_index] and q[0]
((
T
*
)
ordered_pts_x
)[
ii_index
]
=
((
T
*
)
ordered_pts_x
)[
i
];
((
T
*
)
ordered_pts_y
)[
ii_index
]
=
((
T
*
)
ordered_pts_y
)[
i
];
((
T
*
)
dist_ram
)[
ii_index
]
=
((
T
*
)
dist_ram
)[
i
];
((
T
*
)
ordered_pts_x
)[
i
]
=
0
;
((
T
*
)
ordered_pts_y
)[
i
]
=
0
;
((
T
*
)
dist_ram
)[
i
]
=
0
;
break
;
}
}
for
(
int
ii
=
1
;
ii
<
temp_nums_in
-
1
;
ii
++
)
{
for
(
int
jj
=
ii
+
1
;
jj
<
temp_nums_in
;
jj
++
)
{
int
ii_index
=
ii
*
actual_compute_box_num
+
i
;
int
jj_index
=
jj
*
actual_compute_box_num
+
i
;
temp1_x
=
((
T
*
)
ordered_pts_x
)[
ii_index
];
temp1_y
=
((
T
*
)
ordered_pts_y
)[
ii_index
];
temp2_x
=
((
T
*
)
ordered_pts_x
)[
jj_index
];
temp2_y
=
((
T
*
)
ordered_pts_y
)[
jj_index
];
// calculate cross product and sort q (ordered_pts)
temp
=
(
temp1_x
*
temp2_y
)
-
(
temp1_y
*
temp2_x
);
temp_dist_1
=
((
T
*
)
dist_ram
)[
ii_index
];
temp_dist_2
=
((
T
*
)
dist_ram
)[
jj_index
];
if
((
temp
<
(
T
)
-
1e-6
)
||
((
fabs
(
temp
)
<
(
T
)
1e-6
)
&&
(
temp_dist_1
>
temp_dist_2
)))
{
((
T
*
)
ordered_pts_x
)[
ii_index
]
=
temp2_x
;
((
T
*
)
ordered_pts_y
)[
ii_index
]
=
temp2_y
;
((
T
*
)
ordered_pts_x
)[
jj_index
]
=
temp1_x
;
((
T
*
)
ordered_pts_y
)[
jj_index
]
=
temp1_y
;
((
T
*
)
dist_ram
)[
ii_index
]
=
temp_dist_2
;
((
T
*
)
dist_ram
)[
jj_index
]
=
temp_dist_1
;
}
}
}
// Step4:
// Make sure there are at least 2 points(that don't overlap with each
// other) in the stack
int
k
;
// index of the non-overlapped second point
for
(
k
=
1
;
k
<
temp_nums_in
;
k
++
)
{
if
(((
T
*
)
dist_ram
)[
k
*
actual_compute_box_num
+
i
]
>
(
T
)
1e-8
)
{
break
;
}
}
if
(
k
==
temp_nums_in
)
{
// We reach the end, which means the convex hull is just one point
// set valid_box = 0, to get ious = 0
((
T
*
)
valid_box
)[
i
]
=
0
;
continue
;
}
// q[1] = q[k];
((
T
*
)
ordered_pts_x
)[
actual_compute_box_num
+
i
]
=
((
T
*
)
ordered_pts_x
)[
k
*
actual_compute_box_num
+
i
];
((
T
*
)
ordered_pts_y
)[
actual_compute_box_num
+
i
]
=
((
T
*
)
ordered_pts_y
)[
k
*
actual_compute_box_num
+
i
];
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
int
m
=
2
;
// 2 points in the stack
for
(
int
j
=
k
+
1
;
j
<
temp_nums_in
;
j
++
)
{
// while (m > 1 && cross2d<T>(q[j] - q[m - 2], q[m - 1] - q[m - 2]) >=
// 0) {
// m--;
// }
temp1_x
=
((
T
*
)
ordered_pts_x
)[
j
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_x
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp1_y
=
((
T
*
)
ordered_pts_y
)[
j
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_y
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp2_x
=
((
T
*
)
ordered_pts_x
)[(
m
-
1
)
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_x
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp2_y
=
((
T
*
)
ordered_pts_y
)[(
m
-
1
)
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_y
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp
=
(
temp1_x
*
temp2_y
)
-
(
temp1_y
*
temp2_x
);
while
((
m
>
1
)
&&
(
temp
>=
0
))
{
m
--
;
if
(
m
>
1
)
{
temp1_x
=
((
T
*
)
ordered_pts_x
)[
j
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_x
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp1_y
=
((
T
*
)
ordered_pts_y
)[
j
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_y
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp2_x
=
((
T
*
)
ordered_pts_x
)[(
m
-
1
)
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_x
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp2_y
=
((
T
*
)
ordered_pts_y
)[(
m
-
1
)
*
actual_compute_box_num
+
i
]
-
((
T
*
)
ordered_pts_y
)[(
m
-
2
)
*
actual_compute_box_num
+
i
];
temp
=
(
temp1_x
*
temp2_y
)
-
(
temp1_y
*
temp2_x
);
}
}
// q[m++] = q[j];
((
T
*
)
ordered_pts_x
)[
m
*
actual_compute_box_num
+
i
]
=
((
T
*
)
ordered_pts_x
)[
j
*
actual_compute_box_num
+
i
];
((
T
*
)
ordered_pts_y
)[
m
*
actual_compute_box_num
+
i
]
=
((
T
*
)
ordered_pts_y
)[
j
*
actual_compute_box_num
+
i
];
m
++
;
}
// set last(24-m) valid_pts to false, to erase invalid q in polygon area
for
(
int
j
=
m
;
j
<
temp_nums_in
;
j
++
)
{
((
T
*
)
valid_pts
)[
j
*
actual_compute_box_num
+
i
]
=
0
;
}
((
T
*
)
nums_in_ram
)[
i
]
=
m
;
}
}
}
template
<
typename
T
>
__mlu_func__
void
polygonArea
(
T
*
ordered_pts_x
,
T
*
ordered_pts_y
,
T
*
valid_box
,
T
*
valid_pts
,
T
*
nums_in_ram
,
T
*
temp1_ram
,
T
*
temp2_ram
,
T
*
temp3_ram
,
T
*
temp4_ram
,
T
*
temp5_ram
,
T
*
temp6_ram
,
T
*
temp7_ram
,
T
*
temp8_ram
,
T
*
temp9_ram
,
const
uint32_t
&
actual_compute_box_num
)
{
// Set where nums_in <= 2, valid_box = false
__bang_write_value
((
T
*
)
temp9_ram
,
COMPUTE_COUNT_ALIGN
,
(
T
)
2
);
__bang_cycle_gt
((
T
*
)
temp1_ram
,
(
T
*
)
nums_in_ram
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
,
COMPUTE_COUNT_ALIGN
);
__bang_and
((
T
*
)
valid_box
,
(
T
*
)
valid_box
,
(
T
*
)
temp1_ram
,
actual_compute_box_num
);
// temp1 = area, initialize with all 0
__bang_write_zero
((
T
*
)
temp1_ram
,
actual_compute_box_num
);
__bang_max
((
T
*
)
temp7_ram
,
(
T
*
)
nums_in_ram
,
actual_compute_box_num
);
// temp_nums_in = max(nums_in)
T
temp_nums_in
=
((
T
*
)
temp7_ram
)[
0
];
for
(
int
i
=
1
;
i
<
temp_nums_in
-
1
;
i
++
)
{
// q[i] - q[0]: (temp6, temp7)
__bang_sub
((
T
*
)
temp6_ram
,
(
T
*
)
ordered_pts_x
+
i
*
actual_compute_box_num
,
(
T
*
)
ordered_pts_x
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp7_ram
,
(
T
*
)
ordered_pts_y
+
i
*
actual_compute_box_num
,
(
T
*
)
ordered_pts_y
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp6_ram
,
(
T
*
)
temp6_ram
,
(
T
*
)
valid_pts
+
(
i
+
1
)
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp7_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
valid_pts
+
(
i
+
1
)
*
actual_compute_box_num
,
actual_compute_box_num
);
// q[i + 1] - q[0]: (temp8, temp9)
__bang_sub
((
T
*
)
temp8_ram
,
(
T
*
)
ordered_pts_x
+
(
i
+
1
)
*
actual_compute_box_num
,
(
T
*
)
ordered_pts_x
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp9_ram
,
(
T
*
)
ordered_pts_y
+
(
i
+
1
)
*
actual_compute_box_num
,
(
T
*
)
ordered_pts_y
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp8_ram
,
(
T
*
)
temp8_ram
,
(
T
*
)
valid_pts
+
(
i
+
1
)
*
actual_compute_box_num
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp9_ram
,
(
T
*
)
temp9_ram
,
(
T
*
)
valid_pts
+
(
i
+
1
)
*
actual_compute_box_num
,
actual_compute_box_num
);
// area += fabs(cross2d<T>(q[i] - q[0], q[i + 1] - q[0]));
__bang_mul
((
T
*
)
temp4_ram
,
(
T
*
)
temp6_ram
,
(
T
*
)
temp9_ram
,
actual_compute_box_num
);
__bang_mul
((
T
*
)
temp5_ram
,
(
T
*
)
temp7_ram
,
(
T
*
)
temp8_ram
,
actual_compute_box_num
);
__bang_sub
((
T
*
)
temp3_ram
,
(
T
*
)
temp4_ram
,
(
T
*
)
temp5_ram
,
actual_compute_box_num
);
__bang_active_abs
((
T
*
)
temp3_ram
,
(
T
*
)
temp3_ram
,
actual_compute_box_num
);
__bang_add
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
temp3_ram
,
actual_compute_box_num
);
}
// Set where valid_box = false, intersection = 0
__bang_mul
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
*
)
valid_box
,
actual_compute_box_num
);
// area = area / 2.0
__bang_mul_scalar
((
T
*
)
temp1_ram
,
(
T
*
)
temp1_ram
,
(
T
)
0.5
,
actual_compute_box_num
);
}
#endif // IOU3D_UTILS_HPP_
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include <math.h>
/****************************************************************************************
*
* NRAM partition forward:
* | spatial_shapes | data_value_p1_ping | data_value_p2_ping |
* | data_value_p3_ping | data_value_p4_ping | data_col_ping |
* | data_value_p1_pong | data_value_p2_pong | data_value_p3_pong |
* | data_value_p4_pong | data_col_pong | auxiliary_a |
* | auxiliary_b |
* | 128bytes | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size |
*
****************************************************************************************/
/****************************************************************************************
*
* NRAM partition backward:
* default kernel
* | grad_output_nram | grad_output_nram_temp | grad_weight |
* | grad_h_weight | grad_w_weight | top_grad |
* | top_grad_temp | spatial_shapes_nram | sampling_loc_nram |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | 64bytes |
*
* small channel kernel
* | nram_grad_output_tl | nram_grad_output_tr | nram_grad_output_bl |
* | nram_grad_output_br | grad_temp1 | grad_temp2 |
* | grad_temp3 | grad_temp4 | nram_loc_w |
* | nram_loc_h | nram_h_low | nram_w_low |
* | nram_h_high | nram_w_high | nram_h_low_temp |
* | nram_h_high_temp | nram_hw | nram_hh |
* | nram_lw | nram_lh | nram_h_low_ptr_offset |
* | nram_h_high_ptr_offset | nram_w_low_ptr_offset | nram_w_high_ptr_offset |
* | nram_w1 | nram_w2 | nram_w3 |
* | nram_w4 | nram_grad_weight | nram_base_ptr |
* | nram_offset_temp | nram_offset1 | nram_offset2 |
* | nram_offset3 | nram_offset4 | nram_w_low_temp |
* | nram_spatial_shapes | nram_level_start_index | nram_h_stride |
****************************************************************************************/
#define TWELVE_SPLIT 12
#define ALIGN_NUM 32
#define ALIGN_NUM_FOR_REDUCE 32
#define ELE_COUNT 32
#define LEN_FLOAT sizeof(float)
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void loadNeighborPointsData(
const T *data_value_gdram, T *data_value_p1_nram, T *data_value_p2_nram,
T *data_value_p3_nram, T *data_value_p4_nram, const size_t &deal_num,
const int32_t &width, const int32_t &height, const int32_t &num_heads,
const int32_t &channels, const T &x, const T &y, const int32_t &head_idx) {
const int32_t w_low = floorf(x);
const int32_t h_low = floorf(y);
const int32_t w_high = w_low + 1;
const int32_t h_high = h_low + 1;
const int32_t w_stride = num_heads * channels;
const int32_t h_stride = width * w_stride;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * w_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int32_t base_ptr_offset = head_idx * channels;
// top-left point
if (h_low >= 0 && w_low >= 0) {
const int32_t v1_offset =
h_low_ptr_offset + w_low_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p1_nram, data_value_gdram + v1_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// top-right point
if (h_low >= 0 && w_high <= width - 1) {
const int32_t v2_offset =
h_low_ptr_offset + w_high_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p2_nram, data_value_gdram + v2_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// bottom-left point
if (h_high <= height - 1 && w_low >= 0) {
const int32_t v3_offset =
h_high_ptr_offset + w_low_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p3_nram, data_value_gdram + v3_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// bottom-right point
if (h_high <= height - 1 && w_high <= width - 1) {
const int32_t v4_offset =
h_high_ptr_offset + w_high_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p4_nram, data_value_gdram + v4_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
}
template <typename T>
__mlu_func__ void computeMsDeformAttn(
T *data_value_p1_nram, T *data_value_p2_nram, T *data_value_p3_nram,
T *data_value_p4_nram, T *sample_point_value, T *auxiliary_b,
T *data_col_nram, const T &weight, const size_t &deal_num,
const int32_t &width, const int32_t &height, const T &x, const T &y) {
const int32_t w_low = floorf(x);
const int32_t h_low = floorf(y);
const int32_t w_high = w_low + 1;
const int32_t h_high = h_low + 1;
const T lw = x - w_low;
const T lh = y - h_low;
const T hw = 1 - lw;
const T hh = 1 - lh;
const T w1 = hh * hw;
const T w2 = hh * lw;
const T w3 = lh * hw;
const T w4 = lh * lw;
__bang_write_value((T *)sample_point_value, deal_num, (T)0);
// top-left point
if (h_low >= 0 && w_low >= 0) {
// sample_point_value += v1 * w1
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p1_nram, (T)w1,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// top-right point
if (h_low >= 0 && w_high <= width - 1) {
// sample_point_value += v2 * w2
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p2_nram, (T)w2,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// bottom-left point
if (h_high <= height - 1 && w_low >= 0) {
// sample_point_value += v3 * w3
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p3_nram, (T)w3,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// bottom-right point
if (h_high <= height - 1 && w_high <= width - 1) {
// sample_point_value += v4 * w4
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p4_nram, (T)w4,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
__bang_mul_scalar((T *)sample_point_value, (T *)sample_point_value, (T)weight,
deal_num);
__bang_add((T *)data_col_nram, (T *)data_col_nram, (T *)sample_point_value,
deal_num);
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForwardDefault(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
if (coreId == 0x80) {
return;
}
const size_t spatial_size = PAD_UP(2 * sizeof(int32_t), NFU_ALIGN_SIZE);
const size_t span_num_deal =
PAD_DOWN((MAX_NRAM_SIZE - spatial_size) / TWELVE_SPLIT / sizeof(T),
NFU_ALIGN_SIZE);
const size_t align_num = NFU_ALIGN_SIZE;
const int32_t channels_seg_num = channels / span_num_deal;
const size_t channels_rem = channels % span_num_deal;
const size_t channels_align_rem = CEIL_ALIGN(channels_rem, align_num);
char *data_spatial_shapes_nram = nram_buffer;
char *ping_data_value_p1_nram = data_spatial_shapes_nram + spatial_size;
char *ping_data_value_p2_nram =
ping_data_value_p1_nram + span_num_deal * sizeof(T);
char *ping_data_value_p3_nram =
ping_data_value_p2_nram + span_num_deal * sizeof(T);
char *ping_data_value_p4_nram =
ping_data_value_p3_nram + span_num_deal * sizeof(T);
char *ping_data_col_nram =
ping_data_value_p4_nram + span_num_deal * sizeof(T);
char *pong_data_value_p1_nram =
ping_data_col_nram + span_num_deal * sizeof(T);
char *pong_data_value_p2_nram =
pong_data_value_p1_nram + span_num_deal * sizeof(T);
char *pong_data_value_p3_nram =
pong_data_value_p2_nram + span_num_deal * sizeof(T);
char *pong_data_value_p4_nram =
pong_data_value_p3_nram + span_num_deal * sizeof(T);
char *pong_data_col_nram =
pong_data_value_p4_nram + span_num_deal * sizeof(T);
char *auxiliary_a = pong_data_col_nram + span_num_deal * sizeof(T);
char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T);
const size_t ping_pong_gap = 5 * span_num_deal * sizeof(T);
size_t data_col_ping_pong_idx = 0;
int32_t block_num_per_core = (batch_size * num_queries * num_heads) / taskDim;
const int32_t block_num_rem =
(batch_size * num_queries * num_heads) % taskDim;
const int32_t idx_start = taskId < (block_num_rem + 1)
? taskId * (block_num_per_core + 1)
: taskId * block_num_per_core + block_num_rem;
block_num_per_core =
taskId < block_num_rem
? (batch_size * num_queries * num_heads) / taskDim + 1
: (batch_size * num_queries * num_heads) / taskDim;
for (int32_t cur_idx = idx_start; cur_idx < idx_start + block_num_per_core;
++cur_idx) {
// cur_idx = batch_idx * num_queries * num_heads + query_idx * num_heads +
// head_idx
const int32_t head_idx = cur_idx % num_heads;
const int32_t batch_idx = (cur_idx / num_heads) / num_queries;
const char *data_value_gdram_start =
data_value_gdram +
batch_idx * num_keys * num_heads * channels * sizeof(T);
const char *data_sampling_loc_gdram_start =
data_sampling_loc_gdram +
cur_idx * num_levels * num_points * 2 * sizeof(T);
const char *data_attn_weight_gdram_start =
data_attn_weight_gdram + cur_idx * num_levels * num_points * sizeof(T);
char *data_col_gdram_start =
data_col_gdram + cur_idx * channels * sizeof(T);
for (int32_t c_seg_idx = 0; c_seg_idx < channels_seg_num; ++c_seg_idx) {
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
span_num_deal, (T)0);
// load data
// level_idx = 0, point_idx = 0
__memcpy(data_spatial_shapes_nram, data_spatial_shapes_gdram,
2 * sizeof(int32_t), GDRAM2NRAM);
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + c_seg_idx * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_gdram_start)[0];
T loc_h = ((T *)data_sampling_loc_gdram_start)[1];
T weight = ((T *)data_attn_weight_gdram_start)[0];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, span_num_deal, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_gdram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
__memcpy(
data_spatial_shapes_nram,
data_spatial_shapes_gdram + spatial_h_ptr * sizeof(int32_t),
2 * sizeof(int32_t), GDRAM2NRAM);
spatial_h_next_point = ((int32_t *)data_spatial_shapes_nram)[0];
spatial_w_next_point = ((int32_t *)data_spatial_shapes_nram)[1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
c_seg_idx * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_h_next_point = spatial_h;
spatial_w_next_point = spatial_w;
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, span_num_deal, spatial_w, spatial_h, x, y);
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
}
// store
__memcpy_async(
data_col_gdram_start + c_seg_idx * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
span_num_deal * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
if (channels_rem > 0) {
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
channels_align_rem, (T)0);
// load data
// level_idx = 0, point_idx = 0
__memcpy(data_spatial_shapes_nram, data_spatial_shapes_gdram,
2 * sizeof(int32_t), GDRAM2NRAM);
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + channels_seg_num * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_gdram_start)[0];
T loc_h = ((T *)data_sampling_loc_gdram_start)[1];
T weight = ((T *)data_attn_weight_gdram_start)[0];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, channels_rem, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_gdram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
__memcpy(
data_spatial_shapes_nram,
data_spatial_shapes_gdram + spatial_h_ptr * sizeof(int32_t),
2 * sizeof(int32_t), GDRAM2NRAM);
spatial_h_next_point = ((int32_t *)data_spatial_shapes_nram)[0];
spatial_w_next_point = ((int32_t *)data_spatial_shapes_nram)[1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
channels_seg_num * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_w_next_point = spatial_w;
spatial_h_next_point = spatial_h;
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, channels_align_rem, spatial_w, spatial_h, x, y);
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
}
// store
__memcpy_async(
data_col_gdram_start + channels_seg_num * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
channels_rem * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
}
__asm__ volatile("sync;");
return;
}
__mlu_func__ void genMask0101(float *mask_ram, int32_t size) {
int32_t align_num = NFU_ALIGN_SIZE / sizeof(float);
for (int32_t i = 0; i < align_num; ++i) {
mask_ram[i] = i % 2;
}
__asm__ volatile("sync;");
__memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM,
NFU_ALIGN_SIZE, 0, size / align_num - 2);
__asm__ volatile("sync;");
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
#if __BANG_ARCH__ >= 300
if (coreId == 0x80) {
return;
}
size_t block_num_per_core, batch_start, deal_g, offset_g;
size_t block_num_rem = 0;
const size_t grid_total = num_queries * num_heads * num_levels * num_points;
if (batch_size >= taskDim) {
block_num_rem = batch_size % taskDim;
block_num_per_core = taskId < block_num_rem ? batch_size / taskDim + 1
: batch_size / taskDim;
batch_start = taskId < block_num_rem
? taskId * block_num_per_core
: taskId * block_num_per_core + block_num_rem;
deal_g = grid_total;
offset_g = 0;
} else {
size_t skip_n = taskDim / batch_size;
batch_start = taskId / skip_n;
block_num_per_core = batch_start >= batch_size ? 0 : 1;
deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points);
size_t id = taskId % skip_n;
offset_g = id * deal_g;
deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1);
}
const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float);
int32_t deal_num;
int32_t cut_channel_iter = 2;
const size_t spatial_size =
PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE);
const size_t level_start_index_size =
PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE);
int32_t channel = channels;
int32_t mult;
while (true) {
deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) /
(8 * channel + 7) / sizeof(T);
deal_num = PAD_DOWN(deal_num, float_align);
deal_num = PAD_DOWN(deal_num, num_levels * num_points);
if (deal_num > 0) {
break;
} else {
channel = channels / cut_channel_iter;
cut_channel_iter += 2;
}
}
mult = channel;
const int32_t c_rep = channels / channel;
const int32_t c_rem = channels % channel;
const int32_t g_rep = deal_g / deal_num;
const int32_t g_rem = deal_g % deal_num;
// nram buffer alloc
char *data_spatial_shapes_nram = nram_buffer;
char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size;
char *input_tl = data_level_start_index_nram + level_start_index_size;
char *input_tr = input_tl + deal_num * mult * sizeof(T);
char *input_bl = input_tr + deal_num * mult * sizeof(T);
char *input_br = input_bl + deal_num * mult * sizeof(T);
char *weight_tl = input_tl + 4 * deal_num * mult * sizeof(T);
char *weight_tr = weight_tl + deal_num * mult * sizeof(T);
char *weight_bl = weight_tr + deal_num * mult * sizeof(T);
char *weight_br = weight_bl + deal_num * mult * sizeof(T);
char *mask_tl = weight_br + deal_num * mult * sizeof(T);
char *mask_tr = mask_tl + deal_num * sizeof(T);
char *mask_bl = mask_tr + deal_num * sizeof(T);
char *mask_br = mask_bl + deal_num * sizeof(T);
char *point_ram = mask_br + deal_num * sizeof(T);
char *index_tl = point_ram + deal_num * sizeof(T);
char *index_bl = index_tl + deal_num * sizeof(T);
// nram space reuse
char *grid_ram = weight_tl;
char *mask_ram = weight_bl;
char *coord_x = input_bl;
char *coord_y = coord_x + deal_num * sizeof(T);
char *coord_x_low = input_tl;
char *coord_y_low = coord_x_low + deal_num * sizeof(T);
char *coord_x_low_int = weight_tl;
char *coord_y_low_int = weight_tr;
char *spatial_x = mask_tl;
char *spatial_y = mask_tr;
char *spatial_x_float = weight_bl;
char *spatial_y_float = weight_br;
char *spatial_x_temp = mask_bl;
char *spatial_y_temp = mask_br;
char *base_ptr_offset = weight_tl;
char *auxiliary_a = point_ram;
char *auxiliary_b = weight_bl;
__memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(data_level_start_index_nram, data_level_start_index_gdram,
num_levels * sizeof(int32_t), GDRAM2NRAM);
__asm__ volatile("sync;");
for (int32_t batch_idx = batch_start;
batch_idx < batch_start + block_num_per_core; ++batch_idx) {
for (int32_t grid_iter = 0; grid_iter <= g_rep; ++grid_iter) {
int32_t io_data_num = deal_num;
const int32_t grid_off_base =
batch_idx * grid_total + offset_g + grid_iter * deal_num;
if (grid_iter == g_rep) {
if (g_rem == 0) {
continue;
} else {
io_data_num = g_rem;
}
}
char *data_col_gdram_start =
data_col_gdram + (batch_idx * num_queries * num_heads * channels +
(offset_g + grid_iter * deal_num) /
(num_levels * num_points) * channels) *
sizeof(float);
// load data_sampling_loc
__memcpy_async(
grid_ram, data_sampling_loc_gdram + grid_off_base * 2 * sizeof(float),
io_data_num * 2 * sizeof(float), GDRAM2NRAM);
genMask0101((float *)mask_ram, deal_num * 2);
__asm__ volatile("sync;");
// generate x and y coordinate vector
// generate spatial_x and spatial_y spatial vector
__bang_collect((float *)coord_y, (float *)grid_ram, (float *)mask_ram,
deal_num * 2); // y
__bang_collect((float *)spatial_x_temp, (float *)data_spatial_shapes_nram,
(float *)mask_ram,
num_levels * 2); // spatial_x
__bang_not((float *)mask_ram, (float *)mask_ram, deal_num * 2);
__bang_collect((float *)coord_x, (float *)grid_ram, (float *)mask_ram,
deal_num * 2); // x
__bang_collect((float *)spatial_y_temp, (float *)data_spatial_shapes_nram,
(float *)mask_ram,
num_levels * 2); // spatial_y
for (int32_t i = 0; i < num_levels; i++) {
__bang_write_value((int32_t *)spatial_x + i * num_points, num_points,
((int32_t *)spatial_x_temp)[i]);
__bang_write_value((int32_t *)spatial_y + i * num_points, num_points,
((int32_t *)spatial_y_temp)[i]);
}
__bang_int322float_rd((float *)spatial_x_float, (int32_t *)spatial_x,
num_levels * num_points, 0);
__bang_int322float_rd((float *)spatial_y_float, (int32_t *)spatial_y,
num_levels * num_points, 0);
// map x from [0, 1] to [0, spatial_x]; map y from [0, 1] to [0,
// spatial_y]
__bang_cycle_mul((float *)coord_x, (float *)coord_x,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_sub_scalar((float *)coord_x, (float *)coord_x, (float)0.5,
deal_num);
__bang_cycle_mul((float *)coord_y, (float *)coord_y,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5,
deal_num);
__bang_floor((float *)coord_x_low, (float *)coord_x, deal_num);
__bang_floor((float *)coord_y_low, (float *)coord_y, deal_num);
// calc index_tl
const int32_t w_stride = num_heads * channels;
__bang_float2int32_rd((int32_t *)coord_x_low_int, (float *)coord_x_low,
deal_num, 0);
__bang_float2int32_rd((int32_t *)coord_y_low_int, (float *)coord_y_low,
deal_num, 0);
__bang_cycle_mul((int32_t *)index_tl, (int32_t *)coord_y_low_int,
(int32_t *)spatial_x, deal_num, num_levels * num_points);
__bang_add((int32_t *)index_tl, (int32_t *)index_tl,
(int32_t *)coord_x_low_int, deal_num);
__bang_mul_scalar((int32_t *)index_tl, (int32_t *)index_tl, w_stride,
deal_num);
const int32_t deal_lp_num = deal_num / (num_levels * num_points);
const int32_t h_rep = deal_lp_num / num_heads;
const int32_t h_rem = deal_lp_num % num_heads;
const int32_t head_start =
((offset_g + grid_iter * deal_num) / (num_levels * num_points)) %
num_heads;
for (int32_t iter = 0; iter < num_heads; ++iter) {
((int32_t *)base_ptr_offset)[iter] =
((head_start + iter) % num_heads) * channels;
}
if (h_rep > 0) {
__memcpy((int32_t *)base_ptr_offset + num_heads,
(int32_t *)base_ptr_offset, num_heads * sizeof(int32_t),
NRAM2NRAM, num_heads * sizeof(int32_t), 0, h_rep - 1);
}
if (h_rep > 0 && h_rem > 0) {
__memcpy((int32_t *)base_ptr_offset + h_rep * num_heads,
(int32_t *)base_ptr_offset, h_rem * sizeof(int32_t),
NRAM2NRAM);
}
__bang_transpose((int32_t *)auxiliary_a, (int32_t *)index_tl, deal_lp_num,
num_levels * num_points);
__bang_cycle_add((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)base_ptr_offset, deal_num, deal_lp_num);
__bang_transpose((int32_t *)index_tl, (int32_t *)auxiliary_a,
num_levels * num_points, deal_lp_num);
// calc index_bl
__bang_mul_scalar((int32_t *)auxiliary_a, (int32_t *)spatial_x, w_stride,
deal_num);
__bang_cycle_add((int32_t *)index_bl, (int32_t *)index_tl,
(int32_t *)auxiliary_a, deal_num,
num_levels * num_points);
// calc mask_tl, mask_tr, mask_bl, mask_br
__bang_sub_scalar((float *)spatial_x_float, (float *)spatial_x_float,
(float)1.0, deal_num);
__bang_sub_scalar((float *)spatial_y_float, (float *)spatial_y_float,
(float)1.0, deal_num);
// mask_tl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_low < spatial_y
__bang_ge_scalar((float *)mask_bl, (float *)coord_x_low, (float)0,
deal_num);
__bang_cycle_le((float *)mask_br, (float *)coord_x_low,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_bl, (float *)mask_bl, (float *)mask_br,
deal_num);
__bang_ge_scalar((float *)mask_tr, (float *)coord_y_low, (float)0,
deal_num);
__bang_cycle_le((float *)mask_br, (float *)coord_y_low,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br,
deal_num);
__bang_and((float *)mask_tl, (float *)mask_tr, (float *)mask_bl,
deal_num);
// mask_tr : 0 <= coord_x_high < spatial_x && 0 <= coord_y_low < spatial_y
__bang_ge_scalar((float *)mask_br, (float *)coord_x_low, (float)(-1.0),
deal_num);
__bang_cycle_lt((float *)auxiliary_a, (float *)coord_x_low,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
__bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br,
deal_num);
// mask_bl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_high < spatial_y
__bang_ge_scalar((float *)auxiliary_a, (float *)coord_y_low,
(float)(-1.0), deal_num);
__bang_cycle_lt((float *)auxiliary_b, (float *)coord_y_low,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_and((float *)auxiliary_a, (float *)auxiliary_a,
(float *)auxiliary_b, deal_num);
__bang_and((float *)mask_bl, (float *)mask_bl, (float *)auxiliary_a,
deal_num);
// mask_br : 0 <= coord_x_high < spatial_x && 0 <= coord_y_high <
// spatial_y
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
// calc inner point num
__bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0,
deal_num);
__bang_mul_scalar((float *)weight_tr, (float *)mask_tr, (float)5.0,
deal_num);
__bang_add((float *)weight_tl, (float *)weight_tl, (float *)weight_tr,
deal_num);
__bang_mul_scalar((float *)weight_tr, (float *)mask_bl, (float)3.0,
deal_num);
__bang_add((float *)point_ram, (float *)weight_tr, (float *)mask_br,
deal_num);
__bang_add((float *)point_ram, (float *)point_ram, (float *)weight_tl,
deal_num);
// calc interpolation weight
__bang_sub((float *)weight_bl, (float *)coord_x_low, (float *)coord_x,
deal_num);
__bang_sub((float *)weight_br, (float *)coord_y_low, (float *)coord_y,
deal_num);
__bang_add_scalar((float *)weight_bl, (float *)weight_bl, (float)1.0,
deal_num);
__bang_add_scalar((float *)weight_br, (float *)weight_br, (float)1.0,
deal_num);
__bang_sub((float *)weight_tl, (float *)coord_x, (float *)coord_x_low,
deal_num);
__bang_sub((float *)weight_tr, (float *)coord_y, (float *)coord_y_low,
deal_num);
__bang_mul((float *)input_tl, (float *)weight_bl, (float *)weight_br,
deal_num);
__bang_mul((float *)input_tl + deal_num, (float *)weight_br,
(float *)weight_tl, deal_num);
__bang_mul((float *)input_tl + 2 * deal_num, (float *)weight_bl,
(float *)weight_tr, deal_num);
__bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl,
(float *)weight_tr, deal_num);
__asm__ volatile("sync;");
// extend weight
const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT;
const int32_t w_rem = channel % ELE_COUNT;
if (w_rem != 0) {
const int32_t data_sz = 1 * sizeof(float);
const int32_t dst_str = channel * sizeof(float);
for (int32_t iter = w_rep; iter < channel; ++iter) {
__memcpy_async((float *)weight_tl + iter, (float *)input_tl, data_sz,
NRAM2NRAM, dst_str, data_sz, 4 * deal_num - 1);
}
}
if (w_rep != 0) {
for (int32_t i = 0; i < 4 * deal_num; i++) {
__bang_write_value((float *)weight_tl + i * channel, w_rep,
((float *)input_tl)[i]);
}
}
__asm__ volatile("sync;");
const char *data_value_gdram_start =
data_value_gdram +
batch_idx * num_keys * num_heads * channels * sizeof(float);
const int32_t c_str = deal_num * channel * sizeof(float);
const int32_t cs_str = num_heads * channels * sizeof(float);
for (int32_t c_iter = 0; c_iter <= c_rep; ++c_iter) {
int32_t c_real_num = channel;
if (c_iter == c_rep) {
if (c_rem == 0) {
continue;
} else {
c_real_num = c_rem;
}
}
__bang_write_zero((float *)input_tl, 4 * deal_num * channel);
__asm__ volatile("sync;");
// load data_value
for (int32_t p_idx = 0; p_idx < io_data_num; ++p_idx) {
const int32_t inner_point_num = (int32_t)((float *)point_ram)[p_idx];
const int32_t tl_offset = ((int32_t *)index_tl)[p_idx];
const int32_t bl_offset = ((int32_t *)index_bl)[p_idx];
const int32_t level_start_id =
((int32_t *)data_level_start_index_nram)[(p_idx / num_points) %
num_levels];
const char *data_value_ptr =
data_value_gdram_start +
(level_start_id * num_heads * channels + c_iter * channel) *
sizeof(float);
switch (inner_point_num) {
case 16: // 4 points are cached.
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 12: // 2 points are cached. (top_left, top_right)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 4: // 2 points are cached. (bottom_left, bottom_right)
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 10: // 2 points are cached. (top_left, bottom_left)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 6: // 2 points are cached. (top_right, bottom_right)
__memcpy_async(
(float *)input_tr + p_idx * channel,
(float *)data_value_ptr + tl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
__memcpy_async(
(float *)input_br + p_idx * channel,
(float *)data_value_ptr + bl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 7: // 1 point is cached. (top_left)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 5: // 1 point is cached. (top_right)
__memcpy_async(
(float *)input_tr + p_idx * channel,
(float *)data_value_ptr + tl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 3: // 1 point is cached. (bottom_left)
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 1: // 1 point is cached. (bottom_right)
__memcpy_async(
(float *)input_br + p_idx * channel,
(float *)data_value_ptr + bl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
default:
continue;
}
}
__asm__ volatile("sync;");
// interpolation
__bang_mul((float *)input_tl, (float *)input_tl, (float *)weight_tl,
4 * deal_num * channel);
__bang_add((float *)input_tl, (float *)input_tl, (float *)input_bl,
2 * deal_num * channel);
__bang_add((float *)input_tl, (float *)input_tl, (float *)input_tr,
deal_num * channel);
// load attention weight
void *attn_weight = mask_tl;
__memcpy((float *)attn_weight,
(float *)data_attn_weight_gdram + grid_off_base,
io_data_num * sizeof(float), GDRAM2NRAM);
// calc data_col, muladd attention weight
__bang_transpose((float *)input_tr, (float *)input_tl, deal_num,
channel);
__bang_cycle_mul((float *)input_tr, (float *)input_tr,
(float *)attn_weight, deal_num * channel, deal_num);
__bang_transpose((float *)input_tl, (float *)input_tr, channel,
deal_num);
__bang_sumpool((float *)input_bl, (float *)input_tl, channel, 1,
io_data_num, 1, num_levels * num_points,
num_levels * num_points, 1);
// store
__memcpy((float *)data_col_gdram_start + c_iter * channel,
(float *)input_bl, c_real_num * sizeof(float), NRAM2GDRAM,
channels * sizeof(float), channel * sizeof(float),
(io_data_num / (num_levels * num_points)) - 1);
}
}
}
__asm__ volatile("sync;");
#endif
return;
}
template __mlu_global__ void MLUKernelMsDeformAttnForwardDefault<float>(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram);
void KernelMsDeformAttnForwardDefault(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
MLUKernelMsDeformAttnForwardDefault<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
}
template __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel<float>(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram);
void KernelMsDeformAttnForwardSmallChannel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
MLUKernelMsDeformAttnForwardSmallChannel<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
}
template <typename T>
void __mlu_func__ msDeformAttnCol2imBilinear(
T *top_grad_temp, const int32_t &height, const int32_t &width, const T &w1,
const T &w2, const T &w3, const T &w4, const int32_t &h_low,
const int32_t &w_low, const int32_t &h_high, const int32_t &w_high,
const int32_t &base_ptr, const int32_t &h_low_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &h_high_ptr_offset,
const int32_t &w_high_ptr_offset, const T &hh, const T &hw, const T &lh,
const T &lw, T *top_grad, const T &data_attn_weight, T *grad_h_weight,
T *grad_w_weight, T *grad_value, T *grad_output_nram, T *grad_weight,
T *grad_sampling_loc, T *grad_attn_weight, T *grad_output_nram_temp,
const int32_t &deal_num, const int32_t &deal_num_real,
const T *data_value_ptr) {
if (h_low >= 0 && w_low >= 0) {
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram, data_value_ptr + offset1,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num_real);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num_real);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset1),
(T *)top_grad_temp, deal_num_real);
}
if (h_low >= 0 && w_high <= width - 1) {
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset2,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num_real);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num_real);
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w2,
deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset2),
(T *)top_grad_temp, deal_num_real);
}
if (h_high <= height - 1 && w_low >= 0) {
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset3,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num_real);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w3,
deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset3),
(T *)top_grad_temp, deal_num_real);
}
if (h_high <= height - 1 && w_high <= width - 1) {
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset4,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w4,
deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset4),
(T *)top_grad_temp, deal_num_real);
}
__bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num_real);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_output_nram, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
const int32_t align_num_on_200 = NFU_ALIGN_SIZE / LEN_FLOAT;
recursiveSumPool(grad_output_nram, align_num_on_200,
deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_output_nram, grad_output_nram,
NFU_ALIGN_SIZE / LEN_FLOAT);
#endif
__bang_atomic_add((T *)grad_output_nram, (T *)grad_attn_weight,
(T *)grad_output_nram, 1);
__bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
recursiveSumPool(grad_w_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_w_weight, grad_w_weight, NFU_ALIGN_SIZE / LEN_FLOAT);
#endif
__bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc),
(T *)grad_w_weight, 1);
__bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real);
__bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
recursiveSumPool(grad_h_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_h_weight, grad_h_weight, NFU_ALIGN_SIZE / LEN_FLOAT);
#endif
__bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1),
(T *)grad_h_weight, 1);
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight) {
if (coreId == 0x80) {
return;
}
const int32_t split_num = 8;
const int32_t spatial_shapes_size = 64;
int32_t deal_num = PAD_DOWN(
(MAX_NRAM_SIZE - spatial_shapes_size) / split_num / LEN_FLOAT, ALIGN_NUM);
float *grad_output_nram = (float *)nram_buffer;
float *grad_output_nram_temp = (float *)nram_buffer + deal_num;
float *grad_weight = (float *)nram_buffer + 2 * deal_num;
float *grad_h_weight = (float *)nram_buffer + 3 * deal_num;
float *grad_w_weight = (float *)nram_buffer + 4 * deal_num;
float *top_grad = (float *)nram_buffer + 5 * deal_num;
float *top_grad_temp = (float *)nram_buffer + 6 * deal_num;
int32_t *spatial_shapes_nram =
(int32_t *)((float *)nram_buffer + 7 * deal_num);
float *sampling_loc_nram =
(float *)nram_buffer + 7 * deal_num + 2 * sizeof(int32_t);
const int32_t total_num = batch * num_query * num_heads * num_levels;
int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem);
int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core)
: (num_rem + taskId * num_per_core);
int32_t end_per_core = start_per_core + num_per_core;
const int32_t C_repeat = channels / deal_num;
const int32_t C_tail = channels % deal_num;
const int32_t qid_stride = num_heads * channels;
int32_t base_ptr = 0;
for (int32_t num_loop = start_per_core; num_loop < end_per_core; ++num_loop) {
const int32_t l_col = num_loop % num_levels;
const int32_t m_col = num_loop / num_levels % num_heads;
const int32_t q_col = num_loop / num_levels / num_heads % num_query;
const int32_t b_col = num_loop / num_query / num_heads / num_levels;
int32_t data_weight_ptr = num_loop * num_points;
int32_t data_loc_w_ptr = data_weight_ptr << 1;
const int32_t value_offset = b_col * spatial_size * num_heads * channels;
const int32_t level_start_id = data_level_start_index[l_col];
int32_t spatial_h_ptr = l_col << 1;
int32_t grad_output_offset = b_col * num_query * num_heads * channels +
q_col * num_heads * channels +
m_col * channels;
__memcpy(spatial_shapes_nram, spatial_shapes + spatial_h_ptr,
2 * sizeof(int32_t), GDRAM2NRAM);
const int32_t spatial_h = spatial_shapes_nram[0];
const int32_t spatial_w = spatial_shapes_nram[1];
const int32_t value_ptr_offset = value_offset + level_start_id * qid_stride;
const float *data_value_ptr = data_value + value_ptr_offset;
float *grad_value_ptr = grad_value + value_ptr_offset;
const int32_t grad_attn_weight_out = num_loop * num_points;
const int32_t grad_sampling_loc_out = num_loop * num_points * 2;
for (int32_t p_col = 0; p_col < num_points; ++p_col) {
__memcpy(sampling_loc_nram, data_sampling_loc + data_loc_w_ptr,
2 * LEN_FLOAT, GDRAM2NRAM);
const float loc_w = sampling_loc_nram[0];
const float loc_h = sampling_loc_nram[1];
const float weight = data_attn_weight[data_weight_ptr];
const float h_im = loc_h * spatial_h - 0.5;
const float w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
const int32_t h_low = floorf(h_im);
const int32_t w_low = floorf(w_im);
const int32_t h_high = h_low + 1;
const int32_t w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1.0 - lh;
const float hw = 1.0 - lw;
const int32_t w_stride = num_heads * channels;
const int32_t h_stride = spatial_w * w_stride;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * w_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + w_stride;
float w1 = hh * hw;
float w2 = hh * lw;
float w3 = lh * hw;
float w4 = lh * lw;
for (int32_t C_loop = 0; C_loop < C_repeat; ++C_loop) {
base_ptr = m_col * channels + C_loop * deal_num;
__bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM));
__memcpy(top_grad,
grad_output + grad_output_offset + C_loop * deal_num,
deal_num * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram, grad_weight,
grad_sampling_loc + grad_sampling_loc_out + p_col * 2,
grad_attn_weight + grad_attn_weight_out + p_col,
grad_output_nram_temp, deal_num, deal_num, data_value_ptr);
}
if (C_tail != 0) {
base_ptr = m_col * channels + C_repeat * deal_num;
__bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM));
__memcpy(top_grad,
grad_output + grad_output_offset + C_repeat * deal_num,
C_tail * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram, grad_weight,
grad_sampling_loc + grad_sampling_loc_out + p_col * 2,
grad_attn_weight + grad_attn_weight_out + p_col,
grad_output_nram_temp, deal_num, C_tail, data_value_ptr);
}
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
}
void __mlu_func__ computeGridMaskAndOffset(
float *nram_grad_output_tl, float *nram_grad_output_tr, float *nram_loc_w,
float *nram_loc_h, float *nram_h_stride, int32_t *nram_spatial_shapes,
float *nram_w_low_temp, float *nram_h_high_temp, float *nram_w_low,
float *nram_h_low, float *nram_h_high, float *nram_w_high, float *nram_lh,
float *nram_lw, float *nram_hh, float *nram_hw,
float *nram_h_low_ptr_offset, float *nram_h_high_ptr_offset,
float *nram_w_low_ptr_offset, float *nram_w_high_ptr_offset, float *nram_w1,
float *nram_w2, float *nram_w3, float *nram_w4, float *nram_offset_temp,
float *nram_offset1, float *nram_offset2, float *nram_offset3,
float *nram_offset4, float *nram_base_ptr, float *nram_h_low_temp,
int32_t num_deal_grid, int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, const int32_t w_stride,
const int32_t qid_stride) {
#if __BANG_ARCH__ >= 322
// [num_levels, 2] --> [2, num_levels]
__bang_transpose(nram_grad_output_tl, nram_loc_w, num_deal_grid, 2);
__bang_transpose(nram_loc_w, nram_grad_output_tl,
num_per_time_real * num_heads * num_levels, num_points);
__bang_transpose(nram_loc_h, nram_grad_output_tl + num_deal_grid,
num_per_time_real * num_heads * num_levels, num_points);
__bang_int322float((float *)nram_spatial_shapes,
(int32_t *)nram_spatial_shapes, num_levels * 2, 0);
__bang_transpose(nram_grad_output_tr, (float *)nram_spatial_shapes,
num_levels, 2);
__bang_mul_scalar(nram_h_stride, nram_grad_output_tr + num_levels, w_stride,
num_levels);
__memcpy_async(nram_spatial_shapes, nram_grad_output_tr,
num_levels * 2 * sizeof(float), NRAM2NRAM);
__bang_cycle_mul(nram_loc_w, nram_loc_w,
(float *)nram_spatial_shapes + num_levels, num_deal_grid,
num_levels);
__bang_cycle_mul(nram_loc_h, nram_loc_h, (float *)(nram_spatial_shapes),
num_deal_grid, num_levels);
__bang_sub_scalar(nram_loc_w, nram_loc_w, 0.5, num_deal_grid);
__bang_sub_scalar(nram_loc_h, nram_loc_h, 0.5, num_deal_grid);
// get mask. (h_im > -1 && w_im > -1 &&
// h_im < spatial_h && w_im < spatial_w)
__bang_cycle_lt(nram_w_low_temp, nram_loc_w,
(float *)(nram_spatial_shapes + num_levels), num_deal_grid,
num_levels);
__bang_cycle_lt(nram_h_high_temp, nram_loc_h, (float *)(nram_spatial_shapes),
num_deal_grid, num_levels);
__bang_and(nram_w_low_temp, nram_w_low_temp, nram_h_high_temp, num_deal_grid);
__bang_gt_scalar(nram_h_high_temp, nram_loc_h, -1, num_deal_grid);
__bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp,
num_deal_grid);
__bang_gt_scalar(nram_w_low_temp, nram_loc_w, -1, num_deal_grid);
__bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp,
num_deal_grid);
__bang_transpose(nram_w_low_temp, nram_h_high_temp, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_high_temp, nram_w_low_temp,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, nram_loc_w, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_loc_w, nram_grad_output_tl, num_deal_grid * sizeof(float),
NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, nram_loc_h, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_loc_h, nram_grad_output_tl, num_deal_grid * sizeof(float),
NRAM2NRAM);
__bang_floor(nram_w_low, nram_loc_w, num_deal_grid);
__bang_floor(nram_h_low, nram_loc_h, num_deal_grid);
__bang_add_scalar(nram_h_high, nram_h_low, 1, num_deal_grid);
__bang_add_scalar(nram_w_high, nram_w_low, 1, num_deal_grid);
__bang_sub(nram_lh, nram_loc_h, nram_h_low, num_deal_grid);
__bang_sub(nram_lw, nram_loc_w, nram_w_low, num_deal_grid);
__bang_fusion(FUSION_FMA, nram_hh, nram_lh, (float)(-1), 1, num_deal_grid);
__bang_fusion(FUSION_FMA, nram_hw, nram_lw, (float)(-1), 1, num_deal_grid);
__bang_transpose(nram_h_low_ptr_offset, nram_h_low,
num_per_time_real * num_heads * num_levels, num_points);
__bang_cycle_mul(nram_h_low_ptr_offset, nram_h_low_ptr_offset, nram_h_stride,
num_deal_grid, num_levels);
__bang_cycle_add(nram_h_high_ptr_offset, nram_h_low_ptr_offset, nram_h_stride,
num_deal_grid, num_levels);
__bang_transpose(nram_w_low_ptr_offset, nram_h_low_ptr_offset, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_low_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_w_low_ptr_offset, nram_h_high_ptr_offset, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_high_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_mul_scalar(nram_w_low_ptr_offset, nram_w_low, qid_stride,
num_deal_grid);
__bang_add_scalar(nram_w_high_ptr_offset, nram_w_low_ptr_offset, qid_stride,
num_deal_grid);
__bang_mul(nram_w1, nram_hh, nram_hw, num_deal_grid);
__bang_mul(nram_w2, nram_hh, nram_lw, num_deal_grid);
__bang_mul(nram_w3, nram_lh, nram_hw, num_deal_grid);
__bang_mul(nram_w4, nram_lh, nram_lw, num_deal_grid);
__bang_add(nram_offset1, nram_h_low_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset1,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset1, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset2, nram_h_low_ptr_offset, nram_w_high_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset2,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset2, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset3, nram_h_high_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset3,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset3, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset4, nram_h_high_ptr_offset, nram_w_high_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset4,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset4, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
// h_low >= 0 && w_low >= 0 mask2
float *mask1 = nram_h_low_ptr_offset;
float *mask2 = nram_h_high_ptr_offset;
float *mask3 = nram_w_low_ptr_offset;
float *mask4 = nram_w_high_ptr_offset;
__bang_ge_scalar(mask1, nram_h_low, 0, num_deal_grid);
__bang_ge_scalar(mask2, nram_w_low, 0, num_deal_grid);
__bang_and(mask2, mask1, mask2, num_deal_grid);
__bang_and(mask2, nram_h_high_temp, mask2, num_deal_grid);
// h_low >= 0 && w_high <= width - 1 mask1
__bang_transpose(mask3, nram_w_high,
num_per_time_real * num_heads * num_levels, num_points);
__bang_sub_scalar(nram_spatial_shapes, nram_spatial_shapes, 1,
num_levels * 2);
__bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes + num_levels),
num_deal_grid, num_levels);
__bang_transpose(mask4, mask3, num_points,
num_per_time_real * num_heads * num_levels);
__bang_and(mask1, mask1, mask4, num_deal_grid);
__bang_and(mask1, nram_h_high_temp, mask1, num_deal_grid);
// h_high <= height - 1 && w_high <= width - 1 mask3
__bang_transpose(mask3, nram_h_high,
num_per_time_real * num_heads * num_levels, num_points);
__bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes), num_deal_grid,
num_levels);
__bang_transpose(nram_h_low_temp, mask3, num_points,
num_per_time_real * num_heads * num_levels);
__bang_and(mask4, mask4, nram_h_low_temp, num_deal_grid);
__bang_and(mask3, mask4, nram_h_high_temp, num_deal_grid);
// h_high <= height - 1 && w_low >= 0 mask4
__bang_ge_scalar(nram_w_low_temp, nram_w_low, 0, num_deal_grid);
__bang_and(mask4, nram_h_low_temp, nram_w_low_temp, num_deal_grid);
__bang_and(mask4, mask4, nram_h_high_temp, num_deal_grid);
#endif
}
void __mlu_func__ loadValue(
float *nram_grad_output_tl, float *nram_grad_output_tr,
float *nram_grad_output_bl, float *nram_grad_output_br,
const float *data_value, const float *grad_output, float *grad_temp1,
float *grad_temp2, float *mask1, float *mask2, float *mask3, float *mask4,
float *nram_offset1, float *nram_offset2, float *nram_offset3,
float *nram_offset4, float *nram_grad_weight,
int32_t *nram_level_start_index, int32_t offset_nram,
int32_t start_per_core, int32_t grid_loop, int32_t num_per_time_theory,
int32_t num_heads, int32_t deal_num_real, int32_t num_per_time_real,
int32_t num_deal_grid, const int32_t num_query, const int32_t num_levels,
const int32_t num_points, int32_t grid_offset, const int32_t spatial_size,
const int32_t qid_stride) {
#if __BANG_ARCH__ >= 322
int32_t value_offset_temp = 0;
__bang_write_zero(nram_grad_output_tl, 4 * offset_nram);
__sync_io_move_compute();
__memcpy_async(
grad_temp2,
grad_output + (start_per_core + grid_loop * num_per_time_theory) *
num_heads * deal_num_real,
num_per_time_real * num_heads * deal_num_real * sizeof(float),
GDRAM2NRAM);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
const int32_t b_col =
(grid_offset + loop) / num_query / num_heads / num_levels / num_points;
const int32_t l_col = (grid_offset + loop) / num_points % num_levels;
const int32_t level_start_id = nram_level_start_index[l_col];
value_offset_temp =
b_col * spatial_size * qid_stride + level_start_id * qid_stride;
if (mask2[loop]) {
__memcpy_async(
nram_grad_output_tl + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset1[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask1[loop]) {
__memcpy_async(
nram_grad_output_tr + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset2[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask4[loop]) {
__memcpy_async(
nram_grad_output_bl + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset3[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask3[loop]) {
__memcpy_async(
nram_grad_output_br + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset4[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
}
for (int32_t m = 0; m < deal_num_real; ++m) {
__memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
}
__sync_io_move_compute();
#endif
}
void __mlu_func__ computeGradValue(
float *grad_temp1, float *grad_temp2, float *grad_temp3, float *grad_temp4,
float *mask1, float *mask2, float *mask3, float *mask4, float *nram_offset1,
float *nram_offset2, float *nram_offset3, float *nram_offset4,
int32_t *nram_level_start_index, int32_t deal_num_real,
const float *grad_value, float *nram_w1, float *nram_w2, float *nram_w3,
float *nram_w4, int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, const int32_t num_query,
int32_t num_deal_grid, int32_t grid_offset, const int32_t spatial_size,
const int32_t qid_stride, float *nram_grid_offset1,
float *nram_grid_offset2) {
#if __BANG_ARCH__ >= 322
__bang_transpose(grad_temp3, grad_temp1,
deal_num_real * num_per_time_real * num_heads,
num_levels * num_points);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(grad_temp3, grad_temp3, grad_temp1,
num_deal_grid * deal_num_real,
deal_num_real * num_per_time_real * num_heads);
__bang_transpose(grad_temp4, grad_temp3, num_levels * num_points,
deal_num_real * num_per_time_real * num_heads);
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w1,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
nram_grid_offset1[loop] = ((loop + grid_offset) / num_query / num_heads /
num_levels / num_points) *
spatial_size * qid_stride;
}
__bang_transpose(nram_grid_offset2, nram_grid_offset1,
num_per_time_real * num_heads * num_levels, num_points);
__bang_int322float((float *)nram_level_start_index, nram_level_start_index,
num_levels, 0);
__bang_mul_scalar(nram_grid_offset1, (float *)nram_level_start_index,
qid_stride, num_levels);
__bang_cycle_add(nram_grid_offset2, nram_grid_offset2, nram_grid_offset1,
num_deal_grid, num_levels);
__bang_transpose(nram_grid_offset1, nram_grid_offset2, num_points,
num_per_time_real * num_heads * num_levels);
__bang_add(nram_offset1, nram_offset1, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset2, nram_offset2, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset3, nram_offset3, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset4, nram_offset4, nram_grid_offset1, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask2[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset1[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w2,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask1[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset2[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w3,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask4[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset3[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w4,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask3[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset4[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
#endif
}
void __mlu_func__ computeGradAttnWeight(
float *grad_w_weight, float *grad_weight, float *nram_grad_output_tl,
float *nram_grad_output_tr, float *nram_grad_output_bl,
float *nram_grad_output_br, float *grad_temp1, float *grad_temp2,
const float *grad_attn_weight, float *nram_hw, float *nram_hh,
float *nram_lw, float *nram_lh, float *grad_h_weight, float *nram_w1,
float *nram_w2, float *nram_w3, float *nram_w4, int32_t offset_nram,
int32_t num_deal_grid, int32_t deal_num_real, int32_t num_per_time_real,
const int32_t num_heads, const int32_t num_levels, const int32_t num_points,
int32_t grid_offset, float *nram_h_high_temp) {
#if __BANG_ARCH__ >= 322
__bang_write_zero(grad_w_weight, 2 * offset_nram);
// grad_output_nram_tl
__bang_transpose(grad_weight, nram_grad_output_tl, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_w1,
num_deal_grid * deal_num_real, num_deal_grid);
// nram_grad_output_tr
__bang_transpose(grad_weight, nram_grad_output_tr, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_lw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tr,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_hh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_w_weight, grad_w_weight, nram_grad_output_tr,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_w2,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_tr,
num_deal_grid * deal_num_real);
// nram_grad_output_tl
__bang_transpose(grad_weight, nram_grad_output_bl, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_hw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_h_weight, grad_h_weight, nram_grad_output_bl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_lh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_bl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_w3,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_bl,
num_deal_grid * deal_num_real);
// nram_grad_output_br
__bang_transpose(grad_weight, nram_grad_output_br, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_h_weight, grad_h_weight, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_w_weight, grad_w_weight, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_w4,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_br, nram_grad_output_tl, deal_num_real,
num_deal_grid);
__bang_transpose(nram_grad_output_tr, nram_grad_output_br,
num_per_time_real * num_heads,
num_points * num_levels * deal_num_real);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1,
num_deal_grid * deal_num_real,
num_per_time_real * num_heads * deal_num_real);
__bang_transpose(nram_grad_output_br, nram_grad_output_tr,
num_points * num_levels * deal_num_real,
num_per_time_real * num_heads);
__bang_transpose((float *)nram_grad_output_tr, (float *)nram_grad_output_br,
num_deal_grid, deal_num_real);
recursiveSumPool(nram_grad_output_tr, num_deal_grid, deal_num_real,
ALIGN_NUM);
__bang_float2int32((int *)nram_h_high_temp, nram_h_high_temp, num_deal_grid,
0);
__nram__ int table[2] = {0, (int)0xffffffff};
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)nram_grad_output_tr, (char *)nram_grad_output_tr,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__bang_atomic_add((float *)nram_grad_output_tr,
(float *)grad_attn_weight + grid_offset,
(float *)nram_grad_output_tr, num_deal_grid);
#endif
}
void __mlu_func__ computeGradSampingLoc(
const float *grad_sampling_loc, float *nram_grad_output_tl,
float *nram_grad_output_tr, float *grad_h_weight, float *grad_w_weight,
int32_t *nram_spatial_shapes, float *grad_temp1, float *grad_temp2,
float *nram_grad_weight, int32_t num_deal_grid, int32_t deal_num_real,
int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, int32_t grid_offset,
float *nram_h_high_temp) {
#if __BANG_ARCH__ >= 322
__bang_transpose(nram_grad_output_tl, grad_h_weight,
num_per_time_real * num_heads * num_levels * deal_num_real,
num_points);
__bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl,
(float *)nram_spatial_shapes, num_deal_grid * deal_num_real,
num_levels);
__bang_transpose(grad_h_weight, nram_grad_output_tl,
num_points * deal_num_real,
num_per_time_real * num_heads * num_levels);
for (int32_t m = 0; m < deal_num_real; ++m) {
__memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
}
__sync_io_move_compute();
__bang_transpose(nram_grad_output_tr, grad_temp1,
deal_num_real * num_per_time_real * num_heads,
num_levels * num_points);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1,
num_deal_grid * deal_num_real,
deal_num_real * num_per_time_real * num_heads);
__bang_transpose(grad_temp1, nram_grad_output_tr,
num_levels * num_points * deal_num_real,
num_per_time_real * num_heads);
__bang_mul(grad_h_weight, grad_h_weight, grad_temp1,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_tl, grad_h_weight, num_deal_grid,
deal_num_real);
__memcpy_async(grad_h_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM);
recursiveSumPool(grad_h_weight, num_deal_grid, deal_num_real, ALIGN_NUM);
__nram__ int table[2] = {0, (int)0xffffffff};
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)grad_h_weight, (char *)grad_h_weight,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__bang_transpose(nram_grad_output_tl, grad_w_weight,
num_per_time_real * num_heads * num_levels * deal_num_real,
num_points);
__bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl,
(float *)(nram_spatial_shapes + num_levels),
num_deal_grid * deal_num_real, num_levels);
__bang_transpose(grad_w_weight, nram_grad_output_tl,
num_points * deal_num_real,
num_per_time_real * num_heads * num_levels);
__bang_mul(grad_w_weight, grad_w_weight, grad_temp1,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_tl, grad_w_weight, num_deal_grid,
deal_num_real);
__memcpy(grad_w_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM);
recursiveSumPool(grad_w_weight, num_deal_grid, deal_num_real, ALIGN_NUM);
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)grad_w_weight, (char *)grad_w_weight,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__memcpy(grad_w_weight + num_deal_grid, grad_h_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, grad_w_weight, 2, num_deal_grid);
__bang_atomic_add((float *)nram_grad_output_tl,
(float *)grad_sampling_loc + grid_offset * 2,
(float *)nram_grad_output_tl, 2 * num_deal_grid);
#endif
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight) {
#if __BANG_ARCH__ > 322
const int32_t split_grid_num = 28;
const int32_t split_num_c = 8;
const int32_t C_align = PAD_UP(channels, ALIGN_NUM);
const int32_t num_hlp = num_heads * num_levels * num_points;
int32_t num_per_time_theory = (MAX_NRAM_SIZE - num_levels * sizeof(float) -
3 * num_levels * sizeof(int32_t)) /
sizeof(float) /
(split_num_c * C_align + split_grid_num) /
PAD_UP((num_hlp), ALIGN_NUM);
int32_t deal_grid_num_theory = num_per_time_theory * num_hlp;
const int32_t offset_nram = num_per_time_theory * C_align * num_hlp;
const int32_t offset_nram_calc = PAD_UP(deal_grid_num_theory, ALIGN_NUM);
float *nram_grad_output_tl = (float *)nram_buffer;
float *nram_grad_output_tr = (float *)nram_buffer + offset_nram;
float *nram_grad_output_bl = (float *)nram_buffer + 2 * offset_nram;
float *nram_grad_output_br = (float *)nram_buffer + 3 * offset_nram;
float *grad_temp1 = (float *)nram_buffer + 4 * offset_nram;
float *grad_temp2 = (float *)nram_buffer + 5 * offset_nram;
float *grad_temp3 = (float *)nram_buffer + 6 * offset_nram;
float *grad_temp4 = (float *)nram_buffer + 7 * offset_nram;
float *nram_loc_w = (float *)nram_buffer + split_num_c * offset_nram;
float *nram_loc_h =
(float *)nram_buffer + split_num_c * offset_nram + offset_nram_calc;
float *nram_h_low =
(float *)nram_buffer + split_num_c * offset_nram + 2 * offset_nram_calc;
float *nram_w_low =
(float *)nram_buffer + split_num_c * offset_nram + 3 * offset_nram_calc;
float *nram_h_high =
(float *)nram_buffer + split_num_c * offset_nram + 4 * offset_nram_calc;
float *nram_w_high =
(float *)nram_buffer + split_num_c * offset_nram + 5 * offset_nram_calc;
float *nram_h_low_temp =
(float *)nram_buffer + split_num_c * offset_nram + 6 * offset_nram_calc;
float *nram_h_high_temp =
(float *)nram_buffer + split_num_c * offset_nram + 7 * offset_nram_calc;
float *nram_hw =
(float *)nram_buffer + split_num_c * offset_nram + 8 * offset_nram_calc;
float *nram_hh =
(float *)nram_buffer + split_num_c * offset_nram + 9 * offset_nram_calc;
float *nram_lw =
(float *)nram_buffer + split_num_c * offset_nram + 10 * offset_nram_calc;
float *nram_lh =
(float *)nram_buffer + split_num_c * offset_nram + 11 * offset_nram_calc;
float *nram_h_low_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 12 * offset_nram_calc;
float *nram_h_high_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 13 * offset_nram_calc;
float *nram_w_low_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 14 * offset_nram_calc;
float *nram_w_high_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 15 * offset_nram_calc;
float *nram_w1 =
(float *)nram_buffer + split_num_c * offset_nram + 16 * offset_nram_calc;
float *nram_w2 =
(float *)nram_buffer + split_num_c * offset_nram + 17 * offset_nram_calc;
float *nram_w3 =
(float *)nram_buffer + split_num_c * offset_nram + 18 * offset_nram_calc;
float *nram_w4 =
(float *)nram_buffer + split_num_c * offset_nram + 19 * offset_nram_calc;
float *nram_grad_weight =
(float *)nram_buffer + split_num_c * offset_nram + 20 * offset_nram_calc;
float *nram_base_ptr =
(float *)nram_buffer + split_num_c * offset_nram + 21 * offset_nram_calc;
float *nram_offset_temp =
(float *)nram_buffer + split_num_c * offset_nram + 22 * offset_nram_calc;
float *nram_offset1 =
(float *)nram_buffer + split_num_c * offset_nram + 23 * offset_nram_calc;
float *nram_offset2 =
(float *)nram_buffer + split_num_c * offset_nram + 24 * offset_nram_calc;
float *nram_offset3 =
(float *)nram_buffer + split_num_c * offset_nram + 25 * offset_nram_calc;
float *nram_offset4 =
(float *)nram_buffer + split_num_c * offset_nram + 26 * offset_nram_calc;
float *nram_w_low_temp =
(float *)nram_buffer + split_num_c * offset_nram + 27 * offset_nram_calc;
int32_t *nram_spatial_shapes =
(int32_t *)((float *)nram_buffer + split_num_c * offset_nram +
28 * offset_nram_calc);
int32_t *nram_level_start_index =
(int32_t *)(nram_spatial_shapes + 2 * num_levels);
float *nram_h_stride = (float *)(nram_level_start_index + 3 * num_levels);
const int32_t total_num = batch * num_query;
int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem);
num_per_time_theory =
num_per_core > num_per_time_theory ? num_per_time_theory : num_per_core;
int32_t num_deal_grid = num_per_time_theory * num_hlp;
if (num_per_core == 0) return;
int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core)
: (num_rem + taskId * num_per_core);
const int32_t qid_stride = num_heads * channels;
int32_t deal_num_real = channels;
const int32_t repeat_times = num_per_core / num_per_time_theory;
const int32_t tail_num = num_per_core % num_per_time_theory;
int32_t num_per_time_real = num_per_time_theory;
for (int32_t loop = 0; loop < num_heads; ++loop) {
nram_base_ptr[loop] = loop * channels;
}
const int32_t w_stride = num_heads * channels;
for (int32_t grid_loop = 0; grid_loop < repeat_times + 1; grid_loop += 1) {
int32_t grid_offset =
(start_per_core + grid_loop * num_per_time_theory) * num_hlp;
if (grid_loop == repeat_times) {
if (tail_num == 0) {
continue;
} else {
grid_offset =
(start_per_core + repeat_times * num_per_time_theory) * num_hlp;
num_per_time_real = tail_num;
num_deal_grid = tail_num * num_hlp;
}
}
__memcpy_async(nram_spatial_shapes, spatial_shapes,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(nram_level_start_index, data_level_start_index,
num_levels * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(nram_loc_w, data_sampling_loc + grid_offset * 2,
num_deal_grid * 2 * sizeof(float), GDRAM2NRAM);
__memcpy(nram_grad_weight, data_attn_weight + grid_offset,
num_deal_grid * sizeof(float), GDRAM2NRAM);
computeGridMaskAndOffset(
nram_grad_output_tl, nram_grad_output_tr, nram_loc_w, nram_loc_h,
nram_h_stride, nram_spatial_shapes, nram_w_low_temp, nram_h_high_temp,
nram_w_low, nram_h_low, nram_h_high, nram_w_high, nram_lh, nram_lw,
nram_hh, nram_hw, nram_h_low_ptr_offset, nram_h_high_ptr_offset,
nram_w_low_ptr_offset, nram_w_high_ptr_offset, nram_w1, nram_w2,
nram_w3, nram_w4, nram_offset_temp, nram_offset1, nram_offset2,
nram_offset3, nram_offset4, nram_base_ptr, nram_h_low_temp,
num_deal_grid, num_per_time_real, num_heads, num_levels, num_points,
w_stride, qid_stride);
float *mask1 = nram_h_low_ptr_offset;
float *mask2 = nram_h_high_ptr_offset;
float *mask3 = nram_w_low_ptr_offset;
float *mask4 = nram_w_high_ptr_offset;
loadValue(nram_grad_output_tl, nram_grad_output_tr, nram_grad_output_bl,
nram_grad_output_br, data_value, grad_output, grad_temp1,
grad_temp2, mask1, mask2, mask3, mask4, nram_offset1,
nram_offset2, nram_offset3, nram_offset4, nram_grad_weight,
nram_level_start_index, offset_nram, start_per_core, grid_loop,
num_per_time_theory, num_heads, deal_num_real, num_per_time_real,
num_deal_grid, num_query, num_levels, num_points, grid_offset,
spatial_size, qid_stride);
float *nram_grid_offset1 = nram_loc_h;
float *nram_grid_offset2 = nram_loc_w;
computeGradValue(
grad_temp1, grad_temp2, grad_temp3, grad_temp4, mask1, mask2, mask3,
mask4, nram_offset1, nram_offset2, nram_offset3, nram_offset4,
nram_level_start_index, deal_num_real, grad_value, nram_w1, nram_w2,
nram_w3, nram_w4, num_per_time_real, num_heads, num_levels, num_points,
num_query, num_deal_grid, grid_offset, spatial_size, qid_stride,
nram_grid_offset1, nram_grid_offset2);
// compute grad_weight
float *grad_weight = grad_temp1;
float *grad_h_weight = grad_temp4;
float *grad_w_weight = grad_temp3;
computeGradAttnWeight(
grad_w_weight, grad_weight, nram_grad_output_tl, nram_grad_output_tr,
nram_grad_output_bl, nram_grad_output_br, grad_temp1, grad_temp2,
grad_attn_weight, nram_hw, nram_hh, nram_lw, nram_lh, grad_h_weight,
nram_w1, nram_w2, nram_w3, nram_w4, offset_nram, num_deal_grid,
deal_num_real, num_per_time_real, num_heads, num_levels, num_points,
grid_offset, nram_h_high_temp);
// compute grad_sampling_loc
computeGradSampingLoc(grad_sampling_loc, nram_grad_output_tl,
nram_grad_output_tr, grad_h_weight, grad_w_weight,
nram_spatial_shapes, grad_temp1, grad_temp2,
nram_grad_weight, num_deal_grid, deal_num_real,
num_per_time_real, num_heads, num_levels, num_points,
grid_offset, nram_h_high_temp);
}
#endif
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight);
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight);
void KernelMsDeformAttnBackwardDefaultKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float *data_value,
const int32_t *spatial_shapes, const int32_t *data_level_start_index,
const float *data_sampling_loc, const float *data_attn_weight,
const float *grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float *grad_value,
float *grad_sampling_loc, float *grad_attn_weight) {
MLUUnion1KernelMsDeformAttnBackwarDefaultKernel<<<k_dim, k_type, queue>>>(
data_value, spatial_shapes, data_level_start_index, data_sampling_loc,
data_attn_weight, grad_output, batch, spatial_size, num_heads, channels,
num_levels, num_query, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
}
void KernelMsDeformAttnBackwardSmallChannelsKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float *data_value,
const int32_t *spatial_shapes, const int32_t *data_level_start_index,
const float *data_sampling_loc, const float *data_attn_weight,
const float *grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float *grad_value,
float *grad_sampling_loc, float *grad_attn_weight) {
MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel<<<k_dim, k_type,
queue>>>(
data_value, spatial_shapes, data_level_start_index, data_sampling_loc,
data_attn_weight, grad_output, batch, spatial_size, num_heads, channels,
num_levels, num_query, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
}
mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "nms_utils.hpp"
#define COORD_DIM (4)
#define SIZE_NRAM_BUF (MAX_NRAM_SIZE + REM_FOR_STACK - 62 * 1024)
#define SIZE_SRAM_BUF (MAX_SRAM_SIZE)
__nram__ int8_t nram_buffer[SIZE_NRAM_BUF];
__mlu_shared__ int8_t sram_buffer[SIZE_SRAM_BUF];
enum Addr { SRAM, GDRAM };
template <typename IN_DT, typename OUT_DT>
__mlu_func__ void nms_detection(
uint32_t &output_box_num, const int output_mode, OUT_DT *output_dram,
IN_DT *input_data_score, const IN_DT *input_data_box, const Addr input_ram,
IN_DT *sram, const int core_limit, const int input_num_boxes,
const int max_output_size, const float thresh_iou, const float thresh_score,
const float offset, const int algo) {
// global value
int32_t *exit_flag = (int32_t *)(sram + 28);
exit_flag[0] = 0;
// score, x1, y1, x2, y2, inter_x1, inter_y1, inter_x2, inter_y2
int nms_buffer_count1 = 9;
// temp nram buffer to store selected target.
int nram_save_limit_count = 256;
float div_thresh_iou = 1.0 / thresh_iou;
// input data ptr
const IN_DT *input_x1_ptr = input_data_box;
const IN_DT *input_y1_ptr = input_x1_ptr + input_num_boxes;
const IN_DT *input_x2_ptr = input_y1_ptr + input_num_boxes;
const IN_DT *input_y2_ptr = input_x2_ptr + input_num_boxes;
int limit = 0; // find limit when GDRAM or SRAM
int max_seg_pad = 0; // the max length every repeat
int repeat = 0;
int remain = 0;
int remain_pad = 0;
int input_offset = 0; // offset of input_data for current core
int nram_save_count = 0;
if (output_mode == 0) {
limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) -
nram_save_limit_count * sizeof(OUT_DT)) /
(nms_buffer_count1 * sizeof(IN_DT));
} else {
// 5 maens: score, x1, y1, x2, y2
limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) -
nram_save_limit_count * 5 * sizeof(OUT_DT)) /
(nms_buffer_count1 * sizeof(IN_DT));
}
int max_seg_iou_compute = 0;
int repeat_iou_compute = 0;
int remain_iou_compute = 0;
int remain_pad_iou_compute = 0;
getComputeParamsBlockOrU1(sizeof(IN_DT), input_num_boxes, limit, core_limit,
input_offset, max_seg_pad, repeat, remain,
remain_pad, max_seg_iou_compute, repeat_iou_compute,
remain_iou_compute, remain_pad_iou_compute);
// init the data ptr
IN_DT *score = (IN_DT *)nram_buffer;
IN_DT *x1 = score + max_seg_pad;
IN_DT *y1 = x1 + max_seg_pad;
IN_DT *x2 = y1 + max_seg_pad;
IN_DT *y2 = x2 + max_seg_pad;
IN_DT *inter_x1 = y2 + max_seg_pad;
IN_DT *inter_y1 = inter_x1 + max_seg_pad;
IN_DT *inter_x2 = inter_y1 + max_seg_pad;
IN_DT *inter_y2 = inter_x2 + max_seg_pad;
IN_DT *max_box = inter_y2 + max_seg_pad; // the max score, x1, y1, x2, y2
OUT_DT *nram_save =
(OUT_DT *)((char *)max_box +
NFU_ALIGN_SIZE); // offset two line from max_box
#if __BANG_ARCH__ >= 300
float max_box_x1 = 0;
float max_box_y1 = 0;
float max_box_x2 = 0;
float max_box_y2 = 0;
#endif
mluMemcpyDirection_t load_dir = SRAM2NRAM;
mluMemcpyDirection_t store_dir = NRAM2SRAM;
load_dir = (input_ram == SRAM) ? SRAM2NRAM : GDRAM2NRAM;
store_dir = (input_ram == SRAM) ? NRAM2SRAM : NRAM2GDRAM;
for (int keep = 0; keep < max_output_size;
keep++) { // loop until the max_score <= 0
if (core_limit != 1) {
__sync_cluster(); // sync before current loop
}
/******FIND MAX START******/
int max_index = 0; // the max score index
int global_max_index = 0; // for U1
float max_area = 0; // the max socre area
max_box[0] = 0; // init 0
findCoreMaxBox(input_data_score, score, inter_x1, max_box, input_x1_ptr,
input_y1_ptr, input_x2_ptr, input_y2_ptr, load_dir,
input_offset, repeat, remain, remain_pad, max_seg_pad,
max_index);
if (core_limit == 1) {
#if __BANG_ARCH__ >= 300
calMaxArea(max_box, algo, offset, max_area, max_box_x1, max_box_y1,
max_box_x2, max_box_y2);
#else
calMaxArea(max_box, algo, offset, max_area);
#endif
input_data_score[max_index] = 0;
global_max_index = max_index;
} else if (core_limit == 4) {
__sync_cluster();
findClusterMaxBox(sram, max_box, inter_x1, input_data_score, core_limit);
#if __BANG_ARCH__ >= 300
calMaxArea(max_box, algo, offset, max_area, max_box_x1, max_box_y1,
max_box_x2, max_box_y2);
#else
calMaxArea(max_box, algo, offset, max_area);
#endif
global_max_index = ((uint32_t *)(max_box + 5))[0];
input_data_score[global_max_index] = 0;
}
// by now, we get: max_score|max_index|max_box|max_area
/******FIND MAX END******/
storeResult(max_box, nram_save, output_dram, keep, nram_save_limit_count,
max_output_size, thresh_score, output_mode, nram_save_count,
output_box_num);
// if the max score <= 0, end
if (core_limit == 1) {
if (float(max_box[0]) <= thresh_score) {
break;
}
} else {
if (float(max_box[0]) <= thresh_score) {
if (coreId == 0) {
exit_flag[0] = 1;
}
}
__sync_cluster();
if (exit_flag[0] == 1) {
break;
}
}
/******NMS STORE END******/
#if __BANG_ARCH__ >= 300
scoreUpdate(input_data_score, load_dir, store_dir, input_x1_ptr,
input_y1_ptr, input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score,
inter_x1, inter_y1, inter_x2, inter_y2, max_box, max_box_x1,
max_box_y1, max_box_x2, max_box_y2, nram_save,
repeat_iou_compute, remain_iou_compute, remain_pad_iou_compute,
max_seg_iou_compute, max_seg_pad, thresh_iou, div_thresh_iou,
input_offset, offset, max_area, input_num_boxes, algo);
#else
scoreUpdate(input_data_score, load_dir, store_dir, input_x1_ptr,
input_y1_ptr, input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score,
inter_x1, inter_y1, inter_x2, inter_y2, max_box, max_box[1],
max_box[2], max_box[3], max_box[4], nram_save,
repeat_iou_compute, remain_iou_compute, remain_pad_iou_compute,
max_seg_iou_compute, max_seg_pad, thresh_iou, div_thresh_iou,
input_offset, offset, max_area, input_num_boxes, algo);
#endif
} // for max_output_size
}
__mlu_global__ void MLUUnion1KernelNMS(
const void *input_boxes, const void *input_confidence,
const int input_num_boxes, const int max_output_size,
const float iou_threshold, const float confidence_threshold,
const int output_mode, void *workspace, void *result_num, void *output,
const cnrtDataType_t data_type_input, const float offset, const int algo) {
if (data_type_input == CNRT_FLOAT16) {
__memcpy(workspace, input_confidence, input_num_boxes * sizeof(half),
GDRAM2GDRAM);
} else if (data_type_input == CNRT_FLOAT32) {
__memcpy(workspace, input_confidence, input_num_boxes * sizeof(float),
GDRAM2GDRAM);
} else {
}
uint32_t output_box_num = 0;
float *score_data = (float *)workspace;
float *boxes_data = (float *)input_boxes;
float *sram = (float *)sram_buffer;
if (output_mode == 0) {
if (data_type_input == CNRT_FLOAT32) {
nms_detection(output_box_num, output_mode, (uint32_t *)output, score_data,
boxes_data, GDRAM, sram, taskDim, input_num_boxes,
max_output_size, iou_threshold, confidence_threshold,
offset, algo);
} else {
nms_detection(output_box_num, output_mode, (uint32_t *)output,
(half *)score_data, (half *)boxes_data, GDRAM, (half *)sram,
taskDim, input_num_boxes, max_output_size, iou_threshold,
confidence_threshold, offset, algo);
}
} else {
if (data_type_input == CNRT_FLOAT32) {
nms_detection(output_box_num, output_mode, (float *)output, score_data,
boxes_data, GDRAM, sram, taskDim, input_num_boxes,
max_output_size, iou_threshold, confidence_threshold,
offset, algo);
} else {
nms_detection(output_box_num, output_mode, (half *)output,
(half *)score_data, (half *)boxes_data, GDRAM, (half *)sram,
taskDim, input_num_boxes, max_output_size, iou_threshold,
confidence_threshold, offset, algo);
}
}
((uint32_t *)result_num)[0] = output_box_num;
}
template <typename IN_DT, typename OUT_DT>
__mlu_func__ void nms_detection_ux(
int32_t *exit_flag, uint32_t &output_box_num, OUT_DT *output_dram,
IN_DT *score_data, const IN_DT *boxes_data, const Addr input_ram,
const int input_num_boxes, const int max_output_size,
const float thresh_iou, const float thresh_score, const float offset,
const int output_mode, const int algo, char *cdma_gdram) {
exit_flag[0] = 0;
IN_DT *sram = (IN_DT *)sram_buffer;
// score, x1, y1, x2, y2, inter_x1, inter_y1, inter_x2, inter_y2
int nms_buffer_count1 = 9;
// temp nram buffer to store selected target.
int nram_save_limit_count = 256;
float div_thresh_iou = 1.0 / thresh_iou;
// input data ptr
const IN_DT *input_x1_ptr = boxes_data;
const IN_DT *input_y1_ptr = input_x1_ptr + input_num_boxes;
const IN_DT *input_x2_ptr = input_y1_ptr + input_num_boxes;
const IN_DT *input_y2_ptr = input_x2_ptr + input_num_boxes;
int limit = 0; // find limit when GDRAM or SRAM
int max_seg_pad = 0; // the max length every repeat
int repeat = 0;
int remain = 0;
int remain_pad = 0;
int nram_save_count = 0;
if (output_mode == 0) {
limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) -
nram_save_limit_count * sizeof(OUT_DT)) /
(nms_buffer_count1 * sizeof(IN_DT));
} else {
limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) -
nram_save_limit_count * INFO_NUM * sizeof(OUT_DT)) /
(nms_buffer_count1 * sizeof(IN_DT));
}
int input_offset = 0;
int max_seg_iou_compute = 0;
int repeat_iou_compute = 0;
int remain_iou_compute = 0;
int remain_pad_iou_compute = 0;
getComputeParamsUx(sizeof(IN_DT), input_num_boxes, limit, input_offset,
max_seg_pad, repeat, remain, remain_pad,
max_seg_iou_compute, repeat_iou_compute,
remain_iou_compute, remain_pad_iou_compute);
// init the nram ptr
IN_DT *score = (IN_DT *)nram_buffer;
IN_DT *x1 = score + max_seg_pad;
IN_DT *y1 = x1 + max_seg_pad;
IN_DT *x2 = y1 + max_seg_pad;
IN_DT *y2 = x2 + max_seg_pad;
IN_DT *inter_x1 = y2 + max_seg_pad;
IN_DT *inter_y1 = inter_x1 + max_seg_pad;
IN_DT *inter_x2 = inter_y1 + max_seg_pad;
IN_DT *inter_y2 = inter_x2 + max_seg_pad;
IN_DT *max_box = inter_y2 + max_seg_pad; // the max score, x1, y1, x2, y2
OUT_DT *nram_save =
(OUT_DT *)((char *)max_box +
NFU_ALIGN_SIZE); // offset two line from max_box
#if __BANG_ARCH__ >= 300
float max_box_x1 = 0;
float max_box_y1 = 0;
float max_box_x2 = 0;
float max_box_y2 = 0;
#endif
mluMemcpyDirection_t load_dir = SRAM2NRAM;
mluMemcpyDirection_t store_dir = NRAM2SRAM;
load_dir = (input_ram == SRAM) ? SRAM2NRAM : GDRAM2NRAM;
store_dir = (input_ram == SRAM) ? NRAM2SRAM : NRAM2GDRAM;
for (int keep = 0; keep < max_output_size;
keep++) { // loop until the max_score <= 0
__sync_all();
int max_index = 0;
int global_max_index = 0; // for Ux
float max_area = 0; // the max socre area
max_box[0] = 0; // init 0
if (coreId == 0) {
findCoreMaxBox(score_data, score, inter_x1, max_box, input_x1_ptr,
input_y1_ptr, input_x2_ptr, input_y2_ptr, load_dir,
input_offset, repeat, remain, remain_pad, max_seg_pad,
max_index);
// copy max box info to sram
__memcpy(sram, max_box, REDUCE_NUM * sizeof(IN_DT), NRAM2SRAM);
}
__sync_all();
#if __BANG_ARCH__ >= 590
__memcpy((char *)cdma_gdram + REDUCE_NUM * clusterId * sizeof(IN_DT), sram,
REDUCE_NUM * sizeof(IN_DT), SRAM2GDRAM);
__sync_all();
if (clusterId == 0 && coreId == 0) {
__bang_write_zero(inter_x1, NMS_SIZE);
__memcpy((char *)inter_x1, (char *)cdma_gdram, sizeof(IN_DT), GDRAM2NRAM,
sizeof(IN_DT), REDUCE_NUM * sizeof(IN_DT), clusterDim - 1);
__bang_max(max_box, inter_x1, NMS_SIZE);
int max_cluster = (sizeof(IN_DT) == sizeof(half))
? ((uint16_t *)max_box)[1]
: ((uint32_t *)max_box)[1];
__memcpy((char *)cdma_gdram,
(char *)cdma_gdram + max_cluster * REDUCE_NUM * sizeof(IN_DT),
REDUCE_NUM * sizeof(IN_DT), GDRAM2GDRAM);
}
__sync_all();
__memcpy(max_box, cdma_gdram, REDUCE_NUM * sizeof(IN_DT), GDRAM2NRAM);
#else
findGlobalMaxBox(max_box, sram, inter_x1);
#endif
#if __BANG_ARCH__ >= 300
calMaxArea(max_box, algo, offset, max_area, max_box_x1, max_box_y1,
max_box_x2, max_box_y2);
#else
calMaxArea(max_box, algo, offset, max_area);
#endif
global_max_index = ((uint32_t *)(max_box + 5))[0];
if (coreId != MEMORY_CORE) {
score_data[global_max_index] = 0;
}
storeResult(max_box, nram_save, output_dram, keep, nram_save_limit_count,
max_output_size, thresh_score, output_mode, nram_save_count,
output_box_num);
if (float(max_box[0]) <= thresh_score) {
if (clusterId == 0 && coreId == 0) {
exit_flag[0] = 1; // dram
}
}
__sync_all();
if (exit_flag[0] == 1) {
break;
}
/******NMS STORE END******/
#if __BANG_ARCH__ >= 300
scoreUpdate(score_data, load_dir, store_dir, input_x1_ptr, input_y1_ptr,
input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score, inter_x1,
inter_y1, inter_x2, inter_y2, max_box, max_box_x1, max_box_y1,
max_box_x2, max_box_y2, nram_save, repeat_iou_compute,
remain_iou_compute, remain_pad_iou_compute, max_seg_iou_compute,
max_seg_pad, thresh_iou, div_thresh_iou, input_offset, offset,
max_area, input_num_boxes, algo);
#else
scoreUpdate(score_data, load_dir, store_dir, input_x1_ptr, input_y1_ptr,
input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score, inter_x1,
inter_y1, inter_x2, inter_y2, max_box, max_box[1], max_box[2],
max_box[3], max_box[4], nram_save, repeat_iou_compute,
remain_iou_compute, remain_pad_iou_compute, max_seg_iou_compute,
max_seg_pad, thresh_iou, div_thresh_iou, input_offset, offset,
max_area, input_num_boxes, algo);
#endif
} // for max_output_size
}
__mlu_global__ void MLUUionXKernelNMS(
const void *input_boxes, const void *input_confidence,
const int input_num_boxes, const int max_output_size,
const float iou_threshold, const float confidence_threshold,
const float offset, const cnrtDataType_t data_type_input,
const int output_mode, const int algo, void *workspace, void *result_num,
void *output) {
int input_dwidth = (data_type_input == CNRT_FLOAT32) ? 4 : 2;
int32_t *exit_flag = (int32_t *)((char *)workspace +
INFO_NUM * input_num_boxes * input_dwidth);
char *cdma_addr = (char *)exit_flag + sizeof(int32_t);
int reduce_sram_size = NFU_ALIGN_SIZE * REDUCE_NUM * input_dwidth;
int availbale_sram_size = SIZE_SRAM_BUF - reduce_sram_size;
int cluster_score_size = input_num_boxes * input_dwidth;
int cluster_boxes_size = input_num_boxes * 4 * input_dwidth;
char *sram_score = (char *)sram_buffer + reduce_sram_size;
char *sram_boxes =
(char *)sram_buffer + reduce_sram_size + cluster_score_size;
Addr input_ram = GDRAM;
if ((cluster_score_size + cluster_boxes_size) < availbale_sram_size) {
input_ram = SRAM;
__memcpy(sram_score, input_confidence, cluster_score_size, GDRAM2SRAM);
__memcpy(sram_boxes, input_boxes, cluster_boxes_size, GDRAM2SRAM);
} else {
__memcpy(workspace, input_confidence, cluster_score_size, GDRAM2GDRAM);
}
__sync_cluster();
uint32_t output_box_num = 0;
float *score_data;
float *boxes_data;
score_data = (input_ram == SRAM) ? (float *)sram_score : (float *)workspace;
boxes_data = (input_ram == SRAM) ? (float *)sram_boxes : (float *)input_boxes;
if (output_mode == 0) {
if (data_type_input == CNRT_FLOAT32) {
nms_detection_ux(exit_flag, output_box_num, (uint32_t *)output,
score_data, boxes_data, input_ram, input_num_boxes,
max_output_size, iou_threshold, confidence_threshold,
offset, output_mode, algo, cdma_addr);
} else {
nms_detection_ux(exit_flag, output_box_num, (uint32_t *)output,
(half *)score_data, (half *)boxes_data, input_ram,
input_num_boxes, max_output_size, iou_threshold,
confidence_threshold, offset, output_mode, algo,
cdma_addr);
}
} else {
if (data_type_input == CNRT_FLOAT32) {
nms_detection_ux(exit_flag, output_box_num, (float *)output, score_data,
boxes_data, input_ram, input_num_boxes, max_output_size,
iou_threshold, confidence_threshold, offset, output_mode,
algo, cdma_addr);
} else {
nms_detection_ux(exit_flag, output_box_num, (half *)output,
(half *)score_data, (half *)boxes_data, input_ram,
input_num_boxes, max_output_size, iou_threshold,
confidence_threshold, offset, output_mode, algo,
cdma_addr);
}
}
((uint32_t *)result_num)[0] = output_box_num;
}
void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t data_type_input, const void *boxes_ptr,
const void *scores_ptr, const int input_num_boxes,
const int max_output_boxes, const float iou_threshold,
const float offset, void *workspace_ptr, void *output_size_ptr,
void *output_ptr) {
switch (k_type) {
default: { return; }
case CNRT_FUNC_TYPE_BLOCK:
case CNRT_FUNC_TYPE_UNION1: {
MLUUnion1KernelNMS<<<k_dim, k_type, queue>>>(
(void *)boxes_ptr, (void *)scores_ptr, input_num_boxes,
max_output_boxes, iou_threshold, /*confidence_threshold=*/0.0,
/*output_mode=*/0, workspace_ptr, output_size_ptr, output_ptr,
data_type_input, offset, /*algo=*/1);
}; break;
case CNRT_FUNC_TYPE_UNION2:
case CNRT_FUNC_TYPE_UNION4:
case CNRT_FUNC_TYPE_UNION8:
case CNRT_FUNC_TYPE_UNION16: {
MLUUionXKernelNMS<<<k_dim, k_type, queue>>>(
(void *)boxes_ptr, (void *)scores_ptr, input_num_boxes,
max_output_boxes, iou_threshold, /*confidence_threshold=*/0.0, offset,
data_type_input, /*output_mode=*/0, /*algo=*/1, workspace_ptr,
output_size_ptr, output_ptr);
}; break;
}
}
mmcv/ops/csrc/common/mlu/nms_utils.hpp
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) [2019-2022] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef NMS_UTILS_HPP_
#define NMS_UTILS_HPP_
#include "common_mlu_helper.hpp"
#define NMS_SIZE (64)
#define NMS_UP(x, y) (x / y + (int)(x % y > 0)) * y
#define NMS_DOWN(x, y) (x / y) * y
#define INFO_NUM (5) // 5 means x1, x2, y1, y2 and score
#define MEMORY_CORE (0x80)
#define REDUCE_NUM \
(7) // score, x1, y1, x2, y2, max_index (reserve 2 num for half-type input)
__mlu_func__
void
pvLock
()
{
#if __BANG_ARCH__ == 270
if
(
coreId
!=
MEMORY_CORE
)
{
__bang_lock
(
0
,
0
);
}
#endif
}
__mlu_func__
void
pvUnlock
()
{
#if __BANG_ARCH__ == 270
if
(
coreId
!=
MEMORY_CORE
)
{
__bang_unlock
(
0
,
0
);
}
#endif
}
template
<
typename
T
>
static
__mlu_func__
void
computeReluN
(
T
*
nram_dst
,
T
*
nram_src
,
void
*
nram_tmp
,
const
int
deal_num
,
const
T
threshold
=
0
)
{
if
(
threshold
<
0
)
{
return
;
}
if
(
threshold
)
{
#if __BANG_ARCH__ >= 300
__bang_relun
(
nram_dst
,
nram_src
,
deal_num
,
threshold
);
#else
int
align_num
=
NFU_ALIGN_SIZE
/
sizeof
(
T
);
T
*
nram_aux_a
=
(
T
*
)
nram_tmp
;
T
*
nram_aux_b
=
nram_aux_a
+
deal_num
;
T
*
nram_zero
=
nram_aux_b
+
align_num
;
__bang_write_value
(
nram_aux_b
,
align_num
,
threshold
);
__bang_write_zero
(
nram_zero
,
align_num
);
__bang_cycle_lt
((
T
*
)
nram_aux_a
,
nram_src
,
(
T
*
)
nram_aux_b
,
deal_num
,
align_num
);
__bang_mul
(
nram_dst
,
nram_src
,
(
T
*
)
nram_aux_a
,
deal_num
);
__bang_cycle_eq
((
T
*
)
nram_aux_a
,
(
T
*
)
nram_aux_a
,
(
T
*
)
nram_zero
,
deal_num
,
align_num
);
__bang_cycle_mul
((
T
*
)
nram_aux_a
,
(
T
*
)
nram_aux_a
,
(
T
*
)
nram_aux_b
,
deal_num
,
align_num
);
__bang_add
(
nram_dst
,
nram_dst
,
(
T
*
)
nram_aux_a
,
deal_num
);
__bang_cycle_gt
((
T
*
)
nram_aux_a
,
nram_dst
,
(
T
*
)
nram_zero
,
deal_num
,
align_num
);
__bang_mul
(
nram_dst
,
nram_dst
,
(
T
*
)
nram_aux_a
,
deal_num
);
#endif
}
else
{
#if __BANG_ARCH__ >= 300
__bang_relu
(
nram_dst
,
nram_src
,
deal_num
);
#else
__bang_active_relu
(
nram_dst
,
nram_src
,
deal_num
);
#endif
}
}
__mlu_func__
void
getComputeParamsBlockOrU1
(
const
int
input_dwidth
,
const
int
input_box_num
,
const
int
limit
,
const
int
core_limit
,
int
&
input_offset
,
int
&
max_seg_pad
,
int
&
repeat
,
int
&
remain
,
int
&
remain_pad
,
int
&
max_seg_iou_compute
,
int
&
repeat_iou_compute
,
int
&
remain_iou_compute
,
int
&
remain_pad_iou_compute
)
{
int
avg_core
=
input_box_num
/
core_limit
;
int
rem
=
input_box_num
%
core_limit
;
int
len_core
=
avg_core
+
(
coreId
<
rem
?
1
:
0
);
input_offset
=
avg_core
*
coreId
+
(
coreId
<=
rem
?
coreId
:
rem
);
max_seg_pad
=
NMS_DOWN
(
limit
,
NMS_SIZE
);
repeat
=
len_core
/
max_seg_pad
;
remain
=
len_core
%
max_seg_pad
;
remain_pad
=
NMS_UP
(
remain
,
NMS_SIZE
);
// if datatype is fp16, we should cvt to fp32 when compute iou
max_seg_iou_compute
=
NMS_DOWN
(
max_seg_pad
/
(
4
/
input_dwidth
),
NMS_SIZE
);
repeat_iou_compute
=
len_core
/
max_seg_iou_compute
;
remain_iou_compute
=
len_core
%
max_seg_iou_compute
;
remain_pad_iou_compute
=
NMS_UP
(
remain_iou_compute
,
NMS_SIZE
);
}
__mlu_func__
void
getComputeParamsUx
(
const
int
input_dwidth
,
const
int
input_num_boxes
,
const
int
limit
,
int
&
input_offset
,
int
&
max_seg_pad
,
int
&
repeat
,
int
&
remain
,
int
&
remain_pad
,
int
&
max_seg_iou_compute
,
int
&
repeat_iou_compute
,
int
&
remain_iou_compute
,
int
&
remain_pad_iou_compute
)
{
// data split
int
avg_cluster
=
input_num_boxes
/
clusterDim
;
int
rem_cluster
=
input_num_boxes
%
clusterDim
;
int
len_cluster
=
avg_cluster
+
(
clusterId
<
rem_cluster
);
int
cluster_offset
=
avg_cluster
*
clusterId
+
(
clusterId
<=
rem_cluster
?
clusterId
:
rem_cluster
);
int
avg_core
=
len_cluster
/
coreDim
;
int
rem_core
=
len_cluster
%
coreDim
;
int
len_core
=
avg_core
+
(
coreId
<
rem_core
);
int
core_offset
=
avg_core
*
coreId
+
(
coreId
<=
rem_core
?
coreId
:
rem_core
);
input_offset
=
cluster_offset
+
core_offset
;
max_seg_pad
=
NMS_DOWN
(
limit
,
NMS_SIZE
);
// core 0 of each cluster calculate the max score index
int
max_index_len_core
=
avg_cluster
+
(
clusterId
<
rem_cluster
);
repeat
=
max_index_len_core
/
max_seg_pad
;
remain
=
max_index_len_core
%
max_seg_pad
;
remain_pad
=
NMS_UP
(
remain
,
NMS_SIZE
);
// if datatype is fp16, we should cvt to fp32 when compute iou
max_seg_iou_compute
=
NMS_DOWN
(
max_seg_pad
/
(
sizeof
(
float
)
/
input_dwidth
),
NMS_SIZE
);
repeat_iou_compute
=
len_core
/
max_seg_iou_compute
;
remain_iou_compute
=
len_core
%
max_seg_iou_compute
;
remain_pad_iou_compute
=
NMS_UP
(
remain_iou_compute
,
NMS_SIZE
);
}
template
<
typename
IN_DT
>
__mlu_func__
void
findGlobalMaxBox
(
IN_DT
*
max_box
,
IN_DT
*
sram
,
IN_DT
*
inter_x1
)
{
// copy all partial max to the sram of cluster 0
if
(
clusterId
!=
0
)
{
__memcpy
(
sram
+
REDUCE_NUM
*
clusterId
,
sram
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
SRAM2SRAM
,
0
);
}
__sync_all
();
// reduce between clusters to get the global max box
if
(
clusterId
==
0
)
{
if
(
coreId
==
0
)
{
__bang_write_zero
(
inter_x1
,
NMS_SIZE
);
__memcpy
(
inter_x1
,
sram
,
sizeof
(
IN_DT
),
SRAM2NRAM
,
sizeof
(
IN_DT
),
REDUCE_NUM
*
sizeof
(
IN_DT
),
clusterDim
-
1
);
__bang_max
(
max_box
,
inter_x1
,
NMS_SIZE
);
int
max_cluster
=
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
?
((
uint16_t
*
)
max_box
)[
1
]
:
((
uint32_t
*
)
max_box
)[
1
];
__memcpy
(
max_box
,
sram
+
max_cluster
*
REDUCE_NUM
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
SRAM2NRAM
);
__memcpy
(
sram
,
max_box
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
NRAM2SRAM
);
}
__sync_cluster
();
if
(
coreId
==
0x80
&&
clusterDim
>
1
)
{
// broadcast global max box to each cluster's sram
for
(
int
cluster_idx
=
1
;
cluster_idx
<
clusterDim
;
++
cluster_idx
)
{
__memcpy
(
sram
,
sram
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
SRAM2SRAM
,
cluster_idx
);
}
}
__sync_cluster
();
}
__sync_all
();
// copy the global max box to max_box
__memcpy
(
max_box
,
sram
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
SRAM2NRAM
);
}
template
<
typename
IN_DT
>
__mlu_func__
void
findCoreMaxBox
(
IN_DT
*
input_score_ptr
,
IN_DT
*
score
,
IN_DT
*
inter_x1
,
IN_DT
*
max_box
,
const
IN_DT
*
input_x1_ptr
,
const
IN_DT
*
input_y1_ptr
,
const
IN_DT
*
input_x2_ptr
,
const
IN_DT
*
input_y2_ptr
,
const
mluMemcpyDirection_t
load_dir
,
const
int
input_offset
,
const
int
repeat
,
const
int
remain
,
const
int
remain_pad
,
const
int
max_seg_pad
,
int
&
max_index
)
{
if
(
coreId
!=
0x80
)
{
for
(
int
i
=
0
;
i
<=
repeat
;
i
++
)
{
if
(
i
==
repeat
&&
remain
==
0
)
{
break
;
}
int
seg_len
=
0
;
// the length every nms compute
int
cpy_len
=
0
;
// the length every nms memcpy
i
==
repeat
?
seg_len
=
remain_pad
:
seg_len
=
max_seg_pad
;
i
==
repeat
?
cpy_len
=
remain
:
cpy_len
=
max_seg_pad
;
/******NMS LOAD START******/
__bang_write_zero
(
score
,
seg_len
);
__memcpy
(
score
,
input_score_ptr
+
input_offset
+
i
*
max_seg_pad
,
cpy_len
*
sizeof
(
IN_DT
),
load_dir
,
cpy_len
*
sizeof
(
IN_DT
),
cpy_len
*
sizeof
(
IN_DT
),
0
);
/******NMS LOAD END******/
__bang_max
(
inter_x1
,
score
,
seg_len
);
if
(
inter_x1
[
0
]
>
max_box
[
0
])
{
max_box
[
0
]
=
inter_x1
[
0
];
if
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
{
max_index
=
((
uint16_t
*
)
inter_x1
)[
1
]
+
input_offset
+
i
*
max_seg_pad
;
// offset start from head of input_data
}
else
if
(
sizeof
(
IN_DT
)
==
sizeof
(
float
))
{
max_index
=
((
uint32_t
*
)
inter_x1
)[
1
]
+
input_offset
+
i
*
max_seg_pad
;
// offset start from head of input_data
}
}
}
// for repeat
// the max box's x1, y1, x2, y2 on every core
max_box
[
1
]
=
input_x1_ptr
[
max_index
];
max_box
[
2
]
=
input_y1_ptr
[
max_index
];
max_box
[
3
]
=
input_x2_ptr
[
max_index
];
max_box
[
4
]
=
input_y2_ptr
[
max_index
];
((
uint32_t
*
)(
max_box
+
5
))[
0
]
=
max_index
;
}
}
template
<
typename
IN_DT
>
__mlu_func__
void
findClusterMaxBox
(
IN_DT
*
sram
,
IN_DT
*
max_box
,
IN_DT
*
inter_x1
,
IN_DT
*
input_data_score
,
const
int
core_limit
)
{
// find the max with sram
// copy every core's box info to sram, form: score---x1---y1---x2---y2---
__memcpy
(
sram
+
REDUCE_NUM
*
coreId
,
max_box
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
NRAM2SRAM
);
// int32_t datatype
__sync_cluster
();
// copy score from sram to nram and find the max
__bang_write_zero
(
inter_x1
,
64
);
__memcpy
(
inter_x1
,
sram
,
sizeof
(
IN_DT
),
SRAM2NRAM
,
sizeof
(
IN_DT
),
REDUCE_NUM
*
sizeof
(
IN_DT
),
coreDim
-
1
);
__bang_max
(
max_box
,
inter_x1
,
64
);
int
max_core
=
sizeof
(
IN_DT
)
==
sizeof
(
half
)
?
((
uint16_t
*
)
max_box
)[
1
]
:
((
uint32_t
*
)
max_box
)[
1
];
// copy the max box to max_box
__memcpy
(
max_box
,
sram
+
max_core
*
REDUCE_NUM
,
REDUCE_NUM
*
sizeof
(
IN_DT
),
SRAM2NRAM
);
}
/*****************************************************************************/
/*******************************CALCULATE MAX AREA****************************/
/*****************************************************************************/
template
<
typename
IN_DT
>
__mlu_func__
void
calMaxArea
(
IN_DT
*
max_box
,
const
int
algo
,
float
offset
,
float
&
max_area
)
{
if
(
algo
==
0
||
offset
==
0.0
)
{
max_area
=
((
float
)
max_box
[
3
]
-
(
float
)
max_box
[
1
])
*
((
float
)
max_box
[
4
]
-
(
float
)
max_box
[
2
]);
}
else
{
max_area
=
((
float
)
max_box
[
3
]
-
(
float
)
max_box
[
1
]
+
offset
)
*
((
float
)
max_box
[
4
]
-
(
float
)
max_box
[
2
]
+
offset
);
}
}
template
<
typename
IN_DT
>
__mlu_func__
void
calMaxArea
(
IN_DT
*
max_box
,
const
int
algo
,
float
offset
,
float
&
max_area
,
float
&
max_box_x1
,
float
&
max_box_y1
,
float
&
max_box_x2
,
float
&
max_box_y2
)
{
// the case of random inf will break the requirement of x1<=x2, y1<=y2
// so exchange it if it happens.
max_box_x1
=
float
(
max_box
[
1
]);
max_box_x2
=
float
(
max_box
[
3
]);
if
(
max_box
[
1
]
>
max_box
[
3
])
{
max_box_x1
=
float
(
max_box
[
3
]);
max_box_x2
=
float
(
max_box
[
1
]);
}
max_box_y1
=
float
(
max_box
[
2
]);
max_box_y2
=
float
(
max_box
[
4
]);
if
(
max_box
[
2
]
>
max_box
[
4
])
{
max_box_y1
=
float
(
max_box
[
4
]);
max_box_y2
=
float
(
max_box
[
2
]);
}
if
(
algo
==
0
||
offset
==
0.0
)
{
max_area
=
(
max_box_x2
-
max_box_x1
)
*
(
max_box_y2
-
max_box_y1
);
}
else
{
max_area
=
(
max_box_x2
-
max_box_x1
+
offset
)
*
(
max_box_y2
-
max_box_y1
+
offset
);
}
}
/***********************************************************************/
/*******************************STORE RESULT****************************/
/***********************************************************************/
template
<
typename
IN_DT
,
typename
OUT_DT
>
__mlu_func__
void
storeResult
(
IN_DT
*
max_box
,
OUT_DT
*
nram_save
,
OUT_DT
*&
output_dram
,
const
int
keep
,
const
int
nram_save_limit_count
,
const
int
max_output_size
,
const
float
thresh_score
,
const
int
output_mode
,
int
&
nram_save_count
,
uint32_t
&
output_box_num
)
{
/******NMS STORE START******/
// store to nram
if
(
float
(
max_box
[
0
])
>
thresh_score
)
{
OUT_DT
*
save_ptr
;
int
save_offset
=
0
;
int
save_str_num
=
0
;
save_ptr
=
nram_save
;
save_offset
=
nram_save_count
;
save_str_num
=
nram_save_limit_count
;
if
(
clusterId
==
0
&&
coreId
==
0
)
{
if
(
output_mode
==
0
)
{
// index1, index2, ...
save_ptr
[
save_offset
]
=
((
uint32_t
*
)(
max_box
+
INFO_NUM
))[
0
];
}
else
if
(
output_mode
==
1
)
{
// score, x1, y1, x2, y2
__memcpy
(
save_ptr
+
save_offset
*
INFO_NUM
,
max_box
,
INFO_NUM
*
sizeof
(
IN_DT
),
NRAM2NRAM
,
INFO_NUM
*
sizeof
(
IN_DT
),
INFO_NUM
*
sizeof
(
IN_DT
),
0
);
}
else
if
(
output_mode
==
2
)
{
// score---, x1---, y1---, x2---, y2---
__memcpy
(
save_ptr
+
save_offset
,
max_box
,
1
*
sizeof
(
IN_DT
),
NRAM2NRAM
,
save_str_num
*
sizeof
(
IN_DT
),
1
*
sizeof
(
IN_DT
),
4
);
}
}
nram_save_count
++
;
output_box_num
++
;
}
// store to sram/gdram
if
(
output_box_num
!=
0
)
{
if
((
nram_save_count
==
nram_save_limit_count
)
||
(
float
(
max_box
[
0
])
<=
thresh_score
)
||
keep
==
max_output_size
-
1
)
{
if
(
nram_save_count
!=
0
)
{
if
(
clusterId
==
0
&&
coreId
==
0
)
{
if
(
output_mode
==
0
)
{
// index1, index2, ...
pvLock
();
__memcpy
(
output_dram
,
nram_save
,
nram_save_count
*
sizeof
(
uint32_t
),
NRAM2GDRAM
);
pvUnlock
();
output_dram
+=
nram_save_count
;
}
else
if
(
output_mode
==
1
)
{
// score, x1, y1, x2, y2
pvLock
();
__memcpy
(
output_dram
,
nram_save
,
nram_save_count
*
INFO_NUM
*
sizeof
(
IN_DT
),
NRAM2GDRAM
);
pvUnlock
();
output_dram
+=
nram_save_count
*
INFO_NUM
;
}
else
if
(
output_mode
==
2
)
{
// score---, x1---, y1---, x2---, y2---
pvLock
();
__memcpy
(
output_dram
,
nram_save
,
nram_save_count
*
sizeof
(
IN_DT
),
NRAM2GDRAM
,
max_output_size
*
sizeof
(
IN_DT
),
nram_save_limit_count
*
sizeof
(
IN_DT
),
4
);
pvUnlock
();
output_dram
+=
nram_save_count
;
}
nram_save_count
=
0
;
}
}
}
// if move data nram->sram/gdram
}
// if dst
}
template
<
typename
IN_DT
,
typename
OUT_DT
>
__mlu_func__
void
scoreUpdate
(
IN_DT
*
input_score_ptr
,
const
mluMemcpyDirection_t
load_dir
,
const
mluMemcpyDirection_t
store_dir
,
const
IN_DT
*
input_x1_ptr
,
const
IN_DT
*
input_y1_ptr
,
const
IN_DT
*
input_x2_ptr
,
const
IN_DT
*
input_y2_ptr
,
IN_DT
*
x1
,
IN_DT
*
y1
,
IN_DT
*
x2
,
IN_DT
*
y2
,
IN_DT
*
score
,
IN_DT
*
inter_x1
,
IN_DT
*
inter_y1
,
IN_DT
*
inter_x2
,
IN_DT
*
inter_y2
,
IN_DT
*
max_box
,
const
float
max_box_x1
,
const
float
max_box_y1
,
const
float
max_box_x2
,
const
float
max_box_y2
,
OUT_DT
*
nram_save
,
int
repeat_iou_compute
,
int
remain_iou_compute
,
int
remain_pad_iou_compute
,
int
max_seg_iou_compute
,
int
max_seg_pad
,
const
float
thresh_iou
,
const
float
div_thresh_iou
,
const
int
input_offset
,
const
float
offset
,
const
float
max_area
,
const
int
input_num_boxes
,
const
int
algo
)
{
for
(
int
i
=
0
;
i
<=
repeat_iou_compute
;
i
++
)
{
if
(
i
==
repeat_iou_compute
&&
remain_iou_compute
==
0
)
{
break
;
}
int
seg_len
=
(
i
==
repeat_iou_compute
)
?
remain_pad_iou_compute
:
max_seg_iou_compute
;
int
cpy_len
=
(
i
==
repeat_iou_compute
)
?
remain_iou_compute
:
max_seg_iou_compute
;
/******NMS LOAD START******/
int
dt_offset
=
0
;
if
(
sizeof
(
IN_DT
)
==
sizeof
(
float
))
{
__memcpy
(
score
,
input_score_ptr
+
input_offset
+
i
*
max_seg_pad
,
cpy_len
*
sizeof
(
IN_DT
),
load_dir
,
cpy_len
*
sizeof
(
IN_DT
),
cpy_len
*
sizeof
(
IN_DT
),
0
);
dt_offset
=
0
;
}
else
if
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
{
__memcpy
(
x1
,
input_score_ptr
+
input_offset
+
i
*
max_seg_iou_compute
,
cpy_len
*
sizeof
(
IN_DT
),
load_dir
,
cpy_len
*
sizeof
(
IN_DT
),
cpy_len
*
sizeof
(
IN_DT
),
0
);
__bang_half2float
((
float
*
)
score
,
(
half
*
)
x1
,
seg_len
);
dt_offset
=
max_seg_iou_compute
;
}
#if __BANG_ARCH__ >= 300
__memcpy
(
inter_x1
+
dt_offset
,
input_x1_ptr
+
input_offset
+
i
*
max_seg_iou_compute
,
cpy_len
*
sizeof
(
IN_DT
),
load_dir
,
max_seg_pad
*
sizeof
(
IN_DT
),
input_num_boxes
*
sizeof
(
IN_DT
),
3
);
if
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
{
__bang_half2float
((
float
*
)
inter_x1
,
(
half
*
)
inter_x1
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
inter_y1
,
(
half
*
)
inter_y1
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
inter_x2
,
(
half
*
)
inter_x2
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
inter_y2
,
(
half
*
)
inter_y2
+
max_seg_iou_compute
,
seg_len
);
}
// box transfer
__bang_minequal
((
float
*
)
x1
,
(
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
seg_len
);
__bang_maxequal
((
float
*
)
x2
,
(
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
seg_len
);
__bang_minequal
((
float
*
)
y1
,
(
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
seg_len
);
__bang_maxequal
((
float
*
)
y2
,
(
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
seg_len
);
// 1、 compute IOU
// get the area_I
__bang_maxeq_scalar
((
float
*
)
inter_x1
,
(
float
*
)
x1
,
max_box_x1
,
seg_len
);
// inter_x1
__bang_mineq_scalar
((
float
*
)
inter_x2
,
(
float
*
)
x2
,
max_box_x2
,
seg_len
);
// inter_x2
__bang_sub
((
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x1
,
seg_len
);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_add_scalar
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
offset
,
seg_len
);
}
computeReluN
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
NULL
,
seg_len
);
// inter_w
__bang_maxeq_scalar
((
float
*
)
inter_y1
,
(
float
*
)
y1
,
float
(
max_box_y1
),
seg_len
);
// inter_y1
__bang_mineq_scalar
((
float
*
)
inter_y2
,
(
float
*
)
y2
,
float
(
max_box_y2
),
seg_len
);
// inter_y2
__bang_sub
((
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
(
float
*
)
inter_y1
,
seg_len
);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_add_scalar
((
float
*
)
inter_y1
,
(
float
*
)
inter_y1
,
offset
,
seg_len
);
}
computeReluN
((
float
*
)
inter_y1
,
(
float
*
)
inter_y1
,
NULL
,
seg_len
);
// inter_h
__bang_mul
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
(
float
*
)
inter_y1
,
seg_len
);
// area_I
// get the area of input_box: area = (x2 - x1) * (y2 - y1);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_fusion
(
FUSION_FSA
,
(
float
*
)
inter_y1
,
(
float
*
)
x2
,
(
float
*
)
x1
,
offset
,
seg_len
,
seg_len
);
__bang_fusion
(
FUSION_FSA
,
(
float
*
)
inter_y2
,
(
float
*
)
y2
,
(
float
*
)
y1
,
offset
,
seg_len
,
seg_len
);
__bang_mul
((
float
*
)
inter_x2
,
(
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
seg_len
);
// area
}
else
{
__bang_sub
((
float
*
)
inter_y1
,
(
float
*
)
x2
,
(
float
*
)
x1
,
seg_len
);
__bang_fusion
(
FUSION_FSM
,
(
float
*
)
inter_x2
,
(
float
*
)
y2
,
(
float
*
)
y1
,
(
float
*
)
inter_y1
,
seg_len
,
seg_len
);
}
// get the area_U: area + max_area - area_I
__bang_fusion
(
FUSION_FAS
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x2
,
max_area
,
(
float
*
)
inter_x1
,
seg_len
,
seg_len
);
// 2、 select the box
// if IOU greater than thres, set the score to zero, abort it: area_U >
// area_I * (1 / thresh)?
if
(
thresh_iou
>
0.0
)
{
__bang_mul_scalar
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
div_thresh_iou
,
seg_len
);
}
else
{
__bang_mul_scalar
((
float
*
)
inter_x2
,
(
float
*
)
inter_x2
,
thresh_iou
,
seg_len
);
}
// process for nan
__bang_lt
((
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x1
,
seg_len
);
__bang_not
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
seg_len
);
__bang_mul
((
float
*
)
score
,
(
float
*
)
score
,
(
float
*
)
inter_x1
,
seg_len
);
/******NMS COMPUTE END******/
#else
__memcpy
(
x1
+
dt_offset
,
input_x1_ptr
+
input_offset
+
i
*
max_seg_iou_compute
,
cpy_len
*
sizeof
(
IN_DT
),
load_dir
,
max_seg_pad
*
sizeof
(
IN_DT
),
input_num_boxes
*
sizeof
(
IN_DT
),
3
);
if
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
{
__bang_half2float
((
float
*
)
x1
,
(
half
*
)
x1
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
y1
,
(
half
*
)
y1
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
x2
,
(
half
*
)
x2
+
max_seg_iou_compute
,
seg_len
);
__bang_half2float
((
float
*
)
y2
,
(
half
*
)
y2
+
max_seg_iou_compute
,
seg_len
);
}
// 1、 compute IOU
// get the area_I
__bang_write_value
((
float
*
)
inter_y1
,
seg_len
,
float
(
max_box
[
1
]));
// max_x1
__bang_maxequal
((
float
*
)
inter_x1
,
(
float
*
)
x1
,
(
float
*
)
inter_y1
,
seg_len
);
// inter_x1
__bang_write_value
((
float
*
)
inter_y2
,
seg_len
,
float
(
max_box
[
3
]));
// max_x2
__bang_minequal
((
float
*
)
inter_x2
,
(
float
*
)
x2
,
(
float
*
)
inter_y2
,
seg_len
);
// inter_x2
__bang_sub
((
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x1
,
seg_len
);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_add_scalar
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
offset
,
seg_len
);
}
computeReluN
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
NULL
,
seg_len
);
// inter_w
__bang_write_value
((
float
*
)
inter_x2
,
seg_len
,
float
(
max_box
[
2
]));
// max_y1
__bang_maxequal
((
float
*
)
inter_y1
,
(
float
*
)
y1
,
(
float
*
)
inter_x2
,
seg_len
);
// inter_y1
__bang_write_value
((
float
*
)
inter_x2
,
seg_len
,
float
(
max_box
[
4
]));
// max_y2
__bang_minequal
((
float
*
)
inter_y2
,
(
float
*
)
y2
,
(
float
*
)
inter_x2
,
seg_len
);
// inter_y2
__bang_sub
((
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
(
float
*
)
inter_y1
,
seg_len
);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_add_scalar
((
float
*
)
inter_y1
,
(
float
*
)
inter_y1
,
offset
,
seg_len
);
}
computeReluN
((
float
*
)
inter_y1
,
(
float
*
)
inter_y1
,
NULL
,
seg_len
);
// inter_h
__bang_mul
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
(
float
*
)
inter_y1
,
seg_len
);
// area_I
// get the area of input_box: area = (x2 - x1) * (y2 - y1);
__bang_sub
((
float
*
)
inter_y1
,
(
float
*
)
x2
,
(
float
*
)
x1
,
seg_len
);
__bang_sub
((
float
*
)
inter_y2
,
(
float
*
)
y2
,
(
float
*
)
y1
,
seg_len
);
if
(
algo
==
1
&&
offset
!=
0.0
)
{
__bang_add_scalar
((
float
*
)
inter_y1
,
(
float
*
)
inter_y1
,
offset
,
seg_len
);
__bang_add_scalar
((
float
*
)
inter_y2
,
(
float
*
)
inter_y2
,
offset
,
seg_len
);
}
__bang_mul
((
float
*
)
inter_x2
,
(
float
*
)
inter_y1
,
(
float
*
)
inter_y2
,
seg_len
);
// area
// get the area_U: area + max_area - area_I
__bang_add_scalar
((
float
*
)
inter_x2
,
(
float
*
)
inter_x2
,
float
(
max_area
),
seg_len
);
__bang_sub
((
float
*
)
inter_x2
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x1
,
seg_len
);
// area_U
// 2、 select the box
// if IOU greater than thresh, set the score to zero, abort it: area_U >
// area_I * (1 / thresh)?
if
(
thresh_iou
>
0.0
)
{
__bang_mul_scalar
((
float
*
)
inter_x1
,
(
float
*
)
inter_x1
,
div_thresh_iou
,
seg_len
);
}
else
{
__bang_mul_scalar
((
float
*
)
inter_x2
,
(
float
*
)
inter_x2
,
thresh_iou
,
seg_len
);
}
__bang_ge
((
float
*
)
inter_x1
,
(
float
*
)
inter_x2
,
(
float
*
)
inter_x1
,
seg_len
);
__bang_mul
((
float
*
)
score
,
(
float
*
)
score
,
(
float
*
)
inter_x1
,
seg_len
);
/******NMS COMPUTE END******/
#endif
// update the score
if
(
sizeof
(
IN_DT
)
==
sizeof
(
half
))
{
convertFloat2half
((
half
*
)
score
,
(
float
*
)
score
,
seg_len
);
}
pvLock
();
__memcpy
(
input_score_ptr
+
input_offset
+
i
*
max_seg_iou_compute
,
score
,
cpy_len
*
sizeof
(
IN_DT
),
store_dir
,
cpy_len
*
sizeof
(
IN_DT
),
cpy_len
*
sizeof
(
IN_DT
),
0
);
pvUnlock
();
}
}
#endif // NMS_UTILS_HPP_
mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#define ROI_OFFSET 5
__nram__ char buffer[MAX_NRAM_SIZE];
namespace forward {
template <typename T>
__mlu_func__ void bilinearInterpolate(const int input_height,
const int input_width, T y, T x, T *w1,
T *w2, T *w3, T *w4, int *x_low,
int *x_high, int *y_low, int *y_high,
bool *empty) {
// deal with cases that inverse elements are of feature map boundary
if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) {
*empty = true;
return;
}
if (y <= 0) y = 0;
if (x <= 0) x = 0;
int y_low_ = int(y);
int x_low_ = int(x);
if (y_low_ >= input_height - 1) {
*y_high = y_low_ = input_height - 1;
y = (T)y_low_;
} else {
*y_high = y_low_ + 1;
}
if (x_low_ >= input_width - 1) {
*x_high = x_low_ = input_width - 1;
x = T(x_low_);
} else {
*x_high = x_low_ + 1;
}
*y_low = y_low_;
*x_low = x_low_;
T ly = y - y_low_;
T lx = x - x_low_;
T hy = 1.0 - ly;
T hx = 1.0 - lx;
*w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
return;
}
template <typename T>
__mlu_func__ void computeChannel(T *input_core, T *nram_in, T *output_core,
T *nram_out, const int roi_bin_grid_h,
const int roi_bin_grid_w, const T roi_start_h,
const T roi_start_w, const int ph,
const int pw, const T bin_size_h,
const T bin_size_w, const float count,
const int input_height, const int input_width,
const int channels, const int cyc_num,
const int max_elements) {
int cyc_channel = max_elements;
for (int i = 0; i < cyc_num; i++) {
int real_channel =
(i == cyc_num - 1) ? channels - i * cyc_channel : cyc_channel;
int align_channel = PAD_UP(real_channel, NFU_ALIGN_SIZE / sizeof(T));
__bang_write_zero(nram_out, align_channel);
uint32_t real_size = real_channel * sizeof(T);
int iy, ix;
for (iy = 0; iy < roi_bin_grid_h; iy++) {
// 1. compute the coordinates of the y axis in the current roi_bin_grid_h
T y = roi_start_h + ph * bin_size_h +
(T)(iy + 0.5) * bin_size_h / (T)(roi_bin_grid_h);
for (ix = 0; ix < roi_bin_grid_w; ix++) {
// 2. compute the coordinates of the x axis in the current
// roi_bin_grid_w
T x = roi_start_w + pw * bin_size_w +
(T)(ix + 0.5) * bin_size_w / (T)(roi_bin_grid_w);
// 3. compute the four weights (w1, w2, w3 and w4), the height (y_low
// and y_high) and weight (x_low and x_high) of input feature map in
// the current roi bin grid, and the flag (empty) which shows if x, y
// are out of input feature map ranges
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bool empty = false;
bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, &w3, &w4,
&x_low, &x_high, &y_low, &y_high, &empty);
// 4. compute interpolation of the current roi bin grid
// tmp_cyc1, temp_cyc2, tmp_cyc3 and tmp_cyc4 store the input values
// to compute the interpolation, and then reused to compute
// the argmax_x and argmax_y.
T *tmp_cyc1 = nram_in + cyc_channel;
T *tmp_cyc2 = nram_in + cyc_channel * 2;
T *tmp_cyc3 = nram_in + cyc_channel * 3;
T *tmp_cyc4 = nram_in + cyc_channel * 4;
if (empty) { // exits abnormal values
__bang_write_zero(nram_in, align_channel);
} else {
__bang_write_zero(nram_in, align_channel);
uint32_t offset1 = (y_low * input_width + x_low) * channels;
uint32_t offset2 = (y_low * input_width + x_high) * channels;
uint32_t offset3 = (y_high * input_width + x_low) * channels;
uint32_t offset4 = (y_high * input_width + x_high) * channels;
T *input1 = (T *)input_core + offset1 + i * cyc_channel;
T *input2 = (T *)input_core + offset2 + i * cyc_channel;
T *input3 = (T *)input_core + offset3 + i * cyc_channel;
T *input4 = (T *)input_core + offset4 + i * cyc_channel;
// load the four pixels (p1, p2, p3 and p4) of input feature map to
// compute interpolation
__memcpy(tmp_cyc1, input1, real_size, GDRAM2NRAM);
__memcpy(tmp_cyc2, input2, real_size, GDRAM2NRAM);
__memcpy(tmp_cyc3, input3, real_size, GDRAM2NRAM);
__memcpy(tmp_cyc4, input4, real_size, GDRAM2NRAM);
// interpolation value = w1 * p1 + w2 * p2 + w3 * p3 + w4 * p4
__bang_mul_scalar(tmp_cyc1, tmp_cyc1, w1, align_channel);
__bang_mul_scalar(tmp_cyc2, tmp_cyc2, w2, align_channel);
__bang_mul_scalar(tmp_cyc3, tmp_cyc3, w3, align_channel);
__bang_mul_scalar(tmp_cyc4, tmp_cyc4, w4, align_channel);
__bang_add(nram_in, tmp_cyc1, nram_in, align_channel);
__bang_add(nram_in, tmp_cyc2, nram_in, align_channel);
__bang_add(nram_in, tmp_cyc3, nram_in, align_channel);
__bang_add(nram_in, tmp_cyc4, nram_in, align_channel);
}
// 5. compute sum value and corresponding coordinates of x axis and y
// axis. Update the sum value.
__bang_add(nram_out, nram_in, nram_out, align_channel);
} // loop_roi_grid_w
} // loop_roi_grid_h
T count_value = (T)(1.0 / count);
__bang_mul_scalar(nram_out, nram_out, count_value, align_channel);
__memcpy(output_core + i * cyc_channel, nram_out, real_size, NRAM2GDRAM);
} // loop_cyc_num
}
template <typename T>
__mlu_func__ void roialignForwardAvg(
T *input, T *rois, T *output, const bool aligned, const int channels,
const int pooled_height, const int pooled_width, const int input_height,
const int input_width, const int sampling_ratio, const T spatial_scale,
const int num_rois) {
// find limit for channel, the nram space is divided to 6 parts that are
// input, 4 weights to compute the interpolation (w1, w2, w3, w4), output
// max_elements : 300 : float datatype : 27296, half datatype : 54592
// max_elements : 200 : float datatype : 16384, half datatype : 32768
int max_elements = (PAD_DOWN(MAX_NRAM_SIZE / 6, NFU_ALIGN_SIZE)) / sizeof(T);
int cyc_num = channels / max_elements + (int)(channels % max_elements != 0);
T offset = aligned ? (T)0.5 : (T)0.0;
int task_num = num_rois * pooled_height * pooled_width;
T *nram_out = (T *)buffer;
T *nram_in = nram_out + max_elements;
if (task_num < taskDim) {
if (taskId >= task_num) {
return;
}
}
for (int bin_idx = taskId; bin_idx < task_num; bin_idx = bin_idx + taskDim) {
if (bin_idx >= task_num) {
return;
}
// (n,ph.pw) is a c in the pooled output
int pw = bin_idx % pooled_width;
int ph = (bin_idx / pooled_width) % pooled_height;
int n = bin_idx / pooled_width / pooled_height;
T *roi_id_tmp = rois + n * ROI_OFFSET;
// 1. compute width and height of roi region.
int batch_idx = (int)roi_id_tmp[0];
T roi_x1 = roi_id_tmp[1];
T roi_y1 = roi_id_tmp[2];
T roi_x2 = roi_id_tmp[3];
T roi_y2 = roi_id_tmp[4];
T roi_start_w = roi_x1 * spatial_scale - offset;
T roi_start_h = roi_y1 * spatial_scale - offset;
T roi_end_w = roi_x2 * spatial_scale - offset;
T roi_end_h = roi_y2 * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
roi_width = roi_width > (T)(1.0) ? roi_width : (T)(1.0);
roi_height = roi_height > (T)(1.0) ? roi_height : (T)(1.0);
}
// 2. compute float-type width and height of roi bin region.
T bin_size_w = (T)roi_width / (T)pooled_width;
T bin_size_h = (T)roi_height / (T)pooled_height;
// 3. compute int-type width and height of roi bin region.
int roi_bin_grid_h, roi_bin_grid_w;
roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: int(ceilf(roi_height / pooled_height));
roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: int(ceilf(roi_width / pooled_width));
float count = (float)((roi_bin_grid_h * roi_bin_grid_w) > 1
? roi_bin_grid_h * roi_bin_grid_w
: 1.0);
T *input_core = input + batch_idx * channels * input_width * input_height;
T *output_core = output + bin_idx * channels;
// 4. compute avg value and corresponding coordinates of x axis and y axis.
computeChannel(input_core, nram_in, output_core, nram_out, roi_bin_grid_h,
roi_bin_grid_w, roi_start_h, roi_start_w, ph, pw, bin_size_h,
bin_size_w, count, input_height, input_width, channels,
cyc_num, max_elements);
}
}
__mlu_global__ void MLUUnion1KernelRoiAlignAvg(
const void *input, const void *rois, const int channels, const bool aligned,
const int pooled_height, const int pooled_width, const int input_height,
const int input_width, const int sampling_ratio, const float spatial_scale,
const int num_rois, const cnrtDataType_t data_type, void *output) {
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
switch (data_type) {
case CNRT_FLOAT16: {
roialignForwardAvg((half *)input, (half *)rois, (half *)output, aligned,
channels, pooled_height, pooled_width, input_height,
input_width, sampling_ratio, (half)spatial_scale,
num_rois);
}; break;
case CNRT_FLOAT32: {
roialignForwardAvg((float *)input, (float *)rois, (float *)output,
aligned, channels, pooled_height, pooled_width,
input_height, input_width, sampling_ratio,
(float)spatial_scale, num_rois);
}; break;
default:
break;
}
return;
}
} // namespace forward
namespace backward {
__mlu_func__ void bilinearInterpolateGradient(int height, int width, float y,
float x, float *w1, float *w2,
float *w3, float *w4, int *x_low,
int *x_high, int *y_low,
int *y_high) {
if (y < -1.0 || y > height || x < -1.0 || x > width) {
*w1 = 0.0, *w2 = 0.0, *w3 = 0.0, *w4 = 0.0;
*x_low = -1, *x_high = -1, *y_low = -1, *y_high = -1;
return;
}
if (y <= 0) {
y = 0;
}
if (x <= 0) {
x = 0;
}
*y_low = (int)y;
*x_low = (int)x;
if (*y_low >= height - 1) {
*y_high = height - 1, *y_low = height - 1;
y = (float)(*y_low);
} else {
*y_high = *y_low + 1;
}
if (*x_low >= width - 1) {
*x_high = width - 1, *x_low = width - 1;
x = (float)(*x_low);
} else {
*x_high = *x_low + 1;
}
float ly = y - *y_low, lx = x - *x_low;
float hy = 1.0 - ly, hx = 1.0 - lx;
*w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
return;
}
template <typename T>
__mlu_func__ void unionRoiAlignBp(
T *grads, T *boxes, T *grads_image, const int boxes_num, const int hi,
const int wi, const int c, const int no, const int ho, const int wo,
const float spatial_scale, const int sampling_ratio, const bool aligned) {
int c_align = PAD_UP(c, NFU_ALIGN_SIZE / sizeof(T));
int deal_all = boxes_num * hi * wi;
int deal_this_core = deal_all / taskDim + (int)(taskId < deal_all % taskDim);
for (int i = 0; i < deal_this_core; ++i) {
int bhw_id = i * taskDim + taskId;
int box_id = bhw_id / (hi * wi);
int ih = (bhw_id / wi) % hi;
int iw = bhw_id % wi;
T *box = boxes + box_id * 5;
int image_id = (int)box[0];
T *image_offset = grads_image + image_id * ho * wo * c;
T *grads_ = grads + box_id * hi * wi * c + ih * wi * c + iw * c;
float offset = aligned ? 0.5 : 0.0;
float x1 = box[1] * spatial_scale - offset;
float y1 = box[2] * spatial_scale - offset;
float x2 = box[3] * spatial_scale - offset;
float y2 = box[4] * spatial_scale - offset;
float roi_width = x2 - x1;
float roi_height = y2 - y1;
if (!aligned) {
roi_width = (roi_width > 1.0) ? roi_width : 1.0;
roi_height = (roi_height > 1.0) ? roi_height : 1.0;
}
float bin_size_h = roi_height / hi;
float bin_size_w = roi_width / wi;
int roi_grid_h =
(sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_height / hi);
int roi_grid_w =
(sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / wi);
const T count = roi_grid_h * roi_grid_w;
if (c_align * sizeof(T) * 2 <= MAX_NRAM_SIZE) {
for (int iy = 0; iy < roi_grid_h; ++iy) {
const float y =
y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h;
for (int ix = 0; ix < roi_grid_w; ++ix) {
const float x =
x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w;
float w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, &x_low,
&x_high, &y_low, &y_high);
if (x_low >= 0 && y_low >= 0) {
__memcpy(buffer, grads_, c * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w1,
c_align);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align,
1 / count, c_align);
__bang_atomic_add((T *)buffer + c_align,
image_offset + y_low * wo * c + x_low * c,
(T *)buffer + c_align, c);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w2,
c_align);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align,
1 / count, c_align);
__bang_atomic_add((T *)buffer + c_align,
image_offset + y_low * wo * c + x_high * c,
(T *)buffer + c_align, c);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w3,
c_align);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align,
1 / count, c_align);
__bang_atomic_add((T *)buffer + c_align,
image_offset + y_high * wo * c + x_low * c,
(T *)buffer + c_align, c);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w4,
c_align);
__bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align,
1 / count, c_align);
__bang_atomic_add((T *)buffer + c_align,
image_offset + y_high * wo * c + x_high * c,
(T *)buffer + c_align, c);
} // x_low && y_low
} // ix
} // iy
} else {
for (int iy = 0; iy < roi_grid_h; ++iy) {
const float y =
y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h;
for (int ix = 0; ix < roi_grid_w; ++ix) {
const float x =
x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w;
float w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, &x_low,
&x_high, &y_low, &y_high);
if (x_low >= 0 && y_low >= 0) {
int deal_once =
PAD_DOWN(MAX_NRAM_SIZE / 2, NFU_ALIGN_SIZE) / sizeof(T);
int c_repeat = c / deal_once + (int)(c % deal_once != 0);
for (int i = 0; i < c_repeat; ++i) {
int deal_c = deal_once;
int align_c = deal_once;
if (i == c_repeat - 1) {
deal_c = c - i * deal_once;
align_c = c_align - i * deal_once;
}
__memcpy(buffer, grads_ + i * deal_once, deal_c * sizeof(T),
GDRAM2NRAM);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w1,
align_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c,
1 / count, align_c);
__bang_atomic_add(
(T *)buffer + align_c,
image_offset + y_low * wo * c + x_low * c + i * deal_once,
(T *)buffer + align_c, deal_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w2,
align_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c,
1 / count, align_c);
__bang_atomic_add(
(T *)buffer + align_c,
image_offset + y_low * wo * c + x_high * c + i * deal_once,
(T *)buffer + align_c, deal_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w3,
align_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c,
1 / count, align_c);
__bang_atomic_add(
(T *)buffer + align_c,
image_offset + y_high * wo * c + x_low * c + i * deal_once,
(T *)buffer + align_c, deal_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w4,
align_c);
__bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c,
1 / count, align_c);
__bang_atomic_add(
(T *)buffer + align_c,
image_offset + y_high * wo * c + x_high * c + i * deal_once,
(T *)buffer + align_c, deal_c);
} // for c_repeat
} // x_low >= 0 && y_low >= 0
} // ix
} // iy
} // if c
} // i
}
__mlu_global__ void MLUUnion1KernelRoiAlignBackward(
const void *grads, const void *boxes, void *grads_image,
const cnrtDataType_t dtype, const int boxes_num, const int hi, const int wi,
const int c, const int no, const int ho, const int wo,
const float spatial_scale, const int sampling_ratio, const bool aligned) {
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
switch (dtype) {
case CNRT_FLOAT16: {
unionRoiAlignBp((half *)grads, (half *)boxes, (half *)grads_image,
boxes_num, hi, wi, c, no, ho, wo, spatial_scale,
sampling_ratio, aligned);
}; break;
case CNRT_FLOAT32: {
unionRoiAlignBp((float *)grads, (float *)boxes, (float *)grads_image,
boxes_num, hi, wi, c, no, ho, wo, spatial_scale,
sampling_ratio, aligned);
}; break;
default: { return; }
}
}
} // namespace backward
void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const void *input, const void *rois, const int channels,
const bool aligned, const int pooled_height,
const int pooled_width, const int input_height,
const int input_width, const int sampling_ratio,
const float spatial_scale, const int num_rois,
void *output) {
forward::MLUUnion1KernelRoiAlignAvg<<<k_dim, k_type, queue>>>(
input, rois, channels, aligned, pooled_height, pooled_width, input_height,
input_width, sampling_ratio, spatial_scale, num_rois, d_type, output);
}
void KernelRoiAlignBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t dtype,
const void *grads, const void *boxes,
void *grads_image, const int boxes_num,
const int hi, const int wi, const int c,
const int no, const int ho, const int wo,
const float spatial_scale, const int sampling_ratio,
const bool aligned) {
backward::MLUUnion1KernelRoiAlignBackward<<<k_dim, k_type, queue>>>(
grads, boxes, grads_image, dtype, boxes_num, hi, wi, c, no, ho, wo,
spatial_scale, sampling_ratio, aligned);
}
mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#define ROI_OFFSET 7
#define FLOAT_NRAM_BUFFER_NUM 14
#define HALF_NRAM_BUFFER_NUM 25
#define ALIGN_NUM 64
__nram__ char data_nram[MAX_NRAM_SIZE];
template <typename T>
__mlu_global__ void MLUUnion1KernelPtsIdxOfVoxels(
const int pool_method, const int boxes_num, const int pts_num,
const int max_pts_each_voxel, const int out_x, const int out_y,
const int out_z, const T *rois, const T *pts, int *pts_idx_of_voxels) {
// params (T)rois: (boxes_num, 7)
// params (T)pts: (3, pts_num)
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
// max_pts_each_voxel)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int nram_pts_num = 0;
if (sizeof(T) == sizeof(float)) {
nram_pts_num = PAD_DOWN(
(MAX_NRAM_SIZE / sizeof(float) / FLOAT_NRAM_BUFFER_NUM), ALIGN_NUM);
} else {
nram_pts_num = PAD_DOWN(
(MAX_NRAM_SIZE / sizeof(half) / HALF_NRAM_BUFFER_NUM), ALIGN_NUM);
}
char *X = NULL;
char *Y = NULL;
char *Z = NULL;
char *local_X = NULL;
char *local_Y = NULL;
char *local_Z = NULL;
char *nram_pts_in_flag = NULL;
float *temp_buffer1 = NULL;
float *temp_buffer2 = NULL;
float *temp_buffer3 = NULL;
float *temp_buffer4 = NULL;
float *temp_buffer5 = NULL;
float *nram_voxel_offset = NULL;
int *nram_pts_idx_seq = NULL;
float *fp_local_X = NULL;
float *fp_local_Y = NULL;
float *fp_local_Z = NULL;
float *fp_nram_pts_in_flag = NULL;
if (sizeof(T) == sizeof(float)) {
X = (char *)((float *)data_nram);
Y = (char *)((float *)data_nram + nram_pts_num);
Z = (char *)((float *)data_nram + nram_pts_num * 2);
local_X = (char *)((float *)data_nram + nram_pts_num * 3);
local_Y = (char *)((float *)data_nram + nram_pts_num * 4);
local_Z = (char *)((float *)data_nram + nram_pts_num * 5);
nram_pts_in_flag = (char *)((float *)data_nram + nram_pts_num * 6);
temp_buffer1 = (float *)data_nram + nram_pts_num * 7;
temp_buffer2 = (float *)data_nram + nram_pts_num * 8;
temp_buffer3 = (float *)data_nram + nram_pts_num * 9;
temp_buffer4 = (float *)data_nram + nram_pts_num * 10;
temp_buffer5 = (float *)data_nram + nram_pts_num * 11;
nram_voxel_offset = (float *)data_nram + nram_pts_num * 12;
nram_pts_idx_seq = (int *)((float *)data_nram + nram_pts_num * 13);
fp_local_X = (float *)local_X;
fp_local_Y = (float *)local_Y;
fp_local_Z = (float *)local_Z;
fp_nram_pts_in_flag = (float *)nram_pts_in_flag;
} else {
X = (char *)((half *)data_nram);
Y = (char *)((half *)data_nram + nram_pts_num);
Z = (char *)((half *)data_nram + nram_pts_num * 2);
local_X = (char *)((half *)data_nram + nram_pts_num * 4);
local_Y = (char *)((half *)data_nram + nram_pts_num * 6);
local_Z = (char *)((half *)data_nram + nram_pts_num * 8);
nram_pts_in_flag = (char *)((half *)data_nram + nram_pts_num * 10);
temp_buffer1 = (float *)((half *)data_nram + nram_pts_num * 11);
temp_buffer2 = (float *)((half *)data_nram + nram_pts_num * 13);
temp_buffer3 = (float *)((half *)data_nram + nram_pts_num * 15);
temp_buffer4 = (float *)((half *)data_nram + nram_pts_num * 17);
temp_buffer5 = (float *)((half *)data_nram + nram_pts_num * 19);
nram_voxel_offset = (float *)((half *)data_nram + nram_pts_num * 21);
nram_pts_idx_seq = (int *)((half *)data_nram + nram_pts_num * 23);
fp_local_X = (float *)((half *)local_X - nram_pts_num);
fp_local_Y = (float *)((half *)local_Y - nram_pts_num);
fp_local_Z = (float *)((half *)local_Z - nram_pts_num);
fp_nram_pts_in_flag = (float *)((half *)nram_pts_in_flag - nram_pts_num);
}
for (int i = 0; i < nram_pts_num; i++) {
nram_pts_idx_seq[i] = i;
}
int nram_pts_loop_times = pts_num / nram_pts_num;
int rem_nram_num = pts_num % nram_pts_num;
for (int roi_index = taskId; roi_index < boxes_num; roi_index += taskDim) {
const T *cur_roi = rois + roi_index * ROI_OFFSET;
T cx = cur_roi[0];
T cy = cur_roi[1];
T cz = cur_roi[2];
T dx = cur_roi[3];
T dy = cur_roi[4];
T dz = cur_roi[5];
T rz = cur_roi[6];
T dx_2 = dx / 2.0;
T dy_2 = dy / 2.0;
T dz_2 = dz / 2.0;
for (int loop_idx = 0; loop_idx <= nram_pts_loop_times; loop_idx++) {
int load_pts_num =
(loop_idx == nram_pts_loop_times) ? rem_nram_num : nram_pts_num;
if (load_pts_num == 0) {
break;
}
int pts_offset_cur_loop = nram_pts_num * loop_idx;
int compute_pts_num = (loop_idx == nram_pts_loop_times)
? PAD_UP(rem_nram_num, ALIGN_NUM)
: nram_pts_num;
// load pts
__memcpy((void *)X, (T *)pts + pts_offset_cur_loop,
load_pts_num * sizeof(T), GDRAM2NRAM);
__memcpy((void *)Y, (T *)pts + pts_num + pts_offset_cur_loop,
load_pts_num * sizeof(T), GDRAM2NRAM);
__memcpy((void *)Z, (T *)pts + pts_num * 2 + pts_offset_cur_loop,
load_pts_num * sizeof(T), GDRAM2NRAM);
// fabs(local_z)
__bang_sub_scalar((T *)local_Z, (T *)Z, (T)cz, compute_pts_num);
__bang_sub_scalar((T *)temp_buffer1, (T *)Z, (T)(cz + dz_2),
compute_pts_num);
__bang_active_abs((T *)temp_buffer1, (T *)temp_buffer1, compute_pts_num);
#if __BANG_ARCH__ >= 322
__bang_le_scalar((T *)nram_pts_in_flag, (T *)temp_buffer1, (T)(dz_2),
compute_pts_num);
#else
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dz_2));
__bang_le((T *)nram_pts_in_flag, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
#endif
T cosa = std::cos(-rz);
T sina = std::sin(-rz);
__bang_sub_scalar((T *)temp_buffer3, (T *)X, (T)cx, compute_pts_num);
__bang_sub_scalar((T *)temp_buffer4, (T *)Y, (T)cy, compute_pts_num);
__bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)cosa,
compute_pts_num);
__bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)sina,
compute_pts_num);
// local_x
__bang_sub((T *)local_X, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
// fabs(local_x)
__bang_active_abs((T *)temp_buffer1, (T *)local_X, compute_pts_num);
// fabs(local_x) < dx/2 ? 1 : 0
#if __BANG_ARCH__ >= 322
__bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dx_2),
compute_pts_num);
#else
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dx_2));
__bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
#endif
__bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag,
(T *)temp_buffer1,
compute_pts_num); // flush res
__bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)sina,
compute_pts_num);
__bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)cosa,
compute_pts_num);
// local_y
__bang_add((T *)local_Y, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
// fabs(local_y)
__bang_active_abs((T *)temp_buffer1, (T *)local_Y, compute_pts_num);
// fabs(local_y) < dy/2 ? 1 : 0
#if __BANG_ARCH__ >= 322
__bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dy_2),
compute_pts_num);
#else
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dy_2));
__bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
#endif
__bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag,
(T *)temp_buffer1,
compute_pts_num); // flush res
T x_res = dx / out_x;
T y_res = dy / out_y;
T z_res = dz / out_z;
__bang_add_scalar((T *)local_X, (T *)local_X, (T)(dx_2), compute_pts_num);
__bang_add_scalar((T *)local_Y, (T *)local_Y, (T)(dy_2), compute_pts_num);
// local_Z do not need to add dz/2.0
#if (__BANG_ARCH__ >= 322) && (__BANG_ARCH__ != 372)
__bang_div((T *)local_X, (T *)local_X, (T)x_res, compute_pts_num);
__bang_div((T *)local_Y, (T *)local_Y, (T)y_res, compute_pts_num);
__bang_div((T *)local_Z, (T *)local_Z, (T)z_res, compute_pts_num);
#else
__bang_mul_scalar((T *)local_X, (T *)local_X, (T)(1 / x_res),
compute_pts_num);
__bang_mul_scalar((T *)local_Y, (T *)local_Y, (T)(1 / y_res),
compute_pts_num);
__bang_mul_scalar((T *)local_Z, (T *)local_Z, (T)(1 / z_res),
compute_pts_num);
#endif
// float = float2int + int2float, half = half2int + int2float
if (sizeof(T) == sizeof(float)) {
#if __BANG_ARCH__ >= 322
__bang_float2int32_tz((int *)temp_buffer1, (float *)local_X,
compute_pts_num, 0);
__bang_float2int32_tz((int *)temp_buffer2, (float *)local_Y,
compute_pts_num, 0);
__bang_float2int32_tz((int *)temp_buffer3, (float *)local_Z,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3,
compute_pts_num, 0);
#else
convertFloat2Int((int *)temp_buffer1, (float *)temp_buffer2,
(float *)fp_local_X, (float *)temp_buffer3,
compute_pts_num);
convertFloat2Int((int *)temp_buffer2, (float *)temp_buffer3,
(float *)fp_local_Y, (float *)temp_buffer4,
compute_pts_num);
convertFloat2Int((int *)temp_buffer3, (float *)temp_buffer4,
(float *)fp_local_Z, (float *)temp_buffer5,
compute_pts_num);
convertInt2Float((float *)fp_local_X, (float *)temp_buffer4,
(int *)temp_buffer1, (float *)temp_buffer5,
compute_pts_num);
convertInt2Float((float *)fp_local_Y, (float *)temp_buffer4,
(int *)temp_buffer2, (float *)temp_buffer5,
compute_pts_num);
convertInt2Float((float *)fp_local_Z, (float *)temp_buffer4,
(int *)temp_buffer3, (float *)temp_buffer5,
compute_pts_num);
#endif
} else {
__bang_half2float((float *)temp_buffer4, (half *)nram_pts_in_flag,
compute_pts_num);
__bang_move((void *)fp_nram_pts_in_flag, (void *)temp_buffer4,
compute_pts_num * sizeof(float));
#if __BANG_ARCH__ >= 322
__bang_half2int32_tz((int *)temp_buffer1, (half *)local_X,
compute_pts_num, 0);
__bang_half2int32_tz((int *)temp_buffer2, (half *)local_Y,
compute_pts_num, 0);
__bang_half2int32_tz((int *)temp_buffer3, (half *)local_Z,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3,
compute_pts_num, 0);
#else
__bang_half2int16_tz((int16_t *)temp_buffer1, (half *)local_X,
compute_pts_num, 0);
__bang_half2int16_tz((int16_t *)temp_buffer2, (half *)local_Y,
compute_pts_num, 0);
__bang_half2int16_tz((int16_t *)temp_buffer3, (half *)local_Z,
compute_pts_num, 0);
__bang_int162float((float *)fp_local_X, (int16_t *)temp_buffer1,
compute_pts_num, 0);
__bang_int162float((float *)fp_local_Y, (int16_t *)temp_buffer2,
compute_pts_num, 0);
__bang_int162float((float *)fp_local_Z, (int16_t *)temp_buffer3,
compute_pts_num, 0);
#endif
}
// process index >= 0
__bang_write_value((float *)temp_buffer4, compute_pts_num, (float)0.0f);
__bang_maxequal((float *)fp_local_X, (float *)fp_local_X,
(float *)temp_buffer4, compute_pts_num);
__bang_maxequal((float *)fp_local_Y, (float *)fp_local_Y,
(float *)temp_buffer4, compute_pts_num);
__bang_maxequal((float *)fp_local_Z, (float *)fp_local_Z,
(float *)temp_buffer4, compute_pts_num);
// process index <= (out_x - 1)
__bang_write_value((float *)temp_buffer5, compute_pts_num,
(float)(out_x - 1));
__bang_minequal((float *)fp_local_X, (float *)fp_local_X,
(float *)temp_buffer5, compute_pts_num);
__bang_write_value((float *)temp_buffer5, compute_pts_num,
(float)(out_y - 1));
__bang_minequal((float *)fp_local_Y, (float *)fp_local_Y,
(float *)temp_buffer5, compute_pts_num);
__bang_write_value((float *)temp_buffer5, compute_pts_num,
(float)(out_z - 1));
__bang_minequal((float *)fp_local_Z, (float *)fp_local_Z,
(float *)temp_buffer5, compute_pts_num);
__bang_mul_scalar((float *)temp_buffer1, (float *)fp_local_X,
(float)(out_y * out_z), compute_pts_num);
__bang_mul_scalar((float *)temp_buffer2, (float *)fp_local_Y,
(float)out_z, compute_pts_num);
__bang_mul_scalar((float *)temp_buffer3, (float *)fp_local_Z, (float)1.0,
compute_pts_num);
__bang_add((float *)nram_voxel_offset, (float *)temp_buffer1,
(float *)temp_buffer2, compute_pts_num);
__bang_add((float *)nram_voxel_offset, (float *)nram_voxel_offset,
(float *)temp_buffer3, compute_pts_num);
__bang_mul_scalar((float *)nram_voxel_offset, (float *)nram_voxel_offset,
(float)max_pts_each_voxel, compute_pts_num);
if (compute_pts_num != load_pts_num) {
__memset_nram((float *)fp_nram_pts_in_flag + load_pts_num,
compute_pts_num - load_pts_num, (float)0.0);
}
__bang_collect((float *)temp_buffer4, (float *)nram_pts_idx_seq,
(float *)fp_nram_pts_in_flag, compute_pts_num);
int pts_num_in_cur_roi =
(int)__bang_count((float *)fp_nram_pts_in_flag, compute_pts_num);
int *pts_idx_cur_voxels =
(int *)pts_idx_of_voxels +
roi_index * out_x * out_y * out_z * max_pts_each_voxel;
for (int idx = 0; idx < pts_num_in_cur_roi; idx++) {
int cur_pts_idx = *((int *)temp_buffer4 + idx);
int offset = (int)(*((float *)nram_voxel_offset + cur_pts_idx));
int cnt = pts_idx_cur_voxels[offset];
if (cnt < max_pts_each_voxel - 1) {
pts_idx_cur_voxels[offset + cnt + 1] =
cur_pts_idx + loop_idx * nram_pts_num;
pts_idx_cur_voxels[offset]++;
}
}
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelRoiawarePool3dForward(
const int pool_method, const int boxes_num, const int pts_num,
const int channels, const int max_pts_each_voxel, const int out_x,
const int out_y, const int out_z, const T *pts_feature,
const int *pts_idx_of_voxels, T *pooled_features, int *argmax) {
// params (T)pts_feature: (channels, pts_num)
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
// max_pts_each_voxel) params (int)argmax: (boxes_num, out_x, out_y, out_z,
// channels) params (T)pooled_features: (boxes_num, out_x, out_y, out_z,
// channels)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int align_num = NFU_ALIGN_SIZE / sizeof(T);
int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num);
int nram_channels_limit =
PAD_DOWN((MAX_NRAM_SIZE - 128 -
align_max_pts_each_voxel * (sizeof(int) + sizeof(T))) /
((align_max_pts_each_voxel + 1) * sizeof(T) + sizeof(int)),
align_num);
int *nram_pts_idx_cur_voxel = (int *)data_nram;
// nram_pts_idx_cur_voxel [align_max_pts_each_voxel]
T *nram_max_pts_feature_tmp =
(T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel);
// nram_max_pts_feature_tmp [align_max_pts_each_voxel]
T *nram_pts_feature_in_voxel =
((T *)nram_max_pts_feature_tmp + align_max_pts_each_voxel);
// nram_pts_feature_in_voxel [nram_channels_limit, align_max_pts_each_voxel]
T *nram_pooled_features_cur_voxel =
((T *)nram_pts_feature_in_voxel +
nram_channels_limit * align_max_pts_each_voxel);
// nram_pooled_features_cur_voxel [nram_channels_limit]
int *nram_argmax_cur_voxel =
(int *)((T *)nram_pooled_features_cur_voxel + nram_channels_limit);
// nram_argmax_cur_voxel [nram_channels_limit]
char *one_pooled_feature =
(char *)((int *)nram_argmax_cur_voxel + nram_channels_limit);
// one_pooled_feature [128]
int channels_loop_times = channels / nram_channels_limit;
int rem_channels = channels % nram_channels_limit;
for (int voxel_index = taskId;
voxel_index < boxes_num * out_x * out_y * out_z;
voxel_index += taskDim) {
int *pts_idx_cur_voxels =
(int *)pts_idx_of_voxels + voxel_index * max_pts_each_voxel;
__memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxels,
max_pts_each_voxel * sizeof(int), GDRAM2NRAM);
int pts_num_cur_voxel = nram_pts_idx_cur_voxel[0];
if (pts_num_cur_voxel == 0) {
continue;
}
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
channels_loop_idx++) {
int actual_channels_num = (channels_loop_idx == channels_loop_times)
? rem_channels
: nram_channels_limit;
if (actual_channels_num == 0) {
break;
}
int channels_offset = nram_channels_limit * channels_loop_idx;
#if ((__BANG_ARCH__ >= 200) && (__BANG_ARCH__ < 300))
int compute_channels_num = (channels_loop_idx == channels_loop_times)
? PAD_UP(rem_channels, align_num)
: nram_channels_limit;
if (pool_method == 0) {
__bang_write_value((void *)nram_pts_feature_in_voxel,
compute_channels_num * align_max_pts_each_voxel,
(T)-INFINITY);
}
#endif
T *pts_feature_cur_loop = (T *)pts_feature + channels_offset * pts_num;
for (int idx = 0; idx < pts_num_cur_voxel; idx++) {
__memcpy((T *)nram_pts_feature_in_voxel + idx,
(T *)pts_feature_cur_loop + nram_pts_idx_cur_voxel[idx + 1],
sizeof(T), GDRAM2NRAM, align_max_pts_each_voxel * sizeof(T),
pts_num * sizeof(T), actual_channels_num - 1);
}
for (int channel_idx = 0; channel_idx < actual_channels_num;
channel_idx++) {
if (pool_method == 0) {
#if __BANG_ARCH__ >= 322
__bang_argmax((T *)one_pooled_feature,
(T *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel,
pts_num_cur_voxel);
T max_val = ((T *)one_pooled_feature)[0];
int max_idx = (int)(*(uint32_t *)((T *)one_pooled_feature + 1));
nram_pooled_features_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? 0 : max_val;
nram_argmax_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1];
#else
// __bang_max need align num on mlu200 series
if (sizeof(T) == sizeof(float)) {
__bang_max((float *)one_pooled_feature,
(float *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel,
align_max_pts_each_voxel);
float max_val = ((float *)one_pooled_feature)[0];
__bang_write_value((void *)nram_max_pts_feature_tmp,
align_max_pts_each_voxel, (float)max_val);
__bang_eq((float *)nram_max_pts_feature_tmp,
(float *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel,
(float *)nram_max_pts_feature_tmp,
align_max_pts_each_voxel);
int max_idx = (int)__bang_findfirst1(
(float *)nram_max_pts_feature_tmp, align_max_pts_each_voxel);
nram_pooled_features_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? 0 : max_val;
nram_argmax_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? -1
: nram_pts_idx_cur_voxel[max_idx + 1];
} else {
int max_idx = -1;
float max_val = -INFINITY;
for (int k = 0; k < pts_num_cur_voxel; k++) {
float pts_feature_cur_channel = __half2float_rd(
*((half *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel + k));
if (pts_feature_cur_channel > max_val) {
max_val = pts_feature_cur_channel;
max_idx = k;
}
}
nram_pooled_features_cur_voxel[channel_idx] =
(max_idx == -1) ? 0 : max_val;
nram_argmax_cur_voxel[channel_idx] =
(max_idx == -1) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1];
}
#endif
} else if (pool_method == 1) {
float sum_val_cur_channel = 0;
for (int k = 0; k < pts_num_cur_voxel; k++) {
sum_val_cur_channel += static_cast<float>(
((T *)nram_pts_feature_in_voxel)[channel_idx *
align_max_pts_each_voxel +
k]);
}
nram_pooled_features_cur_voxel[channel_idx] =
(T)(sum_val_cur_channel / pts_num_cur_voxel);
}
}
// store
__memcpy((T *)pooled_features + voxel_index * channels + channels_offset,
(void *)nram_pooled_features_cur_voxel,
actual_channels_num * sizeof(T), NRAM2GDRAM);
if (pool_method == 0) {
__memcpy((int *)argmax + voxel_index * channels + channels_offset,
(void *)nram_argmax_cur_voxel,
actual_channels_num * sizeof(int), NRAM2GDRAM);
}
}
}
}
void KernelPtsIdxOfVoxels(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const int pool_method, const int boxes_num,
const int pts_num, const int max_pts_each_voxel,
const int out_x, const int out_y, const int out_z,
const void *rois, const void *pts,
int *pts_idx_of_voxels) {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelPtsIdxOfVoxels<float><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y,
out_z, (float *)rois, (float *)pts, (int *)pts_idx_of_voxels);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelPtsIdxOfVoxels<half><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y,
out_z, (half *)rois, (half *)pts, (int *)pts_idx_of_voxels);
}; break;
default: {
break;
}
}
}
void KernelRoiawarePool3dForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
const int pts_num, const int channels, const int max_pts_each_voxel,
const int out_x, const int out_y, const int out_z, const void *pts_feature,
const int *pts_idx_of_voxels, void *pooled_features, int *argmax) {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelRoiawarePool3dForward<float><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x,
out_y, out_z, (float *)pts_feature, (int *)pts_idx_of_voxels,
(float *)pooled_features, (int *)argmax);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelRoiawarePool3dForward<half><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x,
out_y, out_z, (half *)pts_feature, (int *)pts_idx_of_voxels,
(half *)pooled_features, (int *)argmax);
}; break;
default: {
break;
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelRoiawareMaxPool3dBackward(
const int boxes_num, const int out_x, const int out_y, const int out_z,
const int channels, const int *argmax, const T *grad_out, T *grad_in) {
// params (int)argmax: (boxes_num, out_x, out_y, out_z, channels)
// params (T)grad_out: (boxes_num, out_x, out_y, out_z, channels)
// params (T)grad_in: (pts_num, channels)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int nram_channels_limit =
(MAX_NRAM_SIZE - sizeof(T) * 1) / (sizeof(T) + sizeof(int));
int *nram_argmax_cur_loop = (int *)data_nram;
// nram_argmax_cur_loop [nram_channels_limit]
T *nram_grad_out_cur_loop =
(T *)((int *)nram_argmax_cur_loop + nram_channels_limit);
// nram_grad_out_cur_loop [nram_channels_limit]
T *nram_grad_in_cur_channel =
(T *)nram_grad_out_cur_loop + nram_channels_limit;
// nram_grad_in_cur_channel [1]
int channels_loop_times = channels / nram_channels_limit;
int rem_channels = channels % nram_channels_limit;
int voxels_num = boxes_num * out_x * out_y * out_z;
for (int voxel_index = taskId; voxel_index < voxels_num;
voxel_index += taskDim) {
const int *argmax_cur_voxel = argmax + voxel_index * channels;
const T *grad_out_cur_voxel = grad_out + voxel_index * channels;
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
channels_loop_idx++) {
int actual_channels_num = (channels_loop_idx == channels_loop_times)
? rem_channels
: nram_channels_limit;
if (actual_channels_num == 0) {
break;
}
const int *argmax_cur_loop =
argmax_cur_voxel + nram_channels_limit * channels_loop_idx;
const T *grad_out_cur_loop =
grad_out_cur_voxel + nram_channels_limit * channels_loop_idx;
__memcpy((void *)nram_argmax_cur_loop, (void *)argmax_cur_loop,
actual_channels_num * sizeof(int), GDRAM2NRAM);
__memcpy((void *)nram_grad_out_cur_loop, (void *)grad_out_cur_loop,
actual_channels_num * sizeof(T), GDRAM2NRAM);
for (int channel_idx = 0; channel_idx < actual_channels_num;
channel_idx++) {
int *nram_argmax_cur_channel = nram_argmax_cur_loop + channel_idx;
T *nram_grad_out_cur_channel = nram_grad_out_cur_loop + channel_idx;
if (nram_argmax_cur_channel[0] == -1) {
continue;
}
T *grad_in_cur_channel =
grad_in + nram_argmax_cur_channel[0] * channels +
nram_channels_limit * channels_loop_idx + channel_idx;
__bang_atomic_add((T *)nram_grad_in_cur_channel,
(T *)grad_in_cur_channel,
(T *)(nram_grad_out_cur_channel), 1);
}
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelRoiawareAvgPool3dBackward(
const int boxes_num, const int out_x, const int out_y, const int out_z,
const int channels, const int max_pts_each_voxel,
const int *pts_idx_of_voxels, const T *grad_out, T *grad_in) {
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
// max_pts_each_voxel) params (T)grad_out: (boxes_num, out_x, out_y, out_z,
// channels) params (T)grad_in: (pts_num, channels)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int align_num = NFU_ALIGN_SIZE / sizeof(T);
int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num);
int nram_channels_limit = PAD_DOWN(
(MAX_NRAM_SIZE - align_max_pts_each_voxel * sizeof(int)) / 2 / sizeof(T),
align_num);
int *nram_pts_idx_cur_voxel = (int *)data_nram;
// nram_pts_idx_cur_voxel [align_max_pts_each_voxel]
T *nram_grad_out_cur_loop =
(T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel);
// nram_grad_out_cur_loop [nram_channels_limit]
T *nram_grad_in_cur_loop = (T *)nram_grad_out_cur_loop + nram_channels_limit;
// nram_grad_in_cur_loop [nram_channels_limit]
int channels_loop_times = channels / nram_channels_limit;
int rem_channels = channels % nram_channels_limit;
int voxels_num = boxes_num * out_x * out_y * out_z;
for (int voxel_index = taskId; voxel_index < voxels_num;
voxel_index += taskDim) {
const T *grad_out_cur_voxel = grad_out + voxel_index * channels;
const int *pts_idx_cur_voxel =
pts_idx_of_voxels + voxel_index * max_pts_each_voxel;
__memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxel,
max_pts_each_voxel * sizeof(int), GDRAM2NRAM);
int total_pts_of_voxel = nram_pts_idx_cur_voxel[0];
if (total_pts_of_voxel <= 0) {
continue;
}
float cur_grad = 1.0 / ((float)total_pts_of_voxel);
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
channels_loop_idx++) {
int actual_channels_num = (channels_loop_idx == channels_loop_times)
? rem_channels
: nram_channels_limit;
if (actual_channels_num == 0) {
break;
}
const T *grad_out_cur_loop =
grad_out_cur_voxel + nram_channels_limit * channels_loop_idx;
__memcpy((void *)nram_grad_in_cur_loop, (void *)grad_out_cur_loop,
actual_channels_num * sizeof(T), GDRAM2NRAM);
int align_actual_channels_num = PAD_UP(actual_channels_num, align_num);
if (sizeof(T) == sizeof(half)) {
__bang_half2float((float *)nram_grad_out_cur_loop,
(half *)nram_grad_in_cur_loop,
align_actual_channels_num);
__bang_mul_scalar((float *)nram_grad_out_cur_loop,
(float *)nram_grad_out_cur_loop, (float)cur_grad,
align_actual_channels_num);
convertFloat2half((half *)nram_grad_out_cur_loop,
(float *)nram_grad_out_cur_loop,
align_actual_channels_num);
} else {
__bang_mul_scalar((float *)nram_grad_out_cur_loop,
(float *)nram_grad_in_cur_loop, (float)cur_grad,
align_actual_channels_num);
}
for (int k = 1; k <= total_pts_of_voxel; k++) {
T *grad_in_cur_loop = grad_in + nram_pts_idx_cur_voxel[k] * channels +
nram_channels_limit * channels_loop_idx;
__bang_atomic_add((T *)nram_grad_in_cur_loop, (T *)grad_in_cur_loop,
(T *)nram_grad_out_cur_loop, actual_channels_num);
}
}
}
}
void KernelRoiawarePool3dBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
const int out_x, const int out_y, const int out_z, const int channels,
const int max_pts_each_voxel, const int *pts_idx_of_voxels,
const int *argmax, const void *grad_out, void *grad_in) {
if (pool_method == 0) {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelRoiawareMaxPool3dBackward<float>
<<<k_dim, k_type, queue>>>(boxes_num, out_x, out_y, out_z, channels,
(int *)argmax, (float *)grad_out,
(float *)grad_in);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelRoiawareMaxPool3dBackward<half>
<<<k_dim, k_type, queue>>>(boxes_num, out_x, out_y, out_z, channels,
(int *)argmax, (half *)grad_out,
(half *)grad_in);
}; break;
default: {
break;
}
}
} else {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelRoiawareAvgPool3dBackward<float>
<<<k_dim, k_type, queue>>>(
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
(int *)pts_idx_of_voxels, (float *)grad_out, (float *)grad_in);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelRoiawareAvgPool3dBackward<half>
<<<k_dim, k_type, queue>>>(
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
(int *)pts_idx_of_voxels, (half *)grad_out, (half *)grad_in);
}; break;
default: {
break;
}
}
}
}
mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include <algorithm>
__nram__ char nram_buffer[MAX_NRAM_SIZE];
#if __BANG_ARCH__ >= 322
/**
* returns the index of ret, which is stored at the 1st position of the `ret`,
* used after bang_min
*/
__mlu_func__ uint32_t getIndice(half *ret) {
uint32_t indice = *((uint32_t *)((uint16_t *)ret + 1));
return indice;
}
/**
* returns the index of ret, which is stored at the 1st position of the `ret`,
* used after bang_min
*/
__mlu_func__ uint32_t getIndice(float *ret) {
uint32_t indice = ((uint32_t *)ret)[1];
return indice;
}
#endif
template <typename T>
__mlu_func__ void auxArgmin(T *nram_dst, T *nram_src, const int num_deal,
T *value, int *index) {
__bang_min(nram_dst, nram_src, num_deal);
*value = nram_dst[0];
__bang_write_value(nram_dst, num_deal, *value);
__bang_eq(nram_dst, nram_src, nram_dst, num_deal);
__bang_findfirst1((uint32_t *)nram_dst, nram_dst, num_deal);
*index = *((int *)nram_dst);
}
template <typename T>
__mlu_func__ void auxFuncFind3Min(T *nram_aux_a, const int auxa_offset,
int *nram_aux_b, const int auxb_offset,
T *nram_dest, T *nram_aux_sort_a,
int *nram_aux_sort_b, const int deal_offset) {
__bang_write_value(nram_aux_sort_a, auxa_offset, (T)(INFINITY));
__bang_write_value(nram_aux_sort_b, auxb_offset, (int)0);
int index = 0;
for (int i = 0; i < 3; i++) {
#if __BANG_ARCH__ >= 322
__bang_argmin(nram_dest, nram_aux_a, auxa_offset);
nram_aux_sort_a[i] = nram_dest[0];
index = getIndice(nram_dest);
#else
T value = 0;
auxArgmin(nram_dest, nram_aux_a, auxa_offset, &value, &index);
nram_aux_sort_a[i] = value;
#endif
nram_aux_sort_b[i] = nram_aux_b[index];
__memset_nram(nram_aux_a + index, 1, (T)(INFINITY));
}
__memcpy((char *)nram_aux_a, (char *)nram_aux_sort_a, auxa_offset * sizeof(T),
NRAM2NRAM);
__memcpy((char *)nram_aux_b, (char *)nram_aux_sort_b,
auxb_offset * sizeof(int), NRAM2NRAM);
}
template <typename T>
__mlu_func__ void auxFuncSort(T *nram_aux_a, const int auxa_offset,
int *nram_aux_b, const int auxb_offset,
T *nram_dest, T *nram_help_value,
int *nram_help_idx, const int num_deal,
const int deal_offset) {
for (int k = 0; k < num_deal; ++k) {
auxFuncFind3Min(nram_aux_a + k * auxa_offset, auxa_offset,
nram_aux_b + k * auxb_offset, auxb_offset, nram_dest,
nram_help_value, nram_help_idx, deal_offset);
}
}
template <typename T>
__mlu_func__ void auxFuncNN(
size_t *output_aux_sort_a_gap, size_t *output_aux_sort_b_gap,
size_t *output_aux_dest_gap, size_t *output_unknown_gap,
size_t *output_known_gap, size_t *output_dist_gap, size_t *auxillary_a_gap,
size_t *auxillary_b_gap, size_t *known_num_deal, size_t *unknown_num_deal,
size_t *align_num, size_t *auxa_offset, size_t *auxb_offset) {
/*
* nram partition:
* |-NFU_ALIGN_SIZE-|-2*NFU_ALIGN_SIZE-|-X*3*sizeof(T)-|
* space: | aux_sort_a | aux_sort_b | nram_unknown |
*
* | ------ (Y * 7 *sizeof(T)) ---------------- |
* | nram_known | nram_dist | nram_dest |
*
* | -X * NFU_ALIGN_SIZE ---|---X * 2 * NFU_ALIGN_SIZE-|
* | output_dist(aux_a) | output_dist(aux_b) |
* 200 series
* X = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) * (2/3) / (3 * sizeof(T) + 3 *
* NFU_ALIGN_SIZE)
* Y = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) * (1/3) / (7 * sizeof(T))
* 300 series
* X = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) * (4/5) / (3 *
* sizeof(T) + 3 * NFU_ALIGN_SIZE)
* Y = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) *
* (1/5) / (7 * sizeof(T))
*
*/
*align_num = NFU_ALIGN_SIZE / sizeof(T);
*auxa_offset = NFU_ALIGN_SIZE / sizeof(T);
*auxb_offset = 2 * NFU_ALIGN_SIZE / sizeof(int);
#if __BANG_ARCH__ >= 322
*known_num_deal = PAD_DOWN(
(MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 5 / (7 * sizeof(T)), *align_num);
*unknown_num_deal = PAD_DOWN((MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 5 * 4 /
(3 * sizeof(T) + 3 * NFU_ALIGN_SIZE),
*align_num);
#else
*known_num_deal = PAD_DOWN(
(MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 3 / (7 * sizeof(T)), *align_num);
*unknown_num_deal = PAD_DOWN((MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 3 * 2 /
(3 * sizeof(T) + 3 * NFU_ALIGN_SIZE),
*align_num);
#endif
*output_aux_sort_a_gap = 0;
*output_aux_sort_b_gap = *output_aux_sort_a_gap + NFU_ALIGN_SIZE;
*output_aux_dest_gap = *output_aux_sort_b_gap + 2 * NFU_ALIGN_SIZE;
*output_unknown_gap = *output_aux_dest_gap + *known_num_deal * sizeof(T);
*output_known_gap = *output_unknown_gap + *unknown_num_deal * 3 * sizeof(T);
*output_dist_gap = *output_known_gap + *known_num_deal * 3 * sizeof(T);
*auxillary_a_gap = *output_dist_gap + *known_num_deal * 3 * sizeof(T);
*auxillary_b_gap = *auxillary_a_gap + *unknown_num_deal * NFU_ALIGN_SIZE;
}
#if __BANG_ARCH__ >= 322
template <typename T>
__mlu_func__ bool containNanInf(T *nram_unknown) {
if (std::isnan(nram_unknown[0]) || std::isnan(nram_unknown[1]) ||
std::isnan(nram_unknown[2]) || std::isinf(nram_unknown[0]) ||
std::isinf(nram_unknown[1]) || std::isinf(nram_unknown[2]))
return true;
else
return false;
}
#endif
template <typename T>
__mlu_func__ void computeThreeNN(T *nram_unknown, T *nram_known, T *nram_dist,
T *nram_dest, T *nram_aux_a,
T *nram_aux_sort_a, int *nram_aux_b,
int *nram_aux_sort_b, const int known_num_deal,
const int known_seg_num, const int deal_offset,
const int known_count,
const int known_count_align) {
__bang_write_value(nram_dist, 3 * known_num_deal, (T)(INFINITY));
#if __BANG_ARCH__ >= 322
if (!containNanInf(nram_unknown)) {
#endif
// x1 - x2
__bang_sub_scalar(nram_dist, nram_known, nram_unknown[0],
known_count_align);
// y1 - y2
__bang_sub_scalar(nram_dist + known_count_align,
nram_known + known_count_align, nram_unknown[1],
known_count_align);
// z1 - z2
__bang_sub_scalar(nram_dist + 2 * known_count_align,
nram_known + 2 * known_count_align, nram_unknown[2],
known_count_align);
__bang_square(nram_dist, nram_dist, 3 * known_count_align);
__bang_add(nram_dist, nram_dist, nram_dist + known_count_align,
known_count_align);
__bang_add(nram_dist, nram_dist, nram_dist + 2 * known_count_align,
known_count_align);
#if __BANG_ARCH__ >= 322
}
#endif
int index = 0;
for (int i = 0; i < 3; i++) {
#if __BANG_ARCH__ >= 322
__bang_argmin(nram_dest, nram_dist, known_count_align);
nram_aux_a[i + deal_offset] = nram_dest[0];
index = getIndice(nram_dest);
#else
T value = 0;
auxArgmin(nram_dest, nram_dist, known_count_align, &value, &index);
nram_aux_a[i + deal_offset] = value;
#endif
nram_aux_b[i + deal_offset] = index + known_seg_num * known_num_deal;
__memset_nram(nram_dist + index, 1, (T)(INFINITY));
}
}
template <typename T>
__mlu_func__ void loadTransposedKnownTensor(
char *nram_known, char *nram_dist, const char *known_gdram,
const int known_num_deal, const int batch_id, const int m,
const int known_seg_num, const int count, const int count_align_num) {
__bang_write_value(nram_known, 3 * known_num_deal, (T)(INFINITY));
#if __BANG_ARCH__ >= 322
__bang_write_value(nram_dist, 3 * known_num_deal, (T)(INFINITY));
__memcpy(nram_dist,
known_gdram +
(batch_id * m * 3 + known_seg_num * known_num_deal) * sizeof(T),
count * sizeof(T), GDRAM2NRAM, count_align_num * sizeof(T),
m * sizeof(T), 2);
__bang_minequal((T *)nram_known, (T *)nram_known, (T *)nram_dist,
3 * count_align_num);
#else
__memcpy(nram_known,
known_gdram +
(batch_id * m * 3 + known_seg_num * known_num_deal) * sizeof(T),
count * sizeof(T), GDRAM2NRAM, count_align_num * sizeof(T),
m * sizeof(T), 2);
#endif
}
template <typename T>
__mlu_func__ void loadUnknownTensor(char *nram_unknown,
const char *unknown_gdram,
const int unknown_num_deal,
const int unknown_seg_num, const int count,
const int count_align_num) {
__memcpy(nram_unknown,
unknown_gdram + unknown_seg_num * unknown_num_deal * 3 * sizeof(T),
count * 3 * sizeof(T), GDRAM2NRAM);
}
template <typename T>
__mlu_func__ void auxProcessSegment(
const int m, const int n, T *nram_unknown, T *nram_known, T *nram_dist,
T *nram_dest, T *known_gdram, T *nram_aux_a, const int auxa_offset,
int *nram_aux_b, const int auxb_offset, T *nram_aux_sort_a,
int *nram_aux_sort_b, const int unknown_num_deal, const int known_num_deal,
const int known_seg_num, const int unknown_seg_num, const int unknown_count,
const int known_count, const int known_count_align, const int start_idx,
int *deal_offset) {
int pre_batch_id = -1;
int cur_batch_id = -1;
pre_batch_id = start_idx / n;
// if aux_a space is not enough, get the first 3 min among aux_a and clear.
if (*deal_offset >= PAD_DOWN(auxa_offset, 3)) {
auxFuncSort(nram_aux_a, auxa_offset, nram_aux_b, auxb_offset, nram_dest,
nram_aux_sort_a, nram_aux_sort_b, unknown_count, *deal_offset);
*deal_offset = 3;
}
// load i'th segment of known batch data.
loadTransposedKnownTensor<T>((char *)nram_known, (char *)nram_dist,
(char *)known_gdram, known_num_deal,
pre_batch_id, m, known_seg_num, known_count,
known_count_align);
for (int k = 0; k < unknown_count; ++k) {
cur_batch_id = (start_idx + k) / n;
if (cur_batch_id != pre_batch_id) { // if batch id of unknown data changed,
// load corresponding known batch data
pre_batch_id = cur_batch_id;
loadTransposedKnownTensor<T>((char *)nram_known, (char *)nram_dist,
(char *)known_gdram, known_num_deal,
pre_batch_id, m, known_seg_num, known_count,
known_count_align);
}
computeThreeNN(nram_unknown + 3 * k, nram_known, nram_dist, nram_dest,
nram_aux_a + k * auxa_offset, nram_aux_sort_a,
nram_aux_b + k * auxb_offset, nram_aux_sort_b,
known_num_deal, known_seg_num, *deal_offset, known_count,
known_count_align);
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelThreeNN(const int b, const int n,
const int m, char *unknown_gdram,
char *known_gdram, char *dist2_gdram,
int *idx_gdram) {
if (coreId == 0x80) {
return;
}
size_t output_aux_sort_a_gap = 0, output_aux_sort_b_gap = 0,
output_dest_gap = 0, output_unknown_gap = 0, output_known_gap = 0,
output_dist_gap = 0, auxillary_a_gap = 0, auxillary_b_gap = 0,
known_num_deal = 0, unknown_num_deal = 0, align_num = 0,
auxa_offset = 0, auxb_offset = 0;
auxFuncNN<T>(&output_aux_sort_a_gap, &output_aux_sort_b_gap, &output_dest_gap,
&output_unknown_gap, &output_known_gap, &output_dist_gap,
&auxillary_a_gap, &auxillary_b_gap, &known_num_deal,
&unknown_num_deal, &align_num, &auxa_offset, &auxb_offset);
int num_per_core = b * n / taskDim;
const int core_offset = num_per_core;
char *unknown_gdram_start =
unknown_gdram + taskId * 3 * core_offset * sizeof(T);
char *known_gdram_start = known_gdram;
char *output_dist_start = dist2_gdram + taskId * 3 * core_offset * sizeof(T);
int *output_idx_start = idx_gdram + taskId * 3 * core_offset;
const int rem = (b * n) % taskDim;
if (taskId == taskDim - 1) {
num_per_core += rem;
}
const int unknown_repeat =
num_per_core / unknown_num_deal; // if unknown number is big, process it
// by unknown_repeat times.
const int unknown_rem = num_per_core % unknown_num_deal; // unknown reminder
const int unknown_rem_align = PAD_UP(unknown_rem, align_num);
const int known_repeat =
m / known_num_deal; // if known number is big, process it by
// unknown_repeat times.
const int known_rem = m % known_num_deal; // known reminder
const int known_rem_align = PAD_UP(known_rem, align_num);
char *nram_aux_sort_a = nram_buffer;
int *nram_aux_sort_b = (int *)(nram_buffer + output_aux_sort_b_gap);
char *nram_dest = nram_buffer + output_dest_gap;
char *nram_unknown = nram_buffer + output_unknown_gap;
char *nram_known = nram_buffer + output_known_gap;
char *nram_dist = nram_buffer + output_dist_gap;
char *nram_aux_a = nram_buffer + auxillary_a_gap;
int *nram_aux_b = (int *)(nram_buffer + auxillary_b_gap);
int deal_offset = 0;
int start_idx = -1;
for (int j = 0; j < unknown_repeat;
++j) { // process data within a unknown_repeat
// if unknown need to be process segmentally, use a aux_a and aux_b
// space to find first 3 minimum dist.
__bang_write_value(nram_aux_a, unknown_num_deal * auxa_offset,
(T)(INFINITY));
__bang_write_value(nram_aux_b, unknown_num_deal * auxb_offset, (int)0);
loadUnknownTensor<T>(nram_unknown, unknown_gdram_start, unknown_num_deal, j,
unknown_num_deal, unknown_num_deal);
deal_offset = 0;
start_idx = taskId * core_offset + j * unknown_num_deal;
for (int i = 0; i < known_repeat;
++i) { // process known data in segmentally.
auxProcessSegment<T>(
m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist,
(T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset,
nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_num_deal, known_num_deal, i, j, unknown_num_deal,
known_num_deal, known_num_deal, start_idx, &deal_offset);
deal_offset += 3;
}
if (known_rem > 0) { // process known rem
__bang_write_value(nram_known, 3 * known_num_deal, (T)(INFINITY));
auxProcessSegment<T>(
m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist,
(T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset,
nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_num_deal, known_num_deal, known_repeat, j, unknown_num_deal,
known_rem, known_rem_align, start_idx, &deal_offset);
}
deal_offset += 3;
if (deal_offset > 3) {
auxFuncSort((T *)nram_aux_a, auxa_offset, nram_aux_b, auxb_offset,
(T *)nram_dest, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_num_deal, deal_offset);
deal_offset = 0;
}
__memcpy((char *)output_dist_start + j * unknown_num_deal * 3 * sizeof(T),
(char *)nram_aux_a, 3 * sizeof(T), NRAM2GDRAM, 3 * sizeof(T),
auxa_offset * sizeof(T), unknown_num_deal - 1);
__memcpy((char *)output_idx_start + j * unknown_num_deal * 3 * sizeof(int),
(char *)nram_aux_b, 3 * sizeof(int), NRAM2GDRAM, 3 * sizeof(int),
auxb_offset * sizeof(int), unknown_num_deal - 1);
}
if (unknown_rem > 0) { // process unknown rem
deal_offset = 0;
__bang_write_value(nram_aux_a, unknown_num_deal * auxa_offset,
(T)(INFINITY));
__bang_write_value(nram_aux_b, unknown_num_deal * auxb_offset, (int)0);
loadUnknownTensor<T>(nram_unknown, unknown_gdram_start, unknown_num_deal,
unknown_repeat, unknown_rem, unknown_rem_align);
start_idx = taskId * core_offset + unknown_repeat * unknown_num_deal;
for (int i = 0; i < known_repeat; ++i) {
auxProcessSegment<T>(
m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist,
(T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset,
nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_num_deal, known_num_deal, i, unknown_repeat, unknown_rem,
known_num_deal, known_num_deal, start_idx, &deal_offset);
deal_offset += 3;
}
if (known_rem > 0) {
__bang_write_value(nram_known, 3 * known_num_deal, (T)(INFINITY));
start_idx = taskId * core_offset + unknown_repeat * unknown_num_deal;
auxProcessSegment<T>(
m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist,
(T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset,
nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_num_deal, known_num_deal, known_repeat, unknown_repeat,
unknown_rem, known_rem, known_rem_align, start_idx, &deal_offset);
deal_offset += 3;
}
if (deal_offset > 3) {
auxFuncSort((T *)nram_aux_a, auxa_offset, nram_aux_b, auxb_offset,
(T *)nram_dest, (T *)nram_aux_sort_a, nram_aux_sort_b,
unknown_rem, deal_offset);
deal_offset = 0;
}
__memcpy((char *)output_dist_start +
unknown_repeat * unknown_num_deal * 3 * sizeof(T),
(char *)nram_aux_a, 3 * sizeof(T), NRAM2GDRAM, 3 * sizeof(T),
auxa_offset * sizeof(T), unknown_rem - 1);
__memcpy((char *)output_idx_start +
unknown_repeat * unknown_num_deal * 3 * sizeof(int),
(char *)nram_aux_b, 3 * sizeof(int), NRAM2GDRAM, 3 * sizeof(int),
auxb_offset * sizeof(int), unknown_rem - 1);
}
}
template __mlu_global__ void MLUUnion1KernelThreeNN<float>(
const int b, const int n, const int m, char *unknown_gdram,
char *known_gdram, char *dist2_gdram, int *idx_gdram);
template __mlu_global__ void MLUUnion1KernelThreeNN<half>(
const int b, const int n, const int m, char *unknown_gdram,
char *known_gdram, char *dist2_gdram, int *idx_gdram);
void KernelThreeNNForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *unknown, const void *known, void *dist2,
int *idx, const int b, const int n, const int m) {
switch (data_type) {
case CNRT_FLOAT16: {
MLUUnion1KernelThreeNN<half><<<k_dim, k_type, queue>>>(
b, n, m, (char *)unknown, (char *)known, (char *)dist2, idx);
}; break;
case CNRT_FLOAT32: {
MLUUnion1KernelThreeNN<float><<<k_dim, k_type, queue>>>(
b, n, m, (char *)unknown, (char *)known, (char *)dist2, idx);
}; break;
default: {
break;
}
}
}
mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu
deleted
100644 → 0
View file @
59c1418e
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
__nram__ char nram_buffer[MAX_NRAM_SIZE];
#if __BANG_ARCH__ >= 322
__mlu_func__ void computeDynamicVoxelize(
char *points_x, char *points_y, char *points_z, char *auxiliary_a,
char *auxiliary_b, char *auxiliary_c, const float coors_x_min,
const float coors_y_min, const float coors_z_min, const float voxel_x,
const float voxel_y, const float voxel_z, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z, const int32_t deal_num) {
// x - coors_x_min
__bang_sub_scalar((float *)points_x, (float *)points_x, coors_x_min,
deal_num);
// y - coors_y_min
__bang_sub_scalar((float *)points_y, (float *)points_y, coors_y_min,
deal_num);
// z - coors_z_min
__bang_sub_scalar((float *)points_z, (float *)points_z, coors_z_min,
deal_num);
// (x - coors_x_min) / voxel_x
__bang_mul_scalar((float *)points_x, (float *)points_x, 1.0 / voxel_x,
deal_num);
// (y - coors_y_min) / voxel_y
__bang_mul_scalar((float *)points_y, (float *)points_y, 1.0 / voxel_y,
deal_num);
// (z - coors_z_min) / voxel_z
__bang_mul_scalar((float *)points_z, (float *)points_z, 1.0 / voxel_z,
deal_num);
// c_x = floor((x - coors_x_min) / voxel_x)
__bang_floor((float *)auxiliary_a, (float *)points_x, deal_num);
__bang_float2int32((int32_t *)points_x, (float *)auxiliary_a, deal_num, 0);
// c_y = floor((y - coors_y_min) / voxel_y)
__bang_floor((float *)auxiliary_a, (float *)points_y, deal_num);
__bang_float2int32((int32_t *)points_y, (float *)auxiliary_a, deal_num, 0);
// c_z = floor((z - coors_z_min) / voxel_z)
__bang_floor((float *)auxiliary_a, (float *)points_z, deal_num);
__bang_float2int32((int32_t *)points_z, (float *)auxiliary_a, deal_num, 0);
// c_x >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_x, (int32_t)0,
deal_num);
// c_x < grid_x
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_x, grid_x,
deal_num);
// 0 <= c_x < grid_x
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// c_y >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_y, (int32_t)0,
deal_num);
// c_y < grid_y
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_y, grid_y,
deal_num);
// 0 <= c_y < grid_y
__bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// c_x >= 0 && c_x < grid_x && c_y >= 0 && c_y < grid_y
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)auxiliary_b, deal_num);
// c_z >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_z, (int32_t)0,
deal_num);
// c_z < grid_z
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_z, grid_z,
deal_num);
// 0 <= c_z < grid_z
__bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// 0 <= c_x < grid_x && 0 <= c_y < grid_y && 0 <= c_z < grid_z
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)auxiliary_b, deal_num);
__bang_not((int32_t *)auxiliary_c, (int32_t *)auxiliary_a, deal_num);
__bang_mul((int32_t *)points_x, (int32_t *)points_x, (int32_t *)auxiliary_a,
deal_num);
__bang_mul_scalar((int32_t *)auxiliary_b, (int32_t *)auxiliary_c,
(int32_t)(-1), deal_num);
__bang_add((int32_t *)points_x, (int32_t *)points_x, (int32_t *)auxiliary_b,
deal_num);
__bang_mul((int32_t *)points_y, (int32_t *)points_y, (int32_t *)auxiliary_a,
deal_num);
__bang_add((int32_t *)points_y, (int32_t *)points_y, (int32_t *)auxiliary_b,
deal_num);
__bang_mul((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_a,
deal_num);
__bang_add((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_b,
deal_num);
}
__mlu_func__ void computePoint2Voxel(char *coors_x, char *coors_y,
char *coors_z, const int32_t c_x,
const int32_t c_y, const int32_t c_z,
const int32_t max_points, int32_t *num,
int32_t *first_point,
const int32_t deal_idx,
const int32_t deal_num) {
__bang_eq_scalar((int32_t *)coors_x, (int32_t *)coors_x, c_x, deal_num);
__bang_eq_scalar((int32_t *)coors_y, (int32_t *)coors_y, c_y, deal_num);
__bang_eq_scalar((int32_t *)coors_z, (int32_t *)coors_z, c_z, deal_num);
__bang_mul((int32_t *)coors_x, (int32_t *)coors_x, (int32_t *)coors_y,
deal_num);
__bang_mul((int32_t *)coors_x, (int32_t *)coors_x, (int32_t *)coors_z,
deal_num);
if (*num == 0) {
*num = (int32_t)__bang_count((float *)coors_x, deal_num);
if (*num > 0) {
*first_point =
(int32_t)__bang_findfirst1((float *)coors_x, deal_num) + deal_idx;
}
} else {
*num += (int32_t)__bang_count((float *)coors_x, deal_num);
}
}
#endif
__mlu_global__ void MLUUnion1KernelDynamicVoxelize(
const float *points, int32_t *coors, const float voxel_x,
const float voxel_y, const float voxel_z, const float coors_x_min,
const float coors_y_min, const float coors_z_min, const float coors_x_max,
const float coors_y_max, const float coors_z_max, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z, const int32_t num_points,
const int32_t num_features) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
const int32_t points_rem = num_points % taskDim;
const int32_t points_per_core =
taskId < points_rem ? num_points / taskDim + 1 : num_points / taskDim;
const int32_t points_start = taskId < points_rem
? taskId * points_per_core
: taskId * points_per_core + points_rem;
const int32_t split_num = 9;
const int32_t deal_num =
PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(float), NFU_ALIGN_SIZE);
const int32_t repeat = points_per_core / deal_num;
const int32_t rem = points_per_core % deal_num;
const int32_t ping_pong_gap = 3 * deal_num * sizeof(float);
char *points_x = nram_buffer;
char *points_y = points_x + deal_num * sizeof(float);
char *points_z = points_y + deal_num * sizeof(float);
char *auxiliary_a = points_x + 2 * ping_pong_gap;
char *auxiliary_b = auxiliary_a + deal_num * sizeof(float);
char *auxiliary_c = auxiliary_b + deal_num * sizeof(float);
int32_t *coors_z_start = coors + points_start;
int32_t *coors_y_start = coors + num_points + points_start;
int32_t *coors_x_start = coors + num_points * 2 + points_start;
if (repeat > 0) {
__memcpy_async(points_x, points + points_start * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_y, points + points_start * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_z, points + points_start * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__asm__ volatile("sync;");
}
if (repeat > 1) {
__memcpy_async(points_x + ping_pong_gap,
points + (points_start + deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_y + ping_pong_gap,
points + (points_start + deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_z + ping_pong_gap,
points + (points_start + deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
computeDynamicVoxelize(points_x, points_y, points_z, auxiliary_a,
auxiliary_b, auxiliary_c, coors_x_min, coors_y_min,
coors_z_min, voxel_x, voxel_y, voxel_z, grid_x,
grid_y, grid_z, deal_num);
__asm__ volatile("sync;");
}
for (int32_t i = 0; i < repeat - 2; ++i) {
__memcpy_async(coors_x_start + i * deal_num,
points_x + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + i * deal_num,
points_y + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + i * deal_num,
points_z + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(points_x + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(
points_y + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
deal_num - 1);
__memcpy_async(
points_z + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
deal_num - 1);
computeDynamicVoxelize(points_x + ((i + 1) % 2) * ping_pong_gap,
points_y + ((i + 1) % 2) * ping_pong_gap,
points_z + ((i + 1) % 2) * ping_pong_gap,
auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min,
coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z,
grid_x, grid_y, grid_z, deal_num);
__asm__ volatile("sync;");
}
if (repeat >= 2) {
__memcpy_async(coors_x_start + (repeat - 2) * deal_num,
points_x + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + (repeat - 2) * deal_num,
points_y + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + (repeat - 2) * deal_num,
points_z + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
}
if (rem > 0) {
__memcpy_async(points_x + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), rem - 1);
__memcpy_async(
points_y + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
rem - 1);
__memcpy_async(
points_z + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
rem - 1);
}
if (repeat > 0) {
computeDynamicVoxelize(points_x + ((repeat - 1) % 2) * ping_pong_gap,
points_y + ((repeat - 1) % 2) * ping_pong_gap,
points_z + ((repeat - 1) % 2) * ping_pong_gap,
auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min,
coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z,
grid_x, grid_y, grid_z, deal_num);
}
__asm__ volatile("sync;");
if (repeat > 0) {
__memcpy_async(coors_x_start + (repeat - 1) * deal_num,
points_x + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + (repeat - 1) * deal_num,
points_y + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + (repeat - 1) * deal_num,
points_z + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
}
if (rem > 0) {
computeDynamicVoxelize(points_x + (repeat % 2) * ping_pong_gap,
points_y + (repeat % 2) * ping_pong_gap,
points_z + (repeat % 2) * ping_pong_gap, auxiliary_a,
auxiliary_b, auxiliary_c, coors_x_min, coors_y_min,
coors_z_min, voxel_x, voxel_y, voxel_z, grid_x,
grid_y, grid_z, rem);
__asm__ volatile("sync;");
__memcpy_async(coors_x_start + repeat * deal_num,
points_x + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + repeat * deal_num,
points_y + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + repeat * deal_num,
points_z + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
}
#endif
}
__mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors,
int32_t *point_to_pointidx,
int32_t *point_to_voxelidx,
const int32_t num_points,
const int32_t max_points) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
const int32_t split_num = 6;
const int32_t deal_num =
PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(int32_t), NFU_ALIGN_SIZE);
const int32_t ping_pong_gap = 3 * deal_num * sizeof(int32_t);
char *coors_x = nram_buffer;
char *coors_y = coors_x + deal_num * sizeof(int32_t);
char *coors_z = coors_y + deal_num * sizeof(int32_t);
int32_t *coors_z_start = coors;
int32_t *coors_y_start = coors + num_points;
int32_t *coors_x_start = coors + num_points * 2;
for (int32_t point_idx = taskId; point_idx < num_points;
point_idx += taskDim) {
if (coors_x_start[point_idx] == -1) {
point_to_pointidx[point_idx] = -1;
point_to_voxelidx[point_idx] = -1;
continue;
}
int32_t c_x = coors_x_start[point_idx];
int32_t c_y = coors_y_start[point_idx];
int32_t c_z = coors_z_start[point_idx];
int32_t deal_total_num = point_idx;
int32_t repeat = deal_total_num / deal_num;
int32_t rem = deal_total_num % deal_num;
int32_t num = 0;
int32_t first_point = -1;
if (repeat > 0) {
__memcpy_async(coors_x, coors_x_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_y, coors_y_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_z, coors_z_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__asm__ volatile("sync;");
}
for (int32_t i = 0; i < repeat - 1; ++i) {
__memcpy_async(coors_x + ((i + 1) % 2) * ping_pong_gap,
coors_x_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(coors_y + ((i + 1) % 2) * ping_pong_gap,
coors_y_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(coors_z + ((i + 1) % 2) * ping_pong_gap,
coors_z_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
computePoint2Voxel(
coors_x + (i % 2) * ping_pong_gap, coors_y + (i % 2) * ping_pong_gap,
coors_z + (i % 2) * ping_pong_gap, c_x, c_y, c_z, max_points, &num,
&first_point, i * deal_num, deal_num);
__asm__ volatile("sync;");
}
if (rem > 0) {
__memcpy_async(coors_x + (repeat % 2) * ping_pong_gap,
coors_x_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_y + (repeat % 2) * ping_pong_gap,
coors_y_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_z + (repeat % 2) * ping_pong_gap,
coors_z_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
}
if (repeat > 0) {
computePoint2Voxel(coors_x + ((repeat - 1) % 2) * ping_pong_gap,
coors_y + ((repeat - 1) % 2) * ping_pong_gap,
coors_z + ((repeat - 1) % 2) * ping_pong_gap, c_x, c_y,
c_z, max_points, &num, &first_point,
(repeat - 1) * deal_num, deal_num);
}
__asm__ volatile("sync;");
if (rem > 0) {
computePoint2Voxel(coors_x + (repeat % 2) * ping_pong_gap,
coors_y + (repeat % 2) * ping_pong_gap,
coors_z + (repeat % 2) * ping_pong_gap, c_x, c_y, c_z,
max_points, &num, &first_point, repeat * deal_num,
rem);
__asm__ volatile("sync;");
}
if (num == 0) {
point_to_pointidx[point_idx] = point_idx;
} else if (num > 0) {
point_to_pointidx[point_idx] = first_point;
}
if (num < max_points) {
point_to_voxelidx[point_idx] = num;
} else {
point_to_voxelidx[point_idx] = -1;
}
}
#endif
}
__mlu_global__ void MLUUnion1KernelCalcPointsPerVoxel(
int32_t *point_to_pointidx, int32_t *point_to_voxelidx,
int32_t *coor_to_voxelidx, int32_t *num_points_per_voxel,
int32_t *voxel_num, const int32_t max_voxels, const int32_t num_points) {
#if __BANG_ARCH__ >= 322
if (coreId == 0) {
int32_t voxel_num_temp = 0;
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
int32_t point_pos_in_voxel = point_to_voxelidx[point_idx];
coor_to_voxelidx[point_idx] = -1;
if (point_pos_in_voxel == -1) {
continue;
} else if (point_pos_in_voxel == 0) {
int32_t voxel_idx = voxel_num_temp;
if (voxel_num_temp >= max_voxels) {
continue;
}
voxel_num_temp += 1;
coor_to_voxelidx[point_idx] = voxel_idx;
num_points_per_voxel[voxel_idx] = 1;
} else {
int32_t point_idx_temp = point_to_pointidx[point_idx];
int32_t voxel_idx = coor_to_voxelidx[point_idx_temp];
if (voxel_idx != -1) {
coor_to_voxelidx[point_idx] = voxel_idx;
num_points_per_voxel[voxel_idx] += 1;
}
}
}
*voxel_num = voxel_num_temp;
}
#endif
}
__mlu_global__ void MLUUnion1KernelAssignVoxelsCoors(
const float *points, int32_t *temp_coors, int32_t *point_to_voxelidx,
int32_t *coor_to_voxelidx, float *voxels, int32_t *coors,
const int32_t max_points, const int32_t num_points,
const int32_t num_features) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
int32_t points_per_core = num_points / taskDim;
int32_t points_rem = num_points % taskDim;
int32_t points_start = taskId < points_rem
? taskId * (points_per_core + 1)
: taskId * points_per_core + points_rem;
int32_t points_end = taskId < points_rem ? points_start + points_per_core + 1
: points_start + points_per_core;
for (int32_t point_idx = points_start; point_idx < points_end; ++point_idx) {
int32_t num = point_to_voxelidx[point_idx];
int32_t voxel_idx = coor_to_voxelidx[point_idx];
if (num > -1 && voxel_idx > -1) {
float *voxels_offset =
voxels + voxel_idx * max_points * num_features + num * num_features;
const float *points_offset = points + point_idx * num_features;
__memcpy_async(voxels_offset, points_offset, num_features * sizeof(float),
GDRAM2GDRAM);
if (num == 0) {
int32_t *coors_offset = coors + voxel_idx * 3;
__memcpy_async(coors_offset, temp_coors + point_idx, sizeof(int32_t),
GDRAM2GDRAM, sizeof(int32_t),
num_points * sizeof(int32_t), 2);
}
}
}
__asm__ volatile("sync;");
#endif
}
void KernelDynamicVoxelize(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const void *points, void *coors,
const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min,
const float coors_y_min, const float coors_z_min,
const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z,
const int32_t num_points,
const int32_t num_features) {
MLUUnion1KernelDynamicVoxelize<<<k_dim, k_type, queue>>>(
(float *)points, (int32_t *)coors, voxel_x, voxel_y, voxel_z, coors_x_min,
coors_y_min, coors_z_min, coors_x_max, coors_y_max, coors_z_max, grid_x,
grid_y, grid_z, num_points, num_features);
}
void KernelPoint2Voxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *coors, void *point_to_pointidx,
void *point_to_voxelidx, const int32_t num_points,
const int32_t max_points) {
MLUUnion1KernelPoint2Voxel<<<k_dim, k_type, queue>>>(
(int32_t *)coors, (int32_t *)point_to_pointidx,
(int32_t *)point_to_voxelidx, num_points, max_points);
}
void KernelCalcPointsPerVoxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *point_to_pointidx,
void *point_to_voxelidx, void *coor_to_voxelidx,
void *num_points_per_voxel, void *voxel_num,
const int32_t max_voxels,
const int32_t num_points) {
MLUUnion1KernelCalcPointsPerVoxel<<<k_dim, k_type, queue>>>(
(int32_t *)point_to_pointidx, (int32_t *)point_to_voxelidx,
(int32_t *)coor_to_voxelidx, (int32_t *)num_points_per_voxel,
(int32_t *)voxel_num, max_voxels, num_points);
}
void KernelAssignVoxelsCoors(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const void *points,
void *temp_coors, void *point_to_voxelidx,
void *coor_to_voxelidx, void *voxels, void *coors,
const int32_t max_points, const int32_t num_points,
const int32_t num_features) {
MLUUnion1KernelAssignVoxelsCoors<<<k_dim, k_type, queue>>>(
(float *)points, (int32_t *)temp_coors, (int32_t *)point_to_voxelidx,
(int32_t *)coor_to_voxelidx, (float *)voxels, (int32_t *)coors,
max_points, num_points, num_features);
}
mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp
0 → 100644
View file @
0c23eb02
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
void
BoxIouRotatedMLUKernelLauncher
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
)
{
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
auto
boxes1_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes1
,
boxes1
.
suggest_memory_format
());
auto
boxes2_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes2
,
boxes2
.
suggest_memory_format
());
auto
ious_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
ious
,
ious
.
suggest_memory_format
());
MluOpTensorDescriptor
boxes1_desc
,
boxes2_desc
,
ious_desc
;
boxes1_desc
.
set
(
boxes1_contiguous
);
boxes2_desc
.
set
(
boxes2_contiguous
);
ious_desc
.
set
(
ious_contiguous
);
auto
boxes1_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes1_contiguous
);
auto
boxes2_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes2_contiguous
);
auto
ious_impl
=
torch_mlu
::
getMluTensorImpl
(
ious_contiguous
);
auto
boxes1_ptr
=
boxes1_impl
->
cnnlMalloc
();
auto
boxes2_ptr
=
boxes2_impl
->
cnnlMalloc
();
auto
ious_ptr
=
ious_impl
->
cnnlMalloc
();
CNLOG
(
INFO
)
<<
"Call mluOpBoxIouRotated()."
;
mluOpBoxIouRotated
(
handle
,
mode_flag
,
aligned
,
boxes1_desc
.
desc
(),
boxes1_ptr
,
boxes2_desc
.
desc
(),
boxes2_ptr
,
ious_desc
.
desc
(),
ious_ptr
);
}
void
box_iou_rotated_mlu
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
)
{
BoxIouRotatedMLUKernelLauncher
(
boxes1
,
boxes2
,
ious
,
mode_flag
,
aligned
);
}
void
box_iou_rotated_impl
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
);
REGISTER_DEVICE_IMPL
(
box_iou_rotated_impl
,
MLU
,
box_iou_rotated_mlu
);
mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp
View file @
0c23eb02
...
...
@@ -10,114 +10,30 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelIou3d
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
data_type_input
,
const
void
*
boxes_dram
,
const
int
input_box_num
,
const
float
iou_threshold
,
void
*
workspace
,
void
*
output_size
,
void
*
output
);
int
selectType
(
uint32_t
use_job
,
int
box_num_per_core
)
{
// the box_num_per_core should be at least 256, otherwise the real IO
// bandwidth would be very low
while
(
box_num_per_core
<
256
&&
use_job
>=
4
)
{
box_num_per_core
*=
2
;
use_job
/=
2
;
}
return
use_job
;
}
static
cnnlStatus_t
policyFunc
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
int
&
core_num_per_class
,
const
int
input_box_num
)
{
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
uint32_t
job_limit
=
getJobLimitCapability
();
uint32_t
core_number
=
job_limit
;
int
box_num_per_core
=
(
input_box_num
+
core_number
-
1
)
/
core_number
;
int
use_job
=
selectType
(
job_limit
,
box_num_per_core
);
// initiate k_type as Union1
k_dim
->
x
=
core_dim
;
k_dim
->
y
=
1
;
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
switch
(
job_limit
)
{
case
CN_KERNEL_CLASS_BLOCK
:
case
CN_KERNEL_CLASS_UNION
:
case
CN_KERNEL_CLASS_UNION2
:
case
CN_KERNEL_CLASS_UNION4
:
case
CN_KERNEL_CLASS_UNION8
:
case
CN_KERNEL_CLASS_UNION16
:
{
if
(
use_job
<
4
)
{
k_dim
->
x
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
}
else
if
(
use_job
==
4
)
{
k_dim
->
x
=
core_dim
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
else
{
k_dim
->
x
=
use_job
;
*
k_type
=
(
cnrtFunctionType_t
)
use_job
;
}
};
break
;
default:
LOG
(
WARNING
)
<<
"[cnnlNms_v2]: got unsupported job limit number."
<<
" Use default CN_KERNEL_CLASS_UNION1 with UNION1 task."
;
}
return
CNNL_STATUS_SUCCESS
;
}
#include "mlu_common_helper.h"
void
IoU3DNMS3DMLUKernelLauncher
(
Tensor
boxes
,
Tensor
&
keep
,
Tensor
&
keep_num
,
float
iou_threshold
)
{
// dimension parameters check
TORCH_CHECK
(
boxes
.
dim
()
==
2
,
"boxes should be a 2d tensor, got "
,
boxes
.
dim
(),
"D"
);
TORCH_CHECK
(
boxes
.
size
(
1
)
==
7
,
"boxes should have 7 elements in dimension 1, got "
,
boxes
.
size
(
1
));
// data type check
TORCH_CHECK
(
boxes
.
scalar_type
()
==
at
::
kFloat
||
boxes
.
scalar_type
()
==
at
::
kHalf
,
"data type of boxes should be Float or Half, got "
,
boxes
.
scalar_type
());
if
(
boxes
.
numel
()
==
0
)
{
return
;
}
const
size_t
max_input_num
=
2147483648
;
// 2^31, 2G num
TORCH_CHECK
(
boxes
.
numel
()
<
max_input_num
,
"boxes.numel() should be less than 2147483648, got "
,
boxes
.
numel
());
int
input_box_num
=
boxes
.
size
(
0
);
cnrtDataType_t
data_type_input
=
torch_mlu
::
toCnrtDtype
(
boxes
.
dtype
());
cnrtDim3_t
k_dim
;
cnrtJobType_t
k_type
;
int
core_num_per_class
;
policyFunc
(
&
k_dim
,
&
k_type
,
core_num_per_class
,
input_box_num
);
// transpose boxes (n, 7) to (7, n) for better performance
auto
boxes_t
=
boxes
.
transpose
(
0
,
1
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes_t
);
auto
output
=
at
::
empty
({
input_box_num
},
boxes
.
options
().
dtype
(
at
::
kLong
));
int
input_box_num
=
boxes
.
size
(
0
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes
);
auto
output
=
keep
.
to
(
boxes
.
options
().
dtype
(
at
::
kInt
));
auto
output_size
=
at
::
empty
({
1
},
boxes
.
options
().
dtype
(
at
::
kInt
));
// workspace
const
int
info_num
=
7
;
// x, y,z, dx, dy, dz,angle
size_t
space_size
=
0
;
if
(
boxes
.
scalar_type
()
==
at
::
kHalf
)
{
space_size
=
input_box_num
*
sizeof
(
int16_t
)
*
info_num
+
input_box_num
*
sizeof
(
float
)
+
sizeof
(
float
);
}
else
{
space_size
=
input_box_num
*
sizeof
(
float
)
*
(
info_num
+
1
)
+
sizeof
(
float
);
}
MluOpTensorDescriptor
boxes_desc
,
output_desc
;
boxes_desc
.
set
(
boxes_
);
output_desc
.
set
(
output
);
auto
workspace
=
at
::
empty
(
space_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
// workspace
size_t
workspace_size
=
0
;
auto
handle
=
mluOpGetCurrentHandle
();
mluOpGetNmsWorkspaceSize
(
handle
,
boxes_desc
.
desc
(),
NULL
,
&
workspace_size
);
auto
workspace
=
at
::
empty
(
workspace_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
boxes_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes_
);
auto
boxes_ptr
=
boxes_impl
->
cnnlMalloc
();
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
...
...
@@ -127,11 +43,29 @@ void IoU3DNMS3DMLUKernelLauncher(Tensor boxes, Tensor &keep, Tensor &keep_num,
auto
output_size_impl
=
torch_mlu
::
getMluTensorImpl
(
keep_num
);
auto
output_size_ptr
=
output_size_impl
->
cnnlMalloc
();
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
CNLOG
(
INFO
)
<<
"Launch Kernel KernelIou3d<<<Union"
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelIou3d
(
k_dim
,
k_type
,
queue
,
data_type_input
,
boxes_ptr
,
input_box_num
,
iou_threshold
,
workspace_ptr
,
output_size_ptr
,
output_ptr
);
// nms desc
mluOpNmsDescriptor_t
nms_desc
;
const
mluOpNmsBoxPointMode_t
box_mode
=
(
mluOpNmsBoxPointMode_t
)
0
;
const
mluOpNmsOutputMode_t
output_mode
=
(
mluOpNmsOutputMode_t
)
0
;
const
mluOpNmsAlgo_t
algo
=
(
mluOpNmsAlgo_t
)
0
;
const
mluOpNmsMethodMode_t
method_mode
=
(
mluOpNmsMethodMode_t
)
0
;
const
float
soft_nms_sigma
=
0.0
;
const
float
confidence_threshold
=
0.0
;
const
int
input_layout
=
0
;
const
bool
pad_to_max_output_size
=
false
;
const
int
max_output_size
=
input_box_num
;
const
float
offset
=
0.0
;
mluOpCreateNmsDescriptor
(
&
nms_desc
);
mluOpSetNmsDescriptor
(
nms_desc
,
box_mode
,
output_mode
,
algo
,
method_mode
,
iou_threshold
,
soft_nms_sigma
,
max_output_size
,
confidence_threshold
,
offset
,
input_layout
,
pad_to_max_output_size
);
mluOpNms
(
handle
,
nms_desc
,
boxes_desc
.
desc
(),
boxes_ptr
,
NULL
,
NULL
,
workspace_ptr
,
workspace_size
,
output_desc
.
desc
(),
output_ptr
,
output_size_ptr
);
mluOpDestroyNmsDescriptor
(
nms_desc
);
}
void
iou3d_nms3d_forward_mlu
(
const
Tensor
boxes
,
Tensor
&
keep
,
Tensor
&
keep_num
,
...
...
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
View file @
0c23eb02
...
...
@@ -18,8 +18,8 @@
#include "pytorch_device_registry.hpp"
#define MLUOP_MAJOR 0
#define MLUOP_MINOR
5
#define MLUOP_PATCHLEVEL
302
#define MLUOP_MINOR
6
#define MLUOP_PATCHLEVEL
0
mluOpDataType_t
getMluOpDataType
(
const
caffe2
::
TypeMeta
&
data_type
);
mluOpTensorLayout_t
getMluOpSuggestLayout
(
const
at
::
Tensor
&
input
);
...
...
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
View file @
0c23eb02
...
...
@@ -9,495 +9,117 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
typedef
enum
{
MS_DEFORM_ATTN_FORWARD_INVALID
=
0
,
/*!< Index is invalid. */
MS_DEFORM_ATTN_FORWARD_DEFAULT
=
1
,
/*!< MLUKernelMsDeformAttnForwardDefault */
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL
=
2
,
/*!< MLUKernelMsDeformAttnForwardSmallChannel */
}
MsDeformAttnForwardPolicy
;
void
KernelMsDeformAttnForwardDefault
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
char
*
data_value_gdram
,
const
char
*
data_spatial_shapes_gdram
,
const
char
*
data_level_start_index_gdram
,
const
char
*
data_sampling_loc_gdram
,
const
char
*
data_attn_weight_gdram
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
,
char
*
data_col_gdram
);
void
KernelMsDeformAttnForwardSmallChannel
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
char
*
data_value_gdram
,
const
char
*
data_spatial_shapes_gdram
,
const
char
*
data_level_start_index_gdram
,
const
char
*
data_sampling_loc_gdram
,
const
char
*
data_attn_weight_gdram
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
,
char
*
data_col_gdram
);
typedef
enum
{
MS_DEFORM_ATTN_BACKWARD_DEFAULT
=
0
,
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL
=
1
,
}
MsDeformAttnBackwardKernelPolicy
;
MsDeformAttnBackwardKernelPolicy
msDeformAttnBackwardPolicyFunc
(
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_points
,
const
int32_t
num_heads
)
{
const
int32_t
nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
const
int
num_hlp
=
num_heads
*
num_levels
*
num_points
;
int
num_per_time_theory
=
(
nram_size
-
num_levels
*
sizeof
(
float
)
-
3
*
num_levels
*
sizeof
(
int32_t
))
/
sizeof
(
float
)
/
(
8
*
PAD_UP
(
channels
,
32
)
+
28
)
/
PAD_UP
((
num_hlp
),
32
);
if
(
num_per_time_theory
>=
1
)
{
return
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL
;
}
return
MS_DEFORM_ATTN_BACKWARD_DEFAULT
;
}
void
KernelMsDeformAttnBackwardDefaultKernel
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
float
*
data_value
,
const
int32_t
*
spatial_shapes
,
const
int32_t
*
data_level_start_index
,
const
float
*
data_sampling_loc
,
const
float
*
data_attn_weight
,
const
float
*
grad_output
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
,
float
*
grad_value
,
float
*
grad_sampling_loc
,
float
*
grad_attn_weight
);
void
KernelMsDeformAttnBackwardSmallChannelsKernel
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
float
*
data_value
,
const
int32_t
*
spatial_shapes
,
const
int32_t
*
data_level_start_index
,
const
float
*
data_sampling_loc
,
const
float
*
data_attn_weight
,
const
float
*
grad_output
,
const
int32_t
batch
,
const
int32_t
spatial_size
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_query
,
const
int32_t
num_points
,
float
*
grad_value
,
float
*
grad_sampling_loc
,
float
*
grad_attn_weight
);
// policy function
MsDeformAttnForwardPolicy
msDeformAttnForwardPolicyFunc
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
)
{
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
MIN
((
batch_size
*
num_queries
*
num_heads
+
k_dim
->
x
-
1
)
/
k_dim
->
x
,
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
));
k_dim
->
z
=
1
;
#if __BANG_ARCH__ == 520
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
#else
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
#endif
int32_t
nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
if
(
num_levels
*
num_points
*
3
*
sizeof
(
int32_t
)
>
nram_size
)
{
return
MS_DEFORM_ATTN_FORWARD_DEFAULT
;
}
else
if
(
channels
>
nram_size
/
12
/
sizeof
(
float
)
||
channels
>
96
||
channels
<
16
)
{
return
MS_DEFORM_ATTN_FORWARD_DEFAULT
;
}
else
{
return
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL
;
}
}
// policy function for backward
static
void
policyFuncBackward
(
const
int32_t
batch_size
,
const
int32_t
num_queries
,
const
int32_t
num_heads
,
const
int32_t
num_levels
,
cnrtFunctionType_t
*
k_type
,
cnrtDim3_t
*
k_dim
)
{
size_t
cluster_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
size_t
core_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
x
=
core_limit
;
int32_t
total_num
=
batch_size
*
num_queries
*
num_heads
*
num_levels
;
size_t
total_num_align
=
CEIL_ALIGN
(
total_num
,
core_limit
);
k_dim
->
y
=
(
total_num_align
/
core_limit
)
>
cluster_limit
?
cluster_limit
:
(
total_num_align
/
core_limit
);
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
Tensor
ms_deform_attn_mlu_forward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
// check contiguous
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
// check datatype
TORCH_CHECK
((
value
.
scalar_type
()
==
at
::
kFloat
),
"value type should be Float, got "
,
value
.
scalar_type
(),
"."
);
TORCH_CHECK
((
spatial_shapes
.
scalar_type
()
==
at
::
kInt
||
spatial_shapes
.
scalar_type
()
==
at
::
kLong
),
"spatial_shapes type should be Int, got "
,
spatial_shapes
.
scalar_type
(),
"."
);
TORCH_CHECK
((
level_start_index
.
scalar_type
()
==
at
::
kInt
||
level_start_index
.
scalar_type
()
==
at
::
kLong
),
"level_start_index type should be Int, got "
,
level_start_index
.
scalar_type
(),
"."
);
TORCH_CHECK
((
sampling_loc
.
scalar_type
()
==
at
::
kFloat
),
"sampling_loc type should be Float, got "
,
sampling_loc
.
scalar_type
(),
"."
);
TORCH_CHECK
((
attn_weight
.
scalar_type
()
==
at
::
kFloat
),
"attn_weight type should be Float, got "
,
attn_weight
.
scalar_type
(),
"."
);
// check shape
TORCH_CHECK
(
value
.
dim
()
==
4
,
"value should be a 4d tensor, got "
,
value
.
dim
(),
"D."
);
TORCH_CHECK
(
spatial_shapes
.
dim
()
==
2
,
"spatial_shapes should be a 2d tensor, got "
,
spatial_shapes
.
dim
(),
"D."
);
TORCH_CHECK
(
level_start_index
.
dim
()
==
1
,
"level_start_index should be a 1d tensor, got "
,
level_start_index
.
dim
(),
"D."
);
TORCH_CHECK
(
sampling_loc
.
dim
()
==
6
,
"sampling_loc should be a 6d tensor, got "
,
sampling_loc
.
dim
(),
"D."
);
TORCH_CHECK
(
attn_weight
.
dim
()
==
5
,
"attn_weight should be a 5d tensor, got "
,
attn_weight
.
dim
(),
"D."
);
/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();
Tensor
MsDeformAttnForwardLauncher
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
auto
handle
=
mluOpGetCurrentHandle
();
const
int
batch_size
=
value
.
size
(
0
);
const
int
num_keys
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_queries
=
sampling_loc
.
size
(
1
);
const
int
num_points
=
sampling_loc
.
size
(
4
);
TORCH_CHECK
(
spatial_shapes
.
size
(
1
)
==
2
,
"the 2nd dimensions of spatial_shapes should be 2, got "
,
spatial_shapes
.
size
(
1
),
"."
);
TORCH_CHECK
(
sampling_loc
.
size
(
5
)
==
2
,
"the 6th dimensions of sampling_loc should be 2, got "
,
sampling_loc
.
size
(
5
),
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of sampling_loc should be batch_size, "
,
"but now the 1st dimension of sampling_loc is "
,
sampling_loc
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of attn_weight should be batch_size, "
,
"but now the 1st dimension of attn_weight is "
,
attn_weight
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of sampling_loc should be num_heads, "
,
"but now the 3rd dimension of sampling_loc is "
,
sampling_loc
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of attn_weight should be num_heads, "
,
"but now the 3rd dimension of attn_weight is "
,
attn_weight
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
level_start_index
.
size
(
0
)
==
num_levels
),
"the 1st dimensions of level_start_index should be num_levels, "
,
"but now the 1st dimension of level_start_index is "
,
level_start_index
.
size
(
0
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of sampling_loc should be num_levels, "
,
"but now the 4th dimension of sampling_loc is "
,
sampling_loc
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of attn_weight should be num_levels, "
,
"but now the 4th dimension of attn_weight is "
,
attn_weight
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of attn_weight should be num_queries, "
,
"but now the 2nd dimension of attn_weight is "
,
attn_weight
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
4
)
==
num_points
),
"the 5th dimensions of attn_weight should be num_points, "
,
"but now the 5th dimension of attn_weight is "
,
attn_weight
.
size
(
4
),
", and num_points is "
,
num_points
,
"."
);
auto
output
=
at
::
zeros
({
batch_size
,
num_queries
,
num_heads
,
channels
},
value
.
options
());
// large tensor check
const
size_t
max_input_size
=
2147483648
;
TORCH_CHECK
(
value
.
numel
()
<
max_input_size
,
"value element num should be less than 2^31, got "
,
value
.
numel
(),
"."
);
TORCH_CHECK
(
sampling_loc
.
numel
()
<
max_input_size
,
"sampling_loc element num should be less than 2^31, got "
,
sampling_loc
.
numel
(),
"."
);
TORCH_CHECK
(
output
.
numel
()
<
max_input_size
,
"output element num should be less than 2^31, got "
,
output
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
batch_size
!=
0
,
"batch_size should not be zero"
);
TORCH_CHECK
(
num_heads
!=
0
,
"num_heads should not be zero"
);
TORCH_CHECK
(
channels
!=
0
,
"channels should not be zero"
);
TORCH_CHECK
(
num_queries
!=
0
,
"num_queries should not be zero"
);
if
(
num_keys
==
0
||
num_levels
==
0
||
num_points
==
0
)
{
return
output
;
}
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
MsDeformAttnForwardPolicy
policy
=
msDeformAttnForwardPolicyFunc
(
&
k_dim
,
&
k_type
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
spatial_shapes_
=
spatial_shapes
.
to
(
at
::
kInt
);
auto
level_start_index_
=
level_start_index
.
to
(
at
::
kInt
);
// get ptr of tensors
auto
value_impl
=
torch_mlu
::
getMluTensorImpl
(
value
);
auto
value_ptr
=
value_impl
->
cnnlMalloc
();
auto
spatial_shapes_impl
=
torch_mlu
::
getMluTensorImpl
(
spatial_shapes_
);
auto
spatial_shapes_ptr
=
spatial_shapes_impl
->
cnnlMalloc
();
auto
level_start_index_impl
=
torch_mlu
::
getMluTensorImpl
(
level_start_index_
);
auto
level_start_index_ptr
=
level_start_index_impl
->
cnnlMalloc
();
auto
sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
sampling_loc
);
auto
sampling_loc_ptr
=
sampling_loc_impl
->
cnnlMalloc
();
auto
attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
attn_weight
);
auto
attn_weight_ptr
=
attn_weight_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
// get compute dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
value
.
dtype
());
// launch kernel
switch
(
policy
)
{
default:
{
VLOG
(
5
)
<<
"MsDeformAttnForward Policy not supported"
;
};
break
;
case
MS_DEFORM_ATTN_FORWARD_DEFAULT
:
{
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelMsDeformAttnForwardDefault
(
k_dim
,
k_type
,
queue
,
data_type
,
(
char
*
)
value_ptr
,
(
char
*
)
spatial_shapes_ptr
,
(
char
*
)
level_start_index_ptr
,
(
char
*
)
sampling_loc_ptr
,
(
char
*
)
attn_weight_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
char
*
)
output_ptr
);
break
;
}
case
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL
:
{
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelMsDeformAttnForwardSmallChannel
(
k_dim
,
k_type
,
queue
,
data_type
,
(
char
*
)
value_ptr
,
(
char
*
)
spatial_shapes_ptr
,
(
char
*
)
level_start_index_ptr
,
(
char
*
)
sampling_loc_ptr
,
(
char
*
)
attn_weight_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
char
*
)
output_ptr
);
break
;
}
}
auto
spatial_shapes_int
=
spatial_shapes
.
to
(
at
::
kInt
);
auto
level_start_index_int
=
level_start_index
.
to
(
at
::
kInt
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
output
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
value
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
spatial_shapes_int
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
level_start_index_int
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
sampling_loc
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
attn_weight
);
mluOpMsDeformAttnForward
(
handle
,
value_desc
.
desc
(),
value_ptr
,
spatial_shapes_int_desc
.
desc
(),
spatial_shapes_int_ptr
,
level_start_index_int_desc
.
desc
(),
level_start_index_int_ptr
,
sampling_loc_desc
.
desc
(),
sampling_loc_ptr
,
attn_weight_desc
.
desc
(),
attn_weight_ptr
,
im2col_step
,
output_desc
.
desc
(),
output_ptr
);
output
=
output
.
view
({
batch_size
,
num_queries
,
num_heads
*
channels
});
return
output
;
}
void
ms_d
eform
_a
ttn
_mlu_b
ackward
(
void
MsD
eform
A
ttn
B
ackward
Launcher
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
)
{
// check contiguous
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
AT_ASSERTM
(
grad_output
.
is_contiguous
(),
"grad_output tensor has to be contiguous"
);
// check datatype
TORCH_CHECK
((
value
.
scalar_type
()
==
at
::
kFloat
),
"value type should be Float, got "
,
value
.
scalar_type
(),
"."
);
TORCH_CHECK
((
spatial_shapes
.
scalar_type
()
==
at
::
kInt
||
spatial_shapes
.
scalar_type
()
==
at
::
kLong
),
"spatial_shapes type should be Int, got "
,
spatial_shapes
.
scalar_type
(),
"."
);
TORCH_CHECK
((
level_start_index
.
scalar_type
()
==
at
::
kInt
||
level_start_index
.
scalar_type
()
==
at
::
kLong
),
"level_start_index type should be Int, got "
,
level_start_index
.
scalar_type
(),
"."
);
TORCH_CHECK
((
sampling_loc
.
scalar_type
()
==
at
::
kFloat
),
"sampling_loc type should be Float, got "
,
sampling_loc
.
scalar_type
(),
"."
);
TORCH_CHECK
((
attn_weight
.
scalar_type
()
==
at
::
kFloat
),
"attn_weight type should be Float, got "
,
attn_weight
.
scalar_type
(),
"."
);
TORCH_CHECK
((
grad_output
.
scalar_type
()
==
at
::
kFloat
),
"grad_output type should be Float, got "
,
grad_output
.
scalar_type
(),
"."
);
auto
handle
=
mluOpGetCurrentHandle
();
auto
spatial_shapes_int
=
spatial_shapes
.
to
(
at
::
kInt
);
auto
level_start_index_int
=
level_start_index
.
to
(
at
::
kInt
);
const
int
batch_size
=
value
.
size
(
0
);
const
int
num_keys
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_queries
=
sampling_loc
.
size
(
1
);
const
int
num_points
=
sampling_loc
.
size
(
4
);
// Check shape.
TORCH_CHECK
(
spatial_shapes
.
size
(
1
)
==
2
,
"the 2nd dimensions of spatial_shapes should be 2, got "
,
spatial_shapes
.
size
(
1
),
"."
);
TORCH_CHECK
((
level_start_index
.
size
(
0
)
==
num_levels
),
"the 1st dimensions of level_start_index should be num_levels, "
,
"but now the 1st dimension of level_start_index is "
,
level_start_index
.
size
(
0
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of sampling_loc should be batch_size, "
,
"but now the 1st dimension of sampling_loc is "
,
sampling_loc
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of sampling_loc should be num_heads, "
,
"but now the 3rd dimension of sampling_loc is "
,
sampling_loc
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of sampling_loc should be num_levels, "
,
"but now the 4th dimension of sampling_loc is "
,
sampling_loc
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
(
sampling_loc
.
size
(
5
)
==
2
,
"the 6th dimensions of sampling_loc should be 2, got "
,
sampling_loc
.
size
(
5
),
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of attn_weight should be batch_size, "
,
"but now the 1st dimension of attn_weight is "
,
attn_weight
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of attn_weight should be num_queries, "
,
"but now the 2nd dimension of attn_weight is "
,
attn_weight
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of attn_weight should be num_heads, "
,
"but now the 3rd dimension of attn_weight is "
,
attn_weight
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of attn_weight should be num_levels, "
,
"but now the 4th dimension of attn_weight is "
,
attn_weight
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
4
)
==
num_points
),
"the 5th dimensions of attn_weight should be num_points, "
,
"but now the 5th dimension of attn_weight is "
,
attn_weight
.
size
(
4
),
", and num_points is "
,
num_points
,
"."
);
TORCH_CHECK
((
grad_output
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of grad_output should be batch_size, "
,
"but now the 1st dimension of grad_output is "
,
grad_output
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
grad_output
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of grad_output should be num_queries, "
,
"but now the 2nd dimension of grad_output is "
,
grad_output
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
(
(
grad_output
.
size
(
2
)
==
num_heads
*
channels
),
"the 3rd dimensions of grad_output should be num_heads * channels, "
,
"but now the 3rd dimension of grad_output is "
,
grad_output
.
size
(
2
),
", and num_heads * channels is "
,
num_heads
*
channels
,
"."
);
// check zero element
TORCH_CHECK
(
batch_size
!=
0
,
"The batch_size is zero."
);
TORCH_CHECK
(
channels
!=
0
,
"The channels is zero."
);
TORCH_CHECK
(
num_keys
!=
0
,
"The num_keys is zero."
);
TORCH_CHECK
(
num_heads
!=
0
,
"The num_heads is zero."
);
TORCH_CHECK
(
num_queries
!=
0
,
"The num_queries is zero."
);
if
(
num_levels
==
0
||
num_points
==
0
)
{
return
;
}
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncBackward
(
batch_size
,
num_queries
,
num_heads
,
num_levels
,
&
k_type
,
&
k_dim
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
value_impl
=
torch_mlu
::
getMluTensorImpl
(
value
);
auto
value_ptr
=
value_impl
->
cnnlMalloc
();
auto
spatial_shapes_impl
=
torch_mlu
::
getMluTensorImpl
(
spatial_shapes
);
auto
spatial_shapes_ptr
=
spatial_shapes_impl
->
cnnlMalloc
();
auto
level_start_index_impl
=
torch_mlu
::
getMluTensorImpl
(
level_start_index
);
auto
level_start_index_ptr
=
level_start_index_impl
->
cnnlMalloc
();
auto
sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
sampling_loc
);
auto
sampling_loc_ptr
=
sampling_loc_impl
->
cnnlMalloc
();
auto
attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
attn_weight
);
auto
attn_weight_ptr
=
attn_weight_impl
->
cnnlMalloc
();
auto
grad_output_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_output
);
auto
grad_output_ptr
=
grad_output_impl
->
cnnlMalloc
();
auto
grad_value_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_value
);
auto
grad_value_ptr
=
grad_value_impl
->
cnnlMalloc
();
auto
grad_sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_sampling_loc
);
auto
grad_sampling_loc_ptr
=
grad_sampling_loc_impl
->
cnnlMalloc
();
auto
grad_attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_attn_weight
);
auto
grad_attn_weight_ptr
=
grad_attn_weight_impl
->
cnnlMalloc
();
auto
grad_output_dim4
=
grad_output
.
view
({
batch_size
,
num_queries
,
num_heads
,
channels
});
// auto grad_output_dim4 = grad_output.view({batch_size, num_queries,
// num_heads, channels}).detach();
INITIAL_MLU_PARAM_WITH_TENSOR
(
value
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
spatial_shapes_int
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
level_start_index_int
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
sampling_loc
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
attn_weight
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_output_dim4
);
// INITIAL_MLU_PARAM_WITH_TENSOR(grad_output);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_value
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_sampling_loc
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_attn_weight
);
mluOpMsDeformAttnBackward
(
handle
,
value_desc
.
desc
(),
value_ptr
,
spatial_shapes_int_desc
.
desc
(),
spatial_shapes_int_ptr
,
level_start_index_int_desc
.
desc
(),
level_start_index_int_ptr
,
sampling_loc_desc
.
desc
(),
sampling_loc_ptr
,
attn_weight_desc
.
desc
(),
attn_weight_ptr
,
grad_output_dim4_desc
.
desc
(),
grad_output_dim4_ptr
,
im2col_step
,
grad_value_desc
.
desc
(),
grad_value_ptr
,
grad_sampling_loc_desc
.
desc
(),
grad_sampling_loc_ptr
,
grad_attn_weight_desc
.
desc
(),
grad_attn_weight_ptr
);
return
;
}
// get comput dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
value
.
dtype
());
Tensor
ms_deform_attn_mlu_forward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
return
MsDeformAttnForwardLauncher
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
im2col_step
);
}
// launch kernel
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelMsDeformAttnBackward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
MsDeformAttnBackwardKernelPolicy
kernelPolicy
=
msDeformAttnBackwardPolicyFunc
(
channels
,
num_levels
,
num_points
,
num_heads
);
switch
(
kernelPolicy
)
{
default:
{
VLOG
(
5
)
<<
"NotImplemented."
;
}
break
;
case
MS_DEFORM_ATTN_BACKWARD_DEFAULT
:
{
KernelMsDeformAttnBackwardDefaultKernel
(
k_dim
,
k_type
,
queue
,
data_type
,
(
float
*
)
value_ptr
,
(
int32_t
*
)
spatial_shapes_ptr
,
(
int32_t
*
)
level_start_index_ptr
,
(
float
*
)
sampling_loc_ptr
,
(
float
*
)
attn_weight_ptr
,
(
float
*
)
grad_output_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
float
*
)
grad_value_ptr
,
(
float
*
)
grad_sampling_loc_ptr
,
(
float
*
)
grad_attn_weight_ptr
);
}
break
;
case
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL
:
{
KernelMsDeformAttnBackwardSmallChannelsKernel
(
k_dim
,
k_type
,
queue
,
data_type
,
(
float
*
)
value_ptr
,
(
int32_t
*
)
spatial_shapes_ptr
,
(
int32_t
*
)
level_start_index_ptr
,
(
float
*
)
sampling_loc_ptr
,
(
float
*
)
attn_weight_ptr
,
(
float
*
)
grad_output_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
float
*
)
grad_value_ptr
,
(
float
*
)
grad_sampling_loc_ptr
,
(
float
*
)
grad_attn_weight_ptr
);
}
break
;
}
void
ms_deform_attn_mlu_backward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
)
{
return
MsDeformAttnBackwardLauncher
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
grad_output
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
);
}
Tensor
ms_deform_attn_impl_forward
(
const
Tensor
&
value
,
...
...
@@ -515,5 +137,6 @@ void ms_deform_attn_impl_backward(
REGISTER_DEVICE_IMPL
(
ms_deform_attn_impl_forward
,
MLU
,
ms_deform_attn_mlu_forward
);
REGISTER_DEVICE_IMPL
(
ms_deform_attn_impl_backward
,
MLU
,
ms_deform_attn_mlu_backward
);
mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp
View file @
0c23eb02
...
...
@@ -10,123 +10,35 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelNms
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
data_type_input
,
const
void
*
boxes_ptr
,
const
void
*
scores_ptr
,
const
int
input_num_boxes
,
const
int
max_output_boxes
,
const
float
iou_threshold
,
const
float
offset
,
void
*
workspace_ptr
,
void
*
output_size_ptr
,
void
*
output_ptr
);
int
selectUnionType
(
uint32_t
use_job
,
int
box_num_per_core
)
{
// the box_num_per_core should be at least 256, otherwise the real IO
// bandwidth would be very low
while
(
box_num_per_core
<
256
&&
use_job
>=
4
)
{
box_num_per_core
*=
2
;
use_job
/=
2
;
}
return
use_job
;
}
static
cnnlStatus_t
policyFunc
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
int
&
core_num_per_class
,
const
int
input_box_num
)
{
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
uint32_t
cluster_number
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
uint32_t
job_limit
=
getJobLimitCapability
();
uint32_t
core_number
=
job_limit
;
int
box_num_per_core
=
(
input_box_num
+
core_number
-
1
)
/
core_number
;
int
use_job
=
selectUnionType
(
job_limit
,
box_num_per_core
);
// initiate k_type as Union1
k_dim
->
x
=
core_dim
;
k_dim
->
y
=
1
;
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
switch
(
job_limit
)
{
case
CN_KERNEL_CLASS_BLOCK
:
case
CN_KERNEL_CLASS_UNION
:
case
CN_KERNEL_CLASS_UNION2
:
case
CN_KERNEL_CLASS_UNION4
:
case
CN_KERNEL_CLASS_UNION8
:
case
CN_KERNEL_CLASS_UNION16
:
{
if
(
use_job
<
4
)
{
k_dim
->
x
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
}
else
if
(
use_job
==
4
)
{
k_dim
->
x
=
core_dim
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
else
{
k_dim
->
x
=
use_job
;
*
k_type
=
(
cnrtFunctionType_t
)
use_job
;
}
};
break
;
default:
LOG
(
WARNING
)
<<
"[cnnlNms_v2]: got unsupported job limit number."
<<
" Use default CN_KERNEL_CLASS_UNION1 with UNION1 task."
;
}
return
CNNL_STATUS_SUCCESS
;
}
#include "mlu_common_helper.h"
Tensor
NMSMLUKernelLauncher
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
// dimension parameters check
TORCH_CHECK
(
boxes
.
dim
()
==
2
,
"boxes should be a 2d tensor, got "
,
boxes
.
dim
(),
"D"
);
TORCH_CHECK
(
boxes
.
size
(
1
)
==
4
,
"boxes should have 4 elements in dimension 1, got "
,
boxes
.
size
(
1
));
TORCH_CHECK
(
scores
.
dim
()
==
1
,
"scores should be a 1d tensor, got "
,
scores
.
dim
(),
"D"
);
// data type check
TORCH_CHECK
(
boxes
.
scalar_type
()
==
scores
.
scalar_type
(),
"boxes should have the same type as scores"
);
TORCH_CHECK
(
boxes
.
scalar_type
()
==
at
::
kFloat
||
boxes
.
scalar_type
()
==
at
::
kHalf
,
"data type of boxes should be Float or Half, got "
,
boxes
.
scalar_type
());
if
(
boxes
.
numel
()
==
0
)
{
return
at
::
empty
({
0
},
boxes
.
options
().
dtype
(
at
::
kLong
));
}
int
input_num_boxes
=
boxes
.
size
(
0
);
int
max_output_boxes
=
boxes
.
size
(
0
);
cnrtDataType_t
data_type_input
=
torch_mlu
::
toCnrtDtype
(
boxes
.
dtype
());
cnrtDim3_t
k_dim
;
cnrtJobType_t
k_type
;
int
core_num_per_class
;
policyFunc
(
&
k_dim
,
&
k_type
,
core_num_per_class
,
input_num_boxes
);
// transpose boxes (n, 4) to (4, n) for better performance
auto
boxes_t
=
boxes
.
transpose
(
0
,
1
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes_t
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes
);
auto
scores_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
scores
);
auto
output
=
at
::
empty
({
max_output_boxes
},
boxes
.
options
().
dtype
(
at
::
k
Long
));
auto
output
=
at
::
empty
({
max_output_boxes
},
boxes
.
options
().
dtype
(
at
::
k
Int
));
auto
output_size
=
at
::
empty
({
1
},
scores
.
options
().
dtype
(
at
::
kInt
));
MluOpTensorDescriptor
boxes_desc
,
scores_desc
,
output_desc
;
boxes_desc
.
set
(
boxes_
);
scores_desc
.
set
(
scores_
);
output_desc
.
set
(
output
);
// workspace
const
int
info_num
=
5
;
// x1, x2, y1, y2 and score
size_t
space_size
=
0
;
if
(
boxes
.
scalar_type
()
==
at
::
kHalf
)
{
space_size
=
input_num_boxes
*
sizeof
(
int16_t
)
*
info_num
+
sizeof
(
float
);
}
else
{
space_size
=
input_num_boxes
*
sizeof
(
float
)
*
info_num
+
sizeof
(
float
);
}
#if __BANG_ARCH__ > 370
int
cluster_num
=
getCoreNumOfJobLimitCapability
()
/
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
space_size
+=
cluster_number
*
sizeof
(
float
)
*
7
;
#endif
auto
workspace
=
at
::
empty
(
space_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
size_t
workspace_size
=
0
;
auto
handle
=
mluOpGetCurrentHandle
();
mluOpGetNmsWorkspaceSize
(
handle
,
boxes_desc
.
desc
(),
scores_desc
.
desc
(),
&
workspace_size
);
auto
workspace
=
at
::
empty
(
workspace_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
boxes_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes_
);
auto
boxes_ptr
=
boxes_impl
->
cnnlMalloc
();
auto
scores_impl
=
torch_mlu
::
getMluTensorImpl
(
scores_
);
...
...
@@ -138,14 +50,31 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
auto
output_size_impl
=
torch_mlu
::
getMluTensorImpl
(
output_size
);
auto
output_size_ptr
=
output_size_impl
->
cnnlMalloc
();
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
CNLOG
(
INFO
)
<<
"Launch Kernel MLUUnionX NMS<<<Union"
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelNms
(
k_dim
,
k_type
,
queue
,
data_type_input
,
boxes_ptr
,
scores_ptr
,
input_num_boxes
,
max_output_boxes
,
iou_threshold
,
offset
,
workspace_ptr
,
output_size_ptr
,
output_ptr
);
// nms desc
mluOpNmsDescriptor_t
nms_desc
;
const
mluOpNmsBoxPointMode_t
box_mode
=
(
mluOpNmsBoxPointMode_t
)
0
;
const
mluOpNmsOutputMode_t
output_mode
=
(
mluOpNmsOutputMode_t
)
0
;
const
mluOpNmsAlgo_t
algo
=
(
mluOpNmsAlgo_t
)
0
;
const
mluOpNmsMethodMode_t
method_mode
=
(
mluOpNmsMethodMode_t
)
0
;
const
float
soft_nms_sigma
=
0.0
;
const
float
confidence_threshold
=
0.0
;
const
int
input_layout
=
0
;
const
bool
pad_to_max_output_size
=
false
;
const
int
max_output_size
=
max_output_boxes
;
mluOpCreateNmsDescriptor
(
&
nms_desc
);
mluOpSetNmsDescriptor
(
nms_desc
,
box_mode
,
output_mode
,
algo
,
method_mode
,
iou_threshold
,
soft_nms_sigma
,
max_output_size
,
confidence_threshold
,
(
float
)
offset
,
input_layout
,
pad_to_max_output_size
);
mluOpNms
(
handle
,
nms_desc
,
boxes_desc
.
desc
(),
boxes_ptr
,
scores_desc
.
desc
(),
scores_ptr
,
workspace_ptr
,
workspace_size
,
output_desc
.
desc
(),
output_ptr
,
output_size_ptr
);
mluOpDestroyNmsDescriptor
(
nms_desc
);
int
output_num
=
*
static_cast
<
int
*>
(
output_size
.
cpu
().
data_ptr
());
return
output
.
slice
(
0
,
0
,
output_num
);
auto
ret
=
output
.
to
(
boxes
.
options
().
dtype
(
at
::
kLong
));
return
ret
.
slice
(
0
,
0
,
output_num
);
}
Tensor
nms_mlu
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
...
...
mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp
View file @
0c23eb02
...
...
@@ -9,26 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelRoiAlign
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
input
,
const
void
*
rois
,
const
int
channels
,
const
bool
aligned
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
input_height
,
const
int
input_width
,
const
int
sampling_ratio
,
const
float
spatial_scale
,
const
int
num_rois
,
void
*
output
);
void
KernelRoiAlignBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
dtype
,
const
void
*
grads
,
const
void
*
boxes
,
void
*
grads_image
,
const
int
boxes_num
,
const
int
hi
,
const
int
wi
,
const
int
c
,
const
int
no
,
const
int
ho
,
const
int
wo
,
const
float
spatial_scale
,
const
int
sampling_ratio
,
const
bool
aligned
);
#include "mlu_common_helper.h"
void
ROIAlignForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
rois
,
Tensor
output
,
Tensor
argmax_y
,
Tensor
argmax_x
,
...
...
@@ -36,17 +17,7 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
float
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
)
{
// params check
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
,
"input type should be Float or Half, got "
,
input
.
scalar_type
());
TORCH_CHECK
(
rois
.
scalar_type
()
==
input
.
scalar_type
(),
"rois should have the same type as input"
);
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input should be a 4d tensor, got "
,
input
.
dim
(),
"D"
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2d tensor, got "
,
rois
.
dim
(),
"D"
);
TORCH_CHECK
(
pool_mode
==
1
,
"pool_mode only supports 'avg' currently"
);
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
auto
input_tensor
=
...
...
@@ -57,52 +28,56 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
int
height
=
input
.
size
(
2
);
int
width
=
input
.
size
(
3
);
if
(
output
.
numel
()
==
0
)
{
output
=
at
::
zeros
({
num_rois
,
channels
,
aligned_height
,
aligned_width
},
input
.
options
());
return
;
}
at
::
Tensor
output_tmp
=
auto
output_contiguous
=
at
::
empty
({
num_rois
,
channels
,
aligned_height
,
aligned_width
},
input
.
options
(),
memory_format
);
// get tensor impl
auto
self_impl
=
torch_mlu
::
getMluTensorImpl
(
input_tensor
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
tmp
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
contiguous
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
MluOpTensorDescriptor
input_desc
,
rois_desc
,
argmax_y_desc
,
argmax_x_desc
,
output_desc
;
input_desc
.
set_with_layout
(
input_tensor
,
MLUOP_LAYOUT_NHWC
);
rois_desc
.
set_with_layout
(
rois
,
MLUOP_LAYOUT_ARRAY
);
output_desc
.
set_with_layout
(
output_contiguous
,
MLUOP_LAYOUT_NHWC
);
// get the mlu ptr
auto
self_ptr
=
self_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
cnrtJobType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
k_dim
.
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
.
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
.
z
=
1
;
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
KernelRoiAlign
(
k_dim
,
k_type
,
queue
,
data_type
,
self_ptr
,
rois_ptr
,
channels
,
aligned
,
aligned_height
,
aligned_width
,
height
,
width
,
sampling_ratio
,
spatial_scale
,
num_rois
,
output_ptr
);
output
.
copy_
(
output_tmp
);
}
static
int
nearestPower2
(
int
x
)
{
x
--
;
x
|=
x
>>
1
;
x
|=
x
>>
2
;
x
|=
x
>>
4
;
x
|=
x
>>
8
;
x
|=
x
>>
16
;
x
++
;
return
x
;
mluOpRoiAlignForwardDescriptor_t
roialign_desc
;
mluOpCreateRoiAlignForwardDescriptor
(
&
roialign_desc
);
mluOpSetRoiAlignForwardDescriptor_v2
(
roialign_desc
,
aligned_height
,
aligned_width
,
sampling_ratio
,
spatial_scale
,
pool_mode
,
aligned
);
auto
handle
=
mluOpGetCurrentHandle
();
if
(
pool_mode
==
0
)
{
auto
argmax_y_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_y
,
memory_format
);
auto
argmax_x_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_x
,
memory_format
);
auto
argmax_x_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_x_contiguous
);
auto
argmax_y_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_y_contiguous
);
auto
argmax_x_ptr
=
argmax_x_impl
->
cnnlMalloc
();
auto
argmax_y_ptr
=
argmax_y_impl
->
cnnlMalloc
();
argmax_y_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
argmax_x_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
mluOpRoiAlignForward_v2
(
handle
,
roialign_desc
,
input_desc
.
desc
(),
self_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
output_desc
.
desc
(),
output_ptr
,
argmax_x_desc
.
desc
(),
argmax_x_ptr
,
argmax_y_desc
.
desc
(),
argmax_y_ptr
);
argmax_x
.
copy_
(
argmax_x_contiguous
);
argmax_y
.
copy_
(
argmax_y_contiguous
);
}
else
{
mluOpRoiAlignForward_v2
(
handle
,
roialign_desc
,
input_desc
.
desc
(),
self_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
output_desc
.
desc
(),
output_ptr
,
NULL
,
NULL
,
NULL
,
NULL
);
}
mluOpDestroyRoiAlignForwardDescriptor
(
roialign_desc
);
output
.
copy_
(
output_contiguous
);
}
void
ROIAlignBackwardMLUKernelLauncher
(
Tensor
grad
,
Tensor
rois
,
...
...
@@ -112,17 +87,7 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
)
{
// params check
TORCH_CHECK
(
grad
.
scalar_type
()
==
at
::
kFloat
||
grad
.
scalar_type
()
==
at
::
kHalf
,
"grad type should be Float or Half, got "
,
grad
.
scalar_type
());
TORCH_CHECK
(
rois
.
scalar_type
()
==
grad
.
scalar_type
(),
"rois should have the same type as grad"
);
TORCH_CHECK
(
grad
.
dim
()
==
4
,
"grad should be a 4d tensor, got "
,
grad
.
dim
(),
"D"
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2d tensor, got "
,
rois
.
dim
(),
"D"
);
TORCH_CHECK
(
pool_mode
==
1
,
"pool_mode only supports 'avg' currently"
);
int
batch_size
=
grad_input
.
size
(
0
);
int
channels
=
grad_input
.
size
(
1
);
int
height
=
grad_input
.
size
(
2
);
...
...
@@ -148,26 +113,40 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois,
auto
grad_input_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_input_
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get the mlu ptr
auto
grad_ptr
=
grad_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
grad_input_ptr
=
grad_input_impl
->
cnnlMalloc
();
cnrtJobType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
int
need_core
=
nearestPower2
(
boxes_num
);
int
union_number
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
uint32_t
dim_x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
uint32_t
dim_y
=
(
need_core
-
1
)
/
dim_x
+
1
;
dim_y
=
(
dim_y
>
union_number
)
?
union_number
:
dim_y
;
cnrtDim3_t
k_dim
=
{
dim_x
,
dim_y
,
1
};
cnrtDataType_t
k_dtype
=
torch_mlu
::
toCnrtDtype
(
grad
.
dtype
());
KernelRoiAlignBackward
(
k_dim
,
k_type
,
queue
,
k_dtype
,
grad_ptr
,
rois_ptr
,
grad_input_ptr
,
boxes_num
,
hi
,
wi
,
c
,
no
,
ho
,
wo
,
spatial_scale
,
sampling_ratio
,
aligned
);
MluOpTensorDescriptor
grads_desc
,
rois_desc
,
argmax_y_desc
,
argmax_x_desc
,
grad_input_desc
;
grads_desc
.
set_with_layout
(
grad_
,
MLUOP_LAYOUT_NHWC
);
rois_desc
.
set_with_layout
(
rois
,
MLUOP_LAYOUT_ARRAY
);
grad_input_desc
.
set_with_layout
(
grad_input_
,
MLUOP_LAYOUT_NHWC
);
auto
handle
=
mluOpGetCurrentHandle
();
if
(
pool_mode
==
0
)
{
auto
argmax_y_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_y
,
memory_format
);
auto
argmax_x_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_x
,
memory_format
);
auto
argmax_x_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_x_contiguous
);
auto
argmax_y_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_y_contiguous
);
auto
argmax_x_ptr
=
argmax_x_impl
->
cnnlMalloc
();
auto
argmax_y_ptr
=
argmax_y_impl
->
cnnlMalloc
();
argmax_y_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
argmax_x_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
mluOpRoiAlignBackward_v2
(
handle
,
grads_desc
.
desc
(),
grad_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
argmax_y_desc
.
desc
(),
argmax_x_ptr
,
argmax_y_desc
.
desc
(),
argmax_y_ptr
,
spatial_scale
,
sampling_ratio
,
aligned
,
pool_mode
,
grad_input_desc
.
desc
(),
grad_input_ptr
);
}
else
{
mluOpRoiAlignBackward_v2
(
handle
,
grads_desc
.
desc
(),
grad_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
NULL
,
NULL
,
NULL
,
NULL
,
spatial_scale
,
sampling_ratio
,
aligned
,
pool_mode
,
grad_input_desc
.
desc
(),
grad_input_ptr
);
}
grad_input
.
copy_
(
grad_input_
);
}
...
...
mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp
View file @
0c23eb02
...
...
@@ -9,49 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelPtsIdxOfVoxels
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
pool_method
,
const
int
boxes_num
,
const
int
pts_num
,
const
int
max_pts_each_voxel
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
const
void
*
rois
,
const
void
*
pts
,
int
*
pts_idx_of_voxels
);
void
KernelRoiawarePool3dForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
pool_method
,
const
int
boxes_num
,
const
int
pts_num
,
const
int
channels
,
const
int
max_pts_each_voxel
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
const
void
*
pts_feature
,
const
int
*
pts_idx_of_voxels
,
void
*
pooled_features
,
int
*
argmax
);
// policy function
static
void
kernelPtsIdxOfVoxelsPolicyFunc
(
const
int
boxes_num
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
unsigned
int
use_cluster
=
(
boxes_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
static
void
kernelRoiawarePool3dForwardPolicyFunc
(
const
int
boxes_num
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
const
int
voxels_num
=
boxes_num
*
out_x
*
out_y
*
out_z
;
unsigned
int
use_cluster
=
(
voxels_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
#include "mlu_common_helper.h"
void
RoiawarePool3dForwardMLUKernelLauncher
(
const
int
pool_method
,
const
int
boxes_num
,
const
int
pts_num
,
...
...
@@ -59,168 +17,65 @@ void RoiawarePool3dForwardMLUKernelLauncher(
const
int
out_y
,
const
int
out_z
,
const
Tensor
rois
,
const
Tensor
pts
,
const
Tensor
pts_feature
,
Tensor
pts_idx_of_voxels
,
Tensor
pooled_features
,
Tensor
argmax
)
{
// check datatype
TORCH_CHECK
(((
pts
.
scalar_type
()
==
rois
.
scalar_type
())
&&
(
pts_feature
.
scalar_type
()
==
rois
.
scalar_type
())
&&
(
pooled_features
.
scalar_type
()
==
rois
.
scalar_type
())),
"data types of rois, rois, pts_feature and pooled_features "
"should be the same, "
,
"but now rois type is "
,
rois
.
scalar_type
(),
", pts type is "
,
pts
.
scalar_type
(),
", pts_feature type is "
,
pts_feature
.
scalar_type
(),
", pooled_features type is "
,
pooled_features
.
scalar_type
(),
"."
);
TORCH_CHECK
(
(
rois
.
scalar_type
()
==
at
::
kFloat
||
rois
.
scalar_type
()
==
at
::
kHalf
),
"rois type should be Float or Half, got "
,
rois
.
scalar_type
(),
"."
);
TORCH_CHECK
((
pts_idx_of_voxels
.
scalar_type
()
==
at
::
kInt
),
"pts_idx_of_voxels type should be Int, got "
,
pts_idx_of_voxels
.
scalar_type
(),
"."
);
// check dim
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2D tensor, got "
,
rois
.
dim
(),
"D."
);
TORCH_CHECK
(
pts
.
dim
()
==
2
,
"pts should be a 2D tensor, got "
,
pts
.
dim
(),
"D."
);
TORCH_CHECK
(
pts_feature
.
dim
()
==
2
,
"pts_feature should be a 2D tensor, got "
,
pts_feature
.
dim
(),
"D."
);
TORCH_CHECK
(
pts_idx_of_voxels
.
dim
()
==
5
,
"pts_idx_of_voxels should be a 5D tensor, got "
,
pts_idx_of_voxels
.
dim
(),
"D."
);
TORCH_CHECK
(
pooled_features
.
dim
()
==
5
,
"pooled_features should be a 5D tensor, got "
,
pooled_features
.
dim
(),
"D."
);
// check shape
TORCH_CHECK
(((
rois
.
size
(
0
)
==
boxes_num
)
&&
(
rois
.
size
(
1
)
==
7
)),
"the dimensions of rois should be (boxes_num, 7), "
,
"but got ("
,
rois
.
size
(
0
),
", "
,
rois
.
size
(
1
),
") ."
);
TORCH_CHECK
(((
pts
.
size
(
0
)
==
pts_num
)
&&
(
pts
.
size
(
1
)
==
3
)),
"the dimensions of pts should be (pts_num, 3), "
,
"but got ("
,
pts
.
size
(
0
),
","
,
pts
.
size
(
1
),
")."
);
TORCH_CHECK
(
((
pts_feature
.
size
(
0
)
==
pts_num
)
&&
(
pts_feature
.
size
(
1
)
==
channels
)),
"the dimensions of pts_feature should be (pts_num, channels), "
,
"but got ("
,
pts_feature
.
size
(
0
),
","
,
pts_feature
.
size
(
1
),
")."
);
TORCH_CHECK
(((
pts_idx_of_voxels
.
size
(
0
)
==
boxes_num
)
&&
(
pts_idx_of_voxels
.
size
(
1
)
==
out_x
)
&&
(
pts_idx_of_voxels
.
size
(
2
)
==
out_y
)
&&
(
pts_idx_of_voxels
.
size
(
3
)
==
out_z
)
&&
(
pts_idx_of_voxels
.
size
(
4
)
==
max_pts_each_voxel
)),
"the dimensions of pts_idx_of_voxels should be (boxes_num, "
"out_x, out_y, out_z, max_pts_each_voxel), "
,
"but got ("
,
pts_idx_of_voxels
.
size
(
0
),
","
,
pts_idx_of_voxels
.
size
(
1
),
","
,
pts_idx_of_voxels
.
size
(
2
),
","
,
pts_idx_of_voxels
.
size
(
3
),
","
,
pts_idx_of_voxels
.
size
(
4
),
")."
);
TORCH_CHECK
(((
pooled_features
.
size
(
0
)
==
boxes_num
)
&&
(
pooled_features
.
size
(
1
)
==
out_x
)
&&
(
pooled_features
.
size
(
2
)
==
out_y
)
&&
(
pooled_features
.
size
(
3
)
==
out_z
)
&&
(
pooled_features
.
size
(
4
)
==
channels
)),
"the dimensions of pooled_features should be (boxes_num, out_x, "
"out_y, out_z, channels), "
,
"but got ("
,
pooled_features
.
size
(
0
),
","
,
pooled_features
.
size
(
1
),
","
,
pooled_features
.
size
(
2
),
","
,
pooled_features
.
size
(
3
),
","
,
pooled_features
.
size
(
4
),
")."
);
// check other params : pool_mothod
TORCH_CHECK
(((
pool_method
==
0
)
||
(
pool_method
==
1
)),
"the num of pool_method should be 0(max) or 1(avg), "
,
"but got "
,
pool_method
,
"."
);
// check large tensor
const
size_t
max_input_size
=
2147483648
;
TORCH_CHECK
(
rois
.
numel
()
<
max_input_size
,
"rois element num should be less than 2^31, got "
,
rois
.
numel
(),
"."
);
TORCH_CHECK
(
pts
.
numel
()
<
max_input_size
,
"pts element num should be less than 2^31, got "
,
pts
.
numel
(),
"."
);
TORCH_CHECK
(
pts_feature
.
numel
()
<
max_input_size
,
"pts_feature element num should be less than 2^31, got "
,
pts_feature
.
numel
(),
"."
);
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
<
max_input_size
,
"pts_idx_of_voxels element num should be less than 2^31, got "
,
pts_idx_of_voxels
.
numel
(),
"."
);
TORCH_CHECK
(
pooled_features
.
numel
()
<
max_input_size
,
"pooled_features element num should be less than 2^31, got "
,
pooled_features
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
rois
.
numel
()
!=
0
,
"rois.numel() should not be zero, got "
,
rois
.
numel
());
TORCH_CHECK
(
pts
.
numel
()
!=
0
,
"pts.numel() should not be zero, got "
,
pts
.
numel
());
TORCH_CHECK
(
pts_feature
.
numel
()
!=
0
,
"pts_feature.numel() should not be zero, got "
,
pts_feature
.
numel
());
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
!=
0
,
"pts_idx_of_voxels.numel() should not be zero, got "
,
pts_idx_of_voxels
.
numel
());
TORCH_CHECK
(
pooled_features
.
numel
()
!=
0
,
"pooled_features.numel() should not be zero, got "
,
pooled_features
.
numel
());
if
(
pool_method
==
0
)
{
// check datatype
TORCH_CHECK
((
argmax
.
scalar_type
()
==
at
::
kInt
),
"argmax type should be Int, got "
,
argmax
.
scalar_type
(),
"."
);
// check dim
TORCH_CHECK
(
argmax
.
dim
()
==
5
,
"argmax should be a 5D tensor, got "
,
argmax
.
dim
(),
"D."
);
// check shape
TORCH_CHECK
(((
argmax
.
size
(
0
)
==
boxes_num
)
&&
(
argmax
.
size
(
1
)
==
out_x
)
&&
(
argmax
.
size
(
2
)
==
out_y
)
&&
(
argmax
.
size
(
3
)
==
out_z
)
&&
(
argmax
.
size
(
4
)
==
channels
)),
"the dimensions of argmax should be (boxes_num, out_x, out_y, "
"out_z, channels), "
,
"but got ("
,
argmax
.
size
(
0
),
","
,
argmax
.
size
(
1
),
","
,
argmax
.
size
(
2
),
","
,
argmax
.
size
(
3
),
","
,
argmax
.
size
(
4
),
")."
);
// check large tensor
TORCH_CHECK
(
argmax
.
numel
()
<
max_input_size
,
"argmax element num should be less than 2^31, got "
,
argmax
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
argmax
.
numel
()
!=
0
,
"argmax.numel() should not be zero, got "
,
argmax
.
numel
());
// when pool_method is 0, which is max pool, init argmax data value to -1
argmax
.
fill_
(
static_cast
<
int
>
(
-
1
));
}
// calculate task one dimension
cnrtDim3_t
k1_dim
;
cnrtFunctionType_t
k1_type
;
kernelPtsIdxOfVoxelsPolicyFunc
(
boxes_num
,
&
k1_dim
,
&
k1_type
);
cnrtDim3_t
k2_dim
;
cnrtFunctionType_t
k2_type
;
kernelRoiawarePool3dForwardPolicyFunc
(
boxes_num
,
out_x
,
out_y
,
out_z
,
&
k2_dim
,
&
k2_type
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
auto
rois_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
rois
,
rois
.
suggest_memory_format
());
auto
pts_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts
,
pts
.
suggest_memory_format
());
auto
pts_feature_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_feature
,
pts_feature
.
suggest_memory_format
());
auto
argmax_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax
,
argmax
.
suggest_memory_format
());
auto
pts_idx_of_voxels_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_idx_of_voxels
,
pts_idx_of_voxels
.
suggest_memory_format
());
auto
pooled_features_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pooled_features
,
pooled_features
.
suggest_memory_format
());
MluOpTensorDescriptor
rois_desc
,
pts_desc
,
pts_feature_desc
,
argmax_desc
,
pts_idx_of_voxels_desc
,
pooled_features_desc
;
rois_desc
.
set
(
rois_contiguous
);
pts_desc
.
set
(
pts_contiguous
);
pts_feature_desc
.
set
(
pts_feature_contiguous
);
argmax_desc
.
set
(
argmax_contiguous
);
pts_idx_of_voxels_desc
.
set
(
pts_idx_of_voxels_contiguous
);
pooled_features_desc
.
set
(
pooled_features_contiguous
);
// allocate extra space for workspace
size_t
workspace_size
=
0
;
mluOpGetRoiawarePool3dForwardWorkspaceSize
(
handle
,
rois_desc
.
desc
(),
pts_desc
.
desc
(),
pts_feature_desc
.
desc
(),
&
workspace_size
);
auto
workspace
=
at
::
empty
(
workspace_size
,
rois
.
options
().
dtype
(
at
::
kByte
));
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois_contiguous
);
auto
pts_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_contiguous
);
auto
pts_feature_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_feature_contiguous
);
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_contiguous
);
auto
pts_idx_of_voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels_contiguous
);
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features_contiguous
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
// transpose points [pts_num, 3] -> [3, pts_num]
auto
pts_
=
pts
.
permute
({
1
,
0
}).
contiguous
();
auto
pts_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_
);
auto
pts_ptr
=
pts_impl
->
cnnlMalloc
();
// transpose points_features [pts_num, channels] -> [channels, pts_num]
auto
pts_feature_
=
pts_feature
.
permute
({
1
,
0
}).
contiguous
();
auto
pts_feature_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_feature_
);
auto
pts_feature_ptr
=
pts_feature_impl
->
cnnlMalloc
();
auto
pts_idx_of_voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels
);
auto
argmax_ptr
=
argmax_impl
->
cnnlMalloc
(
);
auto
pts_idx_of_voxels_ptr
=
pts_idx_of_voxels_impl
->
cnnlMalloc
();
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features
);
auto
pooled_features_ptr
=
pooled_features_impl
->
cnnlMalloc
();
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax
);
auto
argmax_ptr
=
argmax_impl
->
cnnlMalloc
();
// get compute dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
rois
.
dtype
());
// launch kernel PtsIdxOfVoxels
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernel PtsIdxOfVoxels<<<"
<<
k1_dim
.
x
<<
", "
<<
k1_dim
.
y
<<
", "
<<
k1_dim
.
z
<<
">>>"
;
KernelPtsIdxOfVoxels
(
k1_dim
,
k1_type
,
queue
,
data_type
,
pool_method
,
boxes_num
,
pts_num
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
rois_ptr
,
pts_ptr
,
(
int
*
)
pts_idx_of_voxels_ptr
);
// launch kernel RoiawarePool3dForward
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernel RoiawarePool3dForward<<<"
<<
k2_dim
.
x
<<
", "
<<
k2_dim
.
y
<<
", "
<<
k2_dim
.
z
<<
">>>"
;
KernelRoiawarePool3dForward
(
k2_dim
,
k2_type
,
queue
,
data_type
,
pool_method
,
boxes_num
,
pts_num
,
channels
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
pts_feature_ptr
,
(
int
*
)
pts_idx_of_voxels_ptr
,
pooled_features_ptr
,
(
int
*
)
argmax_ptr
);
CNLOG
(
INFO
)
<<
"Call mluOpRoiawarePool3dForward()."
;
mluOpRoiawarePool3dForward
(
handle
,
pool_method
,
boxes_num
,
pts_num
,
channels
,
rois_desc
.
desc
(),
rois_ptr
,
pts_desc
.
desc
(),
pts_ptr
,
pts_feature_desc
.
desc
(),
pts_feature_ptr
,
workspace_ptr
,
workspace_size
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
argmax_desc
.
desc
(),
argmax_ptr
,
pts_idx_of_voxels_desc
.
desc
(),
pts_idx_of_voxels_ptr
,
pooled_features_desc
.
desc
(),
pooled_features_ptr
);
}
void
roiaware_pool3d_forward_mlu
(
int
boxes_num
,
int
pts_num
,
int
channels
,
...
...
@@ -245,136 +100,46 @@ void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels,
REGISTER_DEVICE_IMPL
(
roiaware_pool3d_forward_impl
,
MLU
,
roiaware_pool3d_forward_mlu
);
void
KernelRoiawarePool3dBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
pool_method
,
const
int
boxes_num
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
const
int
channels
,
const
int
max_pts_each_voxel
,
const
int
*
pts_idx_of_voxels
,
const
int
*
argmax
,
const
void
*
grad_out
,
void
*
grad_in
);
static
void
kernelRoiawarePool3dBackwardPolicyFunc
(
const
int
boxes_num
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
const
int
voxels_num
=
boxes_num
*
out_x
*
out_y
*
out_z
;
unsigned
int
use_cluster
=
(
voxels_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
void
RoiawarePool3dBackwardMLUKernelLauncher
(
int
pool_method
,
int
boxes_num
,
int
out_x
,
int
out_y
,
int
out_z
,
int
channels
,
int
max_pts_each_voxel
,
const
Tensor
pts_idx_of_voxels
,
const
Tensor
argmax
,
const
Tensor
grad_out
,
Tensor
grad_in
)
{
// check datatype
TORCH_CHECK
((
pts_idx_of_voxels
.
scalar_type
()
==
at
::
kInt
),
"pts_idx_of_voxels type should be Int, got "
,
pts_idx_of_voxels
.
scalar_type
(),
"."
);
TORCH_CHECK
((
argmax
.
scalar_type
()
==
at
::
kInt
),
"argmax type should be Int, got "
,
argmax
.
scalar_type
(),
"."
);
TORCH_CHECK
((
grad_out
.
scalar_type
()
==
at
::
kFloat
||
grad_out
.
scalar_type
()
==
at
::
kHalf
),
"grad_out type should be Float or Half, got "
,
grad_out
.
scalar_type
(),
"."
);
TORCH_CHECK
((
grad_out
.
scalar_type
()
==
grad_in
.
scalar_type
()),
"data types of grad_out, grad_in, should be the same, "
,
"but now grad_out type is "
,
grad_out
.
scalar_type
(),
", grad_in type is "
,
grad_in
.
scalar_type
(),
"."
);
// check dim
TORCH_CHECK
(
pts_idx_of_voxels
.
dim
()
==
5
,
"pts_idx_of_voxels should be a 5D tensor, got "
,
pts_idx_of_voxels
.
dim
(),
"D."
);
TORCH_CHECK
(
argmax
.
dim
()
==
5
,
"argmax should be a 5D tensor, got "
,
argmax
.
dim
(),
"D."
);
TORCH_CHECK
(
grad_out
.
dim
()
==
5
,
"grad_out should be a 5D tensor, got "
,
grad_out
.
dim
(),
"D."
);
TORCH_CHECK
(
grad_in
.
dim
()
==
2
,
"grad_in should be a 2D tensor, got "
,
grad_in
.
dim
(),
"D."
);
// check shape
TORCH_CHECK
(((
pts_idx_of_voxels
.
size
(
0
)
==
boxes_num
)
&&
(
pts_idx_of_voxels
.
size
(
1
)
==
out_x
)
&&
(
pts_idx_of_voxels
.
size
(
2
)
==
out_y
)
&&
(
pts_idx_of_voxels
.
size
(
3
)
==
out_z
)
&&
(
pts_idx_of_voxels
.
size
(
4
)
==
max_pts_each_voxel
)),
"the dimensions of pts_idx_of_voxels should be (boxes_num, "
"out_x, out_y, out_z, max_pts_each_voxel), "
,
"but got ("
,
pts_idx_of_voxels
.
size
(
0
),
","
,
pts_idx_of_voxels
.
size
(
1
),
","
,
pts_idx_of_voxels
.
size
(
2
),
","
,
pts_idx_of_voxels
.
size
(
3
),
","
,
pts_idx_of_voxels
.
size
(
4
),
")."
);
TORCH_CHECK
(((
argmax
.
size
(
0
)
==
boxes_num
)
&&
(
argmax
.
size
(
1
)
==
out_x
)
&&
(
argmax
.
size
(
2
)
==
out_y
)
&&
(
argmax
.
size
(
3
)
==
out_z
)
&&
(
argmax
.
size
(
4
)
==
channels
)),
"the dimensions of argmax should be (boxes_num, out_x, out_y, "
"out_z, channels), "
,
"but got ("
,
argmax
.
size
(
0
),
","
,
argmax
.
size
(
1
),
","
,
argmax
.
size
(
2
),
","
,
argmax
.
size
(
3
),
","
,
argmax
.
size
(
4
),
")."
);
TORCH_CHECK
(((
grad_out
.
size
(
0
)
==
boxes_num
)
&&
(
grad_out
.
size
(
1
)
==
out_x
)
&&
(
grad_out
.
size
(
2
)
==
out_y
)
&&
(
grad_out
.
size
(
3
)
==
out_z
)
&&
(
grad_out
.
size
(
4
)
==
channels
)),
"the dimensions of grad_out should be (boxes_num, out_x, "
"out_y, out_z, channels), "
,
"but got ("
,
grad_out
.
size
(
0
),
","
,
grad_out
.
size
(
1
),
","
,
grad_out
.
size
(
2
),
","
,
grad_out
.
size
(
3
),
","
,
grad_out
.
size
(
4
),
")."
);
TORCH_CHECK
((
grad_in
.
size
(
1
)
==
channels
),
"the 1st dimensions of grad_in should be channels, "
,
"but got "
,
grad_in
.
size
(
1
),
"."
);
// check other params : pool_mothod
TORCH_CHECK
(((
pool_method
==
0
)
||
(
pool_method
==
1
)),
"the num of pool_method should be 0(max) or 1(avg), "
,
"but got "
,
pool_method
,
"."
);
// check large tensor
const
size_t
max_input_size
=
2147483648
;
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
<
max_input_size
,
"pts_idx_of_voxels element num should be less than 2^31, got "
,
pts_idx_of_voxels
.
numel
(),
"."
);
TORCH_CHECK
(
argmax
.
numel
()
<
max_input_size
,
"argmax element num should be less than 2^31, got "
,
argmax
.
numel
(),
"."
);
TORCH_CHECK
(
grad_out
.
numel
()
<
max_input_size
,
"grad_out element num should be less than 2^31, got "
,
grad_out
.
numel
(),
"."
);
TORCH_CHECK
(
grad_in
.
numel
()
<
max_input_size
,
"grad_in element num should be less than 2^31, got "
,
grad_in
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
!=
0
,
"pts_idx_of_voxels.numel() should not be zero, got "
,
pts_idx_of_voxels
.
numel
());
TORCH_CHECK
(
argmax
.
numel
()
!=
0
,
"argmax.numel() should not be zero, got "
,
argmax
.
numel
());
TORCH_CHECK
(
grad_out
.
numel
()
!=
0
,
"grad_out.numel() should not be zero, got "
,
grad_out
.
numel
());
TORCH_CHECK
(
grad_in
.
numel
()
!=
0
,
"grad_in.numel() should not be zero, got "
,
grad_in
.
numel
());
// calculate task one dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
kernelRoiawarePool3dBackwardPolicyFunc
(
boxes_num
,
out_x
,
out_y
,
out_z
,
&
k_dim
,
&
k_type
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// transpose points_features [pts_num, channels] -> [channels, pts_num]
auto
pts_idx_of_voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
auto
pts_idx_of_voxels_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_idx_of_voxels
,
pts_idx_of_voxels
.
suggest_memory_format
());
auto
argmax_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax
,
argmax
.
suggest_memory_format
());
auto
grad_out_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_out
,
grad_out
.
suggest_memory_format
());
auto
grad_in_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_in
,
grad_in
.
suggest_memory_format
());
MluOpTensorDescriptor
pts_idx_of_voxels_desc
,
argmax_desc
,
grad_out_desc
,
grad_in_desc
;
pts_idx_of_voxels_desc
.
set
(
pts_idx_of_voxels_contiguous
);
argmax_desc
.
set
(
argmax_contiguous
);
grad_out_desc
.
set
(
grad_out_contiguous
);
grad_in_desc
.
set
(
grad_in_contiguous
);
auto
pts_idx_of_voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels_contiguous
);
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_contiguous
);
auto
grad_out_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_out_contiguous
);
auto
grad_in_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_in_contiguous
);
auto
pts_idx_of_voxels_ptr
=
pts_idx_of_voxels_impl
->
cnnlMalloc
();
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax
);
auto
argmax_ptr
=
argmax_impl
->
cnnlMalloc
();
auto
grad_out_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_out
);
auto
grad_out_ptr
=
grad_out_impl
->
cnnlMalloc
();
auto
grad_in_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_in
);
auto
grad_in_ptr
=
grad_in_impl
->
cnnlMalloc
();
// get compute dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
grad_out
.
dtype
());
// launch kernel RoiawarePool3dForward
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernel RoiawarePool3dBackward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelRoiawarePool3dBackward
(
k_dim
,
k_type
,
queue
,
data_type
,
pool_method
,
boxes_num
,
out_x
,
out_y
,
out_z
,
channels
,
max_pts_each_voxel
,
(
int
*
)
pts_idx_of_voxels_ptr
,
(
int
*
)
argmax_ptr
,
grad_out_ptr
,
grad_in_ptr
);
CNLOG
(
INFO
)
<<
"Call mluOpRoiawarePool3dBackward()."
;
mluOpRoiawarePool3dBackward
(
handle
,
pool_method
,
boxes_num
,
out_x
,
out_y
,
out_z
,
channels
,
max_pts_each_voxel
,
pts_idx_of_voxels_desc
.
desc
(),
pts_idx_of_voxels_ptr
,
argmax_desc
.
desc
(),
argmax_ptr
,
grad_out_desc
.
desc
(),
grad_out_ptr
,
grad_in_desc
.
desc
(),
grad_in_ptr
);
}
void
roiaware_pool3d_backward_mlu
(
int
boxes_num
,
int
out_x
,
int
out_y
,
...
...
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
View file @
0c23eb02
...
...
@@ -9,84 +9,47 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelThreeNNForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
cnrtDataType_t
data_type
,
const
void
*
unknown
,
const
void
*
known
,
void
*
dist2
,
int
*
idx
,
const
int
b
,
const
int
n
,
const
int
m
);
#include "mlu_common_helper.h"
void
ThreeNNMLUKernelLauncher
(
int
b
,
int
n
,
int
m
,
const
Tensor
unknown
,
const
Tensor
known
,
Tensor
dist2
,
Tensor
idx
)
{
// Check dtype.
TORCH_CHECK
(
unknown
.
scalar_type
()
==
at
::
kFloat
||
unknown
.
scalar_type
()
==
at
::
kHalf
,
"unknown type should be Float or Half, got "
,
unknown
.
scalar_type
(),
"."
);
TORCH_CHECK
(
unknown
.
scalar_type
()
==
known
.
scalar_type
(),
"known should have the same type as unknown."
);
TORCH_CHECK
(
unknown
.
scalar_type
()
==
dist2
.
scalar_type
(),
"dist2 should have the same type as unknown."
);
TORCH_CHECK
(
idx
.
scalar_type
()
==
at
::
kInt
,
"idx type should be Int."
);
// Check shape.
TORCH_CHECK
(
unknown
.
dim
()
==
3
,
"unknown should be 3d tensor, got "
,
unknown
.
dim
(),
"D."
);
TORCH_CHECK
(
known
.
dim
()
==
3
,
"known should be 3d tensor, got "
,
known
.
dim
(),
"D."
);
TORCH_CHECK
(
unknown
.
size
(
0
)
==
known
.
size
(
0
),
"known.dim0 should be equal to unknown.dim0, got "
,
known
.
size
(
0
),
"."
);
TORCH_CHECK
(
unknown
.
size
(
2
)
==
3
,
"unknown dim2 should be 3, got "
,
unknown
.
size
(
2
),
"."
);
TORCH_CHECK
(
known
.
size
(
2
)
==
3
,
"known dim2 should be 3, got "
,
known
.
size
(
2
),
"."
);
// zero element check
TORCH_CHECK
(
unknown
.
numel
()
>
0
,
"unknown.numel should greater than zero, got "
,
unknown
.
numel
(),
"."
);
if
(
known
.
numel
()
==
0
)
{
// return if known zero element
return
;
}
// large tensor check
const
size_t
max_input_num
=
2147483648
;
// 2^31, 2G num
TORCH_CHECK
(
unknown
.
numel
()
<
max_input_num
,
"unknown.numel() should be less than 2147483648, got "
,
unknown
.
numel
(),
"."
);
TORCH_CHECK
(
known
.
numel
()
<
max_input_num
,
"known.numel() should be less than 2147483648, got "
,
known
.
numel
(),
"."
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
unknown_impl
=
torch_mlu
::
getMluTensorImpl
(
unknown
);
auto
unknown_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
unknown
,
unknown
.
suggest_memory_format
());
auto
known_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
known
,
known
.
suggest_memory_format
());
auto
dist2_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
dist2
,
dist2
.
suggest_memory_format
());
auto
idx_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
idx
,
idx
.
suggest_memory_format
());
MluOpTensorDescriptor
unknown_desc
,
known_desc
,
dist2_desc
,
idx_desc
;
unknown_desc
.
set
(
unknown_contiguous
);
known_desc
.
set
(
known_contiguous
);
dist2_desc
.
set
(
dist2_contiguous
);
idx_desc
.
set
(
idx_contiguous
);
auto
handle
=
mluOpGetCurrentHandle
();
size_t
workspace_size
=
0
;
mluOpGetThreeNNForwardWorkspaceSize
(
handle
,
known_desc
.
desc
(),
&
workspace_size
);
auto
known_workspace
=
at
::
empty
(
workspace_size
,
known
.
options
().
dtype
(
at
::
kByte
));
auto
unknown_impl
=
torch_mlu
::
getMluTensorImpl
(
unknown_contiguous
);
auto
known_impl
=
torch_mlu
::
getMluTensorImpl
(
known_contiguous
);
auto
dist2_impl
=
torch_mlu
::
getMluTensorImpl
(
dist2_contiguous
);
auto
idx_impl
=
torch_mlu
::
getMluTensorImpl
(
idx_contiguous
);
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
known_workspace
);
auto
unknown_ptr
=
unknown_impl
->
cnnlMalloc
();
auto
known_t
=
known
.
permute
({
0
,
2
,
1
}).
contiguous
();
auto
known_impl
=
torch_mlu
::
getMluTensorImpl
(
known_t
);
auto
known_ptr
=
known_impl
->
cnnlMalloc
();
auto
dist2_impl
=
torch_mlu
::
getMluTensorImpl
(
dist2
);
auto
dist2_ptr
=
dist2_impl
->
cnnlMalloc
();
auto
idx_impl
=
torch_mlu
::
getMluTensorImpl
(
idx
);
auto
idx_ptr
=
idx_impl
->
cnnlMalloc
();
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
cnrtJobType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
k_dim
.
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
.
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
.
z
=
1
;
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
unknown
.
dtype
());
// launch kernel
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelThreeNNForward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>."
;
KernelThreeNNForward
(
k_dim
,
k_type
,
queue
,
data_type
,
unknown_ptr
,
known_ptr
,
dist2_ptr
,
(
int
*
)
idx_ptr
,
b
,
n
,
m
);
mluOpThreeNNForward
(
handle
,
unknown_desc
.
desc
(),
unknown_ptr
,
known_desc
.
desc
(),
known_ptr
,
workspace_ptr
,
workspace_size
,
dist2_desc
.
desc
(),
dist2_ptr
,
idx_desc
.
desc
(),
idx_ptr
);
}
void
three_nn_forward_mlu
(
int
b
,
int
n
,
int
m
,
const
Tensor
unknown
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment