Unverified Commit 9c127ef5 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Fix context parallelism implementation with THD format (#1012)



* use 2hd layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change qkv_format check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a code comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* tensor shape bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tensor shape fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add function to compute cu_seqlens of a cp rank
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cu_seqlens and cu_seqlens_padded to context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FlashAttention output sequence length
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix cu_seqlens_kv_per_step calculation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV for ending padded tokens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV tensors of FlashAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove padded tokens of KV to save comounication
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not need to zero dkv for FlashAttention any mroe
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero out tensors
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of cp test with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarXiaowei Ren <xren@cs-cw-dfw-login-01.cm.cluster>
parent 70117306
...@@ -6,6 +6,7 @@ import os, sys ...@@ -6,6 +6,7 @@ import os, sys
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank
import transformer_engine_torch as tex import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
...@@ -86,6 +87,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -86,6 +87,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
) )
cu_seqlens_q = None cu_seqlens_q = None
cu_seqlens_kv = None cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim) q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = ( kv_input_shape = (
...@@ -101,18 +104,36 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -101,18 +104,36 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
) )
cu_seqlens_q = None cu_seqlens_q = None
cu_seqlens_kv = None cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "thd": elif qkv_format == "thd":
seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to( q_input_shape = (config.batch_size * config.max_seqlen_q, config.num_heads, config.head_dim)
torch.int32 kv_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim,
) )
seqlens_q = seqlens_q - seqlens_q % (world_size * 2) attn_output_shape = (
cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)]) config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
if kernel_backend == "FlashAttention":
cu_seqlens_q = cu_seqlens_q_padded[:-1]
else:
cu_seqlens_q = torch.cat(
[torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)]
).cuda()
cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv = cu_seqlens_q
q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim) cu_seqlens_kv_padded = cu_seqlens_q_padded
kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim)
attn_output_shape = (cu_seqlens_q[-1], config.num_heads * config.head_dim)
cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda()
cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda()
else: else:
assert False, f"{qkv_format} is an unsupported qkv_format!" assert False, f"{qkv_format} is an unsupported qkv_format!"
...@@ -132,7 +153,7 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -132,7 +153,7 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
for x in [q, k, v, dout] + ([] if bias is None else [bias]): for x in [q, k, v, dout] + ([] if bias is None else [bias]):
dist.broadcast(x, 0, group=cp_comm_group) dist.broadcast(x, 0, group=cp_comm_group)
if qkv_format == "thd": if qkv_format == "thd":
for x in [cu_seqlens_q, cu_seqlens_kv]: for x in [cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded]:
dist.broadcast(x, 0, group=cp_comm_group) dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP # run core_attn without CP
...@@ -146,6 +167,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -146,6 +167,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
core_attention_bias=bias, core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1],
) )
out.backward(dout) out.backward(dout)
...@@ -171,12 +194,14 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -171,12 +194,14 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_]
] ]
elif qkv_format == "thd": elif qkv_format == "thd":
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank) seq_idx_q = tex.thd_get_partitioned_indices(
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank) cu_seqlens_q_padded, q_.shape[0], world_size, rank
)
seq_idx_kv = tex.thd_get_partitioned_indices(
cu_seqlens_kv_padded, k_.shape[0], world_size, rank
)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
cu_seqlens_q = cu_seqlens_q // world_size
cu_seqlens_kv = cu_seqlens_kv // world_size
else: else:
assert False, f"{qkv_format} is an unsupported qkv_format!" assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
...@@ -187,8 +212,6 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -187,8 +212,6 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream()) core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
max_seqlen_q = config.max_seqlen_q
max_seqlen_kv = config.max_seqlen_kv
out_ = core_attn( out_ = core_attn(
q_, q_,
k_, k_,
...@@ -197,8 +220,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -197,8 +220,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
core_attention_bias=bias_, core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q, cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
max_seqlen_kv=max_seqlen_kv, cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1],
) )
out_.backward(dout_) out_.backward(dout_)
...@@ -230,9 +253,45 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -230,9 +253,45 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
for x in [q_.grad, k_.grad, v_.grad, out_] for x in [q_.grad, k_.grad, v_.grad, out_]
] ]
elif qkv_format == "thd": elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]] dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [x.view(-1) for x in [q_.grad, k_.grad, v_.grad, out_]] dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
for x in [dq, out, dq_, out_]:
assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_q[b] == 0
or torch.count_nonzero(
x[(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[b + 1]]
).item()
== 0
)
cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv
num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1]
for x in [dk, dv, dk_, dv_]:
assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_kv[b] == 0
or torch.count_nonzero(
x[
(cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[
b + 1
]
]
).item()
== 0
)
else: else:
assert False, f"{qkv_format} is an unsupported qkv_format!" assert False, f"{qkv_format} is an unsupported qkv_format!"
......
...@@ -1180,6 +1180,27 @@ def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step): ...@@ -1180,6 +1180,27 @@ def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
softmax_lse.copy_(new_scale) softmax_lse.copy_(new_scale)
@jit_fuser
def get_cu_seqlens_on_cp_rank(
cu_seqlens, cu_seqlens_padded_on_cp_rank, cp_size, cp_rank, first_half, second_half
):
"""Compute cu_seqlens of a context parallelism rank"""
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
zeros = torch.zeros_like(seqlens)
cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
if first_half:
seqlens_1 = seqlens - cp_rank * seqlens_padded
seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
if second_half:
seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
cu_seqlens_on_cp_rank.cumsum_(dim=0)
return cu_seqlens_on_cp_rank
class AttnFuncWithCP(torch.autograd.Function): class AttnFuncWithCP(torch.autograd.Function):
""" """
Attention implementation with context parallelism. Attention implementation with context parallelism.
...@@ -1195,9 +1216,9 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1195,9 +1216,9 @@ class AttnFuncWithCP(torch.autograd.Function):
k, k,
v, v,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_kv,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_kv,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
dropout_p, dropout_p,
...@@ -1224,8 +1245,20 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1224,8 +1245,20 @@ class AttnFuncWithCP(torch.autograd.Function):
causal = "causal" in attn_mask_type causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type padding = "padding" in attn_mask_type
if qkv_format in ["bshd", "sbhd"]:
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
else:
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
max_seqlen_q = max_seqlen_q // cp_size
max_seqlen_kv = max_seqlen_kv // cp_size
cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
cu_seqlens_q_per_step = [None for _ in range(cp_size)]
cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
if causal: if causal:
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, s, np, hn] -> [b, 2, s//2, np, hn] # [b, s, np, hn] -> [b, 2, s//2, np, hn]
...@@ -1233,6 +1266,9 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1233,6 +1266,9 @@ class AttnFuncWithCP(torch.autograd.Function):
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
# [s, b, np, hn] -> [2, s//2, b, np, hn] # [s, b, np, hn] -> [2, s//2, b, np, hn]
q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
total_tokens_kv = None if qkv_format != "thd" else k.shape[0]
# remove padded tokens at the end
k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]]
if attn_bias is not None: if attn_bias is not None:
assert len(attn_bias.shape) == 4, ( assert len(attn_bias.shape) == 4, (
"Only support bias shape of [b, h, sq, sk] for forward, " "Only support bias shape of [b, h, sq, sk] for forward, "
...@@ -1273,6 +1309,9 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1273,6 +1309,9 @@ class AttnFuncWithCP(torch.autograd.Function):
fwd_results_correction_done = torch.cuda.Event() fwd_results_correction_done = torch.cuda.Event()
p2p_comm_buffers = [None for _ in range(cp_size)] p2p_comm_buffers = [None for _ in range(cp_size)]
if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
else:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []] send_recv_reqs = [[], []]
...@@ -1298,19 +1337,33 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1298,19 +1337,33 @@ class AttnFuncWithCP(torch.autograd.Function):
kv_inputs[i % 2] = p2p_comm_buffers[i] kv_inputs[i % 2] = p2p_comm_buffers[i]
if causal: if causal:
if i == 0: if i == 0:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
)
else:
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
if use_fused_attention: if use_fused_attention:
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view( kv_inputs[i % 2] = kv_inputs[i % 2].view(
2, k.shape[0], -1, *k.shape[-2:] k.shape[0], -1, 2, *k.shape[-2:]
) )
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:]) kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:]
)
elif qkv_format == "thd": elif qkv_format == "thd":
q_inputs[i % 2] = q q_inputs[i % 2] = q
if attn_bias is not None: if attn_bias is not None:
...@@ -1326,12 +1379,20 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1326,12 +1379,20 @@ class AttnFuncWithCP(torch.autograd.Function):
fused_attn_fwd( fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q_per_step[i],
cu_seqlens_k, cu_seqlens_kv_per_step[i],
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], (
kv_inputs[i % 2][1], kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, attn_scale=softmax_scale,
...@@ -1364,10 +1425,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1364,10 +1425,10 @@ class AttnFuncWithCP(torch.autograd.Function):
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], kv_inputs[i % 2][0],
kv_inputs[i % 2][1], kv_inputs[i % 2][1],
cu_seqlens_q, cu_seqlens_q_per_step[i],
cu_seqlens_k, cu_seqlens_kv_per_step[i],
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_kv,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal=True, causal=True,
...@@ -1375,22 +1436,39 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1375,22 +1436,39 @@ class AttnFuncWithCP(torch.autograd.Function):
**fa_optional_forward_kwargs, **fa_optional_forward_kwargs,
) )
elif i <= rank: elif i <= rank:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv,
cu_seqlens_kv_padded,
cp_size,
(rank - i) % cp_size,
True,
False,
)
else:
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
if use_fused_attention: if use_fused_attention:
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous()
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous() kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous()
elif qkv_format == "thd": elif qkv_format == "thd":
q_inputs[i % 2] = q q_inputs[i % 2] = q
# [2, t, np, hn] -> [2, t/2, np, hn] # [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i % 2] = tex.thd_read_half_tensor( kv_inputs[i % 2] = tex.thd_read_half_tensor(
kv_inputs[i % 2], cu_seqlens_k, 0 kv_inputs[i % 2], cu_seqlens_kv_padded, 0
) )
if attn_bias is not None: if attn_bias is not None:
idx = (rank - i) % cp_size idx = (rank - i) % cp_size
...@@ -1399,12 +1477,20 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1399,12 +1477,20 @@ class AttnFuncWithCP(torch.autograd.Function):
fused_attn_fwd( fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_k // 2, max_seqlen_kv // 2,
cu_seqlens_q, cu_seqlens_q_per_step[i],
cu_seqlens_k // 2, cu_seqlens_kv_per_step[i],
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], (
kv_inputs[i % 2][1], kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, attn_scale=softmax_scale,
...@@ -1429,7 +1515,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1429,7 +1515,7 @@ class AttnFuncWithCP(torch.autograd.Function):
if qkv_format == "thd": if qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn] # [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i % 2] = tex.thd_read_half_tensor( kv_inputs[i % 2] = tex.thd_read_half_tensor(
kv_inputs[i % 2], cu_seqlens_k, 0 kv_inputs[i % 2], cu_seqlens_kv_padded, 0
) )
else: else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
...@@ -1451,10 +1537,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1451,10 +1537,10 @@ class AttnFuncWithCP(torch.autograd.Function):
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], kv_inputs[i % 2][0],
kv_inputs[i % 2][1], kv_inputs[i % 2][1],
cu_seqlens_q, cu_seqlens_q_per_step[i],
cu_seqlens_k // 2, cu_seqlens_kv_per_step[i],
max_seqlen_q, max_seqlen_q,
max_seqlen_k // 2, max_seqlen_kv // 2,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal=False, causal=False,
...@@ -1462,22 +1548,43 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1462,22 +1548,43 @@ class AttnFuncWithCP(torch.autograd.Function):
**fa_optional_forward_kwargs, **fa_optional_forward_kwargs,
) )
else: else:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
)
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv,
cu_seqlens_kv_padded,
cp_size,
(rank - i) % cp_size,
True,
True,
)
else:
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
if use_fused_attention: if use_fused_attention:
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i % 2] = q[:, 1, ...].contiguous() q_inputs[i % 2] = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view( kv_inputs[i % 2] = kv_inputs[i % 2].view(
2, k.shape[0], -1, *k.shape[-2:] k.shape[0], -1, 2, *k.shape[-2:]
) )
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_inputs[i % 2] = q[1].contiguous() q_inputs[i % 2] = q[1].contiguous()
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:]) kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:]
)
elif qkv_format == "thd": elif qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn] # [t, np, hn] -> [t/2, np, hn]
q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) q_inputs[i % 2] = tex.thd_read_half_tensor(
q, cu_seqlens_q_padded, 1
)
if attn_bias is not None: if attn_bias is not None:
idx = (rank - i) % cp_size idx = (rank - i) % cp_size
attn_bias_inputs[i % 2] = torch.cat( attn_bias_inputs[i % 2] = torch.cat(
...@@ -1491,12 +1598,20 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1491,12 +1598,20 @@ class AttnFuncWithCP(torch.autograd.Function):
fused_attn_fwd( fused_attn_fwd(
is_training, is_training,
max_seqlen_q // 2, max_seqlen_q // 2,
max_seqlen_k, max_seqlen_kv,
cu_seqlens_q // 2, cu_seqlens_q_per_step[i],
cu_seqlens_k, cu_seqlens_kv_per_step[i],
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], (
kv_inputs[i % 2][1], kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, attn_scale=softmax_scale,
...@@ -1518,7 +1633,9 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1518,7 +1633,9 @@ class AttnFuncWithCP(torch.autograd.Function):
else: else:
if qkv_format == "thd": if qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn] # [t, np, hn] -> [t/2, np, hn]
q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) q_inputs[i % 2] = tex.thd_read_half_tensor(
q, cu_seqlens_q_padded, 1
)
else: else:
# [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
q_inputs[i % 2] = ( q_inputs[i % 2] = (
...@@ -1541,10 +1658,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1541,10 +1658,10 @@ class AttnFuncWithCP(torch.autograd.Function):
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], kv_inputs[i % 2][0],
kv_inputs[i % 2][1], kv_inputs[i % 2][1],
cu_seqlens_q // 2, cu_seqlens_q_per_step[i],
cu_seqlens_k, cu_seqlens_kv_per_step[i],
max_seqlen_q // 2, max_seqlen_q // 2,
max_seqlen_k, max_seqlen_kv,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal=False, causal=False,
...@@ -1552,6 +1669,23 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1552,6 +1669,23 @@ class AttnFuncWithCP(torch.autograd.Function):
**fa_optional_forward_kwargs, **fa_optional_forward_kwargs,
) )
else: else:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv,
cu_seqlens_kv_padded,
cp_size,
(rank - i) % cp_size,
True,
True,
)
else:
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
if use_fused_attention: if use_fused_attention:
if attn_bias is not None: if attn_bias is not None:
idx = (rank - i) % cp_size idx = (rank - i) % cp_size
...@@ -1566,12 +1700,20 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1566,12 +1700,20 @@ class AttnFuncWithCP(torch.autograd.Function):
fused_attn_fwd( fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q_per_step[i],
cu_seqlens_k, cu_seqlens_kv_per_step[i],
q, q,
kv_inputs[i % 2][0], (
kv_inputs[i % 2][1], kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, attn_scale=softmax_scale,
...@@ -1604,10 +1746,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1604,10 +1746,10 @@ class AttnFuncWithCP(torch.autograd.Function):
q_inputs[i % 2], q_inputs[i % 2],
kv_inputs[i % 2][0], kv_inputs[i % 2][0],
kv_inputs[i % 2][1], kv_inputs[i % 2][1],
cu_seqlens_q, cu_seqlens_q_per_step[i],
cu_seqlens_k, cu_seqlens_kv_per_step[i],
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_kv,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal=False, causal=False,
...@@ -1626,7 +1768,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1626,7 +1768,7 @@ class AttnFuncWithCP(torch.autograd.Function):
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if i == 1: if i == 1:
out = torch.empty_like(q).zero_() out = torch.zeros_like(q)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal and qkv_format != "thd": if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2] # [b, np, sq] -> [b, np, 2, sq//2]
...@@ -1640,7 +1782,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1640,7 +1782,10 @@ class AttnFuncWithCP(torch.autograd.Function):
else: else:
if qkv_format == "thd": if qkv_format == "thd":
tex.thd_second_half_lse_correction( tex.thd_second_half_lse_correction(
softmax_lse, softmax_lse_per_step[i - 1], cu_seqlens_q, q.size(0) softmax_lse,
softmax_lse_per_step[i - 1],
cu_seqlens_q_padded,
max_seqlen_q,
) )
else: else:
flash_attn_fwd_softmax_lse_correction( flash_attn_fwd_softmax_lse_correction(
...@@ -1678,7 +1823,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1678,7 +1823,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_per_step[i], out_per_step[i],
softmax_lse, softmax_lse,
softmax_lse_per_step[i], softmax_lse_per_step[i],
cu_seqlens_q, cu_seqlens_q_padded,
False, False,
) )
else: else:
...@@ -1698,7 +1843,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1698,7 +1843,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_per_step[i], out_per_step[i],
softmax_lse, softmax_lse,
softmax_lse_per_step[i], softmax_lse_per_step[i],
cu_seqlens_q, cu_seqlens_q_padded,
True, True,
) )
else: else:
...@@ -1718,18 +1863,19 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1718,18 +1863,19 @@ class AttnFuncWithCP(torch.autograd.Function):
kv, kv,
out, out,
softmax_lse, softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
*cu_seqlens_q_per_step,
*cu_seqlens_kv_per_step,
*rng_states, *rng_states,
*attn_biases, *attn_biases,
) )
ctx.cp_group = cp_group ctx.cp_group = cp_group
ctx.cp_global_ranks = cp_global_ranks ctx.cp_global_ranks = cp_global_ranks
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.total_tokens_kv = total_tokens_kv
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k ctx.max_seqlen_kv = max_seqlen_kv
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.qkv_format = qkv_format ctx.qkv_format = qkv_format
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
...@@ -1741,19 +1887,23 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1741,19 +1887,23 @@ class AttnFuncWithCP(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout):
(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6]
(cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[6:8]
cp_size = get_distributed_world_size(ctx.cp_group) cp_size = get_distributed_world_size(ctx.cp_group)
rng_states = ctx.saved_tensors[8 : 8 + cp_size]
attn_biases = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
rank = get_distributed_rank(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size] send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
cu_seqlens_q_per_step = ctx.saved_tensors[6 : 6 + cp_size]
cu_seqlens_kv_per_step = ctx.saved_tensors[6 + cp_size : 6 + cp_size * 2]
rng_states = ctx.saved_tensors[6 + cp_size * 2 : 6 + cp_size * 3]
attn_biases = ctx.saved_tensors[6 + cp_size * 3 : 6 + cp_size * 4]
causal = "causal" in ctx.attn_mask_type causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type
if ctx.qkv_format in ["bshd", "sbhd"]:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
else:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
if attn_biases[0] is not None: if attn_biases[0] is not None:
...@@ -1770,7 +1920,9 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1770,7 +1920,9 @@ class AttnFuncWithCP(torch.autograd.Function):
if causal: if causal:
if ctx.qkv_format == "thd": if ctx.qkv_format == "thd":
softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0)) softmax_lse_ = tex.thd_read_second_half_lse(
softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q
)
else: else:
# [b, np, sq] -> [b, np, 2, sq//2] # [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view( softmax_lse_ = softmax_lse.view(
...@@ -1788,6 +1940,8 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1788,6 +1940,8 @@ class AttnFuncWithCP(torch.autograd.Function):
dout = dout.view(*q.shape) dout = dout.view(*q.shape)
# Flash Attn outputs # Flash Attn outputs
dq = torch.empty_like(q) dq = torch.empty_like(q)
if ctx.qkv_format == "thd" and causal:
dq[cu_seqlens_q_padded[-1] :].fill_(0)
p2p_comm_buffers = [ p2p_comm_buffers = [
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
...@@ -1828,16 +1982,16 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1828,16 +1982,16 @@ class AttnFuncWithCP(torch.autograd.Function):
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:]) q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:]) out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_ = q.view(-1, *q.shape[-3:]) q_ = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) kv_ = kv.view(-1, *kv.shape[-4:])
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:]) out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:])
...@@ -1848,12 +2002,12 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1848,12 +2002,12 @@ class AttnFuncWithCP(torch.autograd.Function):
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_k, ctx.max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_k, cu_seqlens_kv_per_step[cp_size - i - 1],
q_, q_,
kv_[0], kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[1], kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_, out_,
dout_, dout_,
TE_DType[q.dtype], TE_DType[q.dtype],
...@@ -1871,7 +2025,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1871,7 +2025,7 @@ class AttnFuncWithCP(torch.autograd.Function):
else: else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:]) q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_) dq_ = torch.zeros_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:]) kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
...@@ -1890,10 +2044,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1890,10 +2044,10 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dq_,
dkv_[0], dkv_[0],
dkv_[1], dkv_[1],
cu_seqlens_q, cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_k, cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_k, ctx.max_seqlen_kv,
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
True, True,
...@@ -1905,34 +2059,34 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1905,34 +2059,34 @@ class AttnFuncWithCP(torch.autograd.Function):
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:]) q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous() kv_ = kv[:, 0, ...].contiguous()
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:]) out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_ = q.view(-1, *q.shape[-3:]) q_ = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_ = kv[:, 0, ...].contiguous() kv_ = kv[0].contiguous()
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:]) out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:])
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
q_, out_, dout_ = q, out, dout q_, out_, dout_ = q, out, dout
# [2, t, np, hn] -> [2, t/2, np, hn] # [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0) kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_k // 2, ctx.max_seqlen_kv // 2,
cu_seqlens_q, cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_k // 2, cu_seqlens_kv_per_step[cp_size - i - 1],
q_, q_,
kv_[0], kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[1], kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_, out_,
dout_, dout_,
TE_DType[q.dtype], TE_DType[q.dtype],
...@@ -1952,10 +2106,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1952,10 +2106,10 @@ class AttnFuncWithCP(torch.autograd.Function):
else: else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:]) q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_) dq_ = torch.zeros_like(q_)
if ctx.qkv_format == "thd": if ctx.qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn] # [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0) kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
else: else:
# [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn] # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
...@@ -1975,10 +2129,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1975,10 +2129,10 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dq_,
dkv_[0], dkv_[0],
dkv_[1], dkv_[1],
cu_seqlens_q, cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_k // 2, cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_k // 2, ctx.max_seqlen_kv // 2,
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
False, False,
...@@ -1990,36 +2144,36 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1990,36 +2144,36 @@ class AttnFuncWithCP(torch.autograd.Function):
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, 1, ...].contiguous() q_ = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out[:, 1, ...].contiguous() out_ = out[:, 1, ...].contiguous()
dout_ = dout[:, 1, ...].contiguous() dout_ = dout[:, 1, ...].contiguous()
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q[1].contiguous() q_ = q[1].contiguous()
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) kv_ = kv.view(-1, *kv.shape[-4:])
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
out_ = out[1].contiguous() out_ = out[1].contiguous()
dout_ = dout[1].contiguous() dout_ = dout[1].contiguous()
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn] # [t, np, hn] -> [t/2, np, hn]
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1) out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1) dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
kv_ = kv kv_ = kv
aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q // 2, ctx.max_seqlen_q // 2,
ctx.max_seqlen_k, ctx.max_seqlen_kv,
cu_seqlens_q // 2, cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_k, cu_seqlens_kv_per_step[cp_size - i - 1],
q_, q_,
kv_[0], kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[1], kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_, out_,
dout_, dout_,
TE_DType[q.dtype], TE_DType[q.dtype],
...@@ -2039,17 +2193,17 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2039,17 +2193,17 @@ class AttnFuncWithCP(torch.autograd.Function):
else: else:
if ctx.qkv_format == "thd": if ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn] # [t, np, hn] -> [t/2, np, hn]
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
else: else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_) dq_ = torch.zeros_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:]) kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
if ctx.qkv_format == "thd": if ctx.qkv_format == "thd":
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1) out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1) dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
else: else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
...@@ -2066,10 +2220,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2066,10 +2220,10 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dq_,
dkv_[0], dkv_[0],
dkv_[1], dkv_[1],
cu_seqlens_q // 2, cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_k, cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q // 2, ctx.max_seqlen_q // 2,
ctx.max_seqlen_k, ctx.max_seqlen_kv,
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
False, False,
...@@ -2083,12 +2237,12 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2083,12 +2237,12 @@ class AttnFuncWithCP(torch.autograd.Function):
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_k, ctx.max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_k, cu_seqlens_kv_per_step[cp_size - i - 1],
q, q,
kv[0], kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
kv[1], kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
out, out,
dout, dout,
TE_DType[q.dtype], TE_DType[q.dtype],
...@@ -2106,7 +2260,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2106,7 +2260,7 @@ class AttnFuncWithCP(torch.autograd.Function):
else: else:
# [b, sq, np, hn] -> [b*sq, np, hn] # [b, sq, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:]) q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_) dq_ = torch.zeros_like(q_)
# [2, b, sk, np, hn] -> [2, b*sk, np, hn] # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:]) kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
...@@ -2125,10 +2279,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2125,10 +2279,10 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dq_,
dkv_[0], dkv_[0],
dkv_[1], dkv_[1],
cu_seqlens_q, cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_k, cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_k, ctx.max_seqlen_kv,
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
False, False,
...@@ -2162,21 +2316,21 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2162,21 +2316,21 @@ class AttnFuncWithCP(torch.autograd.Function):
dq[0].copy_(dq_[0]) dq[0].copy_(dq_[0])
dq[1].add_(dq_[1]) dq[1].add_(dq_[1])
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add") tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
elif i > 0: elif i > 0:
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
dq[:, 1, ...].add_(dq_) dq[:, 1, ...].add_(dq_)
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dq[1].add_(dq_) dq[1].add_(dq_)
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add") tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
else: else:
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
dq[:, 1, ...].copy_(dq_) dq[:, 1, ...].copy_(dq_)
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dq[1].copy_(dq_) dq[1].copy_(dq_)
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy") tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
else: else:
if i == 0: if i == 0:
dq.copy_(dq_) dq.copy_(dq_)
...@@ -2206,6 +2360,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2206,6 +2360,10 @@ class AttnFuncWithCP(torch.autograd.Function):
dkv = p2p_comm_buffers[(i + 1) % 2][1] dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.use_fused_attention: if ctx.use_fused_attention:
dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
if ctx.qkv_format in ["bshd", "sbhd"]:
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
# [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
...@@ -2228,7 +2386,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2228,7 +2386,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dkv[:, 0, ...].add_(dkv_[:, 0, ...]) dkv[:, 0, ...].add_(dkv_[:, 0, ...])
dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy") tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
else: else:
dkv.add_(dkv_) dkv.add_(dkv_)
elif i >= (cp_size - rank - 1): elif i >= (cp_size - rank - 1):
...@@ -2238,14 +2396,14 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2238,14 +2396,14 @@ class AttnFuncWithCP(torch.autograd.Function):
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].copy_(dkv_) dkv[:, 0, ...].copy_(dkv_)
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none") tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
else: else:
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].add_(dkv_) dkv[:, :, 0, ...].add_(dkv_)
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_) dkv[:, 0, ...].add_(dkv_)
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none") tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
elif i > 0: elif i > 0:
dkv.add_(dkv_) dkv.add_(dkv_)
else: else:
...@@ -2259,14 +2417,22 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2259,14 +2417,22 @@ class AttnFuncWithCP(torch.autograd.Function):
if causal: if causal:
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq = dq.view(q.shape[0], -1, *q.shape[-2:]) dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
dq = dq.view(-1, *q.shape[-3:]) dq = dq.view(-1, *dq.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
dkv = dkv.view(kv.shape[0], -1, *kv.shape[-3:]) dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])
if ctx.qkv_format == "thd":
dkv_ = torch.empty(
2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device
)
dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv)
dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
dkv = dkv_
if attn_dbias is not None: if attn_dbias is not None:
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
...@@ -2303,9 +2469,9 @@ def attn_forward_func_with_cp( ...@@ -2303,9 +2469,9 @@ def attn_forward_func_with_cp(
k, k,
v, v,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_kv,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_kv,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
dropout_p, dropout_p,
...@@ -2341,15 +2507,18 @@ def attn_forward_func_with_cp( ...@@ -2341,15 +2507,18 @@ def attn_forward_func_with_cp(
"""Attention bias is only supported with FusedAttention and "causal" """ """Attention bias is only supported with FusedAttention and "causal" """
"""or "no_mask" mask types!""" """or "no_mask" mask types!"""
) )
assert (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
out = AttnFuncWithCP.apply( out = AttnFuncWithCP.apply(
is_training, is_training,
q, q,
k, k,
v, v,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_kv,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_kv,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
dropout_p, dropout_p,
...@@ -3140,7 +3309,8 @@ class FlashAttention(torch.nn.Module): ...@@ -3140,7 +3309,8 @@ class FlashAttention(torch.nn.Module):
qkv_layout in QKVLayouts qkv_layout in QKVLayouts
), f"FlashAttention does not support qkv_layout = {qkv_layout}!" ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1) cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group)
context_parallel = cp_size > 1
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
...@@ -3167,6 +3337,8 @@ class FlashAttention(torch.nn.Module): ...@@ -3167,6 +3337,8 @@ class FlashAttention(torch.nn.Module):
if qkv_format in ["sbhd", "bshd"]: if qkv_format in ["sbhd", "bshd"]:
max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if not context_parallel: if not context_parallel:
# [b * s, h, d] # [b * s, h, d]
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
...@@ -3247,8 +3419,8 @@ class FlashAttention(torch.nn.Module): ...@@ -3247,8 +3419,8 @@ class FlashAttention(torch.nn.Module):
cu_seqlens_kv, cu_seqlens_kv,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
None, cu_seqlens_q,
None, cu_seqlens_kv,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
cp_group, cp_group,
cp_global_ranks, cp_global_ranks,
...@@ -3295,10 +3467,12 @@ class FlashAttention(torch.nn.Module): ...@@ -3295,10 +3467,12 @@ class FlashAttention(torch.nn.Module):
if qkv_format == "sbhd": if qkv_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd) # (bs)hd -> bs(hd) -> sb(hd)
output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous() output = (
output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1).contiguous()
)
elif qkv_format == "bshd": elif qkv_format == "bshd":
# (bs)hd -> bs(hd) # (bs)hd -> bs(hd)
output = output.view(batch_size, max_seqlen_q, -1).contiguous() output = output.view(batch_size, max_seqlen_q // cp_size, -1).contiguous()
elif qkv_format == "thd": elif qkv_format == "thd":
# thd -> t(hd) # thd -> t(hd)
output = output.view(output.shape[0], -1).contiguous() output = output.view(output.shape[0], -1).contiguous()
...@@ -4835,7 +5009,8 @@ class FusedAttention(torch.nn.Module): ...@@ -4835,7 +5009,8 @@ class FusedAttention(torch.nn.Module):
qkv_layout in QKVLayouts qkv_layout in QKVLayouts
), f"FusedAttention does not support qkv_layout = {qkv_layout}!" ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1) cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group)
context_parallel = cp_size > 1
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
...@@ -4852,6 +5027,8 @@ class FusedAttention(torch.nn.Module): ...@@ -4852,6 +5027,8 @@ class FusedAttention(torch.nn.Module):
query_layer.shape[1], query_layer.shape[1],
key_layer.shape[1], key_layer.shape[1],
) )
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if "padding" in attn_mask_type: if "padding" in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!" assert not context_parallel, "Padding mask not supported with context parallelism!"
...@@ -5540,13 +5717,22 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5540,13 +5717,22 @@ class DotProductAttention(TransformerEngineBaseModule):
cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
if max_seqlen_q is None: if max_seqlen_q is None:
if cu_seqlens_q_padded is not None:
seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
else:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item()))) max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
if max_seqlen_kv is None: if max_seqlen_kv is None:
if cu_seqlens_kv_padded is not None:
seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
else:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item()))) max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
batch_size = len(cu_seqlens_q) - 1 batch_size = len(cu_seqlens_q) - 1
cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group)
context_parallel = cp_size > 1
if qkv_format in ["sbhd", "bshd"]: if qkv_format in ["sbhd", "bshd"]:
assert all( assert all(
len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
...@@ -5557,6 +5743,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5557,6 +5743,8 @@ class DotProductAttention(TransformerEngineBaseModule):
if qkv_format == "bshd": if qkv_format == "bshd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
batch_size = query_layer.shape[0] batch_size = query_layer.shape[0]
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if cu_seqlens_q is not None: if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
assert all( assert all(
...@@ -5628,10 +5816,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5628,10 +5816,6 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True
context_parallel = (
self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
)
core_attention_bias_shape = None core_attention_bias_shape = None
if core_attention_bias is not None: if core_attention_bias is not None:
if ( if (
......
...@@ -1565,7 +1565,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float ...@@ -1565,7 +1565,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step); dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step);
dtype *p = reinterpret_cast<dtype *>(&data); dtype *p = reinterpret_cast<dtype *>(&data);
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) {
p[k] += p_per_step[k] * lse_corrected_exp; p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp);
} }
reinterpret_cast<float4 *>(cur_out)[j] = data; reinterpret_cast<float4 *>(cur_out)[j] = data;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment