Unverified Commit 76187a5e authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

fix a sync race error of softmax_lse in CP+THD+P2P (#1624)



fix a race error softmax_lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
parent 3bcd7f6f
...@@ -1359,16 +1359,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1359,16 +1359,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if i > 1: if i > 1:
flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
if use_fused_attention:
# [b, np, sq, 1] -> [b, np, sq] or
# [t, np, 1] -> [t, np]
softmax_lse_per_step[i - 1].squeeze_(-1)
if softmax_lse_in_packed_format:
softmax_lse_per_step[i - 1] = (
softmax_lse_per_step[i - 1].transpose(0, 1).contiguous()
)
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if use_fused_attention:
# [b, np, sq, 1] -> [b, np, sq] or
# [t, np, 1] -> [t, np]
softmax_lse_per_step[i - 1].squeeze_(-1)
if softmax_lse_in_packed_format:
softmax_lse_per_step[i - 1] = (
softmax_lse_per_step[i - 1].transpose(0, 1).contiguous()
)
if fp8: if fp8:
out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32)
if i == 1: if i == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment