Unverified Commit 0cd103e7 authored by Huamin Li's avatar Huamin Li Committed by GitHub
Browse files

CP: make correct_attn_out robust to 4‑D views and fix Triton arg binding (#26509)


Signed-off-by: default avatarHuamin Li <3ericli@gmail.com>
parent 5be7ca1b
...@@ -117,14 +117,52 @@ def correct_attn_out( ...@@ -117,14 +117,52 @@ def correct_attn_out(
if ctx is None: if ctx is None:
ctx = CPTritonContext() ctx = CPTritonContext()
lse = torch.empty_like(lses[0]) # --- Normalize to 3D views ---
if out.ndim == 4 and out.shape[1] == 1:
grid = (out.shape[0], out.shape[1], 1) out = out.squeeze(1)
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), cp_rank) assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}"
const_args = {
"HEAD_DIM": out.shape[-1], if lses.ndim == 4 and lses.shape[-1] == 1:
"N_ROUNDED": lses.shape[0], lses = lses.squeeze(-1)
} if lses.ndim == 4 and lses.shape[1] == 1:
lses = lses.squeeze(1)
assert lses.ndim == 3, (
f"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
f"got {tuple(lses.shape)}"
)
B, H, D = out.shape
N = lses.shape[0]
# Strides after we normalized shapes to 3-D views. The kernel computes
# offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
# have the same B/H stride layout as a slice of `lses`.
o_sB, o_sH, o_sD = out.stride()
l_sN, l_sB, l_sH = lses.stride()
# Allocate LSE with the same B/H strides as `lses` so writes land correctly
# even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
lse = torch.empty_strided(
(B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
)
# Kernel launch config
grid = (B, H, 1)
regular_args = (
out,
out,
lses,
lse,
o_sB,
o_sH,
o_sD,
l_sN,
l_sB,
l_sH,
cp_rank,
)
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args) ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
return out, lse return out, lse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment