Unverified Commit 2a95efd3 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

CP implementation refinement for BSHD/SBHD format (#1523)



* fix recompilation of out and lse correction in p2p+bshd/sbhd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix recompilation of get_seq_chunk_ids_for_reordering
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix recomplilation of reorder_seq_chunks_for_a2a
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* recover a change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change to softmax_lse correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cache cu_seqlens for BSHD/SBHD format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* do not need to allocate out buffer for BSHD/SBHD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* refactor init out correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a docstring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix init out correct dtype
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add pad_between_seqs to DPA API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add pad_between_seqs to the API of MHA and transformer layer
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add pad_between_seqs to the API of MHA and transformer layer
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2ad5da95
...@@ -1589,24 +1589,52 @@ def flash_attn_p2p_communicate( ...@@ -1589,24 +1589,52 @@ def flash_attn_p2p_communicate(
return send_recv_reqs return send_recv_reqs
@jit_fuser
def flash_attn_fwd_out_correction_init(
out_init_step: torch.Tensor,
softmax_lse: torch.Tensor,
softmax_lse_init_step: torch.Tensor,
seq_dim: int,
):
"""Merge partial outputs of the first step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_init_step * softmax_lse_corrected_exp
return out_corrected.to(out_init_step.dtype)
@jit_fuser @jit_fuser
def flash_attn_fwd_out_correction( def flash_attn_fwd_out_correction(
out: torch.Tensor, out: torch.Tensor,
out_per_step: torch.Tensor, out_per_step: torch.Tensor,
softmax_lse: torch.Tensor, softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor, softmax_lse_per_step: torch.Tensor,
movedim_src: int, seq_dim: int,
movedim_dst: int,
): ):
"""Merge partial outputs of each step in Attention with context parallelism""" """Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim( softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
movedim_src, movedim_dst
)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp out_corrected = out_per_step * softmax_lse_corrected_exp
out.add_(out_corrected) out.add_(out_corrected)
@jit_fuser
def flash_attn_fwd_second_half_out_correction(
out: torch.Tensor,
out_per_step: torch.Tensor,
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
seq_dim: int,
):
"""Merge second half of partial outputs of each step in Attention with context parallelism"""
out_ = out.select(seq_dim, 1)
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)[..., 1, :]
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse_).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp
out_.add_(out_corrected)
@jit_fuser @jit_fuser
def flash_attn_fwd_softmax_lse_correction( def flash_attn_fwd_softmax_lse_correction(
softmax_lse: torch.Tensor, softmax_lse: torch.Tensor,
...@@ -1619,6 +1647,19 @@ def flash_attn_fwd_softmax_lse_correction( ...@@ -1619,6 +1647,19 @@ def flash_attn_fwd_softmax_lse_correction(
softmax_lse.copy_(new_scale) softmax_lse.copy_(new_scale)
@jit_fuser
def flash_attn_fwd_second_half_softmax_lse_correction(
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
):
"""Merge second half of softmax stats of each step in Attention with context parallelism"""
softmax_lse_ = softmax_lse[..., 1, :]
max_scale = torch.max(softmax_lse_, softmax_lse_per_step)
min_scale = torch.min(softmax_lse_, softmax_lse_per_step)
new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
softmax_lse_.copy_(new_scale)
@jit_fuser @jit_fuser
def get_cu_seqlens_on_cp_rank( def get_cu_seqlens_on_cp_rank(
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
...@@ -1646,19 +1687,28 @@ def get_cu_seqlens_on_cp_rank( ...@@ -1646,19 +1687,28 @@ def get_cu_seqlens_on_cp_rank(
@jit_fuser @jit_fuser
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device):
""" """
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks to
before or after CP communications (e.g., all-gather, all-to-all). This function is to compute be contigupus before attention compute. This function is to compute sequence chunk ids for
sequence chunk ids for reordering. reordering.
""" """
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
if to_contiguous:
for rank in range(cp_size): for rank in range(cp_size):
chunk_ids[rank] = 2 * rank chunk_ids[rank] = 2 * rank
chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
else: return chunk_ids
@jit_fuser
def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device):
"""
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
We need to reorder sequence chunks back to discontiguous after attention compute. This function
is to compute sequence chunk ids for reordering.
"""
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
for rank in range(cp_size): for rank in range(cp_size):
chunk_ids[2 * rank] = rank chunk_ids[2 * rank] = rank
chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
...@@ -1666,9 +1716,8 @@ def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): ...@@ -1666,9 +1716,8 @@ def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
@jit_fuser @jit_fuser
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size):
"""Reorder sequence chunk for A2A communication.""" """Reorder sequence chunk for A2A communication before attention compute."""
if before_attn:
# [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
# or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
x = x.movedim(0, seq_dim).contiguous() x = x.movedim(0, seq_dim).contiguous()
...@@ -1677,7 +1726,12 @@ def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_at ...@@ -1677,7 +1726,12 @@ def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_at
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
# reorder the sequence chunks # reorder the sequence chunks
x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
else: return x
@jit_fuser
def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size):
"""Reorder sequence chunk for A2A communication after attention compute."""
# [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.movedim(seq_dim, 0).contiguous() x = x.movedim(seq_dim, 0).contiguous()
...@@ -1713,8 +1767,8 @@ def flash_attn_a2a_communicate( ...@@ -1713,8 +1767,8 @@ def flash_attn_a2a_communicate(
a2a_reqs[i - 2].wait() a2a_reqs[i - 2].wait()
x = a2a_outputs[i - 2] x = a2a_outputs[i - 2]
# reorder the sequence chunks # reorder the sequence chunks
x = reorder_seq_chunks_for_a2a( x = reorder_seq_chunks_for_a2a_before_attn(
x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn x, chunk_ids_for_a2a, seq_dim, cp_size
) )
# [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
...@@ -1740,8 +1794,8 @@ def flash_attn_a2a_communicate( ...@@ -1740,8 +1794,8 @@ def flash_attn_a2a_communicate(
# or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
# reorder the sequence chunks # reorder the sequence chunks
a2a_inputs[i] = reorder_seq_chunks_for_a2a( a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn(
x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn x, chunk_ids_for_a2a, seq_dim, cp_size
) )
if i > 1: if i > 1:
with torch.cuda.stream(cp_stream): with torch.cuda.stream(cp_stream):
...@@ -1800,6 +1854,25 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): ...@@ -1800,6 +1854,25 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
_cu_seqlens_info_with_cp_cache = {}
def _get_cu_seqlens_info_with_cp(
batch_size: int,
max_seqlen: int,
cp_size: int,
cu_seqlens: torch.Tensor,
):
"""Cumulative sequence lengths with CP being considered."""
global _cu_seqlens_info_with_cp_cache
if (batch_size, max_seqlen, cp_size) not in _cu_seqlens_info_with_cp_cache:
_cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] = (
cu_seqlens // cp_size,
cu_seqlens // (cp_size * 2),
)
return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)]
class AttnFuncWithCPAndKVP2P(torch.autograd.Function): class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
""" """
Attention implementation with context parallelism. Exchange KV between CP ranks Attention implementation with context parallelism. Exchange KV between CP ranks
...@@ -1839,6 +1912,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1839,6 +1912,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cp_global_ranks, cp_global_ranks,
cp_stream, cp_stream,
quantizers, quantizers,
pad_between_seqs,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
...@@ -1871,27 +1945,28 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1871,27 +1945,28 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal = "causal" in attn_mask_type causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type padding = "padding" in attn_mask_type
batch_dim = None
seq_dim = None seq_dim = None
cu_seqlens_q_half, cu_seqlens_kv_half = None, None
if qkv_format in ["bshd", "sbhd"]: if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s") seq_dim = qkv_format.index("s")
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None
if use_fused_attention:
batch_dim = qkv_format.index("b")
cu_seqlens_q, cu_seqlens_q_half = _get_cu_seqlens_info_with_cp(
q.shape[batch_dim], max_seqlen_q, cp_size, cu_seqlens_q
)
cu_seqlens_kv, cu_seqlens_kv_half = _get_cu_seqlens_info_with_cp(
q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv
)
else: else:
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
pad_between_seqs_q = cu_seqlens_q_padded is not None and not torch.equal(
cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]
)
pad_between_seqs_kv = cu_seqlens_kv_padded is not None and not torch.equal(
cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]
)
max_seqlen_q = max_seqlen_q // cp_size max_seqlen_q = max_seqlen_q // cp_size
max_seqlen_kv = max_seqlen_kv // cp_size max_seqlen_kv = max_seqlen_kv // cp_size
cu_seqlens_q_padded = (
None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // cp_size
)
cu_seqlens_kv_padded = (
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // cp_size
)
cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_q_per_step = [None for _ in range(cp_size)]
cu_seqlens_kv_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
...@@ -1948,7 +2023,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1948,7 +2023,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if cp_size_a2a > 1: if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device)
q, k, v = flash_attn_a2a_communicate( q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True
...@@ -2048,7 +2123,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2048,7 +2123,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []] send_recv_reqs = [[], []]
softmax_lse_ = None
out = None out = None
for i in range(cp_size + 1): for i in range(cp_size + 1):
if i < cp_size: if i < cp_size:
...@@ -2076,18 +2150,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2076,18 +2150,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data
if causal: if causal:
if i == 0: if i == 0:
if pad_between_seqs_q: if pad_between_seqs:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
) )
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
) )
elif use_fused_attention or qkv_format == "thd": elif qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q
cu_seqlens_kv_per_step[i] = cu_seqlens_kv
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
...@@ -2202,13 +2277,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2202,13 +2277,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not _use_flash_attn_3: if not _use_flash_attn_3:
rng_states[i] = fa_outputs[3] rng_states[i] = fa_outputs[3]
elif i <= rank: elif i <= rank:
if pad_between_seqs_q: if pad_between_seqs:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
) )
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
...@@ -2217,8 +2289,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2217,8 +2289,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True, True,
False, False,
) )
elif use_fused_attention or qkv_format == "thd": elif qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q
cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
...@@ -2338,13 +2414,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2338,13 +2414,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not _use_flash_attn_3: if not _use_flash_attn_3:
rng_states[i] = fa_outputs[3] rng_states[i] = fa_outputs[3]
else: else:
if pad_between_seqs_q: if pad_between_seqs:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
) )
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
...@@ -2353,8 +2426,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2353,8 +2426,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True, True,
True, True,
) )
elif use_fused_attention or qkv_format == "thd": elif qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q_half
cu_seqlens_kv_per_step[i] = cu_seqlens_kv
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i % 2] = q[:, 1, ...] q_inputs[i % 2] = q[:, 1, ...]
...@@ -2483,13 +2560,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2483,13 +2560,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not _use_flash_attn_3: if not _use_flash_attn_3:
rng_states[i] = fa_outputs[3] rng_states[i] = fa_outputs[3]
else: else:
if pad_between_seqs_q: if pad_between_seqs:
cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
) )
elif use_fused_attention or qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
if pad_between_seqs_kv:
cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
...@@ -2498,8 +2572,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2498,8 +2572,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True, True,
True, True,
) )
elif use_fused_attention or qkv_format == "thd": elif qkv_format == "thd":
cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
else:
cu_seqlens_q_per_step[i] = cu_seqlens_q
cu_seqlens_kv_per_step[i] = cu_seqlens_kv
if use_fused_attention: if use_fused_attention:
if attn_bias is not None: if attn_bias is not None:
idx = (rank - i) % cp_size idx = (rank - i) % cp_size
...@@ -2615,13 +2693,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2615,13 +2693,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if fp8: if fp8:
out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32)
if i == 1: if i == 1:
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal and qkv_format != "thd": if qkv_format == "thd":
# [b, np, sq] -> [b, np, 2, sq//2] out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
)
elif (i - 1) <= rank or not causal: elif (i - 1) <= rank or not causal:
flash_attn_fwd_softmax_lse_correction( flash_attn_fwd_softmax_lse_correction(
softmax_lse, softmax_lse_per_step[i - 1] softmax_lse, softmax_lse_per_step[i - 1]
...@@ -2635,8 +2709,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2635,8 +2709,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_in_packed_format, softmax_lse_in_packed_format,
) )
else: else:
flash_attn_fwd_softmax_lse_correction( flash_attn_fwd_second_half_softmax_lse_correction(
softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1] softmax_lse.view(*softmax_lse.shape[:-1], 2, -1),
softmax_lse_per_step[i - 1],
) )
if i < cp_size: if i < cp_size:
...@@ -2652,13 +2727,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2652,13 +2727,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
for i in range(cp_size): for i in range(cp_size):
if i <= rank or not causal: if i <= rank or not causal:
if qkv_format in ["bshd", "sbhd"]: if qkv_format in ["bshd", "sbhd"]:
if i == 0:
out = flash_attn_fwd_out_correction_init(
out_per_step[0],
softmax_lse,
softmax_lse_per_step[0],
seq_dim,
)
out = out.view(q.shape)
else:
flash_attn_fwd_out_correction( flash_attn_fwd_out_correction(
out.view(*out_per_step[i].shape), out.view(*out_per_step[i].shape),
out_per_step[i], out_per_step[i],
softmax_lse, softmax_lse,
softmax_lse_per_step[i], softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2, seq_dim,
2 if softmax_lse_in_packed_format else seq_dim,
) )
elif qkv_format == "thd": elif qkv_format == "thd":
tex.thd_out_correction( tex.thd_out_correction(
...@@ -2672,14 +2755,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2672,14 +2755,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
else: else:
if qkv_format in ["bshd", "sbhd"]: if qkv_format in ["bshd", "sbhd"]:
out_ = out.select(seq_dim, 1) flash_attn_fwd_second_half_out_correction(
flash_attn_fwd_out_correction( out,
out_,
out_per_step[i], out_per_step[i],
softmax_lse_[..., 1, :], softmax_lse,
softmax_lse_per_step[i], softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2, seq_dim,
2 if softmax_lse_in_packed_format else seq_dim,
) )
elif qkv_format == "thd": elif qkv_format == "thd":
tex.thd_out_correction( tex.thd_out_correction(
...@@ -2701,7 +2782,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2701,7 +2782,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.batch_size = out.shape[1] ctx.batch_size = out.shape[1]
if cp_size_a2a > 1: if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device)
out = flash_attn_a2a_communicate( out = flash_attn_a2a_communicate(
out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
) )
...@@ -2842,9 +2923,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2842,9 +2923,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
else: else:
# [b, np, sq] -> [b, np, 2, sq//2] # [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view( softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous() softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
if ctx.use_fused_attention: if ctx.use_fused_attention:
if ctx.softmax_lse_in_packed_format: if ctx.softmax_lse_in_packed_format:
...@@ -2932,7 +3011,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2932,7 +3011,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not ctx.use_fused_attention: if not ctx.use_fused_attention:
out = out.view(ctx.batch_size, -1, *out.shape[-2:]) out = out.view(ctx.batch_size, -1, *out.shape[-2:])
dout = dout.view(*out.shape) dout = dout.view(*out.shape)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(
cp_size_a2a, out.device
)
out, dout = flash_attn_a2a_communicate( out, dout = flash_attn_a2a_communicate(
[out, dout], [out, dout],
chunk_ids_for_a2a, chunk_ids_for_a2a,
...@@ -3642,7 +3723,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3642,7 +3723,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dk, dv = dkv[0], dkv[1] dk, dv = dkv[0], dkv[1]
if cp_size_a2a > 1: if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device)
dq, dk, dv = flash_attn_a2a_communicate( dq, dk, dv = flash_attn_a2a_communicate(
[dq, dk, dv], [dq, dk, dv],
chunk_ids_for_a2a, chunk_ids_for_a2a,
...@@ -3692,6 +3773,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3692,6 +3773,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -3806,9 +3888,10 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3806,9 +3888,10 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
max_seqlen_kv = max_seqlen_kv // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
if use_fused_attention or qkv_format == "thd": if use_fused_attention or qkv_format == "thd":
cu_seqlens_q = cu_seqlens_q // (2 * cp_size) cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
cu_seqlens_q_padded = ( if cu_seqlens_q_padded is not None and qkv_format == "thd":
None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // (2 * cp_size) cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
) else:
cu_seqlens_q_padded = None
# [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
...@@ -3822,7 +3905,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3822,7 +3905,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
# [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
# [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
...@@ -4011,7 +4094,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -4011,7 +4094,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
# [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
# [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
...@@ -4147,7 +4230,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -4147,7 +4230,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
# [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device)
dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
# [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
...@@ -4312,7 +4395,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4312,7 +4395,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device)
q, k, v = flash_attn_a2a_communicate( q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
) )
...@@ -4383,7 +4466,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4383,7 +4466,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
rng_state = fa_outputs[3] if not _use_flash_attn_3 else None rng_state = fa_outputs[3] if not _use_flash_attn_3 else None
aux_ctx_tensors = [softmax_lse, rng_state] aux_ctx_tensors = [softmax_lse, rng_state]
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device)
out = flash_attn_a2a_communicate( out = flash_attn_a2a_communicate(
out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
) )
...@@ -4534,7 +4617,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4534,7 +4617,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out = out.view(ctx.batch_size, -1, *out.shape[-2:]) out = out.view(ctx.batch_size, -1, *out.shape[-2:])
dout = dout.view(*out.shape) dout = dout.view(*out.shape)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device)
out, dout = flash_attn_a2a_communicate( out, dout = flash_attn_a2a_communicate(
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
) )
...@@ -4657,7 +4740,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4657,7 +4740,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
**fa_backward_kwargs, **fa_backward_kwargs,
) )
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device)
dq, dk, dv = flash_attn_a2a_communicate( dq, dk, dv = flash_attn_a2a_communicate(
[dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
) )
...@@ -4737,6 +4820,7 @@ def attn_forward_func_with_cp( ...@@ -4737,6 +4820,7 @@ def attn_forward_func_with_cp(
fp8=False, fp8=False,
fp8_meta=None, fp8_meta=None,
quantizers=None, quantizers=None,
pad_between_seqs=False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Attention implementation with context parallelism. Attention implementation with context parallelism.
...@@ -4804,7 +4888,7 @@ def attn_forward_func_with_cp( ...@@ -4804,7 +4888,7 @@ def attn_forward_func_with_cp(
] ]
if cp_comm_type in ["p2p", "a2a+p2p"]: if cp_comm_type in ["p2p", "a2a+p2p"]:
args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers] args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, pad_between_seqs]
out = AttnFuncWithCPAndKVP2P.apply(*args) out = AttnFuncWithCPAndKVP2P.apply(*args)
elif cp_comm_type == "all_gather": elif cp_comm_type == "all_gather":
args.pop(5) args.pop(5)
...@@ -5823,6 +5907,7 @@ class FlashAttention(torch.nn.Module): ...@@ -5823,6 +5907,7 @@ class FlashAttention(torch.nn.Module):
deterministic=self.deterministic, deterministic=self.deterministic,
window_size=window_size, window_size=window_size,
quantizers=quantizers, quantizers=quantizers,
pad_between_seqs=False,
) )
else: else:
...@@ -6529,6 +6614,7 @@ class FusedAttention(torch.nn.Module): ...@@ -6529,6 +6614,7 @@ class FusedAttention(torch.nn.Module):
fp8: bool = False, fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None, quantizers=None,
pad_between_seqs: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
assert ( assert (
...@@ -6667,6 +6753,7 @@ class FusedAttention(torch.nn.Module): ...@@ -6667,6 +6753,7 @@ class FusedAttention(torch.nn.Module):
fp8=fp8, fp8=fp8,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
quantizers=quantizers, quantizers=quantizers,
pad_between_seqs=pad_between_seqs,
) )
else: else:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -7083,6 +7170,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7083,6 +7170,7 @@ class DotProductAttention(TransformerEngineBaseModule):
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Dot Product Attention Layer. Dot Product Attention Layer.
...@@ -7252,6 +7340,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7252,6 +7340,9 @@ class DotProductAttention(TransformerEngineBaseModule):
Adjustments of the sequence_len_offset should be done after a complete forward pass. Adjustments of the sequence_len_offset should be done after a complete forward pass.
If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient. Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
""" """
with self.prepare_forward( with self.prepare_forward(
...@@ -7526,6 +7617,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7526,6 +7617,8 @@ class DotProductAttention(TransformerEngineBaseModule):
False False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
if pad_between_seqs is None:
if qkv_format == "thd":
pad_between_seqs = ( pad_between_seqs = (
cu_seqlens_q_padded is not None cu_seqlens_q_padded is not None
and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]) and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
...@@ -7533,6 +7626,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7533,6 +7626,8 @@ class DotProductAttention(TransformerEngineBaseModule):
cu_seqlens_kv_padded is not None cu_seqlens_kv_padded is not None
and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]) and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
) )
else:
pad_between_seqs = False
attention_params = AttentionParams( attention_params = AttentionParams(
qkv_type=type(query_layer), qkv_type=type(query_layer),
...@@ -7666,6 +7761,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7666,6 +7761,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_comm_type=self.cp_comm_type, cp_comm_type=self.cp_comm_type,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
pad_between_seqs=pad_between_seqs,
) )
return self.fused_attention( return self.fused_attention(
query_layer, query_layer,
...@@ -7692,6 +7788,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7692,6 +7788,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
quantizers=self.quantizers, quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
) )
from .cpu_offload import CPUOffloadEnabled from .cpu_offload import CPUOffloadEnabled
...@@ -8188,6 +8285,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -8188,6 +8285,7 @@ class MultiheadAttention(torch.nn.Module):
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
pad_between_seqs: Optional[bool] = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
""" """
Forward propagation for MultiheadAttention layer. Forward propagation for MultiheadAttention layer.
...@@ -8266,6 +8364,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -8266,6 +8364,9 @@ class MultiheadAttention(torch.nn.Module):
Calculated from `cu_seqlens_kv` if not provided. Calculated from `cu_seqlens_kv` if not provided.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
""" """
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
...@@ -8523,6 +8624,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -8523,6 +8624,7 @@ class MultiheadAttention(torch.nn.Module):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
inference_params=inference_params, inference_params=inference_params,
pad_between_seqs=pad_between_seqs,
) )
# =================== # ===================
......
...@@ -546,6 +546,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -546,6 +546,7 @@ class TransformerLayer(torch.nn.Module):
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
pad_between_seqs: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Transformer Layer: attention block and a feedforward network (MLP) Transformer Layer: attention block and a feedforward network (MLP)
...@@ -637,6 +638,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -637,6 +638,9 @@ class TransformerLayer(torch.nn.Module):
inference_params: InferenceParams, default = None inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference. to efficiently calculate and store the context during inference.
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
""" """
if self_attn_mask_type is None: if self_attn_mask_type is None:
...@@ -697,6 +701,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -697,6 +701,7 @@ class TransformerLayer(torch.nn.Module):
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
pad_between_seqs=pad_between_seqs,
) )
if self.apply_residual_connection_post_layernorm and not self.output_layernorm: if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment