Unverified Commit b36bd0a4 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add FlashAttention3 to CP implementations (#1232)



* fa2 function import renaming
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* refine fa_fwd_kwargs and fa_bwd_kwargs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* import FA3 fucntions for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix output of FA3 fwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix rng_state in a2a implementation with FA3
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* hack lse correction for packed lse format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make CP thd out correction work with packed lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix for packed softmax_lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse shape
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change lse_packed to constexpr
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 9ee2dbdd
...@@ -108,6 +108,7 @@ _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") ...@@ -108,6 +108,7 @@ _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
_flash_attn_3_plus = False _flash_attn_3_plus = False
_use_flash_attn_3 = False _use_flash_attn_3 = False
_flash_attn_3_installation_steps = """\ _flash_attn_3_installation_steps = """\
...@@ -135,13 +136,19 @@ else: ...@@ -135,13 +136,19 @@ else:
from flashattn_hopper.flash_attn_interface import ( from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3,
) )
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_forward as flash_attn_varlen_fwd_v3,
)
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_backward as flash_attn_varlen_bwd_v3,
)
_use_flash_attn_3 = True _use_flash_attn_3 = True
if _flash_attn_version >= _flash_attn_version_required: if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as flash_attn_varlen_fwd
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as flash_attn_varlen_bwd
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
_attention_backends = { _attention_backends = {
...@@ -460,9 +467,6 @@ def get_attention_backend( ...@@ -460,9 +467,6 @@ def get_attention_backend(
) )
use_unfused_attention = False use_unfused_attention = False
if context_parallel and use_flash_attention: if context_parallel and use_flash_attention:
if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for context parallelism")
_use_flash_attn_3 = False
if fp8 and fp8_meta["recipe"].fp8_dpa: if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug( logger.debug(
"Disabling FlashAttention as it does not support context parallelism with FP8" "Disabling FlashAttention as it does not support context parallelism with FP8"
...@@ -1362,12 +1366,15 @@ def flash_attn_p2p_communicate( ...@@ -1362,12 +1366,15 @@ def flash_attn_p2p_communicate(
def flash_attn_fwd_out_correction( def flash_attn_fwd_out_correction(
out: torch.Tensor, out: torch.Tensor,
out_per_step: torch.Tensor, out_per_step: torch.Tensor,
seq_dim: int,
softmax_lse: torch.Tensor, softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor, softmax_lse_per_step: torch.Tensor,
movedim_src: int,
movedim_dst: int,
): ):
"""Merge partial outputs of each step in Attention with context parallelism""" """Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(
movedim_src, movedim_dst
)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp out_corrected = out_per_step * softmax_lse_corrected_exp
out.add_(out_corrected) out.add_(out_corrected)
...@@ -1693,13 +1700,25 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1693,13 +1700,25 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
*attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
) )
assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
fa_optional_forward_kwargs = {}
softmax_lse_in_packed_format = not use_fused_attention and (
_flash_attn_2_6_0_plus or _use_flash_attn_3
)
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
flash_attn_fwd = flash_attn_varlen_fwd_v3
fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
else:
flash_attn_fwd = flash_attn_varlen_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
if _flash_attn_2_4_plus: if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None fa_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus: if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None fa_forward_kwargs["block_table"] = None
# Flash Attn inputs # Flash Attn inputs
q_inputs = [None, None] q_inputs = [None, None]
...@@ -1840,16 +1859,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1840,16 +1859,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
( fa_outputs = flash_attn_fwd(
_,
_,
_,
_,
out_per_step[i],
softmax_lse_per_step[i],
_,
rng_states[i],
) = _flash_attn_forward(
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], kv_inputs[i % 2][0],
kv_inputs[i % 2][1], kv_inputs[i % 2][1],
...@@ -1857,12 +1867,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1857,12 +1867,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[i], cu_seqlens_kv_per_step[i],
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
dropout_p,
softmax_scale,
causal=True, causal=True,
return_softmax=False, **fa_forward_kwargs,
**fa_optional_forward_kwargs,
) )
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[7]
elif i <= rank: elif i <= rank:
if pad_between_seqs_q: if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
...@@ -1952,18 +1963,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1952,18 +1963,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus: if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = (-1, -1) fa_forward_kwargs["window_size"] = (-1, -1)
( fa_outputs = flash_attn_fwd(
_,
_,
_,
_,
out_per_step[i],
softmax_lse_per_step[i],
_,
rng_states[i],
) = _flash_attn_forward(
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], kv_inputs[i % 2][0],
kv_inputs[i % 2][1], kv_inputs[i % 2][1],
...@@ -1971,12 +1973,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1971,12 +1973,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[i], cu_seqlens_kv_per_step[i],
max_seqlen_q, max_seqlen_q,
max_seqlen_kv // 2, max_seqlen_kv // 2,
dropout_p,
softmax_scale,
causal=False, causal=False,
return_softmax=False, **fa_forward_kwargs,
**fa_optional_forward_kwargs,
) )
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[7]
else: else:
if pad_between_seqs_q: if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
...@@ -2075,18 +2078,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2075,18 +2078,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus: if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = (-1, -1) fa_forward_kwargs["window_size"] = (-1, -1)
( fa_outputs = flash_attn_fwd(
_,
_,
_,
_,
out_per_step[i],
softmax_lse_per_step[i],
_,
rng_states[i],
) = _flash_attn_forward(
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], kv_inputs[i % 2][0],
kv_inputs[i % 2][1], kv_inputs[i % 2][1],
...@@ -2094,12 +2088,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2094,12 +2088,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[i], cu_seqlens_kv_per_step[i],
max_seqlen_q // 2, max_seqlen_q // 2,
max_seqlen_kv, max_seqlen_kv,
dropout_p,
softmax_scale,
causal=False, causal=False,
return_softmax=False, **fa_forward_kwargs,
**fa_optional_forward_kwargs,
) )
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[7]
else: else:
if pad_between_seqs_q: if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
...@@ -2167,16 +2162,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2167,16 +2162,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
# [2, b, sk, np, hn] -> [2, b*sk, np, hn] # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
( fa_outputs = flash_attn_fwd(
_,
_,
_,
_,
out_per_step[i],
softmax_lse_per_step[i],
_,
rng_states[i],
) = _flash_attn_forward(
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], kv_inputs[i % 2][0],
kv_inputs[i % 2][1], kv_inputs[i % 2][1],
...@@ -2184,12 +2170,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2184,12 +2170,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[i], cu_seqlens_kv_per_step[i],
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
dropout_p,
softmax_scale,
causal=False, causal=False,
return_softmax=False, **fa_forward_kwargs,
**fa_optional_forward_kwargs,
) )
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[7]
if i > 0: if i > 0:
# wait until fwd restuls correction of last step is done # wait until fwd restuls correction of last step is done
...@@ -2199,6 +2186,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2199,6 +2186,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if use_fused_attention: if use_fused_attention:
# [b, np, sq, 1] -> [b, np, sq] # [b, np, sq, 1] -> [b, np, sq]
softmax_lse_per_step[i - 1].squeeze_(-1) softmax_lse_per_step[i - 1].squeeze_(-1)
if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, t] -> [np, b, sq]
softmax_lse_per_step[i - 1] = softmax_lse_per_step[i - 1].view(
q.shape[-2], q.shape[0], -1
)
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if fp8: if fp8:
...@@ -2213,7 +2205,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2213,7 +2205,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal and qkv_format != "thd": if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2] # [b, np, sq] -> [b, np, 2, sq//2] lse not in packed format
# [np, b, sq] -> [np, b, 2, sq//2] lse in packed format
softmax_lse_ = softmax_lse.view( softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
) )
...@@ -2227,7 +2220,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2227,7 +2220,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse, softmax_lse,
softmax_lse_per_step[i - 1], softmax_lse_per_step[i - 1],
cu_seqlens_q_padded, cu_seqlens_q_padded,
max_seqlen_q, softmax_lse_in_packed_format,
) )
else: else:
flash_attn_fwd_softmax_lse_correction( flash_attn_fwd_softmax_lse_correction(
...@@ -2253,9 +2246,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2253,9 +2246,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_fwd_out_correction( flash_attn_fwd_out_correction(
out.view(*out_per_step[i].shape), out.view(*out_per_step[i].shape),
out_per_step[i], out_per_step[i],
seq_dim,
softmax_lse, softmax_lse,
softmax_lse_per_step[i], softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2,
2 if softmax_lse_in_packed_format else seq_dim,
) )
elif qkv_format == "thd": elif qkv_format == "thd":
tex.thd_out_correction( tex.thd_out_correction(
...@@ -2265,15 +2259,17 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2265,15 +2259,17 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], softmax_lse_per_step[i],
cu_seqlens_q_padded, cu_seqlens_q_padded,
False, False,
softmax_lse_in_packed_format,
) )
else: else:
if qkv_format in ["bshd", "sbhd"]: if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction( flash_attn_fwd_out_correction(
out_, out_,
out_per_step[i], out_per_step[i],
seq_dim,
softmax_lse_[..., 1, :], softmax_lse_[..., 1, :],
softmax_lse_per_step[i], softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2,
2 if softmax_lse_in_packed_format else seq_dim,
) )
elif qkv_format == "thd": elif qkv_format == "thd":
tex.thd_out_correction( tex.thd_out_correction(
...@@ -2283,8 +2279,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2283,8 +2279,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], softmax_lse_per_step[i],
cu_seqlens_q_padded, cu_seqlens_q_padded,
True, True,
softmax_lse_in_packed_format,
) )
if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, b, sq] -> [np, t]
softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1)
kv = p2p_comm_buffers[-1] kv = p2p_comm_buffers[-1]
if qkv_format == "bshd": if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:]) out = out.view(out.shape[0], -1, *out.shape[-2:])
...@@ -2430,10 +2430,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2430,10 +2430,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
attn_dbias = None attn_dbias = None
softmax_lse_in_packed_format = not ctx.use_fused_attention and (
_flash_attn_2_6_0_plus or _use_flash_attn_3
)
if causal: if causal:
if ctx.qkv_format == "thd": if ctx.qkv_format == "thd" or softmax_lse_in_packed_format:
softmax_lse_ = tex.thd_read_second_half_lse( softmax_lse_ = tex.thd_read_second_half_lse(
softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q softmax_lse, cu_seqlens_q_padded, softmax_lse_in_packed_format
) )
else: else:
# [b, np, sq] -> [b, np, 2, sq//2] # [b, np, sq] -> [b, np, 2, sq//2]
...@@ -2526,11 +2530,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2526,11 +2530,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dout = dout.view(*q.shape) dout = dout.view(*q.shape)
send_recv_reqs = [] send_recv_reqs = []
fa_optional_backward_kwargs = {} if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
flash_attn_bwd = flash_attn_varlen_bwd_v3
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
flash_attn_bwd = flash_attn_varlen_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_4_plus: if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus: if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic fa_backward_kwargs["deterministic"] = ctx.deterministic
for i in range(cp_size): for i in range(cp_size):
# wait until KV is received # wait until KV is received
...@@ -2639,9 +2650,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2639,9 +2650,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:]) out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus: if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = (-1, 0) fa_backward_kwargs["window_size"] = (-1, 0)
_flash_attn_backward( if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_, dout_,
q_, q_,
kv_[0], kv_[0],
...@@ -2655,11 +2668,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2655,11 +2668,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
ctx.dropout_p, causal=True,
ctx.softmax_scale, **fa_backward_kwargs,
True,
rng_state=rng_states[cp_size - i - 1],
**fa_optional_backward_kwargs,
) )
elif i >= (cp_size - rank - 1): elif i >= (cp_size - rank - 1):
if ctx.use_fused_attention: if ctx.use_fused_attention:
...@@ -2733,9 +2743,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2733,9 +2743,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:]) out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus: if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = (-1, -1) fa_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward( if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_, dout_,
q_, q_,
kv_[0], kv_[0],
...@@ -2749,11 +2761,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2749,11 +2761,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2, ctx.max_seqlen_kv // 2,
ctx.dropout_p, causal=False,
ctx.softmax_scale, **fa_backward_kwargs,
False,
rng_state=rng_states[cp_size - i - 1],
**fa_optional_backward_kwargs,
) )
else: else:
if ctx.use_fused_attention: if ctx.use_fused_attention:
...@@ -2833,9 +2842,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2833,9 +2842,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus: if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = (-1, -1) fa_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward( if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_, dout_,
q_, q_,
kv_[0], kv_[0],
...@@ -2849,11 +2860,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2849,11 +2860,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q // 2, ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
ctx.dropout_p, causal=False,
ctx.softmax_scale, **fa_backward_kwargs,
False,
rng_state=rng_states[cp_size - i - 1],
**fa_optional_backward_kwargs,
) )
else: else:
if ctx.use_fused_attention: if ctx.use_fused_attention:
...@@ -2897,9 +2905,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2897,9 +2905,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b, sq, np, hn] -> [b*sq, np, hn] # [b, sq, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:]) out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus: if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = (-1, -1) fa_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward( if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_, dout_,
q_, q_,
kv_[0], kv_[0],
...@@ -2913,11 +2923,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2913,11 +2923,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
ctx.dropout_p, causal=False,
ctx.softmax_scale, **fa_backward_kwargs,
False,
rng_state=rng_states[cp_size - i - 1],
**fa_optional_backward_kwargs,
) )
if ctx.fp8: if ctx.fp8:
...@@ -3251,11 +3258,19 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3251,11 +3258,19 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
assert ( assert (
use_fused_attention or _flash_attn_2_3_plus use_fused_attention or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
fa_optional_forward_kwargs = {}
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
flash_attn_fwd = flash_attn_varlen_fwd_v3
else:
flash_attn_fwd = flash_attn_varlen_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _flash_attn_2_4_plus: if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None fa_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus: if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None fa_forward_kwargs["block_table"] = None
assert qkv_format != "thd", f"{qkv_format} format is not supported!" assert qkv_format != "thd", f"{qkv_format} format is not supported!"
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
...@@ -3353,8 +3368,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3353,8 +3368,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
) )
else: else:
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
_, _, _, _, out_per_step[i], softmax_lse_per_step[i], _, rng_states[i] = ( fa_outputs = flash_attn_fwd(
_flash_attn_forward(
q_, q_,
k_, k_,
v_, v_,
...@@ -3362,14 +3376,14 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3362,14 +3376,14 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cu_seqlens_kv_per_step[i], cu_seqlens_kv_per_step[i],
max_seqlen_q, max_seqlen_q,
max_seqlen_kv_, max_seqlen_kv_,
dropout_p,
softmax_scale,
causal=causal, causal=causal,
return_softmax=False,
window_size=window_size_per_step[i], window_size=window_size_per_step[i],
**fa_optional_forward_kwargs, **fa_forward_kwargs,
)
) )
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[7]
if i > 0: if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]): with torch.cuda.stream(flash_attn_streams[i - 1]):
...@@ -3459,11 +3473,18 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3459,11 +3473,18 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
fa_optional_backward_kwargs = {} if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
flash_attn_bwd = flash_attn_varlen_bwd_v3
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
flash_attn_bwd = flash_attn_varlen_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_4_plus: if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus: if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic fa_backward_kwargs["deterministic"] = ctx.deterministic
for i in range(len(local_seq_chunk_ids) + 1): for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids): if i < len(local_seq_chunk_ids):
...@@ -3513,7 +3534,9 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3513,7 +3534,9 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_] torch.empty_like(x) for x in [q_, k_, v_]
] ]
_flash_attn_backward( if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[i]
flash_attn_bwd(
dout_, dout_,
q_, q_,
k_, k_,
...@@ -3527,12 +3550,9 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3527,12 +3550,9 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cu_seqlens_kv_per_step[i], cu_seqlens_kv_per_step[i],
ctx.max_seqlen_q, ctx.max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
ctx.dropout_p, causal="causal" in ctx.attn_mask_type,
ctx.softmax_scale,
"causal" in ctx.attn_mask_type,
window_size=window_size_per_step[i], window_size=window_size_per_step[i],
rng_state=rng_states[i], **fa_backward_kwargs,
**fa_optional_backward_kwargs,
) )
# [b*sq//2, np, hn] -> [b, sq//2, np, hn] # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape) dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape)
...@@ -3655,13 +3675,22 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3655,13 +3675,22 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
or use_fused_attention or use_fused_attention
or _flash_attn_2_3_plus or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
fa_optional_forward_kwargs = {}
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
flash_attn_fwd = flash_attn_varlen_fwd_v3
fa_forward_kwargs["window_size"] = window_size
else:
flash_attn_fwd = flash_attn_varlen_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size fa_forward_kwargs["window_size"] = window_size
if _flash_attn_2_4_plus: if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None fa_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus: if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None fa_forward_kwargs["block_table"] = None
assert ( assert (
q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
...@@ -3756,16 +3785,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3756,16 +3785,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else: else:
# [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn] # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn]
q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]] q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]]
( fa_outputs = flash_attn_fwd(
_,
_,
_,
_,
out,
softmax_lse,
_,
rng_state,
) = _flash_attn_forward(
q, q,
k, k,
v, v,
...@@ -3773,12 +3793,11 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3773,12 +3793,11 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv, cu_seqlens_kv,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
dropout_p,
softmax_scale,
causal=causal, causal=causal,
return_softmax=False, **fa_forward_kwargs,
**fa_optional_forward_kwargs,
) )
out, softmax_lse = fa_outputs[4], fa_outputs[5]
rng_state = fa_outputs[7] if not _use_flash_attn_3 else None
aux_ctx_tensors = [softmax_lse, rng_state] aux_ctx_tensors = [softmax_lse, rng_state]
# [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn] # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn]
out = out.view(batch_size, -1, *out.shape[-2:]) out = out.view(batch_size, -1, *out.shape[-2:])
...@@ -3943,13 +3962,21 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3943,13 +3962,21 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
) )
fa_optional_backward_kwargs = {} if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
flash_attn_bwd = flash_attn_varlen_bwd_v3
fa_backward_kwargs["window_size"] = ctx.window_size
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
flash_attn_bwd = flash_attn_varlen_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["window_size"] = ctx.window_size
if _flash_attn_2_4_plus: if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus: if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic fa_backward_kwargs["deterministic"] = ctx.deterministic
if ctx.use_fused_attention: if ctx.use_fused_attention:
dq, dk, dv, _ = fused_attn_bwd( dq, dk, dv, _ = fused_attn_bwd(
...@@ -3981,7 +4008,9 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3981,7 +4008,9 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
softmax_lse, rng_state = aux_ctx_tensors softmax_lse, rng_state = aux_ctx_tensors
out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]] out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]]
dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
_flash_attn_backward( if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_state
flash_attn_bwd(
dout, dout,
q, q,
k, k,
...@@ -3995,11 +4024,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3995,11 +4024,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv, cu_seqlens_kv,
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
ctx.dropout_p, causal=causal,
ctx.softmax_scale, **fa_backward_kwargs,
causal,
rng_state=rng_state,
**fa_optional_backward_kwargs,
) )
dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
......
...@@ -433,14 +433,14 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s ...@@ -433,14 +433,14 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
int half_idx); int half_idx);
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, int total_tokens); const at::Tensor &cu_seqlens, bool lse_packed);
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
int total_tokens); bool lse_packed);
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half); bool only_second_half, bool lse_packed);
void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half, const at::Tensor &cu_seqlens, const std::string &first_half,
......
...@@ -1464,9 +1464,9 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s ...@@ -1464,9 +1464,9 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
* Support THD format for Context Parallel: softmax_lse related operations * Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/ **************************************************************************************************/
template <typename lse_dtype, typename Functor> template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
int num_heads, int max_seqlen) { int num_heads, int total_tokens) {
extern __shared__ int cu_seqlens_s[]; extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) { for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2; cu_seqlens_s[i] = cu_seqlens[i] / 2;
...@@ -1480,12 +1480,18 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, ...@@ -1480,12 +1480,18 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, half_idx;
if constexpr (lse_packed) {
idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1];
half_idx = head_id * total_tokens / 2 + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id; size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id]; int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
size_t idx = row * max_seqlen + col + seq_len; idx = row * total_tokens + col + seq_len;
size_t half_idx = row * max_seqlen / 2 + col; half_idx = row * total_tokens / 2 + col;
}
Functor::run(lse, half_lse, idx, half_idx); Functor::run(lse, half_lse, idx, half_idx);
} }
...@@ -1504,32 +1510,53 @@ struct LseCorrectionFunctor { ...@@ -1504,32 +1510,53 @@ struct LseCorrectionFunctor {
}; };
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, int total_tokens) { const at::Tensor &cu_seqlens, bool lse_packed) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch, num_heads, total_tokens;
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
NVTE_CHECK(lse_per_step.dim() == 2);
batch = cu_seqlens.size(0) - 1;
num_heads = lse.size(0);
total_tokens = lse.size(1);
NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step.size(1) == total_tokens / 2);
} else {
NVTE_CHECK(lse.dim() == 3); NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(lse_per_step.dim() == 3); NVTE_CHECK(lse_per_step.dim() == 3);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch = lse.size(0); batch = lse.size(0);
int num_heads = lse.size(1); num_heads = lse.size(1);
int max_seqlen = lse.size(2); total_tokens = lse.size(2);
NVTE_CHECK(lse_per_step.size(0) == batch); NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads); NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(2) == max_seqlen / 2); NVTE_CHECK(lse_per_step.size(2) == total_tokens / 2);
NVTE_CHECK(cu_seqlens.size(0) == batch + 1); NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}
constexpr unsigned int block = 256; constexpr unsigned int block = 256;
unsigned int grid_x = (total_tokens / 2 + block - 1) / block; unsigned int grid_x = (total_tokens / 2 + block - 1) / block;
unsigned int grid_y = num_heads; unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y}; dim3 grid = {grid_x, grid_y};
thd_lse_kernel<double, LseCorrectionFunctor> if (lse_packed) {
thd_lse_kernel<double, true, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>( <<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
num_heads, max_seqlen); batch, num_heads, total_tokens);
} else {
thd_lse_kernel<double, false, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, total_tokens);
}
} }
struct ReadLseFunctor { struct ReadLseFunctor {
...@@ -1540,29 +1567,51 @@ struct ReadLseFunctor { ...@@ -1540,29 +1567,51 @@ struct ReadLseFunctor {
}; };
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
int total_tokens) { bool lse_packed) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1); NVTE_CHECK(cu_seqlens.dim() == 1);
int batch = lse.size(0); int batch, num_heads, total_tokens;
int num_heads = lse.size(1); std::vector<int64_t> shape;
int max_seqlen = lse.size(2);
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
batch = cu_seqlens.size(0) - 1;
num_heads = lse.size(0);
total_tokens = lse.size(1);
shape = {num_heads, total_tokens / 2};
} else {
NVTE_CHECK(lse.dim() == 3);
batch = lse.size(0);
num_heads = lse.size(1);
total_tokens = lse.size(2);
NVTE_CHECK(cu_seqlens.size(0) == batch + 1); NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
std::vector<int64_t> shape = {batch, num_heads, max_seqlen / 2}; shape = {batch, num_heads, total_tokens / 2};
}
at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type())); at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type()));
constexpr unsigned int block = 256; constexpr unsigned int block = 256;
unsigned int grid_x = (total_tokens / 2 + block - 1) / block; unsigned int grid_x = (total_tokens / 2 + block - 1) / block;
unsigned int grid_y = num_heads; unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y}; dim3 grid = {grid_x, grid_y};
thd_lse_kernel<float, ReadLseFunctor> if (lse_packed) {
thd_lse_kernel<float, true, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, total_tokens);
} else {
thd_lse_kernel<float, false, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>( <<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, max_seqlen); num_heads, total_tokens);
}
return half_lse; return half_lse;
} }
...@@ -1571,10 +1620,10 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ ...@@ -1571,10 +1620,10 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
* Support THD format for Context Parallel: Out correction in forward * Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/ **************************************************************************************************/
template <typename dtype, int only_second_half, int tile_size> template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch, float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int max_seqlen) { int num_heads, int dim_per_head, int lse_seqlen) {
extern __shared__ int cu_seqlens_s[]; extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) { for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1);
...@@ -1592,11 +1641,16 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float ...@@ -1592,11 +1641,16 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, idx_per_step; size_t idx, idx_per_step;
if constexpr (lse_packed) {
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id; size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id]; int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * max_seqlen + col + seq_len * only_second_half; idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * max_seqlen / (only_second_half + 1) + col; idx_per_step = row * lse_seqlen / (only_second_half + 1) + col;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);
idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
...@@ -1622,7 +1676,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float ...@@ -1622,7 +1676,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
template <typename dtype, int only_second_half> template <typename dtype, int only_second_half>
static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step, static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step,
const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens) { const at::Tensor &cu_seqlens, bool lse_packed) {
NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type()); NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type());
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
...@@ -1631,17 +1685,30 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ ...@@ -1631,17 +1685,30 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
int total_tokens = out.size(0); int total_tokens = out.size(0);
int num_heads = out.size(1); int num_heads = out.size(1);
int dim_per_head = out.size(2); int dim_per_head = out.size(2);
int batch = lse.size(0);
int max_seqlen = lse.size(2);
NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1)); NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1));
NVTE_CHECK(out_per_step.size(1) == num_heads); NVTE_CHECK(out_per_step.size(1) == num_heads);
NVTE_CHECK(out_per_step.size(2) == dim_per_head); NVTE_CHECK(out_per_step.size(2) == dim_per_head);
int batch, lse_seqlen;
if (lse_packed) {
batch = cu_seqlens.size(0) - 1;
lse_seqlen = total_tokens;
NVTE_CHECK(lse.size(0) == num_heads);
NVTE_CHECK(lse.size(1) == lse_seqlen);
NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1));
} else {
batch = lse.size(0);
lse_seqlen = lse.size(2);
NVTE_CHECK(lse.size(1) == num_heads); NVTE_CHECK(lse.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(0) == batch); NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads); NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(2) == max_seqlen / (only_second_half + 1)); NVTE_CHECK(lse_per_step.size(2) == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(cu_seqlens.size(0) == batch + 1); NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}
constexpr int tile = 16; constexpr int tile = 16;
constexpr int block = 512; constexpr int block = 512;
...@@ -1649,39 +1716,53 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ ...@@ -1649,39 +1716,53 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
(static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block; (static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block;
dim3 grid = {grid_x, (unsigned int)num_heads}; dim3 grid = {grid_x, (unsigned int)num_heads};
thd_out_correction_kernel<dtype, only_second_half, tile> if (lse_packed) {
thd_out_correction_kernel<dtype, only_second_half, tile, true>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>( <<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(), out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads, lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, max_seqlen); dim_per_head, lse_seqlen);
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen);
}
} }
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half) { bool only_second_half, bool lse_packed) {
if (only_second_half) { if (only_second_half) {
if (out.scalar_type() == at::ScalarType::Half) { if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half; using dtype = at::Half;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens); thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::BFloat16) { } else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16; using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens); thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::Float) { } else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float; using dtype = float;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens); thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else { } else {
NVTE_ERROR("Unsupported dtype of out\n"); NVTE_ERROR("Unsupported dtype of out\n");
} }
} else { } else {
if (out.scalar_type() == at::ScalarType::Half) { if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half; using dtype = at::Half;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens); thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::BFloat16) { } else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16; using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens); thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::Float) { } else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float; using dtype = float;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens); thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else { } else {
NVTE_ERROR("Unsupported dtype of out\n"); NVTE_ERROR("Unsupported dtype of out\n");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment