Unverified Commit 5f122293 authored by BinZheng's avatar BinZheng Committed by GitHub
Browse files

[Enhancement] ms_deform_attn performance optimization (#2616)



* ms_opt_v2

* ms_opt_v2_1

* optimize MultiScaleDeformableAttention ops for MLU

* ms_opt_v2_1

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

---------
Co-authored-by: default avatardongchengwei <dongchengwei@cambricon.com>
parent ec639323
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
/**************************************************************************************** /****************************************************************************************
* *
* NRAM partition backward: * NRAM partition backward:
* default kernel
* | grad_output_nram | grad_output_nram_temp | grad_weight | * | grad_output_nram | grad_output_nram_temp | grad_weight |
* | grad_h_weight | grad_w_weight | top_grad | * | grad_h_weight | grad_w_weight | top_grad |
* | top_grad_temp | spatial_shapes_nram | sampling_loc_nram | * | top_grad_temp | spatial_shapes_nram | sampling_loc_nram |
...@@ -39,11 +40,26 @@ ...@@ -39,11 +40,26 @@
* | deal_size | deal_size | deal_size | * | deal_size | deal_size | deal_size |
* | deal_size | deal_size | 64bytes | * | deal_size | deal_size | 64bytes |
* *
* small channel kernel
* | nram_grad_output_tl | nram_grad_output_tr | nram_grad_output_bl |
* | nram_grad_output_br | grad_temp1 | grad_temp2 |
* | grad_temp3 | grad_temp4 | nram_loc_w |
* | nram_loc_h | nram_h_low | nram_w_low |
* | nram_h_high | nram_w_high | nram_h_low_temp |
* | nram_h_high_temp | nram_hw | nram_hh |
* | nram_lw | nram_lh | nram_h_low_ptr_offset |
* | nram_h_high_ptr_offset | nram_w_low_ptr_offset | nram_w_high_ptr_offset |
* | nram_w1 | nram_w2 | nram_w3 |
* | nram_w4 | nram_grad_weight | nram_base_ptr |
* | nram_offset_temp | nram_offset1 | nram_offset2 |
* | nram_offset3 | nram_offset4 | nram_w_low_temp |
* | nram_spatial_shapes | nram_level_start_index | nram_h_stride |
****************************************************************************************/ ****************************************************************************************/
#define TWELVE_SPLIT 12 #define TWELVE_SPLIT 12
#define ALIGN_NUM 32 #define ALIGN_NUM 32
#define ALIGN_NUM_FOR_REDUCE 32 #define ALIGN_NUM_FOR_REDUCE 32
#define ELE_COUNT 32
#define LEN_FLOAT sizeof(float) #define LEN_FLOAT sizeof(float)
__nram__ char nram_buffer[MAX_NRAM_SIZE]; __nram__ char nram_buffer[MAX_NRAM_SIZE];
...@@ -540,6 +556,17 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardDefault( ...@@ -540,6 +556,17 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardDefault(
return; return;
} }
__mlu_func__ void genMask0101(float *mask_ram, int32_t size) {
int32_t align_num = NFU_ALIGN_SIZE / sizeof(float);
for (int32_t i = 0; i < align_num; ++i) {
mask_ram[i] = i % 2;
}
__asm__ volatile("sync;");
__memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM,
NFU_ALIGN_SIZE, 0, size / align_num - 2);
__asm__ volatile("sync;");
}
template <typename T> template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
const char *data_value_gdram, const char *data_spatial_shapes_gdram, const char *data_value_gdram, const char *data_spatial_shapes_gdram,
...@@ -548,467 +575,471 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( ...@@ -548,467 +575,471 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) { const int32_t num_points, char *data_col_gdram) {
#if __BANG_ARCH__ >= 300
if (coreId == 0x80) { if (coreId == 0x80) {
return; return;
} }
size_t block_num_per_core, batch_start, deal_g, offset_g;
size_t block_num_rem = 0;
const size_t grid_total = num_queries * num_heads * num_levels * num_points;
if (batch_size >= taskDim) {
block_num_rem = batch_size % taskDim;
block_num_per_core = taskId < block_num_rem ? batch_size / taskDim + 1
: batch_size / taskDim;
batch_start = taskId < block_num_rem
? taskId * block_num_per_core
: taskId * block_num_per_core + block_num_rem;
deal_g = grid_total;
offset_g = 0;
} else {
size_t skip_n = taskDim / batch_size;
batch_start = taskId / skip_n;
block_num_per_core = batch_start >= batch_size ? 0 : 1;
deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points);
size_t id = taskId % skip_n;
offset_g = id * deal_g;
deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1);
}
const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float);
int32_t deal_num;
int32_t cut_channel_iter = 2;
const size_t spatial_size = const size_t spatial_size =
PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE); PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE);
const size_t level_start_index_size = const size_t level_start_index_size =
PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE); PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE);
size_t sampling_loc_size =
PAD_UP(num_levels * num_points * 2 * sizeof(T), NFU_ALIGN_SIZE); int32_t channel = channels;
size_t attn_weight_size = int32_t mult;
PAD_UP(num_levels * num_points * sizeof(T), NFU_ALIGN_SIZE); while (true) {
size_t span_num_deal = deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) /
PAD_DOWN((MAX_NRAM_SIZE - spatial_size - level_start_index_size - (8 * channel + 7) / sizeof(T);
sampling_loc_size - attn_weight_size) / deal_num = PAD_DOWN(deal_num, float_align);
TWELVE_SPLIT / sizeof(T), deal_num = PAD_DOWN(deal_num, num_levels * num_points);
NFU_ALIGN_SIZE); if (deal_num > 0) {
const int32_t channels_seg_num = channels / span_num_deal; break;
const size_t channels_rem = channels % span_num_deal; } else {
int32_t load_loc_weight_idx = 0; channel = channels / cut_channel_iter;
int32_t load_loc_weight_seg = 1; cut_channel_iter += 2;
if (channels_seg_num == 0) { }
span_num_deal = PAD_UP(channels, NFU_ALIGN_SIZE);
attn_weight_size =
PAD_DOWN((MAX_NRAM_SIZE - spatial_size - level_start_index_size -
TWELVE_SPLIT * span_num_deal * sizeof(T)) /
3,
num_levels * num_points * sizeof(T));
attn_weight_size = PAD_DOWN(attn_weight_size, NFU_ALIGN_SIZE);
sampling_loc_size = attn_weight_size * 2;
load_loc_weight_seg =
attn_weight_size / (num_levels * num_points * sizeof(T));
} }
mult = channel;
#if __BANG_ARCH__ < 322 const int32_t c_rep = channels / channel;
const size_t align_num = NFU_ALIGN_SIZE; const int32_t c_rem = channels % channel;
const size_t channels_align_rem = CEIL_ALIGN(channels_rem, align_num);
#endif const int32_t g_rep = deal_g / deal_num;
const int32_t g_rem = deal_g % deal_num;
// nram buffer alloc
char *data_spatial_shapes_nram = nram_buffer; char *data_spatial_shapes_nram = nram_buffer;
char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size; char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size;
char *data_sampling_loc_nram = char *input_tl = data_level_start_index_nram + level_start_index_size;
data_level_start_index_nram + level_start_index_size; char *input_tr = input_tl + deal_num * mult * sizeof(T);
char *data_attn_weight_nram = data_sampling_loc_nram + sampling_loc_size; char *input_bl = input_tr + deal_num * mult * sizeof(T);
char *ping_data_value_p1_nram = data_attn_weight_nram + attn_weight_size; char *input_br = input_bl + deal_num * mult * sizeof(T);
char *ping_data_value_p2_nram = char *weight_tl = input_tl + 4 * deal_num * mult * sizeof(T);
ping_data_value_p1_nram + span_num_deal * sizeof(T); char *weight_tr = weight_tl + deal_num * mult * sizeof(T);
char *ping_data_value_p3_nram = char *weight_bl = weight_tr + deal_num * mult * sizeof(T);
ping_data_value_p2_nram + span_num_deal * sizeof(T); char *weight_br = weight_bl + deal_num * mult * sizeof(T);
char *ping_data_value_p4_nram = char *mask_tl = weight_br + deal_num * mult * sizeof(T);
ping_data_value_p3_nram + span_num_deal * sizeof(T); char *mask_tr = mask_tl + deal_num * sizeof(T);
char *ping_data_col_nram = char *mask_bl = mask_tr + deal_num * sizeof(T);
ping_data_value_p4_nram + span_num_deal * sizeof(T); char *mask_br = mask_bl + deal_num * sizeof(T);
char *pong_data_value_p1_nram = char *point_ram = mask_br + deal_num * sizeof(T);
ping_data_col_nram + span_num_deal * sizeof(T); char *index_tl = point_ram + deal_num * sizeof(T);
char *pong_data_value_p2_nram = char *index_bl = index_tl + deal_num * sizeof(T);
pong_data_value_p1_nram + span_num_deal * sizeof(T);
char *pong_data_value_p3_nram = // nram space reuse
pong_data_value_p2_nram + span_num_deal * sizeof(T); char *grid_ram = weight_tl;
char *pong_data_value_p4_nram = char *mask_ram = weight_bl;
pong_data_value_p3_nram + span_num_deal * sizeof(T); char *coord_x = input_bl;
char *pong_data_col_nram = char *coord_y = coord_x + deal_num * sizeof(T);
pong_data_value_p4_nram + span_num_deal * sizeof(T); char *coord_x_low = input_tl;
char *auxiliary_a = pong_data_col_nram + span_num_deal * sizeof(T); char *coord_y_low = coord_x_low + deal_num * sizeof(T);
char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T); char *coord_x_low_int = weight_tl;
const size_t ping_pong_gap = 5 * span_num_deal * sizeof(T); char *coord_y_low_int = weight_tr;
size_t data_col_ping_pong_idx = 0; char *spatial_x = mask_tl;
char *spatial_y = mask_tr;
const int32_t block_num_rem = char *spatial_x_float = weight_bl;
(batch_size * num_queries * num_heads) % taskDim; char *spatial_y_float = weight_br;
const int32_t block_num_per_core = char *spatial_x_temp = mask_bl;
taskId < block_num_rem char *spatial_y_temp = mask_br;
? (batch_size * num_queries * num_heads) / taskDim + 1 char *base_ptr_offset = weight_tl;
: (batch_size * num_queries * num_heads) / taskDim; char *auxiliary_a = point_ram;
const int32_t idx_start = taskId < block_num_rem char *auxiliary_b = weight_bl;
? taskId * block_num_per_core
: taskId * block_num_per_core + block_num_rem;
__memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram, __memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(data_level_start_index_nram, data_level_start_index_gdram, __memcpy_async(data_level_start_index_nram, data_level_start_index_gdram,
num_levels * sizeof(int32_t), GDRAM2NRAM); num_levels * sizeof(int32_t), GDRAM2NRAM);
__asm__ volatile("sync;");
for (int32_t cur_idx = idx_start; cur_idx < idx_start + block_num_per_core; for (int32_t batch_idx = batch_start;
++cur_idx) { batch_idx < batch_start + block_num_per_core; ++batch_idx) {
// cur_idx = batch_idx * num_queries * num_heads + query_idx * num_heads + for (int32_t grid_iter = 0; grid_iter <= g_rep; ++grid_iter) {
// head_idx int32_t io_data_num = deal_num;
const int32_t head_idx = cur_idx % num_heads; const int32_t grid_off_base =
const int32_t batch_idx = (cur_idx / num_heads) / num_queries; batch_idx * grid_total + offset_g + grid_iter * deal_num;
if (grid_iter == g_rep) {
if (g_rem == 0) {
continue;
} else {
io_data_num = g_rem;
}
}
const char *data_value_gdram_start = char *data_col_gdram_start =
data_value_gdram + data_col_gdram + (batch_idx * num_queries * num_heads * channels +
batch_idx * num_keys * num_heads * channels * sizeof(T); (offset_g + grid_iter * deal_num) /
char *data_col_gdram_start = (num_levels * num_points) * channels) *
data_col_gdram + cur_idx * channels * sizeof(T); sizeof(float);
if (load_loc_weight_seg == 1 || // load data_sampling_loc
(load_loc_weight_idx % load_loc_weight_seg) == 0) {
const char *data_sampling_loc_gdram_start =
data_sampling_loc_gdram +
cur_idx * num_levels * num_points * 2 * sizeof(T);
const char *data_attn_weight_gdram_start =
data_attn_weight_gdram +
cur_idx * num_levels * num_points * sizeof(T);
const int32_t load_loc_weight_size =
(block_num_per_core - load_loc_weight_idx) < load_loc_weight_seg
? block_num_per_core - load_loc_weight_idx
: load_loc_weight_seg;
__memcpy_async( __memcpy_async(
data_sampling_loc_nram, data_sampling_loc_gdram_start, grid_ram, data_sampling_loc_gdram + grid_off_base * 2 * sizeof(float),
load_loc_weight_size * num_levels * num_points * 2 * sizeof(T), io_data_num * 2 * sizeof(float), GDRAM2NRAM);
GDRAM2NRAM); genMask0101((float *)mask_ram, deal_num * 2);
__memcpy_async(data_attn_weight_nram, data_attn_weight_gdram_start,
load_loc_weight_size * num_levels * num_points * sizeof(T),
GDRAM2NRAM);
__asm__ volatile("sync;"); __asm__ volatile("sync;");
}
const int32_t load_loc_weight_offset =
(load_loc_weight_idx % load_loc_weight_seg) * num_levels * num_points;
for (int32_t c_seg_idx = 0; c_seg_idx < channels_seg_num; ++c_seg_idx) { // generate x and y coordinate vector
__bang_write_value( // generate spatial_x and spatial_y spatial vector
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap), __bang_collect((float *)coord_y, (float *)grid_ram, (float *)mask_ram,
span_num_deal, (T)0); deal_num * 2); // y
// load data __bang_collect((float *)spatial_x_temp, (float *)data_spatial_shapes_nram,
// level_idx = 0, point_idx = 0 (float *)mask_ram,
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0]; num_levels * 2); // spatial_x
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1]; __bang_not((float *)mask_ram, (float *)mask_ram, deal_num * 2);
const char *data_value_ptr = __bang_collect((float *)coord_x, (float *)grid_ram, (float *)mask_ram,
data_value_gdram_start + c_seg_idx * span_num_deal * sizeof(T); deal_num * 2); // x
T loc_w = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2]; __bang_collect((float *)spatial_y_temp, (float *)data_spatial_shapes_nram,
T loc_h = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2 + 1]; (float *)mask_ram,
T weight = ((T *)data_attn_weight_nram)[load_loc_weight_offset]; num_levels * 2); // spatial_y
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5; for (int32_t i = 0; i < num_levels; i++) {
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { __bang_write_value((int32_t *)spatial_x + i * num_points, num_points,
loadNeighborPointsData( ((int32_t *)spatial_x_temp)[i]);
(T *)data_value_ptr, (T *)ping_data_value_p1_nram, __bang_write_value((int32_t *)spatial_y + i * num_points, num_points,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram, ((int32_t *)spatial_y_temp)[i]);
(T *)ping_data_value_p4_nram, span_num_deal, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
} }
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) { __bang_int322float_rd((float *)spatial_x_float, (int32_t *)spatial_x,
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) { num_levels * num_points, 0);
// load data __bang_int322float_rd((float *)spatial_y_float, (int32_t *)spatial_y,
if (point_idx == num_points - 1 && level_idx == num_levels - 1) { num_levels * num_points, 0);
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) { // map x from [0, 1] to [0, spatial_x]; map y from [0, 1] to [0,
const int32_t level_start_id = // spatial_y]
((int32_t *)data_level_start_index_nram)[level_idx + 1]; __bang_cycle_mul((float *)coord_x, (float *)coord_x,
const int32_t spatial_h_ptr = (level_idx + 1) << 1; (float *)spatial_x_float, deal_num,
spatial_h_next_point = num_levels * num_points);
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr]; __bang_sub_scalar((float *)coord_x, (float *)coord_x, (float)0.5,
spatial_w_next_point = deal_num);
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr + 1]; __bang_cycle_mul((float *)coord_y, (float *)coord_y,
data_value_ptr = data_value_gdram_start + (float *)spatial_y_float, deal_num,
(level_start_id * num_heads * channels + num_levels * num_points);
c_seg_idx * span_num_deal) * __bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5,
sizeof(T); deal_num);
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points + __bang_floor((float *)coord_x_low, (float *)coord_x, deal_num);
point_idx + 1) * __bang_floor((float *)coord_y_low, (float *)coord_y, deal_num);
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + // calc index_tl
level_idx * num_points + const int32_t w_stride = num_heads * channels;
point_idx + 1) * __bang_float2int32_rd((int32_t *)coord_x_low_int, (float *)coord_x_low,
2 + deal_num, 0);
1]; __bang_float2int32_rd((int32_t *)coord_y_low_int, (float *)coord_y_low,
weight_next_point = deal_num, 0);
((T *)data_attn_weight_nram)[load_loc_weight_offset + __bang_cycle_mul((int32_t *)index_tl, (int32_t *)coord_y_low_int,
level_idx * num_points + (int32_t *)spatial_x, deal_num, num_levels * num_points);
point_idx + 1]; __bang_add((int32_t *)index_tl, (int32_t *)index_tl,
x_next_point = loc_w * spatial_w_next_point - 0.5; (int32_t *)coord_x_low_int, deal_num);
y_next_point = loc_h * spatial_h_next_point - 0.5; __bang_mul_scalar((int32_t *)index_tl, (int32_t *)index_tl, w_stride,
if (y_next_point > -1 && x_next_point > -1 && deal_num);
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) { const int32_t deal_lp_num = deal_num / (num_levels * num_points);
loadNeighborPointsData( const int32_t h_rep = deal_lp_num / num_heads;
(T *)data_value_ptr, const int32_t h_rem = deal_lp_num % num_heads;
(T *)(ping_data_value_p1_nram + const int32_t head_start =
((level_idx * num_points + point_idx + 1) % 2) * ((offset_g + grid_iter * deal_num) / (num_levels * num_points)) %
ping_pong_gap), num_heads;
(T *)(ping_data_value_p2_nram + for (int32_t iter = 0; iter < num_heads; ++iter) {
((level_idx * num_points + point_idx + 1) % 2) * ((int32_t *)base_ptr_offset)[iter] =
ping_pong_gap), ((head_start + iter) % num_heads) * channels;
(T *)(ping_data_value_p3_nram + }
((level_idx * num_points + point_idx + 1) % 2) * if (h_rep > 0) {
ping_pong_gap), __memcpy((int32_t *)base_ptr_offset + num_heads,
(T *)(ping_data_value_p4_nram + (int32_t *)base_ptr_offset, num_heads * sizeof(int32_t),
((level_idx * num_points + point_idx + 1) % 2) * NRAM2NRAM, num_heads * sizeof(int32_t), 0, h_rep - 1);
ping_pong_gap), }
span_num_deal, spatial_w_next_point, spatial_h_next_point, if (h_rep > 0 && h_rem > 0) {
num_heads, channels, x_next_point, y_next_point, head_idx); __memcpy((int32_t *)base_ptr_offset + h_rep * num_heads,
} (int32_t *)base_ptr_offset, h_rem * sizeof(int32_t),
} else { NRAM2NRAM);
spatial_h_next_point = spatial_h; }
spatial_w_next_point = spatial_w; __bang_transpose((int32_t *)auxiliary_a, (int32_t *)index_tl, deal_lp_num,
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + num_levels * num_points);
level_idx * num_points + __bang_cycle_add((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
point_idx + 1) * (int32_t *)base_ptr_offset, deal_num, deal_lp_num);
2]; __bang_transpose((int32_t *)index_tl, (int32_t *)auxiliary_a,
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + num_levels * num_points, deal_lp_num);
level_idx * num_points +
point_idx + 1) * // calc index_bl
2 + __bang_mul_scalar((int32_t *)auxiliary_a, (int32_t *)spatial_x, w_stride,
1]; deal_num);
weight_next_point = __bang_cycle_add((int32_t *)index_bl, (int32_t *)index_tl,
((T *)data_attn_weight_nram)[load_loc_weight_offset + (int32_t *)auxiliary_a, deal_num,
level_idx * num_points + num_levels * num_points);
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5; // calc mask_tl, mask_tr, mask_bl, mask_br
y_next_point = loc_h * spatial_h - 0.5; __bang_sub_scalar((float *)spatial_x_float, (float *)spatial_x_float,
if (y_next_point > -1 && x_next_point > -1 && (float)1.0, deal_num);
y_next_point < spatial_h && x_next_point < spatial_w) { __bang_sub_scalar((float *)spatial_y_float, (float *)spatial_y_float,
loadNeighborPointsData( (float)1.0, deal_num);
(T *)data_value_ptr, // mask_tl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_low < spatial_y
(T *)(ping_data_value_p1_nram + __bang_ge_scalar((float *)mask_bl, (float *)coord_x_low, (float)0,
((level_idx * num_points + point_idx + 1) % 2) * deal_num);
ping_pong_gap), __bang_cycle_le((float *)mask_br, (float *)coord_x_low,
(T *)(ping_data_value_p2_nram + (float *)spatial_x_float, deal_num,
((level_idx * num_points + point_idx + 1) % 2) * num_levels * num_points);
ping_pong_gap), __bang_and((float *)mask_bl, (float *)mask_bl, (float *)mask_br,
(T *)(ping_data_value_p3_nram + deal_num);
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap), __bang_ge_scalar((float *)mask_tr, (float *)coord_y_low, (float)0,
(T *)(ping_data_value_p4_nram + deal_num);
((level_idx * num_points + point_idx + 1) % 2) * __bang_cycle_le((float *)mask_br, (float *)coord_y_low,
ping_pong_gap), (float *)spatial_y_float, deal_num,
span_num_deal, spatial_w, spatial_h, num_heads, channels, num_levels * num_points);
x_next_point, y_next_point, head_idx); __bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br,
} deal_num);
} __bang_and((float *)mask_tl, (float *)mask_tr, (float *)mask_bl,
deal_num);
// mask_tr : 0 <= coord_x_high < spatial_x && 0 <= coord_y_low < spatial_y
__bang_ge_scalar((float *)mask_br, (float *)coord_x_low, (float)(-1.0),
deal_num);
__bang_cycle_lt((float *)auxiliary_a, (float *)coord_x_low,
(float *)spatial_x_float, deal_num,
num_levels * num_points);
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
__bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br,
deal_num);
// mask_bl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_high < spatial_y
__bang_ge_scalar((float *)auxiliary_a, (float *)coord_y_low,
(float)(-1.0), deal_num);
__bang_cycle_lt((float *)auxiliary_b, (float *)coord_y_low,
(float *)spatial_y_float, deal_num,
num_levels * num_points);
__bang_and((float *)auxiliary_a, (float *)auxiliary_a,
(float *)auxiliary_b, deal_num);
__bang_and((float *)mask_bl, (float *)mask_bl, (float *)auxiliary_a,
deal_num);
// mask_br : 0 <= coord_x_high < spatial_x && 0 <= coord_y_high <
// spatial_y
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
// calc inner point num
__bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0,
deal_num);
__bang_mul_scalar((float *)weight_tr, (float *)mask_tr, (float)5.0,
deal_num);
__bang_add((float *)weight_tl, (float *)weight_tl, (float *)weight_tr,
deal_num);
__bang_mul_scalar((float *)weight_tr, (float *)mask_bl, (float)3.0,
deal_num);
__bang_add((float *)point_ram, (float *)weight_tr, (float *)mask_br,
deal_num);
__bang_add((float *)point_ram, (float *)point_ram, (float *)weight_tl,
deal_num);
// calc interpolation weight
__bang_sub((float *)weight_bl, (float *)coord_x_low, (float *)coord_x,
deal_num);
__bang_sub((float *)weight_br, (float *)coord_y_low, (float *)coord_y,
deal_num);
__bang_add_scalar((float *)weight_bl, (float *)weight_bl, (float)1.0,
deal_num);
__bang_add_scalar((float *)weight_br, (float *)weight_br, (float)1.0,
deal_num);
__bang_sub((float *)weight_tl, (float *)coord_x, (float *)coord_x_low,
deal_num);
__bang_sub((float *)weight_tr, (float *)coord_y, (float *)coord_y_low,
deal_num);
__bang_mul((float *)input_tl, (float *)weight_bl, (float *)weight_br,
deal_num);
__bang_mul((float *)input_tl + deal_num, (float *)weight_br,
(float *)weight_tl, deal_num);
__bang_mul((float *)input_tl + 2 * deal_num, (float *)weight_bl,
(float *)weight_tr, deal_num);
__bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl,
(float *)weight_tr, deal_num);
// compute __asm__ volatile("sync;");
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
computeMsDeformAttn(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b,
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
weight, span_num_deal, spatial_w, spatial_h, x, y);
}
spatial_w = spatial_w_next_point; // extend weight
spatial_h = spatial_h_next_point; const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT;
weight = weight_next_point; const int32_t w_rem = channel % ELE_COUNT;
x = x_next_point; if (w_rem != 0) {
y = y_next_point; const int32_t data_sz = 1 * sizeof(float);
__asm__ volatile("sync;"); const int32_t dst_str = channel * sizeof(float);
for (int32_t iter = w_rep; iter < channel; ++iter) {
__memcpy_async((float *)weight_tl + iter, (float *)input_tl, data_sz,
NRAM2NRAM, dst_str, data_sz, 4 * deal_num - 1);
} }
} }
// store if (w_rep != 0) {
__memcpy_async( for (int32_t i = 0; i < 4 * deal_num; i++) {
data_col_gdram_start + c_seg_idx * span_num_deal * sizeof(T), __bang_write_value((float *)weight_tl + i * channel, w_rep,
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap, ((float *)input_tl)[i]);
span_num_deal * sizeof(T), NRAM2GDRAM); }
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
if (channels_rem > 0) {
#if __BANG_ARCH__ >= 322
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
channels_rem, (T)0);
#else
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
channels_align_rem, (T)0);
#endif
// load data
// level_idx = 0, point_idx = 0
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + channels_seg_num * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2];
T loc_h = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2 + 1];
T weight = ((T *)data_attn_weight_nram)[load_loc_weight_offset];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, channels_rem, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
} }
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;"); __asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) { const char *data_value_gdram_start =
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) { data_value_gdram +
// load data batch_idx * num_keys * num_heads * channels * sizeof(float);
if (point_idx == num_points - 1 && level_idx == num_levels - 1) { const int32_t c_str = deal_num * channel * sizeof(float);
// last point no need to load data, continue to compute const int32_t cs_str = num_heads * channels * sizeof(float);
} else if (point_idx == num_points - 1) {
const int32_t level_start_id = for (int32_t c_iter = 0; c_iter <= c_rep; ++c_iter) {
((int32_t *)data_level_start_index_nram)[level_idx + 1]; int32_t c_real_num = channel;
const int32_t spatial_h_ptr = (level_idx + 1) << 1; if (c_iter == c_rep) {
spatial_h_next_point = if (c_rem == 0) {
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr]; continue;
spatial_w_next_point =
((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr + 1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
channels_seg_num * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2 +
1];
weight_next_point =
((T *)data_attn_weight_nram)[load_loc_weight_offset +
level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else { } else {
spatial_w_next_point = spatial_w; c_real_num = c_rem;
spatial_h_next_point = spatial_h;
loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2];
loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset +
level_idx * num_points +
point_idx + 1) *
2 +
1];
weight_next_point =
((T *)data_attn_weight_nram)[load_loc_weight_offset +
level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
} }
}
// compute __bang_write_zero((float *)input_tl, 4 * deal_num * channel);
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { __asm__ volatile("sync;");
#if __BANG_ARCH__ >= 322
computeMsDeformAttn( // load data_value
(T *)(ping_data_value_p1_nram + for (int32_t p_idx = 0; p_idx < io_data_num; ++p_idx) {
((level_idx * num_points + point_idx) % 2) * const int32_t inner_point_num = (int32_t)((float *)point_ram)[p_idx];
ping_pong_gap), const int32_t tl_offset = ((int32_t *)index_tl)[p_idx];
(T *)(ping_data_value_p2_nram + const int32_t bl_offset = ((int32_t *)index_bl)[p_idx];
((level_idx * num_points + point_idx) % 2) * const int32_t level_start_id =
ping_pong_gap), ((int32_t *)data_level_start_index_nram)[(p_idx / num_points) %
(T *)(ping_data_value_p3_nram + num_levels];
((level_idx * num_points + point_idx) % 2) * const char *data_value_ptr =
ping_pong_gap), data_value_gdram_start +
(T *)(ping_data_value_p4_nram + (level_start_id * num_heads * channels + c_iter * channel) *
((level_idx * num_points + point_idx) % 2) * sizeof(float);
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b, switch (inner_point_num) {
(T *)(ping_data_col_nram + case 16: // 4 points are cached.
data_col_ping_pong_idx * ping_pong_gap), __memcpy_async((float *)input_tl + p_idx * channel,
weight, channels_rem, spatial_w, spatial_h, x, y); (float *)data_value_ptr + tl_offset,
#else c_real_num * sizeof(float), GDRAM2NRAM, c_str,
computeMsDeformAttn( cs_str, 1);
(T *)(ping_data_value_p1_nram + __memcpy_async((float *)input_bl + p_idx * channel,
((level_idx * num_points + point_idx) % 2) * (float *)data_value_ptr + bl_offset,
ping_pong_gap), c_real_num * sizeof(float), GDRAM2NRAM, c_str,
(T *)(ping_data_value_p2_nram + cs_str, 1);
((level_idx * num_points + point_idx) % 2) * break;
ping_pong_gap), case 12: // 2 points are cached. (top_left, top_right)
(T *)(ping_data_value_p3_nram + __memcpy_async((float *)input_tl + p_idx * channel,
((level_idx * num_points + point_idx) % 2) * (float *)data_value_ptr + tl_offset,
ping_pong_gap), c_real_num * sizeof(float), GDRAM2NRAM, c_str,
(T *)(ping_data_value_p4_nram + cs_str, 1);
((level_idx * num_points + point_idx) % 2) * break;
ping_pong_gap), case 4: // 2 points are cached. (bottom_left, bottom_right)
(T *)auxiliary_a, (T *)auxiliary_b, __memcpy_async((float *)input_bl + p_idx * channel,
(T *)(ping_data_col_nram + (float *)data_value_ptr + bl_offset,
data_col_ping_pong_idx * ping_pong_gap), c_real_num * sizeof(float), GDRAM2NRAM, c_str,
weight, channels_align_rem, spatial_w, spatial_h, x, y); cs_str, 1);
#endif break;
case 10: // 2 points are cached. (top_left, bottom_left)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 6: // 2 points are cached. (top_right, bottom_right)
__memcpy_async(
(float *)input_tr + p_idx * channel,
(float *)data_value_ptr + tl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
__memcpy_async(
(float *)input_br + p_idx * channel,
(float *)data_value_ptr + bl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 7: // 1 point is cached. (top_left)
__memcpy_async((float *)input_tl + p_idx * channel,
(float *)data_value_ptr + tl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 5: // 1 point is cached. (top_right)
__memcpy_async(
(float *)input_tr + p_idx * channel,
(float *)data_value_ptr + tl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 3: // 1 point is cached. (bottom_left)
__memcpy_async((float *)input_bl + p_idx * channel,
(float *)data_value_ptr + bl_offset,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
case 1: // 1 point is cached. (bottom_right)
__memcpy_async(
(float *)input_br + p_idx * channel,
(float *)data_value_ptr + bl_offset + num_heads * channels,
c_real_num * sizeof(float), GDRAM2NRAM);
break;
default:
continue;
} }
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
} }
__asm__ volatile("sync;");
// interpolation
__bang_mul((float *)input_tl, (float *)input_tl, (float *)weight_tl,
4 * deal_num * channel);
__bang_add((float *)input_tl, (float *)input_tl, (float *)input_bl,
2 * deal_num * channel);
__bang_add((float *)input_tl, (float *)input_tl, (float *)input_tr,
deal_num * channel);
// load attention weight
void *attn_weight = mask_tl;
__memcpy((float *)attn_weight,
(float *)data_attn_weight_gdram + grid_off_base,
io_data_num * sizeof(float), GDRAM2NRAM);
// calc data_col, muladd attention weight
__bang_transpose((float *)input_tr, (float *)input_tl, deal_num,
channel);
__bang_cycle_mul((float *)input_tr, (float *)input_tr,
(float *)attn_weight, deal_num * channel, deal_num);
__bang_transpose((float *)input_tl, (float *)input_tr, channel,
deal_num);
__bang_sumpool((float *)input_bl, (float *)input_tl, channel, 1,
io_data_num, 1, num_levels * num_points,
num_levels * num_points, 1);
// store
__memcpy((float *)data_col_gdram_start + c_iter * channel,
(float *)input_bl, c_real_num * sizeof(float), NRAM2GDRAM,
channels * sizeof(float), channel * sizeof(float),
(io_data_num / (num_levels * num_points)) - 1);
} }
// store
__memcpy_async(
data_col_gdram_start + channels_seg_num * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
channels_rem * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
} }
load_loc_weight_idx += 1;
} }
__asm__ volatile("sync;"); __asm__ volatile("sync;");
#endif
return; return;
} }
...@@ -1316,294 +1347,496 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel( ...@@ -1316,294 +1347,496 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
} }
} }
template <typename T> void __mlu_func__ computeGridMaskAndOffset(
void __mlu_func__ float *nram_grad_output_tl, float *nram_grad_output_tr, float *nram_loc_w,
loadData(const int32_t &h_low, const int32_t &w_low, const int32_t &h_high, float *nram_loc_h, float *nram_h_stride, int32_t *nram_spatial_shapes,
const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tr, float *nram_w_low_temp, float *nram_h_high_temp, float *nram_w_low,
T *grad_output_nram_bl, T *grad_output_nram_br, float *nram_h_low, float *nram_h_high, float *nram_w_high, float *nram_lh,
const T *data_value_ptr, const int32_t &width, const int32_t &height, float *nram_lw, float *nram_hh, float *nram_hw,
const int32_t &deal_num_real, const int32_t &h_low_ptr_offset, float *nram_h_low_ptr_offset, float *nram_h_high_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &w_high_ptr_offset, float *nram_w_low_ptr_offset, float *nram_w_high_ptr_offset, float *nram_w1,
const int32_t &h_high_ptr_offset, const int32_t &base_ptr) { float *nram_w2, float *nram_w3, float *nram_w4, float *nram_offset_temp,
#if __BANG_ARCH__ > 322 float *nram_offset1, float *nram_offset2, float *nram_offset3,
if (h_low >= 0 && w_low >= 0) float *nram_offset4, float *nram_base_ptr, float *nram_h_low_temp,
int32_t num_deal_grid, int32_t num_per_time_real, const int32_t num_heads,
{ const int32_t num_levels, const int32_t num_points, const int32_t w_stride,
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; const int32_t qid_stride) {
__memcpy_async(grad_output_nram_tl, data_value_ptr + offset1, #if __BANG_ARCH__ >= 322
deal_num_real * sizeof(T), GDRAM2NRAM); // [num_levels, 2] --> [2, num_levels]
} __bang_transpose(nram_grad_output_tl, nram_loc_w, num_deal_grid, 2);
if (h_low >= 0 && w_high <= width - 1) __bang_transpose(nram_loc_w, nram_grad_output_tl,
num_per_time_real * num_heads * num_levels, num_points);
{ __bang_transpose(nram_loc_h, nram_grad_output_tl + num_deal_grid,
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; num_per_time_real * num_heads * num_levels, num_points);
__memcpy_async(grad_output_nram_tr, data_value_ptr + offset2, __bang_int322float((float *)nram_spatial_shapes,
deal_num_real * sizeof(T), GDRAM2NRAM); (int32_t *)nram_spatial_shapes, num_levels * 2, 0);
} __bang_transpose(nram_grad_output_tr, (float *)nram_spatial_shapes,
if (h_high <= height - 1 && w_low >= 0) num_levels, 2);
__bang_mul_scalar(nram_h_stride, nram_grad_output_tr + num_levels, w_stride,
{ num_levels);
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; __memcpy_async(nram_spatial_shapes, nram_grad_output_tr,
__memcpy_async(grad_output_nram_bl, data_value_ptr + offset3, num_levels * 2 * sizeof(float), NRAM2NRAM);
deal_num_real * sizeof(T), GDRAM2NRAM); __bang_cycle_mul(nram_loc_w, nram_loc_w,
} (float *)nram_spatial_shapes + num_levels, num_deal_grid,
if (h_high <= height - 1 && w_high <= width - 1) num_levels);
__bang_cycle_mul(nram_loc_h, nram_loc_h, (float *)(nram_spatial_shapes),
{ num_deal_grid, num_levels);
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; __bang_sub_scalar(nram_loc_w, nram_loc_w, 0.5, num_deal_grid);
__memcpy_async(grad_output_nram_br, data_value_ptr + offset4, __bang_sub_scalar(nram_loc_h, nram_loc_h, 0.5, num_deal_grid);
deal_num_real * sizeof(T), GDRAM2NRAM); // get mask. (h_im > -1 && w_im > -1 &&
} // h_im < spatial_h && w_im < spatial_w)
__sync_io(); __bang_cycle_lt(nram_w_low_temp, nram_loc_w,
(float *)(nram_spatial_shapes + num_levels), num_deal_grid,
num_levels);
__bang_cycle_lt(nram_h_high_temp, nram_loc_h, (float *)(nram_spatial_shapes),
num_deal_grid, num_levels);
__bang_and(nram_w_low_temp, nram_w_low_temp, nram_h_high_temp, num_deal_grid);
__bang_gt_scalar(nram_h_high_temp, nram_loc_h, -1, num_deal_grid);
__bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp,
num_deal_grid);
__bang_gt_scalar(nram_w_low_temp, nram_loc_w, -1, num_deal_grid);
__bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp,
num_deal_grid);
__bang_transpose(nram_w_low_temp, nram_h_high_temp, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_high_temp, nram_w_low_temp,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, nram_loc_w, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_loc_w, nram_grad_output_tl, num_deal_grid * sizeof(float),
NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, nram_loc_h, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_loc_h, nram_grad_output_tl, num_deal_grid * sizeof(float),
NRAM2NRAM);
__bang_floor(nram_w_low, nram_loc_w, num_deal_grid);
__bang_floor(nram_h_low, nram_loc_h, num_deal_grid);
__bang_add_scalar(nram_h_high, nram_h_low, 1, num_deal_grid);
__bang_add_scalar(nram_w_high, nram_w_low, 1, num_deal_grid);
__bang_sub(nram_lh, nram_loc_h, nram_h_low, num_deal_grid);
__bang_sub(nram_lw, nram_loc_w, nram_w_low, num_deal_grid);
__bang_fusion(FUSION_FMA, nram_hh, nram_lh, (float)(-1), 1, num_deal_grid);
__bang_fusion(FUSION_FMA, nram_hw, nram_lw, (float)(-1), 1, num_deal_grid);
__bang_transpose(nram_h_low_ptr_offset, nram_h_low,
num_per_time_real * num_heads * num_levels, num_points);
__bang_cycle_mul(nram_h_low_ptr_offset, nram_h_low_ptr_offset, nram_h_stride,
num_deal_grid, num_levels);
__bang_cycle_add(nram_h_high_ptr_offset, nram_h_low_ptr_offset, nram_h_stride,
num_deal_grid, num_levels);
__bang_transpose(nram_w_low_ptr_offset, nram_h_low_ptr_offset, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_low_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_w_low_ptr_offset, nram_h_high_ptr_offset, num_points,
num_per_time_real * num_heads * num_levels);
__memcpy_async(nram_h_high_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_mul_scalar(nram_w_low_ptr_offset, nram_w_low, qid_stride,
num_deal_grid);
__bang_add_scalar(nram_w_high_ptr_offset, nram_w_low_ptr_offset, qid_stride,
num_deal_grid);
__bang_mul(nram_w1, nram_hh, nram_hw, num_deal_grid);
__bang_mul(nram_w2, nram_hh, nram_lw, num_deal_grid);
__bang_mul(nram_w3, nram_lh, nram_hw, num_deal_grid);
__bang_mul(nram_w4, nram_lh, nram_lw, num_deal_grid);
__bang_add(nram_offset1, nram_h_low_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset1,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset1, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset2, nram_h_low_ptr_offset, nram_w_high_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset2,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset2, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset3, nram_h_high_ptr_offset, nram_w_low_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset3,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset3, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
__bang_add(nram_offset4, nram_h_high_ptr_offset, nram_w_high_ptr_offset,
num_deal_grid);
__bang_transpose(nram_offset_temp, nram_offset4,
num_per_time_real * num_heads, num_levels * num_points);
__bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr,
num_deal_grid, num_heads);
__bang_transpose(nram_offset4, nram_offset_temp, num_levels * num_points,
num_per_time_real * num_heads);
// h_low >= 0 && w_low >= 0 mask2
float *mask1 = nram_h_low_ptr_offset;
float *mask2 = nram_h_high_ptr_offset;
float *mask3 = nram_w_low_ptr_offset;
float *mask4 = nram_w_high_ptr_offset;
__bang_ge_scalar(mask1, nram_h_low, 0, num_deal_grid);
__bang_ge_scalar(mask2, nram_w_low, 0, num_deal_grid);
__bang_and(mask2, mask1, mask2, num_deal_grid);
__bang_and(mask2, nram_h_high_temp, mask2, num_deal_grid);
// h_low >= 0 && w_high <= width - 1 mask1
__bang_transpose(mask3, nram_w_high,
num_per_time_real * num_heads * num_levels, num_points);
__bang_sub_scalar(nram_spatial_shapes, nram_spatial_shapes, 1,
num_levels * 2);
__bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes + num_levels),
num_deal_grid, num_levels);
__bang_transpose(mask4, mask3, num_points,
num_per_time_real * num_heads * num_levels);
__bang_and(mask1, mask1, mask4, num_deal_grid);
__bang_and(mask1, nram_h_high_temp, mask1, num_deal_grid);
// h_high <= height - 1 && w_high <= width - 1 mask3
__bang_transpose(mask3, nram_h_high,
num_per_time_real * num_heads * num_levels, num_points);
__bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes), num_deal_grid,
num_levels);
__bang_transpose(nram_h_low_temp, mask3, num_points,
num_per_time_real * num_heads * num_levels);
__bang_and(mask4, mask4, nram_h_low_temp, num_deal_grid);
__bang_and(mask3, mask4, nram_h_high_temp, num_deal_grid);
// h_high <= height - 1 && w_low >= 0 mask4
__bang_ge_scalar(nram_w_low_temp, nram_w_low, 0, num_deal_grid);
__bang_and(mask4, nram_h_low_temp, nram_w_low_temp, num_deal_grid);
__bang_and(mask4, mask4, nram_h_high_temp, num_deal_grid);
#endif #endif
} }
template <typename T> void __mlu_func__ loadValue(
void __mlu_func__ computeData( float *nram_grad_output_tl, float *nram_grad_output_tr,
const int32_t &h_low, const int32_t &w_low, const int32_t &h_high, float *nram_grad_output_bl, float *nram_grad_output_br,
const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tr, const float *data_value, const float *grad_output, float *grad_temp1,
T *grad_output_nram_bl, T *grad_output_nram_br, T *grad_output_nram_tl_temp, float *grad_temp2, float *mask1, float *mask2, float *mask3, float *mask4,
T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp, float *nram_offset1, float *nram_offset2, float *nram_offset3,
T *grad_output_nram_br_temp, const int32_t &width, const int32_t &height, float *nram_offset4, float *nram_grad_weight,
const int32_t &deal_num_real, T *grad_h_weight, T *grad_w_weight, int32_t *nram_level_start_index, int32_t offset_nram,
T *top_grad_temp, T *top_grad, const T &data_attn_weight, const T &hw, int32_t start_per_core, int32_t grid_loop, int32_t num_per_time_theory,
const T &hh, const T &lw, const T &lh, const T &w1, const T &w2, int32_t num_heads, int32_t deal_num_real, int32_t num_per_time_real,
const T &w3, const T &w4) { int32_t num_deal_grid, const int32_t num_query, const int32_t num_levels,
#if __BANG_ARCH__ > 322 const int32_t num_points, int32_t grid_offset, const int32_t spatial_size,
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); const int32_t qid_stride) {
if (h_low >= 0 && w_low >= 0) { #if __BANG_ARCH__ >= 322
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tl, (float)(-hw), int32_t value_offset_temp = 0;
grad_h_weight, deal_num_real, deal_num_real); __bang_write_zero(nram_grad_output_tl, 4 * offset_nram);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tl, (float)(-hh), __sync_io_move_compute();
grad_w_weight, deal_num_real, deal_num_real); __memcpy_async(
__bang_mul_scalar(grad_output_nram_tl_temp, top_grad_temp, w1, grad_temp2,
deal_num_real); grad_output + (start_per_core + grid_loop * num_per_time_theory) *
// for calc grad_attn_weight num_heads * deal_num_real,
__bang_mul_scalar(grad_output_nram_tl, grad_output_nram_tl, w1, num_per_time_real * num_heads * deal_num_real * sizeof(float),
deal_num_real); GDRAM2NRAM);
} for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (h_low >= 0 && w_high <= width - 1) { const int32_t b_col =
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tr, (float)(-lw), (grid_offset + loop) / num_query / num_heads / num_levels / num_points;
grad_h_weight, deal_num_real, deal_num_real); const int32_t l_col = (grid_offset + loop) / num_points % num_levels;
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tr, (float)(hh), const int32_t level_start_id = nram_level_start_index[l_col];
grad_w_weight, deal_num_real, deal_num_real); value_offset_temp =
__bang_mul_scalar(grad_output_nram_tr_temp, top_grad_temp, w2, b_col * spatial_size * qid_stride + level_start_id * qid_stride;
deal_num_real); if (mask2[loop]) {
__bang_mul_scalar(grad_output_nram_tr, grad_output_nram_tr, w2, __memcpy_async(
deal_num_real); nram_grad_output_tl + loop * deal_num_real,
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_tr, data_value + value_offset_temp + int32_t(nram_offset1[loop]),
deal_num_real); deal_num_real * sizeof(float), GDRAM2NRAM);
} }
if (h_high <= height - 1 && w_low >= 0) { if (mask1[loop]) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_bl, (float)(hw), __memcpy_async(
grad_h_weight, deal_num_real, deal_num_real); nram_grad_output_tr + loop * deal_num_real,
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_bl, (float)(-lh), data_value + value_offset_temp + int32_t(nram_offset2[loop]),
grad_w_weight, deal_num_real, deal_num_real); deal_num_real * sizeof(float), GDRAM2NRAM);
__bang_mul_scalar(grad_output_nram_bl_temp, top_grad_temp, w3, }
deal_num_real); if (mask4[loop]) {
// for calc grad_attn_weight __memcpy_async(
__bang_mul_scalar(grad_output_nram_bl, grad_output_nram_bl, w3, nram_grad_output_bl + loop * deal_num_real,
deal_num_real); data_value + value_offset_temp + int32_t(nram_offset3[loop]),
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_bl, deal_num_real * sizeof(float), GDRAM2NRAM);
deal_num_real); }
if (mask3[loop]) {
__memcpy_async(
nram_grad_output_br + loop * deal_num_real,
data_value + value_offset_temp + int32_t(nram_offset4[loop]),
deal_num_real * sizeof(float), GDRAM2NRAM);
}
} }
if (h_high <= height - 1 && w_high <= width - 1) { for (int32_t m = 0; m < deal_num_real; ++m) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_br, (float)(lw), __memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight,
grad_h_weight, deal_num_real, deal_num_real); num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_br, (float)(lh),
grad_w_weight, deal_num_real, deal_num_real);
__bang_mul_scalar(grad_output_nram_br_temp, top_grad_temp, w4,
deal_num_real);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_br, grad_output_nram_br, w4,
deal_num_real);
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_br,
deal_num_real);
} }
__bang_mul(grad_output_nram_tl, grad_output_nram_tl, top_grad, deal_num_real); __sync_io_move_compute();
recursiveSumPool(grad_output_nram_tl, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
__bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real);
__bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real);
recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
__bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real);
__bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real);
recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#endif #endif
} }
template <typename T> void __mlu_func__ computeGradValue(
void __mlu_func__ storeData( float *grad_temp1, float *grad_temp2, float *grad_temp3, float *grad_temp4,
const int32_t &h_low, const int32_t &w_low, const int32_t &h_high, float *mask1, float *mask2, float *mask3, float *mask4, float *nram_offset1,
const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tl_temp, float *nram_offset2, float *nram_offset3, float *nram_offset4,
T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp, int32_t *nram_level_start_index, int32_t deal_num_real,
T *grad_output_nram_br_temp, const int32_t &width, const int32_t &height, const float *grad_value, float *nram_w1, float *nram_w2, float *nram_w3,
const int32_t &deal_num_real, const int32_t &h_low_ptr_offset, float *nram_w4, int32_t num_per_time_real, const int32_t num_heads,
const int32_t &w_low_ptr_offset, const int32_t &w_high_ptr_offset, const int32_t num_levels, const int32_t num_points, const int32_t num_query,
const int32_t &h_high_ptr_offset, const int32_t &base_ptr, T *grad_value, int32_t num_deal_grid, int32_t grid_offset, const int32_t spatial_size,
T *grad_w_weight, T *grad_h_weight, T *grad_sampling_loc, const int32_t qid_stride, float *nram_grid_offset1,
T *grad_attn_weight) { float *nram_grid_offset2) {
#if __BANG_ARCH__ > 322 #if __BANG_ARCH__ >= 322
if (h_low >= 0 && w_low >= 0) __bang_transpose(grad_temp3, grad_temp1,
deal_num_real * num_per_time_real * num_heads,
{ num_levels * num_points);
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; __bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
__bang_atomic_add((T *)grad_output_nram_tl_temp, deal_num_real);
(T *)(grad_value + offset1), __bang_cycle_mul(grad_temp3, grad_temp3, grad_temp1,
(T *)grad_output_nram_tl_temp, deal_num_real); num_deal_grid * deal_num_real,
deal_num_real * num_per_time_real * num_heads);
__bang_transpose(grad_temp4, grad_temp3, num_levels * num_points,
deal_num_real * num_per_time_real * num_heads);
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w1,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
nram_grid_offset1[loop] = ((loop + grid_offset) / num_query / num_heads /
num_levels / num_points) *
spatial_size * qid_stride;
} }
if (h_low >= 0 && w_high <= width - 1) __bang_transpose(nram_grid_offset2, nram_grid_offset1,
num_per_time_real * num_heads * num_levels, num_points);
{ __bang_int322float((float *)nram_level_start_index, nram_level_start_index,
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; num_levels, 0);
__bang_atomic_add((T *)grad_output_nram_tr_temp, __bang_mul_scalar(nram_grid_offset1, (float *)nram_level_start_index,
(T *)(grad_value + offset2), qid_stride, num_levels);
(T *)grad_output_nram_tr_temp, deal_num_real); __bang_cycle_add(nram_grid_offset2, nram_grid_offset2, nram_grid_offset1,
num_deal_grid, num_levels);
__bang_transpose(nram_grid_offset1, nram_grid_offset2, num_points,
num_per_time_real * num_heads * num_levels);
__bang_add(nram_offset1, nram_offset1, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset2, nram_offset2, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset3, nram_offset3, nram_grid_offset1, num_deal_grid);
__bang_add(nram_offset4, nram_offset4, nram_grid_offset1, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask2[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset1[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
} }
if (h_high <= height - 1 && w_low >= 0) __bang_cycle_mul(grad_temp1, grad_temp4, nram_w2,
num_deal_grid * deal_num_real, num_deal_grid);
{ __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
__bang_atomic_add((T *)grad_output_nram_bl_temp, if (mask1[loop]) {
(T *)(grad_value + offset3), __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(T *)grad_output_nram_bl_temp, deal_num_real); (float *)(grad_value + int32_t(nram_offset2[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
}
__bang_cycle_mul(grad_temp1, grad_temp4, nram_w3,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
if (mask4[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset3[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
} }
if (h_high <= height - 1 && w_high <= width - 1)
{ __bang_cycle_mul(grad_temp1, grad_temp4, nram_w4,
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; num_deal_grid * deal_num_real, num_deal_grid);
__bang_atomic_add((T *)grad_output_nram_br_temp, __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid);
(T *)(grad_value + offset4), for (int32_t loop = 0; loop < num_deal_grid; ++loop) {
(T *)grad_output_nram_br_temp, deal_num_real); if (mask3[loop]) {
__bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real),
(float *)(grad_value + int32_t(nram_offset4[loop])),
(float *)(grad_temp3 + loop * deal_num_real),
deal_num_real);
}
} }
__bang_atomic_add((T *)grad_output_nram_tl, (T *)grad_attn_weight,
(T *)grad_output_nram_tl, 1);
__bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc),
(T *)grad_w_weight, 1);
__bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1),
(T *)grad_h_weight, 1);
#endif #endif
} }
template <typename T> void __mlu_func__ computeGradAttnWeight(
void __mlu_func__ msDeformAttnCol2imBilinearSmallChannels( float *grad_w_weight, float *grad_weight, float *nram_grad_output_tl,
T *top_grad_temp, const int32_t &height, const int32_t &width, const T &w1, float *nram_grad_output_tr, float *nram_grad_output_bl,
const T &w2, const T &w3, const T &w4, const int32_t &h_low, float *nram_grad_output_br, float *grad_temp1, float *grad_temp2,
const int32_t &w_low, const int32_t &h_high, const int32_t &w_high, const float *grad_attn_weight, float *nram_hw, float *nram_hh,
const int32_t &base_ptr, const int32_t &h_low_ptr_offset, float *nram_lw, float *nram_lh, float *grad_h_weight, float *nram_w1,
const int32_t &w_low_ptr_offset, const int32_t &h_high_ptr_offset, float *nram_w2, float *nram_w3, float *nram_w4, int32_t offset_nram,
const int32_t &w_high_ptr_offset, const T &hh, const T &hw, const T &lh, int32_t num_deal_grid, int32_t deal_num_real, int32_t num_per_time_real,
const T &lw, T *top_grad, const T &data_attn_weight, T *grad_h_weight, const int32_t num_heads, const int32_t num_levels, const int32_t num_points,
T *grad_w_weight, T *grad_value, T *grad_output_nram_tl, int32_t grid_offset, float *nram_h_high_temp) {
T *grad_output_nram_tr, T *grad_output_nram_bl, T *grad_output_nram_br, #if __BANG_ARCH__ >= 322
T *grad_output_nram_tl_temp, T *grad_output_nram_tr_temp, __bang_write_zero(grad_w_weight, 2 * offset_nram);
T *grad_output_nram_bl_temp, T *grad_output_nram_br_temp,
T *grad_sampling_loc, T *grad_attn_weight, const int32_t &deal_num_real, // grad_output_nram_tl
const T *data_value_ptr) __bang_transpose(grad_weight, nram_grad_output_tl, num_deal_grid,
deal_num_real);
{ __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hw,
loadData(h_low, w_low, h_high, w_high, grad_output_nram_tl, num_deal_grid * deal_num_real, num_deal_grid);
grad_output_nram_tr, grad_output_nram_bl, grad_output_nram_br, __bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tl,
data_value_ptr, width, height, deal_num_real, h_low_ptr_offset, num_deal_grid * deal_num_real);
w_low_ptr_offset, w_high_ptr_offset, h_high_ptr_offset, base_ptr); __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hh,
computeData(h_low, w_low, h_high, w_high, grad_output_nram_tl, num_deal_grid * deal_num_real, num_deal_grid);
grad_output_nram_tr, grad_output_nram_bl, grad_output_nram_br, __bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_tl,
grad_output_nram_tl_temp, grad_output_nram_tr_temp, num_deal_grid * deal_num_real);
grad_output_nram_bl_temp, grad_output_nram_br_temp, width, height, __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_w1,
deal_num_real, grad_h_weight, grad_w_weight, top_grad_temp, num_deal_grid * deal_num_real, num_deal_grid);
top_grad, data_attn_weight, hw, hh, lw, lh, w1, w2, w3, w4); // nram_grad_output_tr
storeData(h_low, w_low, h_high, w_high, grad_output_nram_tl, __bang_transpose(grad_weight, nram_grad_output_tr, num_deal_grid,
grad_output_nram_tl_temp, grad_output_nram_tr_temp, deal_num_real);
grad_output_nram_bl_temp, grad_output_nram_br_temp, width, height, __bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_lw,
deal_num_real, h_low_ptr_offset, w_low_ptr_offset, num_deal_grid * deal_num_real, num_deal_grid);
w_high_ptr_offset, h_high_ptr_offset, base_ptr, grad_value, __bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tr,
grad_w_weight, grad_h_weight, grad_sampling_loc, grad_attn_weight); num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_hh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_w_weight, grad_w_weight, nram_grad_output_tr,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_w2,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_tr,
num_deal_grid * deal_num_real);
// nram_grad_output_tl
__bang_transpose(grad_weight, nram_grad_output_bl, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_hw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_h_weight, grad_h_weight, nram_grad_output_bl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_lh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_bl,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_w3,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_bl,
num_deal_grid * deal_num_real);
// nram_grad_output_br
__bang_transpose(grad_weight, nram_grad_output_br, num_deal_grid,
deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lw,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_h_weight, grad_h_weight, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lh,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(grad_w_weight, grad_w_weight, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_cycle_mul(nram_grad_output_br, grad_weight, nram_w4,
num_deal_grid * deal_num_real, num_deal_grid);
__bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_br,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_br, nram_grad_output_tl, deal_num_real,
num_deal_grid);
__bang_transpose(nram_grad_output_tr, nram_grad_output_br,
num_per_time_real * num_heads,
num_points * num_levels * deal_num_real);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1,
num_deal_grid * deal_num_real,
num_per_time_real * num_heads * deal_num_real);
__bang_transpose(nram_grad_output_br, nram_grad_output_tr,
num_points * num_levels * deal_num_real,
num_per_time_real * num_heads);
__bang_transpose((float *)nram_grad_output_tr, (float *)nram_grad_output_br,
num_deal_grid, deal_num_real);
recursiveSumPool(nram_grad_output_tr, num_deal_grid, deal_num_real,
ALIGN_NUM);
__bang_float2int32((int *)nram_h_high_temp, nram_h_high_temp, num_deal_grid,
0);
__nram__ int table[2] = {0, (int)0xffffffff};
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)nram_grad_output_tr, (char *)nram_grad_output_tr,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__bang_atomic_add((float *)nram_grad_output_tr,
(float *)grad_attn_weight + grid_offset,
(float *)nram_grad_output_tr, num_deal_grid);
#endif
} }
template <typename T> void __mlu_func__ computeGradSampingLoc(
void __mlu_func__ msDeformAttnCol2imImpl( const float *grad_sampling_loc, float *nram_grad_output_tl,
T *top_grad_temp, T *top_grad, T *grad_h_weight, T *grad_w_weight, float *nram_grad_output_tr, float *grad_h_weight, float *grad_w_weight,
T *grad_value, T *grad_output_nram_tl, T *grad_output_nram_tr, int32_t *nram_spatial_shapes, float *grad_temp1, float *grad_temp2,
T *grad_output_nram_bl, T *grad_output_nram_br, T *grad_output_nram_tl_temp, float *nram_grad_weight, int32_t num_deal_grid, int32_t deal_num_real,
T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp, int32_t num_per_time_real, const int32_t num_heads,
T *grad_output_nram_br_temp, T *grad_sampling_loc, T *grad_attn_weight, const int32_t num_levels, const int32_t num_points, int32_t grid_offset,
T *nram_sampling_loc, T *nram_attn_weight, const int32_t &load_num, float *nram_h_high_temp) {
const int32_t &tail, const int32_t &i_repeat, const int32_t &num_points, #if __BANG_ARCH__ >= 322
const int32_t &start_per_core, const int32_t &num_levels, __bang_transpose(nram_grad_output_tl, grad_h_weight,
const int32_t &num_heads, const int32_t &num_query, num_per_time_real * num_heads * num_levels * deal_num_real,
const int32_t &spatial_size, const int32_t &qid_stride, num_points);
int32_t *level_start_index_nram, const int32_t &channels, __bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl,
const T *data_value, const T *grad_output, int32_t *spatial_shapes_nram) { (float *)nram_spatial_shapes, num_deal_grid * deal_num_real,
#if __BANG_ARCH__ > 322 num_levels);
int32_t weight_pos = 0; __bang_transpose(grad_h_weight, nram_grad_output_tl,
int32_t sampling_loc_pos = 0; num_points * deal_num_real,
for (int32_t p = 0; p < tail; ++p) { num_per_time_real * num_heads * num_levels);
int32_t grid_offset = start_per_core + i_repeat * load_num + p; for (int32_t m = 0; m < deal_num_real; ++m) {
const int32_t l_col = grid_offset % num_levels; __memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight,
const int32_t m_col = grid_offset / num_levels % num_heads; num_deal_grid * sizeof(float), NRAM2NRAM);
const int32_t q_col = grid_offset / num_levels / num_heads % num_query;
const int32_t b_col = grid_offset / num_query / num_heads / num_levels;
const int32_t value_offset = b_col * spatial_size * qid_stride;
const int32_t level_start_id = level_start_index_nram[l_col];
const int32_t grad_attn_weight_out = grid_offset * num_points;
const int32_t spatial_h_ptr = l_col << 1;
const int32_t grad_output_offset =
b_col * num_query * qid_stride + q_col * qid_stride + m_col * channels;
__memcpy(top_grad, grad_output + grad_output_offset, channels * LEN_FLOAT,
GDRAM2NRAM);
const int32_t spatial_h = spatial_shapes_nram[spatial_h_ptr];
const int32_t spatial_w = spatial_shapes_nram[spatial_h_ptr + 1];
const int32_t h_stride = spatial_w * qid_stride;
const int32_t value_ptr_offset = value_offset + level_start_id * qid_stride;
const float *data_value_ptr = data_value + value_ptr_offset;
float *grad_value_ptr = grad_value + value_ptr_offset;
const int32_t grad_sampling_loc_out = grid_offset * num_points << 1;
for (int32_t p_col = 0; p_col < num_points; ++p_col) {
const float loc_w = nram_sampling_loc[sampling_loc_pos];
const float loc_h = nram_sampling_loc[sampling_loc_pos + 1];
const float weight = nram_attn_weight[weight_pos];
const float h_im = loc_h * spatial_h - 0.5;
const float w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
const int32_t h_low = floorf(h_im);
const int32_t w_low = floorf(w_im);
const int32_t h_high = h_low + 1;
const int32_t w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1.0 - lh;
const float hw = 1.0 - lw;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * qid_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + qid_stride;
const float w1 = hh * hw;
const float w2 = hh * lw;
const float w3 = lh * hw;
const float w4 = lh * lw;
const int32_t base_ptr = m_col * channels;
__bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM));
__bang_write_zero(grad_output_nram_tl, PAD_UP(channels, ALIGN_NUM));
msDeformAttnCol2imBilinearSmallChannels(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl,
grad_output_nram_br, grad_output_nram_tl_temp,
grad_output_nram_tr_temp, grad_output_nram_bl_temp,
grad_output_nram_br_temp,
grad_sampling_loc + grad_sampling_loc_out + (p_col << 1),
grad_attn_weight + grad_attn_weight_out + p_col, channels,
data_value_ptr);
}
weight_pos += 1;
sampling_loc_pos += 2;
}
} }
__sync_io_move_compute();
__bang_transpose(nram_grad_output_tr, grad_temp1,
deal_num_real * num_per_time_real * num_heads,
num_levels * num_points);
__bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads,
deal_num_real);
__bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1,
num_deal_grid * deal_num_real,
deal_num_real * num_per_time_real * num_heads);
__bang_transpose(grad_temp1, nram_grad_output_tr,
num_levels * num_points * deal_num_real,
num_per_time_real * num_heads);
__bang_mul(grad_h_weight, grad_h_weight, grad_temp1,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_tl, grad_h_weight, num_deal_grid,
deal_num_real);
__memcpy_async(grad_h_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM);
recursiveSumPool(grad_h_weight, num_deal_grid, deal_num_real, ALIGN_NUM);
__nram__ int table[2] = {0, (int)0xffffffff};
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)grad_h_weight, (char *)grad_h_weight,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__bang_transpose(nram_grad_output_tl, grad_w_weight,
num_per_time_real * num_heads * num_levels * deal_num_real,
num_points);
__bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl,
(float *)(nram_spatial_shapes + num_levels),
num_deal_grid * deal_num_real, num_levels);
__bang_transpose(grad_w_weight, nram_grad_output_tl,
num_points * deal_num_real,
num_per_time_real * num_heads * num_levels);
__bang_mul(grad_w_weight, grad_w_weight, grad_temp1,
num_deal_grid * deal_num_real);
__bang_transpose(nram_grad_output_tl, grad_w_weight, num_deal_grid,
deal_num_real);
__memcpy(grad_w_weight, nram_grad_output_tl,
num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM);
recursiveSumPool(grad_w_weight, num_deal_grid, deal_num_real, ALIGN_NUM);
__bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table,
num_deal_grid, 64);
__bang_band((char *)grad_w_weight, (char *)grad_w_weight,
(char *)nram_h_high_temp, num_deal_grid * sizeof(float));
__memcpy(grad_w_weight + num_deal_grid, grad_h_weight,
num_deal_grid * sizeof(float), NRAM2NRAM);
__bang_transpose(nram_grad_output_tl, grad_w_weight, 2, num_deal_grid);
__bang_atomic_add((float *)nram_grad_output_tl,
(float *)grad_sampling_loc + grid_offset * 2,
(float *)nram_grad_output_tl, 2 * num_deal_grid);
#endif #endif
} }
...@@ -1616,117 +1849,195 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel( ...@@ -1616,117 +1849,195 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const int32_t num_points, float *grad_value, float *grad_sampling_loc, const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight) { float *grad_attn_weight) {
#if __BANG_ARCH__ > 322 #if __BANG_ARCH__ > 322
const int32_t split_num = 12; const int32_t split_grid_num = 28;
const int32_t split_num_c = 8;
const int32_t C_align = PAD_UP(channels, ALIGN_NUM); const int32_t C_align = PAD_UP(channels, ALIGN_NUM);
float *grad_output_nram_tl = (float *)nram_buffer;
float *grad_output_nram_tr = (float *)nram_buffer + C_align;
float *grad_output_nram_bl = (float *)nram_buffer + 2 * C_align;
float *grad_output_nram_br = (float *)nram_buffer + 3 * C_align;
float *grad_output_nram_tl_temp = (float *)nram_buffer + 4 * C_align;
float *grad_output_nram_tr_temp = (float *)nram_buffer + 5 * C_align;
float *grad_output_nram_bl_temp = (float *)nram_buffer + 6 * C_align;
float *grad_output_nram_br_temp = (float *)nram_buffer + 7 * C_align;
float *grad_h_weight = (float *)nram_buffer + 8 * C_align;
float *grad_w_weight = (float *)nram_buffer + 9 * C_align;
float *top_grad_temp = (float *)nram_buffer + 10 * C_align;
float *top_grad = (float *)nram_buffer + 11 * C_align;
int32_t *spatial_shapes_nram = const int32_t num_hlp = num_heads * num_levels * num_points;
(int32_t *)((float *)nram_buffer + split_num * C_align); int32_t num_per_time_theory = (MAX_NRAM_SIZE - num_levels * sizeof(float) -
int32_t *level_start_index_nram = 3 * num_levels * sizeof(int32_t)) /
(int32_t *)(spatial_shapes_nram + PAD_UP(num_levels * 2, ALIGN_NUM)); sizeof(float) /
float *nram_remain = (float *)((int32_t *)level_start_index_nram + (split_num_c * C_align + split_grid_num) /
PAD_UP(num_levels, ALIGN_NUM)); PAD_UP((num_hlp), ALIGN_NUM);
// calc load num int32_t deal_grid_num_theory = num_per_time_theory * num_hlp;
const int32_t weight_num2nram =
(MAX_NRAM_SIZE / LEN_FLOAT - split_num * C_align - const int32_t offset_nram = num_per_time_theory * C_align * num_hlp;
3 * PAD_UP(num_levels, ALIGN_NUM)) / const int32_t offset_nram_calc = PAD_UP(deal_grid_num_theory, ALIGN_NUM);
3 / num_points; float *nram_grad_output_tl = (float *)nram_buffer;
int32_t load_num = weight_num2nram; float *nram_grad_output_tr = (float *)nram_buffer + offset_nram;
const int32_t total_num = batch * num_query * num_heads * num_levels; float *nram_grad_output_bl = (float *)nram_buffer + 2 * offset_nram;
float *nram_grad_output_br = (float *)nram_buffer + 3 * offset_nram;
float *grad_temp1 = (float *)nram_buffer + 4 * offset_nram;
float *grad_temp2 = (float *)nram_buffer + 5 * offset_nram;
float *grad_temp3 = (float *)nram_buffer + 6 * offset_nram;
float *grad_temp4 = (float *)nram_buffer + 7 * offset_nram;
float *nram_loc_w = (float *)nram_buffer + split_num_c * offset_nram;
float *nram_loc_h =
(float *)nram_buffer + split_num_c * offset_nram + offset_nram_calc;
float *nram_h_low =
(float *)nram_buffer + split_num_c * offset_nram + 2 * offset_nram_calc;
float *nram_w_low =
(float *)nram_buffer + split_num_c * offset_nram + 3 * offset_nram_calc;
float *nram_h_high =
(float *)nram_buffer + split_num_c * offset_nram + 4 * offset_nram_calc;
float *nram_w_high =
(float *)nram_buffer + split_num_c * offset_nram + 5 * offset_nram_calc;
float *nram_h_low_temp =
(float *)nram_buffer + split_num_c * offset_nram + 6 * offset_nram_calc;
float *nram_h_high_temp =
(float *)nram_buffer + split_num_c * offset_nram + 7 * offset_nram_calc;
float *nram_hw =
(float *)nram_buffer + split_num_c * offset_nram + 8 * offset_nram_calc;
float *nram_hh =
(float *)nram_buffer + split_num_c * offset_nram + 9 * offset_nram_calc;
float *nram_lw =
(float *)nram_buffer + split_num_c * offset_nram + 10 * offset_nram_calc;
float *nram_lh =
(float *)nram_buffer + split_num_c * offset_nram + 11 * offset_nram_calc;
float *nram_h_low_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 12 * offset_nram_calc;
float *nram_h_high_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 13 * offset_nram_calc;
float *nram_w_low_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 14 * offset_nram_calc;
float *nram_w_high_ptr_offset =
(float *)nram_buffer + split_num_c * offset_nram + 15 * offset_nram_calc;
float *nram_w1 =
(float *)nram_buffer + split_num_c * offset_nram + 16 * offset_nram_calc;
float *nram_w2 =
(float *)nram_buffer + split_num_c * offset_nram + 17 * offset_nram_calc;
float *nram_w3 =
(float *)nram_buffer + split_num_c * offset_nram + 18 * offset_nram_calc;
float *nram_w4 =
(float *)nram_buffer + split_num_c * offset_nram + 19 * offset_nram_calc;
float *nram_grad_weight =
(float *)nram_buffer + split_num_c * offset_nram + 20 * offset_nram_calc;
float *nram_base_ptr =
(float *)nram_buffer + split_num_c * offset_nram + 21 * offset_nram_calc;
float *nram_offset_temp =
(float *)nram_buffer + split_num_c * offset_nram + 22 * offset_nram_calc;
float *nram_offset1 =
(float *)nram_buffer + split_num_c * offset_nram + 23 * offset_nram_calc;
float *nram_offset2 =
(float *)nram_buffer + split_num_c * offset_nram + 24 * offset_nram_calc;
float *nram_offset3 =
(float *)nram_buffer + split_num_c * offset_nram + 25 * offset_nram_calc;
float *nram_offset4 =
(float *)nram_buffer + split_num_c * offset_nram + 26 * offset_nram_calc;
float *nram_w_low_temp =
(float *)nram_buffer + split_num_c * offset_nram + 27 * offset_nram_calc;
int32_t *nram_spatial_shapes =
(int32_t *)((float *)nram_buffer + split_num_c * offset_nram +
28 * offset_nram_calc);
int32_t *nram_level_start_index =
(int32_t *)(nram_spatial_shapes + 2 * num_levels);
float *nram_h_stride = (float *)(nram_level_start_index + 3 * num_levels);
const int32_t total_num = batch * num_query;
int32_t num_per_core = total_num / taskDim; int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim; int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem); num_per_core = num_per_core + int32_t(taskId < num_rem);
if (num_per_core == 0) { num_per_time_theory =
return; num_per_core > num_per_time_theory ? num_per_time_theory : num_per_core;
} int32_t num_deal_grid = num_per_time_theory * num_hlp;
const int32_t start_per_core = num_rem > taskId
? (taskId * num_per_core) if (num_per_core == 0) return;
: (num_rem + taskId * num_per_core); int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core)
: (num_rem + taskId * num_per_core);
const int32_t qid_stride = num_heads * channels; const int32_t qid_stride = num_heads * channels;
int32_t deal_num_real = channels;
// load spatial_shapes anddata_level_start_index to nram const int32_t repeat_times = num_per_core / num_per_time_theory;
__memcpy_async(spatial_shapes_nram, spatial_shapes, const int32_t tail_num = num_per_core % num_per_time_theory;
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(level_start_index_nram, data_level_start_index,
num_levels * sizeof(int32_t), GDRAM2NRAM);
const int32_t start_l_col = start_per_core % num_levels; int32_t num_per_time_real = num_per_time_theory;
const int32_t start_m_col = start_per_core / num_levels % num_heads;
const int32_t start_q_col = for (int32_t loop = 0; loop < num_heads; ++loop) {
start_per_core / num_levels / num_heads % num_query; nram_base_ptr[loop] = loop * channels;
const int32_t start_b_col =
start_per_core / num_query / num_heads / num_levels;
const int32_t repeat = num_per_core / load_num;
const int32_t tail = num_per_core % load_num;
float *nram_sampling_loc = nram_remain;
float *nram_attn_weight = nram_sampling_loc + 2 * load_num * num_points;
const int32_t attn_weight_offset =
start_b_col * num_query * num_heads * num_levels * num_points +
start_q_col * num_heads * num_levels * num_points +
start_m_col * num_levels * num_points + start_l_col * num_points;
const int32_t sampling_loc_offset =
start_b_col * num_query * num_heads * num_levels * num_points * 2 +
start_q_col * num_heads * num_levels * num_points * 2 +
start_m_col * num_levels * num_points * 2 + start_l_col * num_points * 2;
if (repeat > 0) {
for (int32_t i_repeat = 0; i_repeat < repeat; ++i_repeat)
{ // load weight and sampling_loc to nram
__memcpy_async(nram_sampling_loc,
data_sampling_loc + sampling_loc_offset +
i_repeat * load_num * 2 * num_points,
2 * load_num * num_points * LEN_FLOAT, GDRAM2NRAM);
__memcpy(nram_attn_weight,
data_attn_weight + attn_weight_offset +
i_repeat * load_num * num_points,
load_num * num_points * LEN_FLOAT, GDRAM2NRAM);
msDeformAttnCol2imImpl(
top_grad_temp, top_grad, grad_h_weight, grad_w_weight, grad_value,
grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl,
grad_output_nram_br, grad_output_nram_tl_temp,
grad_output_nram_tr_temp, grad_output_nram_bl_temp,
grad_output_nram_br_temp, grad_sampling_loc, grad_attn_weight,
nram_sampling_loc, nram_attn_weight, load_num, load_num, i_repeat,
num_points, start_per_core, num_levels, num_heads, num_query,
spatial_size, qid_stride, level_start_index_nram, channels,
data_value, grad_output, spatial_shapes_nram);
}
} }
if (tail > 0) const int32_t w_stride = num_heads * channels;
for (int32_t grid_loop = 0; grid_loop < repeat_times + 1; grid_loop += 1) {
{ // load weight and sampling_loc to nram int32_t grid_offset =
__memcpy_async(nram_sampling_loc, (start_per_core + grid_loop * num_per_time_theory) * num_hlp;
data_sampling_loc + sampling_loc_offset + if (grid_loop == repeat_times) {
repeat * load_num * 2 * num_points, if (tail_num == 0) {
tail * num_points * 2 * LEN_FLOAT, GDRAM2NRAM); continue;
__memcpy( } else {
nram_attn_weight, grid_offset =
data_attn_weight + attn_weight_offset + repeat * load_num * num_points, (start_per_core + repeat_times * num_per_time_theory) * num_hlp;
tail * num_points * LEN_FLOAT, GDRAM2NRAM); num_per_time_real = tail_num;
msDeformAttnCol2imImpl( num_deal_grid = tail_num * num_hlp;
top_grad_temp, top_grad, grad_h_weight, grad_w_weight, grad_value, }
grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl, }
grad_output_nram_br, grad_output_nram_tl_temp, grad_output_nram_tr_temp,
grad_output_nram_bl_temp, grad_output_nram_br_temp, grad_sampling_loc, __memcpy_async(nram_spatial_shapes, spatial_shapes,
grad_attn_weight, nram_sampling_loc, nram_attn_weight, load_num, tail, num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
repeat, num_points, start_per_core, num_levels, num_heads, num_query, __memcpy_async(nram_level_start_index, data_level_start_index,
spatial_size, qid_stride, level_start_index_nram, channels, data_value, num_levels * sizeof(int32_t), GDRAM2NRAM);
grad_output, spatial_shapes_nram); __memcpy_async(nram_loc_w, data_sampling_loc + grid_offset * 2,
num_deal_grid * 2 * sizeof(float), GDRAM2NRAM);
__memcpy(nram_grad_weight, data_attn_weight + grid_offset,
num_deal_grid * sizeof(float), GDRAM2NRAM);
computeGridMaskAndOffset(
nram_grad_output_tl, nram_grad_output_tr, nram_loc_w, nram_loc_h,
nram_h_stride, nram_spatial_shapes, nram_w_low_temp, nram_h_high_temp,
nram_w_low, nram_h_low, nram_h_high, nram_w_high, nram_lh, nram_lw,
nram_hh, nram_hw, nram_h_low_ptr_offset, nram_h_high_ptr_offset,
nram_w_low_ptr_offset, nram_w_high_ptr_offset, nram_w1, nram_w2,
nram_w3, nram_w4, nram_offset_temp, nram_offset1, nram_offset2,
nram_offset3, nram_offset4, nram_base_ptr, nram_h_low_temp,
num_deal_grid, num_per_time_real, num_heads, num_levels, num_points,
w_stride, qid_stride);
float *mask1 = nram_h_low_ptr_offset;
float *mask2 = nram_h_high_ptr_offset;
float *mask3 = nram_w_low_ptr_offset;
float *mask4 = nram_w_high_ptr_offset;
loadValue(nram_grad_output_tl, nram_grad_output_tr, nram_grad_output_bl,
nram_grad_output_br, data_value, grad_output, grad_temp1,
grad_temp2, mask1, mask2, mask3, mask4, nram_offset1,
nram_offset2, nram_offset3, nram_offset4, nram_grad_weight,
nram_level_start_index, offset_nram, start_per_core, grid_loop,
num_per_time_theory, num_heads, deal_num_real, num_per_time_real,
num_deal_grid, num_query, num_levels, num_points, grid_offset,
spatial_size, qid_stride);
float *nram_grid_offset1 = nram_loc_h;
float *nram_grid_offset2 = nram_loc_w;
computeGradValue(
grad_temp1, grad_temp2, grad_temp3, grad_temp4, mask1, mask2, mask3,
mask4, nram_offset1, nram_offset2, nram_offset3, nram_offset4,
nram_level_start_index, deal_num_real, grad_value, nram_w1, nram_w2,
nram_w3, nram_w4, num_per_time_real, num_heads, num_levels, num_points,
num_query, num_deal_grid, grid_offset, spatial_size, qid_stride,
nram_grid_offset1, nram_grid_offset2);
// compute grad_weight
float *grad_weight = grad_temp1;
float *grad_h_weight = grad_temp4;
float *grad_w_weight = grad_temp3;
computeGradAttnWeight(
grad_w_weight, grad_weight, nram_grad_output_tl, nram_grad_output_tr,
nram_grad_output_bl, nram_grad_output_br, grad_temp1, grad_temp2,
grad_attn_weight, nram_hw, nram_hh, nram_lw, nram_lh, grad_h_weight,
nram_w1, nram_w2, nram_w3, nram_w4, offset_nram, num_deal_grid,
deal_num_real, num_per_time_real, num_heads, num_levels, num_points,
grid_offset, nram_h_high_temp);
// compute grad_sampling_loc
computeGradSampingLoc(grad_sampling_loc, nram_grad_output_tl,
nram_grad_output_tr, grad_h_weight, grad_w_weight,
nram_spatial_shapes, grad_temp1, grad_temp2,
nram_grad_weight, num_deal_grid, deal_num_real,
num_per_time_real, num_heads, num_levels, num_points,
grid_offset, nram_h_high_temp);
} }
#endif #endif
} }
...@@ -1739,6 +2050,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel( ...@@ -1739,6 +2050,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel(
const int32_t channels, const int32_t num_levels, const int32_t num_query, const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc, const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight); float *grad_attn_weight);
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel( __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel(
const float *data_value, const int32_t *spatial_shapes, const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc, const int32_t *data_level_start_index, const float *data_sampling_loc,
......
...@@ -47,17 +47,17 @@ typedef enum { ...@@ -47,17 +47,17 @@ typedef enum {
} MsDeformAttnBackwardKernelPolicy; } MsDeformAttnBackwardKernelPolicy;
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc( MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
const int32_t channels, const int32_t num_levels, const int32_t channels, const int32_t num_levels, const int32_t num_points,
const int32_t num_points) { const int32_t num_heads) {
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const uint64_t max_num = nram_size / sizeof(float); const int num_hlp = num_heads * num_levels * num_points;
const uint64_t deal_num = int num_per_time_theory = (nram_size - num_levels * sizeof(float) -
12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points; 3 * num_levels * sizeof(int32_t)) /
sizeof(float) / (8 * PAD_UP(channels, 32) + 28) /
if (max_num >= deal_num) { PAD_UP((num_hlp), 32);
if (num_per_time_theory >= 1) {
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL; return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
} }
return MS_DEFORM_ATTN_BACKWARD_DEFAULT; return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
} }
...@@ -101,7 +101,8 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc( ...@@ -101,7 +101,8 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) { if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT; return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else if (channels > nram_size / 12 / sizeof(float)) { } else if (channels > nram_size / 12 / sizeof(float) || channels > 96 ||
channels < 16) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT; return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else { } else {
return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL; return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL;
...@@ -472,7 +473,8 @@ void ms_deform_attn_mlu_backward( ...@@ -472,7 +473,8 @@ void ms_deform_attn_mlu_backward(
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>"; << ", " << k_dim.y << ", " << k_dim.z << ">>>";
MsDeformAttnBackwardKernelPolicy kernelPolicy = MsDeformAttnBackwardKernelPolicy kernelPolicy =
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points); msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points,
num_heads);
switch (kernelPolicy) { switch (kernelPolicy) {
default: { default: {
VLOG(5) << "NotImplemented."; VLOG(5) << "NotImplemented.";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment