Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4f8b3732
Unverified
Commit
4f8b3732
authored
May 14, 2025
by
qli88
Committed by
GitHub
May 13, 2025
Browse files
[BugFix][AMD] Compatible patch for AITER lib after 04/20 (#17912)
Signed-off-by:
Qiang Li
<
qiang.li2@amd.com
>
parent
7b2f28de
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
17 deletions
+54
-17
vllm/attention/backends/rocm_aiter_mla.py
vllm/attention/backends/rocm_aiter_mla.py
+37
-12
vllm/attention/ops/rocm_aiter_mla.py
vllm/attention/ops/rocm_aiter_mla.py
+12
-1
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+5
-4
No files found.
vllm/attention/backends/rocm_aiter_mla.py
View file @
4f8b3732
...
@@ -53,7 +53,7 @@ class AiterMLABackend(MLACommonBackend):
...
@@ -53,7 +53,7 @@ class AiterMLABackend(MLACommonBackend):
@
dataclass
@
dataclass
class
AiterMLAMetadata
(
MLACommonMetadata
):
class
AiterMLAMetadata
(
MLACommonMetadata
):
# The following
4
tensors are for current version of AITER MLA
# The following
5
tensors are for current version of AITER MLA
block_table_bound
:
Optional
[
torch
.
Tensor
]
=
None
block_table_bound
:
Optional
[
torch
.
Tensor
]
=
None
# The indptr of the paged kv cache, shape: [batch_size + 1]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
paged_kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata):
...
@@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata):
# the paged kv cache, shape: [batch_size]
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
paged_kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
# This is just to make new AITER MLA API work
# -- MTP support is not added yet.
qo_indptr
:
Optional
[
torch
.
Tensor
]
=
None
@
property
@
property
def
prefill_metadata
(
self
):
def
prefill_metadata
(
self
):
prefill_metadata
=
super
().
prefill_metadata
prefill_metadata
=
super
().
prefill_metadata
...
@@ -74,6 +78,7 @@ class AiterMLAMetadata(MLACommonMetadata):
...
@@ -74,6 +78,7 @@ class AiterMLAMetadata(MLACommonMetadata):
prefill_metadata
\
prefill_metadata
\
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
prefill_metadata
.
block_table_bound
=
self
.
block_table_bound
prefill_metadata
.
block_table_bound
=
self
.
block_table_bound
prefill_metadata
.
qo_indptr
=
self
.
qo_indptr
# update the cache
# update the cache
self
.
_cached_prefill_metadata
=
self
.
__class__
(
self
.
_cached_prefill_metadata
=
self
.
__class__
(
...
@@ -93,6 +98,7 @@ class AiterMLAMetadata(MLACommonMetadata):
...
@@ -93,6 +98,7 @@ class AiterMLAMetadata(MLACommonMetadata):
decode_metadata
\
decode_metadata
\
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
decode_metadata
.
block_table_bound
=
self
.
block_table_bound
decode_metadata
.
block_table_bound
=
self
.
block_table_bound
decode_metadata
.
qo_indptr
=
self
.
qo_indptr
# update the cache
# update the cache
self
.
_cached_decode_metadata
=
self
.
__class__
(
self
.
_cached_decode_metadata
=
self
.
__class__
(
...
@@ -136,6 +142,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -136,6 +142,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_indptr
:
list
[
int
]
=
[
0
]
self
.
paged_kv_indptr
:
list
[
int
]
=
[
0
]
self
.
paged_kv_last_page_lens
:
list
[
int
]
=
[]
self
.
paged_kv_last_page_lens
:
list
[
int
]
=
[]
self
.
total_blocks
=
0
self
.
total_blocks
=
0
self
.
qo_indptr
:
list
[
int
]
=
[
0
]
def
_add_seq_group
(
self
,
inter_data
,
chunked_prefill_enabled
:
bool
,
def
_add_seq_group
(
self
,
inter_data
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
prefix_cache_hit
:
bool
):
...
@@ -208,6 +215,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -208,6 +215,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
block_table_bound
)
self
.
qo_indptr
.
append
(
self
.
qo_indptr
[
-
1
]
+
1
)
last_page_len
=
seq_len
%
self
.
block_size
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
if
last_page_len
==
0
:
...
@@ -226,6 +234,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -226,6 +234,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
cuda_graph_pad_size
)
cuda_graph_pad_size
)
self
.
paged_kv_last_page_lens
.
extend
([
0
]
*
cuda_graph_pad_size
)
self
.
paged_kv_last_page_lens
.
extend
([
0
]
*
cuda_graph_pad_size
)
last_qo_indptr
=
self
.
qo_indptr
[
-
1
]
self
.
qo_indptr
.
extend
([
last_qo_indptr
]
*
cuda_graph_pad_size
)
# For current version of AITER MLA
# For current version of AITER MLA
if
len
(
self
.
paged_kv_indptr
)
>
0
:
if
len
(
self
.
paged_kv_indptr
)
>
0
:
...
@@ -245,16 +255,22 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -245,16 +255,22 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
1
,
1
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
int
)
dtype
=
torch
.
int
)
qo_indptr
=
torch
.
tensor
(
self
.
qo_indptr
,
device
=
device
,
dtype
=
torch
.
int
)
else
:
else
:
paged_kv_indices_tensor
=
None
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_lens_tensor
=
None
paged_kv_last_page_lens_tensor
=
None
block_table_bound_tensor
=
None
block_table_bound_tensor
=
None
qo_indptr
=
None
metadata
.
paged_kv_indptr
=
paged_kv_indptr_tensor
metadata
.
paged_kv_indptr
=
paged_kv_indptr_tensor
metadata
.
paged_kv_indices
=
paged_kv_indices_tensor
metadata
.
paged_kv_indices
=
paged_kv_indices_tensor
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens_tensor
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens_tensor
metadata
.
block_table_bound
=
block_table_bound_tensor
metadata
.
block_table_bound
=
block_table_bound_tensor
metadata
.
qo_indptr
=
qo_indptr
return
metadata
return
metadata
...
@@ -263,14 +279,17 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
...
@@ -263,14 +279,17 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
@
contextmanager
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
def
graph_capture
(
self
,
max_batch_size
:
int
):
kv_indices
,
kv_indptr
,
last_page_lens
=
get_aiter_mla_metadata
(
kv_indices
,
kv_indptr
,
last_page_lens
,
qo_indptr
=
\
get_aiter_mla_metadata
(
max_batch_size
=
max_batch_size
,
max_batch_size
=
max_batch_size
,
block_size
=
self
.
runner
.
block_size
,
block_size
=
self
.
runner
.
block_size
,
max_block_per_batch
=
self
.
runner
.
get_max_block_per_batch
(),
max_block_per_batch
=
\
self
.
runner
.
get_max_block_per_batch
(),
device
=
self
.
runner
.
device
)
device
=
self
.
runner
.
device
)
self
.
_paged_kv_indices_tensor
=
kv_indices
self
.
_paged_kv_indices_tensor
=
kv_indices
self
.
_paged_kv_indptr_tensor
=
kv_indptr
self
.
_paged_kv_indptr_tensor
=
kv_indptr
self
.
_paged_kv_last_page_lens_tensor
=
last_page_lens
self
.
_paged_kv_last_page_lens_tensor
=
last_page_lens
self
.
_qo_indptr_tensor
=
qo_indptr
with
super
().
graph_capture
(
max_batch_size
):
with
super
().
graph_capture
(
max_batch_size
):
yield
yield
...
@@ -278,6 +297,7 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
...
@@ -278,6 +297,7 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
del
self
.
_paged_kv_indices_tensor
del
self
.
_paged_kv_indices_tensor
del
self
.
_paged_kv_indptr_tensor
del
self
.
_paged_kv_indptr_tensor
del
self
.
_paged_kv_last_page_lens_tensor
del
self
.
_paged_kv_last_page_lens_tensor
del
self
.
_qo_indptr_tensor
def
graph_capture_get_metadata_for_batch
(
def
graph_capture_get_metadata_for_batch
(
self
,
self
,
...
@@ -291,10 +311,12 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
...
@@ -291,10 +311,12 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
paged_kv_indices
=
self
.
_paged_kv_indices_tensor
paged_kv_indices
=
self
.
_paged_kv_indices_tensor
paged_kv_last_page_lens
=
self
.
_paged_kv_last_page_lens_tensor
[:
paged_kv_last_page_lens
=
self
.
_paged_kv_last_page_lens_tensor
[:
batch_size
]
batch_size
]
qo_indptr
=
self
.
_qo_indptr_tensor
[:
batch_size
+
1
]
metadata
.
paged_kv_indptr
=
paged_kv_indptr
metadata
.
paged_kv_indptr
=
paged_kv_indptr
metadata
.
paged_kv_indices
=
paged_kv_indices
metadata
.
paged_kv_indices
=
paged_kv_indices
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens
metadata
.
qo_indptr
=
qo_indptr
return
metadata
return
metadata
...
@@ -311,6 +333,7 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
...
@@ -311,6 +333,7 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
input_buffers
[
input_buffers
[
"paged_kv_last_page_lens"
]
=
attn_metadata
.
\
"paged_kv_last_page_lens"
]
=
attn_metadata
.
\
decode_metadata
.
paged_kv_last_page_lens
decode_metadata
.
paged_kv_last_page_lens
input_buffers
[
'qo_indptr'
]
=
attn_metadata
.
qo_indptr
return
input_buffers
return
input_buffers
...
@@ -330,6 +353,8 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
...
@@ -330,6 +353,8 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
input_buffers
[
"paged_kv_last_page_lens"
].
copy_
(
input_buffers
[
"paged_kv_last_page_lens"
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_last_page_lens
,
attn_metadata
.
decode_metadata
.
paged_kv_last_page_lens
,
non_blocking
=
True
)
non_blocking
=
True
)
input_buffers
[
"qo_indptr"
].
copy_
(
attn_metadata
.
decode_metadata
.
qo_indptr
,
non_blocking
=
True
)
class
AiterMLAImpl
(
MLACommonImpl
[
AiterMLAMetadata
]):
class
AiterMLAImpl
(
MLACommonImpl
[
AiterMLAMetadata
]):
...
@@ -370,11 +395,9 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -370,11 +395,9 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
softmax_scale
:
float
,
return_softmax_lse
:
bool
,
softmax_scale
:
float
,
return_softmax_lse
:
bool
,
**
kwargs
)
->
Union
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
**
kwargs
)
->
Union
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
output
=
self
.
flash_attn_varlen_func
(
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
q
,
k
=
k
,
k
,
v
=
v
,
v
,
softmax_scale
=
softmax_scale
,
return_lse
=
return_softmax_lse
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -394,7 +417,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -394,7 +417,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
B
=
q_nope
.
shape
[
0
]
B
=
q_nope
.
shape
[
0
]
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
o
=
torch
.
zeros
(
B
,
o
=
torch
.
empty
(
B
,
self
.
num_heads
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
kv_lora_rank
,
dtype
=
q
.
dtype
,
dtype
=
q
.
dtype
,
...
@@ -403,6 +426,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -403,6 +426,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
aiter_mla_decode_fwd
(
q
,
kv_buffer
,
o
,
self
.
scale
,
aiter_mla_decode_fwd
(
q
,
kv_buffer
,
o
,
self
.
scale
,
attn_metadata
.
qo_indptr
,
attn_metadata
.
max_query_len
,
attn_metadata
.
paged_kv_indptr
,
attn_metadata
.
paged_kv_indptr
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_lens
)
attn_metadata
.
paged_kv_last_page_lens
)
...
...
vllm/attention/ops/rocm_aiter_mla.py
View file @
4f8b3732
...
@@ -20,7 +20,8 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
...
@@ -20,7 +20,8 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
paged_kv_last_page_lens
=
torch
.
full
((
max_batch_size
,
),
paged_kv_last_page_lens
=
torch
.
full
((
max_batch_size
,
),
block_size
,
block_size
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
return
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_lens
qo_indptr
=
torch
.
zeros
(
max_batch_size
+
1
,
dtype
=
torch
.
int
,
device
=
device
)
return
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_lens
,
qo_indptr
def
aiter_mla_decode_fwd
(
def
aiter_mla_decode_fwd
(
...
@@ -28,6 +29,8 @@ def aiter_mla_decode_fwd(
...
@@ -28,6 +29,8 @@ def aiter_mla_decode_fwd(
kv_buffer
:
torch
.
Tensor
,
kv_buffer
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
sm_scale
:
float
,
sm_scale
:
float
,
qo_indptr
:
torch
.
Tensor
,
max_seqlen_qo
:
int
,
kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -38,6 +41,8 @@ def aiter_mla_decode_fwd(
...
@@ -38,6 +41,8 @@ def aiter_mla_decode_fwd(
kv_buffer
.
view
(
kv_buffer
.
view
(
-
1
,
1
,
1
,
q
.
shape
[
-
1
]),
-
1
,
1
,
1
,
q
.
shape
[
-
1
]),
o
,
o
,
qo_indptr
,
max_seqlen_qo
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
kv_last_page_lens
,
kv_last_page_lens
,
...
@@ -49,6 +54,8 @@ def mla_decode_fwd_impl(
...
@@ -49,6 +54,8 @@ def mla_decode_fwd_impl(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
kv_buffer
:
torch
.
Tensor
,
kv_buffer
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
max_seqlen_qo
:
int
,
kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -60,9 +67,11 @@ def mla_decode_fwd_impl(
...
@@ -60,9 +67,11 @@ def mla_decode_fwd_impl(
mla_decode_fwd
(
q
,
mla_decode_fwd
(
q
,
kv_buffer
.
view
(
-
1
,
1
,
1
,
q
.
shape
[
-
1
]),
kv_buffer
.
view
(
-
1
,
1
,
1
,
q
.
shape
[
-
1
]),
o
,
o
,
qo_indptr
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
kv_last_page_lens
,
kv_last_page_lens
,
max_seqlen_qo
,
sm_scale
=
sm_scale
,
sm_scale
=
sm_scale
,
logit_cap
=
logit_cap
)
logit_cap
=
logit_cap
)
...
@@ -71,6 +80,8 @@ def mla_decode_fwd_fake(
...
@@ -71,6 +80,8 @@ def mla_decode_fwd_fake(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
kv_buffer
:
torch
.
Tensor
,
kv_buffer
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
max_seqlen_qo
:
int
,
kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
4f8b3732
...
@@ -123,10 +123,11 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
...
@@ -123,10 +123,11 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
fmoe_fp8_blockscale_g1u1
(
out_asm
,
a1
,
w1
,
w2
,
sorted_token_ids
,
fmoe_fp8_blockscale_g1u1
(
out_asm
,
a1
,
w1
,
w2
,
sorted_token_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
num_valid_ids
,
topk
,
w1_scale
.
view
(
local_E
,
-
1
),
num_valid_ids
,
topk
,
w2_scale
.
view
(
local_E
,
-
1
),
a1_scale
.
t
().
contiguous
(),
a1_scale
.
t
().
contiguous
(),
*
block_shape
,
w1_scale
.
view
(
local_E
,
-
1
),
smooth_scale
)
w2_scale
.
view
(
local_E
,
-
1
),
*
block_shape
,
smooth_scale
)
return
out_asm
return
out_asm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment