Unverified Commit 5f122293 authored by BinZheng's avatar BinZheng Committed by GitHub
Browse files

[Enhancement] ms_deform_attn performance optimization (#2616)



* ms_opt_v2

* ms_opt_v2_1

* optimize MultiScaleDeformableAttention ops for MLU

* ms_opt_v2_1

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

---------
Co-authored-by: default avatardongchengwei <dongchengwei@cambricon.com>
parent ec639323
...@@ -47,17 +47,17 @@ typedef enum { ...@@ -47,17 +47,17 @@ typedef enum {
} MsDeformAttnBackwardKernelPolicy; } MsDeformAttnBackwardKernelPolicy;
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc( MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
const int32_t channels, const int32_t num_levels, const int32_t channels, const int32_t num_levels, const int32_t num_points,
const int32_t num_points) { const int32_t num_heads) {
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const uint64_t max_num = nram_size / sizeof(float); const int num_hlp = num_heads * num_levels * num_points;
const uint64_t deal_num = int num_per_time_theory = (nram_size - num_levels * sizeof(float) -
12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points; 3 * num_levels * sizeof(int32_t)) /
sizeof(float) / (8 * PAD_UP(channels, 32) + 28) /
if (max_num >= deal_num) { PAD_UP((num_hlp), 32);
if (num_per_time_theory >= 1) {
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL; return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
} }
return MS_DEFORM_ATTN_BACKWARD_DEFAULT; return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
} }
...@@ -101,7 +101,8 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc( ...@@ -101,7 +101,8 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) { if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT; return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else if (channels > nram_size / 12 / sizeof(float)) { } else if (channels > nram_size / 12 / sizeof(float) || channels > 96 ||
channels < 16) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT; return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else { } else {
return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL; return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL;
...@@ -472,7 +473,8 @@ void ms_deform_attn_mlu_backward( ...@@ -472,7 +473,8 @@ void ms_deform_attn_mlu_backward(
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>"; << ", " << k_dim.y << ", " << k_dim.z << ">>>";
MsDeformAttnBackwardKernelPolicy kernelPolicy = MsDeformAttnBackwardKernelPolicy kernelPolicy =
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points); msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points,
num_heads);
switch (kernelPolicy) { switch (kernelPolicy) {
default: { default: {
VLOG(5) << "NotImplemented."; VLOG(5) << "NotImplemented.";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment