Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
85b72cb7
Unverified
Commit
85b72cb7
authored
May 09, 2025
by
Michael Goin
Committed by
GitHub
May 09, 2025
Browse files
Revert "[BugFix][AMD] Compatible patch for latest AITER(05/07/2025)" (#17910)
parent
6e5595ca
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
54 deletions
+23
-54
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+6
-6
vllm/attention/backends/rocm_aiter_mla.py
vllm/attention/backends/rocm_aiter_mla.py
+12
-37
vllm/attention/ops/rocm_aiter_mla.py
vllm/attention/ops/rocm_aiter_mla.py
+1
-6
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+4
-5
No files found.
vllm/attention/backends/mla/common.py
View file @
85b72cb7
...
...
@@ -1213,9 +1213,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
q
,
k
,
v
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
...
...
@@ -1267,9 +1267,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
,
k
,
v
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
...
...
vllm/attention/backends/rocm_aiter_mla.py
View file @
85b72cb7
...
...
@@ -53,7 +53,7 @@ class AiterMLABackend(MLACommonBackend):
@
dataclass
class
AiterMLAMetadata
(
MLACommonMetadata
):
# The following
5
tensors are for current version of AITER MLA
# The following
4
tensors are for current version of AITER MLA
block_table_bound
:
Optional
[
torch
.
Tensor
]
=
None
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -63,10 +63,6 @@ class AiterMLAMetadata(MLACommonMetadata):
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
# This is just to make new AITER MLA API work
# -- MTP support is not added yet.
qo_indptr
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
self
):
prefill_metadata
=
super
().
prefill_metadata
...
...
@@ -78,7 +74,6 @@ class AiterMLAMetadata(MLACommonMetadata):
prefill_metadata
\
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
prefill_metadata
.
block_table_bound
=
self
.
block_table_bound
prefill_metadata
.
qo_indptr
=
self
.
qo_indptr
# update the cache
self
.
_cached_prefill_metadata
=
self
.
__class__
(
...
...
@@ -98,7 +93,6 @@ class AiterMLAMetadata(MLACommonMetadata):
decode_metadata
\
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
decode_metadata
.
block_table_bound
=
self
.
block_table_bound
decode_metadata
.
qo_indptr
=
self
.
qo_indptr
# update the cache
self
.
_cached_decode_metadata
=
self
.
__class__
(
...
...
@@ -142,7 +136,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_indptr
:
list
[
int
]
=
[
0
]
self
.
paged_kv_last_page_lens
:
list
[
int
]
=
[]
self
.
total_blocks
=
0
self
.
qo_indptr
:
list
[
int
]
=
[
0
]
def
_add_seq_group
(
self
,
inter_data
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
...
...
@@ -215,7 +208,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
self
.
qo_indptr
.
append
(
self
.
qo_indptr
[
-
1
]
+
1
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
...
...
@@ -234,8 +226,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
cuda_graph_pad_size
)
self
.
paged_kv_last_page_lens
.
extend
([
0
]
*
cuda_graph_pad_size
)
last_qo_indptr
=
self
.
qo_indptr
[
-
1
]
self
.
qo_indptr
.
extend
([
last_qo_indptr
]
*
cuda_graph_pad_size
)
# For current version of AITER MLA
if
len
(
self
.
paged_kv_indptr
)
>
0
:
...
...
@@ -255,22 +245,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
1
,
device
=
device
,
dtype
=
torch
.
int
)
qo_indptr
=
torch
.
tensor
(
self
.
qo_indptr
,
device
=
device
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_lens_tensor
=
None
block_table_bound_tensor
=
None
qo_indptr
=
None
metadata
.
paged_kv_indptr
=
paged_kv_indptr_tensor
metadata
.
paged_kv_indices
=
paged_kv_indices_tensor
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens_tensor
metadata
.
block_table_bound
=
block_table_bound_tensor
metadata
.
qo_indptr
=
qo_indptr
return
metadata
...
...
@@ -279,17 +263,14 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
kv_indices
,
kv_indptr
,
last_page_lens
,
qo_indptr
=
\
get_aiter_mla_metadata
(
kv_indices
,
kv_indptr
,
last_page_lens
=
get_aiter_mla_metadata
(
max_batch_size
=
max_batch_size
,
block_size
=
self
.
runner
.
block_size
,
max_block_per_batch
=
\
self
.
runner
.
get_max_block_per_batch
(),
max_block_per_batch
=
self
.
runner
.
get_max_block_per_batch
(),
device
=
self
.
runner
.
device
)
self
.
_paged_kv_indices_tensor
=
kv_indices
self
.
_paged_kv_indptr_tensor
=
kv_indptr
self
.
_paged_kv_last_page_lens_tensor
=
last_page_lens
self
.
_qo_indptr_tensor
=
qo_indptr
with
super
().
graph_capture
(
max_batch_size
):
yield
...
...
@@ -297,7 +278,6 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
del
self
.
_paged_kv_indices_tensor
del
self
.
_paged_kv_indptr_tensor
del
self
.
_paged_kv_last_page_lens_tensor
del
self
.
_qo_indptr_tensor
def
graph_capture_get_metadata_for_batch
(
self
,
...
...
@@ -311,12 +291,10 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
paged_kv_indices
=
self
.
_paged_kv_indices_tensor
paged_kv_last_page_lens
=
self
.
_paged_kv_last_page_lens_tensor
[:
batch_size
]
qo_indptr
=
self
.
_qo_indptr_tensor
[:
batch_size
+
1
]
metadata
.
paged_kv_indptr
=
paged_kv_indptr
metadata
.
paged_kv_indices
=
paged_kv_indices
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens
metadata
.
qo_indptr
=
qo_indptr
return
metadata
...
...
@@ -333,7 +311,6 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
input_buffers
[
"paged_kv_last_page_lens"
]
=
attn_metadata
.
\
decode_metadata
.
paged_kv_last_page_lens
input_buffers
[
'qo_indptr'
]
=
attn_metadata
.
qo_indptr
return
input_buffers
...
...
@@ -353,8 +330,6 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
input_buffers
[
"paged_kv_last_page_lens"
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_last_page_lens
,
non_blocking
=
True
)
input_buffers
[
"qo_indptr"
].
copy_
(
attn_metadata
.
decode_metadata
.
qo_indptr
,
non_blocking
=
True
)
class
AiterMLAImpl
(
MLACommonImpl
[
AiterMLAMetadata
]):
...
...
@@ -395,9 +370,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
softmax_scale
:
float
,
return_softmax_lse
:
bool
,
**
kwargs
)
->
Union
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
output
=
self
.
flash_attn_varlen_func
(
q
,
k
,
v
,
q
=
q
,
k
=
k
,
v
=
v
,
softmax_scale
=
softmax_scale
,
return_lse
=
return_softmax_lse
,
**
kwargs
,
)
...
...
@@ -417,7 +394,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
B
=
q_nope
.
shape
[
0
]
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
o
=
torch
.
empty
(
B
,
o
=
torch
.
zeros
(
B
,
self
.
num_heads
,
self
.
kv_lora_rank
,
dtype
=
q
.
dtype
,
...
...
@@ -426,8 +403,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
aiter_mla_decode_fwd
(
q
,
kv_buffer
,
o
,
self
.
scale
,
attn_metadata
.
qo_indptr
,
attn_metadata
.
max_query_len
,
attn_metadata
.
paged_kv_indptr
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_lens
)
...
...
vllm/attention/ops/rocm_aiter_mla.py
View file @
85b72cb7
...
...
@@ -20,8 +20,7 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
paged_kv_last_page_lens
=
torch
.
full
((
max_batch_size
,
),
block_size
,
dtype
=
torch
.
int32
)
qo_indptr
=
torch
.
zeros
(
max_batch_size
+
1
,
dtype
=
torch
.
int
,
device
=
device
)
return
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_lens
,
qo_indptr
return
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_lens
def
aiter_mla_decode_fwd
(
...
...
@@ -29,8 +28,6 @@ def aiter_mla_decode_fwd(
kv_buffer
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
sm_scale
:
float
,
qo_indptr
:
torch
.
Tensor
,
max_seqlen_qo
:
int
,
kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -63,11 +60,9 @@ def mla_decode_fwd_impl(
mla_decode_fwd
(
q
,
kv_buffer
.
view
(
-
1
,
1
,
1
,
q
.
shape
[
-
1
]),
o
,
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
max_seqlen_qo
,
sm_scale
=
sm_scale
,
logit_cap
=
logit_cap
)
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
85b72cb7
...
...
@@ -123,11 +123,10 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
fmoe_fp8_blockscale_g1u1
(
out_asm
,
a1
,
w1
,
w2
,
sorted_token_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
num_valid_ids
,
topk
,
a1_scale
.
t
().
contiguous
(),
w1_scale
.
view
(
local_E
,
-
1
),
w2_scale
.
view
(
local_E
,
-
1
),
*
block_shape
,
smooth_scale
)
num_valid_ids
,
topk
,
w1_scale
.
view
(
local_E
,
-
1
),
w2_scale
.
view
(
local_E
,
-
1
),
a1_scale
.
t
().
contiguous
(),
*
block_shape
,
smooth_scale
)
return
out_asm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment