Unverified Commit 74eb7c33 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

fix bwd error of context parallelism implementation with FA v2 (#498)



fix bwd error with FA v2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d20ba9fb
...@@ -572,6 +572,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -572,6 +572,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
# [b, np, sq] -> [b, np, 2, sq//2] # [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn] # [b*sq, np, hn] -> [b, 2, sq//2, np, hn]
out = out.view(*q.shape) out = out.view(*q.shape)
dout = dout.view(*q.shape) dout = dout.view(*q.shape)
...@@ -659,7 +660,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -659,7 +660,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
_flash_attn_backward( _flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse_[..., 1, :], dout_, q_, kv_[0], kv_[1], out_, softmax_lse_,
dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k, dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
ctx.max_seqlen_q//2, ctx.max_seqlen_k, ctx.max_seqlen_q//2, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False, ctx.dropout_p, ctx.softmax_scale, False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment