Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31021d81
Commit
31021d81
authored
Dec 05, 2025
by
zhuwenwen
Browse files
update mla interface
parent
f6aa3d19
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
7 deletions
+12
-7
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+12
-7
No files found.
vllm/attention/ops/flashmla.py
View file @
31021d81
...
@@ -77,9 +77,9 @@ def get_mla_metadata(
...
@@ -77,9 +77,9 @@ def get_mla_metadata(
- num_splits: (batch_size + 1), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
"""
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
return
flash_mla_cuda
.
get_mla_metadata
(
cache_seqlens
,
return
flash_mla_cuda
.
get_mla_metadata
(
num_q_tokens_per_head_k
,
cache_seqlens
,
num_q_tokens_per_head_k
,
num_heads_k
,
num_heads_q
,
num_heads_
k
)
is_fp8_kvcache
,
top
k
)
else
:
else
:
return
torch
.
ops
.
_flashmla_C
.
get_mla_decoding_metadata
(
return
torch
.
ops
.
_flashmla_C
.
get_mla_decoding_metadata
(
cache_seqlens
,
num_q_tokens_per_head_k
,
num_heads_k
,
num_heads_q
,
cache_seqlens
,
num_q_tokens_per_head_k
,
num_heads_k
,
num_heads_q
,
...
@@ -160,8 +160,9 @@ def flash_mla_with_kvcache(
...
@@ -160,8 +160,9 @@ def flash_mla_with_kvcache(
else
:
else
:
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
q
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
q
,
k_cache
,
block_table
,
cache_seqlens
,
head_dim_v
,
tile_scheduler_metadata
,
causal
,
tile_scheduler_metadata
,
num_splits
)
num_splits
,
softmax_scale
,
causal
,
is_fp8_kvcache
,
indices
)
else
:
else
:
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_C
.
fwd_kvcache_mla
(
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_C
.
fwd_kvcache_mla
(
q
,
k_cache
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
q
,
k_cache
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
...
@@ -196,8 +197,12 @@ def flash_mla_sparse_prefill(
...
@@ -196,8 +197,12 @@ def flash_mla_sparse_prefill(
- max_logits: [s_q, h_q], float
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
"""
results
=
torch
.
ops
.
_flashmla_C
.
sparse_prefill_fwd
(
q
,
kv
,
indices
,
if
current_platform
.
is_rocm
():
sm_scale
,
d_v
)
return
flash_mla_cuda
.
sparse_prefill_fwd
(
q
,
kv
,
indices
,
sm_scale
,
d_v
)
else
:
results
=
torch
.
ops
.
_flashmla_C
.
sparse_prefill_fwd
(
q
,
kv
,
indices
,
sm_scale
,
d_v
)
return
results
return
results
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment