Unverified Commit 9c127ef5 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Fix context parallelism implementation with THD format (#1012)



* use 2hd layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change qkv_format check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a code comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* tensor shape bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tensor shape fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add function to compute cu_seqlens of a cp rank
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cu_seqlens and cu_seqlens_padded to context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FlashAttention output sequence length
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix cu_seqlens_kv_per_step calculation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV for ending padded tokens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV tensors of FlashAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove padded tokens of KV to save comounication
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not need to zero dkv for FlashAttention any mroe
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero out tensors
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of cp test with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarXiaowei Ren <xren@cs-cw-dfw-login-01.cm.cluster>
parent 70117306
......@@ -6,6 +6,7 @@ import os, sys
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
......@@ -86,6 +87,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = (
......@@ -101,18 +104,36 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "thd":
seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(
torch.int32
q_input_shape = (config.batch_size * config.max_seqlen_q, config.num_heads, config.head_dim)
kv_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim,
)
seqlens_q = seqlens_q - seqlens_q % (world_size * 2)
cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)])
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
if kernel_backend == "FlashAttention":
cu_seqlens_q = cu_seqlens_q_padded[:-1]
else:
cu_seqlens_q = torch.cat(
[torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)]
).cuda()
cu_seqlens_kv = cu_seqlens_q
q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim)
kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim)
attn_output_shape = (cu_seqlens_q[-1], config.num_heads * config.head_dim)
cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda()
cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda()
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
......@@ -132,7 +153,7 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
for x in [q, k, v, dout] + ([] if bias is None else [bias]):
dist.broadcast(x, 0, group=cp_comm_group)
if qkv_format == "thd":
for x in [cu_seqlens_q, cu_seqlens_kv]:
for x in [cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded]:
dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP
......@@ -146,6 +167,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1],
)
out.backward(dout)
......@@ -171,12 +194,14 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_]
]
elif qkv_format == "thd":
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
seq_idx_q = tex.thd_get_partitioned_indices(
cu_seqlens_q_padded, q_.shape[0], world_size, rank
)
seq_idx_kv = tex.thd_get_partitioned_indices(
cu_seqlens_kv_padded, k_.shape[0], world_size, rank
)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
cu_seqlens_q = cu_seqlens_q // world_size
cu_seqlens_kv = cu_seqlens_kv // world_size
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
......@@ -187,8 +212,6 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
max_seqlen_q = config.max_seqlen_q
max_seqlen_kv = config.max_seqlen_kv
out_ = core_attn(
q_,
k_,
......@@ -197,8 +220,8 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1],
)
out_.backward(dout_)
......@@ -230,9 +253,45 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
for x in [q_.grad, k_.grad, v_.grad, out_]
]
elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [x.view(-1) for x in [q_.grad, k_.grad, v_.grad, out_]]
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
for x in [dq, out, dq_, out_]:
assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_q[b] == 0
or torch.count_nonzero(
x[(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[b + 1]]
).item()
== 0
)
cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv
num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1]
for x in [dk, dv, dk_, dv_]:
assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_kv[b] == 0
or torch.count_nonzero(
x[
(cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[
b + 1
]
]
).item()
== 0
)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
......
......@@ -1180,6 +1180,27 @@ def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
softmax_lse.copy_(new_scale)
@jit_fuser
def get_cu_seqlens_on_cp_rank(
cu_seqlens, cu_seqlens_padded_on_cp_rank, cp_size, cp_rank, first_half, second_half
):
"""Compute cu_seqlens of a context parallelism rank"""
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
zeros = torch.zeros_like(seqlens)
cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
if first_half:
seqlens_1 = seqlens - cp_rank * seqlens_padded
seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
if second_half:
seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
cu_seqlens_on_cp_rank.cumsum_(dim=0)
return cu_seqlens_on_cp_rank
class AttnFuncWithCP(torch.autograd.Function):
"""
Attention implementation with context parallelism.
......@@ -1195,9 +1216,9 @@ class AttnFuncWithCP(torch.autograd.Function):
k,
v,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_k,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
......@@ -1224,8 +1245,20 @@ class AttnFuncWithCP(torch.autograd.Function):
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
if qkv_format in ["bshd", "sbhd"]:
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
else:
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
max_seqlen_q = max_seqlen_q // cp_size
max_seqlen_kv = max_seqlen_kv // cp_size
cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
cu_seqlens_q_per_step = [None for _ in range(cp_size)]
cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
if causal:
if qkv_format == "bshd":
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
......@@ -1233,6 +1266,9 @@ class AttnFuncWithCP(torch.autograd.Function):
elif qkv_format == "sbhd":
# [s, b, np, hn] -> [2, s//2, b, np, hn]
q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
total_tokens_kv = None if qkv_format != "thd" else k.shape[0]
# remove padded tokens at the end
k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]]
if attn_bias is not None:
assert len(attn_bias.shape) == 4, (
"Only support bias shape of [b, h, sq, sk] for forward, "
......@@ -1273,6 +1309,9 @@ class AttnFuncWithCP(torch.autograd.Function):
fwd_results_correction_done = torch.cuda.Event()
p2p_comm_buffers = [None for _ in range(cp_size)]
if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
else:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []]
......@@ -1298,19 +1337,33 @@ class AttnFuncWithCP(torch.autograd.Function):
kv_inputs[i % 2] = p2p_comm_buffers[i]
if causal:
if i == 0:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
)
else:
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
if use_fused_attention:
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
2, k.shape[0], -1, *k.shape[-2:]
k.shape[0], -1, 2, *k.shape[-2:]
)
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:])
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:]
)
elif qkv_format == "thd":
q_inputs[i % 2] = q
if attn_bias is not None:
......@@ -1326,12 +1379,20 @@ class AttnFuncWithCP(torch.autograd.Function):
fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_k,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_kv,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale,
......@@ -1364,10 +1425,10 @@ class AttnFuncWithCP(torch.autograd.Function):
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_k,
max_seqlen_kv,
dropout_p,
softmax_scale,
causal=True,
......@@ -1375,22 +1436,39 @@ class AttnFuncWithCP(torch.autograd.Function):
**fa_optional_forward_kwargs,
)
elif i <= rank:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv,
cu_seqlens_kv_padded,
cp_size,
(rank - i) % cp_size,
True,
False,
)
else:
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
if use_fused_attention:
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous()
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous()
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous()
elif qkv_format == "thd":
q_inputs[i % 2] = q
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i % 2] = tex.thd_read_half_tensor(
kv_inputs[i % 2], cu_seqlens_k, 0
kv_inputs[i % 2], cu_seqlens_kv_padded, 0
)
if attn_bias is not None:
idx = (rank - i) % cp_size
......@@ -1399,12 +1477,20 @@ class AttnFuncWithCP(torch.autograd.Function):
fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_k // 2,
cu_seqlens_q,
cu_seqlens_k // 2,
max_seqlen_kv // 2,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale,
......@@ -1429,7 +1515,7 @@ class AttnFuncWithCP(torch.autograd.Function):
if qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i % 2] = tex.thd_read_half_tensor(
kv_inputs[i % 2], cu_seqlens_k, 0
kv_inputs[i % 2], cu_seqlens_kv_padded, 0
)
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
......@@ -1451,10 +1537,10 @@ class AttnFuncWithCP(torch.autograd.Function):
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
cu_seqlens_q,
cu_seqlens_k // 2,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_k // 2,
max_seqlen_kv // 2,
dropout_p,
softmax_scale,
causal=False,
......@@ -1462,22 +1548,43 @@ class AttnFuncWithCP(torch.autograd.Function):
**fa_optional_forward_kwargs,
)
else:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
)
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv,
cu_seqlens_kv_padded,
cp_size,
(rank - i) % cp_size,
True,
True,
)
else:
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
if use_fused_attention:
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i % 2] = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
2, k.shape[0], -1, *k.shape[-2:]
k.shape[0], -1, 2, *k.shape[-2:]
)
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_inputs[i % 2] = q[1].contiguous()
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:])
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:]
)
elif qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
q_inputs[i % 2] = tex.thd_read_half_tensor(
q, cu_seqlens_q_padded, 1
)
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i % 2] = torch.cat(
......@@ -1491,12 +1598,20 @@ class AttnFuncWithCP(torch.autograd.Function):
fused_attn_fwd(
is_training,
max_seqlen_q // 2,
max_seqlen_k,
cu_seqlens_q // 2,
cu_seqlens_k,
max_seqlen_kv,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale,
......@@ -1518,7 +1633,9 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
if qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
q_inputs[i % 2] = tex.thd_read_half_tensor(
q, cu_seqlens_q_padded, 1
)
else:
# [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
q_inputs[i % 2] = (
......@@ -1541,10 +1658,10 @@ class AttnFuncWithCP(torch.autograd.Function):
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
cu_seqlens_q // 2,
cu_seqlens_k,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q // 2,
max_seqlen_k,
max_seqlen_kv,
dropout_p,
softmax_scale,
causal=False,
......@@ -1552,6 +1669,23 @@ class AttnFuncWithCP(torch.autograd.Function):
**fa_optional_forward_kwargs,
)
else:
if pad_between_seqs_q:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv,
cu_seqlens_kv_padded,
cp_size,
(rank - i) % cp_size,
True,
True,
)
else:
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
if use_fused_attention:
if attn_bias is not None:
idx = (rank - i) % cp_size
......@@ -1566,12 +1700,20 @@ class AttnFuncWithCP(torch.autograd.Function):
fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_k,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_kv,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q,
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale,
......@@ -1604,10 +1746,10 @@ class AttnFuncWithCP(torch.autograd.Function):
q_inputs[i % 2],
kv_inputs[i % 2][0],
kv_inputs[i % 2][1],
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_k,
max_seqlen_kv,
dropout_p,
softmax_scale,
causal=False,
......@@ -1626,7 +1768,7 @@ class AttnFuncWithCP(torch.autograd.Function):
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if i == 1:
out = torch.empty_like(q).zero_()
out = torch.zeros_like(q)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2]
......@@ -1640,7 +1782,10 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
if qkv_format == "thd":
tex.thd_second_half_lse_correction(
softmax_lse, softmax_lse_per_step[i - 1], cu_seqlens_q, q.size(0)
softmax_lse,
softmax_lse_per_step[i - 1],
cu_seqlens_q_padded,
max_seqlen_q,
)
else:
flash_attn_fwd_softmax_lse_correction(
......@@ -1678,7 +1823,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
cu_seqlens_q,
cu_seqlens_q_padded,
False,
)
else:
......@@ -1698,7 +1843,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
cu_seqlens_q,
cu_seqlens_q_padded,
True,
)
else:
......@@ -1718,18 +1863,19 @@ class AttnFuncWithCP(torch.autograd.Function):
kv,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
*cu_seqlens_q_per_step,
*cu_seqlens_kv_per_step,
*rng_states,
*attn_biases,
)
ctx.cp_group = cp_group
ctx.cp_global_ranks = cp_global_ranks
ctx.dropout_p = dropout_p
ctx.total_tokens_kv = total_tokens_kv
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.max_seqlen_kv = max_seqlen_kv
ctx.softmax_scale = softmax_scale
ctx.qkv_format = qkv_format
ctx.attn_mask_type = attn_mask_type
......@@ -1741,19 +1887,23 @@ class AttnFuncWithCP(torch.autograd.Function):
@staticmethod
def backward(ctx, dout):
(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6]
(cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[6:8]
cp_size = get_distributed_world_size(ctx.cp_group)
rng_states = ctx.saved_tensors[8 : 8 + cp_size]
attn_biases = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
cu_seqlens_q_per_step = ctx.saved_tensors[6 : 6 + cp_size]
cu_seqlens_kv_per_step = ctx.saved_tensors[6 + cp_size : 6 + cp_size * 2]
rng_states = ctx.saved_tensors[6 + cp_size * 2 : 6 + cp_size * 3]
attn_biases = ctx.saved_tensors[6 + cp_size * 3 : 6 + cp_size * 4]
causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type
if ctx.qkv_format in ["bshd", "sbhd"]:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
else:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
if attn_biases[0] is not None:
......@@ -1770,7 +1920,9 @@ class AttnFuncWithCP(torch.autograd.Function):
if causal:
if ctx.qkv_format == "thd":
softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
softmax_lse_ = tex.thd_read_second_half_lse(
softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q
)
else:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
......@@ -1788,6 +1940,8 @@ class AttnFuncWithCP(torch.autograd.Function):
dout = dout.view(*q.shape)
# Flash Attn outputs
dq = torch.empty_like(q)
if ctx.qkv_format == "thd" and causal:
dq[cu_seqlens_q_padded[-1] :].fill_(0)
p2p_comm_buffers = [
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
......@@ -1828,16 +1982,16 @@ class AttnFuncWithCP(torch.autograd.Function):
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_ = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:])
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:])
......@@ -1848,12 +2002,12 @@ class AttnFuncWithCP(torch.autograd.Function):
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_k,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
q_,
kv_[0],
kv_[1],
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
dout_,
TE_DType[q.dtype],
......@@ -1871,7 +2025,7 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
dq_ = torch.zeros_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
......@@ -1890,10 +2044,10 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_,
dkv_[0],
dkv_[1],
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
True,
......@@ -1905,34 +2059,34 @@ class AttnFuncWithCP(torch.autograd.Function):
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous()
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_ = kv[:, 0, ...].contiguous()
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_ = q.view(-1, *q.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
kv_ = kv[:, 0, ...].contiguous()
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_ = kv[0].contiguous()
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:])
elif ctx.qkv_format == "thd":
q_, out_, dout_ = q, out, dout
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_k // 2,
cu_seqlens_q,
cu_seqlens_k // 2,
ctx.max_seqlen_kv // 2,
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
q_,
kv_[0],
kv_[1],
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
dout_,
TE_DType[q.dtype],
......@@ -1952,10 +2106,10 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
dq_ = torch.zeros_like(q_)
if ctx.qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
else:
# [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
......@@ -1975,10 +2129,10 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_,
dkv_[0],
dkv_[1],
cu_seqlens_q,
cu_seqlens_k // 2,
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_k // 2,
ctx.max_seqlen_kv // 2,
ctx.dropout_p,
ctx.softmax_scale,
False,
......@@ -1990,36 +2144,36 @@ class AttnFuncWithCP(torch.autograd.Function):
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out[:, 1, ...].contiguous()
dout_ = dout[:, 1, ...].contiguous()
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q[1].contiguous()
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:])
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
out_ = out[1].contiguous()
dout_ = dout[1].contiguous()
elif ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
kv_ = kv
aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q // 2,
ctx.max_seqlen_k,
cu_seqlens_q // 2,
cu_seqlens_k,
ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
q_,
kv_[0],
kv_[1],
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
dout_,
TE_DType[q.dtype],
......@@ -2039,17 +2193,17 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
if ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
dq_ = torch.zeros_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
if ctx.qkv_format == "thd":
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
......@@ -2066,10 +2220,10 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_,
dkv_[0],
dkv_[1],
cu_seqlens_q // 2,
cu_seqlens_k,
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q // 2,
ctx.max_seqlen_k,
ctx.max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
False,
......@@ -2083,12 +2237,12 @@ class AttnFuncWithCP(torch.autograd.Function):
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_k,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
q,
kv[0],
kv[1],
kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
out,
dout,
TE_DType[q.dtype],
......@@ -2106,7 +2260,7 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
dq_ = torch.zeros_like(q_)
# [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
......@@ -2125,10 +2279,10 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_,
dkv_[0],
dkv_[1],
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
False,
......@@ -2162,21 +2316,21 @@ class AttnFuncWithCP(torch.autograd.Function):
dq[0].copy_(dq_[0])
dq[1].add_(dq_[1])
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add")
tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
elif i > 0:
if ctx.qkv_format == "bshd":
dq[:, 1, ...].add_(dq_)
elif ctx.qkv_format == "sbhd":
dq[1].add_(dq_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add")
tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
else:
if ctx.qkv_format == "bshd":
dq[:, 1, ...].copy_(dq_)
elif ctx.qkv_format == "sbhd":
dq[1].copy_(dq_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy")
tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
else:
if i == 0:
dq.copy_(dq_)
......@@ -2206,6 +2360,10 @@ class AttnFuncWithCP(torch.autograd.Function):
dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.use_fused_attention:
dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
if ctx.qkv_format in ["bshd", "sbhd"]:
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd":
# [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
......@@ -2228,7 +2386,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dkv[:, 0, ...].add_(dkv_[:, 0, ...])
dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy")
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
else:
dkv.add_(dkv_)
elif i >= (cp_size - rank - 1):
......@@ -2238,14 +2396,14 @@ class AttnFuncWithCP(torch.autograd.Function):
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].copy_(dkv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none")
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
else:
if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].add_(dkv_)
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none")
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
elif i > 0:
dkv.add_(dkv_)
else:
......@@ -2259,14 +2417,22 @@ class AttnFuncWithCP(torch.autograd.Function):
if causal:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq = dq.view(q.shape[0], -1, *q.shape[-2:])
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
dq = dq.view(-1, *q.shape[-3:])
dq = dq.view(-1, *dq.shape[-3:])
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
dkv = dkv.view(kv.shape[0], -1, *kv.shape[-3:])
dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])
if ctx.qkv_format == "thd":
dkv_ = torch.empty(
2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device
)
dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv)
dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
dkv = dkv_
if attn_dbias is not None:
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
......@@ -2303,9 +2469,9 @@ def attn_forward_func_with_cp(
k,
v,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_k,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
......@@ -2341,15 +2507,18 @@ def attn_forward_func_with_cp(
"""Attention bias is only supported with FusedAttention and "causal" """
"""or "no_mask" mask types!"""
)
assert (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
out = AttnFuncWithCP.apply(
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_k,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
......@@ -3140,7 +3309,8 @@ class FlashAttention(torch.nn.Module):
qkv_layout in QKVLayouts
), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)
cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group)
context_parallel = cp_size > 1
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
......@@ -3167,6 +3337,8 @@ class FlashAttention(torch.nn.Module):
if qkv_format in ["sbhd", "bshd"]:
max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if not context_parallel:
# [b * s, h, d]
query_layer, key_layer, value_layer = [
......@@ -3247,8 +3419,8 @@ class FlashAttention(torch.nn.Module):
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
None,
None,
cu_seqlens_q,
cu_seqlens_kv,
self.attention_dropout if self.training else 0.0,
cp_group,
cp_global_ranks,
......@@ -3295,10 +3467,12 @@ class FlashAttention(torch.nn.Module):
if qkv_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd)
output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
output = (
output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1).contiguous()
)
elif qkv_format == "bshd":
# (bs)hd -> bs(hd)
output = output.view(batch_size, max_seqlen_q, -1).contiguous()
output = output.view(batch_size, max_seqlen_q // cp_size, -1).contiguous()
elif qkv_format == "thd":
# thd -> t(hd)
output = output.view(output.shape[0], -1).contiguous()
......@@ -4835,7 +5009,8 @@ class FusedAttention(torch.nn.Module):
qkv_layout in QKVLayouts
), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)
cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group)
context_parallel = cp_size > 1
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
......@@ -4852,6 +5027,8 @@ class FusedAttention(torch.nn.Module):
query_layer.shape[1],
key_layer.shape[1],
)
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if "padding" in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!"
......@@ -5540,13 +5717,22 @@ class DotProductAttention(TransformerEngineBaseModule):
cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
if max_seqlen_q is None:
if cu_seqlens_q_padded is not None:
seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
else:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
if max_seqlen_kv is None:
if cu_seqlens_kv_padded is not None:
seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
else:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
batch_size = len(cu_seqlens_q) - 1
cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group)
context_parallel = cp_size > 1
if qkv_format in ["sbhd", "bshd"]:
assert all(
len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
......@@ -5557,6 +5743,8 @@ class DotProductAttention(TransformerEngineBaseModule):
if qkv_format == "bshd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
batch_size = query_layer.shape[0]
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
assert all(
......@@ -5628,10 +5816,6 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
context_parallel = (
self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
)
core_attention_bias_shape = None
if core_attention_bias is not None:
if (
......
......@@ -1565,7 +1565,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step);
dtype *p = reinterpret_cast<dtype *>(&data);
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) {
p[k] += p_per_step[k] * lse_corrected_exp;
p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp);
}
reinterpret_cast<float4 *>(cur_out)[j] = data;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment