Unverified Commit 40dda924 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add a context parallelism implementation with QKVO all-to-all (#1160)



* clean code for CP function args
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a placeholder for Ulysses implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* commit code change to CP+A2A
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* finish the draft fwd implementation of Ulysses
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft bwd implementation of Ulysses
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make swa work with ulysses
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* commit FP8 code for Ulysses
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix qkv type in the bwd of FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix qkv_dtype of FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor code change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* config cp correction dtype of FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code style change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save chunk_ids
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try to make Ulysses A2A async
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make more a2a async
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a2a_outputs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix chunk_ids generation for A2A
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* avoid code duplication of a2a before attn
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove code duplication of a2a after attn
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cp_stream in A2A implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix qkv of fp8_fwd + bf16_bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kernel order in cp a2a communication
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning for CP a2a
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix merging with main
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a2a communication order
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* adjust sequence chunk reordering for a2a
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add docstring for A2A implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change an assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add unit tests of A2A implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more A2A unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP unit tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more cp unit tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix window size of no_mask
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fused attn does not support swa+no_mask
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change num_gqa_groups to 2 for A2A implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* function and variable renaming
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning for CP all-gather implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* some function renaming
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* commit code change for kv all-gather implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix all-gather implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a window size check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add unit test of all_gather+no_mask
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix all-gather cp implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FP8 with A2A implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add paper references to CP implementations with all-gather and all-to-all
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change pdf to abs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* elaborate cp_comm_type
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP docstring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2a9845e1
......@@ -22,10 +22,16 @@ model_configs_flash_attn = {
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # GQA
}
......@@ -45,31 +51,32 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
),
check=True,
)
......@@ -81,10 +88,16 @@ model_configs_fused_attn = {
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
}
......@@ -93,37 +106,27 @@ model_configs_fused_attn = {
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+.")
pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0")
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!")
pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and cp_comm_type == "a2a":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
"Fused attention does not support sliding window attention + context parallelism yet!"
"Sliding window attention only can be supported with the implementation of QKVO A2A!"
)
if cp_comm_type == "all_gather" and dtype == "fp8":
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
)
......@@ -131,10 +134,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("FP8 attention cannot work with THD format yet!")
if dtype == "fp8" and config.attn_bias_type != "no_bias":
pytest.skip("FP8 attention cannot work with bias yet!")
if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
),
check=True,
)
......@@ -614,12 +614,6 @@ def get_attention_backend(
"with causal mask, no dropout, and qkv_format = bshd/sbhd"
)
use_fused_attention = False
elif context_parallel:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with context parallelism"
)
use_fused_attention = False
elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
"no_mask",
"padding",
......@@ -1429,9 +1423,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
......@@ -1441,6 +1432,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
use_fused_attention,
fp8,
fp8_meta,
cp_group,
cp_global_ranks,
cp_stream,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -2946,10 +2940,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
None,
None,
None,
attn_dbias,
None,
None,
None,
attn_dbias,
None,
None,
None,
......@@ -2958,30 +2952,56 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
@torch.compile
def get_seq_chunk_ids_to_all_gathered_kv(
local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
"""
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks
before or after CP communications (e.g., all-gather, all-to-all). This function is to compute
sequence chunk ids for reordering.
"""
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
if to_contiguous:
for rank in range(cp_size):
chunk_ids[rank] = 2 * rank
chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
else:
for rank in range(cp_size):
chunk_ids[2 * rank] = rank
chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
return chunk_ids
def get_kv_seq_info_after_all_gather(
local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal
):
"""Compute sequence chunk ids to the all-gathered KV."""
seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv
seq_start_idx = max(0, seq_end_idx - max_seqlen_q - window_size_left)
seqlen = seq_end_idx - seq_start_idx
num_chunks = (seqlen + max_seqlen_kv - 1) // max_seqlen_kv
chunk_ids = torch.arange(
local_chunk_id - num_chunks + 1,
local_chunk_id + 1,
dtype=torch.int32,
device=device,
)
chunk_ids_to_all_gathered_kv = torch.where(
chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1
)
return chunk_ids_to_all_gathered_kv
"""Compute KV sequence index range and update window size after all-gather."""
local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv
full_seq_end_idx = max_seqlen_kv * cp_size * 2
if window_size is None:
window_size = (-1, 0) if causal else (-1, -1)
if window_size[1] == -1:
seq_end_idx = full_seq_end_idx
window_size_right = -1
else:
seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1])
window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx
if window_size[0] == -1:
seq_start_idx = 0
window_size_left = -1
else:
seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0])
window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx
return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right)
class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
"""
Attention implementation with context parallelism.
KV all-gather between CP ranks is exposed.
Attention implementation with context parallelism. KV all-gather between CP ranks is exposed.
Refer section 3.3.2 of `The Llama 3 Herd of Models <https://arxiv.org/abs/2407.21783>`_.
"""
@staticmethod
......@@ -2992,14 +3012,10 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
......@@ -3008,6 +3024,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
deterministic,
use_fused_attention,
window_size,
cp_group,
cp_stream,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -3017,10 +3035,9 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
assert causal and not padding, f"{attn_mask_type} mask type is not supported!"
assert not padding, f"{attn_mask_type} mask type is not supported!"
if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
attn_mask_type = attn_mask_type + "_bottom_right"
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
assert (
......@@ -3029,6 +3046,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
fa_optional_forward_kwargs = {}
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
assert qkv_format != "thd", f"{qkv_format} format is not supported!"
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
......@@ -3041,31 +3060,35 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
max_seqlen_q = max_seqlen_q // (2 * cp_size)
max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
cu_seqlens_kv = cu_seqlens_kv // (2 * cp_size)
cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size)
if causal:
if qkv_format == "bshd":
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q = q.view(q.shape[0], 2, q.shape[1] // 2, *q.shape[2:])
# [b, s, np, hn] -> [s, b, np, hn]
k, v = [x.transpose(0, 1).contiguous() for x in [k, v]]
elif qkv_format == "sbhd":
# [s, b, np, hn] -> [2, s//2, b, np, hn]
q = q.view(2, q.shape[0] // 2, *q.shape[1:])
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
# [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
# [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn]
k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]]
# [s, b, np, hn] -> [cp, s, b, np, hn]
k_ag, _ = gather_along_first_dim(k, cp_group)
v_ag, _ = gather_along_first_dim(v, cp_group)
cp_stream.wait_stream(torch.cuda.current_stream())
# [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
# [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
k_ag = k_ag.view(-1, *k.shape[1:])
v_ag = v_ag.view(-1, *v.shape[1:])
cp_stream.wait_stream(torch.cuda.current_stream())
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
chunk_ids_to_kv_ag_per_step = [None, None]
kv_seq_range_per_step = [None, None]
window_size_per_step = [None, None]
cu_seqlens_kv_per_step = [None, None]
out_per_step = [None, None]
softmax_lse_per_step = [None, None]
rng_states = [None, None]
......@@ -3074,53 +3097,36 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
with torch.cuda.stream(flash_attn_streams[i]):
chunk_ids_to_kv_ag = get_seq_chunk_ids_to_all_gathered_kv(
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q.select(seq_dim, i).contiguous()
kv_seq_range_per_step[i], window_size_per_step[i] = (
get_kv_seq_info_after_all_gather(
local_seq_chunk_ids[i],
cp_size,
max_seqlen_q,
max_seqlen_kv,
(
max_seqlen_kv * cp_size * 2
if (window_size is None or window_size[0] == -1)
else window_size[0]
),
k.device,
)
chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag
num_kv_chunks = chunk_ids_to_kv_ag.numel()
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn]
k_ = (
torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(k.shape[1], -1, *k.shape[-2:])
window_size,
causal,
)
v_ = (
torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(v.shape[1], -1, *v.shape[-2:])
)
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q[i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn]
k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *k.shape[-3:]
seq_start_idx, seq_end_idx = (
kv_seq_range_per_step[i][0],
kv_seq_range_per_step[i][1],
)
v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *v.shape[-3:]
max_seqlen_kv_ = seq_end_idx - seq_start_idx
cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens(
k.shape[1], max_seqlen_kv_, k.device
)
k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
# [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
if use_fused_attention:
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv * num_kv_chunks,
max_seqlen_kv_,
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
cu_seqlens_kv_per_step[i],
q_,
k_,
v_,
......@@ -3133,8 +3139,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
attn_bias_type=attn_bias_type,
attn_bias=attn_bias,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks,
window_size=window_size,
cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
window_size=window_size_per_step[i],
)
else:
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
......@@ -3144,14 +3150,14 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
k_,
v_,
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv * num_kv_chunks,
max_seqlen_kv_,
dropout_p,
softmax_scale,
causal=True,
causal=causal,
return_softmax=False,
window_size=window_size,
window_size=window_size_per_step[i],
**fa_optional_forward_kwargs,
)
)
......@@ -3159,9 +3165,9 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
if qkv_format == "bshd":
out[:, i - 1].copy_(out_per_step[i - 1].view_as(out[:, i - 1]))
out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape))
elif qkv_format == "sbhd":
out[i - 1].copy_(out_per_step[i - 1].view_as(out[i - 1]))
out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape))
torch.cuda.current_stream().wait_stream(cp_stream)
......@@ -3178,26 +3184,24 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
*chunk_ids_to_kv_ag_per_step,
*cu_seqlens_kv_per_step,
*out_per_step,
*softmax_lse_per_step,
*rng_states,
)
ctx.kv_seq_range_per_step = kv_seq_range_per_step
ctx.window_size_per_step = window_size_per_step
ctx.cp_group = cp_group
ctx.cp_stream = cp_stream
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.softmax_scale = softmax_scale
ctx.qkv_format = qkv_format
ctx.attn_mask_type = attn_mask_type
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
ctx.window_size = window_size
return out
@staticmethod
......@@ -3205,21 +3209,20 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
(q, k, v, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded) = (
ctx.saved_tensors[:7]
)
chunk_ids_to_kv_ag_per_step = ctx.saved_tensors[7:9]
out_per_step = ctx.saved_tensors[9:11]
softmax_lse_per_step = ctx.saved_tensors[11:13]
rng_states = ctx.saved_tensors[13:15]
(q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5]
cu_seqlens_kv_per_step = ctx.saved_tensors[5:7]
out_per_step = ctx.saved_tensors[7:9]
softmax_lse_per_step = ctx.saved_tensors[9:11]
rng_states = ctx.saved_tensors[11:13]
kv_seq_range_per_step = ctx.kv_seq_range_per_step
window_size_per_step = ctx.window_size_per_step
seq_dim = ctx.qkv_format.index("s")
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
dout = dout.view_as(q)
dout = dout.view(q.shape)
dq = torch.empty_like(q)
dk = torch.zeros(
(2 * cp_size, k.shape[0] // 2, *k.shape[1:]), dtype=k.dtype, device=k.device
)
dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device)
dv = torch.zeros_like(dk)
dq_per_step = [None, None]
dk_per_step = [None, None]
......@@ -3230,11 +3233,20 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
# synchronize dkv update across steps
dkv_update_done = torch.cuda.Event()
# [s, b, np, hn] -> [cp, s, b, np, hn]
k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
ctx.cp_stream.wait_stream(torch.cuda.current_stream())
# [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
# [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
k_ag = k_ag.view(-1, *k.shape[1:])
v_ag = v_ag.view(-1, *v.shape[1:])
ctx.cp_stream.wait_stream(torch.cuda.current_stream())
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
......@@ -3247,66 +3259,46 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
with torch.cuda.stream(flash_attn_streams[i]):
chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i]
num_kv_chunks = chunk_ids_to_kv_ag.numel()
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q.select(seq_dim, i).contiguous()
seq_start_idx, seq_end_idx = (
kv_seq_range_per_step[i][0],
kv_seq_range_per_step[i][1],
)
max_seqlen_kv = seq_end_idx - seq_start_idx
k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
# [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
out_ = out_per_step[i]
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn]
k_ = (
torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(k.shape[1], -1, *k.shape[-2:])
)
v_ = (
torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(v.shape[1], -1, *v.shape[-2:])
)
dout_ = dout[:, i].contiguous().view_as(out_)
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q[i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn]
k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *k.shape[-3:]
)
v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *v.shape[-3:]
)
dout_ = dout[i].contiguous().view_as(out_)
dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
if ctx.use_fused_attention:
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
]
aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv * num_kv_chunks,
max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
cu_seqlens_kv_per_step[i],
q_,
k_,
v_,
out_,
dout_,
TE_DType[q.dtype],
TE_DType[k.dtype],
TE_DType[dout.dtype],
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks,
cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type,
window_size=ctx.window_size,
window_size=window_size_per_step[i],
deterministic=ctx.deterministic,
)
else:
batch_size = k_.shape[0]
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
......@@ -3322,65 +3314,64 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dk_per_step[i],
dv_per_step[i],
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
cu_seqlens_kv_per_step[i],
ctx.max_seqlen_q,
ctx.max_seqlen_kv * num_kv_chunks,
max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
True,
window_size=ctx.window_size,
"causal" in ctx.attn_mask_type,
window_size=window_size_per_step[i],
rng_state=rng_states[i],
**fa_optional_backward_kwargs,
)
# [b*sq//2, np, hn] -> [b, sq//2, np, hn]
dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape)
# [b*s_range, np, hn] -> [b, s_range, np, hn]
dk_per_step[i], dv_per_step[i] = [
x.view(batch_size, -1, *x.shape[-2:])
for x in [dk_per_step[i], dv_per_step[i]]
]
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i - 1]
num_kv_chunks = chunk_ids_to_kv_ag.numel()
if ctx.qkv_format == "bshd":
dq[:, i - 1].copy_(dq_per_step[i - 1].view_as(dq[:, i - 1]))
dk_per_step[i - 1] = (
dk_per_step[i - 1]
.view(k.shape[1], num_kv_chunks, -1, *k.shape[-2:])
.movedim(0, 2)
.contiguous()
)
dv_per_step[i - 1] = (
dv_per_step[i - 1]
.view(v.shape[1], num_kv_chunks, -1, *v.shape[-2:])
.movedim(0, 2)
.contiguous()
)
dq[:, i - 1].copy_(dq_per_step[i - 1])
elif ctx.qkv_format == "sbhd":
dq[i - 1].copy_(dq_per_step[i - 1].view_as(dq[i - 1]))
dk_per_step[i - 1] = dk_per_step[i - 1].view(
num_kv_chunks, -1, *k.shape[-3:]
)
dv_per_step[i - 1] = dv_per_step[i - 1].view(
num_kv_chunks, -1, *v.shape[-3:]
)
dq[i - 1].copy_(dq_per_step[i - 1])
# [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn]
dk_per_step[i - 1], dv_per_step[i - 1] = [
x.movedim(seq_dim, 0).contiguous()
for x in [dk_per_step[i - 1], dv_per_step[i - 1]]
]
# wait until dkv update of last step is done
if i > 1:
flash_attn_streams[i - 1].wait_event(dkv_update_done)
dk.index_add_(0, chunk_ids_to_kv_ag, dk_per_step[i - 1])
dv.index_add_(0, chunk_ids_to_kv_ag, dv_per_step[i - 1])
seq_start_idx, seq_end_idx = (
kv_seq_range_per_step[i - 1][0],
kv_seq_range_per_step[i - 1][1],
)
dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1])
dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1])
if i < len(local_seq_chunk_ids):
flash_attn_streams[i - 1].record_event(dkv_update_done)
torch.cuda.current_stream().wait_stream(ctx.cp_stream)
# [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False)
dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
# [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
dk = dk.view(-1, *dk.shape[-3:])
dv = dv.view(-1, *dv.shape[-3:])
dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)
if ctx.qkv_format == "bshd":
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
dk = dk.transpose(0, 1).contiguous()
dv = dv.transpose(0, 1).contiguous()
elif ctx.qkv_format == "sbhd":
dq = dq.view(-1, *dq.shape[-3:])
dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
dk = dk.movedim(0, seq_dim).contiguous()
dv = dv.movedim(0, seq_dim).contiguous()
return (
None,
......@@ -3402,72 +3393,100 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
None,
None,
None,
None,
None,
)
def attn_forward_func_with_cp(
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
cp_stream,
cp_comm_type,
softmax_scale=None,
qkv_format="bshd",
attn_mask_type="causal",
attn_bias_type="no_bias",
attn_bias=None,
deterministic=False,
use_fused_attention=False,
window_size=None,
fp8=False,
fp8_meta=None,
) -> torch.Tensor:
"""
Attention implementation with context parallelism.
"""
assert qkv_format in [
"bshd",
"sbhd",
"thd",
], f"QKV format of {qkv_format} is not supported with context parallelism!"
assert (
qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!"
assert (
qkv_format != "thd"
or not use_fused_attention
or attn_mask_type in ["padding", "padding_causal"]
), (
f"Context parallelism is not supported for {attn_mask_type} mask type and "
f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
@torch.compile
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
"""Reorder sequence chunk for A2A communication."""
if before_attn:
# [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
x = x.movedim(0, seq_dim).contiguous()
# [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
# reorder the sequence chunks
x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
else:
# [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.movedim(seq_dim, 0).contiguous()
# reorder the sequence chunks
x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
# [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
x = x.view(cp_size, 2, *x.shape[1:])
return x
def flash_attn_a2a_communicate(
a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
chunk_ids_for_a2a: torch.Tensor,
seq_dim: int,
cp_size: int,
cp_group: dist_group_type,
cp_stream: torch.cuda.Stream,
before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""A2A communication for context parallelism."""
a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
if before_attn:
for i in range(len(a2a_inputs) + 2):
if 0 < i < len(a2a_inputs) + 1:
a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
)
assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
"""Attention bias is only supported with FusedAttention and "causal" """
"""or "no_mask" mask types!"""
if i > 1:
with torch.cuda.stream(cp_stream):
a2a_reqs[i - 2].wait()
x = a2a_outputs[i - 2]
# reorder the sequence chunks
x = reorder_seq_chunks_for_a2a(
x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
)
# [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
if i < len(a2a_inputs):
x = a2a_inputs[i]
# [b, s, np, hn] -> [b, s, cp, np//cp, hn] or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
# [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
a2a_inputs[i] = x.movedim(-3, 0).contiguous()
else:
for i in range(len(a2a_inputs) + 2):
if 0 < i < len(a2a_inputs) + 1:
a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
)
if i < len(a2a_inputs):
x = a2a_inputs[i]
# [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
# reorder the sequence chunks
a2a_inputs[i] = reorder_seq_chunks_for_a2a(
x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
)
assert (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
if i > 1:
with torch.cuda.stream(cp_stream):
a2a_reqs[i - 2].wait()
x = a2a_outputs[i - 2]
# [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
# [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
torch.cuda.current_stream().wait_stream(cp_stream)
return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
sliding_window_attn = (
window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
)
if sliding_window_attn or cp_comm_type == "all_gather":
out = AttnFuncWithCPAndKVAllGather.apply(
class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
"""
Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.
Refer the paper `DeepSpeed Ulysses <https://arxiv.org/abs/2309.14509>`_.
"""
@staticmethod
def forward(
ctx,
is_training,
q,
k,
......@@ -3479,8 +3498,6 @@ def attn_forward_func_with_cp(
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
......@@ -3489,33 +3506,534 @@ def attn_forward_func_with_cp(
deterministic,
use_fused_attention,
window_size,
)
elif cp_comm_type == "p2p":
out = AttnFuncWithCPAndKVP2P.apply(
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
attn_bias_type,
attn_bias,
deterministic,
use_fused_attention,
fp8,
fp8_meta,
)
cp_group,
cp_stream,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
cp_size = get_distributed_world_size(cp_group)
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
assert not padding, f"{attn_mask_type} mask type is not supported!"
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
assert (
window_size == (-1, 0)
or window_size == (-1, -1)
or use_fused_attention
or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
assert (
q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
), "The number of attention heads needs to be divisible by CP size!"
assert qkv_format != "thd", f"{qkv_format} format is not supported!"
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
batch_dim = qkv_format.index("b")
seq_dim = qkv_format.index("s")
assert (
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"
if fp8:
if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_backend = FusedAttnBackend["FP8"]
if fp8_meta["recipe"].fp8_mha:
assert (
isinstance(q, Float8Tensor)
and isinstance(k, Float8Tensor)
and isinstance(v, Float8Tensor)
), "q/k/v must be Float8Tensors for FP8 MHA!"
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16, k_f16, v_f16 = q, k, v
q, k, v = [
cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
for x in [q_f16, k_f16, v_f16]
]
fp8_meta_kwargs = {}
fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
fp8_meta_kwargs["d_scale_s_offset"] = META_S
fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
fp8_meta_kwargs["q_scale_s_offset"] = META_S
fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
fp8_meta_kwargs["q_scale_o_offset"] = META_O
fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history
fp8_meta_kwargs["amax_s_offset"] = META_S
fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history
fp8_meta_kwargs["amax_o_offset"] = META_O
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
if use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True)
q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
)
if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16, k_f16, v_f16 = q, k, v
q, k, v = [
cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
for x in [q_f16, k_f16, v_f16]
]
batch_size = q.shape[batch_dim]
if use_fused_attention:
out, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv,
q,
k,
v,
fused_attn_qkv_dtype,
fused_attn_backend,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
attn_bias=attn_bias,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
window_size=window_size,
**fp8_meta_kwargs,
)
else:
# [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn]
q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]]
(
_,
_,
_,
_,
out,
softmax_lse,
_,
rng_state,
) = _flash_attn_forward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=False,
**fa_optional_forward_kwargs,
)
aux_ctx_tensors = [softmax_lse, rng_state]
# [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn]
out = out.view(batch_size, -1, *out.shape[-2:])
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False)
out = flash_attn_a2a_communicate(
out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
)
if use_fused_attention:
if qkv_format == "bshd":
# [b*s, np, hn] -> [b, s, np, hn]
out = out.view(batch_size, -1, *out.shape[-2:])
elif qkv_format == "sbhd":
# [s*b, np, hn] -> [s, b, np, hn]
out = out.view(-1, batch_size, *out.shape[-2:])
if fp8:
if fp8_meta["recipe"].fp8_mha:
out_fp8 = Float8Tensor(
data=out,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q_fp8.dtype,
)
out = out_fp8._data
out_ret = out_fp8
else:
out_f16 = cast_from_fp8(
out,
fp8_meta["scaling_fwd"],
META_O,
fp8_dtype_forward,
TE_DType[q_f16.dtype],
)
out_ret = out_f16
else:
out_ret = out
if fp8:
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_save, k_save, v_save, out_save = q, k, v, out
elif fp8_meta["recipe"].fp8_mha:
q_fp8, k_fp8, v_fp8 = [
Float8Tensor(
data=x,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_QKV,
fp8_dtype=fp8_dtype_forward,
dtype=out_fp8.dtype,
)
for x in [q, k, v]
]
q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8
else:
q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16
else:
q_save, k_save, v_save, out_save = q, k, v, out
if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
else:
fp8_fwd_scales, fp8_fwd_scale_invs = None, None
ctx.save_for_backward(
q_save,
k_save,
v_save,
out_save,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
fp8_fwd_scales,
fp8_fwd_scale_invs,
*aux_ctx_tensors,
)
ctx.batch_size = batch_size
ctx.cp_group = cp_group
ctx.cp_stream = cp_stream
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.softmax_scale = softmax_scale
ctx.qkv_format = qkv_format
ctx.attn_mask_type = attn_mask_type
ctx.attn_bias_type = attn_bias_type
ctx.deterministic = deterministic
ctx.window_size = window_size
ctx.use_fused_attention = use_fused_attention
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.fp8_meta = fp8_meta
return out_ret
@staticmethod
def backward(ctx, dout):
cp_size = get_distributed_world_size(ctx.cp_group)
q, k, v, out = ctx.saved_tensors[:4]
cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[
4:8
]
fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10]
aux_ctx_tensors = ctx.saved_tensors[10:]
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
causal = "causal" in ctx.attn_mask_type
seq_dim = ctx.qkv_format.index("s")
if ctx.fp8:
if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_dqkv_dtype = fp8_dtype_backward
fused_attn_backend = FusedAttnBackend["FP8"]
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
dout_fp8 = dout
dout = dout_fp8._data
else:
dout_f16 = dout
dout = cast_to_fp8(
dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
)
fp8_meta_kwargs = {}
fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV]
fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP]
fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][
META_DQKV
]
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]]
if ctx.use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_dqkv_dtype = TE_DType[dout.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if not ctx.use_fused_attention:
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
dout = dout.view(*out.shape)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True)
out, dout = flash_attn_a2a_communicate(
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
)
fa_optional_backward_kwargs = {}
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = ctx.window_size
if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
if ctx.use_fused_attention:
dq, dk, dv, _ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv,
q,
k,
v,
out,
dout,
fused_attn_qkv_dtype,
fused_attn_dqkv_dtype,
aux_ctx_tensors,
fused_attn_backend,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type,
window_size=ctx.window_size,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
)
else:
softmax_lse, rng_state = aux_ctx_tensors
out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]]
dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_kv,
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
causal,
rng_state=rng_state,
**fa_optional_backward_kwargs,
)
dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False)
dq, dk, dv = flash_attn_a2a_communicate(
[dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
)
if ctx.qkv_format == "bshd":
dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
elif ctx.qkv_format == "sbhd":
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
if ctx.fp8:
if ctx.fp8_meta["recipe"].fp8_mha:
dq, dk, dv = [
Float8Tensor(
data=x,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=dout_fp8.dtype,
)
for x in [dq, dk, dv]
]
else:
dq, dk, dv = [
cast_from_fp8(
x,
ctx.fp8_meta["scaling_bwd"],
META_DQKV,
fp8_dtype_backward,
TE_DType[dout_f16.dtype],
)
for x in [dq, dk, dv]
]
return (
None,
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
def attn_forward_func_with_cp(
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
cp_stream,
cp_comm_type,
softmax_scale=None,
qkv_format="bshd",
attn_mask_type="causal",
attn_bias_type="no_bias",
attn_bias=None,
deterministic=False,
use_fused_attention=False,
window_size=None,
fp8=False,
fp8_meta=None,
) -> torch.Tensor:
"""
Attention implementation with context parallelism.
"""
assert qkv_format in [
"bshd",
"sbhd",
"thd",
], f"QKV format of {qkv_format} is not supported with context parallelism!"
assert (
qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!"
assert (
qkv_format != "thd"
or not use_fused_attention
or attn_mask_type in ["padding", "padding_causal"]
), (
f"Context parallelism is not supported for {attn_mask_type} mask type and "
f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
)
assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
"""Attention bias is only supported with FusedAttention and "causal" """
"""or "no_mask" mask types!"""
)
assert (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
sliding_window_attn = (
window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
)
assert (
not sliding_window_attn
or cp_comm_type == "a2a"
or (cp_comm_type == "all_gather" and not use_fused_attention)
), "The context parallel running configs cannot support sliding window attetnion!"
args = [
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
softmax_scale,
qkv_format,
attn_mask_type,
attn_bias_type,
attn_bias,
deterministic,
use_fused_attention,
]
if cp_comm_type == "p2p":
args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream]
out = AttnFuncWithCPAndKVP2P.apply(*args)
elif cp_comm_type == "all_gather":
args.pop(5)
args.pop(8)
args += [window_size, cp_group, cp_stream]
out = AttnFuncWithCPAndKVAllGather.apply(*args)
elif cp_comm_type == "a2a":
args += [window_size, fp8, fp8_meta, cp_group, cp_stream]
out = AttnFuncWithCPAndQKVOA2A.apply(*args)
else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
......@@ -6416,7 +6934,13 @@ class DotProductAttention(TransformerEngineBaseModule):
can overlap two flash attention kernels.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
Can be "p2p" or "all_gather" or "a2a".
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"""
def __init__(
......@@ -6608,7 +7132,13 @@ class DotProductAttention(TransformerEngineBaseModule):
cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
Can be "p2p" or "all_gather" or "a2a".
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"""
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
......@@ -7633,7 +8163,13 @@ class MultiheadAttention(torch.nn.Module):
cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
Can be "p2p" or "all_gather" or "a2a".
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
......
......@@ -503,7 +503,13 @@ class TransformerLayer(torch.nn.Module):
cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
Can be "p2p" or "all_gather" or "a2a".
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment