Unverified Commit 40dda924 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add a context parallelism implementation with QKVO all-to-all (#1160)



* clean code for CP function args
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a placeholder for Ulysses implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* commit code change to CP+A2A
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* finish the draft fwd implementation of Ulysses
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft bwd implementation of Ulysses
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make swa work with ulysses
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* commit FP8 code for Ulysses
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix qkv type in the bwd of FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix qkv_dtype of FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor code change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* config cp correction dtype of FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code style change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save chunk_ids
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try to make Ulysses A2A async
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make more a2a async
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a2a_outputs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix chunk_ids generation for A2A
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* avoid code duplication of a2a before attn
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove code duplication of a2a after attn
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cp_stream in A2A implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix qkv of fp8_fwd + bf16_bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kernel order in cp a2a communication
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning for CP a2a
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix merging with main
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a2a communication order
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* adjust sequence chunk reordering for a2a
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add docstring for A2A implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change an assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add unit tests of A2A implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more A2A unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP unit tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more cp unit tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix window size of no_mask
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fused attn does not support swa+no_mask
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change num_gqa_groups to 2 for A2A implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* function and variable renaming
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning for CP all-gather implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* some function renaming
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* commit code change for kv all-gather implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix all-gather implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a window size check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add unit test of all_gather+no_mask
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix all-gather cp implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FP8 with A2A implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add paper references to CP implementations with all-gather and all-to-all
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change pdf to abs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* elaborate cp_comm_type
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP docstring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2a9845e1
...@@ -22,10 +22,16 @@ model_configs_flash_attn = { ...@@ -22,10 +22,16 @@ model_configs_flash_attn = {
"cp_1_2": ModelConfig( "cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA ), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_1_3": ModelConfig(
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig( "cp_2_2": ModelConfig(
2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # GQA ), # GQA
} }
...@@ -45,31 +51,32 @@ def get_bash_arguments(**kwargs): ...@@ -45,31 +51,32 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model] config = model_configs_flash_attn[model]
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd": if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip( pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip( pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
" type yet!" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
) )
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention" dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
), ),
check=True, check=True,
) )
...@@ -81,10 +88,16 @@ model_configs_fused_attn = { ...@@ -81,10 +88,16 @@ model_configs_fused_attn = {
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_1_4": ModelConfig(
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
"cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA ), # MHA
"cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
} }
...@@ -93,37 +106,27 @@ model_configs_fused_attn = { ...@@ -93,37 +106,27 @@ model_configs_fused_attn = {
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0): if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+.") pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0") pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
config = model_configs_fused_attn[model] config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!") pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!") pytest.skip("THD format does not support post_scale_bias yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd": if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip( pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
f"CP implementation with KV all-gather does not support {qkv_format} format yet!" if qkv_format == "thd" and cp_comm_type == "a2a":
) pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip( pytest.skip(
"Fused attention does not support sliding window attention + context parallelism yet!" "Sliding window attention only can be supported with the implementation of QKVO A2A!"
) )
if cp_comm_type == "all_gather" and dtype == "fp8": if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip( pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!" "CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
) )
...@@ -131,10 +134,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -131,10 +134,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("FP8 attention cannot work with THD format yet!") pytest.skip("FP8 attention cannot work with THD format yet!")
if dtype == "fp8" and config.attn_bias_type != "no_bias": if dtype == "fp8" and config.attn_bias_type != "no_bias":
pytest.skip("FP8 attention cannot work with bias yet!") pytest.skip("FP8 attention cannot work with bias yet!")
if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention" dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
), ),
check=True, check=True,
) )
This diff is collapsed.
...@@ -503,7 +503,13 @@ class TransformerLayer(torch.nn.Module): ...@@ -503,7 +503,13 @@ class TransformerLayer(torch.nn.Module):
cuda stream for context parallel execution. cuda stream for context parallel execution.
cp_comm_type : str cp_comm_type : str
inter-gpu communication type for context parallelism. inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather". Can be "p2p" or "all_gather" or "a2a".
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
""" """
# Deep iterate but skip self to avoid infinite recursion. # Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()): for index, child in enumerate(self.modules()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment