Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a9f57e73
Commit
a9f57e73
authored
Dec 16, 2025
by
zhuwenwen
Browse files
add VLLM_USE_FLASH_MLA_FP8 to use mla fp8
set VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT=1
parent
8548cf87
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
226 additions
and
67 deletions
+226
-67
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+74
-33
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+1
-2
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+74
-0
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+4
-4
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+67
-27
No files found.
vllm/attention/backends/flashmla.py
View file @
a9f57e73
...
...
@@ -16,7 +16,10 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonState
)
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
from
vllm
import
envs
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
...
...
@@ -87,13 +90,20 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
batch_size
)
if
m
.
num_decode_tokens
>
0
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
m
.
decode_tile_scheduler_metadata
,
m
.
decode_num_splits
=
\
get_mla_decoding_metadata_dense_fp8
(
m
.
seq_lens_tensor
[
m
.
num_prefills
:],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
else
:
m
.
decode_tile_scheduler_metadata
,
m
.
decode_num_splits
=
\
get_mla_metadata
(
m
.
seq_lens_tensor
[
m
.
num_prefills
:],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
return
m
...
...
@@ -108,6 +118,15 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
# Run a dummy `get_mla_metadata` so we can get the right shapes
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
self
.
_graph_decoder_tile_scheduler_metadata
,
\
self
.
_graph_decode_num_splits
=
get_mla_decoding_metadata_dense_fp8
(
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
),
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
else
:
self
.
_graph_decoder_tile_scheduler_metadata
,
\
self
.
_graph_decode_num_splits
=
get_mla_metadata
(
torch
.
ones
(
...
...
@@ -128,6 +147,13 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
batch_size
,
is_encoder_decoder_model
)
assert
metadata
.
num_decode_tokens
>
0
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
decoder_tile_scheduler_metadata
,
decode_num_splits
=
get_mla_decoding_metadata_dense_fp8
(
self
.
_graph_seq_lens
[:
batch_size
],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
else
:
decoder_tile_scheduler_metadata
,
decode_num_splits
=
get_mla_metadata
(
self
.
_graph_seq_lens
[:
batch_size
],
self
.
num_q_heads
,
...
...
@@ -222,6 +248,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLAMetadata
,
q_scale
=
None
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
)
->
torch
.
Tensor
:
...
...
@@ -233,6 +260,21 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
decode_meta
.
decode_tile_scheduler_metadata
,
num_splits
=
decode_meta
.
decode_num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
o
,
_
=
flash_mla_with_kvcache
(
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
...
...
@@ -246,5 +288,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
return
self
.
_v_up_proj
(
o
)
vllm/attention/backends/mla/common.py
View file @
a9f57e73
...
...
@@ -1404,6 +1404,5 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
output
[
num_prefill_tokens
:]
=
self
.
_forward_decode
(
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_q_scale
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
return
output
\ No newline at end of file
vllm/attention/ops/flashmla.py
View file @
a9f57e73
...
...
@@ -69,6 +69,27 @@ def get_mla_metadata(
num_heads_k
)
def
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
:
torch
.
Tensor
,
num_heads_per_head_k
:
int
,
num_heads_k
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Return:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return
flash_mla_cuda
.
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
,
num_heads_per_head_k
,
num_heads_k
)
def
flash_mla_with_kvcache
(
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
...
...
@@ -199,6 +220,59 @@ def flash_mla_with_kvcache_q_nope_pe(
return
out
,
softmax_lse
def
flash_mla_with_kvcache_fp8
(
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
head_dim_v
:
int
,
tile_scheduler_metadata
:
torch
.
Tensor
,
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
descale_q
:
Optional
[
torch
.
Tensor
]
=
None
,
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
torch.int32, return by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, return by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla_fp8
(
q
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
descale_q
,
descale_k
,
)
return
out
,
softmax_lse
#
# TODO: Add fake functions
#
...
...
vllm/envs.py
View file @
a9f57e73
...
...
@@ -146,6 +146,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA_FP8
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
...
...
@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_MLA"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA"
,
"1"
))),
# If set, vLLM will use FLASH MLA fp8 attention optimizations.
"VLLM_USE_FLASH_MLA_FP8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA_FP8"
,
"0"
))),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_OP"
,
"True"
).
lower
()
in
...
...
vllm/model_executor/model_loader/utils.py
View file @
a9f57e73
...
...
@@ -255,8 +255,8 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_REJECT_SAMPLE_OPT"
):
os
.
environ
[
'VLLM_REJECT_SAMPLE_OPT'
]
=
'1'
#
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
#
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
os
.
environ
[
'VLLM_SCHED_ENABLE_MINIMAL_INJECTION'
]
=
'1'
if
model_config
.
quantization
in
{
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
,
"compressed-tensors"
}:
...
...
@@ -300,8 +300,8 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_REJECT_SAMPLE_OPT"
):
os
.
environ
[
'VLLM_REJECT_SAMPLE_OPT'
]
=
'1'
#
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
#
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
os
.
environ
[
'VLLM_SCHED_ENABLE_MINIMAL_INJECTION'
]
=
'1'
if
model_config
.
quantization
in
{
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
,
"compressed-tensors"
}:
...
...
vllm/v1/attention/backends/mla/common.py
View file @
a9f57e73
...
...
@@ -1199,6 +1199,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
output
[:
num_decode_tokens
]
=
self
.
_forward_decode
(
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_q_scale
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
return
output_padded
\ No newline at end of file
vllm/v1/attention/backends/mla/flashmla.py
View file @
a9f57e73
...
...
@@ -11,6 +11,8 @@ from vllm.attention.backends.abstract import (AttentionType,
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
flash_mla_with_kvcache_q_nope_pe
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
...
...
@@ -71,6 +73,14 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
)
->
FlashMLADecodeMetadata
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
tile_scheduler_metadata
,
num_splits
=
\
get_mla_decoding_metadata_dense_fp8
(
seq_lens
,
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
else
:
tile_scheduler_metadata
,
num_splits
=
\
get_mla_metadata
(
seq_lens
,
...
...
@@ -162,12 +172,42 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLAMetadata
,
q_scale
=
None
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
if
not
envs
.
VLLM_USE_CAT_MLA
or
kv_cache_dtype
==
"fp8_e4m3"
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment