Unverified Commit 733e6ff8 authored by bdf's avatar bdf Committed by GitHub
Browse files

Pick MLU modifications from master (1.x) to main (2.x) (#2704)



* [Feature] Support Voxelization with cambricon MLU device (#2500)

* [Feature] Support hard_voxelize with cambricon MLU backend

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Enhance] Optimize the performace of ms_deform_attn for MLU device (#2510)

* ms_opt

* ms_opt

* ms_opt

* ms_opt

* ms_opt

* [Feature] ms_deform_attn performance optimization

* [Feature] ms_deform_attn performance optimization

* [Feature] ms_deform_attn performance optimization

* [Feature] Support ball_query with cambricon MLU backend and mlu-ops library. (#2520)

* [Feature] Support ball_query with cambricon MLU backend and mlu-ops library.

* [Fix] update operator data layout setting.

* [Fix] add cxx compile option to avoid symbol conflict.

* [Fix] fix lint errors.

* [Fix] update ops.md with info of ball_query support by MLU backend.

* [Feature] Fix typo.

* [Fix] Remove print.

* [Fix] get mlu-ops from MMCV_MLU_OPS_PATH env.

* [Fix] update MMCV_MLU_OPS_PATH check logic.

* [Fix] update error info when failed to download mlu-ops.

* [Fix] check mlu-ops version matching info in mmcv.

* [Fix] revise wrong filename.

* [Fix] remove f.close and re.

* [Docs] Steps to compile mmcv-full on MLU machine (#2571)

* [Docs] Steps to compile mmcv-full on MLU machine

* [Docs] Adjust paragraph order

* Update docs/zh_cn/get_started/build.md
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update docs/zh_cn/get_started/build.md
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update docs/en/get_started/build.md
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update docs/en/get_started/build.md
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* [Docs] Modify the format

---------
Co-authored-by: default avatarbudefei <budefei@cambricon.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* [Fix] Fix tensor descriptor setting in MLU ball_query. (#2579)

* [Feature] Add MLU support for Sparse Convolution op (#2589)

* [Feature] Add sparse convolution MLU API

* [Feature] update cpp code style

* end-of-file

* delete libext.a

* code style

* update ops.md

---------
Co-authored-by: default avatarbudefei <budefei@cambricon.com>

* [Enhancement] Replace the implementation of deform_roi_pool with mlu-ops (#2598)

* [Feature] Replace the implementation of deform_roi_pool with mlu-ops

* [Feature] Modify code

---------
Co-authored-by: default avatarbudefei <budefei@cambricon.com>

* [Enhancement] ms_deform_attn performance optimization (#2616)

* ms_opt_v2

* ms_opt_v2_1

* optimize MultiScaleDeformableAttention ops for MLU

* ms_opt_v2_1

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

---------
Co-authored-by: default avatardongchengwei <dongchengwei@cambricon.com>

* [Feature] Support NmsRotated with cambricon MLU backend (#2643)

* [Feature] Support NmsRotated with cambricon MLU backend

* [Feature] remove foolproofs in nms_rotated_mlu.cpp

* [Feature] fix lint in test_nms_rotated.py

* [Feature] fix kMLU not found in nms_rotated.cpp

* [Feature] modify mlu support in nms.py

* [Feature] modify nms_rotated support in ops.md

* [Feature] modify ops/nms.py

* [Enhance] Add a default value for MMCV_MLU_ARGS (#2688)

* add mlu_args

* add mlu_args

* Modify the code

---------
Co-authored-by: default avatarbudefei <budefei@cambricon.com>

* [Enhance] Ignore mlu-ops files (#2691)
Co-authored-by: default avatarbudefei <budefei@cambricon.com>

---------
Co-authored-by: default avatarZShaopeng <108382403+ZShaopeng@users.noreply.github.com>
Co-authored-by: default avatarBinZheng <38182684+Wickyzheng@users.noreply.github.com>
Co-authored-by: default avatarliuduanhui <103939338+DanieeelLiu@users.noreply.github.com>
Co-authored-by: default avatarbudefei <budefei@cambricon.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarduzekun <108381389+duzekunKTH@users.noreply.github.com>
Co-authored-by: default avatardongchengwei <dongchengwei@cambricon.com>
Co-authored-by: default avatarliuyuan1-v <125547457+liuyuan1-v@users.noreply.github.com>
parent 1f161f68
...@@ -27,6 +27,8 @@ wheels/ ...@@ -27,6 +27,8 @@ wheels/
.installed.cfg .installed.cfg
*.egg *.egg
MANIFEST MANIFEST
mlu-ops/
mlu-ops.*
# PyInstaller # PyInstaller
# Usually these files are written by a python script from a template # Usually these files are written by a python script from a template
......
...@@ -290,3 +290,60 @@ If you need to use PyTorch-related modules, make sure PyTorch has been successfu ...@@ -290,3 +290,60 @@ If you need to use PyTorch-related modules, make sure PyTorch has been successfu
```bash ```bash
python -c 'import mmcv;print(mmcv.__version__)' python -c 'import mmcv;print(mmcv.__version__)'
``` ```
### Build mmcv-full on Cambricon MLU Devices
#### Install torch_mlu
##### Option1: Install mmcv-full based on Cambricon docker image
Firstly, install and pull Cambricon docker image (please email service@cambricon.com for the latest release docker):
```bash
docker pull ${docker image}
```
Run and attach to the docker, [Install mmcv-full on MLU device](#install-mmcv\-full-on-cambricon-mlu-device) and [make sure you've installed mmcv-full on MLU device successfully](#test-code)
##### Option2: Install mmcv-full from compiling Cambricon PyTorch source code
Please email service@cambricon.com or contact with Cambricon engineers for a suitable version of CATCH package. After you get the suitable version of CATCH package, please follow the steps in ${CATCH-path}/CONTRIBUTING.md to install Cambricon PyTorch.
#### Install mmcv-full on Cambricon MLU device
Clone the repo
```bash
git clone https://github.com/open-mmlab/mmcv.git
```
The mlu-ops library will be downloaded to the default directory (mmcv/mlu-ops) while building MMCV. You can also set `MMCV_MLU_OPS_PATH` to an existing mlu-ops library before building as follows:
```bash
export MMCV_MLU_OPS_PATH=/xxx/xxx/mlu-ops
```
Install mmcv-full
```bash
cd mmcv
export MMCV_WITH_OPS=1
export FORCE_MLU=1
python setup.py install
```
#### Test Code
After finishing previous steps, you can run the following python code to make sure that you've installed mmcv-full on MLU device successfully
```python
import torch
import torch_mlu
from mmcv.ops import sigmoid_focal_loss
x = torch.randn(3, 10).mlu()
x.requires_grad = True
y = torch.tensor([1, 5, 3]).mlu()
w = torch.ones(10).float().mlu()
output = sigmoid_focal_loss(x, y, 2.0, 0.25, w, 'none')
print(output)
```
...@@ -6,7 +6,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -6,7 +6,7 @@ We implement common ops used in detection, segmentation, etc.
| ---------------------------- | --- | ---- | --- | --- | ------ | | ---------------------------- | --- | ---- | --- | --- | ------ |
| ActiveRotatedFilter | √ | √ | | | | | ActiveRotatedFilter | √ | √ | | | |
| AssignScoreWithK | | √ | | | | | AssignScoreWithK | | √ | | | |
| BallQuery | | √ | | | | | BallQuery | | √ | | | |
| BBoxOverlaps | | √ | √ | √ | √ | | BBoxOverlaps | | √ | √ | √ | √ |
| BorderAlign | | √ | | | | | BorderAlign | | √ | | | |
| BoxIouRotated | √ | √ | | | | | BoxIouRotated | √ | √ | | | |
...@@ -35,7 +35,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -35,7 +35,7 @@ We implement common ops used in detection, segmentation, etc.
| ModulatedDeformConv2d | √ | √ | | | √ | | ModulatedDeformConv2d | √ | √ | | | √ |
| MultiScaleDeformableAttn | | √ | √ | | | | MultiScaleDeformableAttn | | √ | √ | | |
| NMS | √ | √ | √ | | √ | | NMS | √ | √ | √ | | √ |
| NMSRotated | √ | √ | | | √ | | NMSRotated | √ | √ | | | √ |
| NMSQuadri | √ | √ | | | | | NMSQuadri | √ | √ | | | |
| PixelGroup | √ | | | | | | PixelGroup | √ | | | | |
| PointsInBoxes | √ | √ | | | | | PointsInBoxes | √ | √ | | | |
...@@ -52,13 +52,13 @@ We implement common ops used in detection, segmentation, etc. ...@@ -52,13 +52,13 @@ We implement common ops used in detection, segmentation, etc.
| SigmoidFocalLoss | | √ | √ | | √ | | SigmoidFocalLoss | | √ | √ | | √ |
| SoftmaxFocalLoss | | √ | | | √ | | SoftmaxFocalLoss | | √ | | | √ |
| SoftNMS | | √ | | | | | SoftNMS | | √ | | | |
| Sparse Convolution | | √ | | | | | Sparse Convolution | | √ | | | |
| Synchronized BatchNorm | | √ | | | | | Synchronized BatchNorm | | √ | | | |
| ThreeInterpolate | | √ | | | | | ThreeInterpolate | | √ | | | |
| ThreeNN | | √ | √ | | | | ThreeNN | | √ | √ | | |
| TINShift | | √ | √ | | | | TINShift | | √ | √ | | |
| UpFirDn2d | | √ | | | | | UpFirDn2d | | √ | | | |
| Voxelization | √ | √ | | | √ | | Voxelization | √ | √ | | | √ |
| PrRoIPool | | √ | | | | | PrRoIPool | | √ | | | |
| BezierAlign | √ | √ | | | | | BezierAlign | √ | √ | | | |
| BiasAct | | √ | | | | | BiasAct | | √ | | | |
......
...@@ -298,3 +298,59 @@ mmcv 有两个版本: ...@@ -298,3 +298,59 @@ mmcv 有两个版本:
```bash ```bash
python -c 'import mmcv;print(mmcv.__version__)' python -c 'import mmcv;print(mmcv.__version__)'
``` ```
### 在寒武纪 MLU 机器编译 mmcv-full
#### 安装 torch_mlu
##### 选项1: 基于寒武纪 docker image 安装
首先请下载并且拉取寒武纪 docker (请向 service@cambricon.com 发邮件以获得最新的寒武纪 pytorch 发布 docker)。
```
docker pull ${docker image}
```
进入 docker, [编译 MMCV MLU](#编译mmcv-mlu)[进行验证](#验证是否成功安装)
##### 选项2:基于 cambricon pytorch 源码编译安装
请向 service@cambricon.com 发送邮件或联系 Cambricon 工程师以获取合适版本的 CATCH 软件包,在您获得合适版本的 CATCH 软件包后,请参照 ${CATCH-path}/CONTRIBUTING.md 中的步骤安装 CATCH。
#### 编译 MMCV
克隆代码仓库
```bash
git clone https://github.com/open-mmlab/mmcv.git
```
算子库 mlu-ops 在编译 MMCV 时自动下载到默认路径(mmcv/mlu-ops),你也可以在编译前设置环境变量 MMCV_MLU_OPS_PATH 指向已经存在的 mlu-ops 算子库路径。
```bash
export MMCV_MLU_OPS_PATH=/xxx/xxx/mlu-ops
```
开始编译
```bash
cd mmcv
export MMCV_WITH_OPS=1
export FORCE_MLU=1
python setup.py install
```
#### 验证是否成功安装
完成上述安装步骤之后,您可以尝试运行下面的 Python 代码以测试您是否成功在 MLU 设备上安装了 mmcv-full
```python
import torch
import torch_mlu
from mmcv.ops import sigmoid_focal_loss
x = torch.randn(3, 10).mlu()
x.requires_grad = True
y = torch.tensor([1, 5, 3]).mlu()
w = torch.ones(10).float().mlu()
output = sigmoid_focal_loss(x, y, 2.0, 0.25, w, 'none')
```
...@@ -6,7 +6,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -6,7 +6,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ---------------------------- | --- | ---- | --- | --- | ------ | | ---------------------------- | --- | ---- | --- | --- | ------ |
| ActiveRotatedFilter | √ | √ | | | | | ActiveRotatedFilter | √ | √ | | | |
| AssignScoreWithK | | √ | | | | | AssignScoreWithK | | √ | | | |
| BallQuery | | √ | | | | | BallQuery | | √ | | | |
| BBoxOverlaps | | √ | √ | √ | √ | | BBoxOverlaps | | √ | √ | √ | √ |
| BorderAlign | | √ | | | | | BorderAlign | | √ | | | |
| BoxIouRotated | √ | √ | | | | | BoxIouRotated | √ | √ | | | |
...@@ -35,7 +35,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -35,7 +35,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ModulatedDeformConv2d | √ | √ | | | √ | | ModulatedDeformConv2d | √ | √ | | | √ |
| MultiScaleDeformableAttn | | √ | √ | | | | MultiScaleDeformableAttn | | √ | √ | | |
| NMS | √ | √ | √ | | √ | | NMS | √ | √ | √ | | √ |
| NMSRotated | √ | √ | | | √ | | NMSRotated | √ | √ | | | √ |
| NMSQuadri | √ | √ | | | | | NMSQuadri | √ | √ | | | |
| PixelGroup | √ | | | | | | PixelGroup | √ | | | | |
| PointsInBoxes | √ | √ | | | | | PointsInBoxes | √ | √ | | | |
...@@ -52,13 +52,13 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -52,13 +52,13 @@ MMCV 提供了检测、分割等任务中常用的算子
| SigmoidFocalLoss | | √ | √ | | √ | | SigmoidFocalLoss | | √ | √ | | √ |
| SoftmaxFocalLoss | | √ | | | √ | | SoftmaxFocalLoss | | √ | | | √ |
| SoftNMS | | √ | | | | | SoftNMS | | √ | | | |
| Sparse Convolution | | √ | | | | | Sparse Convolution | | √ | | | |
| Synchronized BatchNorm | | √ | | | | | Synchronized BatchNorm | | √ | | | |
| ThreeInterpolate | | √ | | | | | ThreeInterpolate | | √ | | | |
| ThreeNN | | √ | √ | | | | ThreeNN | | √ | √ | | |
| TINShift | | √ | √ | | | | TINShift | | √ | √ | | |
| UpFirDn2d | | √ | | | | | UpFirDn2d | | √ | | | |
| Voxelization | √ | √ | | | √ | | Voxelization | √ | √ | | | √ |
| PrRoIPool | | √ | | | | | PrRoIPool | | √ | | | |
| BezierAlign | √ | √ | | | | | BezierAlign | √ | √ | | | |
| BiasAct | | √ | | | | | BiasAct | | √ | | | |
......
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <iostream>
#include "common_mlu_helper.hpp"
#define ROI_OFFSET 5
#define FOURSPLIT 4
#define FIVESPLIT 5
#define NINESPLIT 9
#define THIRTEENSPLIT 13
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
static __mlu_func__ void bilinearInterpolate(const int input_width, T y, T x,
T *w1, T *w2, T *w3, T *w4,
int *x_low, int *x_high,
const int y_low, bool *is_empty) {
if (x < -1.0 || x > input_width) {
*is_empty = true;
return;
}
if (x <= 0) x = 0;
*x_low = int(x);
if (*x_low >= input_width - 1) {
*x_high = *x_low = input_width - 1;
x = T(*x_low);
} else {
*x_high = *x_low + 1;
}
T ly = y - y_low;
T lx = x - *x_low;
T hy = 1.0 - ly;
T hx = 1.0 - lx;
*w1 = hy * hx;
*w2 = hy * lx;
*w3 = ly * hx;
*w4 = ly * lx;
}
template <typename T>
__mlu_func__ void MLUUnion1DeformRoIPoolForward(
const T *input, const T *rois, const T *offset, T *output,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const T spatial_scale,
const int sampling_ratio, const T gamma) {
for (int bin_index = taskId;
bin_index < num_rois * pooled_width * pooled_height;
bin_index += taskDim) {
int out_batch = bin_index / pooled_width / pooled_height;
int out_height = bin_index / pooled_width % pooled_height;
int out_width = bin_index % pooled_width;
const T *cur_roi = rois + out_batch * ROI_OFFSET;
T *nram_rois = (T *)nram_buffer;
__memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T),
GDRAM2NRAM);
const int roi_batch = nram_rois[0];
T roi_x_min = nram_rois[1] * spatial_scale - 0.5;
T roi_y_min = nram_rois[2] * spatial_scale - 0.5;
const T roi_x_max = nram_rois[3] * spatial_scale - 0.5;
const T roi_y_max = nram_rois[4] * spatial_scale - 0.5;
const T roi_width = roi_x_max - roi_x_min;
const T roi_height = roi_y_max - roi_y_min;
const T bin_width = roi_width / static_cast<T>(pooled_width);
const T bin_height = roi_height / static_cast<T>(pooled_height);
const T *offset_input = input + roi_batch * height * width * channels;
int roi_bin_grid_height =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_width =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
if (offset != NULL) {
const T *offset_cur = offset +
out_batch * pooled_width * pooled_height * 2 +
out_height * pooled_width + out_width;
roi_x_min += gamma * roi_width * offset_cur[0];
roi_y_min +=
gamma * roi_height * offset_cur[pooled_width * pooled_height];
}
int type_align = NFU_ALIGN_SIZE / sizeof(T);
int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
int channels_nram_split =
channels_max_num_nram / NINESPLIT / type_align * type_align;
int channel_rem = channels % channels_nram_split;
int channel_loops =
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
for (int channel_loop_index = 0; channel_loop_index < channel_loops;
++channel_loop_index) {
int channels_num =
channels_nram_split >= channels ? channels : channels_nram_split;
const int channel_offset = channel_loop_index * channels_num;
if (channel_loop_index + 1 == channel_loops && channel_rem != 0) {
channels_num = channel_rem;
}
int channels_align = CEIL_ALIGN(channels_num, type_align);
int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - channels_align) >> 1;
int c_slice = nram_limit / FOURSPLIT / type_align * type_align;
int c_slice_align = 0;
/* NRAM partition
*
* | | ping | pong |
* |----------|-------------------|-------------------|
* | nram_out | p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 |
*
*/
T *nram_out = (T *)nram_buffer;
T *nram_ping = nram_out + channels_align;
T *nram_pong = nram_ping + nram_limit;
__bang_write_value((T *)nram_out, channels_align, (T)0);
__bang_write_value((T *)nram_ping, FOURSPLIT * c_slice, (T)0);
__bang_write_value((T *)nram_pong, FOURSPLIT * c_slice, (T)0);
const T num_bins =
static_cast<T>(max(roi_bin_grid_height * roi_bin_grid_width, 1));
const T value_div = 1.0f / num_bins;
bool is_ping_empty = true;
for (int iy = 0; iy < roi_bin_grid_height; ++iy) {
T y = roi_y_min + out_height * bin_height +
static_cast<T>(iy + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
if (y < -1.0 || y > height) {
is_ping_empty = true;
continue;
}
if (y <= 0) {
y = 0;
}
int y_low = 0, y_high = 0;
y_low = int(y);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = T(y_low);
} else {
y_high = y_low + 1;
}
for (int ix = 0; ix < roi_bin_grid_width; ++ix) {
T x = roi_x_min + out_width * bin_width +
static_cast<T>(ix + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
const int sample_index = iy * roi_bin_grid_width + ix;
int c_rem = channels_num;
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
c_slice_align = 0;
bool is_empty = false;
T w1, w2, w3, w4;
int x_low = 0, x_high = 0;
bilinearInterpolate(width, y, x, &w1, &w2, &w3, &w4, &x_low, &x_high,
y_low, &is_empty);
if (is_empty) {
is_ping_empty = true;
continue;
}
if (is_ping_empty) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
__bang_write_value(nram_ping, FOURSPLIT * c_slice_align, (T)0);
__asm__ volatile("sync;");
__memcpy(nram_ping,
offset_input + y_low * width * channels +
x_low * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping + c_slice_align,
offset_input + y_low * width * channels +
x_high * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping + 2 * c_slice_align,
offset_input + y_high * width * channels +
x_low * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping + 3 * c_slice_align,
offset_input + y_high * width * channels +
x_high * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
is_ping_empty = false;
}
int c_offset = 0;
int pongc_slice = 0;
int pongc_slice_align = 0;
while (c_rem > 0) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) {
int iy_tmp = (sample_index + 1) / roi_bin_grid_width;
int ix_tmp = (sample_index + 1) % roi_bin_grid_width;
y = roi_y_min + out_height * bin_height +
static_cast<T>(iy_tmp + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
x = roi_x_min + out_width * bin_width +
static_cast<T>(ix_tmp + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
if (y < -1.0 || y > height) {
is_empty = true;
} else {
T w1_tmp, w2_tmp, w3_tmp, w4_tmp;
if (y <= 0) {
y = 0;
}
y_low = int(y);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = T(y_low);
} else {
y_high = y_low + 1;
}
bilinearInterpolate(width, y, x, &w1_tmp, &w2_tmp, &w3_tmp,
&w4_tmp, &x_low, &x_high, y_low, &is_empty);
}
pongc_slice = nram_limit / FOURSPLIT / type_align * type_align;
pongc_slice =
pongc_slice > channels_num ? channels_num : pongc_slice;
pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align);
__bang_write_value(nram_pong, FOURSPLIT * pongc_slice_align,
(T)0);
__asm__ volatile("sync;");
if (!is_empty) {
__memcpy_async(nram_pong,
offset_input + y_low * width * channels +
x_low * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong + pongc_slice_align,
offset_input + y_low * width * channels +
x_high * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong + 2 * pongc_slice_align,
offset_input + y_high * width * channels +
x_low * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong + 3 * pongc_slice_align,
offset_input + y_high * width * channels +
x_high * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
}
}
__bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align);
__bang_mul_scalar(nram_ping + c_slice_align,
nram_ping + c_slice_align, w2, c_slice_align);
__bang_add(nram_ping, nram_ping, nram_ping + c_slice_align,
c_slice_align);
__bang_mul_scalar(nram_ping + 2 * c_slice_align,
nram_ping + 2 * c_slice_align, w3, c_slice_align);
__bang_add(nram_ping, nram_ping, nram_ping + 2 * c_slice_align,
c_slice_align);
__bang_mul_scalar(nram_ping + 3 * c_slice_align,
nram_ping + 3 * c_slice_align, w4, c_slice_align);
__bang_add(nram_ping, nram_ping, nram_ping + 3 * c_slice_align,
c_slice_align);
__bang_add(nram_out + c_offset, nram_out + c_offset, nram_ping,
c_slice_align);
T *nram_tmp = nram_ping;
nram_ping = nram_pong;
nram_pong = nram_tmp;
c_rem -= c_slice;
c_offset += c_slice;
__asm__ volatile("sync;");
}
}
}
__bang_mul_scalar(nram_out, nram_out, value_div, channels_align);
__memcpy(output + channels * bin_index + channel_offset, nram_out,
channels_num * sizeof(T), NRAM2GDRAM);
}
}
}
__mlu_global__ void MLUKernelDeformRoIPoolForward(
cnrtDataType_t data_type, const void *input, const void *rois,
const void *offset, void *output, const int channels, const int height,
const int width, const int num_rois, const int pooled_height,
const int pooled_width, const float spatial_scale, const int sampling_ratio,
const float gamma) {
switch (data_type) {
case CNRT_FLOAT16: {
MLUUnion1DeformRoIPoolForward((half *)input, (half *)rois, (half *)offset,
(half *)output, channels, height, width,
num_rois, pooled_height, pooled_width,
static_cast<half>(spatial_scale),
sampling_ratio, static_cast<half>(gamma));
}; break;
case CNRT_FLOAT32: {
MLUUnion1DeformRoIPoolForward(
(float *)input, (float *)rois, (float *)offset, (float *)output,
channels, height, width, num_rois, pooled_height, pooled_width,
static_cast<float>(spatial_scale), sampling_ratio,
static_cast<float>(gamma));
}; break;
default: {
break;
}
}
}
void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *input, const void *rois,
const void *offset, void *output,
const int channels, const int height,
const int width, const int num_rois,
const int pooled_height, const int pooled_width,
const float spatial_scale,
const int sampling_ratio, const float gamma) {
MLUKernelDeformRoIPoolForward<<<k_dim, k_type, queue>>>(
data_type, input, rois, offset, output, channels, height, width, num_rois,
pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma);
}
template <typename T>
__mlu_func__ void MLUUnion1DeformRoIPoolBackward(
const T *grad_output, const T *input, const T *rois, const T *offset,
T *grad_input, T *grad_offset, const int channels, const int height,
const int width, const int num_rois, const int pooled_height,
const int pooled_width, const T spatial_scale, const int sampling_ratio,
const T gamma) {
for (int bin_index = taskId;
bin_index < num_rois * pooled_width * pooled_height;
bin_index += taskDim) {
int out_batch = bin_index / pooled_width / pooled_height;
int out_height = bin_index / pooled_width % pooled_height;
int out_width = bin_index % pooled_width;
const T *cur_roi = rois + out_batch * ROI_OFFSET;
T *nram_rois = (T *)nram_buffer;
__memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T),
GDRAM2NRAM);
const int roi_batch = nram_rois[0];
T roi_x_min = nram_rois[1] * spatial_scale - 0.5;
T roi_y_min = nram_rois[2] * spatial_scale - 0.5;
const T roi_x_max = nram_rois[3] * spatial_scale - 0.5;
const T roi_y_max = nram_rois[4] * spatial_scale - 0.5;
const T roi_width = roi_x_max - roi_x_min;
const T roi_height = roi_y_max - roi_y_min;
const T bin_width = roi_width / static_cast<T>(pooled_width);
const T bin_height = roi_height / static_cast<T>(pooled_height);
const T *offset_input = input + roi_batch * height * width * channels;
T *offset_grad_input = grad_input + roi_batch * height * width * channels;
int roi_bin_grid_height =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_width =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
if (offset != NULL) {
const T *offset_cur = offset +
out_batch * pooled_width * pooled_height * 2 +
out_height * pooled_width + out_width;
roi_x_min += gamma * roi_width * offset_cur[0];
roi_y_min +=
gamma * roi_height * offset_cur[pooled_width * pooled_height];
}
/* NRAM partition
*
* If offset != NULL, NRAM partition belows.
* | |
* ping | pong |
* |---------------------------------------------------------------------|-----------|-----------|
* |nram_tmp1|nram_tmp2|nram_tmp3|nram_tmp4|nram_grad_output|nram_sum_tmp|p1|p2|p3|p4|p1|p2|p3|p4|
*
* If offset == NULL, ping and pang will not be needed.
* | |
* |----------------------------------------------------------------------------------|
* | nram_tmp1 | nram_tmp2 | nram_tmp3 | nram_tmp4 | nram_grad_output |
*
*/
int type_align = NFU_ALIGN_SIZE / sizeof(T);
int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
int channels_nram_split =
channels_max_num_nram / FIVESPLIT / type_align * type_align;
int channel_rem = channels % channels_nram_split;
int channel_loops =
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
if (offset != NULL) {
channels_nram_split =
channels_max_num_nram / THIRTEENSPLIT / type_align * type_align;
channel_rem = channels % channels_nram_split;
channel_loops =
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
}
for (int channel_loop_index = 0; channel_loop_index < channel_loops;
++channel_loop_index) {
int channels_num =
channels_nram_split >= channels ? channels : channels_nram_split;
const int channel_offset = channel_loop_index * channels_num;
if (channel_loop_index + 1 == channel_loops && channel_rem != 0) {
channels_num = channel_rem;
}
int channels_align = CEIL_ALIGN(channels_num, type_align);
const int32_t nram_sum_tmp_channel = NFU_ALIGN_SIZE / sizeof(T);
int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - 5 * channels_align -
nram_sum_tmp_channel) >>
1;
int c_slice = 0;
int c_slice_align = 0;
T *nram_tmp1 = (T *)nram_buffer;
T *nram_tmp2 = (T *)nram_buffer + channels_align;
T *nram_tmp3 = (T *)nram_buffer + 2 * channels_align;
T *nram_tmp4 = (T *)nram_buffer + 3 * channels_align;
T *nram_grad_output = nram_tmp4 + channels_align;
T *nram_sum_tmp = NULL;
T *nram_ping_input = NULL;
T *nram_pong_input = NULL;
__bang_write_value((T *)nram_grad_output, channels_align, (T)0);
__asm__ volatile("sync;");
if (offset != NULL) {
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
nram_sum_tmp = nram_grad_output + channels_align;
nram_ping_input = nram_sum_tmp + nram_sum_tmp_channel;
nram_pong_input = nram_ping_input + FOURSPLIT * c_slice;
__bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0);
__bang_write_value((T *)nram_ping_input, FOURSPLIT * c_slice, (T)0);
__bang_write_value((T *)nram_pong_input, FOURSPLIT * c_slice, (T)0);
__asm__ volatile("sync;");
}
const T num_bins =
static_cast<T>(max(roi_bin_grid_height * roi_bin_grid_width, 1));
const T value_div = 1.0f / num_bins;
bool is_ping_empty = true;
__memcpy(nram_grad_output,
grad_output + channels * bin_index + channel_offset,
channels_num * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(nram_grad_output, nram_grad_output, value_div,
channels_align);
for (int iy = 0; iy < roi_bin_grid_height; ++iy) {
T y = roi_y_min + out_height * bin_height +
static_cast<T>(iy + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
T y_tmp = y;
if (y_tmp < -1.0 || y_tmp > height) {
is_ping_empty = true;
continue;
}
if (y_tmp <= 0) {
y_tmp = 0;
}
int y_low = 0, y_high = 0;
y_low = int(y_tmp);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y_tmp = T(y_low);
} else {
y_high = y_low + 1;
}
for (int ix = 0; ix < roi_bin_grid_width; ++ix) {
T x = roi_x_min + out_width * bin_width +
static_cast<T>(ix + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
const int sample_index = iy * roi_bin_grid_width + ix;
int c_rem = channels_num;
bool is_empty = false;
T w1, w2, w3, w4;
int x_low = 0, x_high = 0;
bilinearInterpolate(width, y_tmp, x, &w1, &w2, &w3, &w4, &x_low,
&x_high, y_low, &is_empty);
if (is_empty) {
is_ping_empty = true;
continue;
}
__bang_mul_scalar((T *)nram_tmp1, (T *)nram_grad_output, w1,
channels_align);
__bang_mul_scalar((T *)nram_tmp2, (T *)nram_grad_output, w2,
channels_align);
__bang_mul_scalar((T *)nram_tmp3, (T *)nram_grad_output, w3,
channels_align);
__bang_mul_scalar((T *)nram_tmp4, (T *)nram_grad_output, w4,
channels_align);
__asm__ volatile("sync;");
__bang_atomic_add(
(T *)nram_tmp1,
(T *)(offset_grad_input + (y_low * width + x_low) * channels +
channel_offset),
(T *)nram_tmp1, channels_num);
__bang_atomic_add(
(T *)nram_tmp2,
(T *)(offset_grad_input + (y_low * width + x_high) * channels +
channel_offset),
(T *)nram_tmp2, channels_num);
__bang_atomic_add(
(T *)nram_tmp3,
(T *)(offset_grad_input + (y_high * width + x_low) * channels +
channel_offset),
(T *)nram_tmp3, channels_num);
__bang_atomic_add(
(T *)nram_tmp4,
(T *)(offset_grad_input + (y_high * width + x_high) * channels +
channel_offset),
(T *)nram_tmp4, channels_num);
if (offset != NULL) {
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
c_slice_align = 0;
if (is_ping_empty) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
__bang_write_value(nram_ping_input, FOURSPLIT * c_slice_align,
(T)0);
__asm__ volatile("sync;");
const T *src_offset1 = offset_input + y_low * width * channels +
x_low * channels + channel_offset;
const T *src_offset2 = offset_input + y_low * width * channels +
x_high * channels + channel_offset;
const T *src_offset3 = offset_input + y_high * width * channels +
x_low * channels + channel_offset;
const T *src_offset4 = offset_input + y_high * width * channels +
x_high * channels + channel_offset;
__memcpy(nram_ping_input, src_offset1, c_slice * sizeof(T),
GDRAM2NRAM);
__memcpy(nram_ping_input + c_slice_align, src_offset2,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping_input + 2 * c_slice_align, src_offset3,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping_input + 3 * c_slice_align, src_offset4,
c_slice * sizeof(T), GDRAM2NRAM);
is_ping_empty = false;
}
int c_offset = 0;
int pongc_slice = 0;
int pongc_slice_align = 0;
while (c_rem > 0) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) {
int iy_tmp = (sample_index + 1) / roi_bin_grid_width;
int ix_tmp = (sample_index + 1) % roi_bin_grid_width;
T y_tmp = roi_y_min + out_height * bin_height +
static_cast<T>(iy_tmp + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
T x_tmp = roi_x_min + out_width * bin_width +
static_cast<T>(ix_tmp + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
int x_low_tmp = 0, x_high_tmp = 0, y_low_tmp = 0,
y_high_tmp = 0;
if (y_tmp < -1.0 || y_tmp > height) {
is_empty = true;
} else {
T w1_tmp, w2_tmp, w3_tmp, w4_tmp;
if (y_tmp <= 0) {
y_tmp = 0;
}
y_low_tmp = int(y_tmp);
if (y_low_tmp >= height - 1) {
y_high_tmp = y_low_tmp = height - 1;
y_tmp = T(y_low_tmp);
} else {
y_high_tmp = y_low_tmp + 1;
}
bilinearInterpolate(width, y_tmp, x_tmp, &w1_tmp, &w2_tmp,
&w3_tmp, &w4_tmp, &x_low_tmp, &x_high_tmp,
y_low_tmp, &is_empty);
}
pongc_slice = nram_limit / FOURSPLIT / type_align * type_align;
pongc_slice =
pongc_slice > channels_num ? channels_num : pongc_slice;
pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align);
__bang_write_value(nram_pong_input,
FOURSPLIT * pongc_slice_align, (T)0);
__asm__ volatile("sync;");
if (!is_empty) {
const T *src_offset1 = offset_input +
y_low_tmp * width * channels +
x_low_tmp * channels + channel_offset;
const T *src_offset2 = offset_input +
y_low_tmp * width * channels +
x_high_tmp * channels + channel_offset;
const T *src_offset3 = offset_input +
y_high_tmp * width * channels +
x_low_tmp * channels + channel_offset;
const T *src_offset4 = offset_input +
y_high_tmp * width * channels +
x_high_tmp * channels + channel_offset;
__memcpy_async(nram_pong_input, src_offset1,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong_input + pongc_slice_align,
src_offset2, pongc_slice * sizeof(T),
GDRAM2NRAM);
__memcpy_async(nram_pong_input + 2 * pongc_slice_align,
src_offset3, pongc_slice * sizeof(T),
GDRAM2NRAM);
__memcpy_async(nram_pong_input + 3 * pongc_slice_align,
src_offset4, pongc_slice * sizeof(T),
GDRAM2NRAM);
}
}
__bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align,
y - y_low, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align,
y_high - y, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align,
y_low - y, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input, y - y_high,
c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_width,
c_slice_align);
__bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align);
const int32_t kernel_width =
c_slice_align / nram_sum_tmp_channel +
(int32_t)(c_slice_align % nram_sum_tmp_channel > 0);
__bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1,
kernel_width, 1, kernel_width, kernel_width, 1);
__bang_reduce_sum(nram_sum_tmp, nram_sum_tmp,
nram_sum_tmp_channel);
__bang_atomic_add(
(T *)nram_sum_tmp,
(T *)(grad_offset +
out_batch * pooled_width * pooled_height * 2 +
out_height * pooled_width + out_width),
(T *)nram_sum_tmp, 1);
__bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0);
__bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align,
x - x_low, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align,
x_high - x, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align,
x_low - x, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input, x - x_high,
c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_height,
c_slice_align);
__bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align);
__bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1,
kernel_width, 1, kernel_width, kernel_width, 1);
__bang_reduce_sum(nram_sum_tmp, nram_sum_tmp,
NFU_ALIGN_SIZE / sizeof(T));
__bang_atomic_add(
(T *)nram_sum_tmp,
(T *)(grad_offset +
out_batch * pooled_width * pooled_height * 2 +
pooled_width * pooled_height +
out_height * pooled_width + out_width),
(T *)nram_sum_tmp, 1);
T *nram_tmp = nram_ping_input;
nram_ping_input = nram_pong_input;
nram_pong_input = nram_tmp;
c_rem -= c_slice;
c_offset += c_slice;
__asm__ volatile("sync;");
}
}
}
}
}
}
}
__mlu_global__ void MLUKernelDeformRoIPoolBackward(
cnrtDataType_t data_type, const void *grad_output, const void *input,
const void *rois, const void *offset, void *grad_input, void *grad_offset,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const float spatial_scale,
const int sampling_ratio, const float gamma) {
switch (data_type) {
case CNRT_FLOAT16: {
MLUUnion1DeformRoIPoolBackward(
(half *)grad_output, (half *)input, (half *)rois, (half *)offset,
(half *)grad_input, (half *)grad_offset, channels, height, width,
num_rois, pooled_height, pooled_width,
static_cast<half>(spatial_scale), sampling_ratio,
static_cast<half>(gamma));
}; break;
case CNRT_FLOAT32: {
MLUUnion1DeformRoIPoolBackward(
(float *)grad_output, (float *)input, (float *)rois, (float *)offset,
(float *)grad_input, (float *)grad_offset, channels, height, width,
num_rois, pooled_height, pooled_width,
static_cast<float>(spatial_scale), sampling_ratio,
static_cast<float>(gamma));
}; break;
default: {
break;
}
}
}
void KernelDeformRoIPoolBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
cnrtDataType_t data_type, const void *grad_output, const void *input,
const void *rois, const void *offset, void *grad_input, void *grad_offset,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const float spatial_scale,
const int sampling_ratio, const float gamma) {
MLUKernelDeformRoIPoolBackward<<<k_dim, k_type, queue>>>(
data_type, grad_output, input, rois, offset, grad_input, grad_offset,
channels, height, width, num_rois, pooled_height, pooled_width,
spatial_scale, sampling_ratio, gamma);
}
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
/**************************************************************************************** /****************************************************************************************
* *
* NRAM partition backward: * NRAM partition backward:
* default kernel
* | grad_output_nram | grad_output_nram_temp | grad_weight | * | grad_output_nram | grad_output_nram_temp | grad_weight |
* | grad_h_weight | grad_w_weight | top_grad | * | grad_h_weight | grad_w_weight | top_grad |
* | top_grad_temp | spatial_shapes_nram | sampling_loc_nram | * | top_grad_temp | spatial_shapes_nram | sampling_loc_nram |
...@@ -39,18 +40,34 @@ ...@@ -39,18 +40,34 @@
* | deal_size | deal_size | deal_size | * | deal_size | deal_size | deal_size |
* | deal_size | deal_size | 64bytes | * | deal_size | deal_size | 64bytes |
* *
* small channel kernel
* | nram_grad_output_tl | nram_grad_output_tr | nram_grad_output_bl |
* | nram_grad_output_br | grad_temp1 | grad_temp2 |
* | grad_temp3 | grad_temp4 | nram_loc_w |
* | nram_loc_h | nram_h_low | nram_w_low |
* | nram_h_high | nram_w_high | nram_h_low_temp |
* | nram_h_high_temp | nram_hw | nram_hh |
* | nram_lw | nram_lh | nram_h_low_ptr_offset |
* | nram_h_high_ptr_offset | nram_w_low_ptr_offset | nram_w_high_ptr_offset |
* | nram_w1 | nram_w2 | nram_w3 |
* | nram_w4 | nram_grad_weight | nram_base_ptr |
* | nram_offset_temp | nram_offset1 | nram_offset2 |
* | nram_offset3 | nram_offset4 | nram_w_low_temp |
* | nram_spatial_shapes | nram_level_start_index | nram_h_stride |
****************************************************************************************/ ****************************************************************************************/
#define TWELVE_SPLIT 12 #define TWELVE_SPLIT 12
#define ALIGN_NUM 64 #define ALIGN_NUM 32
#define ALIGN_NUM_FOR_REDUCE 32 #define ALIGN_NUM_FOR_REDUCE 32
#define ELE_COUNT 32
#define LEN_FLOAT sizeof(float)
__nram__ char nram_buffer[MAX_NRAM_SIZE]; __nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T> template <typename T>
__mlu_func__ void loadNeighborPointsData( __mlu_func__ void loadNeighborPointsData(
const T *data_value_gdram, T *data_value_p1_nram, T *data_value_p2_nram, const T *data_value_gdram, T *data_value_p1_nram, T *data_value_p2_nram,
T *data_value_p3_nram, T *data_value_p4_nram, const size_t deal_num, T *data_value_p3_nram, T *data_value_p4_nram, const size_t &deal_num,
const int32_t &width, const int32_t &height, const int32_t &num_heads, const int32_t &width, const int32_t &height, const int32_t &num_heads,
const int32_t &channels, const T &x, const T &y, const int32_t &head_idx) { const int32_t &channels, const T &x, const T &y, const int32_t &head_idx) {
const int32_t w_low = floorf(x); const int32_t w_low = floorf(x);
...@@ -100,11 +117,11 @@ __mlu_func__ void loadNeighborPointsData( ...@@ -100,11 +117,11 @@ __mlu_func__ void loadNeighborPointsData(
} }
template <typename T> template <typename T>
__mlu_func__ void bilinearInterpolation( __mlu_func__ void computeMsDeformAttn(
T *data_value_p1_nram, T *data_value_p2_nram, T *data_value_p3_nram, T *data_value_p1_nram, T *data_value_p2_nram, T *data_value_p3_nram,
T *data_value_p4_nram, T *sample_point_value, T *auxiliary_b, T *data_value_p4_nram, T *sample_point_value, T *auxiliary_b,
const size_t deal_num, const int32_t &width, const int32_t &height, T *data_col_nram, const T &weight, const size_t &deal_num,
const T &x, const T &y) { const int32_t &width, const int32_t &height, const T &x, const T &y) {
const int32_t w_low = floorf(x); const int32_t w_low = floorf(x);
const int32_t h_low = floorf(y); const int32_t h_low = floorf(y);
const int32_t w_high = w_low + 1; const int32_t w_high = w_low + 1;
...@@ -156,10 +173,15 @@ __mlu_func__ void bilinearInterpolation( ...@@ -156,10 +173,15 @@ __mlu_func__ void bilinearInterpolation(
__bang_add((T *)sample_point_value, (T *)sample_point_value, __bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num); (T *)auxiliary_b, deal_num);
} }
__bang_mul_scalar((T *)sample_point_value, (T *)sample_point_value, (T)weight,
deal_num);
__bang_add((T *)data_col_nram, (T *)data_col_nram, (T *)sample_point_value,
deal_num);
} }
template <typename T> template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForward( __mlu_global__ void MLUKernelMsDeformAttnForwardDefault(
const char *data_value_gdram, const char *data_spatial_shapes_gdram, const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram, const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
...@@ -346,7 +368,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( ...@@ -346,7 +368,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForward(
// compute // compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
bilinearInterpolation( computeMsDeformAttn(
(T *)(ping_data_value_p1_nram + (T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) * ((level_idx * num_points + point_idx) % 2) *
ping_pong_gap), ping_pong_gap),
...@@ -359,15 +381,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( ...@@ -359,15 +381,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForward(
(T *)(ping_data_value_p4_nram + (T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) * ((level_idx * num_points + point_idx) % 2) *
ping_pong_gap), ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b, span_num_deal, spatial_w, (T *)auxiliary_a, (T *)auxiliary_b,
spatial_h, x, y);
__bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight,
span_num_deal);
__bang_add((T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)(ping_data_col_nram + (T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap), data_col_ping_pong_idx * ping_pong_gap),
(T *)auxiliary_a, span_num_deal); weight, span_num_deal, spatial_w, spatial_h, x, y);
} }
spatial_w = spatial_w_next_point; spatial_w = spatial_w_next_point;
...@@ -500,7 +517,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( ...@@ -500,7 +517,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForward(
// compute // compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
bilinearInterpolation( computeMsDeformAttn(
(T *)(ping_data_value_p1_nram + (T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) * ((level_idx * num_points + point_idx) % 2) *
ping_pong_gap), ping_pong_gap),
...@@ -513,15 +530,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( ...@@ -513,15 +530,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForward(
(T *)(ping_data_value_p4_nram + (T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) * ((level_idx * num_points + point_idx) % 2) *
ping_pong_gap), ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b, channels_align_rem, (T *)auxiliary_a, (T *)auxiliary_b,
spatial_w, spatial_h, x, y);
__bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight,
channels_align_rem);
__bang_add((T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)(ping_data_col_nram + (T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap), data_col_ping_pong_idx * ping_pong_gap),
(T *)auxiliary_a, channels_align_rem); weight, channels_align_rem, spatial_w, spatial_h, x, y);
} }
spatial_w = spatial_w_next_point; spatial_w = spatial_w_next_point;
...@@ -544,7 +556,494 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( ...@@ -544,7 +556,494 @@ __mlu_global__ void MLUKernelMsDeformAttnForward(
return; return;
} }
template __mlu_global__ void MLUKernelMsDeformAttnForward<float>( __mlu_func__ void genMask0101(float *mask_ram, int32_t size) {
int32_t align_num = NFU_ALIGN_SIZE / sizeof(float);
for (int32_t i = 0; i < align_num; ++i) {
mask_ram[i] = i % 2;
}
__asm__ volatile("sync;");
__memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM,
NFU_ALIGN_SIZE, 0, size / align_num - 2);
__asm__ volatile("sync;");
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
#if __BANG_ARCH__ >= 300
if (coreId == 0x80) {
return;
}
size_t block_num_per_core, batch_start, deal_g, offset_g;
size_t block_num_rem = 0;
const size_t grid_total = num_queries * num_heads * num_levels * num_points;
if (batch_size >= taskDim) {
block_num_rem = batch_size % taskDim;
block_num_per_core = taskId < block_num_rem ? batch_size / taskDim + 1
: batch_size / taskDim;
batch_start = taskId < block_num_rem
? taskId * block_num_per_core
: taskId * block_num_per_core + block_num_rem;
deal_g = grid_total;
offset_g = 0;
} else {
size_t skip_n = taskDim / batch_size;
batch_start = taskId / skip_n;
block_num_per_core = batch_start >= batch_size ? 0 : 1;
deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points);
size_t id = taskId % skip_n;
offset_g = id * deal_g;
deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1);
}
const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float);
int32_t deal_num;
int32_t cut_channel_iter = 2;
const size_t spatial_size =
PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE);
const size_t level_start_index_size =
PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE);
int32_t channel = channels;
int32_t mult;
while (true) {
deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) /
(8 * channel + 7) / sizeof(T);
deal_num = PAD_DOWN(deal_num, float_align);
deal_num = PAD_DOWN(deal_num, num_levels * num_points);
if (deal_num > 0) {
break;
} else {
channel = channels / cut_channel_iter;
cut_channel_iter += 2;
}
}
mult = channel;
const int32_t c_rep = channels / channel;
const int32_t c_rem = channels % channel;
const int32_t g_rep = deal_g / deal_num;
const int32_t g_rem = deal_g % deal_num;
// nram buffer alloc
char *data_spatial_shapes_nram = nram_buffer;
char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size;
char *input_tl = data_level_start_index_nram + level_start_index_size;
char *input_tr = input_tl + deal_num * mult * sizeof(T);
char *input_bl = input_tr + deal_num * mult * sizeof(T);
char *input_br = input_bl + deal_num * mult * sizeof(T);
char *weight_tl = input_tl + 4 * deal_num * mult * sizeof(T);
char *weight_tr = weight_tl + deal_num * mult * sizeof(T);
char *weight_bl = weight_tr + deal_num * mult * sizeof(T);
char *weight_br = weight_bl + deal_num * mult * sizeof(T);
char *mask_tl = weight_br + deal_num * mult * sizeof(T);
char *mask_tr = mask_tl + deal_num * sizeof(T);
char *mask_bl = mask_tr + deal_num * sizeof(T);
char *mask_br = mask_bl + deal_num * sizeof(T);
char *point_ram = mask_br + deal_num * sizeof(T);
char *index_tl = point_ram + deal_num * sizeof(T);
char *index_bl = index_tl + deal_num * sizeof(T);
// nram space reuse
char *grid_ram = weight_tl;
char *mask_ram = weight_bl;
char *coord_x = input_bl;
char *coord_y = coord_x + deal_num * sizeof(T);
char *coord_x_low = input_tl;
char *coord_y_low = coord_x_low + deal_num * sizeof(T);
char *coord_x_low_int = weight_tl;
char *coord_y_low_int = weight_tr;
char *spatial_x = mask_tl;
char *spatial_y = mask_tr;
char *spatial_x_float = weight_bl;
char *spatial_y_float = weight_br;
char *spatial_x_temp = mask_bl;
char *spatial_y_temp = mask_br;
char *base_ptr_offset = weight_tl;
char *auxiliary_a = point_ram;
char *auxiliary_b = weight_bl;
__memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(data_level_start_index_nram, data_level_start_index_gdram,
num_levels * sizeof(int32_t), GDRAM2NRAM);
__asm__ volatile("sync;");
for (int32_t batch_idx = batch_start;
batch_idx < batch_start + block_num_per_core; ++batch_idx) {
for (int32_t grid_iter = 0; grid_iter <= g_rep; ++grid_iter) {
int32_t io_data_num = deal_num;
const int32_t grid_off_base =
batch_idx * grid_total + offset_g + grid_iter * deal_num;
if (grid_iter == g_rep) {
if (g_rem == 0) {
continue;
} else {
io_data_num = g_rem;
}
}
char *data_col_gdram_start =
data_col_gdram + (batch_idx * num_queries * num_heads * channels +
(offset_g + grid_iter * deal_num) /
(num_levels * num_points) * channels) *
sizeof(float);
// load data_sampling_loc
__memcpy_async(
grid_ram, data_sampling_loc_gdram + grid_off_base * 2 * sizeof(float),
io_data_num * 2 * sizeof(float), GDRAM2NRAM);
genMask0101((float *)mask_ram, deal_num * 2);
__asm__ volatile("sync;");
// generate x and y coordinate vector
// generate spatial_x and spatial_y spatial vector
__bang_collect((float *)coord_y, (float *)grid_ram, (float *)mask_ram,
deal_num * 2); // y
__bang_collect((float *)spatial_x_temp, (float *)data_spatial_shapes_nram,
(float *)mask_ram,
num_levels * 2); // spatial_x
__bang_not((float *)mask_ram, (float *)mask_ram, deal_num * 2);
__bang_collect((float *)coord_x, (float *)grid_ram, (float *)mask_ram,
deal_num * 2); // x
__bang_collect((float *)spatial_y_temp, (float *)data_spatial_shapes_nram,
(float *)mask_ram,
num_levels * 2); // spatial_y
for (int32_t i = 0; i < num_levels; i++) {
__bang_write_value((int32_t *)spatial_x + i * num_points, num_points,
((int32_t *)spatial_x_temp)[i]);
__bang_write_value((int32_t *)spatial_y + i * num_points, num_points,
((int32_t *)spatial_y_temp)[i]);
}
__bang_int322float_rd((float *)spatial_x_float, (int32_t *)spatial_x,
num_levels * num_points, 0);
__bang_int322float_rd((float *)spatial_y_float, (int32_t *)spatial_y,
num_levels * num_points, 0);
// map x from [0, 1] to [0, spatial_x]; map y from [0, 1] to [0,
// spatial_y]
__bang_cycle_mul((float *)coord_x, (float *)coord_x,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_sub_scalar((float *)coord_x, (float *)coord_x, (float)0.5,
deal_num);
__bang_cycle_mul((float *)coord_y, (float *)coord_y,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5,
deal_num);
__bang_floor((float *)coord_x_low, (float *)coord_x, deal_num);
__bang_floor((float *)coord_y_low, (float *)coord_y, deal_num);
// calc index_tl
const int32_t w_stride = num_heads * channels;
__bang_float2int32_rd((int32_t *)coord_x_low_int, (float *)coord_x_low,
deal_num, 0);
__bang_float2int32_rd((int32_t *)coord_y_low_int, (float *)coord_y_low,
deal_num, 0);
__bang_cycle_mul((int32_t *)index_tl, (int32_t *)coord_y_low_int,
(int32_t *)spatial_x, deal_num, num_levels * num_points);
__bang_add((int32_t *)index_tl, (int32_t *)index_tl,
(int32_t *)coord_x_low_int, deal_num);
__bang_mul_scalar((int32_t *)index_tl, (int32_t *)index_tl, w_stride,
deal_num);
const int32_t deal_lp_num = deal_num / (num_levels * num_points);
const int32_t h_rep = deal_lp_num / num_heads;
const int32_t h_rem = deal_lp_num % num_heads;
const int32_t head_start =
((offset_g + grid_iter * deal_num) / (num_levels * num_points)) %
num_heads;
for (int32_t iter = 0; iter < num_heads; ++iter) {
((int32_t *)base_ptr_offset)[iter] =
((head_start + iter) % num_heads) * channels;
}
if (h_rep > 0) {
__memcpy((int32_t *)base_ptr_offset + num_heads,
(int32_t *)base_ptr_offset, num_heads * sizeof(int32_t),
NRAM2NRAM, num_heads * sizeof(int32_t), 0, h_rep - 1);
}
if (h_rep > 0 && h_rem > 0) {
__memcpy((int32_t *)base_ptr_offset + h_rep * num_heads,
(int32_t *)base_ptr_offset, h_rem * sizeof(int32_t),
NRAM2NRAM);
}
__bang_transpose((int32_t *)auxiliary_a, (int32_t *)index_tl, deal_lp_num,
num_levels * num_points);
__bang_cycle_add((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)base_ptr_offset, deal_num, deal_lp_num);
__bang_transpose((int32_t *)index_tl, (int32_t *)auxiliary_a,
num_levels * num_points, deal_lp_num);
// calc index_bl
__bang_mul_scalar((int32_t *)auxiliary_a, (int32_t *)spatial_x, w_stride,
deal_num);
__bang_cycle_add((int32_t *)index_bl, (int32_t *)index_tl,
(int32_t *)auxiliary_a, deal_num,
num_levels * num_points);
// calc mask_tl, mask_tr, mask_bl, mask_br
__bang_sub_scalar((float *)spatial_x_float, (float *)spatial_x_float,
(float)1.0, deal_num);
__bang_sub_scalar((float *)spatial_y_float, (float *)spatial_y_float,
(float)1.0, deal_num);
// mask_tl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_low < spatial_y
__bang_ge_scalar((float *)mask_bl, (float *)coord_x_low, (float)0,
deal_num);
__bang_cycle_le((float *)mask_br, (float *)coord_x_low,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_bl, (float *)mask_bl, (float *)mask_br,
deal_num);
__bang_ge_scalar((float *)mask_tr, (float *)coord_y_low, (float)0,
deal_num);
__bang_cycle_le((float *)mask_br, (float *)coord_y_low,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br,
deal_num);
__bang_and((float *)mask_tl, (float *)mask_tr, (float *)mask_bl,
deal_num);
// mask_tr : 0 <= coord_x_high < spatial_x && 0 <= coord_y_low < spatial_y
__bang_ge_scalar((float *)mask_br, (float *)coord_x_low, (float)(-1.0),
deal_num);
__bang_cycle_lt((float *)auxiliary_a, (float *)coord_x_low,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
__bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br,
deal_num);
// mask_bl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_high < spatial_y
__bang_ge_scalar((float *)auxiliary_a, (float *)coord_y_low,
(float)(-1.0), deal_num);
__bang_cycle_lt((float *)auxiliary_b, (float *)coord_y_low,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_and((float *)auxiliary_a, (float *)auxiliary_a,
(float *)auxiliary_b, deal_num);
__bang_and((float *)mask_bl, (float *)mask_bl, (float *)auxiliary_a,
deal_num);
// mask_br : 0 <= coord_x_high < spatial_x && 0 <= coord_y_high <
// spatial_y
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
// calc inner point num
__bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0,
deal_num);
__bang_mul_scalar((float *)weight_tr, (float *)mask_tr, (float)5.0,
deal_num);
__bang_add((float *)weight_tl, (float *)weight_tl, (float *)weight_tr,
deal_num);
__bang_mul_scalar((float *)weight_tr, (float *)mask_bl, (float)3.0,
deal_num);
__bang_add((float *)point_ram, (float *)weight_tr, (float *)mask_br,
deal_num);
__bang_add((float *)point_ram, (float *)point_ram, (float *)weight_tl,
deal_num);
// calc interpolation weight
__bang_sub((float *)weight_bl, (float *)coord_x_low, (float *)coord_x,
deal_num);
__bang_sub((float *)weight_br, (float *)coord_y_low, (float *)coord_y,
deal_num);
__bang_add_scalar((float *)weight_bl, (float *)weight_bl, (float)1.0,
deal_num);
__bang_add_scalar((float *)weight_br, (float *)weight_br, (float)1.0,
deal_num);
__bang_sub((float *)weight_tl, (float *)coord_x, (float *)coord_x_low,
deal_num);
__bang_sub((float *)weight_tr, (float *)coord_y, (float *)coord_y_low,
deal_num);
__bang_mul((float *)input_tl, (float *)weight_bl, (float *)weight_br,
deal_num);
__bang_mul((float *)input_tl + deal_num, (float *)weight_br,
(float *)weight_tl, deal_num);
__bang_mul((float *)input_tl + 2 * deal_num, (float *)weight_bl,
(float *)weight_tr, deal_num);
__bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl,
(float *)weight_tr, deal_num);
__asm__ volatile("sync;");
// extend weight
const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT;
const int32_t w_rem = channel % ELE_COUNT;
if (w_rem != 0) {
const int32_t data_sz = 1 * sizeof(float);
const int32_t dst_str = channel * sizeof(float);
for (int32_t iter = w_rep; iter < channel; ++iter) {
__memcpy_async((float *)weight_tl + iter, (float *)input_tl, data_sz,
NRAM2NRAM, dst_str, data_sz, 4 * deal_num - 1);
}
}
if (w_rep != 0) {
for (int32_t i = 0; i < 4 * deal_num; i++) {
__bang_write_value((float *)weight_tl + i * channel, w_rep,
((float *)input_tl)[i]);
}
}
__asm__ volatile("sync;");
const char *data_value_gdram_start =
data_value_gdram +
batch_idx * num_keys * num_heads * channels * sizeof(float);
const int32_t c_str = deal_num * channel * sizeof(float);
const int32_t cs_str = num_heads * channels * sizeof(float);
for (int32_t c_iter = 0; c_iter <= c_rep; ++c_iter) {
int32_t c_real_num = channel;
if (c_iter == c_rep) {
if (c_rem == 0) {
continue;
} else {
c_real_num = c_rem;
}
}
__bang_write_zero((float *)input_tl, 4 * deal_num * channel);
__asm__ volatile("sync;");
// load data_value
for (int32_t p_idx = 0; p_idx < io_data_num; ++p_idx) {
const int32_t inner_point_num = (int32_t)((float *)point_ram)[p_idx];
const int32_t tl_offset = ((int32_t *)index_tl)[p_idx];
const int32_t bl_offset = ((int32_t *)index_bl)[p_idx];
const int32_t level_start_id =
((int32_t *)data_level_start_index_nram)[(p_idx / num_points) %
num_levels];
const char *data_value_ptr =
data_value_gdram_start +
(level_start_id * num_heads * channels + c_iter * channel) *
sizeof(float);
switch (inner_point_num) {
case 16: // 4 points are cached.
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 12: // 2 points are cached. (top_left, top_right)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 4: // 2 points are cached. (bottom_left, bottom_right)
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 10: // 2 points are cached. (top_left, bottom_left)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 6: // 2 points are cached. (top_right, bottom_right)
__memcpy_async(
(float *)input_tr + p_idx * channel,
(float *)data_value_ptr + tl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
__memcpy_async(
(float *)input_br + p_idx * channel,
(float *)data_value_ptr + bl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 7: // 1 point is cached. (top_left)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 5: // 1 point is cached. (top_right)
__memcpy_async(
(float *)input_tr + p_idx * channel,
(float *)data_value_ptr + tl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 3: // 1 point is cached. (bottom_left)
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 1: // 1 point is cached. (bottom_right)
__memcpy_async(
(float *)input_br + p_idx * channel,
(float *)data_value_ptr + bl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
default:
continue;
}
}
__asm__ volatile("sync;");
// interpolation
__bang_mul((float *)input_tl, (float *)input_tl, (float *)weight_tl,
4 * deal_num * channel);
__bang_add((float *)input_tl, (float *)input_tl, (float *)input_bl,
2 * deal_num * channel);
__bang_add((float *)input_tl, (float *)input_tl, (float *)input_tr,
deal_num * channel);
// load attention weight
void *attn_weight = mask_tl;
__memcpy((float *)attn_weight,
(float *)data_attn_weight_gdram + grid_off_base,
io_data_num * sizeof(float), GDRAM2NRAM);
// calc data_col, muladd attention weight
__bang_transpose((float *)input_tr, (float *)input_tl, deal_num,
channel);
__bang_cycle_mul((float *)input_tr, (float *)input_tr,
(float *)attn_weight, deal_num * channel, deal_num);
__bang_transpose((float *)input_tl, (float *)input_tr, channel,
deal_num);
__bang_sumpool((float *)input_bl, (float *)input_tl, channel, 1,
io_data_num, 1, num_levels * num_points,
num_levels * num_points, 1);
// store
__memcpy((float *)data_col_gdram_start + c_iter * channel,
(float *)input_bl, c_real_num * sizeof(float), NRAM2GDRAM,
channels * sizeof(float), channel * sizeof(float),
(io_data_num / (num_levels * num_points)) - 1);
}
}
}
__asm__ volatile("sync;");
#endif
return;
}
template __mlu_global__ void MLUKernelMsDeformAttnForwardDefault<float>(
const char *data_value_gdram, const char *data_spatial_shapes_gdram, const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram, const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
...@@ -552,7 +1051,7 @@ template __mlu_global__ void MLUKernelMsDeformAttnForward<float>( ...@@ -552,7 +1051,7 @@ template __mlu_global__ void MLUKernelMsDeformAttnForward<float>(
const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram); const int32_t num_points, char *data_col_gdram);
void KernelMsDeformAttnForward( void KernelMsDeformAttnForwardDefault(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram, const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram, const char *data_spatial_shapes_gdram,
...@@ -561,7 +1060,30 @@ void KernelMsDeformAttnForward( ...@@ -561,7 +1060,30 @@ void KernelMsDeformAttnForward(
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) { const int32_t num_points, char *data_col_gdram) {
MLUKernelMsDeformAttnForward<float><<<k_dim, k_type, queue>>>( MLUKernelMsDeformAttnForwardDefault<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
}
template __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel<float>(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram);
void KernelMsDeformAttnForwardSmallChannel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
MLUKernelMsDeformAttnForwardSmallChannel<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
...@@ -584,15 +1106,15 @@ void __mlu_func__ msDeformAttnCol2imBilinear( ...@@ -584,15 +1106,15 @@ void __mlu_func__ msDeformAttnCol2imBilinear(
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram, data_value_ptr + offset1, __memcpy(grad_output_nram, data_value_ptr + offset1,
deal_num_real * sizeof(T), GDRAM2NRAM); deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num); __bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num_real);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num); __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num); __bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num_real);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num); __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num); __bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num_real);
// for calc grad_attn_weight // for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num); __bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset1), __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset1),
(T *)top_grad_temp, deal_num_real); (T *)top_grad_temp, deal_num_real);
} }
...@@ -600,18 +1122,18 @@ void __mlu_func__ msDeformAttnCol2imBilinear( ...@@ -600,18 +1122,18 @@ void __mlu_func__ msDeformAttnCol2imBilinear(
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset2, __memcpy(grad_output_nram_temp, data_value_ptr + offset2,
deal_num_real * sizeof(T), GDRAM2NRAM); deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num); __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num); __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num); __bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num_real);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num); __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num); __bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num_real);
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w2, __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w2,
deal_num); deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num); deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset2), __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset2),
(T *)top_grad_temp, deal_num_real); (T *)top_grad_temp, deal_num_real);
} }
...@@ -619,18 +1141,18 @@ void __mlu_func__ msDeformAttnCol2imBilinear( ...@@ -619,18 +1141,18 @@ void __mlu_func__ msDeformAttnCol2imBilinear(
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset3, __memcpy(grad_output_nram_temp, data_value_ptr + offset3,
deal_num_real * sizeof(T), GDRAM2NRAM); deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num); __bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num_real);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num); __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num); __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num); __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num); __bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num_real);
// for calc grad_attn_weight // for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w3, __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w3,
deal_num); deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num); deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset3), __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset3),
(T *)top_grad_temp, deal_num_real); (T *)top_grad_temp, deal_num_real);
} }
...@@ -638,63 +1160,61 @@ void __mlu_func__ msDeformAttnCol2imBilinear( ...@@ -638,63 +1160,61 @@ void __mlu_func__ msDeformAttnCol2imBilinear(
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset4, __memcpy(grad_output_nram_temp, data_value_ptr + offset4,
deal_num_real * sizeof(T), GDRAM2NRAM); deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num); __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num); __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num); __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num); __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num); __bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num_real);
// for calc grad_attn_weight // for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w4, __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w4,
deal_num); deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num); deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset4), __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset4),
(T *)top_grad_temp, deal_num_real); (T *)top_grad_temp, deal_num_real);
} }
__bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num); __bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num_real);
#if __BANG_ARCH__ >= 322 #if __BANG_ARCH__ >= 322
recursiveSumPool(grad_output_nram, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); recursiveSumPool(grad_output_nram, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else #else
const int32_t align_num_on_200 = NFU_ALIGN_SIZE / sizeof(float); const int32_t align_num_on_200 = NFU_ALIGN_SIZE / LEN_FLOAT;
recursiveSumPool(grad_output_nram, align_num_on_200, recursiveSumPool(grad_output_nram, align_num_on_200,
deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE); deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_output_nram, grad_output_nram, __bang_reduce_sum(grad_output_nram, grad_output_nram,
NFU_ALIGN_SIZE / sizeof(float)); NFU_ALIGN_SIZE / LEN_FLOAT);
#endif #endif
__bang_atomic_add((T *)grad_output_nram, (T *)grad_attn_weight, __bang_atomic_add((T *)grad_output_nram, (T *)grad_attn_weight,
(T *)grad_output_nram, 1); (T *)grad_output_nram, 1);
__bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num); __bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num); __bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real);
#if __BANG_ARCH__ >= 322 #if __BANG_ARCH__ >= 322
recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else #else
recursiveSumPool(grad_w_weight, align_num_on_200, deal_num / align_num_on_200, recursiveSumPool(grad_w_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE); ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_w_weight, grad_w_weight, __bang_reduce_sum(grad_w_weight, grad_w_weight, NFU_ALIGN_SIZE / LEN_FLOAT);
NFU_ALIGN_SIZE / sizeof(float));
#endif #endif
__bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc), __bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc),
(T *)grad_w_weight, 1); (T *)grad_w_weight, 1);
__bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num); __bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real);
__bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num); __bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real);
#if __BANG_ARCH__ >= 322 #if __BANG_ARCH__ >= 322
recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else #else
recursiveSumPool(grad_h_weight, align_num_on_200, deal_num / align_num_on_200, recursiveSumPool(grad_h_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE); ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_h_weight, grad_h_weight, __bang_reduce_sum(grad_h_weight, grad_h_weight, NFU_ALIGN_SIZE / LEN_FLOAT);
NFU_ALIGN_SIZE / sizeof(float));
#endif #endif
__bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1), __bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1),
(T *)grad_h_weight, 1); (T *)grad_h_weight, 1);
} }
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
const float *data_value, const int32_t *spatial_shapes, const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc, const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output, const float *data_attn_weight, const float *grad_output,
...@@ -708,8 +1228,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( ...@@ -708,8 +1228,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
const int32_t split_num = 8; const int32_t split_num = 8;
const int32_t spatial_shapes_size = 64; const int32_t spatial_shapes_size = 64;
int32_t deal_num = PAD_DOWN( int32_t deal_num = PAD_DOWN(
(MAX_NRAM_SIZE - spatial_shapes_size) / split_num / sizeof(float), (MAX_NRAM_SIZE - spatial_shapes_size) / split_num / LEN_FLOAT, ALIGN_NUM);
ALIGN_NUM);
float *grad_output_nram = (float *)nram_buffer; float *grad_output_nram = (float *)nram_buffer;
float *grad_output_nram_temp = (float *)nram_buffer + deal_num; float *grad_output_nram_temp = (float *)nram_buffer + deal_num;
float *grad_weight = (float *)nram_buffer + 2 * deal_num; float *grad_weight = (float *)nram_buffer + 2 * deal_num;
...@@ -725,10 +1244,8 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( ...@@ -725,10 +1244,8 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
int32_t num_per_core = total_num / taskDim; int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim; int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem); num_per_core = num_per_core + int32_t(taskId < num_rem);
int32_t start_per_core = int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core)
num_rem > taskId : (num_rem + taskId * num_per_core);
? (taskId * num_per_core)
: ((num_per_core + 1) * num_rem + (taskId - num_rem) * num_per_core);
int32_t end_per_core = start_per_core + num_per_core; int32_t end_per_core = start_per_core + num_per_core;
const int32_t C_repeat = channels / deal_num; const int32_t C_repeat = channels / deal_num;
const int32_t C_tail = channels % deal_num; const int32_t C_tail = channels % deal_num;
...@@ -758,7 +1275,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( ...@@ -758,7 +1275,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
const int32_t grad_sampling_loc_out = num_loop * num_points * 2; const int32_t grad_sampling_loc_out = num_loop * num_points * 2;
for (int32_t p_col = 0; p_col < num_points; ++p_col) { for (int32_t p_col = 0; p_col < num_points; ++p_col) {
__memcpy(sampling_loc_nram, data_sampling_loc + data_loc_w_ptr, __memcpy(sampling_loc_nram, data_sampling_loc + data_loc_w_ptr,
2 * sizeof(float), GDRAM2NRAM); 2 * LEN_FLOAT, GDRAM2NRAM);
const float loc_w = sampling_loc_nram[0]; const float loc_w = sampling_loc_nram[0];
const float loc_h = sampling_loc_nram[1]; const float loc_h = sampling_loc_nram[1];
const float weight = data_attn_weight[data_weight_ptr]; const float weight = data_attn_weight[data_weight_ptr];
...@@ -789,11 +1306,12 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( ...@@ -789,11 +1306,12 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
for (int32_t C_loop = 0; C_loop < C_repeat; ++C_loop) { for (int32_t C_loop = 0; C_loop < C_repeat; ++C_loop) {
base_ptr = m_col * channels + C_loop * deal_num; base_ptr = m_col * channels + C_loop * deal_num;
__bang_write_zero(grad_weight, 3 * deal_num); __bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram, deal_num); __bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM));
__memcpy(top_grad, __memcpy(top_grad,
grad_output + grad_output_offset + C_loop * deal_num, grad_output + grad_output_offset + C_loop * deal_num,
deal_num * sizeof(float), GDRAM2NRAM); deal_num * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imBilinear( msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low, top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset, h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
...@@ -806,10 +1324,12 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( ...@@ -806,10 +1324,12 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
} }
if (C_tail != 0) { if (C_tail != 0) {
base_ptr = m_col * channels + C_repeat * deal_num; base_ptr = m_col * channels + C_repeat * deal_num;
__bang_write_zero(grad_output_nram, 8 * deal_num); __bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM));
__memcpy(top_grad, __memcpy(top_grad,
grad_output + grad_output_offset + C_repeat * deal_num, grad_output + grad_output_offset + C_repeat * deal_num,
C_tail * sizeof(float), GDRAM2NRAM); C_tail * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imBilinear( msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low, top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset, h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
...@@ -827,7 +1347,711 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( ...@@ -827,7 +1347,711 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
} }
} }
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( void __mlu_func__ computeGridMaskAndOffset(
float *nram_grad_output_tl, float *nram_grad_output_tr, float *nram_loc_w,
float *nram_loc_h, float *nram_h_stride, int32_t *nram_spatial_shapes,
float *nram_w_low_temp, float *nram_h_high_temp, float *nram_w_low,
float *nram_h_low, float *nram_h_high, float *nram_w_high, float *nram_lh,
float *nram_lw, float *nram_hh, float *nram_hw,
float *nram_h_low_ptr_offset, float *nram_h_high_ptr_offset,
float *nram_w_low_ptr_offset, float *nram_w_high_ptr_offset, float *nram_w1,
float *nram_w2, float *nram_w3, float *nram_w4, float *nram_offset_temp,
float *nram_offset1, float *nram_offset2, float *nram_offset3,
float *nram_offset4, float *nram_base_ptr, float *nram_h_low_temp,
int32_t num_deal_grid, int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, const int32_t w_stride,
const int32_t qid_stride) {
#if __BANG_ARCH__ >= 322
// [num_levels, 2] --> [2, num_levels]
__bang_transpose(nram_grad_output_tl, nram_loc_w, num_deal_grid, 2);
__bang_transpose(nram_loc_w, nram_grad_output_tl,
num_per_time_real * num_heads * num_levels, num_points);
__bang_transpose(nram_loc_h, nram_grad_output_tl + num_deal_grid,
num_per_time_real * num_heads * num_levels, num_points);
__bang_int322float((float *)nram_spatial_shapes,
(int32_t *)nram_spatial_shapes, num_levels * 2, 0);
__bang_transpose(nram_grad_output_tr, (float *)nram_spatial_shapes,
num_levels, 2);
__bang_mul_scalar(nram_h_stride, nram_grad_output_tr + num_levels, w_stride,
num_levels);
__memcpy_async(nram_spatial_shapes, nram_grad_output_tr,
num_levels * 2 * sizeof(float), NRAM2NRAM);
__bang_cycle_mul(nram_loc_w, nram_loc_w,
(float *)nram_spatial_shapes + num_levels, num_deal_grid,
num_levels);
__bang_cycle_mul(nram_loc_h, nram_loc_h, (float *)(nram_spatial_shapes),
num_deal_grid, num_levels);
__bang_sub_scalar(nram_loc_w, nram_loc_w, 0.5, num_deal_grid);
__bang_sub_scalar(nram_loc_h, nram_loc_h, 0.5, num_deal_grid);
// get mask. (h_im > -1 && w_im > -1 &&
// h_im < spatial_h && w_im < spatial_w)
__bang_cycle_lt(nram_w_low_temp, nram_loc_w,
(float *)(nram_spatial_shapes + num_levels), num_deal_grid,
num_levels);
__bang_cycle_lt(nram_h_high_temp, nram_loc_h, (float *)(nram_spatial_shapes),
num_deal_grid, num_levels);
__bang_and(nram_w_low_temp, nram_w_low_temp, nram_h_high_temp, num_deal_grid);
__bang_gt_scalar(nram_h_high_temp, nram_loc_h, -1, num_deal_grid);
__bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp,
num_deal_grid);
__bang_gt_scalar(nram_w_low_temp, nram_loc_w, -1, num_deal_grid);
__bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp,
num_deal_grid);
__bang_transpose(nram_w_low_temp, nram_h_high_temp, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_high_temp, nram_w_low_temp,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, nram_loc_w, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_loc_w, nram_grad_output_tl, num_deal_grid * sizeof(float),
NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, nram_loc_h, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_loc_h, nram_grad_output_tl, num_deal_grid * sizeof(float),
NRAM2NRAM);
__bang_floor(nram_w_low, nram_loc_w, num_deal_grid);
__bang_floor(nram_h_low, nram_loc_h, num_deal_grid);
__bang_add_scalar(nram_h_high, nram_h_low, 1, num_deal_grid);
__bang_add_scalar(nram_w_high, nram_w_low, 1, num_deal_grid);
__bang_sub(nram_lh, nram_loc_h, nram_h_low, num_deal_grid);
__bang_sub(nram_lw, nram_loc_w, nram_w_low, num_deal_grid);
__bang_fusion(FUSION_FMA, nram_hh, nram_lh, (float)(-1), 1, num_deal_grid);
__bang_fusion(FUSION_FMA, nram_hw, nram_lw, (float)(-1), 1, num_deal_grid);
__bang_transpose(nram_h_low_ptr_offset, nram_h_low,
num_per_time_real * num_heads * num_levels, num_points);
__bang_cycle_mul(nram_h_low_ptr_offset, nram_h_low_ptr_offset, nram_h_stride,
num_deal_grid, num_levels);
__bang_cycle_add(nram_h_high_ptr_offset, nram_h_low_ptr_offset, nram_h_stride,
num_deal_grid, num_levels);
__bang_transpose(nram_w_low_ptr_offset, nram_h_low_ptr_offset, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_low_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_w_low_ptr_offset, nram_h_high_ptr_offset, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_high_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_mul_scalar(nram_w_low_ptr_offset, nram_w_low, qid_stride,
num_deal_grid);
__bang_add_scalar(nram_w_high_ptr_offset, nram_w_low_ptr_offset, qid_stride,
num_deal_grid);
__bang_mul(nram_w1, nram_hh, nram_hw, num_deal_grid);
__bang_mul(nram_w2, nram_hh, nram_lw, num_deal_grid);
__bang_mul(nram_w3, nram_lh, nram_hw, num_deal_grid);
__bang_mul(nram_w4, nram_lh, nram_lw, num_deal_grid);
__bang_add(nram_offset1, nram_h_low_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset1,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset1, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset2, nram_h_low_ptr_offset, nram_w_high_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset2,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset2, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset3, nram_h_high_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset3,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset3, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset4, nram_h_high_ptr_offset, nram_w_high_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset4,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset4, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
// h_low >= 0 && w_low >= 0 mask2
float *mask1 = nram_h_low_ptr_offset;
float *mask2 = nram_h_high_ptr_offset;
float *mask3 = nram_w_low_ptr_offset;
float *mask4 = nram_w_high_ptr_offset;
__bang_ge_scalar(mask1, nram_h_low, 0, num_deal_grid);
__bang_ge_scalar(mask2, nram_w_low, 0, num_deal_grid);
__bang_and(mask2, mask1, mask2, num_deal_grid);
__bang_and(mask2, nram_h_high_temp, mask2, num_deal_grid);
// h_low >= 0 && w_high <= width - 1 mask1
__bang_transpose(mask3, nram_w_high,
num_per_time_real * num_heads * num_levels, num_points);
__bang_sub_scalar(nram_spatial_shapes, nram_spatial_shapes, 1,
num_levels * 2);
__bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes + num_levels),
num_deal_grid, num_levels);
__bang_transpose(mask4, mask3, num_points,
num_per_time_real * num_heads * num_levels);
__bang_and(mask1, mask1, mask4, num_deal_grid);
__bang_and(mask1, nram_h_high_temp, mask1, num_deal_grid);
// h_high <= height - 1 && w_high <= width - 1 mask3
__bang_transpose(mask3, nram_h_high,
num_per_time_real * num_heads * num_levels, num_points);
__bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes), num_deal_grid,
num_levels);
__bang_transpose(nram_h_low_temp, mask3, num_points,
num_per_time_real * num_heads * num_levels);
__bang_and(mask4, mask4, nram_h_low_temp, num_deal_grid);
__bang_and(mask3, mask4, nram_h_high_temp, num_deal_grid);
// h_high <= height - 1 && w_low >= 0 mask4
__bang_ge_scalar(nram_w_low_temp, nram_w_low, 0, num_deal_grid);
__bang_and(mask4, nram_h_low_temp, nram_w_low_temp, num_deal_grid);
__bang_and(mask4, mask4, nram_h_high_temp, num_deal_grid);
#endif
}
void __mlu_func__ loadValue(
float *nram_grad_output_tl, float *nram_grad_output_tr,
float *nram_grad_output_bl, float *nram_grad_output_br,
const float *data_value, const float *grad_output, float *grad_temp1,
float *grad_temp2, float *mask1, float *mask2, float *mask3, float *mask4,
float *nram_offset1, float *nram_offset2, float *nram_offset3,
float *nram_offset4, float *nram_grad_weight,
int32_t *nram_level_start_index, int32_t offset_nram,
int32_t start_per_core, int32_t grid_loop, int32_t num_per_time_theory,
int32_t num_heads, int32_t deal_num_real, int32_t num_per_time_real,
int32_t num_deal_grid, const int32_t num_query, const int32_t num_levels,
const int32_t num_points, int32_t grid_offset, const int32_t spatial_size,
const int32_t qid_stride) {
#if __BANG_ARCH__ >= 322
int32_t value_offset_temp = 0;
__bang_write_zero(nram_grad_output_tl, 4 * offset_nram);
__sync_io_move_compute();
__memcpy_async(
grad_temp2,
grad_output + (start_per_core + grid_loop * num_per_time_theory) *
num_heads * deal_num_real,
num_per_time_real * num_heads * deal_num_real * sizeof(float),
GDRAM2NRAM);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
const int32_t b_col =
(grid_offset + loop) / num_query / num_heads / num_levels / num_points;
const int32_t l_col = (grid_offset + loop) / num_points % num_levels;
const int32_t level_start_id = nram_level_start_index[l_col];
value_offset_temp =
b_col * spatial_size * qid_stride + level_start_id * qid_stride;
if (mask2[loop]) {
__memcpy_async(
nram_grad_output_tl + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset1[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask1[loop]) {
__memcpy_async(
nram_grad_output_tr + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset2[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask4[loop]) {
__memcpy_async(
nram_grad_output_bl + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset3[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask3[loop]) {
__memcpy_async(
nram_grad_output_br + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset4[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
}
for (int32_t m = 0; m < deal_num_real; ++m) {
__memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
}
__sync_io_move_compute();
#endif
}
void __mlu_func__ computeGradValue(
float *grad_temp1, float *grad_temp2, float *grad_temp3, float *grad_temp4,
float *mask1, float *mask2, float *mask3, float *mask4, float *nram_offset1,
float *nram_offset2, float *nram_offset3, float *nram_offset4,
int32_t *nram_level_start_index, int32_t deal_num_real,
const float *grad_value, float *nram_w1, float *nram_w2, float *nram_w3,
float *nram_w4, int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, const int32_t num_query,
int32_t num_deal_grid, int32_t grid_offset, const int32_t spatial_size,
const int32_t qid_stride, float *nram_grid_offset1,
float *nram_grid_offset2) {
#if __BANG_ARCH__ >= 322
__bang_transpose(grad_temp3, grad_temp1,
deal_num_real * num_per_time_real * num_heads,
num_levels * num_points);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(grad_temp3, grad_temp3, grad_temp1,
num_deal_grid * deal_num_real,
deal_num_real * num_per_time_real * num_heads);
__bang_transpose(grad_temp4, grad_temp3, num_levels * num_points,
deal_num_real * num_per_time_real * num_heads);
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w1,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
nram_grid_offset1[loop] = ((loop + grid_offset) / num_query / num_heads /
num_levels / num_points) *
spatial_size * qid_stride;
}
__bang_transpose(nram_grid_offset2, nram_grid_offset1,
num_per_time_real * num_heads * num_levels, num_points);
__bang_int322float((float *)nram_level_start_index, nram_level_start_index,
num_levels, 0);
__bang_mul_scalar(nram_grid_offset1, (float *)nram_level_start_index,
qid_stride, num_levels);
__bang_cycle_add(nram_grid_offset2, nram_grid_offset2, nram_grid_offset1,
num_deal_grid, num_levels);
__bang_transpose(nram_grid_offset1, nram_grid_offset2, num_points,
num_per_time_real * num_heads * num_levels);
__bang_add(nram_offset1, nram_offset1, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset2, nram_offset2, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset3, nram_offset3, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset4, nram_offset4, nram_grid_offset1, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask2[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset1[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w2,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask1[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset2[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w3,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask4[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset3[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w4,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask3[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset4[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
#endif
}
void __mlu_func__ computeGradAttnWeight(
float *grad_w_weight, float *grad_weight, float *nram_grad_output_tl,
float *nram_grad_output_tr, float *nram_grad_output_bl,
float *nram_grad_output_br, float *grad_temp1, float *grad_temp2,
const float *grad_attn_weight, float *nram_hw, float *nram_hh,
float *nram_lw, float *nram_lh, float *grad_h_weight, float *nram_w1,
float *nram_w2, float *nram_w3, float *nram_w4, int32_t offset_nram,
int32_t num_deal_grid, int32_t deal_num_real, int32_t num_per_time_real,
const int32_t num_heads, const int32_t num_levels, const int32_t num_points,
int32_t grid_offset, float *nram_h_high_temp) {
#if __BANG_ARCH__ >= 322
__bang_write_zero(grad_w_weight, 2 * offset_nram);
// grad_output_nram_tl
__bang_transpose(grad_weight, nram_grad_output_tl, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_w1,
num_deal_grid * deal_num_real, num_deal_grid);
// nram_grad_output_tr
__bang_transpose(grad_weight, nram_grad_output_tr, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_lw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tr,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_hh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_w_weight, grad_w_weight, nram_grad_output_tr,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_w2,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_tr,
num_deal_grid * deal_num_real);
// nram_grad_output_tl
__bang_transpose(grad_weight, nram_grad_output_bl, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_hw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_h_weight, grad_h_weight, nram_grad_output_bl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_lh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_bl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_w3,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_bl,
num_deal_grid * deal_num_real);
// nram_grad_output_br
__bang_transpose(grad_weight, nram_grad_output_br, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_h_weight, grad_h_weight, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_w_weight, grad_w_weight, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_w4,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_br, nram_grad_output_tl, deal_num_real,
num_deal_grid);
__bang_transpose(nram_grad_output_tr, nram_grad_output_br,
num_per_time_real * num_heads,
num_points * num_levels * deal_num_real);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1,
num_deal_grid * deal_num_real,
num_per_time_real * num_heads * deal_num_real);
__bang_transpose(nram_grad_output_br, nram_grad_output_tr,
num_points * num_levels * deal_num_real,
num_per_time_real * num_heads);
__bang_transpose((float *)nram_grad_output_tr, (float *)nram_grad_output_br,
num_deal_grid, deal_num_real);
recursiveSumPool(nram_grad_output_tr, num_deal_grid, deal_num_real,
ALIGN_NUM);
__bang_float2int32((int *)nram_h_high_temp, nram_h_high_temp, num_deal_grid,
0);
__nram__ int table[2] = {0, (int)0xffffffff};
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)nram_grad_output_tr, (char *)nram_grad_output_tr,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__bang_atomic_add((float *)nram_grad_output_tr,
(float *)grad_attn_weight + grid_offset,
(float *)nram_grad_output_tr, num_deal_grid);
#endif
}
void __mlu_func__ computeGradSampingLoc(
const float *grad_sampling_loc, float *nram_grad_output_tl,
float *nram_grad_output_tr, float *grad_h_weight, float *grad_w_weight,
int32_t *nram_spatial_shapes, float *grad_temp1, float *grad_temp2,
float *nram_grad_weight, int32_t num_deal_grid, int32_t deal_num_real,
int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, int32_t grid_offset,
float *nram_h_high_temp) {
#if __BANG_ARCH__ >= 322
__bang_transpose(nram_grad_output_tl, grad_h_weight,
num_per_time_real * num_heads * num_levels * deal_num_real,
num_points);
__bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl,
(float *)nram_spatial_shapes, num_deal_grid * deal_num_real,
num_levels);
__bang_transpose(grad_h_weight, nram_grad_output_tl,
num_points * deal_num_real,
num_per_time_real * num_heads * num_levels);
for (int32_t m = 0; m < deal_num_real; ++m) {
__memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
}
__sync_io_move_compute();
__bang_transpose(nram_grad_output_tr, grad_temp1,
deal_num_real * num_per_time_real * num_heads,
num_levels * num_points);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1,
num_deal_grid * deal_num_real,
deal_num_real * num_per_time_real * num_heads);
__bang_transpose(grad_temp1, nram_grad_output_tr,
num_levels * num_points * deal_num_real,
num_per_time_real * num_heads);
__bang_mul(grad_h_weight, grad_h_weight, grad_temp1,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_tl, grad_h_weight, num_deal_grid,
deal_num_real);
__memcpy_async(grad_h_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM);
recursiveSumPool(grad_h_weight, num_deal_grid, deal_num_real, ALIGN_NUM);
__nram__ int table[2] = {0, (int)0xffffffff};
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)grad_h_weight, (char *)grad_h_weight,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__bang_transpose(nram_grad_output_tl, grad_w_weight,
num_per_time_real * num_heads * num_levels * deal_num_real,
num_points);
__bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl,
(float *)(nram_spatial_shapes + num_levels),
num_deal_grid * deal_num_real, num_levels);
__bang_transpose(grad_w_weight, nram_grad_output_tl,
num_points * deal_num_real,
num_per_time_real * num_heads * num_levels);
__bang_mul(grad_w_weight, grad_w_weight, grad_temp1,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_tl, grad_w_weight, num_deal_grid,
deal_num_real);
__memcpy(grad_w_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM);
recursiveSumPool(grad_w_weight, num_deal_grid, deal_num_real, ALIGN_NUM);
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)grad_w_weight, (char *)grad_w_weight,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__memcpy(grad_w_weight + num_deal_grid, grad_h_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, grad_w_weight, 2, num_deal_grid);
__bang_atomic_add((float *)nram_grad_output_tl,
(float *)grad_sampling_loc + grid_offset * 2,
(float *)nram_grad_output_tl, 2 * num_deal_grid);
#endif
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight) {
#if __BANG_ARCH__ > 322
const int32_t split_grid_num = 28;
const int32_t split_num_c = 8;
const int32_t C_align = PAD_UP(channels, ALIGN_NUM);
const int32_t num_hlp = num_heads * num_levels * num_points;
int32_t num_per_time_theory = (MAX_NRAM_SIZE - num_levels * sizeof(float) -
3 * num_levels * sizeof(int32_t)) /
sizeof(float) /
(split_num_c * C_align + split_grid_num) /
PAD_UP((num_hlp), ALIGN_NUM);
int32_t deal_grid_num_theory = num_per_time_theory * num_hlp;
const int32_t offset_nram = num_per_time_theory * C_align * num_hlp;
const int32_t offset_nram_calc = PAD_UP(deal_grid_num_theory, ALIGN_NUM);
float *nram_grad_output_tl = (float *)nram_buffer;
float *nram_grad_output_tr = (float *)nram_buffer + offset_nram;
float *nram_grad_output_bl = (float *)nram_buffer + 2 * offset_nram;
float *nram_grad_output_br = (float *)nram_buffer + 3 * offset_nram;
float *grad_temp1 = (float *)nram_buffer + 4 * offset_nram;
float *grad_temp2 = (float *)nram_buffer + 5 * offset_nram;
float *grad_temp3 = (float *)nram_buffer + 6 * offset_nram;
float *grad_temp4 = (float *)nram_buffer + 7 * offset_nram;
float *nram_loc_w = (float *)nram_buffer + split_num_c * offset_nram;
float *nram_loc_h =
(float *)nram_buffer + split_num_c * offset_nram + offset_nram_calc;
float *nram_h_low =
(float *)nram_buffer + split_num_c * offset_nram + 2 * offset_nram_calc;
float *nram_w_low =
(float *)nram_buffer + split_num_c * offset_nram + 3 * offset_nram_calc;
float *nram_h_high =
(float *)nram_buffer + split_num_c * offset_nram + 4 * offset_nram_calc;
float *nram_w_high =
(float *)nram_buffer + split_num_c * offset_nram + 5 * offset_nram_calc;
float *nram_h_low_temp =
(float *)nram_buffer + split_num_c * offset_nram + 6 * offset_nram_calc;
float *nram_h_high_temp =
(float *)nram_buffer + split_num_c * offset_nram + 7 * offset_nram_calc;
float *nram_hw =
(float *)nram_buffer + split_num_c * offset_nram + 8 * offset_nram_calc;
float *nram_hh =
(float *)nram_buffer + split_num_c * offset_nram + 9 * offset_nram_calc;
float *nram_lw =
(float *)nram_buffer + split_num_c * offset_nram + 10 * offset_nram_calc;
float *nram_lh =
(float *)nram_buffer + split_num_c * offset_nram + 11 * offset_nram_calc;
float *nram_h_low_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 12 * offset_nram_calc;
float *nram_h_high_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 13 * offset_nram_calc;
float *nram_w_low_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 14 * offset_nram_calc;
float *nram_w_high_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 15 * offset_nram_calc;
float *nram_w1 =
(float *)nram_buffer + split_num_c * offset_nram + 16 * offset_nram_calc;
float *nram_w2 =
(float *)nram_buffer + split_num_c * offset_nram + 17 * offset_nram_calc;
float *nram_w3 =
(float *)nram_buffer + split_num_c * offset_nram + 18 * offset_nram_calc;
float *nram_w4 =
(float *)nram_buffer + split_num_c * offset_nram + 19 * offset_nram_calc;
float *nram_grad_weight =
(float *)nram_buffer + split_num_c * offset_nram + 20 * offset_nram_calc;
float *nram_base_ptr =
(float *)nram_buffer + split_num_c * offset_nram + 21 * offset_nram_calc;
float *nram_offset_temp =
(float *)nram_buffer + split_num_c * offset_nram + 22 * offset_nram_calc;
float *nram_offset1 =
(float *)nram_buffer + split_num_c * offset_nram + 23 * offset_nram_calc;
float *nram_offset2 =
(float *)nram_buffer + split_num_c * offset_nram + 24 * offset_nram_calc;
float *nram_offset3 =
(float *)nram_buffer + split_num_c * offset_nram + 25 * offset_nram_calc;
float *nram_offset4 =
(float *)nram_buffer + split_num_c * offset_nram + 26 * offset_nram_calc;
float *nram_w_low_temp =
(float *)nram_buffer + split_num_c * offset_nram + 27 * offset_nram_calc;
int32_t *nram_spatial_shapes =
(int32_t *)((float *)nram_buffer + split_num_c * offset_nram +
28 * offset_nram_calc);
int32_t *nram_level_start_index =
(int32_t *)(nram_spatial_shapes + 2 * num_levels);
float *nram_h_stride = (float *)(nram_level_start_index + 3 * num_levels);
const int32_t total_num = batch * num_query;
int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem);
num_per_time_theory =
num_per_core > num_per_time_theory ? num_per_time_theory : num_per_core;
int32_t num_deal_grid = num_per_time_theory * num_hlp;
if (num_per_core == 0) return;
int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core)
: (num_rem + taskId * num_per_core);
const int32_t qid_stride = num_heads * channels;
int32_t deal_num_real = channels;
const int32_t repeat_times = num_per_core / num_per_time_theory;
const int32_t tail_num = num_per_core % num_per_time_theory;
int32_t num_per_time_real = num_per_time_theory;
for (int32_t loop = 0; loop < num_heads; ++loop) {
nram_base_ptr[loop] = loop * channels;
}
const int32_t w_stride = num_heads * channels;
for (int32_t grid_loop = 0; grid_loop < repeat_times + 1; grid_loop += 1) {
int32_t grid_offset =
(start_per_core + grid_loop * num_per_time_theory) * num_hlp;
if (grid_loop == repeat_times) {
if (tail_num == 0) {
continue;
} else {
grid_offset =
(start_per_core + repeat_times * num_per_time_theory) * num_hlp;
num_per_time_real = tail_num;
num_deal_grid = tail_num * num_hlp;
}
}
__memcpy_async(nram_spatial_shapes, spatial_shapes,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(nram_level_start_index, data_level_start_index,
num_levels * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(nram_loc_w, data_sampling_loc + grid_offset * 2,
num_deal_grid * 2 * sizeof(float), GDRAM2NRAM);
__memcpy(nram_grad_weight, data_attn_weight + grid_offset,
num_deal_grid * sizeof(float), GDRAM2NRAM);
computeGridMaskAndOffset(
nram_grad_output_tl, nram_grad_output_tr, nram_loc_w, nram_loc_h,
nram_h_stride, nram_spatial_shapes, nram_w_low_temp, nram_h_high_temp,
nram_w_low, nram_h_low, nram_h_high, nram_w_high, nram_lh, nram_lw,
nram_hh, nram_hw, nram_h_low_ptr_offset, nram_h_high_ptr_offset,
nram_w_low_ptr_offset, nram_w_high_ptr_offset, nram_w1, nram_w2,
nram_w3, nram_w4, nram_offset_temp, nram_offset1, nram_offset2,
nram_offset3, nram_offset4, nram_base_ptr, nram_h_low_temp,
num_deal_grid, num_per_time_real, num_heads, num_levels, num_points,
w_stride, qid_stride);
float *mask1 = nram_h_low_ptr_offset;
float *mask2 = nram_h_high_ptr_offset;
float *mask3 = nram_w_low_ptr_offset;
float *mask4 = nram_w_high_ptr_offset;
loadValue(nram_grad_output_tl, nram_grad_output_tr, nram_grad_output_bl,
nram_grad_output_br, data_value, grad_output, grad_temp1,
grad_temp2, mask1, mask2, mask3, mask4, nram_offset1,
nram_offset2, nram_offset3, nram_offset4, nram_grad_weight,
nram_level_start_index, offset_nram, start_per_core, grid_loop,
num_per_time_theory, num_heads, deal_num_real, num_per_time_real,
num_deal_grid, num_query, num_levels, num_points, grid_offset,
spatial_size, qid_stride);
float *nram_grid_offset1 = nram_loc_h;
float *nram_grid_offset2 = nram_loc_w;
computeGradValue(
grad_temp1, grad_temp2, grad_temp3, grad_temp4, mask1, mask2, mask3,
mask4, nram_offset1, nram_offset2, nram_offset3, nram_offset4,
nram_level_start_index, deal_num_real, grad_value, nram_w1, nram_w2,
nram_w3, nram_w4, num_per_time_real, num_heads, num_levels, num_points,
num_query, num_deal_grid, grid_offset, spatial_size, qid_stride,
nram_grid_offset1, nram_grid_offset2);
// compute grad_weight
float *grad_weight = grad_temp1;
float *grad_h_weight = grad_temp4;
float *grad_w_weight = grad_temp3;
computeGradAttnWeight(
grad_w_weight, grad_weight, nram_grad_output_tl, nram_grad_output_tr,
nram_grad_output_bl, nram_grad_output_br, grad_temp1, grad_temp2,
grad_attn_weight, nram_hw, nram_hh, nram_lw, nram_lh, grad_h_weight,
nram_w1, nram_w2, nram_w3, nram_w4, offset_nram, num_deal_grid,
deal_num_real, num_per_time_real, num_heads, num_levels, num_points,
grid_offset, nram_h_high_temp);
// compute grad_sampling_loc
computeGradSampingLoc(grad_sampling_loc, nram_grad_output_tl,
nram_grad_output_tr, grad_h_weight, grad_w_weight,
nram_spatial_shapes, grad_temp1, grad_temp2,
nram_grad_weight, num_deal_grid, deal_num_real,
num_per_time_real, num_heads, num_levels, num_points,
grid_offset, nram_h_high_temp);
}
#endif
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight);
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const float *data_value, const int32_t *spatial_shapes, const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc, const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output, const float *data_attn_weight, const float *grad_output,
...@@ -836,7 +2060,23 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( ...@@ -836,7 +2060,23 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
const int32_t num_points, float *grad_value, float *grad_sampling_loc, const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight); float *grad_attn_weight);
void KernelMsDeformAttnBackward( void KernelMsDeformAttnBackwardDefaultKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float *data_value,
const int32_t *spatial_shapes, const int32_t *data_level_start_index,
const float *data_sampling_loc, const float *data_attn_weight,
const float *grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float *grad_value,
float *grad_sampling_loc, float *grad_attn_weight) {
MLUUnion1KernelMsDeformAttnBackwarDefaultKernel<<<k_dim, k_type, queue>>>(
data_value, spatial_shapes, data_level_start_index, data_sampling_loc,
data_attn_weight, grad_output, batch, spatial_size, num_heads, channels,
num_levels, num_query, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
}
void KernelMsDeformAttnBackwardSmallChannelsKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float *data_value, const cnrtDataType_t d_type, const float *data_value,
const int32_t *spatial_shapes, const int32_t *data_level_start_index, const int32_t *spatial_shapes, const int32_t *data_level_start_index,
...@@ -845,7 +2085,8 @@ void KernelMsDeformAttnBackward( ...@@ -845,7 +2085,8 @@ void KernelMsDeformAttnBackward(
const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float *grad_value, const int32_t num_query, const int32_t num_points, float *grad_value,
float *grad_sampling_loc, float *grad_attn_weight) { float *grad_sampling_loc, float *grad_attn_weight) {
MLUUnion1KernelMsDeformAttnBackward<<<k_dim, k_type, queue>>>( MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel<<<k_dim, k_type,
queue>>>(
data_value, spatial_shapes, data_level_start_index, data_sampling_loc, data_value, spatial_shapes, data_level_start_index, data_sampling_loc,
data_attn_weight, grad_output, batch, spatial_size, num_heads, channels, data_attn_weight, grad_output, batch, spatial_size, num_heads, channels,
num_levels, num_query, num_points, grad_value, grad_sampling_loc, num_levels, num_query, num_points, grad_value, grad_sampling_loc,
......
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
__nram__ char nram_buffer[MAX_NRAM_SIZE];
#if __BANG_ARCH__ >= 322
__mlu_func__ void computeDynamicVoxelize(
char *points_x, char *points_y, char *points_z, char *auxiliary_a,
char *auxiliary_b, char *auxiliary_c, const float coors_x_min,
const float coors_y_min, const float coors_z_min, const float voxel_x,
const float voxel_y, const float voxel_z, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z, const int32_t deal_num) {
// x - coors_x_min
__bang_sub_scalar((float *)points_x, (float *)points_x, coors_x_min,
deal_num);
// y - coors_y_min
__bang_sub_scalar((float *)points_y, (float *)points_y, coors_y_min,
deal_num);
// z - coors_z_min
__bang_sub_scalar((float *)points_z, (float *)points_z, coors_z_min,
deal_num);
// (x - coors_x_min) / voxel_x
__bang_mul_scalar((float *)points_x, (float *)points_x, 1.0 / voxel_x,
deal_num);
// (y - coors_y_min) / voxel_y
__bang_mul_scalar((float *)points_y, (float *)points_y, 1.0 / voxel_y,
deal_num);
// (z - coors_z_min) / voxel_z
__bang_mul_scalar((float *)points_z, (float *)points_z, 1.0 / voxel_z,
deal_num);
// c_x = floor((x - coors_x_min) / voxel_x)
__bang_floor((float *)auxiliary_a, (float *)points_x, deal_num);
__bang_float2int32((int32_t *)points_x, (float *)auxiliary_a, deal_num, 0);
// c_y = floor((y - coors_y_min) / voxel_y)
__bang_floor((float *)auxiliary_a, (float *)points_y, deal_num);
__bang_float2int32((int32_t *)points_y, (float *)auxiliary_a, deal_num, 0);
// c_z = floor((z - coors_z_min) / voxel_z)
__bang_floor((float *)auxiliary_a, (float *)points_z, deal_num);
__bang_float2int32((int32_t *)points_z, (float *)auxiliary_a, deal_num, 0);
// c_x >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_x, (int32_t)0,
deal_num);
// c_x < grid_x
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_x, grid_x,
deal_num);
// 0 <= c_x < grid_x
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// c_y >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_y, (int32_t)0,
deal_num);
// c_y < grid_y
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_y, grid_y,
deal_num);
// 0 <= c_y < grid_y
__bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// c_x >= 0 && c_x < grid_x && c_y >= 0 && c_y < grid_y
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)auxiliary_b, deal_num);
// c_z >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_z, (int32_t)0,
deal_num);
// c_z < grid_z
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_z, grid_z,
deal_num);
// 0 <= c_z < grid_z
__bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// 0 <= c_x < grid_x && 0 <= c_y < grid_y && 0 <= c_z < grid_z
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)auxiliary_b, deal_num);
__bang_not((int32_t *)auxiliary_c, (int32_t *)auxiliary_a, deal_num);
__bang_mul((int32_t *)points_x, (int32_t *)points_x, (int32_t *)auxiliary_a,
deal_num);
__bang_mul_scalar((int32_t *)auxiliary_b, (int32_t *)auxiliary_c,
(int32_t)(-1), deal_num);
__bang_add((int32_t *)points_x, (int32_t *)points_x, (int32_t *)auxiliary_b,
deal_num);
__bang_mul((int32_t *)points_y, (int32_t *)points_y, (int32_t *)auxiliary_a,
deal_num);
__bang_add((int32_t *)points_y, (int32_t *)points_y, (int32_t *)auxiliary_b,
deal_num);
__bang_mul((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_a,
deal_num);
__bang_add((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_b,
deal_num);
}
__mlu_func__ void computePoint2Voxel(char *coors_x, char *coors_y,
char *coors_z, const int32_t c_x,
const int32_t c_y, const int32_t c_z,
const int32_t max_points, int32_t *num,
int32_t *first_point,
const int32_t deal_idx,
const int32_t deal_num) {
__bang_eq_scalar((int32_t *)coors_x, (int32_t *)coors_x, c_x, deal_num);
__bang_eq_scalar((int32_t *)coors_y, (int32_t *)coors_y, c_y, deal_num);
__bang_eq_scalar((int32_t *)coors_z, (int32_t *)coors_z, c_z, deal_num);
__bang_mul((int32_t *)coors_x, (int32_t *)coors_x, (int32_t *)coors_y,
deal_num);
__bang_mul((int32_t *)coors_x, (int32_t *)coors_x, (int32_t *)coors_z,
deal_num);
if (*num == 0) {
*num = (int32_t)__bang_count((float *)coors_x, deal_num);
if (*num > 0) {
*first_point =
(int32_t)__bang_findfirst1((float *)coors_x, deal_num) + deal_idx;
}
} else {
*num += (int32_t)__bang_count((float *)coors_x, deal_num);
}
}
#endif
__mlu_global__ void MLUUnion1KernelDynamicVoxelize(
const float *points, int32_t *coors, const float voxel_x,
const float voxel_y, const float voxel_z, const float coors_x_min,
const float coors_y_min, const float coors_z_min, const float coors_x_max,
const float coors_y_max, const float coors_z_max, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z, const int32_t num_points,
const int32_t num_features) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
const int32_t points_rem = num_points % taskDim;
const int32_t points_per_core =
taskId < points_rem ? num_points / taskDim + 1 : num_points / taskDim;
const int32_t points_start = taskId < points_rem
? taskId * points_per_core
: taskId * points_per_core + points_rem;
const int32_t split_num = 9;
const int32_t deal_num =
PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(float), NFU_ALIGN_SIZE);
const int32_t repeat = points_per_core / deal_num;
const int32_t rem = points_per_core % deal_num;
const int32_t ping_pong_gap = 3 * deal_num * sizeof(float);
char *points_x = nram_buffer;
char *points_y = points_x + deal_num * sizeof(float);
char *points_z = points_y + deal_num * sizeof(float);
char *auxiliary_a = points_x + 2 * ping_pong_gap;
char *auxiliary_b = auxiliary_a + deal_num * sizeof(float);
char *auxiliary_c = auxiliary_b + deal_num * sizeof(float);
int32_t *coors_z_start = coors + points_start;
int32_t *coors_y_start = coors + num_points + points_start;
int32_t *coors_x_start = coors + num_points * 2 + points_start;
if (repeat > 0) {
__memcpy_async(points_x, points + points_start * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_y, points + points_start * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_z, points + points_start * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__asm__ volatile("sync;");
}
if (repeat > 1) {
__memcpy_async(points_x + ping_pong_gap,
points + (points_start + deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_y + ping_pong_gap,
points + (points_start + deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_z + ping_pong_gap,
points + (points_start + deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
computeDynamicVoxelize(points_x, points_y, points_z, auxiliary_a,
auxiliary_b, auxiliary_c, coors_x_min, coors_y_min,
coors_z_min, voxel_x, voxel_y, voxel_z, grid_x,
grid_y, grid_z, deal_num);
__asm__ volatile("sync;");
}
for (int32_t i = 0; i < repeat - 2; ++i) {
__memcpy_async(coors_x_start + i * deal_num,
points_x + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + i * deal_num,
points_y + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + i * deal_num,
points_z + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(points_x + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(
points_y + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
deal_num - 1);
__memcpy_async(
points_z + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
deal_num - 1);
computeDynamicVoxelize(points_x + ((i + 1) % 2) * ping_pong_gap,
points_y + ((i + 1) % 2) * ping_pong_gap,
points_z + ((i + 1) % 2) * ping_pong_gap,
auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min,
coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z,
grid_x, grid_y, grid_z, deal_num);
__asm__ volatile("sync;");
}
if (repeat >= 2) {
__memcpy_async(coors_x_start + (repeat - 2) * deal_num,
points_x + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + (repeat - 2) * deal_num,
points_y + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + (repeat - 2) * deal_num,
points_z + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
}
if (rem > 0) {
__memcpy_async(points_x + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), rem - 1);
__memcpy_async(
points_y + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
rem - 1);
__memcpy_async(
points_z + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
rem - 1);
}
if (repeat > 0) {
computeDynamicVoxelize(points_x + ((repeat - 1) % 2) * ping_pong_gap,
points_y + ((repeat - 1) % 2) * ping_pong_gap,
points_z + ((repeat - 1) % 2) * ping_pong_gap,
auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min,
coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z,
grid_x, grid_y, grid_z, deal_num);
}
__asm__ volatile("sync;");
if (repeat > 0) {
__memcpy_async(coors_x_start + (repeat - 1) * deal_num,
points_x + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + (repeat - 1) * deal_num,
points_y + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + (repeat - 1) * deal_num,
points_z + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
}
if (rem > 0) {
computeDynamicVoxelize(points_x + (repeat % 2) * ping_pong_gap,
points_y + (repeat % 2) * ping_pong_gap,
points_z + (repeat % 2) * ping_pong_gap, auxiliary_a,
auxiliary_b, auxiliary_c, coors_x_min, coors_y_min,
coors_z_min, voxel_x, voxel_y, voxel_z, grid_x,
grid_y, grid_z, rem);
__asm__ volatile("sync;");
__memcpy_async(coors_x_start + repeat * deal_num,
points_x + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + repeat * deal_num,
points_y + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + repeat * deal_num,
points_z + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
}
#endif
}
__mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors,
int32_t *point_to_pointidx,
int32_t *point_to_voxelidx,
const int32_t num_points,
const int32_t max_points) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
const int32_t split_num = 6;
const int32_t deal_num =
PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(int32_t), NFU_ALIGN_SIZE);
const int32_t ping_pong_gap = 3 * deal_num * sizeof(int32_t);
char *coors_x = nram_buffer;
char *coors_y = coors_x + deal_num * sizeof(int32_t);
char *coors_z = coors_y + deal_num * sizeof(int32_t);
int32_t *coors_z_start = coors;
int32_t *coors_y_start = coors + num_points;
int32_t *coors_x_start = coors + num_points * 2;
for (int32_t point_idx = taskId; point_idx < num_points;
point_idx += taskDim) {
if (coors_x_start[point_idx] == -1) {
point_to_pointidx[point_idx] = -1;
point_to_voxelidx[point_idx] = -1;
continue;
}
int32_t c_x = coors_x_start[point_idx];
int32_t c_y = coors_y_start[point_idx];
int32_t c_z = coors_z_start[point_idx];
int32_t deal_total_num = point_idx;
int32_t repeat = deal_total_num / deal_num;
int32_t rem = deal_total_num % deal_num;
int32_t num = 0;
int32_t first_point = -1;
if (repeat > 0) {
__memcpy_async(coors_x, coors_x_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_y, coors_y_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_z, coors_z_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__asm__ volatile("sync;");
}
for (int32_t i = 0; i < repeat - 1; ++i) {
__memcpy_async(coors_x + ((i + 1) % 2) * ping_pong_gap,
coors_x_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(coors_y + ((i + 1) % 2) * ping_pong_gap,
coors_y_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(coors_z + ((i + 1) % 2) * ping_pong_gap,
coors_z_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
computePoint2Voxel(
coors_x + (i % 2) * ping_pong_gap, coors_y + (i % 2) * ping_pong_gap,
coors_z + (i % 2) * ping_pong_gap, c_x, c_y, c_z, max_points, &num,
&first_point, i * deal_num, deal_num);
__asm__ volatile("sync;");
}
if (rem > 0) {
__memcpy_async(coors_x + (repeat % 2) * ping_pong_gap,
coors_x_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_y + (repeat % 2) * ping_pong_gap,
coors_y_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_z + (repeat % 2) * ping_pong_gap,
coors_z_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
}
if (repeat > 0) {
computePoint2Voxel(coors_x + ((repeat - 1) % 2) * ping_pong_gap,
coors_y + ((repeat - 1) % 2) * ping_pong_gap,
coors_z + ((repeat - 1) % 2) * ping_pong_gap, c_x, c_y,
c_z, max_points, &num, &first_point,
(repeat - 1) * deal_num, deal_num);
}
__asm__ volatile("sync;");
if (rem > 0) {
computePoint2Voxel(coors_x + (repeat % 2) * ping_pong_gap,
coors_y + (repeat % 2) * ping_pong_gap,
coors_z + (repeat % 2) * ping_pong_gap, c_x, c_y, c_z,
max_points, &num, &first_point, repeat * deal_num,
rem);
__asm__ volatile("sync;");
}
if (num == 0) {
point_to_pointidx[point_idx] = point_idx;
} else if (num > 0) {
point_to_pointidx[point_idx] = first_point;
}
if (num < max_points) {
point_to_voxelidx[point_idx] = num;
} else {
point_to_voxelidx[point_idx] = -1;
}
}
#endif
}
__mlu_global__ void MLUUnion1KernelCalcPointsPerVoxel(
int32_t *point_to_pointidx, int32_t *point_to_voxelidx,
int32_t *coor_to_voxelidx, int32_t *num_points_per_voxel,
int32_t *voxel_num, const int32_t max_voxels, const int32_t num_points) {
#if __BANG_ARCH__ >= 322
if (coreId == 0) {
int32_t voxel_num_temp = 0;
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
int32_t point_pos_in_voxel = point_to_voxelidx[point_idx];
coor_to_voxelidx[point_idx] = -1;
if (point_pos_in_voxel == -1) {
continue;
} else if (point_pos_in_voxel == 0) {
int32_t voxel_idx = voxel_num_temp;
if (voxel_num_temp >= max_voxels) {
continue;
}
voxel_num_temp += 1;
coor_to_voxelidx[point_idx] = voxel_idx;
num_points_per_voxel[voxel_idx] = 1;
} else {
int32_t point_idx_temp = point_to_pointidx[point_idx];
int32_t voxel_idx = coor_to_voxelidx[point_idx_temp];
if (voxel_idx != -1) {
coor_to_voxelidx[point_idx] = voxel_idx;
num_points_per_voxel[voxel_idx] += 1;
}
}
}
*voxel_num = voxel_num_temp;
}
#endif
}
__mlu_global__ void MLUUnion1KernelAssignVoxelsCoors(
const float *points, int32_t *temp_coors, int32_t *point_to_voxelidx,
int32_t *coor_to_voxelidx, float *voxels, int32_t *coors,
const int32_t max_points, const int32_t num_points,
const int32_t num_features) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
int32_t points_per_core = num_points / taskDim;
int32_t points_rem = num_points % taskDim;
int32_t points_start = taskId < points_rem
? taskId * (points_per_core + 1)
: taskId * points_per_core + points_rem;
int32_t points_end = taskId < points_rem ? points_start + points_per_core + 1
: points_start + points_per_core;
for (int32_t point_idx = points_start; point_idx < points_end; ++point_idx) {
int32_t num = point_to_voxelidx[point_idx];
int32_t voxel_idx = coor_to_voxelidx[point_idx];
if (num > -1 && voxel_idx > -1) {
float *voxels_offset =
voxels + voxel_idx * max_points * num_features + num * num_features;
const float *points_offset = points + point_idx * num_features;
__memcpy_async(voxels_offset, points_offset, num_features * sizeof(float),
GDRAM2GDRAM);
if (num == 0) {
int32_t *coors_offset = coors + voxel_idx * 3;
__memcpy_async(coors_offset, temp_coors + point_idx, sizeof(int32_t),
GDRAM2GDRAM, sizeof(int32_t),
num_points * sizeof(int32_t), 2);
}
}
}
__asm__ volatile("sync;");
#endif
}
void KernelDynamicVoxelize(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const void *points, void *coors,
const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min,
const float coors_y_min, const float coors_z_min,
const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z,
const int32_t num_points,
const int32_t num_features) {
MLUUnion1KernelDynamicVoxelize<<<k_dim, k_type, queue>>>(
(float *)points, (int32_t *)coors, voxel_x, voxel_y, voxel_z, coors_x_min,
coors_y_min, coors_z_min, coors_x_max, coors_y_max, coors_z_max, grid_x,
grid_y, grid_z, num_points, num_features);
}
void KernelPoint2Voxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *coors, void *point_to_pointidx,
void *point_to_voxelidx, const int32_t num_points,
const int32_t max_points) {
MLUUnion1KernelPoint2Voxel<<<k_dim, k_type, queue>>>(
(int32_t *)coors, (int32_t *)point_to_pointidx,
(int32_t *)point_to_voxelidx, num_points, max_points);
}
void KernelCalcPointsPerVoxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *point_to_pointidx,
void *point_to_voxelidx, void *coor_to_voxelidx,
void *num_points_per_voxel, void *voxel_num,
const int32_t max_voxels,
const int32_t num_points) {
MLUUnion1KernelCalcPointsPerVoxel<<<k_dim, k_type, queue>>>(
(int32_t *)point_to_pointidx, (int32_t *)point_to_voxelidx,
(int32_t *)coor_to_voxelidx, (int32_t *)num_points_per_voxel,
(int32_t *)voxel_num, max_voxels, num_points);
}
void KernelAssignVoxelsCoors(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const void *points,
void *temp_coors, void *point_to_voxelidx,
void *coor_to_voxelidx, void *voxels, void *coors,
const int32_t max_points, const int32_t num_points,
const int32_t num_features) {
MLUUnion1KernelAssignVoxelsCoors<<<k_dim, k_type, queue>>>(
(float *)points, (int32_t *)temp_coors, (int32_t *)point_to_voxelidx,
(int32_t *)coor_to_voxelidx, (float *)voxels, (int32_t *)coors,
max_points, num_points, num_features);
}
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
void ball_query_forward_mlu(int b, int n, int m, float min_radius,
float max_radius, int nsample, const Tensor new_xyz,
const Tensor xyz, Tensor idx) {
auto new_xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
new_xyz, new_xyz.suggest_memory_format());
auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
xyz, new_xyz.suggest_memory_format());
auto idx_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
idx, new_xyz.suggest_memory_format());
MluOpTensorDescriptor new_xyz_desc, xyz_desc, idx_desc;
new_xyz_desc.set(new_xyz_contiguous);
xyz_desc.set(xyz_contiguous);
idx_desc.set(idx_contiguous);
auto new_xyz_impl = torch_mlu::getMluTensorImpl(new_xyz_contiguous);
auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_contiguous);
auto idx_impl = torch_mlu::getMluTensorImpl(idx_contiguous);
auto new_xyz_ptr = new_xyz_impl->cnnlMalloc();
auto xyz_ptr = xyz_impl->cnnlMalloc();
auto idx_ptr = idx_impl->cnnlMalloc();
auto handle = mluOpGetCurrentHandle();
mluOpBallQuery(handle, new_xyz_desc.desc(), new_xyz_ptr, xyz_desc.desc(),
xyz_ptr, min_radius, max_radius, nsample, idx_desc.desc(),
idx_ptr);
}
void ball_query_forward_impl(int b, int n, int m, float min_radius,
float max_radius, int nsample,
const Tensor new_xyz, const Tensor xyz,
Tensor idx);
REGISTER_DEVICE_IMPL(ball_query_forward_impl, MLU, ball_query_forward_mlu);
...@@ -9,254 +9,59 @@ ...@@ -9,254 +9,59 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/ *************************************************************************/
#include "pytorch_device_registry.hpp" #include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *input, const void *rois,
const void *offset, void *output,
const int channels, const int height,
const int width, const int num_rois,
const int pooled_height, const int pooled_width,
const float spatial_scale,
const int sampling_ratio, const float gamma);
void KernelDeformRoIPoolBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
cnrtDataType_t data_type, const void *grad_output, const void *input,
const void *rois, const void *offset, void *grad_input, void *grad_offset,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const float spatial_scale,
const int sampling_ratio, const float gamma);
// policy function for forward and backward
static void policyFunc(const int bin_num, cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type) {
const size_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
;
const size_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
const size_t bin_num_align = CEIL_ALIGN(bin_num, core_limit);
k_dim->x = core_limit;
k_dim->y = (bin_num_align / core_limit) > cluster_limit
? cluster_limit
: (bin_num_align / core_limit);
k_dim->z = 1;
*k_type = CNRT_FUNC_TYPE_UNION1;
}
void DeformRoIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois, void DeformRoIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois,
Tensor offset, Tensor output, Tensor offset, Tensor output,
int pooled_height, int pooled_width, int pooled_height, int pooled_width,
float spatial_scale, float spatial_scale,
int sampling_ratio, float gamma) { int sampling_ratio, float gamma) {
// Check dtype.
TORCH_CHECK(
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
"input type should be Float or Half, got ", input.scalar_type());
TORCH_CHECK(input.scalar_type() == rois.scalar_type(),
"rois should have the same type as input");
// Check shape.
TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(),
"D.");
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
"D.");
if (offset.defined() && offset.numel() > 0) {
TORCH_CHECK(input.scalar_type() == offset.scalar_type(),
"offset should have the same type as input");
TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ",
offset.dim(), "D.");
TORCH_CHECK(
(offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0),
"while rois.size(0)) = ", rois.size(0), ". They should be the same.");
TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ",
"but now offset.size(1) = ", offset.size(1), ".");
TORCH_CHECK((offset.size(2) == output.size(2)),
"offset.size(2) = ", offset.size(2),
"while output.size(2)) = ", output.size(2),
". They should be the same.");
TORCH_CHECK((offset.size(3) == output.size(3)),
"offset.size(3) = ", offset.size(3),
"while output.size(3)) = ", output.size(3),
". They should be the same.");
}
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
"spatial_scale should be within (0, 1], got ", spatial_scale,
".");
// compute kernel params
auto height = input.size(2);
auto width = input.size(3);
auto channels = input.size(1);
auto num_rois = output.size(0);
if (output.numel() == 0) {
output = at::zeros({num_rois, channels, pooled_height, pooled_width},
input.options());
return;
}
// zero element check
TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ",
input.size(0));
TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ",
rois.numel());
if (input.numel() == 0 || output.numel() == 0) {
return;
}
// large tensor check
const size_t max_input_num = 2147483648; // 2^31, 2G num
TORCH_CHECK(input.numel() < max_input_num,
"input.numel() should be less than 2147483648, got ",
input.numel());
TORCH_CHECK(rois.numel() < max_input_num,
"rois.numel() should be less than 2147483648, got ",
rois.numel());
TORCH_CHECK(output.numel() < max_input_num,
"output.numel() should be less than 2147483648, got ",
output.numel());
TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num,
"offset.numel() should be less than 2147483648, got ",
offset.numel());
auto memory_format = auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
auto rois_contiguous =
at::Tensor output_ = torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format());
at::empty({num_rois, channels, pooled_height, pooled_width}, auto output_contiguous =
input.options(), memory_format); torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format);
// calculate task dimension MluOpTensorDescriptor input_desc, rois_desc, offset_desc, output_desc;
cnrtDim3_t k_dim; input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC);
cnrtFunctionType_t k_type; rois_desc.set(rois_contiguous);
policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type); output_desc.set_with_layout(output_contiguous, MLUOP_LAYOUT_NHWC);
// get compute queue mluOpTensorDescriptor_t offset_real_desc = NULL;
auto queue = torch_mlu::getCurQueue(); void *offset_ptr = NULL;
if (offset.defined() && offset.numel() > 0) {
auto offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
offset, offset.suggest_memory_format());
offset_desc.set(offset_contiguous);
offset_real_desc = offset_desc.desc();
auto offset_impl = torch_mlu::getMluTensorImpl(offset_contiguous);
offset_ptr = offset_impl->cnnlMalloc();
}
// get ptr of tensors // get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(input_); auto input_impl = torch_mlu::getMluTensorImpl(input_);
auto input_ptr = input_impl->cnnlMalloc(); auto input_ptr = input_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois); auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous);
auto rois_ptr = rois_impl->cnnlMalloc(); auto rois_ptr = rois_impl->cnnlMalloc();
auto offset_impl = torch_mlu::getMluTensorImpl(offset); auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous);
auto offset_ptr = offset_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output_);
auto output_ptr = output_impl->cnnlMalloc(); auto output_ptr = output_impl->cnnlMalloc();
// get comput dtype of input // get compute handle
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input_.dtype()); auto handle = mluOpGetCurrentHandle();
mluOpDeformRoiPoolForward(
// launch kernel handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr,
CNLOG(INFO) << "Launch Kernel MLUKernelDeformRoIPoolForward<<<" << k_dim.x offset_real_desc, offset_ptr, pooled_height, pooled_width, spatial_scale,
<< ", " << k_dim.y << ", " << k_dim.z << ">>>"; sampling_ratio, gamma, output_desc.desc(), output_ptr);
KernelDeformRoIPoolForward(k_dim, k_type, queue, data_type, input_ptr, output.copy_(output_contiguous);
rois_ptr, offset_ptr, output_ptr, channels, height,
width, num_rois, pooled_height, pooled_width,
spatial_scale, sampling_ratio, gamma);
output.copy_(output_);
} }
void DeformRoIPoolBackwardMLUKernelLauncher( void DeformRoIPoolBackwardMLUKernelLauncher(
Tensor grad_output, Tensor input, Tensor rois, Tensor offset, Tensor grad_output, Tensor input, Tensor rois, Tensor offset,
Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width, Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width,
float spatial_scale, int sampling_ratio, float gamma) { float spatial_scale, int sampling_ratio, float gamma) {
// Check dtype.
TORCH_CHECK(
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
"input type should be Float or Half, got ", input.scalar_type());
TORCH_CHECK(input.scalar_type() == grad_output.scalar_type(),
"grad_output should have the same type as input");
TORCH_CHECK(input.scalar_type() == rois.scalar_type(),
"rois should have the same type as input");
TORCH_CHECK(input.scalar_type() == grad_input.scalar_type(),
"grad_input should have the same type as input");
// Check shape.
TORCH_CHECK(grad_output.dim() == 4, "grad_output should be 4d tensor, got ",
grad_output.dim(), "D.");
TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(),
"D.");
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
"D.");
if (offset.defined() && offset.numel() > 0) {
TORCH_CHECK(input.scalar_type() == offset.scalar_type(),
"offset should have the same type as input");
TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ",
offset.dim(), "D.");
TORCH_CHECK(
(offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0),
"while rois.size(0)) = ", rois.size(0), ". They should be the same.");
TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ",
"but now offset.size(1) = ", offset.size(1), ".");
TORCH_CHECK((offset.size(2) == grad_output.size(2)),
"offset.size(2) = ", offset.size(2),
"while grad_output.size(2)) = ", grad_output.size(2),
". They should be the same.");
TORCH_CHECK((offset.size(3) == grad_output.size(3)),
"offset.size(3) = ", offset.size(3),
"while grad_output.size(3)) = ", grad_output.size(3),
". They should be the same.");
}
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
"spatial_scale should be within (0, 1], got ", spatial_scale);
// Check relationship between tensor.
TORCH_CHECK((grad_output.size(0) == rois.size(0)),
"grad_output.size(0) = ", grad_output.size(0),
"while rois.size(0)) = ", rois.size(0),
". They should be the same.");
TORCH_CHECK((grad_output.size(1) == input.size(1)),
"grad_output.size(1) = ", grad_output.size(1),
"while input.size(1)) = ", input.size(1),
". They should be the same.");
TORCH_CHECK((grad_output.size(2) == pooled_height),
"grad_output.size(2) = ", grad_output.size(2),
"while pooled_height = ", pooled_height,
". They should be the same.");
TORCH_CHECK((grad_output.size(3) == pooled_width),
"grad_output.size(3) = ", grad_output.size(3),
"while pooled_width = ", pooled_width,
". They should be the same.");
// compute kernel params
auto batch = input.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);
auto num_rois = grad_output.size(0);
// zero element check
TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ",
input.size(0));
TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ",
rois.numel());
if (input.numel() == 0 || grad_output.numel() == 0) {
return;
}
// large tensor check
const size_t max_input_num = 2147483648; // 2^31, 2G num
TORCH_CHECK(input.numel() < max_input_num,
"input.numel() should be less than 2147483648, got ",
input.numel());
TORCH_CHECK(rois.numel() < max_input_num,
"rois.numel() should be less than 2147483648, got ",
rois.numel());
TORCH_CHECK(grad_output.numel() < max_input_num,
"grad_output.numel() should be less than 2147483648, got ",
grad_output.numel());
TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num,
"offset.numel() should be less than 2147483648, got ",
offset.numel());
auto memory_format = auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim()); torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim());
auto grad_output_ = auto grad_output_ =
...@@ -264,45 +69,56 @@ void DeformRoIPoolBackwardMLUKernelLauncher( ...@@ -264,45 +69,56 @@ void DeformRoIPoolBackwardMLUKernelLauncher(
memory_format = memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
at::Tensor grad_input_ = at::empty({batch, channels, height, width}, auto rois_contiguous =
input.options(), memory_format) torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format());
.zero_(); auto grad_input_ =
torch_mlu::cnnl::ops::cnnl_contiguous(grad_input, memory_format);
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors // get ptr of tensors
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_); auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_);
auto grad_output_ptr = grad_output_impl->cnnlMalloc(); auto grad_output_ptr = grad_output_impl->cnnlMalloc();
auto input_impl = torch_mlu::getMluTensorImpl(input_); auto input_impl = torch_mlu::getMluTensorImpl(input_);
auto input_ptr = input_impl->cnnlMalloc(); auto input_ptr = input_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois); auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous);
auto rois_ptr = rois_impl->cnnlMalloc(); auto rois_ptr = rois_impl->cnnlMalloc();
auto offset_impl = torch_mlu::getMluTensorImpl(offset);
auto offset_ptr = offset_impl->cnnlMalloc();
auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_); auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_);
auto grad_input_ptr = grad_input_impl->cnnlMalloc(); auto grad_input_ptr = grad_input_impl->cnnlMalloc();
auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset);
auto grad_offset_ptr = grad_offset_impl->cnnlMalloc();
// get comput dtype of input
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input.dtype());
// launch kernel MluOpTensorDescriptor grad_output_desc, input_desc, rois_desc, offset_desc,
CNLOG(INFO) << "Launch Kernel KernelDeformRoIPoolBackward<<<" << k_dim.x grad_input_desc, grad_offset_desc;
<< ", " << k_dim.y << ", " << k_dim.z << ">>>"; grad_output_desc.set_with_layout(grad_output_, MLUOP_LAYOUT_NHWC);
input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC);
KernelDeformRoIPoolBackward(k_dim, k_type, queue, data_type, grad_output_ptr, rois_desc.set(rois_contiguous);
input_ptr, rois_ptr, offset_ptr, grad_input_ptr, grad_input_desc.set_with_layout(grad_input_, MLUOP_LAYOUT_NHWC);
grad_offset_ptr, channels, height, width, mluOpTensorDescriptor_t offset_real_desc = NULL;
num_rois, pooled_height, pooled_width, void *offset_ptr = NULL;
spatial_scale, sampling_ratio, gamma); if (offset.defined() && offset.numel() > 0) {
auto offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
offset, offset.suggest_memory_format());
offset_desc.set(offset_contiguous);
offset_real_desc = offset_desc.desc();
auto offset_impl = torch_mlu::getMluTensorImpl(offset_contiguous);
offset_ptr = offset_impl->cnnlMalloc();
}
mluOpTensorDescriptor_t grad_offset_real_desc = NULL;
void *grad_offset_ptr = NULL;
if (grad_offset.defined() && grad_offset.numel() > 0) {
auto grad_offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_offset, grad_offset.suggest_memory_format());
grad_offset_desc.set(grad_offset_contiguous);
grad_offset_real_desc = grad_offset_desc.desc();
auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset_contiguous);
grad_offset_ptr = grad_offset_impl->cnnlMalloc();
}
// get compute handle
auto handle = mluOpGetCurrentHandle();
mluOpDeformRoiPoolBackward(
handle, grad_output_desc.desc(), grad_output_ptr, input_desc.desc(),
input_ptr, rois_desc.desc(), rois_ptr, offset_real_desc, offset_ptr,
pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma,
grad_input_desc.desc(), grad_input_ptr, grad_offset_real_desc,
grad_offset_ptr);
grad_input.copy_(grad_input_); grad_input.copy_(grad_input_);
} }
......
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
// Descriptors
mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type) {
const std::map<std::string, mluOpDataType_t> mapping_type = {
{std::string("c10::Half"), MLUOP_DTYPE_HALF},
{std::string("float"), MLUOP_DTYPE_FLOAT},
{std::string("double"), MLUOP_DTYPE_DOUBLE},
{std::string("int8"), MLUOP_DTYPE_INT8},
{std::string("signed char"), MLUOP_DTYPE_INT8},
{std::string("short int"), MLUOP_DTYPE_INT16},
{std::string("short"), MLUOP_DTYPE_INT16},
{std::string("int"), MLUOP_DTYPE_INT32},
{std::string("long int"), MLUOP_DTYPE_INT64},
{std::string("long"), MLUOP_DTYPE_INT64},
{std::string("unsigned char"), MLUOP_DTYPE_UINT8},
{std::string("bool"), MLUOP_DTYPE_BOOL},
{std::string("c10::complex<c10::Half>"), MLUOP_DTYPE_COMPLEX_HALF},
{std::string("c10::complex<float>"), MLUOP_DTYPE_COMPLEX_FLOAT}};
if (mapping_type.find(std::string(data_type.name())) != mapping_type.end()) {
return mapping_type.find(std::string(data_type.name()))->second;
}
return MLUOP_DTYPE_INVALID;
}
// laytout
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input) {
auto suggest_memory_format = input.suggest_memory_format();
mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY;
switch (input.dim()) {
case 4:
layout = (suggest_memory_format == at::MemoryFormat::ChannelsLast)
? MLUOP_LAYOUT_NHWC
: MLUOP_LAYOUT_NCHW;
break;
case 5:
layout = (suggest_memory_format == at::MemoryFormat::ChannelsLast3d)
? MLUOP_LAYOUT_NDHWC
: MLUOP_LAYOUT_NCDHW;
break;
default:
layout = MLUOP_LAYOUT_ARRAY;
}
return layout;
}
void MluOpTensorDescriptor::set(Tensor t) {
mluOpDataType_t data_type = getMluOpDataType(t.dtype());
mluOpTensorLayout_t layout = getMluOpSuggestLayout(t);
int t_dim = t.dim();
std::vector<int> dim_array;
if (t_dim == 0) {
dim_array.push_back(
1); // ScalarTensor(0-dim 1-item Tensor) view like size = 1 as default;
} else {
for (int i = 0; i < t_dim; i++) {
dim_array.push_back(static_cast<int>(t.sizes().vec()[i]));
}
}
set_desc(t, layout, data_type, dim_array);
}
void MluOpTensorDescriptor::set_with_layout(Tensor t,
mluOpTensorLayout_t layout) {
mluOpDataType_t data_type = getMluOpDataType(t.dtype());
int t_dim = t.dim();
std::vector<int> shape_info = checkUpperBoundAndCastTo<int>(t.sizes().vec());
std::vector<int> stride_info =
checkUpperBoundAndCastTo<int>(t.strides().vec());
if (layout == MLUOP_LAYOUT_NHWC || layout == MLUOP_LAYOUT_NDHWC ||
layout == MLUOP_LAYOUT_NLC) {
convertShapeAndStride(shape_info, stride_info);
} else if (layout == MLUOP_LAYOUT_HWCN) {
auto convertDepthWiseConvShapeStride = [](const std::vector<int64_t>& vec,
std::vector<int>& target_vec,
std::vector<int>& stride_vec) {
// NCHW --> HWCN
target_vec[0] = static_cast<int>(vec[2]);
target_vec[1] = static_cast<int>(vec[3]);
target_vec[2] = static_cast<int>(vec[1]);
target_vec[3] = static_cast<int>(vec[0]);
// Calculate Stride just like contiguous of HWCN.
stride_vec[3] = 1;
stride_vec[2] = target_vec[3] * stride_vec[3];
stride_vec[1] = target_vec[2] * stride_vec[2];
stride_vec[0] = target_vec[1] * stride_vec[1];
};
convertDepthWiseConvShapeStride(t.sizes().vec(), shape_info, stride_info);
}
TORCH_CHECK(mluOpSetTensorDescriptorEx(
desc_, layout, data_type, t_dim, shape_info.data(),
stride_info.data()) == MLUOP_STATUS_SUCCESS,
"mluOpSetTensorDescriptorEx execution failed.");
}
void MluOpTensorDescriptor::set_desc(const at::Tensor& t,
mluOpTensorLayout_t layout,
mluOpDataType_t dtype,
std::vector<int>& dims) {
int dimNb = dims.size();
mluOpSetTensorDescriptor(desc_, layout, dtype, dimNb, dims.data());
}
// Handles
std::once_flag mmcv_mluop_init_flag;
std::mutex mmcv_mluop_mutex;
static std::vector<MluOpHandle> mmcv_mluop_handles;
mluOpHandle_t mluOpGetCurrentHandle(c10::DeviceIndex device_index) {
std::call_once(mmcv_mluop_init_flag,
[]() // Init mmcv_mluop_handles 1-device <-> 1-handle
{
c10::DeviceIndex num_devices = torch_mlu::device_count();
mmcv_mluop_handles.resize(num_devices);
});
if (device_index == -1) {
device_index = torch_mlu::current_device();
}
std::lock_guard<std::mutex> mmcv_mluop_guard(mmcv_mluop_mutex);
auto queue = torch_mlu::getCurrentQueue(device_index).queue();
mmcv_mluop_handles[device_index].setQueue(queue);
return mmcv_mluop_handles[device_index].handle;
}
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#pragma once
#include <ATen/ATen.h>
#include <c10/core/ScalarType.h>
#include "aten.h"
#include "mlu_op.h"
#include "pytorch_device_registry.hpp"
#define MLUOP_MAJOR 0
#define MLUOP_MINOR 5
#define MLUOP_PATCHLEVEL 302
mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type);
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input);
class MluOpTensorDescriptor {
public:
MluOpTensorDescriptor() { mluOpCreateTensorDescriptor(&desc_); };
~MluOpTensorDescriptor() { mluOpDestroyTensorDescriptor(desc_); }
void set(at::Tensor);
void set_with_layout(at::Tensor, mluOpTensorLayout_t layout);
mluOpTensorDescriptor_t desc() { return desc_; }
private:
mluOpTensorDescriptor_t desc_;
void set_desc(const at::Tensor&, mluOpTensorLayout_t, mluOpDataType_t,
std::vector<int>& dims);
};
mluOpHandle_t mluOpGetCurrentHandle(c10::DeviceIndex device_index = -1);
class MluOpHandle {
public:
MluOpHandle() : handle(nullptr) { mluOpCreate(&handle); }
~MluOpHandle() {
if (handle) {
mluOpDestroy(handle);
handle = nullptr;
}
}
void setQueue(cnrtQueue_t queue) { mluOpSetQueue(handle, queue); }
mluOpHandle_t handle;
};
// modify tensor size and stride order based on
// channels_first to channels_last or channels_last_3d.
// which this is not same with pytorch original layout,
// this real layout is based on data storage real order.
// example: modify channels_last tensor dim to nhwc tensor desc.
// N C H W --> N H W C
// C*H*W 1 W C --> C*H*W W C 1
template <typename T>
void convertShapeAndStride(std::vector<T>& shape_info,
std::vector<T>& stride_info) {
TORCH_MLU_CHECK(shape_info.size() == stride_info.size(),
"shape size need equal to stride size.");
const int dim = shape_info.size();
std::vector<T> temp_shape_info(dim);
std::vector<T> temp_stride_info(dim);
temp_shape_info[0] = shape_info[0];
temp_stride_info[0] = stride_info[0];
for (size_t i = 0; i < dim - 1; ++i) {
const int index = (i + 1) % (dim - 1) + 1;
temp_shape_info[i + 1] = shape_info[index];
temp_stride_info[i + 1] = stride_info[index];
}
shape_info.assign(temp_shape_info.begin(), temp_shape_info.end());
stride_info.assign(temp_stride_info.begin(), temp_stride_info.end());
}
// torch tensor provides int64_t type of shape and stride,
// but mluops descriptor requires type int32.
// use this function to ensure safe CAST, or report an error.
template <typename DST_T, typename SRC_T>
std::vector<DST_T> checkUpperBoundAndCastTo(const std::vector<SRC_T>& input) {
std::vector<DST_T> output;
output.reserve(input.size());
for (const auto& val : input) {
if (val > std::numeric_limits<DST_T>::max()) {
TORCH_MLU_CHECK(false, "Requires dim size not greater than ",
std::numeric_limits<DST_T>::max(), ". But got ", val,
".");
}
output.push_back(static_cast<DST_T>(val));
}
return output;
}
...@@ -14,7 +14,15 @@ ...@@ -14,7 +14,15 @@
#define MIN(a, b) (((a) < (b)) ? (a) : (b)) #define MIN(a, b) (((a) < (b)) ? (a) : (b))
void KernelMsDeformAttnForward( typedef enum {
MS_DEFORM_ATTN_FORWARD_INVALID = 0, /*!< Index is invalid. */
MS_DEFORM_ATTN_FORWARD_DEFAULT =
1, /*!< MLUKernelMsDeformAttnForwardDefault */
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL =
2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */
} MsDeformAttnForwardPolicy;
void KernelMsDeformAttnForwardDefault(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char* data_value_gdram, const cnrtDataType_t d_type, const char* data_value_gdram,
const char* data_spatial_shapes_gdram, const char* data_spatial_shapes_gdram,
...@@ -23,7 +31,37 @@ void KernelMsDeformAttnForward( ...@@ -23,7 +31,37 @@ void KernelMsDeformAttnForward(
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram); const int32_t num_points, char* data_col_gdram);
void KernelMsDeformAttnBackward( void KernelMsDeformAttnForwardSmallChannel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char* data_value_gdram,
const char* data_spatial_shapes_gdram,
const char* data_level_start_index_gdram,
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram);
typedef enum {
MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0,
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL = 1,
} MsDeformAttnBackwardKernelPolicy;
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
const int32_t channels, const int32_t num_levels, const int32_t num_points,
const int32_t num_heads) {
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const int num_hlp = num_heads * num_levels * num_points;
int num_per_time_theory = (nram_size - num_levels * sizeof(float) -
3 * num_levels * sizeof(int32_t)) /
sizeof(float) / (8 * PAD_UP(channels, 32) + 28) /
PAD_UP((num_hlp), 32);
if (num_per_time_theory >= 1) {
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
}
return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
}
void KernelMsDeformAttnBackwardDefaultKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float* data_value, const cnrtDataType_t d_type, const float* data_value,
const int32_t* spatial_shapes, const int32_t* data_level_start_index, const int32_t* spatial_shapes, const int32_t* data_level_start_index,
...@@ -32,10 +70,23 @@ void KernelMsDeformAttnBackward( ...@@ -32,10 +70,23 @@ void KernelMsDeformAttnBackward(
const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_queries, const int32_t num_points, float* grad_value, const int32_t num_queries, const int32_t num_points, float* grad_value,
float* grad_sampling_loc, float* grad_attn_weight); float* grad_sampling_loc, float* grad_attn_weight);
void KernelMsDeformAttnBackwardSmallChannelsKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float* data_value,
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
const float* data_sampling_loc, const float* data_attn_weight,
const float* grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float* grad_value,
float* grad_sampling_loc, float* grad_attn_weight);
// policy function // policy function
static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
const int batch_size, const int num_queries, cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, const int32_t batch_size,
const int num_heads) { const int32_t num_keys, const int32_t num_heads, const int32_t channels,
const int32_t num_levels, const int32_t num_queries,
const int32_t num_points) {
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = k_dim->y =
MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x, MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x,
...@@ -46,6 +97,16 @@ static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, ...@@ -46,6 +97,16 @@ static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type,
#else #else
*k_type = CNRT_FUNC_TYPE_UNION1; *k_type = CNRT_FUNC_TYPE_UNION1;
#endif #endif
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else if (channels > nram_size / 12 / sizeof(float) || channels > 96 ||
channels < 16) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else {
return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL;
}
} }
// policy function for backward // policy function for backward
...@@ -196,7 +257,9 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value, ...@@ -196,7 +257,9 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
// calculate task dimension // calculate task dimension
cnrtDim3_t k_dim; cnrtDim3_t k_dim;
cnrtFunctionType_t k_type; cnrtFunctionType_t k_type;
policyFuncForward(&k_dim, &k_type, batch_size, num_queries, num_heads); MsDeformAttnForwardPolicy policy = msDeformAttnForwardPolicyFunc(
&k_dim, &k_type, batch_size, num_keys, num_heads, channels, num_levels,
num_queries, num_points);
// get compute queue // get compute queue
auto queue = torch_mlu::getCurQueue(); auto queue = torch_mlu::getCurQueue();
...@@ -222,15 +285,33 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value, ...@@ -222,15 +285,33 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype()); cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype());
// launch kernel // launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForward<<<" << k_dim.x switch (policy) {
<< ", " << k_dim.y << ", " << k_dim.z << ">>>"; default: {
VLOG(5) << "MsDeformAttnForward Policy not supported";
KernelMsDeformAttnForward( }; break;
case MS_DEFORM_ATTN_FORWARD_DEFAULT: {
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<"
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForwardDefault(
k_dim, k_type, queue, data_type, (char*)value_ptr,
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points,
(char*)output_ptr);
break;
}
case MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL: {
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<"
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForwardSmallChannel(
k_dim, k_type, queue, data_type, (char*)value_ptr, k_dim, k_type, queue, data_type, (char*)value_ptr,
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr, (char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, num_heads, channels, num_levels, num_queries, num_points,
(char*)output_ptr); (char*)output_ptr);
break;
}
}
output = output.view({batch_size, num_queries, num_heads * channels}); output = output.view({batch_size, num_queries, num_heads * channels});
return output; return output;
...@@ -391,14 +472,32 @@ void ms_deform_attn_mlu_backward( ...@@ -391,14 +472,32 @@ void ms_deform_attn_mlu_backward(
// launch kernel // launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>"; << ", " << k_dim.y << ", " << k_dim.z << ">>>";
MsDeformAttnBackwardKernelPolicy kernelPolicy =
KernelMsDeformAttnBackward( msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points,
num_heads);
switch (kernelPolicy) {
default: {
VLOG(5) << "NotImplemented.";
} break;
case MS_DEFORM_ATTN_BACKWARD_DEFAULT: {
KernelMsDeformAttnBackwardDefaultKernel(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
} break;
case MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL: {
KernelMsDeformAttnBackwardSmallChannelsKernel(
k_dim, k_type, queue, data_type, (float*)value_ptr, k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr, (float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr, num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
} break;
}
} }
Tensor ms_deform_attn_impl_forward(const Tensor& value, Tensor ms_deform_attn_impl_forward(const Tensor& value,
......
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
Tensor nms_rotated_mlu(Tensor boxes, Tensor scores, float iou_threshold) {
if (boxes.numel() == 0) {
return at::empty({0}, boxes.options().dtype(at::kLong));
}
int boxes_num = boxes.size(0);
auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes);
auto scores_ = torch_mlu::cnnl::ops::cnnl_contiguous(scores);
auto output = at::empty({boxes_num}, boxes.options().dtype(at::kInt));
auto output_size = at::empty({1}, scores.options().dtype(at::kInt));
MluOpTensorDescriptor boxes_desc, scores_desc, output_desc;
boxes_desc.set(boxes_);
scores_desc.set(scores_);
output_desc.set(output);
// workspace
size_t workspace_size = 0;
auto handle = mluOpGetCurrentHandle();
mluOpGetNmsRotatedWorkspaceSize(handle, boxes_desc.desc(), &workspace_size);
auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte));
auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_);
auto boxes_ptr = boxes_impl->cnnlMalloc();
auto scores_impl = torch_mlu::getMluTensorImpl(scores_);
auto scores_ptr = scores_impl->cnnlMalloc();
auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
auto workspace_ptr = workspace_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output);
auto output_ptr = output_impl->cnnlMalloc();
auto output_size_impl = torch_mlu::getMluTensorImpl(output_size);
auto output_size_ptr = output_size_impl->cnnlMalloc();
mluOpNmsRotated(handle, iou_threshold, boxes_desc.desc(), boxes_ptr,
scores_desc.desc(), scores_ptr, workspace_ptr, workspace_size,
output_desc.desc(), output_ptr, (int *)output_size_ptr);
int output_num = *static_cast<int *>(output_size.cpu().data_ptr());
auto ret = output.to(boxes.options().dtype(at::kLong));
return ret.slice(0, 0, output_num);
}
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <torch/script.h>
#include <vector>
#include "mlu_common_helper.h"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
template <unsigned NDim>
std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose) {
// The following code is copied from
// mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu to ensure the output is
// available for network train. The outputs of this function have correct
// shape but wrong value.
auto numAct = indices.size(0);
auto kernelVolume = kernelSize[0];
int sub_m = (int)_subM;
int transpose = (int)_transpose;
int batch = (int)batchSize;
auto coorDim = indices.size(1) - 1;
for (int i = 1; i < kernelSize.size(); ++i) {
kernelVolume *= kernelSize[i];
}
auto outputVolume = outSpatialShape[0];
for (int i = 1; i < outSpatialShape.size(); ++i) {
outputVolume *= outSpatialShape[i];
}
torch::Tensor indicePairs = at::full({kernelVolume, 2, numAct}, -1,
indices.options().dtype(at::kInt));
torch::Tensor indiceNum =
at::zeros({kernelVolume}, indices.options().dtype(at::kInt));
int out_size = sub_m == 1
? numAct
: std::min(numAct * kernelVolume, batch * outputVolume);
torch::Tensor out_indices =
at::zeros({out_size, coorDim + 1}, indices.options().dtype(at::kInt));
auto indices_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
indices, at::MemoryFormat::Contiguous);
auto indicePairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
indicePairs, at::MemoryFormat::Contiguous);
auto indiceNum_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
indiceNum, at::MemoryFormat::Contiguous);
auto out_indices_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
out_indices, at::MemoryFormat::Contiguous);
std::vector<int> input_space;
std::vector<int> filter_space;
std::vector<int> output_space;
std::vector<int> padding32;
std::vector<int> stride32;
std::vector<int> dilation32;
for (int i = 0; i < NDim; i++) {
input_space.push_back(spatialShape[i]);
filter_space.push_back(kernelSize[i]);
output_space.push_back(outSpatialShape[i]);
padding32.push_back(padding[i]);
stride32.push_back(stride[i]);
dilation32.push_back(dilation[i]);
}
MluOpTensorDescriptor indices_desc, out_indices_desc, indicePairs_desc,
indiceNum_desc;
indices_desc.set(indices_contiguous);
indicePairs_desc.set(indicePairs_contiguous);
indiceNum_desc.set(indiceNum_contiguous);
out_indices_desc.set(out_indices_contiguous);
{
mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY;
mluOpDataType_t dtype = MLUOP_DTYPE_INT32;
std::vector<int> dims;
dims = {numAct, coorDim + 1};
mluOpSetTensorDescriptor(indices_desc.desc(), layout, dtype, dims.size(),
dims.data());
dims = {kernelVolume, 2, numAct};
mluOpSetTensorDescriptor(indicePairs_desc.desc(), layout, dtype,
dims.size(), dims.data());
dims = {kernelVolume};
mluOpSetTensorDescriptor(indiceNum_desc.desc(), layout, dtype, dims.size(),
dims.data());
dims = {out_size, coorDim + 1};
mluOpSetTensorDescriptor(out_indices_desc.desc(), layout, dtype,
dims.size(), dims.data());
}
mluOpSparseConvolutionDescriptor_t sparse_conv_desc;
mluOpCreateSparseConvolutionDescriptor(&sparse_conv_desc);
mluOpSetSparseConvolutionDescriptor(
sparse_conv_desc, NDim + 2, batch, padding32.data(), stride32.data(),
dilation32.data(), input_space.data(), filter_space.data(),
output_space.data(), sub_m, transpose, 0);
auto handle = mluOpGetCurrentHandle();
size_t workspace_size = 0;
mluOpGetIndicePairsWorkspaceSize(
handle, sparse_conv_desc, indices_desc.desc(), indicePairs_desc.desc(),
out_indices_desc.desc(), indiceNum_desc.desc(), &workspace_size);
auto indice_workspace_size =
at::empty(workspace_size, indices.options().dtype(at::kByte));
auto indices_impl = torch_mlu::getMluTensorImpl(indices_contiguous);
auto out_indices_impl = torch_mlu::getMluTensorImpl(out_indices_contiguous);
auto indicePairs_impl = torch_mlu::getMluTensorImpl(indicePairs_contiguous);
auto indiceNum_impl = torch_mlu::getMluTensorImpl(indiceNum_contiguous);
auto indice_workspace_impl =
torch_mlu::getMluTensorImpl(indice_workspace_size);
auto indices_ptr = indices_impl->cnnlMalloc();
auto out_indices_ptr = out_indices_impl->cnnlMalloc();
auto indicePairs_ptr = indicePairs_impl->cnnlMalloc();
auto indiceNum_ptr = indiceNum_impl->cnnlMalloc();
auto indice_workspace_ptr = indice_workspace_impl->cnnlMalloc();
mluOpGetIndicePairs(handle, sparse_conv_desc, indices_desc.desc(),
indices_ptr, indice_workspace_ptr, workspace_size,
indicePairs_desc.desc(), indicePairs_ptr,
out_indices_desc.desc(), out_indices_ptr,
indiceNum_desc.desc(), indiceNum_ptr);
int num_act_out = 0;
mluOpGetSparseConvolutionNumActOut(sparse_conv_desc, &num_act_out);
mluOpDestroySparseConvolutionDescriptor(sparse_conv_desc);
if (!sub_m) {
return {out_indices.slice(0, 0, num_act_out), indicePairs, indiceNum};
} else {
return {indices, indicePairs, indiceNum};
}
}
torch::Tensor IndiceConvForwardMLUKernelLauncher(
torch::Tensor features, torch::Tensor filters, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse,
int64_t _subM) {
auto indice_num_cpu = indiceNum.to({torch::kCPU});
auto indice_num_cpu_64 = indice_num_cpu.data_ptr<int>();
int indice_num_len = indiceNum.numel();
int64_t indice_num[indice_num_len];
for (int i = 0; i < indice_num_len; ++i) {
indice_num[i] = (int64_t)(((int *)indice_num_cpu_64)[i]);
}
// generate empty output
int C = filters.dim() == 4 ? filters.size(3) : filters.size(4);
torch::Tensor output =
at::zeros({numActOut, C}, features.options().dtype(at::kFloat));
// generate descriptor
auto features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
features, at::MemoryFormat::Contiguous);
auto filters_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
filters, at::MemoryFormat::Contiguous);
auto indice_pairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
indicePairs, at::MemoryFormat::Contiguous);
auto output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
output, at::MemoryFormat::Contiguous);
MluOpTensorDescriptor features_desc, filters_desc, indice_pairs_desc,
output_desc;
features_desc.set(features_contiguous);
filters_desc.set(filters_contiguous);
indice_pairs_desc.set(indice_pairs_contiguous);
output_desc.set(output_contiguous);
// set layout
{
mluOpTensorLayout_t layout;
mluOpDataType_t dtype;
int dim;
int dims[8];
// features_desc
mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims);
mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
dim, dims);
// filters_desc
mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims);
mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
dim, dims);
// indice_pairs_desc
mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim,
dims);
mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY,
dtype, dim, dims);
// output_desc
mluOpGetTensorDescriptor(output_desc.desc(), &layout, &dtype, &dim, dims);
mluOpSetTensorDescriptor(output_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim,
dims);
}
auto handle = mluOpGetCurrentHandle();
size_t workspace_size = 0;
mluOpGetIndiceConvolutionForwardWorkspaceSize(
handle, features_desc.desc(), filters_desc.desc(),
indice_pairs_desc.desc(), output_desc.desc(), indice_num, numActOut,
_inverse, _subM, &workspace_size);
auto workspace =
at::empty(workspace_size, features.options().dtype(at::kByte));
auto features_impl = torch_mlu::getMluTensorImpl(features_contiguous);
auto filters_impl = torch_mlu::getMluTensorImpl(filters_contiguous);
auto indice_pairs_impl = torch_mlu::getMluTensorImpl(indice_pairs_contiguous);
auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
auto features_ptr = features_impl->cnnlMalloc();
auto filters_ptr = filters_impl->cnnlMalloc();
auto indice_pairs_ptr = indice_pairs_impl->cnnlMalloc();
auto workspace_ptr = workspace_impl->cnnlMalloc();
// outputs
auto output_impl = torch_mlu::getMluTensorImpl(output);
auto output_ptr = output_impl->cnnlMalloc();
mluOpIndiceConvolutionForward(
handle, features_desc.desc(), features_ptr, filters_desc.desc(),
filters_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num,
numActOut, _inverse, _subM, workspace_ptr, workspace_size,
output_desc.desc(), output_ptr);
return output;
}
std::vector<torch::Tensor> IndiceConvBackwardMLUKernelLauncher(
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM) {
auto indice_num_cpu = indiceNum.to({torch::kCPU});
auto indice_num_cpu_64 = indice_num_cpu.data_ptr<int>();
int indice_num_len = indiceNum.numel();
int64_t indice_num[indice_num_len];
for (int i = 0; i < indice_num_len; ++i) {
indice_num[i] = (int64_t)(((int *)(indice_num_cpu_64))[i]);
}
// generate empty input_grad
torch::Tensor input_grad = at::zeros({features.size(0), features.size(1)},
features.options().dtype(at::kFloat));
torch::Tensor filters_grad;
if (filters.dim() == 4) {
int h = filters.size(0);
int w = filters.size(1);
int c = filters.size(2);
int n = filters.size(3);
filters_grad = at::zeros({h, w, c, n}, filters.options().dtype(at::kFloat));
} else if (filters.dim() == 5) {
int d = filters.size(0);
int h = filters.size(1);
int w = filters.size(2);
int c = filters.size(3);
int n = filters.size(4);
filters_grad =
at::zeros({d, h, w, c, n}, filters.options().dtype(at::kFloat));
}
auto features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
features, at::MemoryFormat::Contiguous);
auto filters_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
filters, at::MemoryFormat::Contiguous);
auto output_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
outGrad, at::MemoryFormat::Contiguous);
auto indice_pairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
indicePairs, at::MemoryFormat::Contiguous);
auto input_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
features, at::MemoryFormat::Contiguous);
auto filters_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
filters, at::MemoryFormat::Contiguous);
MluOpTensorDescriptor features_desc, output_grad_desc, filters_desc,
indice_pairs_desc, input_grad_desc, filters_grad_desc;
features_desc.set(features_contiguous);
filters_desc.set(filters_contiguous);
output_grad_desc.set(output_grad_contiguous);
indice_pairs_desc.set(indice_pairs_contiguous);
input_grad_desc.set(input_grad_contiguous);
filters_grad_desc.set(filters_grad_contiguous);
// need to set desc layout with mluOp functions
{
mluOpTensorLayout_t layout;
mluOpDataType_t dtype;
int dim;
int dims[8];
// features_desc
mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims);
mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
dim, dims);
// filters_desc
mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims);
if (dim == 4) {
mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_HWCN, dtype,
dim, dims);
} else {
mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
dim, dims);
}
// output_grad_desc
mluOpGetTensorDescriptor(output_grad_desc.desc(), &layout, &dtype, &dim,
dims);
mluOpSetTensorDescriptor(output_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
dim, dims);
// indice_pairs_desc
mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim,
dims);
mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY,
dtype, dim, dims);
// input_grad_desc
mluOpGetTensorDescriptor(input_grad_desc.desc(), &layout, &dtype, &dim,
dims);
mluOpSetTensorDescriptor(input_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
dim, dims);
}
auto handle = mluOpGetCurrentHandle();
size_t data_workspace_size = 0;
mluOpGetIndiceConvolutionBackwardDataWorkspaceSize(
handle, output_grad_desc.desc(), filters_desc.desc(),
indice_pairs_desc.desc(), input_grad_desc.desc(), indice_num, _inverse,
&data_workspace_size);
size_t filters_workspace_size = 0;
mluOpGetIndiceConvolutionBackwardFilterWorkspaceSize(
handle, features_desc.desc(), output_grad_desc.desc(),
indice_pairs_desc.desc(), filters_grad_desc.desc(), indice_num, _inverse,
_subM, &filters_workspace_size);
auto indice_convbpdata_workspace =
at::empty(data_workspace_size, features.options().dtype(at::kByte));
auto indice_convbpfilter_workspace =
at::empty(filters_workspace_size, filters.options().dtype(at::kByte));
auto features_impl = torch_mlu::getMluTensorImpl(features_contiguous);
auto filters_impl = torch_mlu::getMluTensorImpl(filters_contiguous);
auto output_grad_impl = torch_mlu::getMluTensorImpl(output_grad_contiguous);
auto indice_pairs_impl = torch_mlu::getMluTensorImpl(indice_pairs_contiguous);
auto indice_convbpdata_workspace_impl =
torch_mlu::getMluTensorImpl(indice_convbpdata_workspace);
auto indice_convbpfilter_workspace_impl =
torch_mlu::getMluTensorImpl(indice_convbpfilter_workspace);
auto features_ptr = features_impl->cnnlMalloc();
auto filters_ptr = filters_impl->cnnlMalloc();
auto output_grad_ptr = output_grad_impl->cnnlMalloc();
auto indice_pairs_ptr = indice_pairs_impl->cnnlMalloc();
auto indice_convbpdata_workspace_ptr =
indice_convbpdata_workspace_impl->cnnlMalloc();
auto indice_convbpfilter_workspace_ptr =
indice_convbpfilter_workspace_impl->cnnlMalloc();
// outputs
auto input_grad_impl = torch_mlu::getMluTensorImpl(input_grad);
auto input_grad_ptr = input_grad_impl->cnnlMalloc();
auto filters_grad_impl = torch_mlu::getMluTensorImpl(filters_grad);
auto filters_grad_ptr = filters_grad_impl->cnnlMalloc();
mluOpIndiceConvolutionBackwardData(
handle, output_grad_desc.desc(), output_grad_ptr, filters_desc.desc(),
filters_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num,
_inverse, _subM, indice_convbpdata_workspace_ptr, data_workspace_size,
input_grad_desc.desc(), input_grad_ptr);
mluOpIndiceConvolutionBackwardFilter(
handle, features_desc.desc(), features_ptr, output_grad_desc.desc(),
output_grad_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num,
_inverse, _subM, indice_convbpfilter_workspace_ptr,
filters_workspace_size, filters_grad_desc.desc(), filters_grad_ptr);
std::vector<torch::Tensor> result;
result.push_back(input_grad);
result.push_back(filters_grad);
return result;
}
torch::Tensor indice_conv_forward_mlu(torch::Tensor features,
torch::Tensor filters,
torch::Tensor indicePairs,
torch::Tensor indiceNum,
int64_t numActOut, int64_t _inverse,
int64_t _subM) {
return IndiceConvForwardMLUKernelLauncher(
features, filters, indicePairs, indiceNum, numActOut, _inverse, _subM);
}
std::vector<torch::Tensor> indice_conv_backward_mlu(
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM) {
return IndiceConvBackwardMLUKernelLauncher(
features, filters, outGrad, indicePairs, indiceNum, _inverse, _subM);
}
torch::Tensor indice_conv_forward_impl(torch::Tensor features,
torch::Tensor filters,
torch::Tensor indicePairs,
torch::Tensor indiceNum,
int64_t numActOut, int64_t _inverse,
int64_t _subM);
std::vector<torch::Tensor> indice_conv_backward_impl(
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM);
REGISTER_DEVICE_IMPL(indice_conv_forward_impl, MLU, indice_conv_forward_mlu);
REGISTER_DEVICE_IMPL(indice_conv_backward_impl, MLU, indice_conv_backward_mlu);
template std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher<2>(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);
template std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher<3>(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);
template std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher<4>(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
void KernelDynamicVoxelize(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *points, void *coors, const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min, const float coors_y_min,
const float coors_z_min, const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int32_t grid_x, const int32_t grid_y,
const int32_t grid_z, const int32_t num_points, const int32_t num_features);
void KernelPoint2Voxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *coors, void *point_to_pointidx,
void *point_to_voxelidx, const int32_t num_points,
const int32_t max_points);
void KernelCalcPointsPerVoxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *point_to_pointidx,
void *point_to_voxelidx, void *coor_to_voxelidx,
void *num_points_per_voxel, void *voxel_num,
const int32_t max_voxels,
const int32_t num_points);
void KernelAssignVoxelsCoors(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const void *points,
void *temp_coors, void *point_to_voxelidx,
void *coor_to_voxelidx, void *voxels, void *coors,
const int32_t max_points, const int32_t num_points,
const int32_t num_features);
// policy function
static void policyFuncDefault(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type,
const int num_points) {
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = MIN((num_points + k_dim->x - 1) / k_dim->x,
torch_mlu::getDeviceAttr(cnrtAttrClusterCount));
k_dim->z = 1;
*k_type = CNRT_FUNC_TYPE_UNION1;
}
// policy function
static void policyFuncCalcPointsPerVoxel(cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type,
const int num_points) {
k_dim->x = 1;
k_dim->y = 1;
k_dim->z = 1;
*k_type = CNRT_FUNC_TYPE_BLOCK;
}
int HardVoxelizeForwardMLUKernelLauncher(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3) {
// check datatype
TORCH_CHECK(points.scalar_type() == at::kFloat,
"points type should be Float, got ", points.scalar_type(), ".");
TORCH_CHECK(voxels.scalar_type() == at::kFloat,
"voxels type should be Float, got ", voxels.scalar_type(), ".");
TORCH_CHECK(coors.scalar_type() == at::kInt,
"coors type should be Float, got ", coors.scalar_type(), ".");
TORCH_CHECK(num_points_per_voxel.scalar_type() == at::kInt,
"num_points_per_voxel type should be Float, got ",
num_points_per_voxel.scalar_type(), ".");
// check shape
TORCH_CHECK(points.dim() == 2, "points should be a 2d tensor, got ",
points.dim(), "D.");
TORCH_CHECK(voxels.dim() == 3, "voxels should be a 3d tensor, got ",
voxels.dim(), "D.");
TORCH_CHECK(coors.dim() == 2, "coors should be a 2d tensor, got ",
coors.dim(), "D.");
TORCH_CHECK(num_points_per_voxel.dim() == 1,
"num_points_per_voxel should be a 1d tensor, got ",
num_points_per_voxel.dim(), "D.");
const int num_points = points.size(0);
const int num_features = points.size(1);
TORCH_CHECK(points.size(0) == num_points,
"the 1st dimensions of points should be num_points, got ",
points.size(0), ".");
TORCH_CHECK(points.size(1) == num_features,
"the 2nd dimensions of points should be num_features, got ",
points.size(1), ".");
TORCH_CHECK(voxels.size(0) == max_voxels,
"the 1st dimensions of voxels should be max_voxels, got ",
voxels.size(0), ".");
TORCH_CHECK(voxels.size(1) == max_points,
"the 2nd dimensions of voxels should be max_points, got ",
voxels.size(1), ".");
TORCH_CHECK(voxels.size(2) == num_features,
"the 3rd dimensions of voxels should be num_features, got ",
voxels.size(2), ".");
TORCH_CHECK(coors.size(0) == max_voxels,
"the 1st dimensions of coors should be max_voxels, got ",
coors.size(0), ".");
TORCH_CHECK(coors.size(1) == 3,
"the 2nd dimensions of coors should be 3, got ", coors.size(1),
".");
TORCH_CHECK(num_points_per_voxel.size(0) == max_voxels,
"the 1st dimensions of num_points_per_voxel should be 3, got ",
num_points_per_voxel.size(0), ".");
// large tensor check
const size_t max_input_size = 2147483648;
TORCH_CHECK(points.numel() < max_input_size,
"points element num should be less than 2^31, got ",
points.numel(), ".");
TORCH_CHECK(voxels.numel() < max_input_size,
"voxels element num should be less than 2^31, got ",
voxels.numel(), ".");
TORCH_CHECK(coors.numel() < max_input_size,
"coors element num should be less than 2^31, got ", coors.numel(),
".");
// check zero element
if (max_points == 0 || max_voxels == 0) {
return 0;
}
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto points_ = points.contiguous();
auto points_impl = torch_mlu::getMluTensorImpl(points_);
auto points_ptr = points_impl->cnnlMalloc();
auto voxels_ = voxels.contiguous();
auto voxels_impl = torch_mlu::getMluTensorImpl(voxels_);
auto voxels_ptr = voxels_impl->cnnlMalloc();
auto coors_ = coors.contiguous();
auto coors_impl = torch_mlu::getMluTensorImpl(coors_);
auto coors_ptr = coors_impl->cnnlMalloc();
auto num_points_per_voxel_ = num_points_per_voxel.contiguous();
auto num_points_per_voxel_impl =
torch_mlu::getMluTensorImpl(num_points_per_voxel_);
auto num_points_per_voxel_ptr = num_points_per_voxel_impl->cnnlMalloc();
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncDefault(&k_dim, &k_type, num_points);
// 1. link point to corresponding voxel coors
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
auto temp_coors =
at::zeros({NDim, num_points}, points.options().dtype(at::kInt))
.contiguous();
auto temp_coors_impl = torch_mlu::getMluTensorImpl(temp_coors);
auto temp_coors_ptr = temp_coors_impl->cnnlMalloc();
KernelDynamicVoxelize(k_dim, k_type, queue, points_ptr, temp_coors_ptr,
voxel_x, voxel_y, voxel_z, coors_x_min, coors_y_min,
coors_z_min, coors_x_max, coors_y_max, coors_z_max,
grid_x, grid_y, grid_z, num_points, num_features);
// 2. map point to the idx of the corresponding voxel, find duplicate coor
auto point_to_pointidx = at::zeros(
{
num_points,
},
points.options().dtype(at::kInt))
.contiguous();
auto point_to_pointidx_impl = torch_mlu::getMluTensorImpl(point_to_pointidx);
auto point_to_pointidx_ptr = point_to_pointidx_impl->cnnlMalloc();
auto point_to_voxelidx = at::zeros(
{
num_points,
},
points.options().dtype(at::kInt))
.contiguous();
auto point_to_voxelidx_impl = torch_mlu::getMluTensorImpl(point_to_voxelidx);
auto point_to_voxelidx_ptr = point_to_voxelidx_impl->cnnlMalloc();
KernelPoint2Voxel(k_dim, k_type, queue, temp_coors_ptr, point_to_pointidx_ptr,
point_to_voxelidx_ptr, num_points, max_points);
// calculate task dimension
cnrtDim3_t k_dim_calc_points_per_voxel;
cnrtFunctionType_t k_type_calc_points_per_voxel;
policyFuncCalcPointsPerVoxel(&k_dim_calc_points_per_voxel,
&k_type_calc_points_per_voxel, num_points);
// 3. determine voxel num and voxel's coor index
auto coor_to_voxelidx = at::zeros(
{
num_points,
},
points.options().dtype(at::kInt))
.contiguous();
auto coor_to_voxelidx_impl = torch_mlu::getMluTensorImpl(coor_to_voxelidx);
auto coor_to_voxelidx_ptr = coor_to_voxelidx_impl->cnnlMalloc();
auto voxel_num = at::zeros(
{
1,
},
points.options().dtype(at::kInt))
.contiguous();
auto voxel_num_impl = torch_mlu::getMluTensorImpl(voxel_num);
auto voxel_num_ptr = voxel_num_impl->cnnlMalloc();
KernelCalcPointsPerVoxel(
k_dim_calc_points_per_voxel, k_type_calc_points_per_voxel, queue,
point_to_pointidx_ptr, point_to_voxelidx_ptr, coor_to_voxelidx_ptr,
num_points_per_voxel_ptr, voxel_num_ptr, max_voxels, num_points);
// 4. copy point features and coors of each voxels to voxels
KernelAssignVoxelsCoors(k_dim, k_type, queue, points_ptr, temp_coors_ptr,
point_to_voxelidx_ptr, coor_to_voxelidx_ptr,
voxels_ptr, coors_ptr, max_points, num_points,
num_features);
auto voxel_num_cpu = voxel_num.to(at::kCPU);
int voxel_num_int = voxel_num_cpu.data_ptr<int>()[0];
return voxel_num_int;
}
int hard_voxelize_forward_mlu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim) {
return HardVoxelizeForwardMLUKernelLauncher(
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
max_points, max_voxels, NDim);
};
int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim);
REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, MLU,
hard_voxelize_forward_mlu);
...@@ -17,6 +17,11 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores, ...@@ -17,6 +17,11 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
const Tensor labels, const float iou_threshold); const Tensor labels, const float iou_threshold);
#endif #endif
#ifdef MMCV_WITH_MLU
Tensor nms_rotated_mlu(const Tensor dets, const Tensor scores,
const float iou_threshold);
#endif
// Interface for Python // Interface for Python
// inline is needed to prevent multiple function definitions when this header is // inline is needed to prevent multiple function definitions when this header is
// included by different cpps // included by different cpps
...@@ -36,6 +41,10 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, ...@@ -36,6 +41,10 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
return nms_rotated_npu(dets, scores, labels, iou_threshold); return nms_rotated_npu(dets, scores, labels, iou_threshold);
#else #else
AT_ERROR("Not compiled with NPU support"); AT_ERROR("Not compiled with NPU support");
#endif
#ifdef MMCV_WITH_MLU
} else if (dets.device().type() == at::kMLU) {
return nms_rotated_mlu(dets, scores, iou_threshold);
#endif #endif
} }
......
...@@ -35,6 +35,26 @@ std::vector<torch::Tensor> get_indice_pairs_forward_cuda( ...@@ -35,6 +35,26 @@ std::vector<torch::Tensor> get_indice_pairs_forward_cuda(
padding, dilation, outPadding, _subM, _transpose); padding, dilation, outPadding, _subM, _transpose);
}; };
template <unsigned NDim>
std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);
template <unsigned NDim>
std::vector<torch::Tensor> get_indice_pairs_forward_mlu(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose) {
return GetIndicePairsForwardMLUKernelLauncher<NDim>(
indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride,
padding, dilation, outPadding, _subM, _transpose);
}
template <unsigned NDim> template <unsigned NDim>
std::vector<torch::Tensor> GetIndicePairsBackwardCUDAKernelLauncher( std::vector<torch::Tensor> GetIndicePairsBackwardCUDAKernelLauncher(
torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize, torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize,
...@@ -71,6 +91,12 @@ std::vector<torch::Tensor> get_indice_pairs_forward( ...@@ -71,6 +91,12 @@ std::vector<torch::Tensor> get_indice_pairs_forward(
padding, dilation, outPadding, _subM, _transpose); padding, dilation, outPadding, _subM, _transpose);
#else #else
AT_ERROR("get_indice_pairs is not compiled with GPU support"); AT_ERROR("get_indice_pairs is not compiled with GPU support");
#endif
#ifdef MMCV_WITH_MLU
} else if (indices.device().type() == at::kMLU) {
return get_indice_pairs_forward_mlu<NDim>(
indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride,
padding, dilation, outPadding, _subM, _transpose);
#endif #endif
} else { } else {
AT_ERROR("get_indice_pairs is not implemented on CPU"); AT_ERROR("get_indice_pairs is not implemented on CPU");
......
...@@ -410,8 +410,9 @@ def nms_rotated(dets: Tensor, ...@@ -410,8 +410,9 @@ def nms_rotated(dets: Tensor,
input_labels = scores.new_empty(0, dtype=torch.int) input_labels = scores.new_empty(0, dtype=torch.int)
else: else:
input_labels = labels input_labels = labels
if dets.device.type == 'npu': if dets.device.type in ('npu', 'mlu'):
order = scores.new_empty(0, dtype=torch.long) order = scores.new_empty(0, dtype=torch.long)
if dets.device.type == 'npu':
coefficient = 57.29578 # 180 / PI coefficient = 57.29578 # 180 / PI
for i in range(dets.size()[0]): for i in range(dets.size()[0]):
dets_cw[i][4] *= coefficient # radians to angle dets_cw[i][4] *= coefficient # radians to angle
......
...@@ -211,6 +211,7 @@ def get_extensions(): ...@@ -211,6 +211,7 @@ def get_extensions():
include_dirs = [] include_dirs = []
extra_objects = []
is_rocm_pytorch = False is_rocm_pytorch = False
try: try:
from torch.utils.cpp_extension import ROCM_HOME from torch.utils.cpp_extension import ROCM_HOME
...@@ -238,16 +239,98 @@ def get_extensions(): ...@@ -238,16 +239,98 @@ def get_extensions():
torch.is_mlu_available()) or \ torch.is_mlu_available()) or \
os.getenv('FORCE_MLU', '0') == '1': os.getenv('FORCE_MLU', '0') == '1':
from torch_mlu.utils.cpp_extension import MLUExtension from torch_mlu.utils.cpp_extension import MLUExtension
def get_mluops_version(file_path):
with open(file_path) as f:
for line in f:
if re.search('MLUOP_MAJOR', line):
major = line.strip().split(' ')[2]
if re.search('MLUOP_MINOR', line):
minor = line.strip().split(' ')[2]
if re.search('MLUOP_PATCHLEVEL', line):
patchlevel = line.strip().split(' ')[2]
mluops_version = f'v{major}.{minor}.{patchlevel}'
return mluops_version
mmcv_mluops_version = get_mluops_version(
'./mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h')
mlu_ops_path = os.getenv('MMCV_MLU_OPS_PATH')
if mlu_ops_path:
exists_mluops_version = get_mluops_version(
mlu_ops_path + '/bangc-ops/mlu_op.h')
if exists_mluops_version != mmcv_mluops_version:
print('the version of mlu-ops provided is %s,'
' while %s is needed.' %
(exists_mluops_version, mmcv_mluops_version))
exit()
try:
if os.path.exists('mlu-ops'):
if os.path.islink('mlu-ops'):
os.remove('mlu-ops')
os.symlink(mlu_ops_path, 'mlu-ops')
elif os.path.abspath('mlu-ops') != mlu_ops_path:
os.symlink(mlu_ops_path, 'mlu-ops')
else:
os.symlink(mlu_ops_path, 'mlu-ops')
except Exception:
raise FileExistsError(
'mlu-ops already exists, please move it out,'
'or rename or remove it.')
else:
if not os.path.exists('mlu-ops'):
import requests
mluops_url = 'https://github.com/Cambricon/mlu-ops/' + \
'archive/refs/tags/' + mmcv_mluops_version + '.zip'
req = requests.get(mluops_url)
with open('./mlu-ops.zip', 'wb') as f:
try:
f.write(req.content)
except Exception:
raise ImportError('failed to download mlu-ops')
from zipfile import BadZipFile, ZipFile
with ZipFile('./mlu-ops.zip', 'r') as archive:
try:
archive.extractall()
dir_name = archive.namelist()[0].split('/')[0]
os.rename(dir_name, 'mlu-ops')
except BadZipFile:
print('invalid mlu-ops.zip file')
else:
exists_mluops_version = get_mluops_version(
'./mlu-ops/bangc-ops/mlu_op.h')
if exists_mluops_version != mmcv_mluops_version:
print('the version of provided mlu-ops is %s,'
' while %s is needed.' %
(exists_mluops_version, mmcv_mluops_version))
exit()
define_macros += [('MMCV_WITH_MLU', None)] define_macros += [('MMCV_WITH_MLU', None)]
mlu_args = os.getenv('MMCV_MLU_ARGS') mlu_args = os.getenv('MMCV_MLU_ARGS', '-DNDEBUG ')
extra_compile_args['cncc'] = [mlu_args] if mlu_args else [] mluops_includes = []
mluops_includes.append('-I' +
os.path.abspath('./mlu-ops/bangc-ops'))
mluops_includes.append(
'-I' + os.path.abspath('./mlu-ops/bangc-ops/kernels'))
extra_compile_args['cncc'] = [mlu_args] + \
mluops_includes if mlu_args else mluops_includes
extra_compile_args['cxx'] += ['-fno-gnu-unique']
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu') glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu') + \
glob.glob(
'./mlu-ops/bangc-ops/core/**/*.cpp', recursive=True) + \
glob.glob(
'./mlu-ops/bangc-ops/kernels/**/*.cpp', recursive=True) + \
glob.glob(
'./mlu-ops/bangc-ops/kernels/**/*.mlu', recursive=True)
extra_objects = glob.glob(
'./mlu-ops/bangc-ops/kernels/kernel_wrapper/*.o')
extension = MLUExtension extension = MLUExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))
include_dirs.append(os.path.abspath('./mlu-ops/bangc-ops'))
elif (hasattr(torch.backends, 'mps') elif (hasattr(torch.backends, 'mps')
and torch.backends.mps.is_available()) or os.getenv( and torch.backends.mps.is_available()) or os.getenv(
'FORCE_MPS', '0') == '1': 'FORCE_MPS', '0') == '1':
...@@ -309,6 +392,7 @@ def get_extensions(): ...@@ -309,6 +392,7 @@ def get_extensions():
sources=op_files, sources=op_files,
include_dirs=include_dirs, include_dirs=include_dirs,
define_macros=define_macros, define_macros=define_macros,
extra_objects=extra_objects,
extra_compile_args=extra_compile_args) extra_compile_args=extra_compile_args)
extensions.append(ext_ops) extensions.append(ext_ops)
return extensions return extensions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment