"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "7e1d5e5308fa3549dfed1821188d588260a03c8a"
Unverified Commit 2a95efd3 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

CP implementation refinement for BSHD/SBHD format (#1523)



* fix recompilation of out and lse correction in p2p+bshd/sbhd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix recompilation of get_seq_chunk_ids_for_reordering
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix recomplilation of reorder_seq_chunks_for_a2a
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* recover a change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change to softmax_lse correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cache cu_seqlens for BSHD/SBHD format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* do not need to allocate out buffer for BSHD/SBHD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* refactor init out correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a docstring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix init out correct dtype
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add pad_between_seqs to DPA API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add pad_between_seqs to the API of MHA and transformer layer
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add pad_between_seqs to the API of MHA and transformer layer
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2ad5da95
......@@ -1589,24 +1589,52 @@ def flash_attn_p2p_communicate(
return send_recv_reqs
@jit_fuser
def flash_attn_fwd_out_correction_init(
out_init_step: torch.Tensor,
softmax_lse: torch.Tensor,
softmax_lse_init_step: torch.Tensor,
seq_dim: int,
):
"""Merge partial outputs of the first step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_init_step * softmax_lse_corrected_exp
return out_corrected.to(out_init_step.dtype)
@jit_fuser
def flash_attn_fwd_out_correction(
out: torch.Tensor,
out_per_step: torch.Tensor,
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
movedim_src: int,
movedim_dst: int,
seq_dim: int,
):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(
movedim_src, movedim_dst
)
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp
out.add_(out_corrected)
@jit_fuser
def flash_attn_fwd_second_half_out_correction(
out: torch.Tensor,
out_per_step: torch.Tensor,
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
seq_dim: int,
):
"""Merge second half of partial outputs of each step in Attention with context parallelism"""
out_ = out.select(seq_dim, 1)
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)[..., 1, :]
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse_).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp
out_.add_(out_corrected)
@jit_fuser
def flash_attn_fwd_softmax_lse_correction(
softmax_lse: torch.Tensor,
......@@ -1619,6 +1647,19 @@ def flash_attn_fwd_softmax_lse_correction(
softmax_lse.copy_(new_scale)
@jit_fuser
def flash_attn_fwd_second_half_softmax_lse_correction(
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
):
"""Merge second half of softmax stats of each step in Attention with context parallelism"""
softmax_lse_ = softmax_lse[..., 1, :]
max_scale = torch.max(softmax_lse_, softmax_lse_per_step)
min_scale = torch.min(softmax_lse_, softmax_lse_per_step)
new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
softmax_lse_.copy_(new_scale)
@jit_fuser
def get_cu_seqlens_on_cp_rank(
cu_seqlens: torch.Tensor,
......@@ -1646,46 +1687,59 @@ def get_cu_seqlens_on_cp_rank(
@jit_fuser
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device):
"""
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks
before or after CP communications (e.g., all-gather, all-to-all). This function is to compute
sequence chunk ids for reordering.
To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks to
be contigupus before attention compute. This function is to compute sequence chunk ids for
reordering.
"""
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
if to_contiguous:
for rank in range(cp_size):
chunk_ids[rank] = 2 * rank
chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
else:
for rank in range(cp_size):
chunk_ids[2 * rank] = rank
chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
for rank in range(cp_size):
chunk_ids[rank] = 2 * rank
chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
return chunk_ids
@jit_fuser
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
"""Reorder sequence chunk for A2A communication."""
if before_attn:
# [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
# or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
x = x.movedim(0, seq_dim).contiguous()
# [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
# or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
# reorder the sequence chunks
x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
else:
# [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.movedim(seq_dim, 0).contiguous()
# reorder the sequence chunks
x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
# [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
x = x.view(cp_size, 2, *x.shape[1:])
def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device):
"""
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
We need to reorder sequence chunks back to discontiguous after attention compute. This function
is to compute sequence chunk ids for reordering.
"""
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
for rank in range(cp_size):
chunk_ids[2 * rank] = rank
chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
return chunk_ids
@jit_fuser
def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size):
"""Reorder sequence chunk for A2A communication before attention compute."""
# [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
# or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
x = x.movedim(0, seq_dim).contiguous()
# [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
# or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
# reorder the sequence chunks
x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
return x
@jit_fuser
def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size):
"""Reorder sequence chunk for A2A communication after attention compute."""
# [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.movedim(seq_dim, 0).contiguous()
# reorder the sequence chunks
x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
# [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
x = x.view(cp_size, 2, *x.shape[1:])
return x
......@@ -1713,8 +1767,8 @@ def flash_attn_a2a_communicate(
a2a_reqs[i - 2].wait()
x = a2a_outputs[i - 2]
# reorder the sequence chunks
x = reorder_seq_chunks_for_a2a(
x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
x = reorder_seq_chunks_for_a2a_before_attn(
x, chunk_ids_for_a2a, seq_dim, cp_size
)
# [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
......@@ -1740,8 +1794,8 @@ def flash_attn_a2a_communicate(
# or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
# reorder the sequence chunks
a2a_inputs[i] = reorder_seq_chunks_for_a2a(
x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn(
x, chunk_ids_for_a2a, seq_dim, cp_size
)
if i > 1:
with torch.cuda.stream(cp_stream):
......@@ -1800,6 +1854,25 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
_cu_seqlens_info_with_cp_cache = {}
def _get_cu_seqlens_info_with_cp(
batch_size: int,
max_seqlen: int,
cp_size: int,
cu_seqlens: torch.Tensor,
):
"""Cumulative sequence lengths with CP being considered."""
global _cu_seqlens_info_with_cp_cache
if (batch_size, max_seqlen, cp_size) not in _cu_seqlens_info_with_cp_cache:
_cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] = (
cu_seqlens // cp_size,
cu_seqlens // (cp_size * 2),
)
return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)]
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
"""
Attention implementation with context parallelism. Exchange KV between CP ranks
......@@ -1839,6 +1912,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cp_global_ranks,
cp_stream,
quantizers,
pad_between_seqs,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
......@@ -1871,27 +1945,28 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
batch_dim = None
seq_dim = None
cu_seqlens_q_half, cu_seqlens_kv_half = None, None
if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s")
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None
if use_fused_attention:
batch_dim = qkv_format.index("b")
cu_seqlens_q, cu_seqlens_q_half = _get_cu_seqlens_info_with_cp(
q.shape[batch_dim], max_seqlen_q, cp_size, cu_seqlens_q
)
cu_seqlens_kv, cu_seqlens_kv_half = _get_cu_seqlens_info_with_cp(
q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv
)
else:
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
pad_between_seqs_q = cu_seqlens_q_padded is not None and not torch.equal(
cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]
)
pad_between_seqs_kv = cu_seqlens_kv_padded is not None and not torch.equal(
cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]
)
max_seqlen_q = max_seqlen_q // cp_size
max_seqlen_kv = max_seqlen_kv // cp_size
cu_seqlens_q_padded = (
None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // cp_size
)
cu_seqlens_kv_padded = (
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // cp_size
)
cu_seqlens_q_per_step = [None for _ in range(cp_size)]
cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
......@@ -1948,7 +2023,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device)
q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True
......@@ -2048,7 +2123,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []]
softmax_lse_ = None
out = None
for i in range(cp_size + 1):
if i < cp_size:
......@@ -2076,18 +2150,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data
if causal:
if i == 0:
if pad_between_seqs_q:
if pad_between_seqs:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
)
elif use_fused_attention or qkv_format == "thd":
elif qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q
cu_seqlens_kv_per_step[i] = cu_seqlens_kv
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
......@@ -2202,13 +2277,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[3]
elif i <= rank:
if pad_between_seqs_q:
if pad_between_seqs:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv,
cu_seqlens_kv_padded,
......@@ -2217,8 +2289,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True,
False,
)
elif use_fused_attention or qkv_format == "thd":
elif qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q
cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
......@@ -2338,13 +2414,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[3]
else:
if pad_between_seqs_q:
if pad_between_seqs:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
)
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv,
cu_seqlens_kv_padded,
......@@ -2353,8 +2426,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True,
True,
)
elif use_fused_attention or qkv_format == "thd":
elif qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q_half
cu_seqlens_kv_per_step[i] = cu_seqlens_kv
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i % 2] = q[:, 1, ...]
......@@ -2483,13 +2560,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not _use_flash_attn_3:
rng_states[i] = fa_outputs[3]
else:
if pad_between_seqs_q:
if pad_between_seqs:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
)
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv,
cu_seqlens_kv_padded,
......@@ -2498,8 +2572,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True,
True,
)
elif use_fused_attention or qkv_format == "thd":
elif qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q
cu_seqlens_kv_per_step[i] = cu_seqlens_kv
if use_fused_attention:
if attn_bias is not None:
idx = (rank - i) % cp_size
......@@ -2615,13 +2693,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if fp8:
out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32)
if i == 1:
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
)
if qkv_format == "thd":
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
elif (i - 1) <= rank or not causal:
flash_attn_fwd_softmax_lse_correction(
softmax_lse, softmax_lse_per_step[i - 1]
......@@ -2635,8 +2709,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_in_packed_format,
)
else:
flash_attn_fwd_softmax_lse_correction(
softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1]
flash_attn_fwd_second_half_softmax_lse_correction(
softmax_lse.view(*softmax_lse.shape[:-1], 2, -1),
softmax_lse_per_step[i - 1],
)
if i < cp_size:
......@@ -2652,14 +2727,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
for i in range(cp_size):
if i <= rank or not causal:
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(
out.view(*out_per_step[i].shape),
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2,
2 if softmax_lse_in_packed_format else seq_dim,
)
if i == 0:
out = flash_attn_fwd_out_correction_init(
out_per_step[0],
softmax_lse,
softmax_lse_per_step[0],
seq_dim,
)
out = out.view(q.shape)
else:
flash_attn_fwd_out_correction(
out.view(*out_per_step[i].shape),
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
seq_dim,
)
elif qkv_format == "thd":
tex.thd_out_correction(
out,
......@@ -2672,14 +2755,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
else:
if qkv_format in ["bshd", "sbhd"]:
out_ = out.select(seq_dim, 1)
flash_attn_fwd_out_correction(
out_,
flash_attn_fwd_second_half_out_correction(
out,
out_per_step[i],
softmax_lse_[..., 1, :],
softmax_lse,
softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2,
2 if softmax_lse_in_packed_format else seq_dim,
seq_dim,
)
elif qkv_format == "thd":
tex.thd_out_correction(
......@@ -2701,7 +2782,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.batch_size = out.shape[1]
if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device)
out = flash_attn_a2a_communicate(
out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
)
......@@ -2842,9 +2923,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
else:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
)
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
if ctx.use_fused_attention:
if ctx.softmax_lse_in_packed_format:
......@@ -2932,7 +3011,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not ctx.use_fused_attention:
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
dout = dout.view(*out.shape)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(
cp_size_a2a, out.device
)
out, dout = flash_attn_a2a_communicate(
[out, dout],
chunk_ids_for_a2a,
......@@ -3642,7 +3723,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dk, dv = dkv[0], dkv[1]
if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device)
dq, dk, dv = flash_attn_a2a_communicate(
[dq, dk, dv],
chunk_ids_for_a2a,
......@@ -3692,6 +3773,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -3806,9 +3888,10 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
if use_fused_attention or qkv_format == "thd":
cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
cu_seqlens_q_padded = (
None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // (2 * cp_size)
)
if cu_seqlens_q_padded is not None and qkv_format == "thd":
cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
else:
cu_seqlens_q_padded = None
# [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
......@@ -3822,7 +3905,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
# [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
# [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
......@@ -4011,7 +4094,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
# [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
# [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
......@@ -4147,7 +4230,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
# [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False)
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device)
dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
# [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
......@@ -4312,7 +4395,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
fp8_meta_kwargs = {}
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device)
q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
)
......@@ -4383,7 +4466,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
rng_state = fa_outputs[3] if not _use_flash_attn_3 else None
aux_ctx_tensors = [softmax_lse, rng_state]
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device)
out = flash_attn_a2a_communicate(
out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
)
......@@ -4534,7 +4617,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
dout = dout.view(*out.shape)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device)
out, dout = flash_attn_a2a_communicate(
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
)
......@@ -4657,7 +4740,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
**fa_backward_kwargs,
)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device)
dq, dk, dv = flash_attn_a2a_communicate(
[dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
)
......@@ -4737,6 +4820,7 @@ def attn_forward_func_with_cp(
fp8=False,
fp8_meta=None,
quantizers=None,
pad_between_seqs=False,
) -> torch.Tensor:
"""
Attention implementation with context parallelism.
......@@ -4804,7 +4888,7 @@ def attn_forward_func_with_cp(
]
if cp_comm_type in ["p2p", "a2a+p2p"]:
args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers]
args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, pad_between_seqs]
out = AttnFuncWithCPAndKVP2P.apply(*args)
elif cp_comm_type == "all_gather":
args.pop(5)
......@@ -5823,6 +5907,7 @@ class FlashAttention(torch.nn.Module):
deterministic=self.deterministic,
window_size=window_size,
quantizers=quantizers,
pad_between_seqs=False,
)
else:
......@@ -6529,6 +6614,7 @@ class FusedAttention(torch.nn.Module):
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None,
pad_between_seqs: bool = False,
) -> torch.Tensor:
"""fused attention fprop"""
assert (
......@@ -6667,6 +6753,7 @@ class FusedAttention(torch.nn.Module):
fp8=fp8,
fp8_meta=fp8_meta,
quantizers=quantizers,
pad_between_seqs=pad_between_seqs,
)
else:
with self.attention_dropout_ctx():
......@@ -7083,6 +7170,7 @@ class DotProductAttention(TransformerEngineBaseModule):
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
......@@ -7252,6 +7340,9 @@ class DotProductAttention(TransformerEngineBaseModule):
Adjustments of the sequence_len_offset should be done after a complete forward pass.
If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
"""
with self.prepare_forward(
......@@ -7526,13 +7617,17 @@ class DotProductAttention(TransformerEngineBaseModule):
False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
pad_between_seqs = (
cu_seqlens_q_padded is not None
and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
) or (
cu_seqlens_kv_padded is not None
and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
)
if pad_between_seqs is None:
if qkv_format == "thd":
pad_between_seqs = (
cu_seqlens_q_padded is not None
and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
) or (
cu_seqlens_kv_padded is not None
and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
)
else:
pad_between_seqs = False
attention_params = AttentionParams(
qkv_type=type(query_layer),
......@@ -7666,6 +7761,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_comm_type=self.cp_comm_type,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
pad_between_seqs=pad_between_seqs,
)
return self.fused_attention(
query_layer,
......@@ -7692,6 +7788,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
)
from .cpu_offload import CPUOffloadEnabled
......@@ -8188,6 +8285,7 @@ class MultiheadAttention(torch.nn.Module):
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True,
pad_between_seqs: Optional[bool] = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""
Forward propagation for MultiheadAttention layer.
......@@ -8266,6 +8364,9 @@ class MultiheadAttention(torch.nn.Module):
Calculated from `cu_seqlens_kv` if not provided.
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
"""
# hidden_states: [sq, b, h]
......@@ -8523,6 +8624,7 @@ class MultiheadAttention(torch.nn.Module):
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill,
inference_params=inference_params,
pad_between_seqs=pad_between_seqs,
)
# ===================
......
......@@ -546,6 +546,7 @@ class TransformerLayer(torch.nn.Module):
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True,
pad_between_seqs: Optional[bool] = None,
) -> torch.Tensor:
"""
Transformer Layer: attention block and a feedforward network (MLP)
......@@ -637,6 +638,9 @@ class TransformerLayer(torch.nn.Module):
inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
"""
if self_attn_mask_type is None:
......@@ -697,6 +701,7 @@ class TransformerLayer(torch.nn.Module):
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
fast_zero_fill=fast_zero_fill,
pad_between_seqs=pad_between_seqs,
)
if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment