Unverified Commit 5f122293 authored by BinZheng's avatar BinZheng Committed by GitHub
Browse files

[Enhancement] ms_deform_attn performance optimization (#2616)



* ms_opt_v2

* ms_opt_v2_1

* optimize MultiScaleDeformableAttention ops for MLU

* ms_opt_v2_1

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

---------
Co-authored-by: default avatardongchengwei <dongchengwei@cambricon.com>
parent ec639323
......@@ -32,6 +32,7 @@
/****************************************************************************************
*
* NRAM partition backward:
* default kernel
* | grad_output_nram | grad_output_nram_temp | grad_weight |
* | grad_h_weight | grad_w_weight | top_grad |
* | top_grad_temp | spatial_shapes_nram | sampling_loc_nram |
......@@ -39,11 +40,26 @@
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | 64bytes |
*
* small channel kernel
* | nram_grad_output_tl | nram_grad_output_tr | nram_grad_output_bl |
* | nram_grad_output_br | grad_temp1 | grad_temp2 |
* | grad_temp3 | grad_temp4 | nram_loc_w |
* | nram_loc_h | nram_h_low | nram_w_low |
* | nram_h_high | nram_w_high | nram_h_low_temp |
* | nram_h_high_temp | nram_hw | nram_hh |
* | nram_lw | nram_lh | nram_h_low_ptr_offset |
* | nram_h_high_ptr_offset | nram_w_low_ptr_offset | nram_w_high_ptr_offset |
* | nram_w1 | nram_w2 | nram_w3 |
* | nram_w4 | nram_grad_weight | nram_base_ptr |
* | nram_offset_temp | nram_offset1 | nram_offset2 |
* | nram_offset3 | nram_offset4 | nram_w_low_temp |
* | nram_spatial_shapes | nram_level_start_index | nram_h_stride |
****************************************************************************************/
#define TWELVE_SPLIT 12
#define ALIGN_NUM 32
#define ALIGN_NUM_FOR_REDUCE 32
#define ELE_COUNT 32
#define LEN_FLOAT sizeof(float)
__nram__ char nram_buffer[MAX_NRAM_SIZE];
......@@ -540,6 +556,17 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardDefault(
return;
}
__mlu_func__ void genMask0101(float *mask_ram, int32_t size) {
int32_t align_num = NFU_ALIGN_SIZE / sizeof(float);
for (int32_t i = 0; i < align_num; ++i) {
mask_ram[i] = i % 2;
}
__asm__ volatile("sync;");
__memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM,
NFU_ALIGN_SIZE, 0, size / align_num - 2);
__asm__ volatile("sync;");
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
......@@ -548,467 +575,471 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
#if __BANG_ARCH__ >= 300
if (coreId == 0x80) {
return;
}
size_t block_num_per_core, batch_start, deal_g, offset_g;
size_t block_num_rem = 0;
const size_t grid_total = num_queries * num_heads * num_levels * num_points;
if (batch_size >= taskDim) {
block_num_rem = batch_size % taskDim;
block_num_per_core = taskId < block_num_rem ? batch_size / taskDim + 1
: batch_size / taskDim;
batch_start = taskId < block_num_rem
? taskId * block_num_per_core
: taskId * block_num_per_core + block_num_rem;
deal_g = grid_total;
offset_g = 0;
} else {
size_t skip_n = taskDim / batch_size;
batch_start = taskId / skip_n;
block_num_per_core = batch_start >= batch_size ? 0 : 1;
deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points);
size_t id = taskId % skip_n;
offset_g = id * deal_g;
deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1);
}
const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float);
int32_t deal_num;
int32_t cut_channel_iter = 2;
const size_t spatial_size =
PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE);
const size_t level_start_index_size =
PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE);
size_t sampling_loc_size =
PAD_UP(num_levels * num_points * 2 * sizeof(T), NFU_ALIGN_SIZE);
size_t attn_weight_size =
PAD_UP(num_levels * num_points * sizeof(T), NFU_ALIGN_SIZE);
size_t span_num_deal =
PAD_DOWN((MAX_NRAM_SIZE - spatial_size - level_start_index_size -
sampling_loc_size - attn_weight_size) /
TWELVE_SPLIT / sizeof(T),
NFU_ALIGN_SIZE);
const int32_t channels_seg_num = channels / span_num_deal;
const size_t channels_rem = channels % span_num_deal;
int32_t load_loc_weight_idx = 0;
int32_t load_loc_weight_seg = 1;
if (channels_seg_num == 0) {
span_num_deal = PAD_UP(channels, NFU_ALIGN_SIZE);
attn_weight_size =
PAD_DOWN((MAX_NRAM_SIZE - spatial_size - level_start_index_size -
TWELVE_SPLIT * span_num_deal * sizeof(T)) /
3,
num_levels * num_points * sizeof(T));
attn_weight_size = PAD_DOWN(attn_weight_size, NFU_ALIGN_SIZE);
sampling_loc_size = attn_weight_size * 2;
load_loc_weight_seg =
attn_weight_size / (num_levels * num_points * sizeof(T));
int32_t channel = channels;
int32_t mult;
while (true) {
deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) /
(8 * channel + 7) / sizeof(T);
deal_num = PAD_DOWN(deal_num, float_align);
deal_num = PAD_DOWN(deal_num, num_levels * num_points);
if (deal_num > 0) {
break;
} else {
channel = channels / cut_channel_iter;
cut_channel_iter += 2;
}
}
mult = channel;
#if __BANG_ARCH__ < 322
const size_t align_num = NFU_ALIGN_SIZE;
const size_t channels_align_rem = CEIL_ALIGN(channels_rem, align_num);
#endif
const int32_t c_rep = channels / channel;
const int32_t c_rem = channels % channel;
const int32_t g_rep = deal_g / deal_num;
const int32_t g_rem = deal_g % deal_num;
// nram buffer alloc
char *data_spatial_shapes_nram = nram_buffer;
char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size;
char *data_sampling_loc_nram =
data_level_start_index_nram + level_start_index_size;
char *data_attn_weight_nram = data_sampling_loc_nram + sampling_loc_size;
char *ping_data_value_p1_nram = data_attn_weight_nram + attn_weight_size;
char *ping_data_value_p2_nram =
ping_data_value_p1_nram + span_num_deal * sizeof(T);
char *ping_data_value_p3_nram =
ping_data_value_p2_nram + span_num_deal * sizeof(T);
char *ping_data_value_p4_nram =
ping_data_value_p3_nram + span_num_deal * sizeof(T);
char *ping_data_col_nram =
ping_data_value_p4_nram + span_num_deal * sizeof(T);
char *pong_data_value_p1_nram =
ping_data_col_nram + span_num_deal * sizeof(T);
char *pong_data_value_p2_nram =
pong_data_value_p1_nram + span_num_deal * sizeof(T);
char *pong_data_value_p3_nram =
pong_data_value_p2_nram + span_num_deal * sizeof(T);
char *pong_data_value_p4_nram =
pong_data_value_p3_nram + span_num_deal * sizeof(T);
char *pong_data_col_nram =
pong_data_value_p4_nram + span_num_deal * sizeof(T);
char *auxiliary_a = pong_data_col_nram + span_num_deal * sizeof(T);
char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T);
const size_t ping_pong_gap = 5 * span_num_deal * sizeof(T);
size_t data_col_ping_pong_idx = 0;
const int32_t block_num_rem =
(batch_size * num_queries * num_heads) % taskDim;
const int32_t block_num_per_core =
taskId < block_num_rem
? (batch_size * num_queries * num_heads) / taskDim + 1
: (batch_size * num_queries * num_heads) / taskDim;
const int32_t idx_start = taskId < block_num_rem
? taskId * block_num_per_core
: taskId * block_num_per_core + block_num_rem;
char *input_tl = data_level_start_index_nram + level_start_index_size;
char *input_tr = input_tl + deal_num * mult * sizeof(T);
char *input_bl = input_tr + deal_num * mult * sizeof(T);
char *input_br = input_bl + deal_num * mult * sizeof(T);
char *weight_tl = input_tl + 4 * deal_num * mult * sizeof(T);
char *weight_tr = weight_tl + deal_num * mult * sizeof(T);
char *weight_bl = weight_tr + deal_num * mult * sizeof(T);
char *weight_br = weight_bl + deal_num * mult * sizeof(T);
char *mask_tl = weight_br + deal_num * mult * sizeof(T);
char *mask_tr = mask_tl + deal_num * sizeof(T);
char *mask_bl = mask_tr + deal_num * sizeof(T);
char *mask_br = mask_bl + deal_num * sizeof(T);
char *point_ram = mask_br + deal_num * sizeof(T);
char *index_tl = point_ram + deal_num * sizeof(T);
char *index_bl = index_tl + deal_num * sizeof(T);
// nram space reuse
char *grid_ram = weight_tl;
char *mask_ram = weight_bl;
char *coord_x = input_bl;
char *coord_y = coord_x + deal_num * sizeof(T);
char *coord_x_low = input_tl;
char *coord_y_low = coord_x_low + deal_num * sizeof(T);
char *coord_x_low_int = weight_tl;
char *coord_y_low_int = weight_tr;
char *spatial_x = mask_tl;
char *spatial_y = mask_tr;
char *spatial_x_float = weight_bl;
char *spatial_y_float = weight_br;
char *spatial_x_temp = mask_bl;
char *spatial_y_temp = mask_br;
char *base_ptr_offset = weight_tl;
char *auxiliary_a = point_ram;
char *auxiliary_b = weight_bl;
__memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(data_level_start_index_nram, data_level_start_index_gdram,
num_levels * sizeof(int32_t), GDRAM2NRAM);
__asm__ volatile("sync;");
for (int32_t cur_idx = idx_start; cur_idx < idx_start + block_num_per_core;
++cur_idx) {
// cur_idx = batch_idx * num_queries * num_heads + query_idx * num_heads +
// head_idx
const int32_t head_idx = cur_idx % num_heads;
const int32_t batch_idx = (cur_idx / num_heads) / num_queries;
for (int32_t batch_idx = batch_start;
batch_idx < batch_start + block_num_per_core; ++batch_idx) {
for (int32_t grid_iter = 0; grid_iter <= g_rep; ++grid_iter) {
int32_t io_data_num = deal_num;
const int32_t grid_off_base =
batch_idx * grid_total + offset_g + grid_iter * deal_num;
if (grid_iter == g_rep) {
if (g_rem == 0) {
continue;
} else {
io_data_num = g_rem;
}
}
const char *data_value_gdram_start =
data_value_gdram +
batch_idx * num_keys * num_heads * channels * sizeof(T);
char *data_col_gdram_start =
data_col_gdram + cur_idx * channels * sizeof(T);
char *data_col_gdram_start =
data_col_gdram + (batch_idx * num_queries * num_heads * channels +
(offset_g + grid_iter * deal_num) /
(num_levels * num_points) * channels) *
sizeof(float);
if (load_loc_weight_seg == 1 ||
(load_loc_weight_idx % load_loc_weight_seg) == 0) {
const char *data_sampling_loc_gdram_start =
data_sampling_loc_gdram +
cur_idx * num_levels * num_points * 2 * sizeof(T);
const char *data_attn_weight_gdram_start =
data_attn_weight_gdram +
cur_idx * num_levels * num_points * sizeof(T);
const int32_t load_loc_weight_size =
(block_num_per_core - load_loc_weight_idx) < load_loc_weight_seg
? block_num_per_core - load_loc_weight_idx
: load_loc_weight_seg;
// load data_sampling_loc
__memcpy_async(
data_sampling_loc_nram, data_sampling_loc_gdram_start,
load_loc_weight_size * num_levels * num_points * 2 * sizeof(T),
GDRAM2NRAM);
__memcpy_async(data_attn_weight_nram, data_attn_weight_gdram_start,
load_loc_weight_size * num_levels * num_points * sizeof(T),
GDRAM2NRAM);
grid_ram, data_sampling_loc_gdram + grid_off_base * 2 * sizeof(float),
io_data_num * 2 * sizeof(float), GDRAM2NRAM);
genMask0101((float *)mask_ram, deal_num * 2);
__asm__ volatile("sync;");
}
const int32_t load_loc_weight_offset =
(load_loc_weight_idx % load_loc_weight_seg) * num_levels * num_points;
for (int32_t c_seg_idx = 0; c_seg_idx < channels_seg_num; ++c_seg_idx) {
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
span_num_deal, (T)0);
// load data
// level_idx = 0, point_idx = 0
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + c_seg_idx * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2];
T loc_h = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2 + 1];
T weight = ((T *)data_attn_weight_nram)[load_loc_weight_offset];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, span_num_deal, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
// generate x and y coordinate vector
// generate spatial_x and spatial_y spatial vector
__bang_collect((float *)coord_y, (float *)grid_ram, (float *)mask_ram,
deal_num * 2); // y
__bang_collect((float *)spatial_x_temp, (float *)data_spatial_shapes_nram,
(float *)mask_ram,
num_levels * 2); // spatial_x
__bang_not((float *)mask_ram, (float *)mask_ram, deal_num * 2);
__bang_collect((float *)coord_x, (float *)grid_ram, (float *)mask_ram,
deal_num * 2); // x
__bang_collect((float *)spatial_y_temp, (float *)data_spatial_shapes_nram,
(float *)mask_ram,
num_levels * 2); // spatial_y
for (int32_t i = 0; i < num_levels; i++) {
__bang_write_value((int32_t *)spatial_x + i * num_points, num_points,
((int32_t *)spatial_x_temp)[i]);
__bang_write_value((int32_t *)spatial_y + i * num_points, num_points,
((int32_t *)spatial_y_temp)[i]);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_nram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
spatial_h_next_point =
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr];
spatial_w_next_point =
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr + 1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
c_seg_idx * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2 +
1];
weight_next_point =
((T *)data_attn_weight_nram)[load_loc_weight_offset +
level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_h_next_point = spatial_h;
spatial_w_next_point = spatial_w;
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2 +
1];
weight_next_point =
((T *)data_attn_weight_nram)[load_loc_weight_offset +
level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
__bang_int322float_rd((float *)spatial_x_float, (int32_t *)spatial_x,
num_levels * num_points, 0);
__bang_int322float_rd((float *)spatial_y_float, (int32_t *)spatial_y,
num_levels * num_points, 0);
// map x from [0, 1] to [0, spatial_x]; map y from [0, 1] to [0,
// spatial_y]
__bang_cycle_mul((float *)coord_x, (float *)coord_x,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_sub_scalar((float *)coord_x, (float *)coord_x, (float)0.5,
deal_num);
__bang_cycle_mul((float *)coord_y, (float *)coord_y,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5,
deal_num);
__bang_floor((float *)coord_x_low, (float *)coord_x, deal_num);
__bang_floor((float *)coord_y_low, (float *)coord_y, deal_num);
// calc index_tl
const int32_t w_stride = num_heads * channels;
__bang_float2int32_rd((int32_t *)coord_x_low_int, (float *)coord_x_low,
deal_num, 0);
__bang_float2int32_rd((int32_t *)coord_y_low_int, (float *)coord_y_low,
deal_num, 0);
__bang_cycle_mul((int32_t *)index_tl, (int32_t *)coord_y_low_int,
(int32_t *)spatial_x, deal_num, num_levels * num_points);
__bang_add((int32_t *)index_tl, (int32_t *)index_tl,
(int32_t *)coord_x_low_int, deal_num);
__bang_mul_scalar((int32_t *)index_tl, (int32_t *)index_tl, w_stride,
deal_num);
const int32_t deal_lp_num = deal_num / (num_levels * num_points);
const int32_t h_rep = deal_lp_num / num_heads;
const int32_t h_rem = deal_lp_num % num_heads;
const int32_t head_start =
((offset_g + grid_iter * deal_num) / (num_levels * num_points)) %
num_heads;
for (int32_t iter = 0; iter < num_heads; ++iter) {
((int32_t *)base_ptr_offset)[iter] =
((head_start + iter) % num_heads) * channels;
}
if (h_rep > 0) {
__memcpy((int32_t *)base_ptr_offset + num_heads,
(int32_t *)base_ptr_offset, num_heads * sizeof(int32_t),
NRAM2NRAM, num_heads * sizeof(int32_t), 0, h_rep - 1);
}
if (h_rep > 0 && h_rem > 0) {
__memcpy((int32_t *)base_ptr_offset + h_rep * num_heads,
(int32_t *)base_ptr_offset, h_rem * sizeof(int32_t),
NRAM2NRAM);
}
__bang_transpose((int32_t *)auxiliary_a, (int32_t *)index_tl, deal_lp_num,
num_levels * num_points);
__bang_cycle_add((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)base_ptr_offset, deal_num, deal_lp_num);
__bang_transpose((int32_t *)index_tl, (int32_t *)auxiliary_a,
num_levels * num_points, deal_lp_num);
// calc index_bl
__bang_mul_scalar((int32_t *)auxiliary_a, (int32_t *)spatial_x, w_stride,
deal_num);
__bang_cycle_add((int32_t *)index_bl, (int32_t *)index_tl,
(int32_t *)auxiliary_a, deal_num,
num_levels * num_points);
// calc mask_tl, mask_tr, mask_bl, mask_br
__bang_sub_scalar((float *)spatial_x_float, (float *)spatial_x_float,
(float)1.0, deal_num);
__bang_sub_scalar((float *)spatial_y_float, (float *)spatial_y_float,
(float)1.0, deal_num);
// mask_tl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_low < spatial_y
__bang_ge_scalar((float *)mask_bl, (float *)coord_x_low, (float)0,
deal_num);
__bang_cycle_le((float *)mask_br, (float *)coord_x_low,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_bl, (float *)mask_bl, (float *)mask_br,
deal_num);
__bang_ge_scalar((float *)mask_tr, (float *)coord_y_low, (float)0,
deal_num);
__bang_cycle_le((float *)mask_br, (float *)coord_y_low,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br,
deal_num);
__bang_and((float *)mask_tl, (float *)mask_tr, (float *)mask_bl,
deal_num);
// mask_tr : 0 <= coord_x_high < spatial_x && 0 <= coord_y_low < spatial_y
__bang_ge_scalar((float *)mask_br, (float *)coord_x_low, (float)(-1.0),
deal_num);
__bang_cycle_lt((float *)auxiliary_a, (float *)coord_x_low,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
__bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br,
deal_num);
// mask_bl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_high < spatial_y
__bang_ge_scalar((float *)auxiliary_a, (float *)coord_y_low,
(float)(-1.0), deal_num);
__bang_cycle_lt((float *)auxiliary_b, (float *)coord_y_low,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_and((float *)auxiliary_a, (float *)auxiliary_a,
(float *)auxiliary_b, deal_num);
__bang_and((float *)mask_bl, (float *)mask_bl, (float *)auxiliary_a,
deal_num);
// mask_br : 0 <= coord_x_high < spatial_x && 0 <= coord_y_high <
// spatial_y
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
// calc inner point num
__bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0,
deal_num);
__bang_mul_scalar((float *)weight_tr, (float *)mask_tr, (float)5.0,
deal_num);
__bang_add((float *)weight_tl, (float *)weight_tl, (float *)weight_tr,
deal_num);
__bang_mul_scalar((float *)weight_tr, (float *)mask_bl, (float)3.0,
deal_num);
__bang_add((float *)point_ram, (float *)weight_tr, (float *)mask_br,
deal_num);
__bang_add((float *)point_ram, (float *)point_ram, (float *)weight_tl,
deal_num);
// calc interpolation weight
__bang_sub((float *)weight_bl, (float *)coord_x_low, (float *)coord_x,
deal_num);
__bang_sub((float *)weight_br, (float *)coord_y_low, (float *)coord_y,
deal_num);
__bang_add_scalar((float *)weight_bl, (float *)weight_bl, (float)1.0,
deal_num);
__bang_add_scalar((float *)weight_br, (float *)weight_br, (float)1.0,
deal_num);
__bang_sub((float *)weight_tl, (float *)coord_x, (float *)coord_x_low,
deal_num);
__bang_sub((float *)weight_tr, (float *)coord_y, (float *)coord_y_low,
deal_num);
__bang_mul((float *)input_tl, (float *)weight_bl, (float *)weight_br,
deal_num);
__bang_mul((float *)input_tl + deal_num, (float *)weight_br,
(float *)weight_tl, deal_num);
__bang_mul((float *)input_tl + 2 * deal_num, (float *)weight_bl,
(float *)weight_tr, deal_num);
__bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl,
(float *)weight_tr, deal_num);
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, span_num_deal, spatial_w, spatial_h, x, y);
}
__asm__ volatile("sync;");
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
// extend weight
const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT;
const int32_t w_rem = channel % ELE_COUNT;
if (w_rem != 0) {
const int32_t data_sz = 1 * sizeof(float);
const int32_t dst_str = channel * sizeof(float);
for (int32_t iter = w_rep; iter < channel; ++iter) {
__memcpy_async((float *)weight_tl + iter, (float *)input_tl, data_sz,
NRAM2NRAM, dst_str, data_sz, 4 * deal_num - 1);
}
}
// store
__memcpy_async(
data_col_gdram_start + c_seg_idx * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
span_num_deal * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
if (channels_rem > 0) {
#if __BANG_ARCH__ >= 322
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
channels_rem, (T)0);
#else
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
channels_align_rem, (T)0);
#endif
// load data
// level_idx = 0, point_idx = 0
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + channels_seg_num * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2];
T loc_h = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2 + 1];
T weight = ((T *)data_attn_weight_nram)[load_loc_weight_offset];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, channels_rem, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
if (w_rep != 0) {
for (int32_t i = 0; i < 4 * deal_num; i++) {
__bang_write_value((float *)weight_tl + i * channel, w_rep,
((float *)input_tl)[i]);
}
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_nram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
spatial_h_next_point =
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr];
spatial_w_next_point =
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr + 1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
channels_seg_num * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2 +
1];
weight_next_point =
((T *)data_attn_weight_nram)[load_loc_weight_offset +
level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
const char *data_value_gdram_start =
data_value_gdram +
batch_idx * num_keys * num_heads * channels * sizeof(float);
const int32_t c_str = deal_num * channel * sizeof(float);
const int32_t cs_str = num_heads * channels * sizeof(float);
for (int32_t c_iter = 0; c_iter <= c_rep; ++c_iter) {
int32_t c_real_num = channel;
if (c_iter == c_rep) {
if (c_rem == 0) {
continue;
} else {
spatial_w_next_point = spatial_w;
spatial_h_next_point = spatial_h;
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2 +
1];
weight_next_point =
((T *)data_attn_weight_nram)[load_loc_weight_offset +
level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
c_real_num = c_rem;
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
#if __BANG_ARCH__ >= 322
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, channels_rem, spatial_w, spatial_h, x, y);
#else
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, channels_align_rem, spatial_w, spatial_h, x, y);
#endif
__bang_write_zero((float *)input_tl, 4 * deal_num * channel);
__asm__ volatile("sync;");
// load data_value
for (int32_t p_idx = 0; p_idx < io_data_num; ++p_idx) {
const int32_t inner_point_num = (int32_t)((float *)point_ram)[p_idx];
const int32_t tl_offset = ((int32_t *)index_tl)[p_idx];
const int32_t bl_offset = ((int32_t *)index_bl)[p_idx];
const int32_t level_start_id =
((int32_t *)data_level_start_index_nram)[(p_idx / num_points) %
num_levels];
const char *data_value_ptr =
data_value_gdram_start +
(level_start_id * num_heads * channels + c_iter * channel) *
sizeof(float);
switch (inner_point_num) {
case 16: // 4 points are cached.
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 12: // 2 points are cached. (top_left, top_right)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 4: // 2 points are cached. (bottom_left, bottom_right)
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM, c_str,
cs_str, 1);
break;
case 10: // 2 points are cached. (top_left, bottom_left)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 6: // 2 points are cached. (top_right, bottom_right)
__memcpy_async(
(float *)input_tr + p_idx * channel,
(float *)data_value_ptr + tl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
__memcpy_async(
(float *)input_br + p_idx * channel,
(float *)data_value_ptr + bl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 7: // 1 point is cached. (top_left)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 5: // 1 point is cached. (top_right)
__memcpy_async(
(float *)input_tr + p_idx * channel,
(float *)data_value_ptr + tl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 3: // 1 point is cached. (bottom_left)
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 1: // 1 point is cached. (bottom_right)
__memcpy_async(
(float *)input_br + p_idx * channel,
(float *)data_value_ptr + bl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
default:
continue;
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
__asm__ volatile("sync;");
// interpolation
__bang_mul((float *)input_tl, (float *)input_tl, (float *)weight_tl,
4 * deal_num * channel);
__bang_add((float *)input_tl, (float *)input_tl, (float *)input_bl,
2 * deal_num * channel);
__bang_add((float *)input_tl, (float *)input_tl, (float *)input_tr,
deal_num * channel);
// load attention weight
void *attn_weight = mask_tl;
__memcpy((float *)attn_weight,
(float *)data_attn_weight_gdram + grid_off_base,
io_data_num * sizeof(float), GDRAM2NRAM);
// calc data_col, muladd attention weight
__bang_transpose((float *)input_tr, (float *)input_tl, deal_num,
channel);
__bang_cycle_mul((float *)input_tr, (float *)input_tr,
(float *)attn_weight, deal_num * channel, deal_num);
__bang_transpose((float *)input_tl, (float *)input_tr, channel,
deal_num);
__bang_sumpool((float *)input_bl, (float *)input_tl, channel, 1,
io_data_num, 1, num_levels * num_points,
num_levels * num_points, 1);
// store
__memcpy((float *)data_col_gdram_start + c_iter * channel,
(float *)input_bl, c_real_num * sizeof(float), NRAM2GDRAM,
channels * sizeof(float), channel * sizeof(float),
(io_data_num / (num_levels * num_points)) - 1);
}
// store
__memcpy_async(
data_col_gdram_start + channels_seg_num * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
channels_rem * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
load_loc_weight_idx += 1;
}
__asm__ volatile("sync;");
#endif
return;
}
......@@ -1316,294 +1347,496 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
}
}
template <typename T>
void __mlu_func__
loadData(const int32_t &h_low, const int32_t &w_low, const int32_t &h_high,
const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tr,
T *grad_output_nram_bl, T *grad_output_nram_br,
const T *data_value_ptr, const int32_t &width, const int32_t &height,
const int32_t &deal_num_real, const int32_t &h_low_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &w_high_ptr_offset,
const int32_t &h_high_ptr_offset, const int32_t &base_ptr) {
#if __BANG_ARCH__ > 322
if (h_low >= 0 && w_low >= 0)
{
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy_async(grad_output_nram_tl, data_value_ptr + offset1,
deal_num_real * sizeof(T), GDRAM2NRAM);
}
if (h_low >= 0 && w_high <= width - 1)
{
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy_async(grad_output_nram_tr, data_value_ptr + offset2,
deal_num_real * sizeof(T), GDRAM2NRAM);
}
if (h_high <= height - 1 && w_low >= 0)
{
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy_async(grad_output_nram_bl, data_value_ptr + offset3,
deal_num_real * sizeof(T), GDRAM2NRAM);
}
if (h_high <= height - 1 && w_high <= width - 1)
{
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy_async(grad_output_nram_br, data_value_ptr + offset4,
deal_num_real * sizeof(T), GDRAM2NRAM);
}
__sync_io();
void __mlu_func__ computeGridMaskAndOffset(
float *nram_grad_output_tl, float *nram_grad_output_tr, float *nram_loc_w,
float *nram_loc_h, float *nram_h_stride, int32_t *nram_spatial_shapes,
float *nram_w_low_temp, float *nram_h_high_temp, float *nram_w_low,
float *nram_h_low, float *nram_h_high, float *nram_w_high, float *nram_lh,
float *nram_lw, float *nram_hh, float *nram_hw,
float *nram_h_low_ptr_offset, float *nram_h_high_ptr_offset,
float *nram_w_low_ptr_offset, float *nram_w_high_ptr_offset, float *nram_w1,
float *nram_w2, float *nram_w3, float *nram_w4, float *nram_offset_temp,
float *nram_offset1, float *nram_offset2, float *nram_offset3,
float *nram_offset4, float *nram_base_ptr, float *nram_h_low_temp,
int32_t num_deal_grid, int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, const int32_t w_stride,
const int32_t qid_stride) {
#if __BANG_ARCH__ >= 322
// [num_levels, 2] --> [2, num_levels]
__bang_transpose(nram_grad_output_tl, nram_loc_w, num_deal_grid, 2);
__bang_transpose(nram_loc_w, nram_grad_output_tl,
num_per_time_real * num_heads * num_levels, num_points);
__bang_transpose(nram_loc_h, nram_grad_output_tl + num_deal_grid,
num_per_time_real * num_heads * num_levels, num_points);
__bang_int322float((float *)nram_spatial_shapes,
(int32_t *)nram_spatial_shapes, num_levels * 2, 0);
__bang_transpose(nram_grad_output_tr, (float *)nram_spatial_shapes,
num_levels, 2);
__bang_mul_scalar(nram_h_stride, nram_grad_output_tr + num_levels, w_stride,
num_levels);
__memcpy_async(nram_spatial_shapes, nram_grad_output_tr,
num_levels * 2 * sizeof(float), NRAM2NRAM);
__bang_cycle_mul(nram_loc_w, nram_loc_w,
(float *)nram_spatial_shapes + num_levels, num_deal_grid,
num_levels);
__bang_cycle_mul(nram_loc_h, nram_loc_h, (float *)(nram_spatial_shapes),
num_deal_grid, num_levels);
__bang_sub_scalar(nram_loc_w, nram_loc_w, 0.5, num_deal_grid);
__bang_sub_scalar(nram_loc_h, nram_loc_h, 0.5, num_deal_grid);
// get mask. (h_im > -1 && w_im > -1 &&
// h_im < spatial_h && w_im < spatial_w)
__bang_cycle_lt(nram_w_low_temp, nram_loc_w,
(float *)(nram_spatial_shapes + num_levels), num_deal_grid,
num_levels);
__bang_cycle_lt(nram_h_high_temp, nram_loc_h, (float *)(nram_spatial_shapes),
num_deal_grid, num_levels);
__bang_and(nram_w_low_temp, nram_w_low_temp, nram_h_high_temp, num_deal_grid);
__bang_gt_scalar(nram_h_high_temp, nram_loc_h, -1, num_deal_grid);
__bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp,
num_deal_grid);
__bang_gt_scalar(nram_w_low_temp, nram_loc_w, -1, num_deal_grid);
__bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp,
num_deal_grid);
__bang_transpose(nram_w_low_temp, nram_h_high_temp, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_high_temp, nram_w_low_temp,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, nram_loc_w, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_loc_w, nram_grad_output_tl, num_deal_grid * sizeof(float),
NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, nram_loc_h, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_loc_h, nram_grad_output_tl, num_deal_grid * sizeof(float),
NRAM2NRAM);
__bang_floor(nram_w_low, nram_loc_w, num_deal_grid);
__bang_floor(nram_h_low, nram_loc_h, num_deal_grid);
__bang_add_scalar(nram_h_high, nram_h_low, 1, num_deal_grid);
__bang_add_scalar(nram_w_high, nram_w_low, 1, num_deal_grid);
__bang_sub(nram_lh, nram_loc_h, nram_h_low, num_deal_grid);
__bang_sub(nram_lw, nram_loc_w, nram_w_low, num_deal_grid);
__bang_fusion(FUSION_FMA, nram_hh, nram_lh, (float)(-1), 1, num_deal_grid);
__bang_fusion(FUSION_FMA, nram_hw, nram_lw, (float)(-1), 1, num_deal_grid);
__bang_transpose(nram_h_low_ptr_offset, nram_h_low,
num_per_time_real * num_heads * num_levels, num_points);
__bang_cycle_mul(nram_h_low_ptr_offset, nram_h_low_ptr_offset, nram_h_stride,
num_deal_grid, num_levels);
__bang_cycle_add(nram_h_high_ptr_offset, nram_h_low_ptr_offset, nram_h_stride,
num_deal_grid, num_levels);
__bang_transpose(nram_w_low_ptr_offset, nram_h_low_ptr_offset, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_low_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_w_low_ptr_offset, nram_h_high_ptr_offset, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_high_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_mul_scalar(nram_w_low_ptr_offset, nram_w_low, qid_stride,
num_deal_grid);
__bang_add_scalar(nram_w_high_ptr_offset, nram_w_low_ptr_offset, qid_stride,
num_deal_grid);
__bang_mul(nram_w1, nram_hh, nram_hw, num_deal_grid);
__bang_mul(nram_w2, nram_hh, nram_lw, num_deal_grid);
__bang_mul(nram_w3, nram_lh, nram_hw, num_deal_grid);
__bang_mul(nram_w4, nram_lh, nram_lw, num_deal_grid);
__bang_add(nram_offset1, nram_h_low_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset1,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset1, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset2, nram_h_low_ptr_offset, nram_w_high_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset2,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset2, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset3, nram_h_high_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset3,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset3, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset4, nram_h_high_ptr_offset, nram_w_high_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset4,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset4, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
// h_low >= 0 && w_low >= 0 mask2
float *mask1 = nram_h_low_ptr_offset;
float *mask2 = nram_h_high_ptr_offset;
float *mask3 = nram_w_low_ptr_offset;
float *mask4 = nram_w_high_ptr_offset;
__bang_ge_scalar(mask1, nram_h_low, 0, num_deal_grid);
__bang_ge_scalar(mask2, nram_w_low, 0, num_deal_grid);
__bang_and(mask2, mask1, mask2, num_deal_grid);
__bang_and(mask2, nram_h_high_temp, mask2, num_deal_grid);
// h_low >= 0 && w_high <= width - 1 mask1
__bang_transpose(mask3, nram_w_high,
num_per_time_real * num_heads * num_levels, num_points);
__bang_sub_scalar(nram_spatial_shapes, nram_spatial_shapes, 1,
num_levels * 2);
__bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes + num_levels),
num_deal_grid, num_levels);
__bang_transpose(mask4, mask3, num_points,
num_per_time_real * num_heads * num_levels);
__bang_and(mask1, mask1, mask4, num_deal_grid);
__bang_and(mask1, nram_h_high_temp, mask1, num_deal_grid);
// h_high <= height - 1 && w_high <= width - 1 mask3
__bang_transpose(mask3, nram_h_high,
num_per_time_real * num_heads * num_levels, num_points);
__bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes), num_deal_grid,
num_levels);
__bang_transpose(nram_h_low_temp, mask3, num_points,
num_per_time_real * num_heads * num_levels);
__bang_and(mask4, mask4, nram_h_low_temp, num_deal_grid);
__bang_and(mask3, mask4, nram_h_high_temp, num_deal_grid);
// h_high <= height - 1 && w_low >= 0 mask4
__bang_ge_scalar(nram_w_low_temp, nram_w_low, 0, num_deal_grid);
__bang_and(mask4, nram_h_low_temp, nram_w_low_temp, num_deal_grid);
__bang_and(mask4, mask4, nram_h_high_temp, num_deal_grid);
#endif
}
template <typename T>
void __mlu_func__ computeData(
const int32_t &h_low, const int32_t &w_low, const int32_t &h_high,
const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tr,
T *grad_output_nram_bl, T *grad_output_nram_br, T *grad_output_nram_tl_temp,
T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp,
T *grad_output_nram_br_temp, const int32_t &width, const int32_t &height,
const int32_t &deal_num_real, T *grad_h_weight, T *grad_w_weight,
T *top_grad_temp, T *top_grad, const T &data_attn_weight, const T &hw,
const T &hh, const T &lw, const T &lh, const T &w1, const T &w2,
const T &w3, const T &w4) {
#if __BANG_ARCH__ > 322
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
if (h_low >= 0 && w_low >= 0) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tl, (float)(-hw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tl, (float)(-hh),
grad_w_weight, deal_num_real, deal_num_real);
__bang_mul_scalar(grad_output_nram_tl_temp, top_grad_temp, w1,
deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_tl, grad_output_nram_tl, w1,
deal_num_real);
}
if (h_low >= 0 && w_high <= width - 1) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tr, (float)(-lw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tr, (float)(hh),
grad_w_weight, deal_num_real, deal_num_real);
__bang_mul_scalar(grad_output_nram_tr_temp, top_grad_temp, w2,
deal_num_real);
__bang_mul_scalar(grad_output_nram_tr, grad_output_nram_tr, w2,
deal_num_real);
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_tr,
deal_num_real);
}
if (h_high <= height - 1 && w_low >= 0) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_bl, (float)(hw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_bl, (float)(-lh),
grad_w_weight, deal_num_real, deal_num_real);
__bang_mul_scalar(grad_output_nram_bl_temp, top_grad_temp, w3,
deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_bl, grad_output_nram_bl, w3,
deal_num_real);
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_bl,
deal_num_real);
void __mlu_func__ loadValue(
float *nram_grad_output_tl, float *nram_grad_output_tr,
float *nram_grad_output_bl, float *nram_grad_output_br,
const float *data_value, const float *grad_output, float *grad_temp1,
float *grad_temp2, float *mask1, float *mask2, float *mask3, float *mask4,
float *nram_offset1, float *nram_offset2, float *nram_offset3,
float *nram_offset4, float *nram_grad_weight,
int32_t *nram_level_start_index, int32_t offset_nram,
int32_t start_per_core, int32_t grid_loop, int32_t num_per_time_theory,
int32_t num_heads, int32_t deal_num_real, int32_t num_per_time_real,
int32_t num_deal_grid, const int32_t num_query, const int32_t num_levels,
const int32_t num_points, int32_t grid_offset, const int32_t spatial_size,
const int32_t qid_stride) {
#if __BANG_ARCH__ >= 322
int32_t value_offset_temp = 0;
__bang_write_zero(nram_grad_output_tl, 4 * offset_nram);
__sync_io_move_compute();
__memcpy_async(
grad_temp2,
grad_output + (start_per_core + grid_loop * num_per_time_theory) *
num_heads * deal_num_real,
num_per_time_real * num_heads * deal_num_real * sizeof(float),
GDRAM2NRAM);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
const int32_t b_col =
(grid_offset + loop) / num_query / num_heads / num_levels / num_points;
const int32_t l_col = (grid_offset + loop) / num_points % num_levels;
const int32_t level_start_id = nram_level_start_index[l_col];
value_offset_temp =
b_col * spatial_size * qid_stride + level_start_id * qid_stride;
if (mask2[loop]) {
__memcpy_async(
nram_grad_output_tl + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset1[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask1[loop]) {
__memcpy_async(
nram_grad_output_tr + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset2[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask4[loop]) {
__memcpy_async(
nram_grad_output_bl + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset3[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
if (mask3[loop]) {
__memcpy_async(
nram_grad_output_br + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset4[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
}
if (h_high <= height - 1 && w_high <= width - 1) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_br, (float)(lw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_br, (float)(lh),
grad_w_weight, deal_num_real, deal_num_real);
__bang_mul_scalar(grad_output_nram_br_temp, top_grad_temp, w4,
deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_br, grad_output_nram_br, w4,
deal_num_real);
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_br,
deal_num_real);
for (int32_t m = 0; m < deal_num_real; ++m) {
__memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
}
__bang_mul(grad_output_nram_tl, grad_output_nram_tl, top_grad, deal_num_real);
recursiveSumPool(grad_output_nram_tl, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
__bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real);
__bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real);
recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
__bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real);
__bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real);
recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
__sync_io_move_compute();
#endif
}
template <typename T>
void __mlu_func__ storeData(
const int32_t &h_low, const int32_t &w_low, const int32_t &h_high,
const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tl_temp,
T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp,
T *grad_output_nram_br_temp, const int32_t &width, const int32_t &height,
const int32_t &deal_num_real, const int32_t &h_low_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &w_high_ptr_offset,
const int32_t &h_high_ptr_offset, const int32_t &base_ptr, T *grad_value,
T *grad_w_weight, T *grad_h_weight, T *grad_sampling_loc,
T *grad_attn_weight) {
#if __BANG_ARCH__ > 322
if (h_low >= 0 && w_low >= 0)
{
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
__bang_atomic_add((T *)grad_output_nram_tl_temp,
(T *)(grad_value + offset1),
(T *)grad_output_nram_tl_temp, deal_num_real);
void __mlu_func__ computeGradValue(
float *grad_temp1, float *grad_temp2, float *grad_temp3, float *grad_temp4,
float *mask1, float *mask2, float *mask3, float *mask4, float *nram_offset1,
float *nram_offset2, float *nram_offset3, float *nram_offset4,
int32_t *nram_level_start_index, int32_t deal_num_real,
const float *grad_value, float *nram_w1, float *nram_w2, float *nram_w3,
float *nram_w4, int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, const int32_t num_query,
int32_t num_deal_grid, int32_t grid_offset, const int32_t spatial_size,
const int32_t qid_stride, float *nram_grid_offset1,
float *nram_grid_offset2) {
#if __BANG_ARCH__ >= 322
__bang_transpose(grad_temp3, grad_temp1,
deal_num_real * num_per_time_real * num_heads,
num_levels * num_points);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(grad_temp3, grad_temp3, grad_temp1,
num_deal_grid * deal_num_real,
deal_num_real * num_per_time_real * num_heads);
__bang_transpose(grad_temp4, grad_temp3, num_levels * num_points,
deal_num_real * num_per_time_real * num_heads);
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w1,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
nram_grid_offset1[loop] = ((loop + grid_offset) / num_query / num_heads /
num_levels / num_points) *
spatial_size * qid_stride;
}
if (h_low >= 0 && w_high <= width - 1)
{
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
__bang_atomic_add((T *)grad_output_nram_tr_temp,
(T *)(grad_value + offset2),
(T *)grad_output_nram_tr_temp, deal_num_real);
__bang_transpose(nram_grid_offset2, nram_grid_offset1,
num_per_time_real * num_heads * num_levels, num_points);
__bang_int322float((float *)nram_level_start_index, nram_level_start_index,
num_levels, 0);
__bang_mul_scalar(nram_grid_offset1, (float *)nram_level_start_index,
qid_stride, num_levels);
__bang_cycle_add(nram_grid_offset2, nram_grid_offset2, nram_grid_offset1,
num_deal_grid, num_levels);
__bang_transpose(nram_grid_offset1, nram_grid_offset2, num_points,
num_per_time_real * num_heads * num_levels);
__bang_add(nram_offset1, nram_offset1, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset2, nram_offset2, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset3, nram_offset3, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset4, nram_offset4, nram_grid_offset1, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask2[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset1[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
if (h_high <= height - 1 && w_low >= 0)
{
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
__bang_atomic_add((T *)grad_output_nram_bl_temp,
(T *)(grad_value + offset3),
(T *)grad_output_nram_bl_temp, deal_num_real);
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w2,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask1[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset2[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w3,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask4[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset3[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
if (h_high <= height - 1 && w_high <= width - 1)
{
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
__bang_atomic_add((T *)grad_output_nram_br_temp,
(T *)(grad_value + offset4),
(T *)grad_output_nram_br_temp, deal_num_real);
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w4,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask3[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset4[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_atomic_add((T *)grad_output_nram_tl, (T *)grad_attn_weight,
(T *)grad_output_nram_tl, 1);
__bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc),
(T *)grad_w_weight, 1);
__bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1),
(T *)grad_h_weight, 1);
#endif
}
template <typename T>
void __mlu_func__ msDeformAttnCol2imBilinearSmallChannels(
T *top_grad_temp, const int32_t &height, const int32_t &width, const T &w1,
const T &w2, const T &w3, const T &w4, const int32_t &h_low,
const int32_t &w_low, const int32_t &h_high, const int32_t &w_high,
const int32_t &base_ptr, const int32_t &h_low_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &h_high_ptr_offset,
const int32_t &w_high_ptr_offset, const T &hh, const T &hw, const T &lh,
const T &lw, T *top_grad, const T &data_attn_weight, T *grad_h_weight,
T *grad_w_weight, T *grad_value, T *grad_output_nram_tl,
T *grad_output_nram_tr, T *grad_output_nram_bl, T *grad_output_nram_br,
T *grad_output_nram_tl_temp, T *grad_output_nram_tr_temp,
T *grad_output_nram_bl_temp, T *grad_output_nram_br_temp,
T *grad_sampling_loc, T *grad_attn_weight, const int32_t &deal_num_real,
const T *data_value_ptr)
{
loadData(h_low, w_low, h_high, w_high, grad_output_nram_tl,
grad_output_nram_tr, grad_output_nram_bl, grad_output_nram_br,
data_value_ptr, width, height, deal_num_real, h_low_ptr_offset,
w_low_ptr_offset, w_high_ptr_offset, h_high_ptr_offset, base_ptr);
computeData(h_low, w_low, h_high, w_high, grad_output_nram_tl,
grad_output_nram_tr, grad_output_nram_bl, grad_output_nram_br,
grad_output_nram_tl_temp, grad_output_nram_tr_temp,
grad_output_nram_bl_temp, grad_output_nram_br_temp, width, height,
deal_num_real, grad_h_weight, grad_w_weight, top_grad_temp,
top_grad, data_attn_weight, hw, hh, lw, lh, w1, w2, w3, w4);
storeData(h_low, w_low, h_high, w_high, grad_output_nram_tl,
grad_output_nram_tl_temp, grad_output_nram_tr_temp,
grad_output_nram_bl_temp, grad_output_nram_br_temp, width, height,
deal_num_real, h_low_ptr_offset, w_low_ptr_offset,
w_high_ptr_offset, h_high_ptr_offset, base_ptr, grad_value,
grad_w_weight, grad_h_weight, grad_sampling_loc, grad_attn_weight);
void __mlu_func__ computeGradAttnWeight(
float *grad_w_weight, float *grad_weight, float *nram_grad_output_tl,
float *nram_grad_output_tr, float *nram_grad_output_bl,
float *nram_grad_output_br, float *grad_temp1, float *grad_temp2,
const float *grad_attn_weight, float *nram_hw, float *nram_hh,
float *nram_lw, float *nram_lh, float *grad_h_weight, float *nram_w1,
float *nram_w2, float *nram_w3, float *nram_w4, int32_t offset_nram,
int32_t num_deal_grid, int32_t deal_num_real, int32_t num_per_time_real,
const int32_t num_heads, const int32_t num_levels, const int32_t num_points,
int32_t grid_offset, float *nram_h_high_temp) {
#if __BANG_ARCH__ >= 322
__bang_write_zero(grad_w_weight, 2 * offset_nram);
// grad_output_nram_tl
__bang_transpose(grad_weight, nram_grad_output_tl, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_w1,
num_deal_grid * deal_num_real, num_deal_grid);
// nram_grad_output_tr
__bang_transpose(grad_weight, nram_grad_output_tr, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_lw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tr,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_hh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_w_weight, grad_w_weight, nram_grad_output_tr,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_w2,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_tr,
num_deal_grid * deal_num_real);
// nram_grad_output_tl
__bang_transpose(grad_weight, nram_grad_output_bl, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_hw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_h_weight, grad_h_weight, nram_grad_output_bl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_lh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_bl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_w3,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_bl,
num_deal_grid * deal_num_real);
// nram_grad_output_br
__bang_transpose(grad_weight, nram_grad_output_br, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_h_weight, grad_h_weight, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_w_weight, grad_w_weight, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_w4,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_br, nram_grad_output_tl, deal_num_real,
num_deal_grid);
__bang_transpose(nram_grad_output_tr, nram_grad_output_br,
num_per_time_real * num_heads,
num_points * num_levels * deal_num_real);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1,
num_deal_grid * deal_num_real,
num_per_time_real * num_heads * deal_num_real);
__bang_transpose(nram_grad_output_br, nram_grad_output_tr,
num_points * num_levels * deal_num_real,
num_per_time_real * num_heads);
__bang_transpose((float *)nram_grad_output_tr, (float *)nram_grad_output_br,
num_deal_grid, deal_num_real);
recursiveSumPool(nram_grad_output_tr, num_deal_grid, deal_num_real,
ALIGN_NUM);
__bang_float2int32((int *)nram_h_high_temp, nram_h_high_temp, num_deal_grid,
0);
__nram__ int table[2] = {0, (int)0xffffffff};
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)nram_grad_output_tr, (char *)nram_grad_output_tr,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__bang_atomic_add((float *)nram_grad_output_tr,
(float *)grad_attn_weight + grid_offset,
(float *)nram_grad_output_tr, num_deal_grid);
#endif
}
template <typename T>
void __mlu_func__ msDeformAttnCol2imImpl(
T *top_grad_temp, T *top_grad, T *grad_h_weight, T *grad_w_weight,
T *grad_value, T *grad_output_nram_tl, T *grad_output_nram_tr,
T *grad_output_nram_bl, T *grad_output_nram_br, T *grad_output_nram_tl_temp,
T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp,
T *grad_output_nram_br_temp, T *grad_sampling_loc, T *grad_attn_weight,
T *nram_sampling_loc, T *nram_attn_weight, const int32_t &load_num,
const int32_t &tail, const int32_t &i_repeat, const int32_t &num_points,
const int32_t &start_per_core, const int32_t &num_levels,
const int32_t &num_heads, const int32_t &num_query,
const int32_t &spatial_size, const int32_t &qid_stride,
int32_t *level_start_index_nram, const int32_t &channels,
const T *data_value, const T *grad_output, int32_t *spatial_shapes_nram) {
#if __BANG_ARCH__ > 322
int32_t weight_pos = 0;
int32_t sampling_loc_pos = 0;
for (int32_t p = 0; p < tail; ++p) {
int32_t grid_offset = start_per_core + i_repeat * load_num + p;
const int32_t l_col = grid_offset % num_levels;
const int32_t m_col = grid_offset / num_levels % num_heads;
const int32_t q_col = grid_offset / num_levels / num_heads % num_query;
const int32_t b_col = grid_offset / num_query / num_heads / num_levels;
const int32_t value_offset = b_col * spatial_size * qid_stride;
const int32_t level_start_id = level_start_index_nram[l_col];
const int32_t grad_attn_weight_out = grid_offset * num_points;
const int32_t spatial_h_ptr = l_col << 1;
const int32_t grad_output_offset =
b_col * num_query * qid_stride + q_col * qid_stride + m_col * channels;
__memcpy(top_grad, grad_output + grad_output_offset, channels * LEN_FLOAT,
GDRAM2NRAM);
const int32_t spatial_h = spatial_shapes_nram[spatial_h_ptr];
const int32_t spatial_w = spatial_shapes_nram[spatial_h_ptr + 1];
const int32_t h_stride = spatial_w * qid_stride;
const int32_t value_ptr_offset = value_offset + level_start_id * qid_stride;
const float *data_value_ptr = data_value + value_ptr_offset;
float *grad_value_ptr = grad_value + value_ptr_offset;
const int32_t grad_sampling_loc_out = grid_offset * num_points << 1;
for (int32_t p_col = 0; p_col < num_points; ++p_col) {
const float loc_w = nram_sampling_loc[sampling_loc_pos];
const float loc_h = nram_sampling_loc[sampling_loc_pos + 1];
const float weight = nram_attn_weight[weight_pos];
const float h_im = loc_h * spatial_h - 0.5;
const float w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
const int32_t h_low = floorf(h_im);
const int32_t w_low = floorf(w_im);
const int32_t h_high = h_low + 1;
const int32_t w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1.0 - lh;
const float hw = 1.0 - lw;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * qid_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + qid_stride;
const float w1 = hh * hw;
const float w2 = hh * lw;
const float w3 = lh * hw;
const float w4 = lh * lw;
const int32_t base_ptr = m_col * channels;
__bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram_tl, PAD_UP(channels, ALIGN_NUM));
msDeformAttnCol2imBilinearSmallChannels(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl,
grad_output_nram_br, grad_output_nram_tl_temp,
grad_output_nram_tr_temp, grad_output_nram_bl_temp,
grad_output_nram_br_temp,
grad_sampling_loc + grad_sampling_loc_out + (p_col << 1),
grad_attn_weight + grad_attn_weight_out + p_col, channels,
data_value_ptr);
}
weight_pos += 1;
sampling_loc_pos += 2;
}
void __mlu_func__ computeGradSampingLoc(
const float *grad_sampling_loc, float *nram_grad_output_tl,
float *nram_grad_output_tr, float *grad_h_weight, float *grad_w_weight,
int32_t *nram_spatial_shapes, float *grad_temp1, float *grad_temp2,
float *nram_grad_weight, int32_t num_deal_grid, int32_t deal_num_real,
int32_t num_per_time_real, const int32_t num_heads,
const int32_t num_levels, const int32_t num_points, int32_t grid_offset,
float *nram_h_high_temp) {
#if __BANG_ARCH__ >= 322
__bang_transpose(nram_grad_output_tl, grad_h_weight,
num_per_time_real * num_heads * num_levels * deal_num_real,
num_points);
__bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl,
(float *)nram_spatial_shapes, num_deal_grid * deal_num_real,
num_levels);
__bang_transpose(grad_h_weight, nram_grad_output_tl,
num_points * deal_num_real,
num_per_time_real * num_heads * num_levels);
for (int32_t m = 0; m < deal_num_real; ++m) {
__memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
}
__sync_io_move_compute();
__bang_transpose(nram_grad_output_tr, grad_temp1,
deal_num_real * num_per_time_real * num_heads,
num_levels * num_points);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1,
num_deal_grid * deal_num_real,
deal_num_real * num_per_time_real * num_heads);
__bang_transpose(grad_temp1, nram_grad_output_tr,
num_levels * num_points * deal_num_real,
num_per_time_real * num_heads);
__bang_mul(grad_h_weight, grad_h_weight, grad_temp1,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_tl, grad_h_weight, num_deal_grid,
deal_num_real);
__memcpy_async(grad_h_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM);
recursiveSumPool(grad_h_weight, num_deal_grid, deal_num_real, ALIGN_NUM);
__nram__ int table[2] = {0, (int)0xffffffff};
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)grad_h_weight, (char *)grad_h_weight,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__bang_transpose(nram_grad_output_tl, grad_w_weight,
num_per_time_real * num_heads * num_levels * deal_num_real,
num_points);
__bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl,
(float *)(nram_spatial_shapes + num_levels),
num_deal_grid * deal_num_real, num_levels);
__bang_transpose(grad_w_weight, nram_grad_output_tl,
num_points * deal_num_real,
num_per_time_real * num_heads * num_levels);
__bang_mul(grad_w_weight, grad_w_weight, grad_temp1,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_tl, grad_w_weight, num_deal_grid,
deal_num_real);
__memcpy(grad_w_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM);
recursiveSumPool(grad_w_weight, num_deal_grid, deal_num_real, ALIGN_NUM);
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)grad_w_weight, (char *)grad_w_weight,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__memcpy(grad_w_weight + num_deal_grid, grad_h_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, grad_w_weight, 2, num_deal_grid);
__bang_atomic_add((float *)nram_grad_output_tl,
(float *)grad_sampling_loc + grid_offset * 2,
(float *)nram_grad_output_tl, 2 * num_deal_grid);
#endif
}
......@@ -1616,117 +1849,195 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight) {
#if __BANG_ARCH__ > 322
const int32_t split_num = 12;
const int32_t split_grid_num = 28;
const int32_t split_num_c = 8;
const int32_t C_align = PAD_UP(channels, ALIGN_NUM);
float *grad_output_nram_tl = (float *)nram_buffer;
float *grad_output_nram_tr = (float *)nram_buffer + C_align;
float *grad_output_nram_bl = (float *)nram_buffer + 2 * C_align;
float *grad_output_nram_br = (float *)nram_buffer + 3 * C_align;
float *grad_output_nram_tl_temp = (float *)nram_buffer + 4 * C_align;
float *grad_output_nram_tr_temp = (float *)nram_buffer + 5 * C_align;
float *grad_output_nram_bl_temp = (float *)nram_buffer + 6 * C_align;
float *grad_output_nram_br_temp = (float *)nram_buffer + 7 * C_align;
float *grad_h_weight = (float *)nram_buffer + 8 * C_align;
float *grad_w_weight = (float *)nram_buffer + 9 * C_align;
float *top_grad_temp = (float *)nram_buffer + 10 * C_align;
float *top_grad = (float *)nram_buffer + 11 * C_align;
int32_t *spatial_shapes_nram =
(int32_t *)((float *)nram_buffer + split_num * C_align);
int32_t *level_start_index_nram =
(int32_t *)(spatial_shapes_nram + PAD_UP(num_levels * 2, ALIGN_NUM));
float *nram_remain = (float *)((int32_t *)level_start_index_nram +
PAD_UP(num_levels, ALIGN_NUM));
// calc load num
const int32_t weight_num2nram =
(MAX_NRAM_SIZE / LEN_FLOAT - split_num * C_align -
3 * PAD_UP(num_levels, ALIGN_NUM)) /
3 / num_points;
int32_t load_num = weight_num2nram;
const int32_t total_num = batch * num_query * num_heads * num_levels;
const int32_t num_hlp = num_heads * num_levels * num_points;
int32_t num_per_time_theory = (MAX_NRAM_SIZE - num_levels * sizeof(float) -
3 * num_levels * sizeof(int32_t)) /
sizeof(float) /
(split_num_c * C_align + split_grid_num) /
PAD_UP((num_hlp), ALIGN_NUM);
int32_t deal_grid_num_theory = num_per_time_theory * num_hlp;
const int32_t offset_nram = num_per_time_theory * C_align * num_hlp;
const int32_t offset_nram_calc = PAD_UP(deal_grid_num_theory, ALIGN_NUM);
float *nram_grad_output_tl = (float *)nram_buffer;
float *nram_grad_output_tr = (float *)nram_buffer + offset_nram;
float *nram_grad_output_bl = (float *)nram_buffer + 2 * offset_nram;
float *nram_grad_output_br = (float *)nram_buffer + 3 * offset_nram;
float *grad_temp1 = (float *)nram_buffer + 4 * offset_nram;
float *grad_temp2 = (float *)nram_buffer + 5 * offset_nram;
float *grad_temp3 = (float *)nram_buffer + 6 * offset_nram;
float *grad_temp4 = (float *)nram_buffer + 7 * offset_nram;
float *nram_loc_w = (float *)nram_buffer + split_num_c * offset_nram;
float *nram_loc_h =
(float *)nram_buffer + split_num_c * offset_nram + offset_nram_calc;
float *nram_h_low =
(float *)nram_buffer + split_num_c * offset_nram + 2 * offset_nram_calc;
float *nram_w_low =
(float *)nram_buffer + split_num_c * offset_nram + 3 * offset_nram_calc;
float *nram_h_high =
(float *)nram_buffer + split_num_c * offset_nram + 4 * offset_nram_calc;
float *nram_w_high =
(float *)nram_buffer + split_num_c * offset_nram + 5 * offset_nram_calc;
float *nram_h_low_temp =
(float *)nram_buffer + split_num_c * offset_nram + 6 * offset_nram_calc;
float *nram_h_high_temp =
(float *)nram_buffer + split_num_c * offset_nram + 7 * offset_nram_calc;
float *nram_hw =
(float *)nram_buffer + split_num_c * offset_nram + 8 * offset_nram_calc;
float *nram_hh =
(float *)nram_buffer + split_num_c * offset_nram + 9 * offset_nram_calc;
float *nram_lw =
(float *)nram_buffer + split_num_c * offset_nram + 10 * offset_nram_calc;
float *nram_lh =
(float *)nram_buffer + split_num_c * offset_nram + 11 * offset_nram_calc;
float *nram_h_low_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 12 * offset_nram_calc;
float *nram_h_high_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 13 * offset_nram_calc;
float *nram_w_low_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 14 * offset_nram_calc;
float *nram_w_high_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 15 * offset_nram_calc;
float *nram_w1 =
(float *)nram_buffer + split_num_c * offset_nram + 16 * offset_nram_calc;
float *nram_w2 =
(float *)nram_buffer + split_num_c * offset_nram + 17 * offset_nram_calc;
float *nram_w3 =
(float *)nram_buffer + split_num_c * offset_nram + 18 * offset_nram_calc;
float *nram_w4 =
(float *)nram_buffer + split_num_c * offset_nram + 19 * offset_nram_calc;
float *nram_grad_weight =
(float *)nram_buffer + split_num_c * offset_nram + 20 * offset_nram_calc;
float *nram_base_ptr =
(float *)nram_buffer + split_num_c * offset_nram + 21 * offset_nram_calc;
float *nram_offset_temp =
(float *)nram_buffer + split_num_c * offset_nram + 22 * offset_nram_calc;
float *nram_offset1 =
(float *)nram_buffer + split_num_c * offset_nram + 23 * offset_nram_calc;
float *nram_offset2 =
(float *)nram_buffer + split_num_c * offset_nram + 24 * offset_nram_calc;
float *nram_offset3 =
(float *)nram_buffer + split_num_c * offset_nram + 25 * offset_nram_calc;
float *nram_offset4 =
(float *)nram_buffer + split_num_c * offset_nram + 26 * offset_nram_calc;
float *nram_w_low_temp =
(float *)nram_buffer + split_num_c * offset_nram + 27 * offset_nram_calc;
int32_t *nram_spatial_shapes =
(int32_t *)((float *)nram_buffer + split_num_c * offset_nram +
28 * offset_nram_calc);
int32_t *nram_level_start_index =
(int32_t *)(nram_spatial_shapes + 2 * num_levels);
float *nram_h_stride = (float *)(nram_level_start_index + 3 * num_levels);
const int32_t total_num = batch * num_query;
int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem);
if (num_per_core == 0) {
return;
}
const int32_t start_per_core = num_rem > taskId
? (taskId * num_per_core)
: (num_rem + taskId * num_per_core);
num_per_time_theory =
num_per_core > num_per_time_theory ? num_per_time_theory : num_per_core;
int32_t num_deal_grid = num_per_time_theory * num_hlp;
if (num_per_core == 0) return;
int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core)
: (num_rem + taskId * num_per_core);
const int32_t qid_stride = num_heads * channels;
int32_t deal_num_real = channels;
// load spatial_shapes anddata_level_start_index to nram
__memcpy_async(spatial_shapes_nram, spatial_shapes,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(level_start_index_nram, data_level_start_index,
num_levels * sizeof(int32_t), GDRAM2NRAM);
const int32_t repeat_times = num_per_core / num_per_time_theory;
const int32_t tail_num = num_per_core % num_per_time_theory;
const int32_t start_l_col = start_per_core % num_levels;
const int32_t start_m_col = start_per_core / num_levels % num_heads;
const int32_t start_q_col =
start_per_core / num_levels / num_heads % num_query;
const int32_t start_b_col =
start_per_core / num_query / num_heads / num_levels;
const int32_t repeat = num_per_core / load_num;
const int32_t tail = num_per_core % load_num;
float *nram_sampling_loc = nram_remain;
float *nram_attn_weight = nram_sampling_loc + 2 * load_num * num_points;
const int32_t attn_weight_offset =
start_b_col * num_query * num_heads * num_levels * num_points +
start_q_col * num_heads * num_levels * num_points +
start_m_col * num_levels * num_points + start_l_col * num_points;
const int32_t sampling_loc_offset =
start_b_col * num_query * num_heads * num_levels * num_points * 2 +
start_q_col * num_heads * num_levels * num_points * 2 +
start_m_col * num_levels * num_points * 2 + start_l_col * num_points * 2;
if (repeat > 0) {
for (int32_t i_repeat = 0; i_repeat < repeat; ++i_repeat)
{ // load weight and sampling_loc to nram
__memcpy_async(nram_sampling_loc,
data_sampling_loc + sampling_loc_offset +
i_repeat * load_num * 2 * num_points,
2 * load_num * num_points * LEN_FLOAT, GDRAM2NRAM);
__memcpy(nram_attn_weight,
data_attn_weight + attn_weight_offset +
i_repeat * load_num * num_points,
load_num * num_points * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imImpl(
top_grad_temp, top_grad, grad_h_weight, grad_w_weight, grad_value,
grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl,
grad_output_nram_br, grad_output_nram_tl_temp,
grad_output_nram_tr_temp, grad_output_nram_bl_temp,
grad_output_nram_br_temp, grad_sampling_loc, grad_attn_weight,
nram_sampling_loc, nram_attn_weight, load_num, load_num, i_repeat,
num_points, start_per_core, num_levels, num_heads, num_query,
spatial_size, qid_stride, level_start_index_nram, channels,
data_value, grad_output, spatial_shapes_nram);
}
int32_t num_per_time_real = num_per_time_theory;
for (int32_t loop = 0; loop < num_heads; ++loop) {
nram_base_ptr[loop] = loop * channels;
}
if (tail > 0)
{ // load weight and sampling_loc to nram
__memcpy_async(nram_sampling_loc,
data_sampling_loc + sampling_loc_offset +
repeat * load_num * 2 * num_points,
tail * num_points * 2 * LEN_FLOAT, GDRAM2NRAM);
__memcpy(
nram_attn_weight,
data_attn_weight + attn_weight_offset + repeat * load_num * num_points,
tail * num_points * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imImpl(
top_grad_temp, top_grad, grad_h_weight, grad_w_weight, grad_value,
grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl,
grad_output_nram_br, grad_output_nram_tl_temp, grad_output_nram_tr_temp,
grad_output_nram_bl_temp, grad_output_nram_br_temp, grad_sampling_loc,
grad_attn_weight, nram_sampling_loc, nram_attn_weight, load_num, tail,
repeat, num_points, start_per_core, num_levels, num_heads, num_query,
spatial_size, qid_stride, level_start_index_nram, channels, data_value,
grad_output, spatial_shapes_nram);
const int32_t w_stride = num_heads * channels;
for (int32_t grid_loop = 0; grid_loop < repeat_times + 1; grid_loop += 1) {
int32_t grid_offset =
(start_per_core + grid_loop * num_per_time_theory) * num_hlp;
if (grid_loop == repeat_times) {
if (tail_num == 0) {
continue;
} else {
grid_offset =
(start_per_core + repeat_times * num_per_time_theory) * num_hlp;
num_per_time_real = tail_num;
num_deal_grid = tail_num * num_hlp;
}
}
__memcpy_async(nram_spatial_shapes, spatial_shapes,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(nram_level_start_index, data_level_start_index,
num_levels * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(nram_loc_w, data_sampling_loc + grid_offset * 2,
num_deal_grid * 2 * sizeof(float), GDRAM2NRAM);
__memcpy(nram_grad_weight, data_attn_weight + grid_offset,
num_deal_grid * sizeof(float), GDRAM2NRAM);
computeGridMaskAndOffset(
nram_grad_output_tl, nram_grad_output_tr, nram_loc_w, nram_loc_h,
nram_h_stride, nram_spatial_shapes, nram_w_low_temp, nram_h_high_temp,
nram_w_low, nram_h_low, nram_h_high, nram_w_high, nram_lh, nram_lw,
nram_hh, nram_hw, nram_h_low_ptr_offset, nram_h_high_ptr_offset,
nram_w_low_ptr_offset, nram_w_high_ptr_offset, nram_w1, nram_w2,
nram_w3, nram_w4, nram_offset_temp, nram_offset1, nram_offset2,
nram_offset3, nram_offset4, nram_base_ptr, nram_h_low_temp,
num_deal_grid, num_per_time_real, num_heads, num_levels, num_points,
w_stride, qid_stride);
float *mask1 = nram_h_low_ptr_offset;
float *mask2 = nram_h_high_ptr_offset;
float *mask3 = nram_w_low_ptr_offset;
float *mask4 = nram_w_high_ptr_offset;
loadValue(nram_grad_output_tl, nram_grad_output_tr, nram_grad_output_bl,
nram_grad_output_br, data_value, grad_output, grad_temp1,
grad_temp2, mask1, mask2, mask3, mask4, nram_offset1,
nram_offset2, nram_offset3, nram_offset4, nram_grad_weight,
nram_level_start_index, offset_nram, start_per_core, grid_loop,
num_per_time_theory, num_heads, deal_num_real, num_per_time_real,
num_deal_grid, num_query, num_levels, num_points, grid_offset,
spatial_size, qid_stride);
float *nram_grid_offset1 = nram_loc_h;
float *nram_grid_offset2 = nram_loc_w;
computeGradValue(
grad_temp1, grad_temp2, grad_temp3, grad_temp4, mask1, mask2, mask3,
mask4, nram_offset1, nram_offset2, nram_offset3, nram_offset4,
nram_level_start_index, deal_num_real, grad_value, nram_w1, nram_w2,
nram_w3, nram_w4, num_per_time_real, num_heads, num_levels, num_points,
num_query, num_deal_grid, grid_offset, spatial_size, qid_stride,
nram_grid_offset1, nram_grid_offset2);
// compute grad_weight
float *grad_weight = grad_temp1;
float *grad_h_weight = grad_temp4;
float *grad_w_weight = grad_temp3;
computeGradAttnWeight(
grad_w_weight, grad_weight, nram_grad_output_tl, nram_grad_output_tr,
nram_grad_output_bl, nram_grad_output_br, grad_temp1, grad_temp2,
grad_attn_weight, nram_hw, nram_hh, nram_lw, nram_lh, grad_h_weight,
nram_w1, nram_w2, nram_w3, nram_w4, offset_nram, num_deal_grid,
deal_num_real, num_per_time_real, num_heads, num_levels, num_points,
grid_offset, nram_h_high_temp);
// compute grad_sampling_loc
computeGradSampingLoc(grad_sampling_loc, nram_grad_output_tl,
nram_grad_output_tr, grad_h_weight, grad_w_weight,
nram_spatial_shapes, grad_temp1, grad_temp2,
nram_grad_weight, num_deal_grid, deal_num_real,
num_per_time_real, num_heads, num_levels, num_points,
grid_offset, nram_h_high_temp);
}
#endif
}
......@@ -1739,6 +2050,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight);
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
......
......@@ -47,17 +47,17 @@ typedef enum {
} MsDeformAttnBackwardKernelPolicy;
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
const int32_t channels, const int32_t num_levels,
const int32_t num_points) {
const int32_t channels, const int32_t num_levels, const int32_t num_points,
const int32_t num_heads) {
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const uint64_t max_num = nram_size / sizeof(float);
const uint64_t deal_num =
12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points;
if (max_num >= deal_num) {
const int num_hlp = num_heads * num_levels * num_points;
int num_per_time_theory = (nram_size - num_levels * sizeof(float) -
3 * num_levels * sizeof(int32_t)) /
sizeof(float) / (8 * PAD_UP(channels, 32) + 28) /
PAD_UP((num_hlp), 32);
if (num_per_time_theory >= 1) {
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
}
return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
}
......@@ -101,7 +101,8 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else if (channels > nram_size / 12 / sizeof(float)) {
} else if (channels > nram_size / 12 / sizeof(float) || channels > 96 ||
channels < 16) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else {
return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL;
......@@ -472,7 +473,8 @@ void ms_deform_attn_mlu_backward(
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
MsDeformAttnBackwardKernelPolicy kernelPolicy =
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points);
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points,
num_heads);
switch (kernelPolicy) {
default: {
VLOG(5) << "NotImplemented.";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment