Unverified Commit 199e6123 authored by Sergii Dymchenko's avatar Sergii Dymchenko Committed by GitHub
Browse files

Use log1p(x) instead of log(1+x) (#1401)

This function is more accurate than torch.log() for small values of input - https://pytorch.org/docs/stable/generated/torch.log1p.html

Found with TorchFix https://github.com/pytorch-labs/torchfix/

Signed-off-by: default avatarSergii Dymchenko <sdym@meta.com>
Co-authored-by: default avatarXiaowei Ren <103958965+xrennvidia@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 2fce82b7
......@@ -1604,7 +1604,7 @@ def flash_attn_fwd_softmax_lse_correction(
"""Merge softmax stats of each step in Attention with context parallelism"""
max_scale = torch.max(softmax_lse, softmax_lse_per_step)
min_scale = torch.min(softmax_lse, softmax_lse_per_step)
new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
softmax_lse.copy_(new_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment