"tests/vscode:/vscode.git/clone" did not exist on "daa1273b14da5bdf643aa4b1bcbef3985b1edd75"
Unverified Commit 3a6d5cbe authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Optimize dcp allocate tensor (#33102)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent f5d7049c
...@@ -195,14 +195,10 @@ def _cp_lse_common( ...@@ -195,14 +195,10 @@ def _cp_lse_common(
if ctx is None: if ctx is None:
ctx = CPTritonContext() ctx = CPTritonContext()
lses = torch.empty(
(cp_group.world_size,) + cp_attn_lse.shape,
dtype=cp_attn_lse.dtype,
device=cp_attn_lse.device,
)
cp_attn_lse = cp_attn_lse.contiguous() cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) lses = cp_group.all_gather(cp_attn_lse, dim=0).reshape(
(cp_group.world_size,) + cp_attn_lse.shape
)
out, lse = correct_attn_out( out, lse = correct_attn_out(
cp_attn_out, cp_attn_out,
lses, lses,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment