Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
a0939977
Commit
a0939977
authored
Nov 16, 2022
by
ZShaopeng
Committed by
Zaida Zhou
Nov 23, 2022
Browse files
[Feature] Support MultiScaleDeformableAttn with cambricon MLU backend
parent
193de43b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1393 additions
and
43 deletions
+1393
-43
docs/en/understand_mmcv/ops.md
docs/en/understand_mmcv/ops.md
+1
-1
docs/zh_cn/understand_mmcv/ops.md
docs/zh_cn/understand_mmcv/ops.md
+1
-1
mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp
mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp
+33
-0
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
+853
-0
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
+420
-0
mmcv/ops/multi_scale_deform_attn.py
mmcv/ops/multi_scale_deform_attn.py
+5
-3
tests/test_ops/test_ms_deformable_attn.py
tests/test_ops/test_ms_deformable_attn.py
+80
-38
No files found.
docs/en/understand_mmcv/ops.md
View file @
a0939977
...
...
@@ -33,7 +33,7 @@ We implement common ops used in detection, segmentation, etc.
| MergeCells | | √ | | |
| MinAreaPolygon | | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| MultiScaleDeformableAttn | | √ |
| |
| MultiScaleDeformableAttn | | √ |
√
| |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
...
...
docs/zh_cn/understand_mmcv/ops.md
View file @
a0939977
...
...
@@ -33,7 +33,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MergeCells | | √ | | |
| MinAreaPolygon | | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| MultiScaleDeformableAttn | | √ |
| |
| MultiScaleDeformableAttn | | √ |
√
| |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
...
...
mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp
View file @
a0939977
...
...
@@ -362,4 +362,37 @@ __mlu_func__ inline void convertFloat2half(half *dst, float *src,
#endif
}
/*!
* @brief recursiveSumPool.
* @param[in,out] dst
* Pointer to NRAM that stores the input and output data.
* @param[in] low_dim
* Which is the number of low dim.
* @param[in] high_dim
* Which is the number of high dim.
* @param[in] kernel_limit
* Which is the high_dim of sumpool per time.
******************************************************************************/
template
<
typename
T
>
__mlu_func__
void
recursiveSumPool
(
T
*
dst
,
int
low_dim
,
int
high_dim
,
int
kernel_limit
)
{
for
(;
high_dim
>
1
;)
{
int
repeat_s
=
high_dim
/
kernel_limit
;
int
remain_s
=
high_dim
%
kernel_limit
;
if
(
remain_s
)
{
__bang_sumpool
((
T
*
)
dst
,
(
T
*
)
dst
,
low_dim
,
1
,
remain_s
,
1
,
remain_s
,
1
,
1
);
}
if
(
repeat_s
)
{
__bang_sumpool
((
T
*
)
dst
+
(
remain_s
>
0
?
low_dim
:
0
),
(
T
*
)
dst
+
remain_s
*
low_dim
,
low_dim
,
kernel_limit
*
repeat_s
,
1
,
kernel_limit
,
1
,
1
,
kernel_limit
);
}
high_dim
=
repeat_s
+
(
bool
)
remain_s
;
}
return
;
}
#endif // COMMON_MLU_HELPER_HPP_
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
0 → 100644
View file @
a0939977
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include <math.h>
/****************************************************************************************
*
* NRAM partition forward:
* | spatial_shapes | data_value_p1_ping | data_value_p2_ping |
* | data_value_p3_ping | data_value_p4_ping | data_col_ping |
* | data_value_p1_pong | data_value_p2_pong | data_value_p3_pong |
* | data_value_p4_pong | data_col_pong | auxiliary_a |
* | auxiliary_b |
* | 128bytes | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size |
*
****************************************************************************************/
/****************************************************************************************
*
* NRAM partition backward:
* | grad_output_nram | grad_output_nram_temp | grad_weight |
* | grad_h_weight | grad_w_weight | top_grad |
* | top_grad_temp | spatial_shapes_nram | sampling_loc_nram |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | 64bytes |
*
****************************************************************************************/
#define TWELVE_SPLIT 12
#define ALIGN_NUM 64
#define ALIGN_NUM_FOR_REDUCE 32
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void loadNeighborPointsData(
const T *data_value_gdram, T *data_value_p1_nram, T *data_value_p2_nram,
T *data_value_p3_nram, T *data_value_p4_nram, const size_t deal_num,
const int32_t &width, const int32_t &height, const int32_t &num_heads,
const int32_t &channels, const T &x, const T &y, const int32_t &head_idx) {
const int32_t w_low = floorf(x);
const int32_t h_low = floorf(y);
const int32_t w_high = w_low + 1;
const int32_t h_high = h_low + 1;
const int32_t w_stride = num_heads * channels;
const int32_t h_stride = width * w_stride;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * w_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int32_t base_ptr_offset = head_idx * channels;
// top-left point
if (h_low >= 0 && w_low >= 0) {
const int32_t v1_offset =
h_low_ptr_offset + w_low_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p1_nram, data_value_gdram + v1_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// top-right point
if (h_low >= 0 && w_high <= width - 1) {
const int32_t v2_offset =
h_low_ptr_offset + w_high_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p2_nram, data_value_gdram + v2_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// bottom-left point
if (h_high <= height - 1 && w_low >= 0) {
const int32_t v3_offset =
h_high_ptr_offset + w_low_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p3_nram, data_value_gdram + v3_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// bottom-right point
if (h_high <= height - 1 && w_high <= width - 1) {
const int32_t v4_offset =
h_high_ptr_offset + w_high_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p4_nram, data_value_gdram + v4_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
}
template <typename T>
__mlu_func__ void bilinearInterpolation(
T *data_value_p1_nram, T *data_value_p2_nram, T *data_value_p3_nram,
T *data_value_p4_nram, T *sample_point_value, T *auxiliary_b,
const size_t deal_num, const int32_t &width, const int32_t &height,
const T &x, const T &y) {
const int32_t w_low = floorf(x);
const int32_t h_low = floorf(y);
const int32_t w_high = w_low + 1;
const int32_t h_high = h_low + 1;
const T lw = x - w_low;
const T lh = y - h_low;
const T hw = 1 - lw;
const T hh = 1 - lh;
const T w1 = hh * hw;
const T w2 = hh * lw;
const T w3 = lh * hw;
const T w4 = lh * lw;
__bang_write_value((T *)sample_point_value, deal_num, (T)0);
// top-left point
if (h_low >= 0 && w_low >= 0) {
// sample_point_value += v1 * w1
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p1_nram, (T)w1,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// top-right point
if (h_low >= 0 && w_high <= width - 1) {
// sample_point_value += v2 * w2
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p2_nram, (T)w2,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// bottom-left point
if (h_high <= height - 1 && w_low >= 0) {
// sample_point_value += v3 * w3
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p3_nram, (T)w3,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// bottom-right point
if (h_high <= height - 1 && w_high <= width - 1) {
// sample_point_value += v4 * w4
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p4_nram, (T)w4,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForward(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
if (coreId == 0x80) {
return;
}
const size_t spatial_size = PAD_UP(2 * sizeof(int32_t), NFU_ALIGN_SIZE);
const size_t span_num_deal =
PAD_DOWN((MAX_NRAM_SIZE - spatial_size) / TWELVE_SPLIT / sizeof(T),
NFU_ALIGN_SIZE);
const size_t align_num = NFU_ALIGN_SIZE;
const int32_t channels_seg_num = channels / span_num_deal;
const size_t channels_rem = channels % span_num_deal;
const size_t channels_align_rem = CEIL_ALIGN(channels_rem, align_num);
char *data_spatial_shapes_nram = nram_buffer;
char *ping_data_value_p1_nram = data_spatial_shapes_nram + spatial_size;
char *ping_data_value_p2_nram =
ping_data_value_p1_nram + span_num_deal * sizeof(T);
char *ping_data_value_p3_nram =
ping_data_value_p2_nram + span_num_deal * sizeof(T);
char *ping_data_value_p4_nram =
ping_data_value_p3_nram + span_num_deal * sizeof(T);
char *ping_data_col_nram =
ping_data_value_p4_nram + span_num_deal * sizeof(T);
char *pong_data_value_p1_nram =
ping_data_col_nram + span_num_deal * sizeof(T);
char *pong_data_value_p2_nram =
pong_data_value_p1_nram + span_num_deal * sizeof(T);
char *pong_data_value_p3_nram =
pong_data_value_p2_nram + span_num_deal * sizeof(T);
char *pong_data_value_p4_nram =
pong_data_value_p3_nram + span_num_deal * sizeof(T);
char *pong_data_col_nram =
pong_data_value_p4_nram + span_num_deal * sizeof(T);
char *auxiliary_a = pong_data_col_nram + span_num_deal * sizeof(T);
char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T);
const size_t ping_pong_gap = 5 * span_num_deal * sizeof(T);
size_t data_col_ping_pong_idx = 0;
int32_t block_num_per_core = (batch_size * num_queries * num_heads) / taskDim;
const int32_t block_num_rem =
(batch_size * num_queries * num_heads) % taskDim;
const int32_t idx_start = taskId < (block_num_rem + 1)
? taskId * (block_num_per_core + 1)
: taskId * block_num_per_core + block_num_rem;
block_num_per_core =
taskId < block_num_rem
? (batch_size * num_queries * num_heads) / taskDim + 1
: (batch_size * num_queries * num_heads) / taskDim;
for (int32_t cur_idx = idx_start; cur_idx < idx_start + block_num_per_core;
++cur_idx) {
// cur_idx = batch_idx * num_queries * num_heads + query_idx * num_heads +
// head_idx
const int32_t head_idx = cur_idx % num_heads;
const int32_t batch_idx = (cur_idx / num_heads) / num_queries;
const char *data_value_gdram_start =
data_value_gdram +
batch_idx * num_keys * num_heads * channels * sizeof(T);
const char *data_sampling_loc_gdram_start =
data_sampling_loc_gdram +
cur_idx * num_levels * num_points * 2 * sizeof(T);
const char *data_attn_weight_gdram_start =
data_attn_weight_gdram + cur_idx * num_levels * num_points * sizeof(T);
char *data_col_gdram_start =
data_col_gdram + cur_idx * channels * sizeof(T);
for (int32_t c_seg_idx = 0; c_seg_idx < channels_seg_num; ++c_seg_idx) {
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
span_num_deal, (T)0);
// load data
// level_idx = 0, point_idx = 0
__memcpy(data_spatial_shapes_nram, data_spatial_shapes_gdram,
2 * sizeof(int32_t), GDRAM2NRAM);
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + c_seg_idx * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_gdram_start)[0];
T loc_h = ((T *)data_sampling_loc_gdram_start)[1];
T weight = ((T *)data_attn_weight_gdram_start)[0];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, span_num_deal, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_gdram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
__memcpy(
data_spatial_shapes_nram,
data_spatial_shapes_gdram + spatial_h_ptr * sizeof(int32_t),
2 * sizeof(int32_t), GDRAM2NRAM);
spatial_h_next_point = ((int32_t *)data_spatial_shapes_nram)[0];
spatial_w_next_point = ((int32_t *)data_spatial_shapes_nram)[1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
c_seg_idx * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_h_next_point = spatial_h;
spatial_w_next_point = spatial_w;
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
bilinearInterpolation(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b, span_num_deal, spatial_w,
spatial_h, x, y);
__bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight,
span_num_deal);
__bang_add((T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)auxiliary_a, span_num_deal);
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
}
// store
__memcpy_async(
data_col_gdram_start + c_seg_idx * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
span_num_deal * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
if (channels_rem > 0) {
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
channels_align_rem, (T)0);
// load data
// level_idx = 0, point_idx = 0
__memcpy(data_spatial_shapes_nram, data_spatial_shapes_gdram,
2 * sizeof(int32_t), GDRAM2NRAM);
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + channels_seg_num * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_gdram_start)[0];
T loc_h = ((T *)data_sampling_loc_gdram_start)[1];
T weight = ((T *)data_attn_weight_gdram_start)[0];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, channels_rem, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_gdram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
__memcpy(
data_spatial_shapes_nram,
data_spatial_shapes_gdram + spatial_h_ptr * sizeof(int32_t),
2 * sizeof(int32_t), GDRAM2NRAM);
spatial_h_next_point = ((int32_t *)data_spatial_shapes_nram)[0];
spatial_w_next_point = ((int32_t *)data_spatial_shapes_nram)[1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
channels_seg_num * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_w_next_point = spatial_w;
spatial_h_next_point = spatial_h;
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
bilinearInterpolation(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b, channels_align_rem,
spatial_w, spatial_h, x, y);
__bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight,
channels_align_rem);
__bang_add((T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)auxiliary_a, channels_align_rem);
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
}
// store
__memcpy_async(
data_col_gdram_start + channels_seg_num * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
channels_rem * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
}
__asm__ volatile("sync;");
return;
}
template __mlu_global__ void MLUKernelMsDeformAttnForward<float>(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram);
void KernelMsDeformAttnForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
MLUKernelMsDeformAttnForward<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
}
template <typename T>
void __mlu_func__ msDeformAttnCol2imBilinear(
T *top_grad_temp, const int32_t &height, const int32_t &width, const T &w1,
const T &w2, const T &w3, const T &w4, const int32_t &h_low,
const int32_t &w_low, const int32_t &h_high, const int32_t &w_high,
const int32_t &base_ptr, const int32_t &h_low_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &h_high_ptr_offset,
const int32_t &w_high_ptr_offset, const T &hh, const T &hw, const T &lh,
const T &lw, T *top_grad, const T &data_attn_weight, T *grad_h_weight,
T *grad_w_weight, T *grad_value, T *grad_output_nram, T *grad_weight,
T *grad_sampling_loc, T *grad_attn_weight, T *grad_output_nram_temp,
const int32_t &deal_num, const int32_t &deal_num_real,
const T *data_value_ptr) {
if (h_low >= 0 && w_low >= 0) {
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram, data_value_ptr + offset1,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset1),
(T *)top_grad_temp, deal_num_real);
}
if (h_low >= 0 && w_high <= width - 1) {
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset2,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num);
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w2,
deal_num);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset2),
(T *)top_grad_temp, deal_num_real);
}
if (h_high <= height - 1 && w_low >= 0) {
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset3,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w3,
deal_num);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset3),
(T *)top_grad_temp, deal_num_real);
}
if (h_high <= height - 1 && w_high <= width - 1) {
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset4,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w4,
deal_num);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset4),
(T *)top_grad_temp, deal_num_real);
}
__bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_output_nram, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
const int32_t align_num_on_200 = NFU_ALIGN_SIZE / sizeof(float);
recursiveSumPool(grad_output_nram, align_num_on_200,
deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_output_nram, grad_output_nram,
NFU_ALIGN_SIZE / sizeof(float));
#endif
__bang_atomic_add((T *)grad_output_nram, (T *)grad_attn_weight,
(T *)grad_output_nram, 1);
__bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
recursiveSumPool(grad_w_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_w_weight, grad_w_weight,
NFU_ALIGN_SIZE / sizeof(float));
#endif
__bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc),
(T *)grad_w_weight, 1);
__bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num);
__bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
recursiveSumPool(grad_h_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_h_weight, grad_h_weight,
NFU_ALIGN_SIZE / sizeof(float));
#endif
__bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1),
(T *)grad_h_weight, 1);
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight) {
if (coreId == 0x80) {
return;
}
const int32_t split_num = 8;
const int32_t spatial_shapes_size = 64;
int32_t deal_num = PAD_DOWN(
(MAX_NRAM_SIZE - spatial_shapes_size) / split_num / sizeof(float),
ALIGN_NUM);
float *grad_output_nram = (float *)nram_buffer;
float *grad_output_nram_temp = (float *)nram_buffer + deal_num;
float *grad_weight = (float *)nram_buffer + 2 * deal_num;
float *grad_h_weight = (float *)nram_buffer + 3 * deal_num;
float *grad_w_weight = (float *)nram_buffer + 4 * deal_num;
float *top_grad = (float *)nram_buffer + 5 * deal_num;
float *top_grad_temp = (float *)nram_buffer + 6 * deal_num;
int32_t *spatial_shapes_nram =
(int32_t *)((float *)nram_buffer + 7 * deal_num);
float *sampling_loc_nram =
(float *)nram_buffer + 7 * deal_num + 2 * sizeof(int32_t);
const int32_t total_num = batch * num_query * num_heads * num_levels;
int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem);
int32_t start_per_core =
num_rem > taskId
? (taskId * num_per_core)
: ((num_per_core + 1) * num_rem + (taskId - num_rem) * num_per_core);
int32_t end_per_core = start_per_core + num_per_core;
const int32_t C_repeat = channels / deal_num;
const int32_t C_tail = channels % deal_num;
const int32_t qid_stride = num_heads * channels;
int32_t base_ptr = 0;
for (int32_t num_loop = start_per_core; num_loop < end_per_core; ++num_loop) {
const int32_t l_col = num_loop % num_levels;
const int32_t m_col = num_loop / num_levels % num_heads;
const int32_t q_col = num_loop / num_levels / num_heads % num_query;
const int32_t b_col = num_loop / num_query / num_heads / num_levels;
int32_t data_weight_ptr = num_loop * num_points;
int32_t data_loc_w_ptr = data_weight_ptr << 1;
const int32_t value_offset = b_col * spatial_size * num_heads * channels;
const int32_t level_start_id = data_level_start_index[l_col];
int32_t spatial_h_ptr = l_col << 1;
int32_t grad_output_offset = b_col * num_query * num_heads * channels +
q_col * num_heads * channels +
m_col * channels;
__memcpy(spatial_shapes_nram, spatial_shapes + spatial_h_ptr,
2 * sizeof(int32_t), GDRAM2NRAM);
const int32_t spatial_h = spatial_shapes_nram[0];
const int32_t spatial_w = spatial_shapes_nram[1];
const int32_t value_ptr_offset = value_offset + level_start_id * qid_stride;
const float *data_value_ptr = data_value + value_ptr_offset;
float *grad_value_ptr = grad_value + value_ptr_offset;
const int32_t grad_attn_weight_out = num_loop * num_points;
const int32_t grad_sampling_loc_out = num_loop * num_points * 2;
for (int32_t p_col = 0; p_col < num_points; ++p_col) {
__memcpy(sampling_loc_nram, data_sampling_loc + data_loc_w_ptr,
2 * sizeof(float), GDRAM2NRAM);
const float loc_w = sampling_loc_nram[0];
const float loc_h = sampling_loc_nram[1];
const float weight = data_attn_weight[data_weight_ptr];
const float h_im = loc_h * spatial_h - 0.5;
const float w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
const int32_t h_low = floorf(h_im);
const int32_t w_low = floorf(w_im);
const int32_t h_high = h_low + 1;
const int32_t w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1.0 - lh;
const float hw = 1.0 - lw;
const int32_t w_stride = num_heads * channels;
const int32_t h_stride = spatial_w * w_stride;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * w_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + w_stride;
float w1 = hh * hw;
float w2 = hh * lw;
float w3 = lh * hw;
float w4 = lh * lw;
for (int32_t C_loop = 0; C_loop < C_repeat; ++C_loop) {
base_ptr = m_col * channels + C_loop * deal_num;
__bang_write_zero(grad_weight, 3 * deal_num);
__bang_write_zero(grad_output_nram, deal_num);
__memcpy(top_grad,
grad_output + grad_output_offset + C_loop * deal_num,
deal_num * sizeof(float), GDRAM2NRAM);
msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram, grad_weight,
grad_sampling_loc + grad_sampling_loc_out + p_col * 2,
grad_attn_weight + grad_attn_weight_out + p_col,
grad_output_nram_temp, deal_num, deal_num, data_value_ptr);
}
if (C_tail != 0) {
base_ptr = m_col * channels + C_repeat * deal_num;
__bang_write_zero(grad_output_nram, 8 * deal_num);
__memcpy(top_grad,
grad_output + grad_output_offset + C_repeat * deal_num,
C_tail * sizeof(float), GDRAM2NRAM);
msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram, grad_weight,
grad_sampling_loc + grad_sampling_loc_out + p_col * 2,
grad_attn_weight + grad_attn_weight_out + p_col,
grad_output_nram_temp, deal_num, C_tail, data_value_ptr);
}
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight);
void KernelMsDeformAttnBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float *data_value,
const int32_t *spatial_shapes, const int32_t *data_level_start_index,
const float *data_sampling_loc, const float *data_attn_weight,
const float *grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float *grad_value,
float *grad_sampling_loc, float *grad_attn_weight) {
MLUUnion1KernelMsDeformAttnBackward<<<k_dim, k_type, queue>>>(
data_value, spatial_shapes, data_level_start_index, data_sampling_loc,
data_attn_weight, grad_output, batch, spatial_size, num_heads, channels,
num_levels, num_query, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
}
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
0 → 100644
View file @
a0939977
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
void
KernelMsDeformAttnForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
char
*
data_value_gdram
,
const
char
*
data_spatial_shapes_gdram
,
const
char
*
data_level_start_index_gdram
,
const
char
*
data_sampling_loc_gdram
,
const
char
*
data_attn_weight_gdram
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
,
char
*
data_col_gdram
);
void
KernelMsDeformAttnBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
float
*
data_value
,
const
int32_t
*
spatial_shapes
,
const
int32_t
*
data_level_start_index
,
const
float
*
data_sampling_loc
,
const
float
*
data_attn_weight
,
const
float
*
grad_output
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
,
float
*
grad_value
,
float
*
grad_sampling_loc
,
float
*
grad_attn_weight
);
// policy function
static
void
policyFuncForward
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
const
int
batch_size
,
const
int
num_queries
,
const
int
num_heads
)
{
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
MIN
((
batch_size
*
num_queries
*
num_heads
+
k_dim
->
x
-
1
)
/
k_dim
->
x
,
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
));
k_dim
->
z
=
1
;
#if __BANG_ARCH__ == 520
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
#else
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
#endif
}
// policy function for backward
static
void
policyFuncBackward
(
const
int32_t
batch_size
,
const
int32_t
num_queries
,
const
int32_t
num_heads
,
const
int32_t
num_levels
,
cnrtFunctionType_t
*
k_type
,
cnrtDim3_t
*
k_dim
)
{
size_t
cluster_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
size_t
core_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
x
=
core_limit
;
int32_t
total_num
=
batch_size
*
num_queries
*
num_heads
*
num_levels
;
size_t
total_num_align
=
CEIL_ALIGN
(
total_num
,
core_limit
);
k_dim
->
y
=
(
total_num_align
/
core_limit
)
>
cluster_limit
?
cluster_limit
:
(
total_num_align
/
core_limit
);
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
Tensor
ms_deform_attn_mlu_forward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
// check contiguous
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
// check datatype
TORCH_CHECK
((
value
.
scalar_type
()
==
at
::
kFloat
),
"value type should be Float, got "
,
value
.
scalar_type
(),
"."
);
TORCH_CHECK
((
spatial_shapes
.
scalar_type
()
==
at
::
kInt
||
spatial_shapes
.
scalar_type
()
==
at
::
kLong
),
"spatial_shapes type should be Int, got "
,
spatial_shapes
.
scalar_type
(),
"."
);
TORCH_CHECK
((
level_start_index
.
scalar_type
()
==
at
::
kInt
||
level_start_index
.
scalar_type
()
==
at
::
kLong
),
"level_start_index type should be Int, got "
,
level_start_index
.
scalar_type
(),
"."
);
TORCH_CHECK
((
sampling_loc
.
scalar_type
()
==
at
::
kFloat
),
"sampling_loc type should be Float, got "
,
sampling_loc
.
scalar_type
(),
"."
);
TORCH_CHECK
((
attn_weight
.
scalar_type
()
==
at
::
kFloat
),
"attn_weight type should be Float, got "
,
attn_weight
.
scalar_type
(),
"."
);
// check shape
TORCH_CHECK
(
value
.
dim
()
==
4
,
"value should be a 4d tensor, got "
,
value
.
dim
(),
"D."
);
TORCH_CHECK
(
spatial_shapes
.
dim
()
==
2
,
"spatial_shapes should be a 2d tensor, got "
,
spatial_shapes
.
dim
(),
"D."
);
TORCH_CHECK
(
level_start_index
.
dim
()
==
1
,
"level_start_index should be a 1d tensor, got "
,
level_start_index
.
dim
(),
"D."
);
TORCH_CHECK
(
sampling_loc
.
dim
()
==
6
,
"sampling_loc should be a 6d tensor, got "
,
sampling_loc
.
dim
(),
"D."
);
TORCH_CHECK
(
attn_weight
.
dim
()
==
5
,
"attn_weight should be a 5d tensor, got "
,
attn_weight
.
dim
(),
"D."
);
const
int
batch_size
=
value
.
size
(
0
);
const
int
num_keys
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_queries
=
sampling_loc
.
size
(
1
);
const
int
num_points
=
sampling_loc
.
size
(
4
);
TORCH_CHECK
(
spatial_shapes
.
size
(
1
)
==
2
,
"the 2nd dimensions of spatial_shapes should be 2, got "
,
spatial_shapes
.
size
(
1
),
"."
);
TORCH_CHECK
(
sampling_loc
.
size
(
5
)
==
2
,
"the 6th dimensions of sampling_loc should be 2, got "
,
sampling_loc
.
size
(
5
),
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of sampling_loc should be batch_size, "
,
"but now the 1st dimension of sampling_loc is "
,
sampling_loc
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of attn_weight should be batch_size, "
,
"but now the 1st dimension of attn_weight is "
,
attn_weight
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of sampling_loc should be num_heads, "
,
"but now the 3rd dimension of sampling_loc is "
,
sampling_loc
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of attn_weight should be num_heads, "
,
"but now the 3rd dimension of attn_weight is "
,
attn_weight
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
level_start_index
.
size
(
0
)
==
num_levels
),
"the 1st dimensions of level_start_index should be num_levels, "
,
"but now the 1st dimension of level_start_index is "
,
level_start_index
.
size
(
0
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of sampling_loc should be num_levels, "
,
"but now the 4th dimension of sampling_loc is "
,
sampling_loc
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of attn_weight should be num_levels, "
,
"but now the 4th dimension of attn_weight is "
,
attn_weight
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of attn_weight should be num_queries, "
,
"but now the 2nd dimension of attn_weight is "
,
attn_weight
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
4
)
==
num_points
),
"the 5th dimensions of attn_weight should be num_points, "
,
"but now the 5th dimension of attn_weight is "
,
attn_weight
.
size
(
4
),
", and num_points is "
,
num_points
,
"."
);
auto
output
=
at
::
zeros
({
batch_size
,
num_queries
,
num_heads
,
channels
},
value
.
options
());
// large tensor check
const
size_t
max_input_size
=
2147483648
;
TORCH_CHECK
(
value
.
numel
()
<
max_input_size
,
"value element num should be less than 2^31, got "
,
value
.
numel
(),
"."
);
TORCH_CHECK
(
sampling_loc
.
numel
()
<
max_input_size
,
"sampling_loc element num should be less than 2^31, got "
,
sampling_loc
.
numel
(),
"."
);
TORCH_CHECK
(
output
.
numel
()
<
max_input_size
,
"output element num should be less than 2^31, got "
,
output
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
batch_size
!=
0
,
"batch_size should not be zero"
);
TORCH_CHECK
(
num_heads
!=
0
,
"num_heads should not be zero"
);
TORCH_CHECK
(
channels
!=
0
,
"channels should not be zero"
);
TORCH_CHECK
(
num_queries
!=
0
,
"num_queries should not be zero"
);
if
(
num_keys
==
0
||
num_levels
==
0
||
num_points
==
0
)
{
return
output
;
}
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncForward
(
&
k_dim
,
&
k_type
,
batch_size
,
num_queries
,
num_heads
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
spatial_shapes_
=
spatial_shapes
.
to
(
at
::
kInt
);
auto
level_start_index_
=
level_start_index
.
to
(
at
::
kInt
);
// get ptr of tensors
auto
value_impl
=
torch_mlu
::
getMluTensorImpl
(
value
);
auto
value_ptr
=
value_impl
->
cnnlMalloc
();
auto
spatial_shapes_impl
=
torch_mlu
::
getMluTensorImpl
(
spatial_shapes_
);
auto
spatial_shapes_ptr
=
spatial_shapes_impl
->
cnnlMalloc
();
auto
level_start_index_impl
=
torch_mlu
::
getMluTensorImpl
(
level_start_index_
);
auto
level_start_index_ptr
=
level_start_index_impl
->
cnnlMalloc
();
auto
sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
sampling_loc
);
auto
sampling_loc_ptr
=
sampling_loc_impl
->
cnnlMalloc
();
auto
attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
attn_weight
);
auto
attn_weight_ptr
=
attn_weight_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
// get compute dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
value
.
dtype
());
// launch kernel
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelMsDeformAttnForward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelMsDeformAttnForward
(
k_dim
,
k_type
,
queue
,
data_type
,
(
char
*
)
value_ptr
,
(
char
*
)
spatial_shapes_ptr
,
(
char
*
)
level_start_index_ptr
,
(
char
*
)
sampling_loc_ptr
,
(
char
*
)
attn_weight_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
char
*
)
output_ptr
);
output
=
output
.
view
({
batch_size
,
num_queries
,
num_heads
*
channels
});
return
output
;
}
void
ms_deform_attn_mlu_backward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
)
{
// check contiguous
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
AT_ASSERTM
(
grad_output
.
is_contiguous
(),
"grad_output tensor has to be contiguous"
);
// check datatype
TORCH_CHECK
((
value
.
scalar_type
()
==
at
::
kFloat
),
"value type should be Float, got "
,
value
.
scalar_type
(),
"."
);
TORCH_CHECK
((
spatial_shapes
.
scalar_type
()
==
at
::
kInt
||
spatial_shapes
.
scalar_type
()
==
at
::
kLong
),
"spatial_shapes type should be Int, got "
,
spatial_shapes
.
scalar_type
(),
"."
);
TORCH_CHECK
((
level_start_index
.
scalar_type
()
==
at
::
kInt
||
level_start_index
.
scalar_type
()
==
at
::
kLong
),
"level_start_index type should be Int, got "
,
level_start_index
.
scalar_type
(),
"."
);
TORCH_CHECK
((
sampling_loc
.
scalar_type
()
==
at
::
kFloat
),
"sampling_loc type should be Float, got "
,
sampling_loc
.
scalar_type
(),
"."
);
TORCH_CHECK
((
attn_weight
.
scalar_type
()
==
at
::
kFloat
),
"attn_weight type should be Float, got "
,
attn_weight
.
scalar_type
(),
"."
);
TORCH_CHECK
((
grad_output
.
scalar_type
()
==
at
::
kFloat
),
"grad_output type should be Float, got "
,
grad_output
.
scalar_type
(),
"."
);
const
int
batch_size
=
value
.
size
(
0
);
const
int
num_keys
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_queries
=
sampling_loc
.
size
(
1
);
const
int
num_points
=
sampling_loc
.
size
(
4
);
// Check shape.
TORCH_CHECK
(
spatial_shapes
.
size
(
1
)
==
2
,
"the 2nd dimensions of spatial_shapes should be 2, got "
,
spatial_shapes
.
size
(
1
),
"."
);
TORCH_CHECK
((
level_start_index
.
size
(
0
)
==
num_levels
),
"the 1st dimensions of level_start_index should be num_levels, "
,
"but now the 1st dimension of level_start_index is "
,
level_start_index
.
size
(
0
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of sampling_loc should be batch_size, "
,
"but now the 1st dimension of sampling_loc is "
,
sampling_loc
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of sampling_loc should be num_heads, "
,
"but now the 3rd dimension of sampling_loc is "
,
sampling_loc
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of sampling_loc should be num_levels, "
,
"but now the 4th dimension of sampling_loc is "
,
sampling_loc
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
(
sampling_loc
.
size
(
5
)
==
2
,
"the 6th dimensions of sampling_loc should be 2, got "
,
sampling_loc
.
size
(
5
),
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of attn_weight should be batch_size, "
,
"but now the 1st dimension of attn_weight is "
,
attn_weight
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of attn_weight should be num_queries, "
,
"but now the 2nd dimension of attn_weight is "
,
attn_weight
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of attn_weight should be num_heads, "
,
"but now the 3rd dimension of attn_weight is "
,
attn_weight
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of attn_weight should be num_levels, "
,
"but now the 4th dimension of attn_weight is "
,
attn_weight
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
4
)
==
num_points
),
"the 5th dimensions of attn_weight should be num_points, "
,
"but now the 5th dimension of attn_weight is "
,
attn_weight
.
size
(
4
),
", and num_points is "
,
num_points
,
"."
);
TORCH_CHECK
((
grad_output
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of grad_output should be batch_size, "
,
"but now the 1st dimension of grad_output is "
,
grad_output
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
grad_output
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of grad_output should be num_queries, "
,
"but now the 2nd dimension of grad_output is "
,
grad_output
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
(
(
grad_output
.
size
(
2
)
==
num_heads
*
channels
),
"the 3rd dimensions of grad_output should be num_heads * channels, "
,
"but now the 3rd dimension of grad_output is "
,
grad_output
.
size
(
2
),
", and num_heads * channels is "
,
num_heads
*
channels
,
"."
);
// check zero element
TORCH_CHECK
(
batch_size
!=
0
,
"The batch_size is zero."
);
TORCH_CHECK
(
channels
!=
0
,
"The channels is zero."
);
TORCH_CHECK
(
num_keys
!=
0
,
"The num_keys is zero."
);
TORCH_CHECK
(
num_heads
!=
0
,
"The num_heads is zero."
);
TORCH_CHECK
(
num_queries
!=
0
,
"The num_queries is zero."
);
if
(
num_levels
==
0
||
num_points
==
0
)
{
return
;
}
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncBackward
(
batch_size
,
num_queries
,
num_heads
,
num_levels
,
&
k_type
,
&
k_dim
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
value_impl
=
torch_mlu
::
getMluTensorImpl
(
value
);
auto
value_ptr
=
value_impl
->
cnnlMalloc
();
auto
spatial_shapes_impl
=
torch_mlu
::
getMluTensorImpl
(
spatial_shapes
);
auto
spatial_shapes_ptr
=
spatial_shapes_impl
->
cnnlMalloc
();
auto
level_start_index_impl
=
torch_mlu
::
getMluTensorImpl
(
level_start_index
);
auto
level_start_index_ptr
=
level_start_index_impl
->
cnnlMalloc
();
auto
sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
sampling_loc
);
auto
sampling_loc_ptr
=
sampling_loc_impl
->
cnnlMalloc
();
auto
attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
attn_weight
);
auto
attn_weight_ptr
=
attn_weight_impl
->
cnnlMalloc
();
auto
grad_output_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_output
);
auto
grad_output_ptr
=
grad_output_impl
->
cnnlMalloc
();
auto
grad_value_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_value
);
auto
grad_value_ptr
=
grad_value_impl
->
cnnlMalloc
();
auto
grad_sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_sampling_loc
);
auto
grad_sampling_loc_ptr
=
grad_sampling_loc_impl
->
cnnlMalloc
();
auto
grad_attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_attn_weight
);
auto
grad_attn_weight_ptr
=
grad_attn_weight_impl
->
cnnlMalloc
();
// get comput dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
value
.
dtype
());
// launch kernel
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelMsDeformAttnBackward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelMsDeformAttnBackward
(
k_dim
,
k_type
,
queue
,
data_type
,
(
float
*
)
value_ptr
,
(
int32_t
*
)
spatial_shapes_ptr
,
(
int32_t
*
)
level_start_index_ptr
,
(
float
*
)
sampling_loc_ptr
,
(
float
*
)
attn_weight_ptr
,
(
float
*
)
grad_output_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
float
*
)
grad_value_ptr
,
(
float
*
)
grad_sampling_loc_ptr
,
(
float
*
)
grad_attn_weight_ptr
);
}
Tensor
ms_deform_attn_impl_forward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
);
void
ms_deform_attn_impl_backward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
);
REGISTER_DEVICE_IMPL
(
ms_deform_attn_impl_forward
,
MLU
,
ms_deform_attn_mlu_forward
);
REGISTER_DEVICE_IMPL
(
ms_deform_attn_impl_backward
,
MLU
,
ms_deform_attn_mlu_backward
);
mmcv/ops/multi_scale_deform_attn.py
View file @
a0939977
...
...
@@ -12,6 +12,7 @@ from mmengine.registry import MODELS
from
mmengine.utils
import
deprecated_api_warning
from
torch.autograd.function
import
Function
,
once_differentiable
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
from
..utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
...
...
@@ -26,7 +27,7 @@ class MultiScaleDeformableAttnFunction(Function):
sampling_locations
:
torch
.
Tensor
,
attention_weights
:
torch
.
Tensor
,
im2col_step
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""GPU version of multi-scale deformable attention.
"""GPU
/MLU
version of multi-scale deformable attention.
Args:
value (torch.Tensor): The value has shape
...
...
@@ -63,7 +64,7 @@ class MultiScaleDeformableAttnFunction(Function):
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
tuple
:
"""GPU version of backward function.
"""GPU
/MLU
version of backward function.
Args:
grad_output (torch.Tensor): Gradient of output tensor of forward.
...
...
@@ -346,7 +347,8 @@ class MultiScaleDeformableAttention(BaseModule):
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
if
((
IS_CUDA_AVAILABLE
and
value
.
is_cuda
)
or
(
IS_MLU_AVAILABLE
and
value
.
is_mlu
)):
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
...
...
tests/test_ops/test_ms_deformable_attn.py
View file @
a0939977
...
...
@@ -5,6 +5,7 @@ import torch
from
mmcv.ops.multi_scale_deform_attn
import
(
MultiScaleDeformableAttention
,
MultiScaleDeformableAttnFunction
,
multi_scale_deformable_attn_pytorch
)
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
_USING_PARROTS
=
True
try
:
...
...
@@ -14,22 +15,25 @@ except ImportError:
_USING_PARROTS
=
False
@
pytest
.
mark
.
parametrize
(
'device
_type
'
,
[
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cpu'
,
pytest
.
param
(
'cuda:0'
,
marks
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
))
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
))
])
def
test_multiscale_deformable_attention
(
device_type
):
def
test_multiscale_deformable_attention
(
device
):
with
pytest
.
raises
(
ValueError
):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention
(
embed_dims
=
256
,
num_heads
=
7
,
)
device
=
torch
.
device
(
device
_type
)
device
=
torch
.
device
(
device
)
msda
=
MultiScaleDeformableAttention
(
embed_dims
=
3
,
num_levels
=
2
,
num_heads
=
3
)
msda
.
init_weights
()
...
...
@@ -70,20 +74,19 @@ def test_forward_multi_scale_deformable_attn_pytorch():
attention_weights
.
double
()).
detach
()
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
@
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)
def
test_forward_equal_with_pytorch_double
():
N
,
M
,
D
=
1
,
2
,
2
Lq
,
L
,
P
=
2
,
2
,
2
shapes
=
torch
.
as_tensor
([(
6
,
4
),
(
3
,
2
)],
dtype
=
torch
.
long
)
.
cuda
()
shapes
=
torch
.
as_tensor
([(
6
,
4
),
(
3
,
2
)],
dtype
=
torch
.
long
)
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
(
(
1
,
)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
S
=
sum
((
H
*
W
).
item
()
for
H
,
W
in
shapes
)
torch
.
manual_seed
(
3
)
value
=
torch
.
rand
(
N
,
S
,
M
,
D
)
.
cuda
()
*
0.01
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
)
.
cuda
()
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
)
.
cuda
()
+
1e-5
value
=
torch
.
rand
(
N
,
S
,
M
,
D
)
*
0.01
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
)
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
)
+
1e-5
attention_weights
/=
attention_weights
.
sum
(
-
1
,
keepdim
=
True
).
sum
(
-
2
,
keepdim
=
True
)
...
...
@@ -93,8 +96,9 @@ def test_forward_equal_with_pytorch_double():
attention_weights
.
double
()).
detach
().
cpu
()
output_cuda
=
MultiScaleDeformableAttnFunction
.
apply
(
value
.
double
(),
shapes
,
level_start_index
,
sampling_locations
.
double
(),
attention_weights
.
double
(),
im2col_step
).
detach
().
cpu
()
value
.
cuda
().
double
(),
shapes
.
cuda
(),
level_start_index
.
cuda
(),
sampling_locations
.
cuda
().
double
(),
attention_weights
.
cuda
().
double
(),
im2col_step
).
detach
().
cpu
()
assert
torch
.
allclose
(
output_cuda
,
output_pytorch
)
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
...
...
@@ -103,20 +107,28 @@ def test_forward_equal_with_pytorch_double():
assert
max_rel_err
<
1e-15
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
def
test_forward_equal_with_pytorch_float
():
@
pytest
.
mark
.
parametrize
(
'device'
,
[
pytest
.
param
(
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
))
])
def
test_forward_equal_with_pytorch_float
(
device
):
N
,
M
,
D
=
1
,
2
,
2
Lq
,
L
,
P
=
2
,
2
,
2
shapes
=
torch
.
as_tensor
([(
6
,
4
),
(
3
,
2
)],
dtype
=
torch
.
long
)
.
cuda
()
shapes
=
torch
.
as_tensor
([(
6
,
4
),
(
3
,
2
)],
dtype
=
torch
.
long
)
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
(
(
1
,
)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
S
=
sum
((
H
*
W
).
item
()
for
H
,
W
in
shapes
)
torch
.
manual_seed
(
3
)
value
=
torch
.
rand
(
N
,
S
,
M
,
D
)
.
cuda
()
*
0.01
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
)
.
cuda
()
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
)
.
cuda
()
+
1e-5
value
=
torch
.
rand
(
N
,
S
,
M
,
D
)
*
0.01
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
)
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
)
+
1e-5
attention_weights
/=
attention_weights
.
sum
(
-
1
,
keepdim
=
True
).
sum
(
-
2
,
keepdim
=
True
)
...
...
@@ -124,19 +136,37 @@ def test_forward_equal_with_pytorch_float():
output_pytorch
=
multi_scale_deformable_attn_pytorch
(
value
,
shapes
,
sampling_locations
,
attention_weights
).
detach
().
cpu
()
output_cuda
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
).
detach
().
cpu
()
assert
torch
.
allclose
(
output_cuda
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
output_device
=
MultiScaleDeformableAttnFunction
.
apply
(
value
.
to
(
device
),
shapes
.
to
(
device
),
level_start_index
.
to
(
device
),
sampling_locations
.
to
(
device
),
attention_weights
.
to
(
device
),
im2col_step
).
detach
().
cpu
()
assert
torch
.
allclose
(
output_device
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
output_device
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_device
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
assert
max_abs_err
<
1e-9
assert
max_rel_err
<
1e-6
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
@
pytest
.
mark
.
parametrize
(
'device'
,
[
pytest
.
param
(
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
))
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float
,
pytest
.
param
(
torch
.
double
,
marks
=
pytest
.
mark
.
skipif
(
IS_MLU_AVAILABLE
,
reason
=
'MLU does not support for 64-bit floating point'
)),
torch
.
half
])
@
pytest
.
mark
.
parametrize
(
'channels'
,
[
4
,
30
,
...
...
@@ -146,20 +176,22 @@ def test_forward_equal_with_pytorch_float():
1025
,
])
def
test_gradient_numerical
(
channels
,
device
,
dtype
,
grad_value
=
True
,
grad_sampling_loc
=
True
,
grad_attn_weight
=
True
):
N
,
M
,
_
=
1
,
2
,
2
Lq
,
L
,
P
=
2
,
2
,
2
shapes
=
torch
.
as_tensor
([(
3
,
2
),
(
2
,
1
)],
dtype
=
torch
.
long
).
cuda
(
)
shapes
=
torch
.
as_tensor
([(
3
,
2
),
(
2
,
1
)],
dtype
=
torch
.
long
).
to
(
device
)
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
(
(
1
,
)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
S
=
sum
((
H
*
W
).
item
()
for
H
,
W
in
shapes
)
value
=
torch
.
rand
(
N
,
S
,
M
,
channels
).
cuda
(
)
*
0.01
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
).
cuda
(
)
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
).
cuda
(
)
+
1e-5
value
=
torch
.
rand
(
N
,
S
,
M
,
channels
).
to
(
device
)
*
0.01
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
).
to
(
device
)
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
).
to
(
device
)
+
1e-5
attention_weights
/=
attention_weights
.
sum
(
-
1
,
keepdim
=
True
).
sum
(
-
2
,
keepdim
=
True
)
...
...
@@ -170,13 +202,23 @@ def test_gradient_numerical(channels,
value
.
requires_grad
=
grad_value
sampling_locations
.
requires_grad
=
grad_sampling_loc
attention_weights
.
requires_grad
=
grad_attn_weight
if
device
==
'cuda'
:
dtype
=
torch
.
double
eps
=
1e-6
elif
device
==
'mlu'
:
dtype
=
torch
.
float
eps
=
1e-4
if
_USING_PARROTS
:
assert
gradcheck
(
func
,
(
value
.
double
(
),
shapes
,
level_start_index
,
sampling_locations
.
double
(
),
attention_weights
.
double
(
),
func
,
(
value
.
to
(
dtype
),
shapes
,
level_start_index
,
sampling_locations
.
to
(
dtype
),
attention_weights
.
to
(
dtype
),
im2col_step
),
no_grads
=
[
shapes
,
level_start_index
])
no_grads
=
[
shapes
,
level_start_index
],
eps
=
eps
)
else
:
assert
gradcheck
(
func
,
(
value
.
double
(),
shapes
,
level_start_index
,
sampling_locations
.
double
(),
attention_weights
.
double
(),
im2col_step
))
assert
gradcheck
(
func
,
(
value
.
to
(
dtype
),
shapes
,
level_start_index
,
sampling_locations
.
to
(
dtype
),
attention_weights
.
to
(
dtype
),
im2col_step
),
eps
=
eps
,
atol
=
1e-2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment