"superbench/runner/vscode:/vscode.git/clone" did not exist on "69b2c631fc7ea0be002fa76532e73085f3d78474"
Unverified Commit 560bccf8 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

clean CP implementation for flash attention and cuDNN 9.6 (#1387)



* make pad_between_seqs check do not consider padding at the end
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change CP THD test to make it consider 0-length sequence
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change to flash func name
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* only use varlen func of flash attention while qkv_format is THD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try to converge code of flash and fused attentions
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bwd compute with P2P
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant out_per_step view
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable cudnn>9.6 and THD+GQA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable CP with FusedAttn+SWA+All_Gather
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable CP with FusedAttn+SWA+All_Gather
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning for cu_seqlens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix some pylint error
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor import change for pylint
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* more fix for pylint
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix lse_seqlen in thd out correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a4cb1d17
......@@ -163,12 +163,10 @@ def run_dpa_with_cp(
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
if kernel_backend == "FlashAttention":
cu_seqlens_q = cu_seqlens_q_padded[:-1]
else:
cu_seqlens_q = torch.cat(
[torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
cu_seqlens_q[-1] = cu_seqlens_q[-2]
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
......@@ -204,10 +202,8 @@ def run_dpa_with_cp(
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if fp8_mha:
dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2)
......@@ -276,10 +272,8 @@ def run_dpa_with_cp(
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if fp8_mha:
dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2)
......@@ -311,7 +305,7 @@ def run_dpa_with_cp(
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
......@@ -327,7 +321,7 @@ def run_dpa_with_cp(
).item()
== 0
)
cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
......
......@@ -121,22 +121,14 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")
if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0):
pytest.skip("THD format is not supported for cuDNN 9.6+!")
config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
"Sliding window attention only can be supported with the implementation of QKVO A2A!"
)
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
......@@ -147,6 +139,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("FP8 attention cannot work with bias yet!")
if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
......
......@@ -125,11 +125,13 @@ _flash_attn_2_4_1_plus = False
_flash_attn_2_5_7_plus = False
_flash_attn_2_6_0_plus = False
flash_attn_cuda_bwd = None
flash_attn_func = None
flash_attn_varlen_func = None
flash_attn_varlen_fwd = None
flash_attn_varlen_bwd = None
flash_attn_cuda_bwd = None
_flash_attn_fwd = None
_flash_attn_bwd = None
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
try:
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
......@@ -141,14 +143,16 @@ except PackageNotFoundError:
)
else:
if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd
from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd
from flash_attn.flash_attn_interface import (
_flash_attn_varlen_forward as flash_attn_varlen_fwd,
_flash_attn_varlen_forward as _flash_attn_varlen_fwd,
)
from flash_attn.flash_attn_interface import (
_flash_attn_varlen_backward as flash_attn_varlen_bwd,
_flash_attn_varlen_backward as _flash_attn_varlen_bwd,
)
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
_flash_attn_is_installed = True
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
......@@ -195,11 +199,13 @@ else:
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_forward as flash_attn_varlen_fwd_v3,
_flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3,
)
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_backward as flash_attn_varlen_bwd_v3,
_flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3,
)
_flash_attn_3_is_installed = True
......@@ -602,12 +608,6 @@ def get_attention_backend(
"Disabling FusedAttention as it does not support context parallelism with MLA"
)
use_fused_attention = False
elif cudnn_version >= (9, 6, 0) and qkv_format == "thd":
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with THD for"
" cuDNN 9.6+"
)
use_fused_attention = False
# Filter: Attention mask
# attn_mask_type | attention_mask | supported backends
......@@ -1804,12 +1804,20 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
pad_between_seqs_q = cu_seqlens_q_padded is not None and not torch.equal(
cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]
)
pad_between_seqs_kv = cu_seqlens_kv_padded is not None and not torch.equal(
cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]
)
max_seqlen_q = max_seqlen_q // cp_size
max_seqlen_kv = max_seqlen_kv // cp_size
cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
cu_seqlens_q_padded = (
None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // cp_size
)
cu_seqlens_kv_padded = (
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // cp_size
)
cu_seqlens_q_per_step = [None for _ in range(cp_size)]
cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
......@@ -1882,9 +1890,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif qkv_format == "sbhd":
# [s, b, np, hn] -> [2, s//2, b, np, hn]
q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
total_tokens_kv = None if qkv_format != "thd" else k.shape[0]
# remove padded tokens at the end
k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]]
if attn_bias is not None:
assert len(attn_bias.shape) == 4, (
"Only support bias shape of [b, h, sq, sk] for forward, "
......@@ -1907,17 +1912,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
softmax_lse_in_packed_format = not use_fused_attention and (
_flash_attn_2_6_0_plus or _use_flash_attn_3
)
softmax_lse_in_packed_format = False
if qkv_format == "thd":
if use_fused_attention:
softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
else:
softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3
flash_attn_fwd = None
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
flash_attn_fwd = flash_attn_varlen_fwd_v3
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
flash_attn_fwd = _flash_attn_fwd_v3
fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
else:
flash_attn_fwd = flash_attn_varlen_fwd
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd
else:
flash_attn_fwd = _flash_attn_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _flash_attn_2_3_plus:
......@@ -1943,7 +1958,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fwd_results_correction_done = torch.cuda.Event()
p2p_comm_buffers = [None for _ in range(cp_size)]
if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
if qkv_format in ["bshd", "sbhd"]:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
else:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
......@@ -1991,31 +2006,31 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
else:
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
)
else:
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
k.shape[0], -1, 2, *k.shape[-2:]
)
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:]
)
elif qkv_format == "thd":
q_inputs[i % 2] = q
if use_fused_attention:
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
k.shape[0], -1, 2, *k.shape[-2:]
)
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:]
)
elif qkv_format == "thd":
q_inputs[i % 2] = q
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i % 2] = torch.cat(
......@@ -2060,18 +2075,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
fa_forward_args_thd = []
if qkv_format == "thd":
fa_forward_args_thd = [
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv,
]
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv,
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
*fa_forward_args_thd,
causal=True,
**fa_forward_kwargs,
)
......@@ -2084,7 +2108,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
else:
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
......@@ -2095,25 +2119,26 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True,
False,
)
else:
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...]
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][0]
elif qkv_format == "thd":
q_inputs[i % 2] = q
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i % 2] = tex.thd_read_half_tensor(
kv_inputs[i % 2], cu_seqlens_kv_padded, 0
)
if use_fused_attention:
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous()
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous()
elif qkv_format == "thd":
q_inputs[i % 2] = q
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i % 2] = tex.thd_read_half_tensor(
kv_inputs[i % 2], cu_seqlens_kv_padded, 0
)
kv_inputs[i % 2] = kv_inputs[i % 2].contiguous()
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
......@@ -2156,28 +2181,29 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
fa_forward_args_thd = []
if qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i % 2] = tex.thd_read_half_tensor(
kv_inputs[i % 2], cu_seqlens_kv_padded, 0
)
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
fa_forward_args_thd = [
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv // 2,
]
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_forward_kwargs["window_size"] = (-1, -1)
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv // 2,
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
*fa_forward_args_thd,
causal=False,
**fa_forward_kwargs,
)
......@@ -2190,7 +2216,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
)
else:
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
......@@ -2201,28 +2227,29 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True,
True,
)
else:
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i % 2] = q[:, 1, ...]
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
k.shape[0], -1, 2, *k.shape[-2:]
)
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_inputs[i % 2] = q[1]
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:]
)
elif qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_inputs[i % 2] = tex.thd_read_half_tensor(
q, cu_seqlens_q_padded, 1
)
if use_fused_attention:
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i % 2] = q[:, 1, ...].contiguous()
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
k.shape[0], -1, 2, *k.shape[-2:]
)
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_inputs[i % 2] = q[1].contiguous()
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:]
)
elif qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_inputs[i % 2] = tex.thd_read_half_tensor(
q, cu_seqlens_q_padded, 1
)
q_inputs[i % 2] = q_inputs[i % 2].contiguous()
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i % 2] = torch.cat(
......@@ -2271,28 +2298,29 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
fa_forward_args_thd = []
if qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_inputs[i % 2] = tex.thd_read_half_tensor(
q, cu_seqlens_q_padded, 1
)
else:
# [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
q_inputs[i % 2] = (
q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
fa_forward_args_thd = [
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q // 2,
max_seqlen_kv,
]
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_forward_kwargs["window_size"] = (-1, -1)
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q // 2,
max_seqlen_kv,
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
*fa_forward_args_thd,
causal=False,
**fa_forward_kwargs,
)
......@@ -2305,7 +2333,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
else:
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
......@@ -2316,7 +2344,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True,
True,
)
else:
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
if use_fused_attention:
if attn_bias is not None:
......@@ -2363,18 +2391,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
# [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
fa_forward_args_thd = []
if qkv_format == "thd":
fa_forward_args_thd = [
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv,
]
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv,
q,
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
*fa_forward_args_thd,
causal=False,
**fa_forward_kwargs,
)
......@@ -2389,13 +2426,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
if use_fused_attention:
# [b, np, sq, 1] -> [b, np, sq]
# [b, np, sq, 1] -> [b, np, sq] or
# [t, np, 1] -> [t, np]
softmax_lse_per_step[i - 1].squeeze_(-1)
if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, t] -> [np, b, sq]
softmax_lse_per_step[i - 1] = softmax_lse_per_step[i - 1].view(
q.shape[-2], q.shape[0], -1
)
if softmax_lse_in_packed_format:
softmax_lse_per_step[i - 1] = (
softmax_lse_per_step[i - 1].transpose(0, 1).contiguous()
)
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if fp8:
......@@ -2410,8 +2447,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2] lse not in packed format
# [np, b, sq] -> [np, b, 2, sq//2] lse in packed format
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
)
......@@ -2439,16 +2475,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse = softmax_lse.to(torch.float)
for i in range(cp_size):
out_ = None
if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(
out.shape[0], -1, *out.shape[-2:]
) # pylint: disable=used-before-assignment
out_ = out[:, 1, ...]
elif qkv_format == "sbhd":
out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
out_ = out[1]
if i <= rank or not causal:
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(
......@@ -2471,6 +2497,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
else:
if qkv_format in ["bshd", "sbhd"]:
out_ = out.select(seq_dim, 1)
flash_attn_fwd_out_correction(
out_,
out_per_step[i],
......@@ -2490,9 +2517,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_in_packed_format,
)
if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, b, sq] -> [np, t]
softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1)
kv = p2p_comm_buffers[-1]
if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:])
......@@ -2587,7 +2611,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.cp_global_ranks = cp_global_ranks
ctx.cp_stream = cp_stream
ctx.dropout_p = dropout_p
ctx.total_tokens_kv = total_tokens_kv
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.softmax_scale = softmax_scale
......@@ -2597,6 +2620,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
......@@ -2646,14 +2670,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_dbias = None
attn_dbias_ = None
softmax_lse_in_packed_format = not ctx.use_fused_attention and (
_flash_attn_2_6_0_plus or _use_flash_attn_3
)
if causal:
if ctx.qkv_format == "thd" or softmax_lse_in_packed_format:
if ctx.qkv_format == "thd":
softmax_lse_ = tex.thd_read_second_half_lse(
softmax_lse, cu_seqlens_q_padded, softmax_lse_in_packed_format
softmax_lse, cu_seqlens_q_padded, ctx.softmax_lse_in_packed_format
)
else:
# [b, np, sq] -> [b, np, 2, sq//2]
......@@ -2661,13 +2681,20 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1)
if ctx.use_fused_attention:
if ctx.softmax_lse_in_packed_format:
softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous()
# [b, np, sq//2] -> [b, np, sq//2, 1] or
# [t//2, np] -> [t//2, np, 1]
softmax_lse_.unsqueeze_(-1)
if ctx.use_fused_attention:
# [b, np, sq] -> [b, np, sq, 1]
if ctx.softmax_lse_in_packed_format:
softmax_lse = softmax_lse.transpose(0, 1).contiguous()
# [b, np, sq] -> [b, np, sq, 1] or
# [t, np] -> [t, np, 1]
softmax_lse.unsqueeze_(-1)
dq = None
dout_dtype = dout.dtype
fused_attn_backend = None
fused_attn_qkv_dtype = None
......@@ -2715,8 +2742,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dout_scale_inv = dout._scale_inv
dout = dout._data
dq = torch.empty_like(q)
if ctx.qkv_format == "thd" and causal:
dq[cu_seqlens_q_padded[-1] :].fill_(0)
p2p_comm_buffers = [
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
......@@ -2760,10 +2785,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
flash_attn_bwd = flash_attn_varlen_bwd_v3
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd_v3
else:
flash_attn_bwd = _flash_attn_bwd_v3
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
flash_attn_bwd = flash_attn_varlen_bwd
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd
else:
flash_attn_bwd = _flash_attn_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_4_plus:
fa_backward_kwargs["alibi_slopes"] = None
......@@ -2808,32 +2839,29 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
kv = p2p_comm_buffers[i % 2][0]
dk_, dv_ = None, None
q_, kv_, out_, dout_ = None, None, None, None
dq_, dk_, dv_ = None, None, None
if ctx.fp8 and ctx.use_fused_attention:
fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i]
fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i]
# In reversed order of fwd
if causal:
if i == (cp_size - 1):
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_, out_, dout_ = [
x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
]
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:])
elif ctx.qkv_format == "thd":
q_, kv_, out_, dout_ = q, kv, out, dout
if ctx.use_fused_attention:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_ = q.view(-1, *q.shape[-3:])
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:])
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:])
elif ctx.qkv_format == "thd":
q_, kv_, out_, dout_ = q, kv, out, dout
if ctx.fp8:
aux_ctx_tensors = [
softmax_lse,
......@@ -2869,15 +2897,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
**fp8_meta_kwargs,
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.zeros_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
]
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, 0)
if not _use_flash_attn_3:
......@@ -2885,42 +2914,36 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_bwd(
dout_,
q_,
kv_[0],
kv_[1],
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
softmax_lse,
dq_,
dkv_[0],
dkv_[1],
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd,
causal=True,
**fa_backward_kwargs,
)
elif i >= (cp_size - rank - 1):
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_, out_, dout_ = [
x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
]
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_ = kv[:, 0]
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_ = kv[0]
elif ctx.qkv_format == "thd":
q_, out_, dout_ = q, out, dout
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
if ctx.use_fused_attention:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_ = kv[:, 0, ...].contiguous()
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_ = q.view(-1, *q.shape[-3:])
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_ = kv[0].contiguous()
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:])
elif ctx.qkv_format == "thd":
q_, out_, dout_ = q, out, dout
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
kv_ = kv_.contiguous()
if ctx.fp8:
aux_ctx_tensors = [
softmax_lse,
......@@ -2958,19 +2981,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
**fp8_meta_kwargs,
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.zeros_like(q_)
if ctx.qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
else:
# [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2,
]
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
if not _use_flash_attn_3:
......@@ -2978,44 +2998,37 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_bwd(
dout_,
q_,
kv_[0],
kv_[1],
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
softmax_lse,
dq_,
dkv_[0],
dkv_[1],
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd,
causal=False,
**fa_backward_kwargs,
)
else:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1]
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_, out_, dout_ = q[1], out[1], dout[1]
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:])
elif ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_, out_, dout_ = [
tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1)
for x in [q, out, dout]
]
kv_ = kv
if ctx.use_fused_attention:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, 1, ...].contiguous()
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out[:, 1, ...].contiguous()
dout_ = dout[:, 1, ...].contiguous()
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q[1].contiguous()
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:])
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
out_ = out[1].contiguous()
dout_ = dout[1].contiguous()
elif ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
kv_ = kv
q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]]
if ctx.fp8:
aux_ctx_tensors = [
softmax_lse_,
......@@ -3053,23 +3066,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
**fp8_meta_kwargs,
)
else:
if ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
dq_ = torch.zeros_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_)
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
fa_backward_args_thd = [
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv,
]
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
if not _use_flash_attn_3:
......@@ -3077,17 +3083,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_bwd(
dout_,
q_,
kv_[0],
kv_[1],
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
softmax_lse_,
dq_,
dkv_[0],
dkv_[1],
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd,
causal=False,
**fa_backward_kwargs,
)
......@@ -3124,50 +3127,41 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
**fp8_meta_kwargs,
)
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.zeros_like(q_)
# [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, sq, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
dq_ = torch.empty_like(q)
dkv_ = torch.empty_like(kv)
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
]
if _use_flash_attn_3 or _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
q_,
kv_[0],
kv_[1],
out_,
dout,
q,
kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
out,
softmax_lse,
dq_,
dkv_[0],
dkv_[1],
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd,
causal=False,
**fa_backward_kwargs,
)
if ctx.fp8:
dq = dq_fp8[(rank + i + 1) % cp_size]
if i >= (cp_size - rank - 1) or not causal:
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
# [b*sq, np, hn] -> [b, sq, np, hn] if not causal
if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1):
# [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or
# [sq, b, np, hn] -> [2, sq//2, b, np, hn]
dq_ = dq_.view(*dq.shape)
else:
if ctx.qkv_format == "bshd":
# [b*sq//2, np, hn] -> [b, sq//2, np, hn]
dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
elif ctx.qkv_format == "sbhd":
# [b*sq//2, np, hn] -> [sq//2, b, np, hn]
dq_ = dq_.view(-1, *dq.shape[-3:])
if ctx.fp8:
if i >= (cp_size - rank - 1) or not causal:
......@@ -3242,24 +3236,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.use_fused_attention:
dkv_ = torch.cat(
(dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
) # pylint: disable=used-before-assignment
if ctx.qkv_format in ["bshd", "sbhd"]:
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd":
# [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
elif ctx.qkv_format == "sbhd":
# [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn]
dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:])
else:
# [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
# [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
dkv_ = dkv_.view(*dkv.shape)
dkv_ = _combine_tensors([dk_, dv_], -2)
elif ctx.qkv_format == "thd":
dkv_ = torch.cat(
(dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
) # pylint: disable=used-before-assignment
if ctx.qkv_format in ["bshd", "sbhd"]:
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
dkv_ = dkv_.movedim(-3, 0)
if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
# [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
dkv_ = dkv_.view(*dkv.shape)
if ctx.fp8:
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
......@@ -3341,13 +3332,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])
if ctx.qkv_format == "thd":
dkv_ = torch.empty(
2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device
)
dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv)
dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
dkv = dkv_
if ctx.qkv_format == "thd" and not ctx.use_fused_attention:
dq[cu_seqlens_q_padded[-1] :].fill_(0)
dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0)
if ctx.fp8 and ctx.is_input_fp8:
dq, dkv = [
......@@ -3494,9 +3481,15 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
flash_attn_fwd = flash_attn_varlen_fwd_v3
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
flash_attn_fwd = _flash_attn_fwd_v3
else:
flash_attn_fwd = flash_attn_varlen_fwd
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd
else:
flash_attn_fwd = _flash_attn_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _flash_attn_2_4_plus:
......@@ -3514,8 +3507,11 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
max_seqlen_q = max_seqlen_q // (2 * cp_size)
max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
if use_fused_attention or qkv_format == "thd":
cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
cu_seqlens_q_padded = (
None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // (2 * cp_size)
)
# [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
......@@ -3570,9 +3566,10 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
kv_seq_range_per_step[i][1],
)
max_seqlen_kv_ = seq_end_idx - seq_start_idx
cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens(
k.shape[1], max_seqlen_kv_, k.device
)
if use_fused_attention or qkv_format == "thd":
cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens(
k.shape[1], max_seqlen_kv_, k.device
)
k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
# [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
......@@ -3599,15 +3596,19 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
window_size=window_size_per_step[i],
)
else:
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
fa_forward_args_thd = []
if qkv_format == "thd":
fa_forward_args_thd = [
cu_seqlens_q,
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv_,
]
fa_outputs = flash_attn_fwd(
q_,
k_,
v_,
cu_seqlens_q,
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv_,
*fa_forward_args_thd,
causal=causal,
window_size=window_size_per_step[i],
**fa_forward_kwargs,
......@@ -3620,9 +3621,9 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
if qkv_format == "bshd":
out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape))
out[:, i - 1].copy_(out_per_step[i - 1])
elif qkv_format == "sbhd":
out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape))
out[i - 1].copy_(out_per_step[i - 1])
torch.cuda.current_stream().wait_stream(cp_stream)
......@@ -3711,10 +3712,16 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
flash_attn_bwd = flash_attn_varlen_bwd_v3
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd_v3
else:
flash_attn_bwd = _flash_attn_bwd_v3
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
flash_attn_bwd = flash_attn_varlen_bwd
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd
else:
flash_attn_bwd = _flash_attn_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_4_plus:
fa_backward_kwargs["alibi_slopes"] = None
......@@ -3764,11 +3771,17 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
deterministic=ctx.deterministic,
)
else:
batch_size = k_.shape[0]
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
]
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q,
cu_seqlens_kv_per_step[i],
ctx.max_seqlen_q,
max_seqlen_kv,
]
if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[i]
flash_attn_bwd(
......@@ -3781,21 +3794,11 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dq_per_step[i],
dk_per_step[i],
dv_per_step[i],
cu_seqlens_q,
cu_seqlens_kv_per_step[i],
ctx.max_seqlen_q,
max_seqlen_kv,
*fa_backward_args_thd,
causal="causal" in ctx.attn_mask_type,
window_size=window_size_per_step[i],
**fa_backward_kwargs,
)
# [b*sq//2, np, hn] -> [b, sq//2, np, hn]
dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape)
# [b*s_range, np, hn] -> [b, s_range, np, hn]
dk_per_step[i], dv_per_step[i] = [
x.view(batch_size, -1, *x.shape[-2:])
for x in [dk_per_step[i], dv_per_step[i]]
]
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
......@@ -3916,10 +3919,16 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if _use_flash_attn_3:
flash_attn_fwd = flash_attn_varlen_fwd_v3
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
flash_attn_fwd = _flash_attn_fwd_v3
fa_forward_kwargs["window_size"] = window_size
else:
flash_attn_fwd = flash_attn_varlen_fwd
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd
else:
flash_attn_fwd = _flash_attn_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if _flash_attn_2_3_plus:
......@@ -4025,24 +4034,25 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
**fp8_meta_kwargs,
)
else:
# [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn]
q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]]
fa_forward_args_thd = []
if qkv_format == "thd":
fa_forward_args_thd = [
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
]
fa_outputs = flash_attn_fwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
*fa_forward_args_thd,
causal=causal,
**fa_forward_kwargs,
)
out, softmax_lse = fa_outputs[4], fa_outputs[5]
rng_state = fa_outputs[7] if not _use_flash_attn_3 else None
aux_ctx_tensors = [softmax_lse, rng_state]
# [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn]
out = out.view(batch_size, -1, *out.shape[-2:])
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False)
out = flash_attn_a2a_communicate(
......@@ -4214,11 +4224,17 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if _use_flash_attn_3:
flash_attn_bwd = flash_attn_varlen_bwd_v3
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd_v3
else:
flash_attn_bwd = _flash_attn_bwd_v3
fa_backward_kwargs["window_size"] = ctx.window_size
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
flash_attn_bwd = flash_attn_varlen_bwd
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd
else:
flash_attn_bwd = _flash_attn_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = ctx.window_size
......@@ -4255,8 +4271,15 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
)
else:
softmax_lse, rng_state = aux_ctx_tensors
out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]]
dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q,
cu_seqlens_kv,
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
]
if not _use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_state
flash_attn_bwd(
......@@ -4269,14 +4292,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_kv,
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
*fa_backward_args_thd,
causal=causal,
**fa_backward_kwargs,
)
dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False)
dq, dk, dv = flash_attn_a2a_communicate(
......@@ -4400,18 +4419,17 @@ def attn_forward_func_with_cp(
"""Attention bias is only supported with FusedAttention and "causal" """
"""or "no_mask" mask types!"""
)
assert (
assert qkv_format != "thd" or (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
), "cu_seqlens_padded cannot be None with context parallelism + THD format!"
sliding_window_attn = (
window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
)
assert (
not sliding_window_attn
or cp_comm_type == "a2a"
or (cp_comm_type == "all_gather" and not use_fused_attention)
), "The context parallel running configs cannot support sliding window attetnion!"
assert not sliding_window_attn or cp_comm_type in [
"a2a",
"all_gather",
], "The context parallel running configs cannot support sliding window attetnion!"
args = [
is_training,
......@@ -5419,8 +5437,8 @@ class FlashAttention(torch.nn.Module):
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q if qkv_format == "thd" else None,
cu_seqlens_kv if qkv_format == "thd" else None,
self.attention_dropout if self.training else 0.0,
cp_group,
cp_global_ranks,
......@@ -7215,7 +7233,7 @@ class FusedAttention(torch.nn.Module):
and cu_seqlens_kv is not None
), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
if qkv_format == "thd" and (cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None):
cu_seqlens_q_padded = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_kv
......@@ -8151,10 +8169,10 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs = (
cu_seqlens_q_padded is not None
and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
) or (
cu_seqlens_kv_padded is not None
and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
)
attention_params = AttentionParams(
......
......@@ -1537,10 +1537,10 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
int batch, lse_seqlen;
if (lse_packed) {
batch = cu_seqlens.size(0) - 1;
lse_seqlen = total_tokens;
lse_seqlen = lse.size(1);
NVTE_CHECK(lse.size(0) == num_heads);
NVTE_CHECK(lse.size(1) == lse_seqlen);
NVTE_CHECK(lse_seqlen >= total_tokens);
NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1));
} else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment