Unverified Commit 30407856 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add a CP implementation variant with KV all-gather. (#1060)



* add window_size to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo for cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets calculation of cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove a thd assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bias for thd test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add thd test for cudnn FA with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* skip GQA/MQA test for cuDNN THD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets inputs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove two comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn mask type for cudnn thd with cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type for cudnn fa with thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a typo
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix out dout in bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert cudnn+thd does not support attn bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if attn_mask_type has padding
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change cp test batch size to 2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix code format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix two assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert swa+CP cannot work with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a new CP function for swa
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a missing dgrads
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft fwd function for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable flash attention for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove an assert of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* call SWAFuncWithCP for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use 2hd layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change qkv_format check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a code comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* tensor shape bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tensor shape fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add function to compute cu_seqlens of a cp rank
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cu_seqlens and cu_seqlens_padded to context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FlashAttention output sequence length
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix cu_seqlens_kv_per_step calculation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV for ending padded tokens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV tensors of FlashAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove padded tokens of KV to save comounication
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not need to zero dkv for FlashAttention any mroe
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero out tensors
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of cp test with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add simple code framework
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try not to have a separate CP function for SWA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* backup some code change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* back up code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* clean up fwd implementation of SWAFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* reduce kv chunk concat overheads
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make AttnFuncWithCP and SWAFuncWithCP have same API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a docstring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* preliminary implementation of SWAFuncWithCP forward seems working
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix output shape of SWAFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring for FlashAttention and add a code placeholder for bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use gather_along_first_dim
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* finish the preliminary implementation of bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert condition
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft implementation of SWA+CP with FusedAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attention mask type of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add qkv_layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add missing window_size argument
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug and typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix dout shape
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add multi stream in fwd of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save chunk_ids_to_kv_ag in fwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add multi stream in bwd of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix to cp stream sync
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* rename AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if window size is None
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix docstring of AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add env var for users to choose KV ag or KV p2p
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix window size in cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix pytest skip messages
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cp_comm_type into API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* assert sequence length divisible requirements
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add support table of context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo and code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not print multiple disabling messages
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix device in torch.arange and adjust code for the PR of MLA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix typos and clean asserts
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXiaowei Ren <xren@cs-cw-dfw-login-01.cm.cluster>
parent 941364df
......@@ -13,7 +13,9 @@ from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fuse
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"):
def run_dpa_with_cp(
dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p"
):
"""Test DotProductAttention module with context parallelism"""
os.environ["NVTE_FLASH_ATTN"] = "0"
......@@ -24,10 +26,16 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
if qkv_format == "thd" and (
config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"
):
return
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
......@@ -49,73 +57,77 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
q_input_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
q_input_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "thd":
q_input_shape = (config.batch_size * config.max_seqlen_q, config.num_heads, config.head_dim)
q_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
......@@ -211,7 +223,9 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
)
out_ = core_attn(
q_,
k_,
......
......@@ -16,11 +16,17 @@ from transformer_engine.pytorch.utils import (
)
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
}
......@@ -39,7 +45,28 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_flash_attention(dtype, model, qkv_format):
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
)
subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
......@@ -49,7 +76,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
......@@ -66,9 +93,37 @@ model_configs_fused_attn = {
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_fused_attention(dtype, model, qkv_format):
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+.")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0")
config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"Fused attention does not support sliding window attention + context parallelism yet!"
)
subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
......
This diff is collapsed.
......@@ -487,6 +487,7 @@ class TransformerLayer(torch.nn.Module):
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
) -> None:
"""
Set the context parallel attributes for the given
......@@ -500,13 +501,16 @@ class TransformerLayer(torch.nn.Module):
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment