Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9726e645
Unverified
Commit
9726e645
authored
Nov 29, 2025
by
Augusto Yao
Committed by
GitHub
Nov 29, 2025
Browse files
bugfix: correct attn output with base 2 or e (#28840)
Signed-off-by:
augusto.yjh
<
augusto.yjh@antgroup.com
>
parent
3fd1fb0b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
13 deletions
+47
-13
vllm/attention/ops/common.py
vllm/attention/ops/common.py
+32
-10
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+9
-2
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+6
-1
No files found.
vllm/attention/ops/common.py
View file @
9726e645
...
@@ -21,6 +21,7 @@ def _correct_attn_cp_out_kernel(
...
@@ -21,6 +21,7 @@ def _correct_attn_cp_out_kernel(
lse_idx
,
lse_idx
,
HEAD_DIM
:
tl
.
constexpr
,
HEAD_DIM
:
tl
.
constexpr
,
N_ROUNDED
:
tl
.
constexpr
,
N_ROUNDED
:
tl
.
constexpr
,
IS_BASE_E
:
tl
.
constexpr
,
):
):
"""
"""
Apply the all-gathered lses to correct each local rank's attention
Apply the all-gathered lses to correct each local rank's attention
...
@@ -55,9 +56,14 @@ def _correct_attn_cp_out_kernel(
...
@@ -55,9 +56,14 @@ def _correct_attn_cp_out_kernel(
lse_max
=
tl
.
max
(
lse
,
axis
=
0
)
lse_max
=
tl
.
max
(
lse
,
axis
=
0
)
lse_max
=
tl
.
where
(
lse_max
==
-
float
(
"inf"
),
0
,
lse_max
)
lse_max
=
tl
.
where
(
lse_max
==
-
float
(
"inf"
),
0
,
lse_max
)
lse
-=
lse_max
lse
-=
lse_max
if
IS_BASE_E
:
lse_exp
=
tl
.
exp
(
lse
)
lse_exp
=
tl
.
exp
(
lse
)
lse_acc
=
tl
.
sum
(
lse_exp
,
axis
=
0
)
lse_acc
=
tl
.
sum
(
lse_exp
,
axis
=
0
)
lse
=
tl
.
log
(
lse_acc
)
lse
=
tl
.
log
(
lse_acc
)
else
:
lse_exp
=
tl
.
exp2
(
lse
)
lse_acc
=
tl
.
sum
(
lse_exp
,
axis
=
0
)
lse
=
tl
.
log2
(
lse_acc
)
lse
+=
lse_max
lse
+=
lse_max
lse_offsets
=
batch_idx
*
lses_stride_B
+
head_idx
*
lses_stride_H
lse_offsets
=
batch_idx
*
lses_stride_B
+
head_idx
*
lses_stride_H
...
@@ -81,7 +87,7 @@ def _correct_attn_cp_out_kernel(
...
@@ -81,7 +87,7 @@ def _correct_attn_cp_out_kernel(
-
float
(
"inf"
),
-
float
(
"inf"
),
lse_finally
,
lse_finally
,
)
)
factor
=
tl
.
exp
(
lse_finally
)
factor
=
tl
.
exp
(
lse_finally
)
if
IS_BASE_E
else
tl
.
exp2
(
lse_finally
)
output
=
tl
.
load
(
outputs_ptr
+
output_offsets
)
output
=
tl
.
load
(
outputs_ptr
+
output_offsets
)
output
=
output
*
factor
output
=
output
*
factor
...
@@ -102,7 +108,11 @@ class CPTritonContext:
...
@@ -102,7 +108,11 @@ class CPTritonContext:
def
correct_attn_out
(
def
correct_attn_out
(
out
:
torch
.
Tensor
,
lses
:
torch
.
Tensor
,
cp_rank
:
int
,
ctx
:
CPTritonContext
out
:
torch
.
Tensor
,
lses
:
torch
.
Tensor
,
cp_rank
:
int
,
ctx
:
CPTritonContext
,
is_lse_base_on_e
:
bool
=
True
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Correct the attention output using the all-gathered lses.
"""Correct the attention output using the all-gathered lses.
...
@@ -163,8 +173,7 @@ def correct_attn_out(
...
@@ -163,8 +173,7 @@ def correct_attn_out(
l_sH
,
l_sH
,
cp_rank
,
cp_rank
,
)
)
const_args
=
{
"HEAD_DIM"
:
D
,
"N_ROUNDED"
:
N
}
const_args
=
{
"HEAD_DIM"
:
D
,
"N_ROUNDED"
:
N
,
"IS_BASE_E"
:
is_lse_base_on_e
}
ctx
.
call_kernel
(
_correct_attn_cp_out_kernel
,
grid
,
*
regular_args
,
**
const_args
)
ctx
.
call_kernel
(
_correct_attn_cp_out_kernel
,
grid
,
*
regular_args
,
**
const_args
)
return
out
,
lse
return
out
,
lse
...
@@ -174,6 +183,7 @@ def _cp_lse_common(
...
@@ -174,6 +183,7 @@ def _cp_lse_common(
cp_attn_lse
:
torch
.
Tensor
,
cp_attn_lse
:
torch
.
Tensor
,
cp_group
:
GroupCoordinator
,
cp_group
:
GroupCoordinator
,
ctx
:
CPTritonContext
|
None
=
None
,
ctx
:
CPTritonContext
|
None
=
None
,
is_lse_base_on_e
=
True
,
):
):
"""
"""
cp_attn_out: [ B, H, D ]
cp_attn_out: [ B, H, D ]
...
@@ -193,7 +203,13 @@ def _cp_lse_common(
...
@@ -193,7 +203,13 @@ def _cp_lse_common(
cp_attn_lse
=
cp_attn_lse
.
contiguous
()
cp_attn_lse
=
cp_attn_lse
.
contiguous
()
lses
=
cp_group
.
all_gather
(
cp_attn_lse
,
dim
=
0
).
view_as
(
lses
)
lses
=
cp_group
.
all_gather
(
cp_attn_lse
,
dim
=
0
).
view_as
(
lses
)
out
,
lse
=
correct_attn_out
(
cp_attn_out
,
lses
,
cp_group
.
rank_in_group
,
ctx
)
out
,
lse
=
correct_attn_out
(
cp_attn_out
,
lses
,
cp_group
.
rank_in_group
,
ctx
,
is_lse_base_on_e
=
is_lse_base_on_e
,
)
return
out
,
lse
return
out
,
lse
...
@@ -203,12 +219,15 @@ def cp_lse_ag_out_rs(
...
@@ -203,12 +219,15 @@ def cp_lse_ag_out_rs(
cp_group
:
GroupCoordinator
,
cp_group
:
GroupCoordinator
,
ctx
:
CPTritonContext
|
None
=
None
,
ctx
:
CPTritonContext
|
None
=
None
,
return_lse
:
bool
=
False
,
return_lse
:
bool
=
False
,
is_lse_base_on_e
=
True
,
):
):
"""
"""
cp_attn_out: [ B, H, D ]
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
cp_attn_lse: [ B, H ]
"""
"""
out
,
lse
=
_cp_lse_common
(
cp_attn_out
,
cp_attn_lse
,
cp_group
,
ctx
=
ctx
)
out
,
lse
=
_cp_lse_common
(
cp_attn_out
,
cp_attn_lse
,
cp_group
,
ctx
=
ctx
,
is_lse_base_on_e
=
is_lse_base_on_e
)
out
=
cp_group
.
reduce_scatter
(
out
,
dim
=
1
)
out
=
cp_group
.
reduce_scatter
(
out
,
dim
=
1
)
if
return_lse
:
if
return_lse
:
...
@@ -225,12 +244,15 @@ def cp_lse_ag_out_ar(
...
@@ -225,12 +244,15 @@ def cp_lse_ag_out_ar(
cp_group
:
GroupCoordinator
,
cp_group
:
GroupCoordinator
,
ctx
:
CPTritonContext
|
None
=
None
,
ctx
:
CPTritonContext
|
None
=
None
,
return_lse
:
bool
=
False
,
return_lse
:
bool
=
False
,
is_lse_base_on_e
=
True
,
):
):
"""
"""
cp_attn_out: [ B, H, D ]
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
cp_attn_lse: [ B, H ]
"""
"""
out
,
lse
=
_cp_lse_common
(
cp_attn_out
,
cp_attn_lse
,
cp_group
,
ctx
=
ctx
)
out
,
lse
=
_cp_lse_common
(
cp_attn_out
,
cp_attn_lse
,
cp_group
,
ctx
=
ctx
,
is_lse_base_on_e
=
is_lse_base_on_e
)
out
=
cp_group
.
all_reduce
(
out
)
out
=
cp_group
.
all_reduce
(
out
)
if
return_lse
:
if
return_lse
:
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
9726e645
...
@@ -249,7 +249,11 @@ class BatchDCPPrefillWrapper:
...
@@ -249,7 +249,11 @@ class BatchDCPPrefillWrapper:
return_lse
=
True
,
return_lse
=
True
,
)
)
output_context
,
lse_context
=
cp_lse_ag_out_rs
(
output_context
,
lse_context
=
cp_lse_ag_out_rs
(
output_context_tmp
,
lse_context_tmp
,
get_dcp_group
(),
return_lse
=
True
output_context_tmp
,
lse_context_tmp
,
get_dcp_group
(),
return_lse
=
True
,
is_lse_base_on_e
=
False
,
)
)
lse_context
=
lse_context
.
transpose
(
0
,
1
).
contiguous
()
lse_context
=
lse_context
.
transpose
(
0
,
1
).
contiguous
()
...
@@ -1335,7 +1339,10 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1335,7 +1339,10 @@ class FlashInferImpl(AttentionImpl):
return_lse
=
True
,
return_lse
=
True
,
)
)
output
[:
num_decode_tokens
]
=
cp_lse_ag_out_rs
(
output
[:
num_decode_tokens
]
=
cp_lse_ag_out_rs
(
output_tmp
,
lse
,
get_dcp_group
()
output_tmp
,
lse
,
get_dcp_group
(),
is_lse_base_on_e
=
False
,
)
)
else
:
else
:
decode_wrapper
.
run
(
decode_wrapper
.
run
(
...
...
vllm/v1/attention/backends/mla/common.py
View file @
9726e645
...
@@ -2057,7 +2057,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -2057,7 +2057,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# correct dcp attn_out with lse.
# correct dcp attn_out with lse.
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
attn_out
=
cp_lse_ag_out_rs
(
attn_out
,
lse
,
get_dcp_group
())
attn_out
=
cp_lse_ag_out_rs
(
attn_out
,
lse
,
get_dcp_group
(),
is_lse_base_on_e
=
not
self
.
_use_fi_prefill
,
)
# v_up projection
# v_up projection
self
.
_v_up_proj
(
attn_out
,
out
=
output
[:
num_decode_tokens
])
self
.
_v_up_proj
(
attn_out
,
out
=
output
[:
num_decode_tokens
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment