"...git@developer.sourcefind.cn:modelzoo/stylegan2_mmcv.git" did not exist on "1401de15d079af4d9d9f995f2d57ddb6d930d7f0"
Unverified Commit 59bfc17b authored by Kite0011's avatar Kite0011 Committed by GitHub
Browse files

[Pytorch] Update context parallel softmax lse correction func (#716)



[Pytorch] Update context parallel softmax lse correction func.
Signed-off-by: default avatarkitefang <kitefang@tencent.com>
Co-authored-by: default avatarkitefang <kitefang@tencent.com>
parent c38779be
......@@ -483,9 +483,10 @@ def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_pe
@jit_fuser
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
"""Merge softmax stats of each step in Attention with context parallelism"""
softmax_lse.exp_()
softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp())
softmax_lse.log_()
max_scale = torch.max(softmax_lse, softmax_lse_per_step)
min_scale = torch.min(softmax_lse, softmax_lse_per_step)
new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
softmax_lse.copy_(new_scale)
class AttnFuncWithCP(torch.autograd.Function):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment