Unverified Commit fdc052e8 authored by BinZheng's avatar BinZheng Committed by GitHub
Browse files

[Enhance] Optimize the performace of ms_deform_attn for MLU device (#2510)

* ms_opt

* ms_opt

* ms_opt

* ms_opt

* ms_opt

* [Feature] ms_deform_attn performance optimization

* [Feature] ms_deform_attn performance optimization

* [Feature] ms_deform_attn performance optimization
parent f76de907
......@@ -42,15 +42,16 @@
****************************************************************************************/
#define TWELVE_SPLIT 12
#define ALIGN_NUM 64
#define ALIGN_NUM 32
#define ALIGN_NUM_FOR_REDUCE 32
#define LEN_FLOAT sizeof(float)
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void loadNeighborPointsData(
const T *data_value_gdram, T *data_value_p1_nram, T *data_value_p2_nram,
T *data_value_p3_nram, T *data_value_p4_nram, const size_t deal_num,
T *data_value_p3_nram, T *data_value_p4_nram, const size_t &deal_num,
const int32_t &width, const int32_t &height, const int32_t &num_heads,
const int32_t &channels, const T &x, const T &y, const int32_t &head_idx) {
const int32_t w_low = floorf(x);
......@@ -100,11 +101,11 @@ __mlu_func__ void loadNeighborPointsData(
}
template <typename T>
__mlu_func__ void bilinearInterpolation(
__mlu_func__ void computeMsDeformAttn(
T *data_value_p1_nram, T *data_value_p2_nram, T *data_value_p3_nram,
T *data_value_p4_nram, T *sample_point_value, T *auxiliary_b,
const size_t deal_num, const int32_t &width, const int32_t &height,
const T &x, const T &y) {
T *data_col_nram, const T &weight, const size_t &deal_num,
const int32_t &width, const int32_t &height, const T &x, const T &y) {
const int32_t w_low = floorf(x);
const int32_t h_low = floorf(y);
const int32_t w_high = w_low + 1;
......@@ -156,10 +157,15 @@ __mlu_func__ void bilinearInterpolation(
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
__bang_mul_scalar((T *)sample_point_value, (T *)sample_point_value, (T)weight,
deal_num);
__bang_add((T *)data_col_nram, (T *)data_col_nram, (T *)sample_point_value,
deal_num);
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForward(
__mlu_global__ void MLUKernelMsDeformAttnForwardDefault(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
......@@ -346,7 +352,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForward(
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
bilinearInterpolation(
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
......@@ -359,15 +365,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForward(
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b, span_num_deal, spatial_w,
spatial_h, x, y);
__bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight,
span_num_deal);
__bang_add((T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)auxiliary_a, span_num_deal);
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, span_num_deal, spatial_w, spatial_h, x, y);
}
spatial_w = spatial_w_next_point;
......@@ -500,7 +501,459 @@ __mlu_global__ void MLUKernelMsDeformAttnForward(
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
bilinearInterpolation(
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, channels_align_rem, spatial_w, spatial_h, x, y);
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
}
// store
__memcpy_async(
data_col_gdram_start + channels_seg_num * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
channels_rem * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
}
__asm__ volatile("sync;");
return;
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
if (coreId == 0x80) {
return;
}
const size_t spatial_size =
PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE);
const size_t level_start_index_size =
PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE);
size_t sampling_loc_size =
PAD_UP(num_levels * num_points * 2 * sizeof(T), NFU_ALIGN_SIZE);
size_t attn_weight_size =
PAD_UP(num_levels * num_points * sizeof(T), NFU_ALIGN_SIZE);
size_t span_num_deal =
PAD_DOWN((MAX_NRAM_SIZE - spatial_size - level_start_index_size -
sampling_loc_size - attn_weight_size) /
TWELVE_SPLIT / sizeof(T),
NFU_ALIGN_SIZE);
const int32_t channels_seg_num = channels / span_num_deal;
const size_t channels_rem = channels % span_num_deal;
int32_t load_loc_weight_idx = 0;
int32_t load_loc_weight_seg = 1;
if (channels_seg_num == 0) {
span_num_deal = PAD_UP(channels, NFU_ALIGN_SIZE);
attn_weight_size =
PAD_DOWN((MAX_NRAM_SIZE - spatial_size - level_start_index_size -
TWELVE_SPLIT * span_num_deal * sizeof(T)) /
3,
num_levels * num_points * sizeof(T));
attn_weight_size = PAD_DOWN(attn_weight_size, NFU_ALIGN_SIZE);
sampling_loc_size = attn_weight_size * 2;
load_loc_weight_seg =
attn_weight_size / (num_levels * num_points * sizeof(T));
}
#if __BANG_ARCH__ < 322
const size_t align_num = NFU_ALIGN_SIZE;
const size_t channels_align_rem = CEIL_ALIGN(channels_rem, align_num);
#endif
char *data_spatial_shapes_nram = nram_buffer;
char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size;
char *data_sampling_loc_nram =
data_level_start_index_nram + level_start_index_size;
char *data_attn_weight_nram = data_sampling_loc_nram + sampling_loc_size;
char *ping_data_value_p1_nram = data_attn_weight_nram + attn_weight_size;
char *ping_data_value_p2_nram =
ping_data_value_p1_nram + span_num_deal * sizeof(T);
char *ping_data_value_p3_nram =
ping_data_value_p2_nram + span_num_deal * sizeof(T);
char *ping_data_value_p4_nram =
ping_data_value_p3_nram + span_num_deal * sizeof(T);
char *ping_data_col_nram =
ping_data_value_p4_nram + span_num_deal * sizeof(T);
char *pong_data_value_p1_nram =
ping_data_col_nram + span_num_deal * sizeof(T);
char *pong_data_value_p2_nram =
pong_data_value_p1_nram + span_num_deal * sizeof(T);
char *pong_data_value_p3_nram =
pong_data_value_p2_nram + span_num_deal * sizeof(T);
char *pong_data_value_p4_nram =
pong_data_value_p3_nram + span_num_deal * sizeof(T);
char *pong_data_col_nram =
pong_data_value_p4_nram + span_num_deal * sizeof(T);
char *auxiliary_a = pong_data_col_nram + span_num_deal * sizeof(T);
char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T);
const size_t ping_pong_gap = 5 * span_num_deal * sizeof(T);
size_t data_col_ping_pong_idx = 0;
const int32_t block_num_rem =
(batch_size * num_queries * num_heads) % taskDim;
const int32_t block_num_per_core =
taskId < block_num_rem
? (batch_size * num_queries * num_heads) / taskDim + 1
: (batch_size * num_queries * num_heads) / taskDim;
const int32_t idx_start = taskId < block_num_rem
? taskId * block_num_per_core
: taskId * block_num_per_core + block_num_rem;
__memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(data_level_start_index_nram, data_level_start_index_gdram,
num_levels * sizeof(int32_t), GDRAM2NRAM);
for (int32_t cur_idx = idx_start; cur_idx < idx_start + block_num_per_core;
++cur_idx) {
// cur_idx = batch_idx * num_queries * num_heads + query_idx * num_heads +
// head_idx
const int32_t head_idx = cur_idx % num_heads;
const int32_t batch_idx = (cur_idx / num_heads) / num_queries;
const char *data_value_gdram_start =
data_value_gdram +
batch_idx * num_keys * num_heads * channels * sizeof(T);
char *data_col_gdram_start =
data_col_gdram + cur_idx * channels * sizeof(T);
if (load_loc_weight_seg == 1 ||
(load_loc_weight_idx % load_loc_weight_seg) == 0) {
const char *data_sampling_loc_gdram_start =
data_sampling_loc_gdram +
cur_idx * num_levels * num_points * 2 * sizeof(T);
const char *data_attn_weight_gdram_start =
data_attn_weight_gdram +
cur_idx * num_levels * num_points * sizeof(T);
const int32_t load_loc_weight_size =
(block_num_per_core - load_loc_weight_idx) < load_loc_weight_seg
? block_num_per_core - load_loc_weight_idx
: load_loc_weight_seg;
__memcpy_async(
data_sampling_loc_nram, data_sampling_loc_gdram_start,
load_loc_weight_size * num_levels * num_points * 2 * sizeof(T),
GDRAM2NRAM);
__memcpy_async(data_attn_weight_nram, data_attn_weight_gdram_start,
load_loc_weight_size * num_levels * num_points * sizeof(T),
GDRAM2NRAM);
__asm__ volatile("sync;");
}
const int32_t load_loc_weight_offset =
(load_loc_weight_idx % load_loc_weight_seg) * num_levels * num_points;
for (int32_t c_seg_idx = 0; c_seg_idx < channels_seg_num; ++c_seg_idx) {
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
span_num_deal, (T)0);
// load data
// level_idx = 0, point_idx = 0
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + c_seg_idx * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2];
T loc_h = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2 + 1];
T weight = ((T *)data_attn_weight_nram)[load_loc_weight_offset];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, span_num_deal, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_nram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
spatial_h_next_point =
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr];
spatial_w_next_point =
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr + 1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
c_seg_idx * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2 +
1];
weight_next_point =
((T *)data_attn_weight_nram)[load_loc_weight_offset +
level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_h_next_point = spatial_h;
spatial_w_next_point = spatial_w;
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2 +
1];
weight_next_point =
((T *)data_attn_weight_nram)[load_loc_weight_offset +
level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, span_num_deal, spatial_w, spatial_h, x, y);
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
}
// store
__memcpy_async(
data_col_gdram_start + c_seg_idx * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
span_num_deal * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
if (channels_rem > 0) {
#if __BANG_ARCH__ >= 322
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
channels_rem, (T)0);
#else
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
channels_align_rem, (T)0);
#endif
// load data
// level_idx = 0, point_idx = 0
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + channels_seg_num * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2];
T loc_h = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2 + 1];
T weight = ((T *)data_attn_weight_nram)[load_loc_weight_offset];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, channels_rem, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_nram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
spatial_h_next_point =
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr];
spatial_w_next_point =
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr + 1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
channels_seg_num * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2 +
1];
weight_next_point =
((T *)data_attn_weight_nram)[load_loc_weight_offset +
level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_w_next_point = spatial_w;
spatial_h_next_point = spatial_h;
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2 +
1];
weight_next_point =
((T *)data_attn_weight_nram)[load_loc_weight_offset +
level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
#if __BANG_ARCH__ >= 322
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
......@@ -513,15 +966,29 @@ __mlu_global__ void MLUKernelMsDeformAttnForward(
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b, channels_align_rem,
spatial_w, spatial_h, x, y);
__bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight,
channels_align_rem);
__bang_add((T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)auxiliary_a, channels_align_rem);
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, channels_rem, spatial_w, spatial_h, x, y);
#else
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, channels_align_rem, spatial_w, spatial_h, x, y);
#endif
}
spatial_w = spatial_w_next_point;
......@@ -539,12 +1006,36 @@ __mlu_global__ void MLUKernelMsDeformAttnForward(
channels_rem * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
load_loc_weight_idx += 1;
}
__asm__ volatile("sync;");
return;
}
template __mlu_global__ void MLUKernelMsDeformAttnForward<float>(
template __mlu_global__ void MLUKernelMsDeformAttnForwardDefault<float>(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram);
void KernelMsDeformAttnForwardDefault(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
MLUKernelMsDeformAttnForwardDefault<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
}
template __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel<float>(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
......@@ -552,7 +1043,7 @@ template __mlu_global__ void MLUKernelMsDeformAttnForward<float>(
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram);
void KernelMsDeformAttnForward(
void KernelMsDeformAttnForwardSmallChannel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram,
......@@ -561,7 +1052,7 @@ void KernelMsDeformAttnForward(
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
MLUKernelMsDeformAttnForward<float><<<k_dim, k_type, queue>>>(
MLUKernelMsDeformAttnForwardSmallChannel<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
......@@ -584,15 +1075,15 @@ void __mlu_func__ msDeformAttnCol2imBilinear(
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram, data_value_ptr + offset1,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num_real);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num_real);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num);
__bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset1),
(T *)top_grad_temp, deal_num_real);
}
......@@ -600,18 +1091,18 @@ void __mlu_func__ msDeformAttnCol2imBilinear(
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset2,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num_real);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num_real);
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w2,
deal_num);
deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num);
deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset2),
(T *)top_grad_temp, deal_num_real);
}
......@@ -619,18 +1110,18 @@ void __mlu_func__ msDeformAttnCol2imBilinear(
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset3,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num_real);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w3,
deal_num);
deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num);
deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset3),
(T *)top_grad_temp, deal_num_real);
}
......@@ -638,63 +1129,61 @@ void __mlu_func__ msDeformAttnCol2imBilinear(
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset4,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w4,
deal_num);
deal_num_real);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num);
deal_num_real);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset4),
(T *)top_grad_temp, deal_num_real);
}
__bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num);
__bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num_real);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_output_nram, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
const int32_t align_num_on_200 = NFU_ALIGN_SIZE / sizeof(float);
const int32_t align_num_on_200 = NFU_ALIGN_SIZE / LEN_FLOAT;
recursiveSumPool(grad_output_nram, align_num_on_200,
deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_output_nram, grad_output_nram,
NFU_ALIGN_SIZE / sizeof(float));
NFU_ALIGN_SIZE / LEN_FLOAT);
#endif
__bang_atomic_add((T *)grad_output_nram, (T *)grad_attn_weight,
(T *)grad_output_nram, 1);
__bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num);
__bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
__bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
recursiveSumPool(grad_w_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_w_weight, grad_w_weight,
NFU_ALIGN_SIZE / sizeof(float));
__bang_reduce_sum(grad_w_weight, grad_w_weight, NFU_ALIGN_SIZE / LEN_FLOAT);
#endif
__bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc),
(T *)grad_w_weight, 1);
__bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num);
__bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num);
__bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real);
__bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
recursiveSumPool(grad_h_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_h_weight, grad_h_weight,
NFU_ALIGN_SIZE / sizeof(float));
__bang_reduce_sum(grad_h_weight, grad_h_weight, NFU_ALIGN_SIZE / LEN_FLOAT);
#endif
__bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1),
(T *)grad_h_weight, 1);
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
......@@ -708,8 +1197,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
const int32_t split_num = 8;
const int32_t spatial_shapes_size = 64;
int32_t deal_num = PAD_DOWN(
(MAX_NRAM_SIZE - spatial_shapes_size) / split_num / sizeof(float),
ALIGN_NUM);
(MAX_NRAM_SIZE - spatial_shapes_size) / split_num / LEN_FLOAT, ALIGN_NUM);
float *grad_output_nram = (float *)nram_buffer;
float *grad_output_nram_temp = (float *)nram_buffer + deal_num;
float *grad_weight = (float *)nram_buffer + 2 * deal_num;
......@@ -725,10 +1213,8 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem);
int32_t start_per_core =
num_rem > taskId
? (taskId * num_per_core)
: ((num_per_core + 1) * num_rem + (taskId - num_rem) * num_per_core);
int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core)
: (num_rem + taskId * num_per_core);
int32_t end_per_core = start_per_core + num_per_core;
const int32_t C_repeat = channels / deal_num;
const int32_t C_tail = channels % deal_num;
......@@ -758,7 +1244,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
const int32_t grad_sampling_loc_out = num_loop * num_points * 2;
for (int32_t p_col = 0; p_col < num_points; ++p_col) {
__memcpy(sampling_loc_nram, data_sampling_loc + data_loc_w_ptr,
2 * sizeof(float), GDRAM2NRAM);
2 * LEN_FLOAT, GDRAM2NRAM);
const float loc_w = sampling_loc_nram[0];
const float loc_h = sampling_loc_nram[1];
const float weight = data_attn_weight[data_weight_ptr];
......@@ -789,11 +1275,12 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
for (int32_t C_loop = 0; C_loop < C_repeat; ++C_loop) {
base_ptr = m_col * channels + C_loop * deal_num;
__bang_write_zero(grad_weight, 3 * deal_num);
__bang_write_zero(grad_output_nram, deal_num);
__bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM));
__memcpy(top_grad,
grad_output + grad_output_offset + C_loop * deal_num,
deal_num * sizeof(float), GDRAM2NRAM);
deal_num * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
......@@ -806,10 +1293,12 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
}
if (C_tail != 0) {
base_ptr = m_col * channels + C_repeat * deal_num;
__bang_write_zero(grad_output_nram, 8 * deal_num);
__bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM));
__memcpy(top_grad,
grad_output + grad_output_offset + C_repeat * deal_num,
C_tail * sizeof(float), GDRAM2NRAM);
C_tail * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
......@@ -827,7 +1316,422 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
}
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
template <typename T>
void __mlu_func__
loadData(const int32_t &h_low, const int32_t &w_low, const int32_t &h_high,
const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tr,
T *grad_output_nram_bl, T *grad_output_nram_br,
const T *data_value_ptr, const int32_t &width, const int32_t &height,
const int32_t &deal_num_real, const int32_t &h_low_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &w_high_ptr_offset,
const int32_t &h_high_ptr_offset, const int32_t &base_ptr) {
#if __BANG_ARCH__ > 322
if (h_low >= 0 && w_low >= 0)
{
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy_async(grad_output_nram_tl, data_value_ptr + offset1,
deal_num_real * sizeof(T), GDRAM2NRAM);
}
if (h_low >= 0 && w_high <= width - 1)
{
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy_async(grad_output_nram_tr, data_value_ptr + offset2,
deal_num_real * sizeof(T), GDRAM2NRAM);
}
if (h_high <= height - 1 && w_low >= 0)
{
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy_async(grad_output_nram_bl, data_value_ptr + offset3,
deal_num_real * sizeof(T), GDRAM2NRAM);
}
if (h_high <= height - 1 && w_high <= width - 1)
{
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy_async(grad_output_nram_br, data_value_ptr + offset4,
deal_num_real * sizeof(T), GDRAM2NRAM);
}
__sync_io();
#endif
}
template <typename T>
void __mlu_func__ computeData(
const int32_t &h_low, const int32_t &w_low, const int32_t &h_high,
const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tr,
T *grad_output_nram_bl, T *grad_output_nram_br, T *grad_output_nram_tl_temp,
T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp,
T *grad_output_nram_br_temp, const int32_t &width, const int32_t &height,
const int32_t &deal_num_real, T *grad_h_weight, T *grad_w_weight,
T *top_grad_temp, T *top_grad, const T &data_attn_weight, const T &hw,
const T &hh, const T &lw, const T &lh, const T &w1, const T &w2,
const T &w3, const T &w4) {
#if __BANG_ARCH__ > 322
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
if (h_low >= 0 && w_low >= 0) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tl, (float)(-hw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tl, (float)(-hh),
grad_w_weight, deal_num_real, deal_num_real);
__bang_mul_scalar(grad_output_nram_tl_temp, top_grad_temp, w1,
deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_tl, grad_output_nram_tl, w1,
deal_num_real);
}
if (h_low >= 0 && w_high <= width - 1) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tr, (float)(-lw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tr, (float)(hh),
grad_w_weight, deal_num_real, deal_num_real);
__bang_mul_scalar(grad_output_nram_tr_temp, top_grad_temp, w2,
deal_num_real);
__bang_mul_scalar(grad_output_nram_tr, grad_output_nram_tr, w2,
deal_num_real);
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_tr,
deal_num_real);
}
if (h_high <= height - 1 && w_low >= 0) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_bl, (float)(hw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_bl, (float)(-lh),
grad_w_weight, deal_num_real, deal_num_real);
__bang_mul_scalar(grad_output_nram_bl_temp, top_grad_temp, w3,
deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_bl, grad_output_nram_bl, w3,
deal_num_real);
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_bl,
deal_num_real);
}
if (h_high <= height - 1 && w_high <= width - 1) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_br, (float)(lw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_br, (float)(lh),
grad_w_weight, deal_num_real, deal_num_real);
__bang_mul_scalar(grad_output_nram_br_temp, top_grad_temp, w4,
deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_br, grad_output_nram_br, w4,
deal_num_real);
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_br,
deal_num_real);
}
__bang_mul(grad_output_nram_tl, grad_output_nram_tl, top_grad, deal_num_real);
recursiveSumPool(grad_output_nram_tl, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
__bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real);
__bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real);
recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
__bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real);
__bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real);
recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#endif
}
template <typename T>
void __mlu_func__ storeData(
const int32_t &h_low, const int32_t &w_low, const int32_t &h_high,
const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tl_temp,
T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp,
T *grad_output_nram_br_temp, const int32_t &width, const int32_t &height,
const int32_t &deal_num_real, const int32_t &h_low_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &w_high_ptr_offset,
const int32_t &h_high_ptr_offset, const int32_t &base_ptr, T *grad_value,
T *grad_w_weight, T *grad_h_weight, T *grad_sampling_loc,
T *grad_attn_weight) {
#if __BANG_ARCH__ > 322
if (h_low >= 0 && w_low >= 0)
{
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
__bang_atomic_add((T *)grad_output_nram_tl_temp,
(T *)(grad_value + offset1),
(T *)grad_output_nram_tl_temp, deal_num_real);
}
if (h_low >= 0 && w_high <= width - 1)
{
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
__bang_atomic_add((T *)grad_output_nram_tr_temp,
(T *)(grad_value + offset2),
(T *)grad_output_nram_tr_temp, deal_num_real);
}
if (h_high <= height - 1 && w_low >= 0)
{
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
__bang_atomic_add((T *)grad_output_nram_bl_temp,
(T *)(grad_value + offset3),
(T *)grad_output_nram_bl_temp, deal_num_real);
}
if (h_high <= height - 1 && w_high <= width - 1)
{
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
__bang_atomic_add((T *)grad_output_nram_br_temp,
(T *)(grad_value + offset4),
(T *)grad_output_nram_br_temp, deal_num_real);
}
__bang_atomic_add((T *)grad_output_nram_tl, (T *)grad_attn_weight,
(T *)grad_output_nram_tl, 1);
__bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc),
(T *)grad_w_weight, 1);
__bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1),
(T *)grad_h_weight, 1);
#endif
}
template <typename T>
void __mlu_func__ msDeformAttnCol2imBilinearSmallChannels(
T *top_grad_temp, const int32_t &height, const int32_t &width, const T &w1,
const T &w2, const T &w3, const T &w4, const int32_t &h_low,
const int32_t &w_low, const int32_t &h_high, const int32_t &w_high,
const int32_t &base_ptr, const int32_t &h_low_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &h_high_ptr_offset,
const int32_t &w_high_ptr_offset, const T &hh, const T &hw, const T &lh,
const T &lw, T *top_grad, const T &data_attn_weight, T *grad_h_weight,
T *grad_w_weight, T *grad_value, T *grad_output_nram_tl,
T *grad_output_nram_tr, T *grad_output_nram_bl, T *grad_output_nram_br,
T *grad_output_nram_tl_temp, T *grad_output_nram_tr_temp,
T *grad_output_nram_bl_temp, T *grad_output_nram_br_temp,
T *grad_sampling_loc, T *grad_attn_weight, const int32_t &deal_num_real,
const T *data_value_ptr)
{
loadData(h_low, w_low, h_high, w_high, grad_output_nram_tl,
grad_output_nram_tr, grad_output_nram_bl, grad_output_nram_br,
data_value_ptr, width, height, deal_num_real, h_low_ptr_offset,
w_low_ptr_offset, w_high_ptr_offset, h_high_ptr_offset, base_ptr);
computeData(h_low, w_low, h_high, w_high, grad_output_nram_tl,
grad_output_nram_tr, grad_output_nram_bl, grad_output_nram_br,
grad_output_nram_tl_temp, grad_output_nram_tr_temp,
grad_output_nram_bl_temp, grad_output_nram_br_temp, width, height,
deal_num_real, grad_h_weight, grad_w_weight, top_grad_temp,
top_grad, data_attn_weight, hw, hh, lw, lh, w1, w2, w3, w4);
storeData(h_low, w_low, h_high, w_high, grad_output_nram_tl,
grad_output_nram_tl_temp, grad_output_nram_tr_temp,
grad_output_nram_bl_temp, grad_output_nram_br_temp, width, height,
deal_num_real, h_low_ptr_offset, w_low_ptr_offset,
w_high_ptr_offset, h_high_ptr_offset, base_ptr, grad_value,
grad_w_weight, grad_h_weight, grad_sampling_loc, grad_attn_weight);
}
template <typename T>
void __mlu_func__ msDeformAttnCol2imImpl(
T *top_grad_temp, T *top_grad, T *grad_h_weight, T *grad_w_weight,
T *grad_value, T *grad_output_nram_tl, T *grad_output_nram_tr,
T *grad_output_nram_bl, T *grad_output_nram_br, T *grad_output_nram_tl_temp,
T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp,
T *grad_output_nram_br_temp, T *grad_sampling_loc, T *grad_attn_weight,
T *nram_sampling_loc, T *nram_attn_weight, const int32_t &load_num,
const int32_t &tail, const int32_t &i_repeat, const int32_t &num_points,
const int32_t &start_per_core, const int32_t &num_levels,
const int32_t &num_heads, const int32_t &num_query,
const int32_t &spatial_size, const int32_t &qid_stride,
int32_t *level_start_index_nram, const int32_t &channels,
const T *data_value, const T *grad_output, int32_t *spatial_shapes_nram) {
#if __BANG_ARCH__ > 322
int32_t weight_pos = 0;
int32_t sampling_loc_pos = 0;
for (int32_t p = 0; p < tail; ++p) {
int32_t grid_offset = start_per_core + i_repeat * load_num + p;
const int32_t l_col = grid_offset % num_levels;
const int32_t m_col = grid_offset / num_levels % num_heads;
const int32_t q_col = grid_offset / num_levels / num_heads % num_query;
const int32_t b_col = grid_offset / num_query / num_heads / num_levels;
const int32_t value_offset = b_col * spatial_size * qid_stride;
const int32_t level_start_id = level_start_index_nram[l_col];
const int32_t grad_attn_weight_out = grid_offset * num_points;
const int32_t spatial_h_ptr = l_col << 1;
const int32_t grad_output_offset =
b_col * num_query * qid_stride + q_col * qid_stride + m_col * channels;
__memcpy(top_grad, grad_output + grad_output_offset, channels * LEN_FLOAT,
GDRAM2NRAM);
const int32_t spatial_h = spatial_shapes_nram[spatial_h_ptr];
const int32_t spatial_w = spatial_shapes_nram[spatial_h_ptr + 1];
const int32_t h_stride = spatial_w * qid_stride;
const int32_t value_ptr_offset = value_offset + level_start_id * qid_stride;
const float *data_value_ptr = data_value + value_ptr_offset;
float *grad_value_ptr = grad_value + value_ptr_offset;
const int32_t grad_sampling_loc_out = grid_offset * num_points << 1;
for (int32_t p_col = 0; p_col < num_points; ++p_col) {
const float loc_w = nram_sampling_loc[sampling_loc_pos];
const float loc_h = nram_sampling_loc[sampling_loc_pos + 1];
const float weight = nram_attn_weight[weight_pos];
const float h_im = loc_h * spatial_h - 0.5;
const float w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
const int32_t h_low = floorf(h_im);
const int32_t w_low = floorf(w_im);
const int32_t h_high = h_low + 1;
const int32_t w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1.0 - lh;
const float hw = 1.0 - lw;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * qid_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + qid_stride;
const float w1 = hh * hw;
const float w2 = hh * lw;
const float w3 = lh * hw;
const float w4 = lh * lw;
const int32_t base_ptr = m_col * channels;
__bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram_tl, PAD_UP(channels, ALIGN_NUM));
msDeformAttnCol2imBilinearSmallChannels(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl,
grad_output_nram_br, grad_output_nram_tl_temp,
grad_output_nram_tr_temp, grad_output_nram_bl_temp,
grad_output_nram_br_temp,
grad_sampling_loc + grad_sampling_loc_out + (p_col << 1),
grad_attn_weight + grad_attn_weight_out + p_col, channels,
data_value_ptr);
}
weight_pos += 1;
sampling_loc_pos += 2;
}
}
#endif
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight) {
#if __BANG_ARCH__ > 322
const int32_t split_num = 12;
const int32_t C_align = PAD_UP(channels, ALIGN_NUM);
float *grad_output_nram_tl = (float *)nram_buffer;
float *grad_output_nram_tr = (float *)nram_buffer + C_align;
float *grad_output_nram_bl = (float *)nram_buffer + 2 * C_align;
float *grad_output_nram_br = (float *)nram_buffer + 3 * C_align;
float *grad_output_nram_tl_temp = (float *)nram_buffer + 4 * C_align;
float *grad_output_nram_tr_temp = (float *)nram_buffer + 5 * C_align;
float *grad_output_nram_bl_temp = (float *)nram_buffer + 6 * C_align;
float *grad_output_nram_br_temp = (float *)nram_buffer + 7 * C_align;
float *grad_h_weight = (float *)nram_buffer + 8 * C_align;
float *grad_w_weight = (float *)nram_buffer + 9 * C_align;
float *top_grad_temp = (float *)nram_buffer + 10 * C_align;
float *top_grad = (float *)nram_buffer + 11 * C_align;
int32_t *spatial_shapes_nram =
(int32_t *)((float *)nram_buffer + split_num * C_align);
int32_t *level_start_index_nram =
(int32_t *)(spatial_shapes_nram + PAD_UP(num_levels * 2, ALIGN_NUM));
float *nram_remain = (float *)((int32_t *)level_start_index_nram +
PAD_UP(num_levels, ALIGN_NUM));
// calc load num
const int32_t weight_num2nram =
(MAX_NRAM_SIZE / LEN_FLOAT - split_num * C_align -
3 * PAD_UP(num_levels, ALIGN_NUM)) /
3 / num_points;
int32_t load_num = weight_num2nram;
const int32_t total_num = batch * num_query * num_heads * num_levels;
int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem);
if (num_per_core == 0) {
return;
}
const int32_t start_per_core = num_rem > taskId
? (taskId * num_per_core)
: (num_rem + taskId * num_per_core);
const int32_t qid_stride = num_heads * channels;
// load spatial_shapes anddata_level_start_index to nram
__memcpy_async(spatial_shapes_nram, spatial_shapes,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(level_start_index_nram, data_level_start_index,
num_levels * sizeof(int32_t), GDRAM2NRAM);
const int32_t start_l_col = start_per_core % num_levels;
const int32_t start_m_col = start_per_core / num_levels % num_heads;
const int32_t start_q_col =
start_per_core / num_levels / num_heads % num_query;
const int32_t start_b_col =
start_per_core / num_query / num_heads / num_levels;
const int32_t repeat = num_per_core / load_num;
const int32_t tail = num_per_core % load_num;
float *nram_sampling_loc = nram_remain;
float *nram_attn_weight = nram_sampling_loc + 2 * load_num * num_points;
const int32_t attn_weight_offset =
start_b_col * num_query * num_heads * num_levels * num_points +
start_q_col * num_heads * num_levels * num_points +
start_m_col * num_levels * num_points + start_l_col * num_points;
const int32_t sampling_loc_offset =
start_b_col * num_query * num_heads * num_levels * num_points * 2 +
start_q_col * num_heads * num_levels * num_points * 2 +
start_m_col * num_levels * num_points * 2 + start_l_col * num_points * 2;
if (repeat > 0) {
for (int32_t i_repeat = 0; i_repeat < repeat; ++i_repeat)
{ // load weight and sampling_loc to nram
__memcpy_async(nram_sampling_loc,
data_sampling_loc + sampling_loc_offset +
i_repeat * load_num * 2 * num_points,
2 * load_num * num_points * LEN_FLOAT, GDRAM2NRAM);
__memcpy(nram_attn_weight,
data_attn_weight + attn_weight_offset +
i_repeat * load_num * num_points,
load_num * num_points * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imImpl(
top_grad_temp, top_grad, grad_h_weight, grad_w_weight, grad_value,
grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl,
grad_output_nram_br, grad_output_nram_tl_temp,
grad_output_nram_tr_temp, grad_output_nram_bl_temp,
grad_output_nram_br_temp, grad_sampling_loc, grad_attn_weight,
nram_sampling_loc, nram_attn_weight, load_num, load_num, i_repeat,
num_points, start_per_core, num_levels, num_heads, num_query,
spatial_size, qid_stride, level_start_index_nram, channels,
data_value, grad_output, spatial_shapes_nram);
}
}
if (tail > 0)
{ // load weight and sampling_loc to nram
__memcpy_async(nram_sampling_loc,
data_sampling_loc + sampling_loc_offset +
repeat * load_num * 2 * num_points,
tail * num_points * 2 * LEN_FLOAT, GDRAM2NRAM);
__memcpy(
nram_attn_weight,
data_attn_weight + attn_weight_offset + repeat * load_num * num_points,
tail * num_points * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imImpl(
top_grad_temp, top_grad, grad_h_weight, grad_w_weight, grad_value,
grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl,
grad_output_nram_br, grad_output_nram_tl_temp, grad_output_nram_tr_temp,
grad_output_nram_bl_temp, grad_output_nram_br_temp, grad_sampling_loc,
grad_attn_weight, nram_sampling_loc, nram_attn_weight, load_num, tail,
repeat, num_points, start_per_core, num_levels, num_heads, num_query,
spatial_size, qid_stride, level_start_index_nram, channels, data_value,
grad_output, spatial_shapes_nram);
}
#endif
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
......@@ -835,8 +1739,32 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight);
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight);
void KernelMsDeformAttnBackwardDefaultKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float *data_value,
const int32_t *spatial_shapes, const int32_t *data_level_start_index,
const float *data_sampling_loc, const float *data_attn_weight,
const float *grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float *grad_value,
float *grad_sampling_loc, float *grad_attn_weight) {
MLUUnion1KernelMsDeformAttnBackwarDefaultKernel<<<k_dim, k_type, queue>>>(
data_value, spatial_shapes, data_level_start_index, data_sampling_loc,
data_attn_weight, grad_output, batch, spatial_size, num_heads, channels,
num_levels, num_query, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
}
void KernelMsDeformAttnBackward(
void KernelMsDeformAttnBackwardSmallChannelsKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float *data_value,
const int32_t *spatial_shapes, const int32_t *data_level_start_index,
......@@ -845,7 +1773,8 @@ void KernelMsDeformAttnBackward(
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float *grad_value,
float *grad_sampling_loc, float *grad_attn_weight) {
MLUUnion1KernelMsDeformAttnBackward<<<k_dim, k_type, queue>>>(
MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel<<<k_dim, k_type,
queue>>>(
data_value, spatial_shapes, data_level_start_index, data_sampling_loc,
data_attn_weight, grad_output, batch, spatial_size, num_heads, channels,
num_levels, num_query, num_points, grad_value, grad_sampling_loc,
......
......@@ -14,7 +14,15 @@
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
void KernelMsDeformAttnForward(
typedef enum {
MS_DEFORM_ATTN_FORWARD_INVALID = 0, /*!< Index is invalid. */
MS_DEFORM_ATTN_FORWARD_DEFAULT =
1, /*!< MLUKernelMsDeformAttnForwardDefault */
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL =
2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */
} MsDeformAttnForwardPolicy;
void KernelMsDeformAttnForwardDefault(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char* data_value_gdram,
const char* data_spatial_shapes_gdram,
......@@ -23,7 +31,37 @@ void KernelMsDeformAttnForward(
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram);
void KernelMsDeformAttnBackward(
void KernelMsDeformAttnForwardSmallChannel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char* data_value_gdram,
const char* data_spatial_shapes_gdram,
const char* data_level_start_index_gdram,
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram);
typedef enum {
MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0,
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL = 1,
} MsDeformAttnBackwardKernelPolicy;
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
const int32_t channels, const int32_t num_levels,
const int32_t num_points) {
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const uint64_t max_num = nram_size / sizeof(float);
const uint64_t deal_num =
12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points;
if (max_num >= deal_num) {
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
}
return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
}
void KernelMsDeformAttnBackwardDefaultKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float* data_value,
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
......@@ -32,10 +70,23 @@ void KernelMsDeformAttnBackward(
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_queries, const int32_t num_points, float* grad_value,
float* grad_sampling_loc, float* grad_attn_weight);
void KernelMsDeformAttnBackwardSmallChannelsKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float* data_value,
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
const float* data_sampling_loc, const float* data_attn_weight,
const float* grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float* grad_value,
float* grad_sampling_loc, float* grad_attn_weight);
// policy function
static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type,
const int batch_size, const int num_queries,
const int num_heads) {
MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, const int32_t batch_size,
const int32_t num_keys, const int32_t num_heads, const int32_t channels,
const int32_t num_levels, const int32_t num_queries,
const int32_t num_points) {
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y =
MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x,
......@@ -46,6 +97,15 @@ static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type,
#else
*k_type = CNRT_FUNC_TYPE_UNION1;
#endif
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else if (channels > nram_size / 12 / sizeof(float)) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else {
return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL;
}
}
// policy function for backward
......@@ -196,7 +256,9 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncForward(&k_dim, &k_type, batch_size, num_queries, num_heads);
MsDeformAttnForwardPolicy policy = msDeformAttnForwardPolicyFunc(
&k_dim, &k_type, batch_size, num_keys, num_heads, channels, num_levels,
num_queries, num_points);
// get compute queue
auto queue = torch_mlu::getCurQueue();
......@@ -222,15 +284,33 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype());
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForward(
k_dim, k_type, queue, data_type, (char*)value_ptr,
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points,
(char*)output_ptr);
switch (policy) {
default: {
VLOG(5) << "MsDeformAttnForward Policy not supported";
}; break;
case MS_DEFORM_ATTN_FORWARD_DEFAULT: {
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<"
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForwardDefault(
k_dim, k_type, queue, data_type, (char*)value_ptr,
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points,
(char*)output_ptr);
break;
}
case MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL: {
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<"
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForwardSmallChannel(
k_dim, k_type, queue, data_type, (char*)value_ptr,
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points,
(char*)output_ptr);
break;
}
}
output = output.view({batch_size, num_queries, num_heads * channels});
return output;
......@@ -391,14 +471,31 @@ void ms_deform_attn_mlu_backward(
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnBackward(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
MsDeformAttnBackwardKernelPolicy kernelPolicy =
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points);
switch (kernelPolicy) {
default: {
VLOG(5) << "NotImplemented.";
} break;
case MS_DEFORM_ATTN_BACKWARD_DEFAULT: {
KernelMsDeformAttnBackwardDefaultKernel(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
} break;
case MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL: {
KernelMsDeformAttnBackwardSmallChannelsKernel(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
} break;
}
}
Tensor ms_deform_attn_impl_forward(const Tensor& value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment