Unverified Commit 9ee2dbdd authored by 李金梁's avatar 李金梁 Committed by GitHub
Browse files

Fix bug in torch compile and seqdim is integer (#1217)



* Fix bug in torch compile and seqdim is integer
Signed-off-by: default avatar李金梁 <975761915@qq.com>

* Update attention.py

change the jit_fuser to torch.compile on flash_attn_fwd_out_correction
Signed-off-by: default avatar李金梁 <975761915@qq.com>

* Annotate fused functions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatar李金梁 <975761915@qq.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 3b89c36f
......@@ -1359,7 +1359,13 @@ def flash_attn_p2p_communicate(
@jit_fuser
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step):
def flash_attn_fwd_out_correction(
out: torch.Tensor,
out_per_step: torch.Tensor,
seq_dim: int,
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
......@@ -1368,7 +1374,10 @@ def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softm
@jit_fuser
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
def flash_attn_fwd_softmax_lse_correction(
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
):
"""Merge softmax stats of each step in Attention with context parallelism"""
max_scale = torch.max(softmax_lse, softmax_lse_per_step)
min_scale = torch.min(softmax_lse, softmax_lse_per_step)
......@@ -1378,7 +1387,12 @@ def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
@jit_fuser
def get_cu_seqlens_on_cp_rank(
cu_seqlens, cu_seqlens_padded_on_cp_rank, cp_size, cp_rank, first_half, second_half
cu_seqlens: torch.Tensor,
cu_seqlens_padded_on_cp_rank: torch.Tensor,
cp_size: int,
cp_rank: int,
first_half: bool,
second_half: bool,
):
"""Compute cu_seqlens of a context parallelism rank"""
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment