Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3e102623
Unverified
Commit
3e102623
authored
Dec 22, 2025
by
Pavani Majety
Committed by
GitHub
Dec 22, 2025
Browse files
Revert "[SM100] Enable fp8 compute for prefill MLA (#30746)" (#31197)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
612d5ffd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
116 deletions
+18
-116
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+3
-4
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+1
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+14
-112
No files found.
tests/v1/attention/test_mla_backends.py
View file @
3e102623
...
...
@@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.attention.backends.mla.common
import
QueryLenSupport
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
MLA
AttentionSpec
from
vllm.v1.kv_cache_interface
import
Full
AttentionSpec
BACKENDS_TO_TEST
=
[
AttentionBackendEnum
.
CUTLASS_MLA
,
...
...
@@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def
run_attention_backend
(
backend
:
AttentionBackendEnum
,
kv_cache_spec
:
MLA
AttentionSpec
,
kv_cache_spec
:
Full
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
,
device
:
torch
.
device
,
...
...
@@ -740,7 +740,7 @@ def test_backend_correctness(
kv_cache
=
kv_cache_per_block_size
[
block_size
]
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec
=
MLA
AttentionSpec
(
backend_kv_cache_spec
=
Full
AttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
...
...
@@ -748,7 +748,6 @@ def test_backend_correctness(
head_size
=
vllm_config
.
model_config
.
get_head_size
(),
dtype
=
vllm_config
.
model_config
.
dtype
,
sliding_window
=
vllm_config
.
model_config
.
get_sliding_window
(),
cache_dtype_str
=
vllm_config
.
cache_config
.
cache_dtype
,
)
backend_output
=
run_attention_backend
(
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
View file @
3e102623
...
...
@@ -325,6 +325,7 @@ def flashinfer_trtllm_fp4_moe(
local_expert_offset
=
layer
.
ep_rank
*
layer
.
local_num_experts
,
local_num_experts
=
layer
.
local_num_experts
,
routed_scaling_factor
=
None
,
tile_tokens_dim
=
None
,
routing_method_type
=
routing_method_type
,
do_finalize
=
True
,
)[
0
]
...
...
vllm/v1/attention/backends/mla/common.py
View file @
3e102623
...
...
@@ -541,11 +541,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
metadata_cls
if
metadata_cls
is
not
None
else
MLACommonMetadata
)
self
.
kv_cache_spec
=
kv_cache_spec
self
.
q_data_type
=
(
current_platform
.
fp8_dtype
()
if
(
kv_cache_spec
is
not
None
and
"fp8"
in
kv_cache_spec
.
cache_dtype_str
)
else
vllm_config
.
model_config
.
dtype
)
scheduler_config
=
vllm_config
.
scheduler_config
self
.
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
...
...
@@ -689,6 +684,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# For main run, qo_indptr == kv_indptr
kv_indptr
=
qo_indptr
.
clone
()
# Prepare main prefill
self
.
_fi_prefill_main
.
plan
(
qo_indptr
=
qo_indptr
,
...
...
@@ -701,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale
=
self
.
_global_hyperparameters
.
sm_scale
,
window_left
=
self
.
_global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
_global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_
type
,
q_data_type
=
self
.
model_config
.
d
type
,
)
# Prepare context prefills
...
...
@@ -720,7 +716,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale
=
self
.
_global_hyperparameters
.
sm_scale
,
window_left
=
self
.
_global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
_global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_
type
,
q_data_type
=
self
.
model_config
.
d
type
,
)
prefill
.
prefill_main
=
self
.
_fi_prefill_main
...
...
@@ -973,7 +969,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc
=
prefill_query_start_loc
,
max_query_len
=
max_query_len
,
chunked_context
=
chunked_context_metadata
,
q_data_type
=
self
.
q_data_type
,
)
if
self
.
_use_cudnn_prefill
:
...
...
@@ -1384,15 +1379,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return
attn_out
def
_run_prefill_new_tokens_fa
(
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
,
fp8_attention
:
bool
,
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
):
logger
.
debug_once
(
"Running FlashAttention prefill new tokens"
,
scope
=
"local"
)
return
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
...
...
@@ -1407,23 +1395,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def
_run_prefill_new_tokens_fi
(
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
,
fp8_attention
:
bool
,
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
):
logger
.
debug_once
(
"Running FlashInfer prefill new tokens"
,
scope
=
"local"
)
assert
isinstance
(
prefill
,
FlashInferPrefillMetadata
)
assert
prefill
.
prefill_main
is
not
None
if
fp8_attention
:
logger
.
debug_once
(
"Running Flashinfer prefill in FP8"
)
fp8_dtype
=
current_platform
.
fp8_dtype
()
q
=
q
.
to
(
fp8_dtype
)
k
=
k
.
to
(
fp8_dtype
)
v
=
v
.
to
(
fp8_dtype
)
ret
=
prefill
.
prefill_main
.
run
(
q
=
q
,
k
=
k
,
...
...
@@ -1436,18 +1412,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return
ret
def
_run_prefill_new_tokens_cudnn
(
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
,
fp8_attention
:
bool
,
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
):
logger
.
debug_once
(
"Running Cudnn prefill new tokens"
,
scope
=
"local"
)
assert
isinstance
(
prefill
,
CudnnPrefillMetadata
)
assert
prefill
.
query_seq_lens
is
not
None
assert
fp8_attention
is
False
,
"Cudnn prefill does not support fp8 attention"
output
,
lse
=
cudnn_batch_prefill_with_kv_cache
(
q
=
q
,
k_cache
=
k
,
...
...
@@ -1469,19 +1437,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return
output
def
_run_prefill_context_chunk_fa
(
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
,
fp8_attention
:
bool
,
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
):
logger
.
debug_once
(
"Running FlashAttention prefill context chunk"
,
scope
=
"local"
)
assert
prefill
.
chunked_context
is
not
None
assert
fp8_attention
is
False
,
(
"FlashAttention prefill does not support fp8 attention"
)
return
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
...
...
@@ -1496,22 +1454,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def
_run_prefill_context_chunk_fi
(
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
,
fp8_attention
:
bool
,
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
):
logger
.
debug_once
(
"Running FlashInfer prefill context chunk"
,
scope
=
"local"
)
assert
isinstance
(
prefill
,
FlashInferPrefillMetadata
)
if
fp8_attention
:
logger
.
debug_once
(
"Running FlashInfer prefill in FP8"
)
fp8_dtype
=
current_platform
.
fp8_dtype
()
q
=
q
.
to
(
fp8_dtype
)
k
=
k
.
to
(
fp8_dtype
)
v
=
v
.
to
(
fp8_dtype
)
attn_out
,
lse
=
prefill
.
prefill_chunks
[
chunk_idx
].
run
(
q
=
q
,
k
=
k
,
...
...
@@ -1523,20 +1469,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return
attn_out
,
lse
.
transpose
(
0
,
1
).
contiguous
()
def
_run_prefill_context_chunk_cudnn
(
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
,
fp8_attention
:
bool
,
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
):
logger
.
debug_once
(
"Running Cudnn prefill context chunk"
,
scope
=
"local"
)
assert
isinstance
(
prefill
,
CudnnPrefillMetadata
)
assert
prefill
.
chunked_context
is
not
None
assert
prefill
.
chunked_context
.
seq_lens
[
chunk_idx
]
is
not
None
assert
prefill
.
query_seq_lens
is
not
None
assert
fp8_attention
is
False
,
"Cudnn prefill does not support fp8 attention"
return
cudnn_batch_prefill_with_kv_cache
(
q
=
q
,
k_cache
=
k
,
...
...
@@ -1556,28 +1494,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def
_run_prefill_new_tokens_trtllm_ragged
(
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
,
fp8_attention
:
bool
,
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
):
logger
.
debug_once
(
"Running TRT-LLM ragged prefill new tokens"
,
scope
=
"local"
)
"""TRT-LLM ragged attention for new tokens (causal)."""
from
flashinfer.prefill
import
trtllm_ragged_attention_deepseek
assert
prefill
.
query_seq_lens
is
not
None
assert
prefill
.
workspace_buffer
is
not
None
if
fp8_attention
:
logger
.
debug_once
(
"Running TRT-LLM ragged prefill in FP8"
)
fp8_dtype
=
current_platform
.
fp8_dtype
()
q
=
q
.
to
(
fp8_dtype
)
k
=
k
.
to
(
fp8_dtype
)
v
=
v
.
to
(
fp8_dtype
)
ret
=
trtllm_ragged_attention_deepseek
(
query
=
q
,
key
=
k
,
...
...
@@ -1604,15 +1528,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return
ret
def
_run_prefill_context_chunk_trtllm_ragged
(
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
,
fp8_attention
:
bool
,
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
):
logger
.
debug_once
(
"Running TRT-LLM ragged prefill context chunk"
,
scope
=
"local"
)
"""TRT-LLM ragged attention for context chunks (non-causal)."""
from
flashinfer.prefill
import
trtllm_ragged_attention_deepseek
...
...
@@ -1629,13 +1546,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
prefill
.
workspace_buffer
.
fill_
(
0
)
if
fp8_attention
:
logger
.
debug_once
(
"Running TRT-LLM ragged prefill context chunk in FP8"
)
fp8_dtype
=
current_platform
.
fp8_dtype
()
q
=
q
.
to
(
fp8_dtype
)
k
=
k
.
to
(
fp8_dtype
)
v
=
v
.
to
(
fp8_dtype
)
attn_out
,
lse
=
trtllm_ragged_attention_deepseek
(
query
=
q
,
key
=
k
,
...
...
@@ -1788,7 +1698,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
k_scale
:
torch
.
Tensor
,
fp8_attention
:
bool
,
):
assert
attn_metadata
.
prefill
is
not
None
prefill_metadata
=
attn_metadata
.
prefill
...
...
@@ -1827,7 +1736,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q
=
q
,
k
=
k
,
v
=
v
,
fp8_attention
=
fp8_attention
,
)
if
output
is
None
:
...
...
@@ -1856,7 +1764,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata
:
MLACommonMetadata
,
k_scale
:
torch
.
Tensor
,
dcp_world_size
:
int
,
fp8_attention
:
bool
,
):
assert
k_scale
is
None
,
"DCP not support scaled kvcache now."
assert
attn_metadata
.
prefill
is
not
None
...
...
@@ -1933,7 +1840,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q
=
q
,
k
=
k
,
v
=
v
,
fp8_attention
=
fp8_attention
,
)
if
output
is
None
:
...
...
@@ -1964,7 +1870,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata
:
MLACommonMetadata
,
k_scale
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
fp8_attention
:
bool
=
False
,
)
->
None
:
# TODO (zyongye): Prefill function here
assert
attn_metadata
.
prefill
is
not
None
...
...
@@ -1984,7 +1889,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k
=
k
,
v
=
v
,
return_softmax_lse
=
has_context
,
fp8_attention
=
fp8_attention
,
)
if
has_context
:
...
...
@@ -1997,12 +1901,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata
,
k_scale
=
None
,
dcp_world_size
=
self
.
dcp_world_size
,
fp8_attention
=
fp8_attention
,
)
)
else
:
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
q
,
kv_c_and_k_pe_cache
,
attn_metadata
,
k_scale
,
fp8_attention
q
,
kv_c_and_k_pe_cache
,
attn_metadata
,
k_scale
)
# unpad if necessary
...
...
@@ -2123,7 +2026,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata
,
layer
.
_k_scale
,
output
=
output
[
num_decode_tokens
:],
fp8_attention
=
fp8_attention
,
)
if
has_decode
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment