Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8294773e
Unverified
Commit
8294773e
authored
Feb 27, 2025
by
qli88
Committed by
GitHub
Feb 27, 2025
Browse files
[core] Perf improvement for DSv3 on AMD GPUs (#13718)
Signed-off-by:
qli88
<
qiang.li2@amd.com
>
parent
cd813c6d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
210 additions
and
25 deletions
+210
-25
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+72
-20
vllm/attention/ops/triton_decode_attention.py
vllm/attention/ops/triton_decode_attention.py
+10
-5
vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json
...Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json
+128
-0
No files found.
vllm/attention/backends/mla/common.py
View file @
8294773e
...
@@ -237,14 +237,20 @@ from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
...
@@ -237,14 +237,20 @@ from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
try
:
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
is_vllm_fa
=
True
except
ImportError
:
except
ImportError
:
# For rocm use upstream flash attention
# For rocm use upstream flash attention
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
is_vllm_fa
=
False
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
is_hip
=
current_platform
.
is_rocm
()
class
MLACommonBackend
(
AttentionBackend
):
class
MLACommonBackend
(
AttentionBackend
):
...
@@ -1046,12 +1052,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1046,12 +1052,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
q_proj
=
q_proj
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
vllm_flash_attn_version
=
get_flash_attn_vers
ion
()
self
.
triton_fa_func
=
triton_attent
ion
# Handle the differences between the flash_attn_varlen from flash_attn
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
# latter has an additional parameter to control FA2 vs FA3
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
self
.
vllm_flash_attn_version
is
not
None
:
if
self
.
vllm_flash_attn_version
is
not
None
:
self
.
flash_attn_varlen_func
=
\
self
.
flash_attn_varlen_func
=
\
functools
.
partial
(
flash_attn_varlen_func
,
functools
.
partial
(
flash_attn_varlen_func
,
...
@@ -1315,18 +1322,48 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1315,18 +1322,48 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
value
=
0
)
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
:
q
=
q
,
attn_output
,
attn_softmax_lse
=
self
.
triton_fa_func
(
k
=
k
,
q
,
v
=
v_padded
,
k
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
v_padded
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
None
,
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
prefill_metadata
.
query_start_loc
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
prefill_metadata
.
max_query_len
,
causal
=
False
,
# Context is unmasked
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
return_softmax_lse
=
True
,
False
,
# causal
)
self
.
scale
,
None
,
# attn_mask is None unless applying ALiBi mask
)
elif
is_vllm_fa
:
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
else
:
attn_output
,
attn_softmax_lse
,
_
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_attn_probs
=
True
,
)
if
output
is
None
:
if
output
is
None
:
output
=
attn_output
output
=
attn_output
...
@@ -1374,11 +1411,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1374,11 +1411,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
value
=
0
)
if
has_context
:
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
:
if
not
current_platform
.
is_cuda
():
output
=
self
.
triton_fa_func
(
raise
NotImplementedError
(
q
,
"Chunked Prefill for MLA is not currently supported on"
k
,
"non-cuda platforms"
)
v_padded
,
None
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
max_prefill_seq_len
,
prefill_metadata
.
max_prefill_seq_len
,
True
,
# causal
self
.
scale
,
None
,
# attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if
not
has_context
:
output
=
output
[
0
]
elif
is_vllm_fa
:
output
=
self
.
flash_attn_varlen_func
(
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
...
@@ -1389,7 +1439,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1389,7 +1439,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
return_softmax_lse
=
True
,
return_softmax_lse
=
has_context
,
)
)
else
:
else
:
output
=
self
.
flash_attn_varlen_func
(
output
=
self
.
flash_attn_varlen_func
(
...
@@ -1402,10 +1452,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1402,10 +1452,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
return_attn_probs
=
has_context
,
)
)
if
has_context
:
if
has_context
:
suffix_output
,
suffix_lse
=
output
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output
,
suffix_lse
,
*
rest
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
...
...
vllm/attention/ops/triton_decode_attention.py
View file @
8294773e
...
@@ -178,7 +178,8 @@ def _decode_att_m_fwd(
...
@@ -178,7 +178,8 @@ def _decode_att_m_fwd(
page_size
,
page_size
,
logit_cap
,
logit_cap
,
):
):
BLOCK
=
64
BLOCK
=
64
if
not
is_hip_
else
8
NUM_KV_SPLITS
=
num_kv_splits
NUM_KV_SPLITS
=
num_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
...
@@ -188,7 +189,9 @@ def _decode_att_m_fwd(
...
@@ -188,7 +189,9 @@ def _decode_att_m_fwd(
grid
=
(
batch
,
head_num
,
NUM_KV_SPLITS
)
grid
=
(
batch
,
head_num
,
NUM_KV_SPLITS
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
num_warps
=
4
if
kv_group_num
==
1
else
2
num_warps
=
4
if
kv_group_num
!=
1
:
num_warps
=
1
if
is_hip_
else
2
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
...
@@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd(
...
@@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd(
)
)
extra_kargs
=
{}
extra_kargs
=
{}
num_stages
=
2
if
is_hip_
:
if
is_hip_
:
# https://rocm.docs.amd.com/en/
docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://rocm.docs.amd.com/en/
latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
extra_kargs
=
{
"waves_per_eu"
:
4
,
"waves_per_eu"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
"kpack"
:
2
}
}
num_stages
=
1
_fwd_grouped_kernel_stage1
[
grid
](
_fwd_grouped_kernel_stage1
[
grid
](
q
,
q
,
...
@@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd(
PAGE_SIZE
=
page_size
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
2
,
num_stages
=
num_stages
,
Lk
=
Lk
,
Lk
=
Lk
,
Lv
=
Lv
,
Lv
=
Lv
,
**
extra_kargs
,
**
extra_kargs
,
...
...
vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json
0 → 100644
View file @
8294773e
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment