Unverified Commit b36bd0a4 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add FlashAttention3 to CP implementations (#1232)



* fa2 function import renaming
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* refine fa_fwd_kwargs and fa_bwd_kwargs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* import FA3 fucntions for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix output of FA3 fwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix rng_state in a2a implementation with FA3
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* hack lse correction for packed lse format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make CP thd out correction work with packed lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix for packed softmax_lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse shape
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change lse_packed to constexpr
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 9ee2dbdd
......@@ -108,6 +108,7 @@ _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
_flash_attn_3_plus = False
_use_flash_attn_3 = False
_flash_attn_3_installation_steps = """\
......@@ -135,13 +136,19 @@ else:
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_forward as flash_attn_varlen_fwd_v3,
)
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_backward as flash_attn_varlen_bwd_v3,
)
_use_flash_attn_3 = True
if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as flash_attn_varlen_fwd
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as flash_attn_varlen_bwd
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
_attention_backends = {
......@@ -460,9 +467,6 @@ def get_attention_backend(
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for context parallelism")
_use_flash_attn_3 = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with FP8"
......@@ -1362,12 +1366,15 @@ def flash_attn_p2p_communicate(
def flash_attn_fwd_out_correction(
out: torch.Tensor,
out_per_step: torch.Tensor,
seq_dim: int,
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
movedim_src: int,
movedim_dst: int,
):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(
movedim_src, movedim_dst
)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp
out.add_(out_corrected)
......@@ -1693,13 +1700,25 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
*attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
)
assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
fa_optional_forward_kwargs = {}
softmax_lse_in_packed_format = not use_fused_attention and (
_flash_attn_2_6_0_plus or _use_flash_attn_3
)
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
flash_attn_fwd = flash_attn_varlen_fwd_v3
fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
else:
flash_attn_fwd = flash_attn_varlen_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
fa_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
fa_forward_kwargs["block_table"] = None
# Flash Attn inputs
q_inputs = [None, None]
......@@ -1840,16 +1859,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
(
_,
_,
_,
_,
out_per_step[i],
softmax_lse_per_step[i],
_,
rng_states[i],
) = _flash_attn_forward(
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
......@@ -1857,12 +1867,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv,
dropout_p,
softmax_scale,
causal=True,
return_softmax=False,
**fa_optional_forward_kwargs,
**fa_forward_kwargs,
)
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[7]
elif i <= rank:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
......@@ -1952,18 +1963,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = (-1, -1)
(
_,
_,
_,
_,
out_per_step[i],
softmax_lse_per_step[i],
_,
rng_states[i],
) = _flash_attn_forward(
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_forward_kwargs["window_size"] = (-1, -1)
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
......@@ -1971,12 +1973,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv // 2,
dropout_p,
softmax_scale,
causal=False,
return_softmax=False,
**fa_optional_forward_kwargs,
**fa_forward_kwargs,
)
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[7]
else:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
......@@ -2075,18 +2078,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = (-1, -1)
(
_,
_,
_,
_,
out_per_step[i],
softmax_lse_per_step[i],
_,
rng_states[i],
) = _flash_attn_forward(
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_forward_kwargs["window_size"] = (-1, -1)
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
......@@ -2094,12 +2088,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[i],
max_seqlen_q // 2,
max_seqlen_kv,
dropout_p,
softmax_scale,
causal=False,
return_softmax=False,
**fa_optional_forward_kwargs,
**fa_forward_kwargs,
)
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[7]
else:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
......@@ -2167,16 +2162,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
# [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
(
_,
_,
_,
_,
out_per_step[i],
softmax_lse_per_step[i],
_,
rng_states[i],
) = _flash_attn_forward(
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
......@@ -2184,12 +2170,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv,
dropout_p,
softmax_scale,
causal=False,
return_softmax=False,
**fa_optional_forward_kwargs,
**fa_forward_kwargs,
)
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[7]
if i > 0:
# wait until fwd restuls correction of last step is done
......@@ -2199,6 +2186,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if use_fused_attention:
# [b, np, sq, 1] -> [b, np, sq]
softmax_lse_per_step[i - 1].squeeze_(-1)
if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, t] -> [np, b, sq]
softmax_lse_per_step[i - 1] = softmax_lse_per_step[i - 1].view(
q.shape[-2], q.shape[0], -1
)
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if fp8:
......@@ -2213,7 +2205,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2]
# [b, np, sq] -> [b, np, 2, sq//2] lse not in packed format
# [np, b, sq] -> [np, b, 2, sq//2] lse in packed format
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
)
......@@ -2227,7 +2220,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse,
softmax_lse_per_step[i - 1],
cu_seqlens_q_padded,
max_seqlen_q,
softmax_lse_in_packed_format,
)
else:
flash_attn_fwd_softmax_lse_correction(
......@@ -2253,9 +2246,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_fwd_out_correction(
out.view(*out_per_step[i].shape),
out_per_step[i],
seq_dim,
softmax_lse,
softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2,
2 if softmax_lse_in_packed_format else seq_dim,
)
elif qkv_format == "thd":
tex.thd_out_correction(
......@@ -2265,15 +2259,17 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i],
cu_seqlens_q_padded,
False,
softmax_lse_in_packed_format,
)
else:
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(
out_,
out_per_step[i],
seq_dim,
softmax_lse_[..., 1, :],
softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2,
2 if softmax_lse_in_packed_format else seq_dim,
)
elif qkv_format == "thd":
tex.thd_out_correction(
......@@ -2283,8 +2279,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i],
cu_seqlens_q_padded,
True,
softmax_lse_in_packed_format,
)
if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, b, sq] -> [np, t]
softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1)
kv = p2p_comm_buffers[-1]
if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:])
......@@ -2430,10 +2430,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
attn_dbias = None
softmax_lse_in_packed_format = not ctx.use_fused_attention and (
_flash_attn_2_6_0_plus or _use_flash_attn_3
)
if causal:
if ctx.qkv_format == "thd":
if ctx.qkv_format == "thd" or softmax_lse_in_packed_format:
softmax_lse_ = tex.thd_read_second_half_lse(
softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q
softmax_lse, cu_seqlens_q_padded, softmax_lse_in_packed_format
)
else:
# [b, np, sq] -> [b, np, 2, sq//2]
......@@ -2526,11 +2530,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dout = dout.view(*q.shape)
send_recv_reqs = []
fa_optional_backward_kwargs = {}
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
flash_attn_bwd = flash_attn_varlen_bwd_v3
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
flash_attn_bwd = flash_attn_varlen_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None
fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
fa_backward_kwargs["deterministic"] = ctx.deterministic
for i in range(cp_size):
# wait until KV is received
......@@ -2639,9 +2650,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = (-1, 0)
_flash_attn_backward(
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, 0)
if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
q_,
kv_[0],
......@@ -2655,11 +2668,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
True,
rng_state=rng_states[cp_size - i - 1],
**fa_optional_backward_kwargs,
causal=True,
**fa_backward_kwargs,
)
elif i >= (cp_size - rank - 1):
if ctx.use_fused_attention:
......@@ -2733,9 +2743,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward(
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
q_,
kv_[0],
......@@ -2749,11 +2761,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2,
ctx.dropout_p,
ctx.softmax_scale,
False,
rng_state=rng_states[cp_size - i - 1],
**fa_optional_backward_kwargs,
causal=False,
**fa_backward_kwargs,
)
else:
if ctx.use_fused_attention:
......@@ -2833,9 +2842,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward(
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
q_,
kv_[0],
......@@ -2849,11 +2860,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
False,
rng_state=rng_states[cp_size - i - 1],
**fa_optional_backward_kwargs,
causal=False,
**fa_backward_kwargs,
)
else:
if ctx.use_fused_attention:
......@@ -2897,9 +2905,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b, sq, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward(
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
q_,
kv_[0],
......@@ -2913,11 +2923,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
False,
rng_state=rng_states[cp_size - i - 1],
**fa_optional_backward_kwargs,
causal=False,
**fa_backward_kwargs,
)
if ctx.fp8:
......@@ -3251,11 +3258,19 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
assert (
use_fused_attention or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
fa_optional_forward_kwargs = {}
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
flash_attn_fwd = flash_attn_varlen_fwd_v3
else:
flash_attn_fwd = flash_attn_varlen_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
fa_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
fa_forward_kwargs["block_table"] = None
assert qkv_format != "thd", f"{qkv_format} format is not supported!"
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
......@@ -3353,8 +3368,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
)
else:
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
_, _, _, _, out_per_step[i], softmax_lse_per_step[i], _, rng_states[i] = (
_flash_attn_forward(
fa_outputs = flash_attn_fwd(
q_,
k_,
v_,
......@@ -3362,14 +3376,14 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv_,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=False,
window_size=window_size_per_step[i],
**fa_optional_forward_kwargs,
)
**fa_forward_kwargs,
)
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[7]
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
......@@ -3459,11 +3473,18 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
fa_optional_backward_kwargs = {}
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
flash_attn_bwd = flash_attn_varlen_bwd_v3
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
flash_attn_bwd = flash_attn_varlen_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None
fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
fa_backward_kwargs["deterministic"] = ctx.deterministic
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
......@@ -3513,7 +3534,9 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
]
_flash_attn_backward(
if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[i]
flash_attn_bwd(
dout_,
q_,
k_,
......@@ -3527,12 +3550,9 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cu_seqlens_kv_per_step[i],
ctx.max_seqlen_q,
max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
"causal" in ctx.attn_mask_type,
causal="causal" in ctx.attn_mask_type,
window_size=window_size_per_step[i],
rng_state=rng_states[i],
**fa_optional_backward_kwargs,
**fa_backward_kwargs,
)
# [b*sq//2, np, hn] -> [b, sq//2, np, hn]
dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape)
......@@ -3655,13 +3675,22 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
or use_fused_attention
or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
fa_optional_forward_kwargs = {}
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
flash_attn_fwd = flash_attn_varlen_fwd_v3
fa_forward_kwargs["window_size"] = window_size
else:
flash_attn_fwd = flash_attn_varlen_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size
fa_forward_kwargs["window_size"] = window_size
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
fa_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
fa_forward_kwargs["block_table"] = None
assert (
q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
......@@ -3756,16 +3785,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else:
# [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn]
q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]]
(
_,
_,
_,
_,
out,
softmax_lse,
_,
rng_state,
) = _flash_attn_forward(
fa_outputs = flash_attn_fwd(
q,
k,
v,
......@@ -3773,12 +3793,11 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=False,
**fa_optional_forward_kwargs,
**fa_forward_kwargs,
)
out, softmax_lse = fa_outputs[4], fa_outputs[5]
rng_state = fa_outputs[7] if not _use_flash_attn_3 else None
aux_ctx_tensors = [softmax_lse, rng_state]
# [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn]
out = out.view(batch_size, -1, *out.shape[-2:])
......@@ -3943,13 +3962,21 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
)
fa_optional_backward_kwargs = {}
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
flash_attn_bwd = flash_attn_varlen_bwd_v3
fa_backward_kwargs["window_size"] = ctx.window_size
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
flash_attn_bwd = flash_attn_varlen_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = ctx.window_size
fa_backward_kwargs["window_size"] = ctx.window_size
if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None
fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
fa_backward_kwargs["deterministic"] = ctx.deterministic
if ctx.use_fused_attention:
dq, dk, dv, _ = fused_attn_bwd(
......@@ -3981,7 +4008,9 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
softmax_lse, rng_state = aux_ctx_tensors
out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]]
dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
_flash_attn_backward(
if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_state
flash_attn_bwd(
dout,
q,
k,
......@@ -3995,11 +4024,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv,
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
causal,
rng_state=rng_state,
**fa_optional_backward_kwargs,
causal=causal,
**fa_backward_kwargs,
)
dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
......
......@@ -433,14 +433,14 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
int half_idx);
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, int total_tokens);
const at::Tensor &cu_seqlens, bool lse_packed);
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
int total_tokens);
bool lse_packed);
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half);
bool only_second_half, bool lse_packed);
void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half,
......
......@@ -1464,9 +1464,9 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
template <typename lse_dtype, typename Functor>
template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
int num_heads, int max_seqlen) {
int num_heads, int total_tokens) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
......@@ -1480,12 +1480,18 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, half_idx;
if constexpr (lse_packed) {
idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1];
half_idx = head_id * total_tokens / 2 + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
size_t idx = row * max_seqlen + col + seq_len;
size_t half_idx = row * max_seqlen / 2 + col;
idx = row * total_tokens + col + seq_len;
half_idx = row * total_tokens / 2 + col;
}
Functor::run(lse, half_lse, idx, half_idx);
}
......@@ -1504,32 +1510,53 @@ struct LseCorrectionFunctor {
};
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, int total_tokens) {
const at::Tensor &cu_seqlens, bool lse_packed) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch, num_heads, total_tokens;
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
NVTE_CHECK(lse_per_step.dim() == 2);
batch = cu_seqlens.size(0) - 1;
num_heads = lse.size(0);
total_tokens = lse.size(1);
NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step.size(1) == total_tokens / 2);
} else {
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(lse_per_step.dim() == 3);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch = lse.size(0);
int num_heads = lse.size(1);
int max_seqlen = lse.size(2);
batch = lse.size(0);
num_heads = lse.size(1);
total_tokens = lse.size(2);
NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(2) == max_seqlen / 2);
NVTE_CHECK(lse_per_step.size(2) == total_tokens / 2);
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}
constexpr unsigned int block = 256;
unsigned int grid_x = (total_tokens / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
thd_lse_kernel<double, LseCorrectionFunctor>
if (lse_packed) {
thd_lse_kernel<double, true, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, max_seqlen);
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, total_tokens);
} else {
thd_lse_kernel<double, false, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, total_tokens);
}
}
struct ReadLseFunctor {
......@@ -1540,29 +1567,51 @@ struct ReadLseFunctor {
};
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
int total_tokens) {
bool lse_packed) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch = lse.size(0);
int num_heads = lse.size(1);
int max_seqlen = lse.size(2);
int batch, num_heads, total_tokens;
std::vector<int64_t> shape;
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
batch = cu_seqlens.size(0) - 1;
num_heads = lse.size(0);
total_tokens = lse.size(1);
shape = {num_heads, total_tokens / 2};
} else {
NVTE_CHECK(lse.dim() == 3);
batch = lse.size(0);
num_heads = lse.size(1);
total_tokens = lse.size(2);
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
std::vector<int64_t> shape = {batch, num_heads, max_seqlen / 2};
shape = {batch, num_heads, total_tokens / 2};
}
at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type()));
constexpr unsigned int block = 256;
unsigned int grid_x = (total_tokens / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
thd_lse_kernel<float, ReadLseFunctor>
if (lse_packed) {
thd_lse_kernel<float, true, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, total_tokens);
} else {
thd_lse_kernel<float, false, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, max_seqlen);
num_heads, total_tokens);
}
return half_lse;
}
......@@ -1571,10 +1620,10 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/
template <typename dtype, int only_second_half, int tile_size>
template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int max_seqlen) {
int num_heads, int dim_per_head, int lse_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1);
......@@ -1592,11 +1641,16 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, idx_per_step;
if constexpr (lse_packed) {
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * max_seqlen + col + seq_len * only_second_half;
idx_per_step = row * max_seqlen / (only_second_half + 1) + col;
idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * lse_seqlen / (only_second_half + 1) + col;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);
idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
......@@ -1622,7 +1676,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
template <typename dtype, int only_second_half>
static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step,
const at::Tensor &lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens) {
const at::Tensor &cu_seqlens, bool lse_packed) {
NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type());
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
......@@ -1631,17 +1685,30 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
int total_tokens = out.size(0);
int num_heads = out.size(1);
int dim_per_head = out.size(2);
int batch = lse.size(0);
int max_seqlen = lse.size(2);
NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1));
NVTE_CHECK(out_per_step.size(1) == num_heads);
NVTE_CHECK(out_per_step.size(2) == dim_per_head);
int batch, lse_seqlen;
if (lse_packed) {
batch = cu_seqlens.size(0) - 1;
lse_seqlen = total_tokens;
NVTE_CHECK(lse.size(0) == num_heads);
NVTE_CHECK(lse.size(1) == lse_seqlen);
NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1));
} else {
batch = lse.size(0);
lse_seqlen = lse.size(2);
NVTE_CHECK(lse.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(2) == max_seqlen / (only_second_half + 1));
NVTE_CHECK(lse_per_step.size(2) == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}
constexpr int tile = 16;
constexpr int block = 512;
......@@ -1649,39 +1716,53 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
(static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block;
dim3 grid = {grid_x, (unsigned int)num_heads};
thd_out_correction_kernel<dtype, only_second_half, tile>
if (lse_packed) {
thd_out_correction_kernel<dtype, only_second_half, tile, true>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, max_seqlen);
dim_per_head, lse_seqlen);
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen);
}
}
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half) {
bool only_second_half, bool lse_packed) {
if (only_second_half) {
if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else {
NVTE_ERROR("Unsupported dtype of out\n");
}
} else {
if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else {
NVTE_ERROR("Unsupported dtype of out\n");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment