Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0f67d4d9
Unverified
Commit
0f67d4d9
authored
Oct 24, 2025
by
Ming Yang
Committed by
GitHub
Oct 24, 2025
Browse files
[Attention] Add MLA prefill backend: trtllm_ragged_attention_deepseek (#26397)
Signed-off-by:
Ming Yang
<
minos.future@gmail.com
>
parent
7e1d697b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
113 additions
and
2 deletions
+113
-2
vllm/envs.py
vllm/envs.py
+6
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+107
-2
No files found.
vllm/envs.py
View file @
0f67d4d9
...
...
@@ -183,6 +183,7 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB
:
int
|
None
=
None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT
:
int
=
480
VLLM_USE_CUDNN_PREFILL
:
bool
=
False
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
:
bool
=
False
VLLM_ENABLE_CUDAGRAPH_GC
:
bool
=
False
VLLM_LOOPBACK_IP
:
str
=
""
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE
:
bool
=
False
...
...
@@ -1250,6 +1251,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_CUDNN_PREFILL"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_CUDNN_PREFILL"
,
"0"
))
),
# Controls whether to use TRT-LLM ragged DeepSeek prefill
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL"
,
"0"
))
),
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
# If set to 0/False, use the default attention backend in flashinfer.
# If not set, auto-detect the attention backend in flashinfer.
...
...
@@ -1481,6 +1486,7 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS"
,
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"
,
"VLLM_USE_CUDNN_PREFILL"
,
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL"
,
"VLLM_USE_TRTLLM_ATTENTION"
,
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION"
,
"VLLM_ROCM_USE_AITER"
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
0f67d4d9
...
...
@@ -371,6 +371,7 @@ class MLACommonPrefillMetadata:
query_start_loc
:
torch
.
Tensor
max_query_len
:
int
chunked_context
:
ChunkedContextMetadata
|
None
=
None
query_seq_lens
:
torch
.
Tensor
|
None
=
None
@
dataclass
...
...
@@ -386,7 +387,6 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
class
ChunkedContextMetadata
(
MLACommonPrefillMetadata
.
ChunkedContextMetadata
):
seq_lens
:
torch
.
Tensor
query_seq_lens
:
torch
.
Tensor
|
None
=
None
cudnn_workspace
:
torch
.
Tensor
|
None
=
None
...
...
@@ -457,6 +457,7 @@ def use_flashinfer_prefill() -> bool:
not
envs
.
VLLM_DISABLE_FLASHINFER_PREFILL
and
flashinfer_available
and
not
envs
.
VLLM_USE_CUDNN_PREFILL
and
not
envs
.
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
and
current_platform
.
is_device_capability
(
100
)
)
...
...
@@ -470,6 +471,15 @@ def use_cudnn_prefill() -> bool:
)
def
use_trtllm_ragged_deepseek_prefill
()
->
bool
:
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
return
(
flashinfer_available
and
envs
.
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
and
current_platform
.
is_device_capability
(
100
)
)
# Currently 394MB, this can be tuned based on GEMM sizes used.
# Chosen to be the same as sglang:
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
...
...
@@ -593,6 +603,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self
.
_use_cudnn_prefill
=
use_cudnn_prefill
()
self
.
_use_fi_prefill
=
use_flashinfer_prefill
()
self
.
_use_trtllm_ragged_prefill
=
use_trtllm_ragged_deepseek_prefill
()
self
.
prefill_metadata_cls
=
(
FlashInferPrefillMetadata
if
self
.
_use_fi_prefill
...
...
@@ -613,6 +624,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
get_per_layer_parameters
(
vllm_config
,
layer_names
,
MLACommonImpl
)
)
if
self
.
_use_trtllm_ragged_prefill
:
self
.
_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
device
)
if
self
.
_use_cudnn_prefill
:
self
.
cudnn_workspace
=
torch
.
empty
(
CUDNN_WORKSPACE_SIZE
*
scheduler_config
.
max_num_seqs
,
...
...
@@ -934,6 +950,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
)
prefill_metadata
.
cudnn_workspace
=
self
.
cudnn_workspace
if
self
.
_use_trtllm_ragged_prefill
:
prefill_metadata
.
query_seq_lens
=
(
prefill_query_start_loc
[
1
:]
-
prefill_query_start_loc
[:
-
1
]
)
decode_metadata
=
None
if
num_decodes
>
0
:
decode_metadata
=
self
.
_build_decode
(
...
...
@@ -1230,6 +1251,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self
.
_run_prefill_context_chunk
=
self
.
_run_prefill_context_chunk_fi
self
.
_run_prefill_new_tokens
=
self
.
_run_prefill_new_tokens_fi
self
.
_pad_v
=
False
elif
use_trtllm_ragged_deepseek_prefill
():
logger
.
debug_once
(
"Using TRT-LLM ragged DeepSeek prefill for MLA"
)
self
.
_run_prefill_context_chunk
=
(
self
.
_run_prefill_context_chunk_trtllm_ragged
)
self
.
_run_prefill_new_tokens
=
self
.
_run_prefill_new_tokens_trtllm_ragged
self
.
_pad_v
=
False
elif
use_cudnn_prefill
():
logger
.
debug_once
(
"Using CUDNN prefill for MLA"
)
self
.
_run_prefill_context_chunk
=
self
.
_run_prefill_context_chunk_cudnn
...
...
@@ -1326,6 +1354,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
):
assert
isinstance
(
prefill
,
FlashInferPrefillMetadata
)
assert
prefill
.
prefill_main
is
not
None
ret
=
prefill
.
prefill_main
.
run
(
q
=
q
,
k
=
k
,
...
...
@@ -1334,7 +1363,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
if
isinstance
(
ret
,
tuple
):
# Convert from (q_len, num_heads) to (num_heads, q_len)
return
ret
[
0
],
ret
[
1
].
transpose
(
0
,
1
).
contiguous
()
return
ret
...
...
@@ -1384,12 +1412,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
):
assert
isinstance
(
prefill
,
FlashInferPrefillMetadata
)
attn_out
,
lse
=
prefill
.
prefill_chunks
[
chunk_idx
].
run
(
q
=
q
,
k
=
k
,
v
=
v
,
return_lse
=
True
,
)
# Convert from (q_len, num_heads) to (num_heads, q_len)
return
attn_out
,
lse
.
transpose
(
0
,
1
).
contiguous
()
...
...
@@ -1418,6 +1448,81 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
is_cuda_graph_compatible
=
True
,
)
def
_run_prefill_new_tokens_trtllm_ragged
(
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
):
"""TRT-LLM ragged attention for new tokens (causal)."""
from
flashinfer.prefill
import
trtllm_ragged_attention_deepseek
assert
prefill
.
query_seq_lens
is
not
None
ret
=
trtllm_ragged_attention_deepseek
(
query
=
q
,
key
=
k
,
value
=
v
,
workspace_buffer
=
self
.
_workspace_buffer
,
seq_lens
=
prefill
.
query_seq_lens
,
max_q_len
=
prefill
.
max_query_len
,
max_kv_len
=
prefill
.
max_query_len
,
bmm1_scale
=
self
.
scale
,
bmm2_scale
=
1.0
,
o_sf_scale
=
1.0
,
batch_size
=
prefill
.
query_seq_lens
.
shape
[
0
],
window_left
=-
1
,
cum_seq_lens_q
=
prefill
.
query_start_loc
,
cum_seq_lens_kv
=
prefill
.
query_start_loc
,
enable_pdl
=
False
,
is_causal
=
True
,
return_lse
=
return_softmax_lse
,
)
if
isinstance
(
ret
,
tuple
):
# Convert from (q_len, num_heads) to (num_heads, q_len)
return
ret
[
0
],
ret
[
1
].
transpose
(
0
,
1
).
contiguous
()
return
ret
def
_run_prefill_context_chunk_trtllm_ragged
(
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
):
"""TRT-LLM ragged attention for context chunks (non-causal)."""
from
flashinfer.prefill
import
trtllm_ragged_attention_deepseek
assert
prefill
.
chunked_context
is
not
None
assert
prefill
.
chunked_context
.
seq_lens
[
chunk_idx
]
is
not
None
out
=
torch
.
zeros
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v
.
shape
[
2
],
device
=
q
.
device
,
dtype
=
q
.
dtype
,
)
self
.
_workspace_buffer
.
fill_
(
0
)
attn_out
,
lse
=
trtllm_ragged_attention_deepseek
(
query
=
q
,
key
=
k
,
value
=
v
,
workspace_buffer
=
self
.
_workspace_buffer
,
seq_lens
=
prefill
.
chunked_context
.
seq_lens
[
chunk_idx
],
max_q_len
=
prefill
.
max_query_len
,
max_kv_len
=
prefill
.
chunked_context
.
max_seq_lens
[
chunk_idx
],
bmm1_scale
=
self
.
scale
,
bmm2_scale
=
1.0
,
o_sf_scale
=
1.0
,
batch_size
=
prefill
.
chunked_context
.
seq_lens
[
chunk_idx
].
shape
[
0
],
window_left
=-
1
,
cum_seq_lens_q
=
prefill
.
query_start_loc
,
cum_seq_lens_kv
=
prefill
.
chunked_context
.
cu_seq_lens
[
chunk_idx
],
enable_pdl
=
False
,
is_causal
=
False
,
return_lse
=
True
,
out
=
out
,
)
# Convert from (q_len, num_heads) to (num_heads, q_len)
return
attn_out
,
lse
.
transpose
(
0
,
1
).
contiguous
()
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
def
get_layer_weight
(
layer
):
WEIGHT_NAMES
=
(
"weight"
,
"qweight"
,
"weight_packed"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment