Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1fb632fd
Unverified
Commit
1fb632fd
authored
Dec 08, 2025
by
Lain
Committed by
GitHub
Dec 08, 2025
Browse files
[Perf] Improve fp8 quant in mla; replace ReduceSum with ReduceScatterSum (#29795)
Signed-off-by:
Siyuan Fu
<
siyuanf@nvidia.com
>
parent
6af70e11
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
13 deletions
+22
-13
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+1
-1
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+21
-12
No files found.
vllm/distributed/device_communicators/cuda_communicator.py
View file @
1fb632fd
...
@@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
output_shape
,
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
output_shape
,
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
)
)
if
sizes
is
not
None
:
if
sizes
is
not
None
and
sizes
.
count
(
sizes
[
0
])
!=
len
(
sizes
)
:
pynccl_comm
.
reduce_scatterv
(
output
,
input_tensor
,
sizes
=
sizes
)
pynccl_comm
.
reduce_scatterv
(
output
,
input_tensor
,
sizes
=
sizes
)
else
:
else
:
pynccl_comm
.
reduce_scatter
(
output
,
input_tensor
)
pynccl_comm
.
reduce_scatter
(
output
,
input_tensor
)
...
...
vllm/v1/attention/backends/mla/common.py
View file @
1fb632fd
...
@@ -2037,21 +2037,30 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -2037,21 +2037,30 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
if
fp8_attention
:
if
fp8_attention
:
ql_nope_shape
=
decode_ql_nope
.
shape
ql_nope_shape
=
decode_ql_nope
.
shape
decode_ql_nope
,
_
=
ops
.
scaled_fp8_quant
(
decode_ql_nope
.
reshape
(
[
ql_nope_shape
[
0
],
ql_nope_shape
[
1
]
*
ql_nope_shape
[
2
]]
),
layer
.
_q_scale
,
)
decode_ql_nope
=
decode_ql_nope
.
reshape
(
ql_nope_shape
)
q_pe_shape
=
decode_q_pe
.
shape
q_pe_shape
=
decode_q_pe
.
shape
decode_q_pe
,
_
=
ops
.
scaled_fp8_quant
(
assert
decode_ql_nope
.
shape
[
0
]
==
decode_q_pe
.
shape
[
0
]
decode_q_pe
.
reshape
([
q_pe_shape
[
0
],
q_pe_shape
[
1
]
*
q_pe_shape
[
2
]]),
assert
decode_ql_nope
.
shape
[
1
]
==
decode_q_pe
.
shape
[
1
]
layer
.
_q_scale
,
decode_q_shape
=
(
ql_nope_shape
[
0
],
ql_nope_shape
[
1
],
ql_nope_shape
[
2
]
+
q_pe_shape
[
2
],
)
# Using empty and copy since torch.cat introduces significant overhead.
decode_q0
=
torch
.
empty
(
decode_q_shape
,
device
=
decode_ql_nope
.
device
,
dtype
=
decode_ql_nope
.
dtype
,
)
)
decode_q_pe
=
decode_q_pe
.
reshape
(
q_pe_shape
)
decode_q0
[...,
:
ql_nope_shape
[
2
]].
copy_
(
decode_ql_nope
)
decode_q0
[...,
ql_nope_shape
[
2
]
:].
copy_
(
decode_q_pe
)
decode_q
=
(
decode_ql_nope
,
decode_q_pe
)
decode_q
,
_
=
ops
.
scaled_fp8_quant
(
decode_q0
.
view
(
decode_q_shape
[
0
],
-
1
),
layer
.
_q_scale
,
)
decode_q
=
decode_q
.
view
(
decode_q_shape
)
else
:
decode_q
=
(
decode_ql_nope
,
decode_q_pe
)
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
assert
not
fp8_attention
,
"DCP not support fp8 kvcache now."
assert
not
fp8_attention
,
"DCP not support fp8 kvcache now."
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment