Unverified Commit 3a6d5cbe authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Optimize dcp allocate tensor (#33102)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent f5d7049c
......@@ -195,14 +195,10 @@ def _cp_lse_common(
if ctx is None:
ctx = CPTritonContext()
lses = torch.empty(
(cp_group.world_size,) + cp_attn_lse.shape,
dtype=cp_attn_lse.dtype,
device=cp_attn_lse.device,
)
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
lses = cp_group.all_gather(cp_attn_lse, dim=0).reshape(
(cp_group.world_size,) + cp_attn_lse.shape
)
out, lse = correct_attn_out(
cp_attn_out,
lses,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment