Unverified Commit c24a4c41 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Hierarchical CP implementation (Ulysses + Ring) (#1209)



* change API for hierarchical CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* move fp8 code before qkv reshape
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try to insert A2A for hierarchical CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make fwd work
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove a redundant sync
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make bwd of hierarchical CP work
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix dout a2a in bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix q_f16 with fp8
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert hierarchical CP implementation does not support THD format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert hierarchical CP does not support attn bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add unit test for hierarchical CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix cp_comm_type in unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix and code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* an assert info change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* dout shape fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move function definitions to the front of the first call
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix tensor view comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* refine CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save cp_size_a2a and rank_a2a in fwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more explainations of cp_group in doc_string
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 60f738ff
......@@ -59,6 +59,17 @@ def run_dpa_with_cp(
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert (
world_size % 2 == 0
), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!"
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True)
......@@ -167,13 +178,6 @@ def run_dpa_with_cp(
else:
bias = None
# make sure all GPU ranks have same inputs
for x in [q, k, v, dout] + ([] if bias is None else [bias]):
dist.broadcast(x, 0, group=cp_comm_group)
if qkv_format == "thd":
for x in [cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded]:
dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
......@@ -239,7 +243,10 @@ def run_dpa_with_cp(
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if dtype == "fp8":
......
......@@ -36,8 +36,13 @@ model_configs_flash_attn = {
}
def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
def get_bash_arguments(num_gpus_per_node, **kwargs):
args = [
"python",
"-m",
"torch.distributed.launch",
"--nproc-per-node=" + str(num_gpus_per_node),
]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
args.append(script_path)
......@@ -51,20 +56,20 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and qkv_format == "thd":
if "a2a" in cp_comm_type and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
......@@ -72,6 +77,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
subprocess.run(
get_bash_arguments(
num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2,
dtype=dtype,
model=model,
qkv_format=qkv_format,
......@@ -106,7 +112,7 @@ model_configs_fused_attn = {
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!")
......@@ -122,7 +128,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and cp_comm_type == "a2a":
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
......@@ -140,9 +146,9 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
......@@ -150,6 +156,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
subprocess.run(
get_bash_arguments(
num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2,
dtype=dtype,
model=model,
qkv_format=qkv_format,
......
This diff is collapsed.
......@@ -484,7 +484,7 @@ class TransformerLayer(torch.nn.Module):
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_group: Union[dist_group_type, List[dist_group_type], None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
......@@ -495,21 +495,27 @@ class TransformerLayer(torch.nn.Module):
Parameters
----------
cp_group : ProcessGroup
cp_group : Union[ProcessGroup, List[ProcessGroup]]
context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
and cp_group[1] are for a2a and p2p communications respectively.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cp_comm_type : str
cp_comm_type : str, default = `p2p`
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a".
Can be "p2p" or "all_gather" or "a2a", or "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment