Unverified Commit 30407856 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add a CP implementation variant with KV all-gather. (#1060)



* add window_size to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo for cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets calculation of cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove a thd assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bias for thd test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add thd test for cudnn FA with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* skip GQA/MQA test for cuDNN THD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets inputs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove two comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn mask type for cudnn thd with cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type for cudnn fa with thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a typo
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix out dout in bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert cudnn+thd does not support attn bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if attn_mask_type has padding
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change cp test batch size to 2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix code format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix two assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert swa+CP cannot work with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a new CP function for swa
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a missing dgrads
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft fwd function for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable flash attention for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove an assert of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* call SWAFuncWithCP for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use 2hd layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change qkv_format check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a code comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* tensor shape bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tensor shape fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add function to compute cu_seqlens of a cp rank
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cu_seqlens and cu_seqlens_padded to context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FlashAttention output sequence length
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix cu_seqlens_kv_per_step calculation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV for ending padded tokens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV tensors of FlashAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove padded tokens of KV to save comounication
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not need to zero dkv for FlashAttention any mroe
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero out tensors
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of cp test with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add simple code framework
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try not to have a separate CP function for SWA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* backup some code change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* back up code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* clean up fwd implementation of SWAFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* reduce kv chunk concat overheads
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make AttnFuncWithCP and SWAFuncWithCP have same API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a docstring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* preliminary implementation of SWAFuncWithCP forward seems working
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix output shape of SWAFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring for FlashAttention and add a code placeholder for bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use gather_along_first_dim
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* finish the preliminary implementation of bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert condition
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft implementation of SWA+CP with FusedAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attention mask type of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add qkv_layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add missing window_size argument
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug and typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix dout shape
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add multi stream in fwd of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save chunk_ids_to_kv_ag in fwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add multi stream in bwd of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix to cp stream sync
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* rename AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if window size is None
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix docstring of AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add env var for users to choose KV ag or KV p2p
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix window size in cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix pytest skip messages
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cp_comm_type into API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* assert sequence length divisible requirements
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add support table of context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo and code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not print multiple disabling messages
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix device in torch.arange and adjust code for the PR of MLA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix typos and clean asserts
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXiaowei Ren <xren@cs-cw-dfw-login-01.cm.cluster>
parent 941364df
...@@ -13,7 +13,9 @@ from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fuse ...@@ -13,7 +13,9 @@ from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fuse
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"): def run_dpa_with_cp(
dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p"
):
"""Test DotProductAttention module with context parallelism""" """Test DotProductAttention module with context parallelism"""
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
...@@ -24,10 +26,16 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -24,10 +26,16 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
if kernel_backend == "FusedAttention": if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model] config = model_configs_fused_attn[model]
if qkv_format == "thd" and (
config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias" assert config.attn_mask_type in [
): "causal",
return "no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
rank = int(os.getenv("RANK", "0")) rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
...@@ -49,73 +57,77 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -49,73 +57,77 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
assert rank in cp_comm_ranks assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
# instantiate core attn module # instantiate core attn module
core_attn = DotProductAttention( core_attn = DotProductAttention(
config.num_heads, config.num_heads,
config.head_dim, config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups, num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
qkv_format=qkv_format, qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
) )
core_attn = core_attn.cuda() core_attn = core_attn.cuda()
# create flash attn inputs # create flash attn inputs
if qkv_format == "bshd": if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim) q_input_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = ( kv_input_shape = (
config.batch_size, config.batch_size,
config.max_seqlen_kv, config.max_seqlen_kv,
config.num_gqa_groups, config.num_gqa_groups,
config.head_dim, config.head_dim_qk,
) )
attn_output_shape = ( attn_output_shape = (
config.batch_size, config.batch_size,
config.max_seqlen_q, config.max_seqlen_q,
config.num_heads * config.head_dim, config.num_heads * config.head_dim_qk,
) )
cu_seqlens_q = None cu_seqlens_q = None
cu_seqlens_kv = None cu_seqlens_kv = None
cu_seqlens_q_padded = None cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None cu_seqlens_kv_padded = None
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim) q_input_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = ( kv_input_shape = (
config.max_seqlen_kv, config.max_seqlen_kv,
config.batch_size, config.batch_size,
config.num_gqa_groups, config.num_gqa_groups,
config.head_dim, config.head_dim_qk,
) )
attn_output_shape = ( attn_output_shape = (
config.max_seqlen_q, config.max_seqlen_q,
config.batch_size, config.batch_size,
config.num_heads * config.head_dim, config.num_heads * config.head_dim_qk,
) )
cu_seqlens_q = None cu_seqlens_q = None
cu_seqlens_kv = None cu_seqlens_kv = None
cu_seqlens_q_padded = None cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None cu_seqlens_kv_padded = None
elif qkv_format == "thd": elif qkv_format == "thd":
q_input_shape = (config.batch_size * config.max_seqlen_q, config.num_heads, config.head_dim) q_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = ( kv_input_shape = (
config.batch_size * config.max_seqlen_q, config.batch_size * config.max_seqlen_q,
config.num_gqa_groups, config.num_gqa_groups,
config.head_dim, config.head_dim_qk,
) )
attn_output_shape = ( attn_output_shape = (
config.batch_size * config.max_seqlen_q, config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim, config.num_heads * config.head_dim_qk,
) )
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2) seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
...@@ -211,7 +223,9 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ...@@ -211,7 +223,9 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
) )
bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream()) core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
)
out_ = core_attn( out_ = core_attn(
q_, q_,
k_, k_,
......
...@@ -16,11 +16,17 @@ from transformer_engine.pytorch.utils import ( ...@@ -16,11 +16,17 @@ from transformer_engine.pytorch.utils import (
) )
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
} }
...@@ -39,7 +45,28 @@ def get_bash_arguments(**kwargs): ...@@ -39,7 +45,28 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_flash_attention(dtype, model, qkv_format): @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
)
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention" dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
...@@ -49,7 +76,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format): ...@@ -49,7 +76,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):
model_configs_fused_attn = { model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
...@@ -66,9 +93,37 @@ model_configs_fused_attn = { ...@@ -66,9 +93,37 @@ model_configs_fused_attn = {
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_fused_attention(dtype, model, qkv_format): @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0): if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+.") pytest.skip("THD format is only supported on sm90+.")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0")
config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"Fused attention does not support sliding window attention + context parallelism yet!"
)
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention" dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
......
...@@ -65,6 +65,8 @@ from transformer_engine.pytorch.distributed import ( ...@@ -65,6 +65,8 @@ from transformer_engine.pytorch.distributed import (
set_all_rng_states, set_all_rng_states,
CudaRNGStatesTracker, CudaRNGStatesTracker,
graph_safe_rng_available, graph_safe_rng_available,
gather_along_first_dim,
reduce_scatter_along_first_dim,
) )
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
...@@ -321,13 +323,6 @@ def get_attention_backend( ...@@ -321,13 +323,6 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention as it requires compute capability sm80+") logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False use_fused_attention = False
# Filter: Context parallelism
if context_parallel and use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
use_unfused_attention = False
# Filter: Data type # Filter: Data type
if use_flash_attention and ( if use_flash_attention and (
qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor
...@@ -398,6 +393,81 @@ def get_attention_backend( ...@@ -398,6 +393,81 @@ def get_attention_backend(
) )
use_flash_attention = False use_flash_attention = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# ----------------------------------------------------------------------------------------------------
# bshd, sbhd | self-attention: | no_bias, post_scale_bias | FlashAttention, FusedAttention
# | no_mask, causal | |
# | cross-attention: | |
# | no_mask | |
# thd | self-attention: | no_bias | FlashAttention, FusedAttention
# | padding, padding_causal | | if no padding between sequences,
# | cross-attention: | | FusedAttention
# | padding | | if there is padding between sequences
# Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v.
if context_parallel and use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if "bottom_right" in attn_mask_type:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_flash_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with causal"
" masking for cross-attention"
)
use_flash_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with bias type"
" of %s",
core_attention_bias_type,
)
use_flash_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with attention"
" bias for THD format"
)
use_flash_attention = False
if context_parallel and use_fused_attention:
if "bottom_right" in attn_mask_type:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_fused_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with causal"
" masking for cross-attention"
)
use_fused_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with bias type"
" of %s",
core_attention_bias_type,
)
use_fused_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with attention"
" bias for THD format"
)
use_fused_attention = False
elif head_dim_qk != head_dim_v:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with MLA"
)
use_fused_attention = False
# Filter: Attention mask # Filter: Attention mask
# attn_mask_type | supported backends # attn_mask_type | supported backends
# ------------------------------------------------------------------- # -------------------------------------------------------------------
...@@ -498,11 +568,10 @@ def get_attention_backend( ...@@ -498,11 +568,10 @@ def get_attention_backend(
if ( if (
use_flash_attention use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0]) and (window_size[0] != -1 or window_size[1] not in [-1, 0])
and (not _flash_attn_2_3_plus or context_parallel) and not _flash_attn_2_3_plus
): ):
logger.debug( logger.debug(
"Disabling FlashAttention as sliding window attention requires " "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
"flash-attn 2.3+ and no context parallelism"
) )
use_flash_attention = False use_flash_attention = False
...@@ -1222,11 +1291,11 @@ def get_cu_seqlens_on_cp_rank( ...@@ -1222,11 +1291,11 @@ def get_cu_seqlens_on_cp_rank(
return cu_seqlens_on_cp_rank return cu_seqlens_on_cp_rank
class AttnFuncWithCP(torch.autograd.Function): class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
""" """
Attention implementation with context parallelism. Attention implementation with context parallelism. Exchange KV between CP ranks
Split attention compute into multiple steps, and overlap current-step with P2P in ring topology. Split attention compute into multiple steps, and overlap
compute with next-step communication. current-step compute with next-step communication.
""" """
@staticmethod @staticmethod
...@@ -1267,6 +1336,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1267,6 +1336,7 @@ class AttnFuncWithCP(torch.autograd.Function):
padding = "padding" in attn_mask_type padding = "padding" in attn_mask_type
if qkv_format in ["bshd", "sbhd"]: if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s")
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
else: else:
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
...@@ -1280,6 +1350,9 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1280,6 +1350,9 @@ class AttnFuncWithCP(torch.autograd.Function):
cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_q_per_step = [None for _ in range(cp_size)]
cu_seqlens_kv_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
assert qkv_format == "thd" or (
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"
if causal: if causal:
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, s, np, hn] -> [b, 2, s//2, np, hn] # [b, s, np, hn] -> [b, 2, s//2, np, hn]
...@@ -1295,6 +1368,9 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1295,6 +1368,9 @@ class AttnFuncWithCP(torch.autograd.Function):
"Only support bias shape of [b, h, sq, sk] for forward, " "Only support bias shape of [b, h, sq, sk] for forward, "
"and [1, h, sq, sk] for backward!" "and [1, h, sq, sk] for backward!"
) )
assert (
attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0
), "Sequence length does not meet divisible requirements!"
# [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
attn_bias_ = attn_bias.view( attn_bias_ = attn_bias.view(
*attn_bias.shape[:-2], *attn_bias.shape[:-2],
...@@ -1310,7 +1386,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1310,7 +1386,7 @@ class AttnFuncWithCP(torch.autograd.Function):
assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
fa_optional_forward_kwargs = {} fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1] fa_optional_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
if _flash_attn_2_4_plus: if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None fa_optional_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus: if _flash_attn_2_5_7_plus:
...@@ -1546,7 +1622,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1546,7 +1622,7 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, -1] fa_optional_forward_kwargs["window_size"] = (-1, -1)
( (
_, _,
_, _,
...@@ -1667,7 +1743,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1667,7 +1743,7 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, -1] fa_optional_forward_kwargs["window_size"] = (-1, -1)
( (
_, _,
_, _,
...@@ -1821,8 +1897,6 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1821,8 +1897,6 @@ class AttnFuncWithCP(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) torch.cuda.current_stream().wait_stream(flash_attn_streams[1])
softmax_lse = softmax_lse.to(torch.float) softmax_lse = softmax_lse.to(torch.float)
if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s")
for i in range(cp_size): for i in range(cp_size):
if qkv_format == "bshd": if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
...@@ -1849,8 +1923,6 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1849,8 +1923,6 @@ class AttnFuncWithCP(torch.autograd.Function):
cu_seqlens_q_padded, cu_seqlens_q_padded,
False, False,
) )
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
else: else:
if qkv_format in ["bshd", "sbhd"]: if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction( flash_attn_fwd_out_correction(
...@@ -1869,8 +1941,6 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1869,8 +1941,6 @@ class AttnFuncWithCP(torch.autograd.Function):
cu_seqlens_q_padded, cu_seqlens_q_padded,
True, True,
) )
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
kv = p2p_comm_buffers[-1] kv = p2p_comm_buffers[-1]
if use_fused_attention: if use_fused_attention:
...@@ -2056,7 +2126,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2056,7 +2126,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_ = out.view(-1, *out.shape[-2:]) out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, 0] fa_optional_backward_kwargs["window_size"] = (-1, 0)
_flash_attn_backward( _flash_attn_backward(
dout_, dout_,
q_, q_,
...@@ -2141,7 +2211,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2141,7 +2211,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_ = out.view(-1, *out.shape[-2:]) out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1] fa_optional_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward( _flash_attn_backward(
dout_, dout_,
q_, q_,
...@@ -2232,7 +2302,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2232,7 +2302,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1] fa_optional_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward( _flash_attn_backward(
dout_, dout_,
q_, q_,
...@@ -2291,7 +2361,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2291,7 +2361,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_ = out.view(-1, *out.shape[-2:]) out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1] fa_optional_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward( _flash_attn_backward(
dout_, dout_,
q_, q_,
...@@ -2486,6 +2556,455 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2486,6 +2556,455 @@ class AttnFuncWithCP(torch.autograd.Function):
) )
@jit_fuser
def get_seq_chunk_ids_to_all_gathered_kv(
local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left
):
"""Compute sequence chunk ids to the all-gathered KV."""
seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv
seq_start_idx = max(0, seq_end_idx - max_seqlen_q - window_size_left)
seqlen = seq_end_idx - seq_start_idx
num_chunks = (seqlen + max_seqlen_kv - 1) // max_seqlen_kv
chunk_ids = torch.arange(
local_chunk_id - num_chunks + 1,
local_chunk_id + 1,
dtype=torch.int32,
device="cuda",
)
chunk_ids_to_all_gathered_kv = torch.where(
chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1
)
return chunk_ids_to_all_gathered_kv
class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
"""
Attention implementation with context parallelism.
KV all-gather between CP ranks is exposed.
"""
@staticmethod
def forward(
ctx,
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
attn_bias_type,
attn_bias,
deterministic,
use_fused_attention,
window_size,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
cp_size = get_distributed_world_size(cp_group)
rank = get_distributed_rank(cp_group)
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
assert causal and not padding, f"{attn_mask_type} mask type is not supported!"
if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
attn_mask_type = attn_mask_type + "_bottom_right"
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
assert (
use_fused_attention or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
fa_optional_forward_kwargs = {}
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
assert qkv_format != "thd", f"{qkv_format} format is not supported!"
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
seq_dim = qkv_format.index("s")
assert (
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"
max_seqlen_q = max_seqlen_q // (2 * cp_size)
max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
cu_seqlens_kv = cu_seqlens_kv // (2 * cp_size)
cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size)
if causal:
if qkv_format == "bshd":
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q = q.view(q.shape[0], 2, q.shape[1] // 2, *q.shape[2:])
# [b, s, np, hn] -> [s, b, np, hn]
k, v = [x.transpose(0, 1).contiguous() for x in [k, v]]
elif qkv_format == "sbhd":
# [s, b, np, hn] -> [2, s//2, b, np, hn]
q = q.view(2, q.shape[0] // 2, *q.shape[1:])
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
k_ag, _ = gather_along_first_dim(k, cp_group)
v_ag, _ = gather_along_first_dim(v, cp_group)
cp_stream.wait_stream(torch.cuda.current_stream())
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
chunk_ids_to_kv_ag_per_step = [None, None]
out_per_step = [None, None]
softmax_lse_per_step = [None, None]
rng_states = [None, None]
out = torch.empty_like(q)
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
with torch.cuda.stream(flash_attn_streams[i]):
chunk_ids_to_kv_ag = get_seq_chunk_ids_to_all_gathered_kv(
local_seq_chunk_ids[i],
cp_size,
max_seqlen_q,
max_seqlen_kv,
(
max_seqlen_kv * cp_size * 2
if (window_size is None or window_size[0] == -1)
else window_size[0]
),
)
chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag
num_kv_chunks = chunk_ids_to_kv_ag.numel()
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn]
k_ = (
torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(k.shape[1], -1, *k.shape[-2:])
)
v_ = (
torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(v.shape[1], -1, *v.shape[-2:])
)
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q[i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn]
k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *k.shape[-3:]
)
v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *v.shape[-3:]
)
if use_fused_attention:
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv * num_kv_chunks,
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
q_,
k_,
v_,
TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
attn_bias=attn_bias,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks,
window_size=window_size,
)
else:
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
_, _, _, _, out_per_step[i], softmax_lse_per_step[i], _, rng_states[i] = (
_flash_attn_forward(
q_,
k_,
v_,
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
max_seqlen_q,
max_seqlen_kv * num_kv_chunks,
dropout_p,
softmax_scale,
causal=True,
return_softmax=False,
window_size=window_size,
**fa_optional_forward_kwargs,
)
)
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
if qkv_format == "bshd":
out[:, i - 1].copy_(out_per_step[i - 1].view_as(out[:, i - 1]))
elif qkv_format == "sbhd":
out[i - 1].copy_(out_per_step[i - 1].view_as(out[i - 1]))
torch.cuda.current_stream().wait_stream(cp_stream)
if use_fused_attention:
if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:])
elif qkv_format == "sbhd":
out = out.view(-1, *out.shape[-3:])
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
*chunk_ids_to_kv_ag_per_step,
*out_per_step,
*softmax_lse_per_step,
*rng_states,
)
ctx.cp_group = cp_group
ctx.cp_stream = cp_stream
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.softmax_scale = softmax_scale
ctx.qkv_format = qkv_format
ctx.attn_mask_type = attn_mask_type
ctx.attn_bias_type = attn_bias_type
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
ctx.window_size = window_size
return out
@staticmethod
def backward(ctx, dout):
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
(q, k, v, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded) = (
ctx.saved_tensors[:7]
)
chunk_ids_to_kv_ag_per_step = ctx.saved_tensors[7:9]
out_per_step = ctx.saved_tensors[9:11]
softmax_lse_per_step = ctx.saved_tensors[11:13]
rng_states = ctx.saved_tensors[13:15]
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
dout = dout.view_as(q)
dq = torch.empty_like(q)
dk = torch.zeros(
(2 * cp_size, k.shape[0] // 2, *k.shape[1:]), dtype=k.dtype, device=k.device
)
dv = torch.zeros_like(dk)
dq_per_step = [None, None]
dk_per_step = [None, None]
dv_per_step = [None, None]
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream]
# synchronize dkv update across steps
dkv_update_done = torch.cuda.Event()
k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
ctx.cp_stream.wait_stream(torch.cuda.current_stream())
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
fa_optional_backward_kwargs = {}
if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
with torch.cuda.stream(flash_attn_streams[i]):
chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i]
num_kv_chunks = chunk_ids_to_kv_ag.numel()
out_ = out_per_step[i]
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn]
k_ = (
torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(k.shape[1], -1, *k.shape[-2:])
)
v_ = (
torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(v.shape[1], -1, *v.shape[-2:])
)
dout_ = dout[:, i].contiguous().view_as(out_)
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q[i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn]
k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *k.shape[-3:]
)
v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *v.shape[-3:]
)
dout_ = dout[i].contiguous().view_as(out_)
if ctx.use_fused_attention:
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
]
aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv * num_kv_chunks,
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
q_,
k_,
v_,
out_,
dout_,
TE_DType[q.dtype],
TE_DType[k.dtype],
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type,
window_size=ctx.window_size,
)
else:
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
]
_flash_attn_backward(
dout_,
q_,
k_,
v_,
out_,
softmax_lse_per_step[i],
dq_per_step[i],
dk_per_step[i],
dv_per_step[i],
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
ctx.max_seqlen_q,
ctx.max_seqlen_kv * num_kv_chunks,
ctx.dropout_p,
ctx.softmax_scale,
True,
window_size=ctx.window_size,
rng_state=rng_states[i],
**fa_optional_backward_kwargs,
)
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i - 1]
num_kv_chunks = chunk_ids_to_kv_ag.numel()
if ctx.qkv_format == "bshd":
dq[:, i - 1].copy_(dq_per_step[i - 1].view_as(dq[:, i - 1]))
dk_per_step[i - 1] = (
dk_per_step[i - 1]
.view(k.shape[1], num_kv_chunks, -1, *k.shape[-2:])
.movedim(0, 2)
.contiguous()
)
dv_per_step[i - 1] = (
dv_per_step[i - 1]
.view(v.shape[1], num_kv_chunks, -1, *v.shape[-2:])
.movedim(0, 2)
.contiguous()
)
elif ctx.qkv_format == "sbhd":
dq[i - 1].copy_(dq_per_step[i - 1].view_as(dq[i - 1]))
dk_per_step[i - 1] = dk_per_step[i - 1].view(
num_kv_chunks, -1, *k.shape[-3:]
)
dv_per_step[i - 1] = dv_per_step[i - 1].view(
num_kv_chunks, -1, *v.shape[-3:]
)
# wait until dkv update of last step is done
if i > 1:
flash_attn_streams[i - 1].wait_event(dkv_update_done)
dk.index_add_(0, chunk_ids_to_kv_ag, dk_per_step[i - 1])
dv.index_add_(0, chunk_ids_to_kv_ag, dv_per_step[i - 1])
if i < len(local_seq_chunk_ids):
flash_attn_streams[i - 1].record_event(dkv_update_done)
torch.cuda.current_stream().wait_stream(ctx.cp_stream)
dk = dk.view(-1, *dk.shape[-3:])
dv = dv.view(-1, *dv.shape[-3:])
dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)
if ctx.qkv_format == "bshd":
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
dk = dk.transpose(0, 1).contiguous()
dv = dv.transpose(0, 1).contiguous()
elif ctx.qkv_format == "sbhd":
dq = dq.view(-1, *dq.shape[-3:])
return (
None,
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
def attn_forward_func_with_cp( def attn_forward_func_with_cp(
is_training, is_training,
q, q,
...@@ -2501,6 +3020,7 @@ def attn_forward_func_with_cp( ...@@ -2501,6 +3020,7 @@ def attn_forward_func_with_cp(
cp_group, cp_group,
cp_global_ranks, cp_global_ranks,
cp_stream, cp_stream,
cp_comm_type,
softmax_scale=None, softmax_scale=None,
qkv_format="bshd", qkv_format="bshd",
attn_mask_type="causal", attn_mask_type="causal",
...@@ -2508,8 +3028,12 @@ def attn_forward_func_with_cp( ...@@ -2508,8 +3028,12 @@ def attn_forward_func_with_cp(
attn_bias=None, attn_bias=None,
deterministic=False, deterministic=False,
use_fused_attention=False, use_fused_attention=False,
window_size=None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Attention implementation with context parallelism""" """
Attention implementation with context parallelism.
"""
assert qkv_format in [ assert qkv_format in [
"bshd", "bshd",
"sbhd", "sbhd",
...@@ -2533,29 +3057,62 @@ def attn_forward_func_with_cp( ...@@ -2533,29 +3057,62 @@ def attn_forward_func_with_cp(
assert ( assert (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!" ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
out = AttnFuncWithCP.apply(
is_training, sliding_window_attn = (
q, window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
attn_bias_type,
attn_bias,
deterministic,
use_fused_attention,
) )
if sliding_window_attn or cp_comm_type == "all_gather":
out = AttnFuncWithCPAndKVAllGather.apply(
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
attn_bias_type,
attn_bias,
deterministic,
use_fused_attention,
window_size,
)
elif cp_comm_type == "p2p":
out = AttnFuncWithCPAndKVP2P.apply(
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
attn_bias_type,
attn_bias,
deterministic,
use_fused_attention,
)
else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
return out return out
...@@ -3316,6 +3873,7 @@ class FlashAttention(torch.nn.Module): ...@@ -3316,6 +3873,7 @@ class FlashAttention(torch.nn.Module):
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
...@@ -3424,10 +3982,6 @@ class FlashAttention(torch.nn.Module): ...@@ -3424,10 +3982,6 @@ class FlashAttention(torch.nn.Module):
max_seqlen_kv = seqlens_kv.max().item() max_seqlen_kv = seqlens_kv.max().item()
if context_parallel: if context_parallel:
assert window_size in (
(-1, -1),
(-1, 0),
), "Sliding window attention is not supported with context parallelism."
assert ( assert (
alibi_slopes is None alibi_slopes is None
), "Alibi slope bias addition is not supported with context parallelism." ), "Alibi slope bias addition is not supported with context parallelism."
...@@ -3447,10 +4001,12 @@ class FlashAttention(torch.nn.Module): ...@@ -3447,10 +4001,12 @@ class FlashAttention(torch.nn.Module):
cp_group, cp_group,
cp_global_ranks, cp_global_ranks,
cp_stream, cp_stream,
cp_comm_type,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
qkv_format="bshd" if qkv_format == "sbhd" else qkv_format, qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
deterministic=self.deterministic, deterministic=self.deterministic,
window_size=window_size,
) )
else: else:
...@@ -4995,6 +5551,7 @@ class FusedAttention(torch.nn.Module): ...@@ -4995,6 +5551,7 @@ class FusedAttention(torch.nn.Module):
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
fp8: bool = False, fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -5107,12 +5664,14 @@ class FusedAttention(torch.nn.Module): ...@@ -5107,12 +5664,14 @@ class FusedAttention(torch.nn.Module):
cp_group, cp_group,
cp_global_ranks, cp_global_ranks,
cp_stream, cp_stream,
cp_comm_type,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
qkv_format=qkv_format, qkv_format=qkv_format,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=core_attention_bias_type, attn_bias_type=core_attention_bias_type,
attn_bias=core_attention_bias, attn_bias=core_attention_bias,
use_fused_attention=True, use_fused_attention=True,
window_size=window_size,
) )
else: else:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -5260,6 +5819,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5260,6 +5819,9 @@ class DotProductAttention(TransformerEngineBaseModule):
compute and communication overlapping. To address the wave quantization compute and communication overlapping. To address the wave quantization
issue of each split step, we add an additional CUDA stream so that we issue of each split step, we add an additional CUDA stream so that we
can overlap two flash attention kernels. can overlap two flash attention kernels.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
""" """
def __init__( def __init__(
...@@ -5280,6 +5842,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5280,6 +5842,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -5307,6 +5870,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5307,6 +5870,7 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_group = cp_group self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream self.cp_stream = cp_stream
self.cp_comm_type = cp_comm_type
self.hidden_size_per_attention_head_k = ( self.hidden_size_per_attention_head_k = (
kv_channels if isinstance(kv_channels, int) else kv_channels[0] kv_channels if isinstance(kv_channels, int) else kv_channels[0]
...@@ -5430,6 +5994,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5430,6 +5994,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_group: Union[dist_group_type, None], cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
) -> None: ) -> None:
""" """
Set the context parallel attributes for the given Set the context parallel attributes for the given
...@@ -5443,10 +6008,14 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5443,10 +6008,14 @@ class DotProductAttention(TransformerEngineBaseModule):
list of global ranks in the context group. list of global ranks in the context group.
cp_stream : torch.cuda.Stream cp_stream : torch.cuda.Stream
cuda stream for context parallel execution. cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
""" """
self.cp_group = cp_group self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream self.cp_stream = cp_stream
self.cp_comm_type = cp_comm_type
@no_torch_dynamo(recursive=False) @no_torch_dynamo(recursive=False)
def forward( def forward(
...@@ -5943,6 +6512,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5943,6 +6512,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream, cp_stream=self.cp_stream,
cp_comm_type=self.cp_comm_type,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
) )
...@@ -5985,6 +6555,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5985,6 +6555,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream, cp_stream=self.cp_stream,
cp_comm_type=self.cp_comm_type,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
) )
...@@ -6009,6 +6580,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6009,6 +6580,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream, cp_stream=self.cp_stream,
cp_comm_type=self.cp_comm_type,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
) )
...@@ -6437,6 +7009,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6437,6 +7009,7 @@ class MultiheadAttention(torch.nn.Module):
cp_group: Union[dist_group_type, None], cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
) -> None: ) -> None:
""" """
Set the context parallel attributes for the given Set the context parallel attributes for the given
...@@ -6450,13 +7023,16 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6450,13 +7023,16 @@ class MultiheadAttention(torch.nn.Module):
list of global ranks in the context group. list of global ranks in the context group.
cp_stream : torch.cuda.Stream cp_stream : torch.cuda.Stream
cuda stream for context parallel execution. cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
""" """
# Deep iterate but skip self to avoid infinite recursion. # Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()): for index, child in enumerate(self.modules()):
if index == 0: if index == 0:
continue continue
if hasattr(child, "set_context_parallel_group"): if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream) child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
def forward( def forward(
self, self,
......
...@@ -487,6 +487,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -487,6 +487,7 @@ class TransformerLayer(torch.nn.Module):
cp_group: Union[dist_group_type, None], cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
) -> None: ) -> None:
""" """
Set the context parallel attributes for the given Set the context parallel attributes for the given
...@@ -500,13 +501,16 @@ class TransformerLayer(torch.nn.Module): ...@@ -500,13 +501,16 @@ class TransformerLayer(torch.nn.Module):
list of global ranks in the context group. list of global ranks in the context group.
cp_stream : torch.cuda.Stream cp_stream : torch.cuda.Stream
cuda stream for context parallel execution. cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
""" """
# Deep iterate but skip self to avoid infinite recursion. # Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()): for index, child in enumerate(self.modules()):
if index == 0: if index == 0:
continue continue
if hasattr(child, "set_context_parallel_group"): if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream) child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment