Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
77210184
Commit
77210184
authored
Dec 17, 2025
by
zhuwenwen
Browse files
update flash_mla_with_kvcache_fp8 interface and k_cache
parent
347fc09c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
8 deletions
+5
-8
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+1
-2
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+2
-1
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+2
-5
No files found.
vllm/attention/backends/flashmla.py
View file @
77210184
...
...
@@ -17,7 +17,6 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
from
vllm
import
envs
...
...
@@ -239,7 +238,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
view
(
torch
.
float8_e4m3fn
)
.
unsqueeze
(
-
2
)
,
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
to
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
head_dim_v
=
self
.
kv_lora_rank
,
...
...
vllm/attention/ops/flashmla.py
View file @
77210184
...
...
@@ -73,6 +73,7 @@ def get_mla_decoding_metadata_dense_fp8(
cache_seqlens
:
torch
.
Tensor
,
num_heads_per_head_k
:
int
,
num_heads_k
:
int
,
num_heads_q
:
int
=
16
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
...
...
@@ -87,7 +88,7 @@ def get_mla_decoding_metadata_dense_fp8(
"""
return
flash_mla_cuda
.
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
,
num_heads_per_head_k
,
num_heads_k
)
num_heads_k
,
num_heads_q
)
def
flash_mla_with_kvcache
(
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
77210184
...
...
@@ -12,7 +12,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
...
...
@@ -183,10 +182,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
view
(
torch
.
float8_e4m3fn
)
.
unsqueeze
(
-
2
)
,
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
to
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
...
...
@@ -213,7 +211,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
if
not
envs
.
VLLM_USE_CAT_MLA
or
kv_cache_dtype
==
"fp8_e4m3"
:
o
,
_
=
flash_mla_with_kvcache
(
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment