Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4b00d1ba
Commit
4b00d1ba
authored
Nov 23, 2025
by
zhuwenwen
Browse files
add VLLM_USE_CAT_MLA to use fused cat and mla
parent
be22412f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
100 additions
and
23 deletions
+100
-23
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+52
-0
vllm/envs.py
vllm/envs.py
+6
-0
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+42
-23
No files found.
vllm/attention/ops/flashmla.py
View file @
4b00d1ba
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm
import
envs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -144,6 +145,57 @@ def flash_mla_with_kvcache(
...
@@ -144,6 +145,57 @@ def flash_mla_with_kvcache(
return
out
,
softmax_lse
return
out
,
softmax_lse
def
flash_mla_with_kvcache_q_nope_pe
(
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
head_dim_v
:
int
,
tile_scheduler_metadata
:
torch
.
Tensor
,
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
softmax_scale
is
None
:
softmax_scale
=
(
q_nope
.
shape
[
-
1
]
+
q_pe
.
shape
[
-
1
])
**
(
-
0.5
)
if
current_platform
.
is_rocm
():
if
kv_cache_dtype
==
"fp8"
or
kv_cache_dtype
==
"fp8_e4m3"
or
kv_cache_dtype
==
"fp8_e5m2"
:
kv_dtype
=
"fp8_e4m3"
if
kv_cache_dtype
==
"fp8"
else
kv_cache_dtype
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_quantization_q_nope_pe_mla
(
q_nope
,
q_pe
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
k_scale
,
kv_dtype
,
)
return
out
,
softmax_lse
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla_nope_pe
(
q_nope
,
q_pe
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
)
return
out
,
softmax_lse
#
#
# TODO: Add fake functions
# TODO: Add fake functions
#
#
...
...
vllm/envs.py
View file @
4b00d1ba
...
@@ -184,6 +184,7 @@ if TYPE_CHECKING:
...
@@ -184,6 +184,7 @@ if TYPE_CHECKING:
VLLM_USE_PP_BALANCE
:
bool
=
False
VLLM_USE_PP_BALANCE
:
bool
=
False
VLLM_USE_ZERO_MTP
:
bool
=
False
VLLM_USE_ZERO_MTP
:
bool
=
False
VLLM_USE_CUDA_GRAPH_SIZES
:
bool
=
False
VLLM_USE_CUDA_GRAPH_SIZES
:
bool
=
False
VLLM_USE_CAT_MLA
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1194,6 +1195,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1194,6 +1195,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_CUDA_GRAPH_SIZES"
:
"VLLM_USE_CUDA_GRAPH_SIZES"
:
lambda
:
(
os
.
getenv
(
'VLLM_USE_CUDA_GRAPH_SIZES'
,
'True'
).
lower
()
in
lambda
:
(
os
.
getenv
(
'VLLM_USE_CUDA_GRAPH_SIZES'
,
'True'
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vllm will use fused cat and mla
"VLLM_USE_CAT_MLA"
:
lambda
:
(
os
.
getenv
(
'VLLM_USE_CAT_MLA'
,
'True'
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
4b00d1ba
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
from
vllm.attention.backends.abstract
import
(
AttentionType
,
from
vllm.attention.backends.abstract
import
(
AttentionType
,
is_quantized_kv_cache
)
is_quantized_kv_cache
)
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
flash_mla_with_kvcache_q_nope_pe
,
get_mla_metadata
,
get_mla_metadata
,
is_flashmla_supported
)
is_flashmla_supported
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -167,31 +168,49 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -167,31 +168,49 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
if
envs
.
VLLM_USE_OPT_CAT
:
if
not
envs
.
VLLM_USE_CAT_MLA
:
if
q_nope
.
shape
[
0
]
<
1024
:
if
envs
.
VLLM_USE_OPT_CAT
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
if
q_nope
.
shape
[
0
]
<
1024
:
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
.
unsqueeze
(
1
)
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
if
not
envs
.
VLLM_USE_CAT_MLA
:
o
,
_
=
flash_mla_with_kvcache
(
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
else
:
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
o
,
_
=
flash_mla_with_kvcache_q_nope_pe
(
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
q_nope
=
q_nope
.
unsqueeze
(
1
),
q_pe
=
q_pe
.
unsqueeze
(
1
),
o
,
_
=
flash_mla_with_kvcache
(
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
q
=
q
,
block_table
=
attn_metadata
.
decode
.
block_table
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
block_table
=
attn_metadata
.
decode
.
block_table
,
head_dim_v
=
self
.
kv_lora_rank
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
num_splits
=
attn_metadata
.
decode
.
num_splits
,
tile_scheduler_metadata
,
softmax_scale
=
self
.
scale
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
causal
=
True
,
softmax_scale
=
self
.
scale
,
k_scale
=
k_scale
,
causal
=
True
,
kv_cache_dtype
=
kv_cache_dtype
,
k_scale
=
k_scale
,
)
kv_cache_dtype
=
kv_cache_dtype
,
)
return
self
.
_v_up_proj
(
o
)
return
self
.
_v_up_proj
(
o
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment