Unverified Commit c24a4c41 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Hierarchical CP implementation (Ulysses + Ring) (#1209)



* change API for hierarchical CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* move fp8 code before qkv reshape
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try to insert A2A for hierarchical CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make fwd work
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove a redundant sync
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make bwd of hierarchical CP work
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix dout a2a in bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix q_f16 with fp8
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert hierarchical CP implementation does not support THD format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert hierarchical CP does not support attn bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add unit test for hierarchical CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix cp_comm_type in unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix and code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* an assert info change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* dout shape fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move function definitions to the front of the first call
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix tensor view comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* refine CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save cp_size_a2a and rank_a2a in fwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more explainations of cp_group in doc_string
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 60f738ff
......@@ -59,6 +59,17 @@ def run_dpa_with_cp(
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert (
world_size % 2 == 0
), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!"
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True)
......@@ -167,13 +178,6 @@ def run_dpa_with_cp(
else:
bias = None
# make sure all GPU ranks have same inputs
for x in [q, k, v, dout] + ([] if bias is None else [bias]):
dist.broadcast(x, 0, group=cp_comm_group)
if qkv_format == "thd":
for x in [cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded]:
dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
......@@ -239,7 +243,10 @@ def run_dpa_with_cp(
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if dtype == "fp8":
......
......@@ -36,8 +36,13 @@ model_configs_flash_attn = {
}
def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
def get_bash_arguments(num_gpus_per_node, **kwargs):
args = [
"python",
"-m",
"torch.distributed.launch",
"--nproc-per-node=" + str(num_gpus_per_node),
]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
args.append(script_path)
......@@ -51,20 +56,20 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and qkv_format == "thd":
if "a2a" in cp_comm_type and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
......@@ -72,6 +77,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
subprocess.run(
get_bash_arguments(
num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2,
dtype=dtype,
model=model,
qkv_format=qkv_format,
......@@ -106,7 +112,7 @@ model_configs_fused_attn = {
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!")
......@@ -122,7 +128,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and cp_comm_type == "a2a":
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
......@@ -140,9 +146,9 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
......@@ -150,6 +156,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
subprocess.run(
get_bash_arguments(
num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2,
dtype=dtype,
model=model,
qkv_format=qkv_format,
......
......@@ -1402,11 +1402,128 @@ def get_cu_seqlens_on_cp_rank(
return cu_seqlens_on_cp_rank
@torch.compile
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
"""
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks
before or after CP communications (e.g., all-gather, all-to-all). This function is to compute
sequence chunk ids for reordering.
"""
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
if to_contiguous:
for rank in range(cp_size):
chunk_ids[rank] = 2 * rank
chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
else:
for rank in range(cp_size):
chunk_ids[2 * rank] = rank
chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
return chunk_ids
@torch.compile
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
"""Reorder sequence chunk for A2A communication."""
if before_attn:
# [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
# or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
x = x.movedim(0, seq_dim).contiguous()
# [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
# or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
# reorder the sequence chunks
x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
else:
# [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.movedim(seq_dim, 0).contiguous()
# reorder the sequence chunks
x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
# [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
x = x.view(cp_size, 2, *x.shape[1:])
return x
def flash_attn_a2a_communicate(
a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
chunk_ids_for_a2a: torch.Tensor,
seq_dim: int,
cp_size: int,
cp_group: dist_group_type,
cp_stream: torch.cuda.Stream,
before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""A2A communication for context parallelism."""
a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
if before_attn:
for i in range(len(a2a_inputs) + 2):
if 0 < i < len(a2a_inputs) + 1:
a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
)
if i > 1:
with torch.cuda.stream(cp_stream):
a2a_reqs[i - 2].wait()
x = a2a_outputs[i - 2]
# reorder the sequence chunks
x = reorder_seq_chunks_for_a2a(
x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
)
# [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
if i < len(a2a_inputs):
x = a2a_inputs[i]
# [b, s, np, hn] -> [b, s, cp, np//cp, hn]
# or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
# [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn]
# or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
a2a_inputs[i] = x.movedim(-3, 0).contiguous()
else:
for i in range(len(a2a_inputs) + 2):
if 0 < i < len(a2a_inputs) + 1:
a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
)
if i < len(a2a_inputs):
x = a2a_inputs[i]
# [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
# or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
# reorder the sequence chunks
a2a_inputs[i] = reorder_seq_chunks_for_a2a(
x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
)
if i > 1:
with torch.cuda.stream(cp_stream):
a2a_reqs[i - 2].wait()
x = a2a_outputs[i - 2]
# [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn]
# or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
# [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn]
# or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
torch.cuda.current_stream().wait_stream(cp_stream)
return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
"""
Attention implementation with context parallelism. Exchange KV between CP ranks
with P2P in ring topology. Split attention compute into multiple steps, and overlap
current-step compute with next-step communication.
This implementation also supports hierarchical CP, which parallelizes attention
heads in low-level CP groups and parallelizes sequence dimension in high-level CP
groups. For more details, please refer to `LongVILA <https://arxiv.org/abs/2408.10188>`_
and `USP <https://arxiv.org/abs/2405.07719>`_.
"""
@staticmethod
......@@ -1439,10 +1556,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if isinstance(cp_group, list):
assert (
qkv_format != "thd"
), f"{qkv_format} format is not supported with hierarchical CP implementation yet!"
assert attn_bias_type == "no_bias", (
f"{attn_bias_type} bias type is not supported with hierarchical CP implementation"
" yet!"
)
cp_group_a2a = cp_group[0]
cp_size_a2a = get_distributed_world_size(cp_group_a2a)
rank_a2a = get_distributed_rank(cp_group_a2a)
cp_group = cp_group[1]
else:
cp_group_a2a = None
cp_size_a2a = 1
rank_a2a = 0
cp_size = get_distributed_world_size(cp_group)
rank = get_distributed_rank(cp_group)
send_dst = cp_global_ranks[(rank + 1) % cp_size]
recv_src = cp_global_ranks[(rank - 1) % cp_size]
send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
causal = "causal" in attn_mask_type
......@@ -1463,6 +1597,59 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_per_step = [None for _ in range(cp_size)]
cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
if fp8:
if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_backend = FusedAttnBackend["FP8"]
if fp8_meta["recipe"].fp8_mha:
assert (
isinstance(q, Float8Tensor)
and isinstance(k, Float8Tensor)
and isinstance(v, Float8Tensor)
), "q/k/v must be Float8Tensors for FP8 MHA!"
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
else:
q_f16, k_f16, v_f16 = q, k, v
if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
k, v = [
cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
for x in [k_f16, v_f16]
]
fp8_meta_kwargs = {}
fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
fp8_meta_kwargs["d_scale_s_offset"] = META_S
fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
fp8_meta_kwargs["q_scale_s_offset"] = META_S
fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
q_f16 = q
if use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True)
q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True
)
if not fp8:
q_f16 = q
elif not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16 = q
q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
assert qkv_format == "thd" or (
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"
......@@ -1520,47 +1707,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# synchronize fwd results correction across steps
fwd_results_correction_done = torch.cuda.Event()
if fp8:
if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_backend = FusedAttnBackend["FP8"]
if fp8_meta["recipe"].fp8_mha:
assert (
isinstance(q, Float8Tensor)
and isinstance(k, Float8Tensor)
and isinstance(v, Float8Tensor)
), "q/k/v must be Float8Tensors for FP8 MHA!"
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
else:
q_f16, k_f16, v_f16 = q, k, v
q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
k, v = [
cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
for x in [k_f16, v_f16]
]
fp8_meta_kwargs = {}
fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
fp8_meta_kwargs["d_scale_s_offset"] = META_S
fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
fp8_meta_kwargs["q_scale_s_offset"] = META_S
fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
q_f16 = q
if use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
p2p_comm_buffers = [None for _ in range(cp_size)]
if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
......@@ -2131,12 +2277,26 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
kv = p2p_comm_buffers[-1]
if use_fused_attention:
if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:])
ctx.batch_size = out.shape[0]
elif qkv_format == "sbhd":
out = out.view(-1, *out.shape[-3:])
else:
ctx.batch_size = out.shape[1]
if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False)
out = flash_attn_a2a_communicate(
out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
)
if use_fused_attention:
if qkv_format == "bshd":
# [b*s, np, hn] -> [b, s, np, hn]
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
elif qkv_format == "sbhd":
# [s*b, np, hn] -> [s, b, np, hn]
out = out.view(-1, ctx.batch_size, *out.shape[-2:])
elif not use_fused_attention:
out = out.view(-1, *out.shape[-2:])
if fp8 and use_fused_attention:
......@@ -2165,6 +2325,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
elif fp8 and fp8_meta["recipe"].fp8_mha:
q_fp8 = Float8Tensor(
data=q,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_QKV,
fp8_dtype=fp8_dtype_forward,
dtype=q_fp8.dtype,
)
kv_fp8 = Float8Tensor(
data=kv,
fp8_meta=fp8_meta,
......@@ -2176,6 +2344,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16
fp8_fwd_scales, fp8_fwd_scale_invs = None, None
else:
q_f16 = q_f16.view(q.shape)
q_save, kv_save, out_save = q_f16, kv, out_f16
fp8_fwd_scales, fp8_fwd_scale_invs = None, None
......@@ -2193,8 +2362,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
*rng_states,
*attn_biases,
)
ctx.cp_group_a2a = cp_group_a2a
ctx.cp_size_a2a = cp_size_a2a
ctx.rank_a2a = rank_a2a
ctx.cp_group = cp_group
ctx.cp_global_ranks = cp_global_ranks
ctx.cp_stream = cp_stream
ctx.dropout_p = dropout_p
ctx.total_tokens_kv = total_tokens_kv
ctx.max_seqlen_q = max_seqlen_q
......@@ -2212,10 +2385,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
@staticmethod
def backward(ctx, dout):
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
......@@ -2228,6 +2404,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type
if ctx.qkv_format in ["bshd", "sbhd"]:
seq_dim = ctx.qkv_format.index("s")
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
else:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
......@@ -2262,6 +2439,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b, np, sq] -> [b, np, sq, 1]
softmax_lse.unsqueeze_(-1)
dout_dtype = dout.dtype
if ctx.fp8:
if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
......@@ -2272,7 +2450,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
dkv_fp8_ = torch.empty_like(dkv_fp8)
dout_dtype = dout.dtype
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
......@@ -2296,7 +2473,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
assert False, "FP8 is only supported with Fused Attention!"
else:
if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
q, kv, dout = [x.from_float8(x.dtype) for x in [q, kv, dout]]
q, kv = [x.from_float8(x.dtype) for x in [q, kv]]
if cp_size_a2a == 1:
dout = dout.from_float8(dout_dtype)
else:
dout_fp8_dtype = dout._fp8_dtype
dout_scale_inv = dout._scale_inv
dout = dout._data
dq = torch.empty_like(q)
if ctx.qkv_format == "thd" and causal:
dq[cu_seqlens_q_padded[-1] :].fill_(0)
......@@ -2308,9 +2491,28 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_dqkv_dtype = TE_DType[dout.dtype]
fused_attn_dqkv_dtype = TE_DType[dout_dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if cp_size_a2a > 1:
if not ctx.use_fused_attention:
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
dout = dout.view(*out.shape)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True)
out, dout = flash_attn_a2a_communicate(
[out, dout],
chunk_ids_for_a2a,
seq_dim,
cp_size_a2a,
ctx.cp_group_a2a,
ctx.cp_stream,
True,
)
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
dout = cast_from_fp8(
dout, None, None, dout_fp8_dtype, TE_DType[dout_dtype], scale_inv=dout_scale_inv
)
out = out.view(*q.shape)
dout = dout.view(*q.shape)
send_recv_reqs = []
......@@ -2906,6 +3108,25 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward)
for x in [dq, dkv]
]
dk, dv = dkv[0], dkv[1]
if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False)
dq, dk, dv = flash_attn_a2a_communicate(
[dq, dk, dv],
chunk_ids_for_a2a,
seq_dim,
cp_size_a2a,
ctx.cp_group_a2a,
ctx.cp_stream,
False,
)
if ctx.qkv_format == "bshd":
dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
elif ctx.qkv_format == "sbhd":
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha:
dq, dk, dv = [
Float8Tensor(
data=x,
......@@ -2915,10 +3136,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fp8_dtype=fp8_dtype_backward,
dtype=dout_dtype,
)
for x in [dq, dkv[0], dkv[1]]
for x in [dq, dk, dv]
]
else:
dk, dv = dkv[0], dkv[1]
if attn_dbias is not None:
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
......@@ -2951,26 +3170,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
@torch.compile
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
"""
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks
before or after CP communications (e.g., all-gather, all-to-all). This function is to compute
sequence chunk ids for reordering.
"""
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
if to_contiguous:
for rank in range(cp_size):
chunk_ids[rank] = 2 * rank
chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
else:
for rank in range(cp_size):
chunk_ids[2 * rank] = rank
chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
return chunk_ids
def get_kv_seq_info_after_all_gather(
local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal
):
......@@ -3097,7 +3296,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
with torch.cuda.stream(flash_attn_streams[i]):
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
# or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q.select(seq_dim, i).contiguous()
kv_seq_range_per_step[i], window_size_per_step[i] = (
get_kv_seq_info_after_all_gather(
......@@ -3259,7 +3459,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
with torch.cuda.stream(flash_attn_streams[i]):
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
# or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q.select(seq_dim, i).contiguous()
seq_start_idx, seq_end_idx = (
kv_seq_range_per_step[i][0],
......@@ -3396,88 +3597,6 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
)
@torch.compile
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
"""Reorder sequence chunk for A2A communication."""
if before_attn:
# [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
x = x.movedim(0, seq_dim).contiguous()
# [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
# reorder the sequence chunks
x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
else:
# [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.movedim(seq_dim, 0).contiguous()
# reorder the sequence chunks
x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
# [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
x = x.view(cp_size, 2, *x.shape[1:])
return x
def flash_attn_a2a_communicate(
a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
chunk_ids_for_a2a: torch.Tensor,
seq_dim: int,
cp_size: int,
cp_group: dist_group_type,
cp_stream: torch.cuda.Stream,
before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""A2A communication for context parallelism."""
a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
if before_attn:
for i in range(len(a2a_inputs) + 2):
if 0 < i < len(a2a_inputs) + 1:
a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
)
if i > 1:
with torch.cuda.stream(cp_stream):
a2a_reqs[i - 2].wait()
x = a2a_outputs[i - 2]
# reorder the sequence chunks
x = reorder_seq_chunks_for_a2a(
x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
)
# [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
if i < len(a2a_inputs):
x = a2a_inputs[i]
# [b, s, np, hn] -> [b, s, cp, np//cp, hn] or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
# [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
a2a_inputs[i] = x.movedim(-3, 0).contiguous()
else:
for i in range(len(a2a_inputs) + 2):
if 0 < i < len(a2a_inputs) + 1:
a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
)
if i < len(a2a_inputs):
x = a2a_inputs[i]
# [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
# reorder the sequence chunks
a2a_inputs[i] = reorder_seq_chunks_for_a2a(
x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
)
if i > 1:
with torch.cuda.stream(cp_stream):
a2a_reqs[i - 2].wait()
x = a2a_outputs[i - 2]
# [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
# [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
torch.cuda.current_stream().wait_stream(cp_stream)
return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
"""
Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.
......@@ -3969,6 +4088,22 @@ def attn_forward_func_with_cp(
Attention implementation with context parallelism.
"""
if cp_comm_type == "a2a+p2p":
assert isinstance(
cp_group, list
), "Hierarchical CP implementation needs multi-level CP groups!"
assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
if get_distributed_world_size(cp_group[0]) == 1:
cp_group = cp_group[1]
cp_comm_type = "p2p"
elif get_distributed_world_size(cp_group[1]) == 1:
cp_group = cp_group[0]
cp_comm_type = "a2a"
else:
assert isinstance(
cp_group, dist_group_type
), f"Unsupported process group for CP communication type {cp_comm_type}!"
assert qkv_format in [
"bshd",
"sbhd",
......@@ -4023,7 +4158,7 @@ def attn_forward_func_with_cp(
use_fused_attention,
]
if cp_comm_type == "p2p":
if cp_comm_type in ["p2p", "a2a+p2p"]:
args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream]
out = AttnFuncWithCPAndKVP2P.apply(*args)
elif cp_comm_type == "all_gather":
......@@ -4843,7 +4978,7 @@ class FlashAttention(torch.nn.Module):
attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
cp_group: Optional[dist_group_type] = None,
cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
......@@ -4863,7 +4998,12 @@ class FlashAttention(torch.nn.Module):
qkv_layout in QKVLayouts
), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group)
cp_size = 1
if isinstance(cp_group, dist_group_type):
cp_size = get_distributed_world_size(cp_group)
elif isinstance(cp_group, list):
for group in cp_group:
cp_size *= get_distributed_world_size(group)
context_parallel = cp_size > 1
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
......@@ -6652,7 +6792,7 @@ class FusedAttention(torch.nn.Module):
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
cp_group: Optional[dist_group_type] = None,
cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
......@@ -6674,7 +6814,12 @@ class FusedAttention(torch.nn.Module):
qkv_layout in QKVLayouts
), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group)
cp_size = 1
if isinstance(cp_group, dist_group_type):
cp_size = get_distributed_world_size(cp_group)
elif isinstance(cp_group, list):
for group in cp_group:
cp_size *= get_distributed_world_size(group)
context_parallel = cp_size > 1
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
......@@ -6920,8 +7065,11 @@ class DotProductAttention(TransformerEngineBaseModule):
tensor parallel world size.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
cp_group : ProcessGroup, default = `None`
cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None`
context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
and cp_group[1] are for a2a and p2p communications respectively.
cp_global_ranks : list of global rank IDs, default = `None`
global rank IDs of GPUs that are in cp_group.
cp_stream : CUDA stream, default = `None`
......@@ -6929,15 +7077,18 @@ class DotProductAttention(TransformerEngineBaseModule):
compute and communication overlapping. To address the wave quantization
issue of each split step, we add an additional CUDA stream so that we
can overlap two flash attention kernels.
cp_comm_type : str
cp_comm_type : str, default = `p2p`
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a".
Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
"""
def __init__(
......@@ -6955,7 +7106,7 @@ class DotProductAttention(TransformerEngineBaseModule):
tp_group: Optional[dist_group_type] = None,
layer_number: Optional[int] = None,
attention_type: str = "self",
cp_group: Optional[dist_group_type] = None,
cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
......@@ -7110,7 +7261,7 @@ class DotProductAttention(TransformerEngineBaseModule):
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_group: Union[dist_group_type, List[dist_group_type], None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
......@@ -7121,21 +7272,27 @@ class DotProductAttention(TransformerEngineBaseModule):
Parameters
----------
cp_group : ProcessGroup
cp_group : Union[ProcessGroup, List[ProcessGroup]]
context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
and cp_group[1] are for a2a and p2p communications respectively.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cp_comm_type : str
cp_comm_type : str, default = `p2p`
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a".
Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
"""
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
......@@ -7445,7 +7602,12 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
batch_size = len(cu_seqlens_q) - 1
cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group)
cp_size = 1
if isinstance(self.cp_group, dist_group_type):
cp_size = get_distributed_world_size(self.cp_group)
elif isinstance(self.cp_group, list):
for group in self.cp_group:
cp_size *= get_distributed_world_size(group)
context_parallel = cp_size > 1
if qkv_format in ["sbhd", "bshd"]:
......@@ -8141,7 +8303,7 @@ class MultiheadAttention(torch.nn.Module):
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_group: Union[dist_group_type, List[dist_group_type], None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
......@@ -8152,21 +8314,27 @@ class MultiheadAttention(torch.nn.Module):
Parameters
----------
cp_group : ProcessGroup
cp_group : Union[ProcessGroup, List[ProcessGroup]]
context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
and cp_group[1] are for a2a and p2p communications respectively.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cp_comm_type : str
cp_comm_type : str, default = `p2p`
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a".
Can be "p2p" or "all_gather" or "a2a", "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
......
......@@ -484,7 +484,7 @@ class TransformerLayer(torch.nn.Module):
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_group: Union[dist_group_type, List[dist_group_type], None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
......@@ -495,21 +495,27 @@ class TransformerLayer(torch.nn.Module):
Parameters
----------
cp_group : ProcessGroup
cp_group : Union[ProcessGroup, List[ProcessGroup]]
context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
and cp_group[1] are for a2a and p2p communications respectively.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cp_comm_type : str
cp_comm_type : str, default = `p2p`
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a".
Can be "p2p" or "all_gather" or "a2a", or "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment