Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0cd103e7
Unverified
Commit
0cd103e7
authored
Oct 11, 2025
by
Huamin Li
Committed by
GitHub
Oct 11, 2025
Browse files
CP: make correct_attn_out robust to 4‑D views and fix Triton arg binding (#26509)
Signed-off-by:
Huamin Li
<
3ericli@gmail.com
>
parent
5be7ca1b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
8 deletions
+46
-8
vllm/attention/ops/common.py
vllm/attention/ops/common.py
+46
-8
No files found.
vllm/attention/ops/common.py
View file @
0cd103e7
...
...
@@ -117,14 +117,52 @@ def correct_attn_out(
if
ctx
is
None
:
ctx
=
CPTritonContext
()
lse
=
torch
.
empty_like
(
lses
[
0
])
grid
=
(
out
.
shape
[
0
],
out
.
shape
[
1
],
1
)
regular_args
=
(
out
,
out
,
lses
,
lse
,
*
out
.
stride
(),
*
lses
.
stride
(),
cp_rank
)
const_args
=
{
"HEAD_DIM"
:
out
.
shape
[
-
1
],
"N_ROUNDED"
:
lses
.
shape
[
0
],
}
# --- Normalize to 3D views ---
if
out
.
ndim
==
4
and
out
.
shape
[
1
]
==
1
:
out
=
out
.
squeeze
(
1
)
assert
out
.
ndim
==
3
,
f
"expected out [B,H,D] or [B,1,H,D], got
{
tuple
(
out
.
shape
)
}
"
if
lses
.
ndim
==
4
and
lses
.
shape
[
-
1
]
==
1
:
lses
=
lses
.
squeeze
(
-
1
)
if
lses
.
ndim
==
4
and
lses
.
shape
[
1
]
==
1
:
lses
=
lses
.
squeeze
(
1
)
assert
lses
.
ndim
==
3
,
(
f
"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
f
"got
{
tuple
(
lses
.
shape
)
}
"
)
B
,
H
,
D
=
out
.
shape
N
=
lses
.
shape
[
0
]
# Strides after we normalized shapes to 3-D views. The kernel computes
# offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
# have the same B/H stride layout as a slice of `lses`.
o_sB
,
o_sH
,
o_sD
=
out
.
stride
()
l_sN
,
l_sB
,
l_sH
=
lses
.
stride
()
# Allocate LSE with the same B/H strides as `lses` so writes land correctly
# even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
lse
=
torch
.
empty_strided
(
(
B
,
H
),
(
l_sB
,
l_sH
),
device
=
lses
.
device
,
dtype
=
lses
.
dtype
)
# Kernel launch config
grid
=
(
B
,
H
,
1
)
regular_args
=
(
out
,
out
,
lses
,
lse
,
o_sB
,
o_sH
,
o_sD
,
l_sN
,
l_sB
,
l_sH
,
cp_rank
,
)
const_args
=
{
"HEAD_DIM"
:
D
,
"N_ROUNDED"
:
N
}
ctx
.
call_kernel
(
_correct_attn_cp_out_kernel
,
grid
,
*
regular_args
,
**
const_args
)
return
out
,
lse
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment