Unverified Commit 5f122293 authored by BinZheng's avatar BinZheng Committed by GitHub
Browse files

[Enhancement] ms_deform_attn performance optimization (#2616)



* ms_opt_v2

* ms_opt_v2_1

* optimize MultiScaleDeformableAttention ops for MLU

* ms_opt_v2_1

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

---------
Co-authored-by: default avatardongchengwei <dongchengwei@cambricon.com>
parent ec639323
......@@ -47,17 +47,17 @@ typedef enum {
} MsDeformAttnBackwardKernelPolicy;
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
const int32_t channels, const int32_t num_levels,
const int32_t num_points) {
const int32_t channels, const int32_t num_levels, const int32_t num_points,
const int32_t num_heads) {
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const uint64_t max_num = nram_size / sizeof(float);
const uint64_t deal_num =
12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points;
if (max_num >= deal_num) {
const int num_hlp = num_heads * num_levels * num_points;
int num_per_time_theory = (nram_size - num_levels * sizeof(float) -
3 * num_levels * sizeof(int32_t)) /
sizeof(float) / (8 * PAD_UP(channels, 32) + 28) /
PAD_UP((num_hlp), 32);
if (num_per_time_theory >= 1) {
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
}
return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
}
......@@ -101,7 +101,8 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else if (channels > nram_size / 12 / sizeof(float)) {
} else if (channels > nram_size / 12 / sizeof(float) || channels > 96 ||
channels < 16) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else {
return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL;
......@@ -472,7 +473,8 @@ void ms_deform_attn_mlu_backward(
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
MsDeformAttnBackwardKernelPolicy kernelPolicy =
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points);
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points,
num_heads);
switch (kernelPolicy) {
default: {
VLOG(5) << "NotImplemented.";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment