Unverified Commit 59bfc17b authored by Kite0011's avatar Kite0011 Committed by GitHub
Browse files

[Pytorch] Update context parallel softmax lse correction func (#716)



[Pytorch] Update context parallel softmax lse correction func.
Signed-off-by: default avatarkitefang <kitefang@tencent.com>
Co-authored-by: default avatarkitefang <kitefang@tencent.com>
parent c38779be
...@@ -483,9 +483,10 @@ def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_pe ...@@ -483,9 +483,10 @@ def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_pe
@jit_fuser @jit_fuser
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step): def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
"""Merge softmax stats of each step in Attention with context parallelism""" """Merge softmax stats of each step in Attention with context parallelism"""
softmax_lse.exp_() max_scale = torch.max(softmax_lse, softmax_lse_per_step)
softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp()) min_scale = torch.min(softmax_lse, softmax_lse_per_step)
softmax_lse.log_() new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
softmax_lse.copy_(new_scale)
class AttnFuncWithCP(torch.autograd.Function): class AttnFuncWithCP(torch.autograd.Function):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment