Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da5e7b12
Unverified
Commit
da5e7b12
authored
Jan 24, 2026
by
Lucas Wilkinson
Committed by
GitHub
Jan 24, 2026
Browse files
[MLA] Fuse cat and qaunt for fp8 kv-cache (#32950)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
719ac592
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
20 deletions
+41
-20
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+41
-20
No files found.
vllm/model_executor/layers/attention/mla_attention.py
View file @
da5e7b12
...
@@ -202,13 +202,16 @@ from vllm._aiter_ops import rocm_aiter_ops
...
@@ -202,13 +202,16 @@ from vllm._aiter_ops import rocm_aiter_ops
from
vllm.config
import
ModelConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.config
import
ModelConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.distributed.parallel_state
import
get_dcp_group
,
is_global_first_rank
from
vllm.distributed.parallel_state
import
get_dcp_group
,
is_global_first_rank
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.batch_invariant
import
(
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
vllm_is_batch_invariant
,
)
)
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
get_and_maybe_dequant_weights
,
get_and_maybe_dequant_weights
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -287,6 +290,37 @@ def dynamic_per_batched_tensor_quant(
...
@@ -287,6 +290,37 @@ def dynamic_per_batched_tensor_quant(
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
CustomOp
.
register
(
"mla_decode_concat_quant_fp8"
)
class
_DecodeConcatQuantFP8
(
QuantFP8
):
"""
QuantFP8 variant that concatenates decode_ql_nope and decode_q_pe before
quantization. When disabled, forward_native is compiled via torch.compile,
fusing cat/reshape/quant/view together.
"""
def
_make_forward
(
quant_fn
):
# noqa: N805
"""Factory to create forward methods that concat before quantization."""
def
forward
(
self
,
decode_ql_nope
:
torch
.
Tensor
,
decode_q_pe
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale_ub
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
decode_q0
=
torch
.
cat
((
decode_ql_nope
,
decode_q_pe
),
dim
=-
1
)
decode_q_flat
=
decode_q0
.
reshape
(
decode_q0
.
shape
[
0
],
-
1
)
decode_q
,
_
=
quant_fn
(
self
,
decode_q_flat
,
scale
,
scale_ub
)
return
decode_q
.
view
(
decode_q0
.
shape
)
return
forward
forward_native
=
_make_forward
(
QuantFP8
.
forward_native
)
# type: ignore[arg-type]
forward_cuda
=
_make_forward
(
QuantFP8
.
forward_cuda
)
# type: ignore[arg-type]
forward_hip
=
_make_forward
(
QuantFP8
.
forward_hip
)
# type: ignore[arg-type]
CUDNN_WORKSPACE_SIZE
=
12800
CUDNN_WORKSPACE_SIZE
=
12800
...
@@ -1398,6 +1432,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1398,6 +1432,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self
.
cp_kv_cache_interleave_size
:
int
=
(
self
.
cp_kv_cache_interleave_size
:
int
=
(
get_current_vllm_config
().
parallel_config
.
cp_kv_cache_interleave_size
get_current_vllm_config
().
parallel_config
.
cp_kv_cache_interleave_size
)
)
self
.
_decode_concat_quant_fp8_op
=
_DecodeConcatQuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
compile_native
=
True
,
)
def
_flash_attn_varlen_diff_headdims
(
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
return_softmax_lse
=
False
,
softmax_scale
=
None
,
**
kwargs
self
,
q
,
k
,
v
,
return_softmax_lse
=
False
,
softmax_scale
=
None
,
**
kwargs
...
@@ -2048,29 +2087,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -2048,29 +2087,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
if
fp8_attention
:
if
fp8_attention
:
ql_nope_shape
=
decode_ql_nope
.
shape
q_pe_shape
=
decode_q_pe
.
shape
assert
decode_ql_nope
.
shape
[
0
]
==
decode_q_pe
.
shape
[
0
]
assert
decode_ql_nope
.
shape
[
0
]
==
decode_q_pe
.
shape
[
0
]
assert
decode_ql_nope
.
shape
[
1
]
==
decode_q_pe
.
shape
[
1
]
assert
decode_ql_nope
.
shape
[
1
]
==
decode_q_pe
.
shape
[
1
]
decode_q_shape
=
(
decode_q
=
self
.
_decode_concat_quant_fp8_op
(
ql_nope_shape
[
0
],
decode_ql_nope
,
decode_q_pe
,
layer
.
_q_scale
ql_nope_shape
[
1
],
ql_nope_shape
[
2
]
+
q_pe_shape
[
2
],
)
# Using empty and copy since torch.cat introduces significant overhead.
decode_q0
=
torch
.
empty
(
decode_q_shape
,
device
=
decode_ql_nope
.
device
,
dtype
=
decode_ql_nope
.
dtype
,
)
decode_q0
[...,
:
ql_nope_shape
[
2
]].
copy_
(
decode_ql_nope
)
decode_q0
[...,
ql_nope_shape
[
2
]
:].
copy_
(
decode_q_pe
)
decode_q
,
_
=
ops
.
scaled_fp8_quant
(
decode_q0
.
view
(
decode_q_shape
[
0
],
-
1
),
layer
.
_q_scale
,
)
)
decode_q
=
decode_q
.
view
(
decode_q_shape
)
else
:
else
:
decode_q
=
(
decode_ql_nope
,
decode_q_pe
)
decode_q
=
(
decode_ql_nope
,
decode_q_pe
)
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment