Unverified Commit 30407856 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add a CP implementation variant with KV all-gather. (#1060)



* add window_size to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo for cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets calculation of cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove a thd assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bias for thd test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add thd test for cudnn FA with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* skip GQA/MQA test for cuDNN THD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets inputs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove two comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn mask type for cudnn thd with cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type for cudnn fa with thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a typo
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix out dout in bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert cudnn+thd does not support attn bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if attn_mask_type has padding
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change cp test batch size to 2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix code format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix two assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert swa+CP cannot work with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a new CP function for swa
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a missing dgrads
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft fwd function for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable flash attention for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove an assert of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* call SWAFuncWithCP for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use 2hd layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change qkv_format check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a code comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* tensor shape bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tensor shape fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add function to compute cu_seqlens of a cp rank
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cu_seqlens and cu_seqlens_padded to context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FlashAttention output sequence length
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix cu_seqlens_kv_per_step calculation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV for ending padded tokens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV tensors of FlashAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove padded tokens of KV to save comounication
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not need to zero dkv for FlashAttention any mroe
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero out tensors
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of cp test with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add simple code framework
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try not to have a separate CP function for SWA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* backup some code change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* back up code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* clean up fwd implementation of SWAFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* reduce kv chunk concat overheads
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make AttnFuncWithCP and SWAFuncWithCP have same API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a docstring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* preliminary implementation of SWAFuncWithCP forward seems working
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix output shape of SWAFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring for FlashAttention and add a code placeholder for bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use gather_along_first_dim
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* finish the preliminary implementation of bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert condition
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft implementation of SWA+CP with FusedAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attention mask type of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add qkv_layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add missing window_size argument
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug and typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix dout shape
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add multi stream in fwd of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save chunk_ids_to_kv_ag in fwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add multi stream in bwd of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix to cp stream sync
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* rename AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if window size is None
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix docstring of AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add env var for users to choose KV ag or KV p2p
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix window size in cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix pytest skip messages
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cp_comm_type into API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* assert sequence length divisible requirements
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add support table of context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo and code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not print multiple disabling messages
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix device in torch.arange and adjust code for the PR of MLA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix typos and clean asserts
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXiaowei Ren <xren@cs-cw-dfw-login-01.cm.cluster>
parent 941364df
......@@ -13,7 +13,9 @@ from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fuse
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"):
def run_dpa_with_cp(
dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p"
):
"""Test DotProductAttention module with context parallelism"""
os.environ["NVTE_FLASH_ATTN"] = "0"
......@@ -24,10 +26,16 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
if qkv_format == "thd" and (
config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"
):
return
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
......@@ -49,73 +57,77 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
q_input_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
q_input_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "thd":
q_input_shape = (config.batch_size * config.max_seqlen_q, config.num_heads, config.head_dim)
q_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
......@@ -211,7 +223,9 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
)
out_ = core_attn(
q_,
k_,
......
......@@ -16,11 +16,17 @@ from transformer_engine.pytorch.utils import (
)
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
}
......@@ -39,7 +45,28 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_flash_attention(dtype, model, qkv_format):
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
)
subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
......@@ -49,7 +76,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
......@@ -66,9 +93,37 @@ model_configs_fused_attn = {
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_fused_attention(dtype, model, qkv_format):
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+.")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0")
config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"Fused attention does not support sliding window attention + context parallelism yet!"
)
subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
......
......@@ -65,6 +65,8 @@ from transformer_engine.pytorch.distributed import (
set_all_rng_states,
CudaRNGStatesTracker,
graph_safe_rng_available,
gather_along_first_dim,
reduce_scatter_along_first_dim,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
......@@ -321,13 +323,6 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False
# Filter: Context parallelism
if context_parallel and use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
use_unfused_attention = False
# Filter: Data type
if use_flash_attention and (
qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor
......@@ -398,6 +393,81 @@ def get_attention_backend(
)
use_flash_attention = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# ----------------------------------------------------------------------------------------------------
# bshd, sbhd | self-attention: | no_bias, post_scale_bias | FlashAttention, FusedAttention
# | no_mask, causal | |
# | cross-attention: | |
# | no_mask | |
# thd | self-attention: | no_bias | FlashAttention, FusedAttention
# | padding, padding_causal | | if no padding between sequences,
# | cross-attention: | | FusedAttention
# | padding | | if there is padding between sequences
# Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v.
if context_parallel and use_unfused_attention:
logger.debug(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if "bottom_right" in attn_mask_type:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_flash_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with causal"
" masking for cross-attention"
)
use_flash_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with bias type"
" of %s",
core_attention_bias_type,
)
use_flash_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with attention"
" bias for THD format"
)
use_flash_attention = False
if context_parallel and use_fused_attention:
if "bottom_right" in attn_mask_type:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_fused_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with causal"
" masking for cross-attention"
)
use_fused_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with bias type"
" of %s",
core_attention_bias_type,
)
use_fused_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with attention"
" bias for THD format"
)
use_fused_attention = False
elif head_dim_qk != head_dim_v:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with MLA"
)
use_fused_attention = False
# Filter: Attention mask
# attn_mask_type | supported backends
# -------------------------------------------------------------------
......@@ -498,11 +568,10 @@ def get_attention_backend(
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
and (not _flash_attn_2_3_plus or context_parallel)
and not _flash_attn_2_3_plus
):
logger.debug(
"Disabling FlashAttention as sliding window attention requires "
"flash-attn 2.3+ and no context parallelism"
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention = False
......@@ -1222,11 +1291,11 @@ def get_cu_seqlens_on_cp_rank(
return cu_seqlens_on_cp_rank
class AttnFuncWithCP(torch.autograd.Function):
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
"""
Attention implementation with context parallelism.
Split attention compute into multiple steps, and overlap current-step
compute with next-step communication.
Attention implementation with context parallelism. Exchange KV between CP ranks
with P2P in ring topology. Split attention compute into multiple steps, and overlap
current-step compute with next-step communication.
"""
@staticmethod
......@@ -1267,6 +1336,7 @@ class AttnFuncWithCP(torch.autograd.Function):
padding = "padding" in attn_mask_type
if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s")
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
else:
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
......@@ -1280,6 +1350,9 @@ class AttnFuncWithCP(torch.autograd.Function):
cu_seqlens_q_per_step = [None for _ in range(cp_size)]
cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
assert qkv_format == "thd" or (
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"
if causal:
if qkv_format == "bshd":
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
......@@ -1295,6 +1368,9 @@ class AttnFuncWithCP(torch.autograd.Function):
"Only support bias shape of [b, h, sq, sk] for forward, "
"and [1, h, sq, sk] for backward!"
)
assert (
attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0
), "Sequence length does not meet divisible requirements!"
# [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
attn_bias_ = attn_bias.view(
*attn_bias.shape[:-2],
......@@ -1310,7 +1386,7 @@ class AttnFuncWithCP(torch.autograd.Function):
assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
fa_optional_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
if _flash_attn_2_5_7_plus:
......@@ -1546,7 +1622,7 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, -1]
fa_optional_forward_kwargs["window_size"] = (-1, -1)
(
_,
_,
......@@ -1667,7 +1743,7 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, -1]
fa_optional_forward_kwargs["window_size"] = (-1, -1)
(
_,
_,
......@@ -1821,8 +1897,6 @@ class AttnFuncWithCP(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])
softmax_lse = softmax_lse.to(torch.float)
if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s")
for i in range(cp_size):
if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
......@@ -1849,8 +1923,6 @@ class AttnFuncWithCP(torch.autograd.Function):
cu_seqlens_q_padded,
False,
)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
else:
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(
......@@ -1869,8 +1941,6 @@ class AttnFuncWithCP(torch.autograd.Function):
cu_seqlens_q_padded,
True,
)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
kv = p2p_comm_buffers[-1]
if use_fused_attention:
......@@ -2056,7 +2126,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, 0]
fa_optional_backward_kwargs["window_size"] = (-1, 0)
_flash_attn_backward(
dout_,
q_,
......@@ -2141,7 +2211,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
fa_optional_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward(
dout_,
q_,
......@@ -2232,7 +2302,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
fa_optional_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward(
dout_,
q_,
......@@ -2291,7 +2361,7 @@ class AttnFuncWithCP(torch.autograd.Function):
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
fa_optional_backward_kwargs["window_size"] = (-1, -1)
_flash_attn_backward(
dout_,
q_,
......@@ -2486,6 +2556,455 @@ class AttnFuncWithCP(torch.autograd.Function):
)
@jit_fuser
def get_seq_chunk_ids_to_all_gathered_kv(
local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left
):
"""Compute sequence chunk ids to the all-gathered KV."""
seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv
seq_start_idx = max(0, seq_end_idx - max_seqlen_q - window_size_left)
seqlen = seq_end_idx - seq_start_idx
num_chunks = (seqlen + max_seqlen_kv - 1) // max_seqlen_kv
chunk_ids = torch.arange(
local_chunk_id - num_chunks + 1,
local_chunk_id + 1,
dtype=torch.int32,
device="cuda",
)
chunk_ids_to_all_gathered_kv = torch.where(
chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1
)
return chunk_ids_to_all_gathered_kv
class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
"""
Attention implementation with context parallelism.
KV all-gather between CP ranks is exposed.
"""
@staticmethod
def forward(
ctx,
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
attn_bias_type,
attn_bias,
deterministic,
use_fused_attention,
window_size,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
cp_size = get_distributed_world_size(cp_group)
rank = get_distributed_rank(cp_group)
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
assert causal and not padding, f"{attn_mask_type} mask type is not supported!"
if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
attn_mask_type = attn_mask_type + "_bottom_right"
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
assert (
use_fused_attention or _flash_attn_2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
fa_optional_forward_kwargs = {}
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
assert qkv_format != "thd", f"{qkv_format} format is not supported!"
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
seq_dim = qkv_format.index("s")
assert (
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"
max_seqlen_q = max_seqlen_q // (2 * cp_size)
max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
cu_seqlens_kv = cu_seqlens_kv // (2 * cp_size)
cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size)
if causal:
if qkv_format == "bshd":
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q = q.view(q.shape[0], 2, q.shape[1] // 2, *q.shape[2:])
# [b, s, np, hn] -> [s, b, np, hn]
k, v = [x.transpose(0, 1).contiguous() for x in [k, v]]
elif qkv_format == "sbhd":
# [s, b, np, hn] -> [2, s//2, b, np, hn]
q = q.view(2, q.shape[0] // 2, *q.shape[1:])
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
k_ag, _ = gather_along_first_dim(k, cp_group)
v_ag, _ = gather_along_first_dim(v, cp_group)
cp_stream.wait_stream(torch.cuda.current_stream())
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
chunk_ids_to_kv_ag_per_step = [None, None]
out_per_step = [None, None]
softmax_lse_per_step = [None, None]
rng_states = [None, None]
out = torch.empty_like(q)
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
with torch.cuda.stream(flash_attn_streams[i]):
chunk_ids_to_kv_ag = get_seq_chunk_ids_to_all_gathered_kv(
local_seq_chunk_ids[i],
cp_size,
max_seqlen_q,
max_seqlen_kv,
(
max_seqlen_kv * cp_size * 2
if (window_size is None or window_size[0] == -1)
else window_size[0]
),
)
chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag
num_kv_chunks = chunk_ids_to_kv_ag.numel()
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn]
k_ = (
torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(k.shape[1], -1, *k.shape[-2:])
)
v_ = (
torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(v.shape[1], -1, *v.shape[-2:])
)
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q[i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn]
k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *k.shape[-3:]
)
v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *v.shape[-3:]
)
if use_fused_attention:
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv * num_kv_chunks,
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
q_,
k_,
v_,
TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
attn_bias=attn_bias,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks,
window_size=window_size,
)
else:
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
_, _, _, _, out_per_step[i], softmax_lse_per_step[i], _, rng_states[i] = (
_flash_attn_forward(
q_,
k_,
v_,
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
max_seqlen_q,
max_seqlen_kv * num_kv_chunks,
dropout_p,
softmax_scale,
causal=True,
return_softmax=False,
window_size=window_size,
**fa_optional_forward_kwargs,
)
)
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
if qkv_format == "bshd":
out[:, i - 1].copy_(out_per_step[i - 1].view_as(out[:, i - 1]))
elif qkv_format == "sbhd":
out[i - 1].copy_(out_per_step[i - 1].view_as(out[i - 1]))
torch.cuda.current_stream().wait_stream(cp_stream)
if use_fused_attention:
if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:])
elif qkv_format == "sbhd":
out = out.view(-1, *out.shape[-3:])
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
*chunk_ids_to_kv_ag_per_step,
*out_per_step,
*softmax_lse_per_step,
*rng_states,
)
ctx.cp_group = cp_group
ctx.cp_stream = cp_stream
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.softmax_scale = softmax_scale
ctx.qkv_format = qkv_format
ctx.attn_mask_type = attn_mask_type
ctx.attn_bias_type = attn_bias_type
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
ctx.window_size = window_size
return out
@staticmethod
def backward(ctx, dout):
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
(q, k, v, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded) = (
ctx.saved_tensors[:7]
)
chunk_ids_to_kv_ag_per_step = ctx.saved_tensors[7:9]
out_per_step = ctx.saved_tensors[9:11]
softmax_lse_per_step = ctx.saved_tensors[11:13]
rng_states = ctx.saved_tensors[13:15]
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
dout = dout.view_as(q)
dq = torch.empty_like(q)
dk = torch.zeros(
(2 * cp_size, k.shape[0] // 2, *k.shape[1:]), dtype=k.dtype, device=k.device
)
dv = torch.zeros_like(dk)
dq_per_step = [None, None]
dk_per_step = [None, None]
dv_per_step = [None, None]
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream]
# synchronize dkv update across steps
dkv_update_done = torch.cuda.Event()
k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
ctx.cp_stream.wait_stream(torch.cuda.current_stream())
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
fa_optional_backward_kwargs = {}
if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
with torch.cuda.stream(flash_attn_streams[i]):
chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i]
num_kv_chunks = chunk_ids_to_kv_ag.numel()
out_ = out_per_step[i]
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn]
k_ = (
torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(k.shape[1], -1, *k.shape[-2:])
)
v_ = (
torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag)
.movedim(2, 0)
.contiguous()
.view(v.shape[1], -1, *v.shape[-2:])
)
dout_ = dout[:, i].contiguous().view_as(out_)
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_ = q[i].contiguous()
# [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn]
k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *k.shape[-3:]
)
v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view(
-1, *v.shape[-3:]
)
dout_ = dout[i].contiguous().view_as(out_)
if ctx.use_fused_attention:
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
]
aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv * num_kv_chunks,
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
q_,
k_,
v_,
out_,
dout_,
TE_DType[q.dtype],
TE_DType[k.dtype],
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type,
window_size=ctx.window_size,
)
else:
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
]
_flash_attn_backward(
dout_,
q_,
k_,
v_,
out_,
softmax_lse_per_step[i],
dq_per_step[i],
dk_per_step[i],
dv_per_step[i],
cu_seqlens_q,
cu_seqlens_kv * num_kv_chunks,
ctx.max_seqlen_q,
ctx.max_seqlen_kv * num_kv_chunks,
ctx.dropout_p,
ctx.softmax_scale,
True,
window_size=ctx.window_size,
rng_state=rng_states[i],
**fa_optional_backward_kwargs,
)
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i - 1]
num_kv_chunks = chunk_ids_to_kv_ag.numel()
if ctx.qkv_format == "bshd":
dq[:, i - 1].copy_(dq_per_step[i - 1].view_as(dq[:, i - 1]))
dk_per_step[i - 1] = (
dk_per_step[i - 1]
.view(k.shape[1], num_kv_chunks, -1, *k.shape[-2:])
.movedim(0, 2)
.contiguous()
)
dv_per_step[i - 1] = (
dv_per_step[i - 1]
.view(v.shape[1], num_kv_chunks, -1, *v.shape[-2:])
.movedim(0, 2)
.contiguous()
)
elif ctx.qkv_format == "sbhd":
dq[i - 1].copy_(dq_per_step[i - 1].view_as(dq[i - 1]))
dk_per_step[i - 1] = dk_per_step[i - 1].view(
num_kv_chunks, -1, *k.shape[-3:]
)
dv_per_step[i - 1] = dv_per_step[i - 1].view(
num_kv_chunks, -1, *v.shape[-3:]
)
# wait until dkv update of last step is done
if i > 1:
flash_attn_streams[i - 1].wait_event(dkv_update_done)
dk.index_add_(0, chunk_ids_to_kv_ag, dk_per_step[i - 1])
dv.index_add_(0, chunk_ids_to_kv_ag, dv_per_step[i - 1])
if i < len(local_seq_chunk_ids):
flash_attn_streams[i - 1].record_event(dkv_update_done)
torch.cuda.current_stream().wait_stream(ctx.cp_stream)
dk = dk.view(-1, *dk.shape[-3:])
dv = dv.view(-1, *dv.shape[-3:])
dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)
if ctx.qkv_format == "bshd":
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
dk = dk.transpose(0, 1).contiguous()
dv = dv.transpose(0, 1).contiguous()
elif ctx.qkv_format == "sbhd":
dq = dq.view(-1, *dq.shape[-3:])
return (
None,
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
def attn_forward_func_with_cp(
is_training,
q,
......@@ -2501,6 +3020,7 @@ def attn_forward_func_with_cp(
cp_group,
cp_global_ranks,
cp_stream,
cp_comm_type,
softmax_scale=None,
qkv_format="bshd",
attn_mask_type="causal",
......@@ -2508,8 +3028,12 @@ def attn_forward_func_with_cp(
attn_bias=None,
deterministic=False,
use_fused_attention=False,
window_size=None,
) -> torch.Tensor:
"""Attention implementation with context parallelism"""
"""
Attention implementation with context parallelism.
"""
assert qkv_format in [
"bshd",
"sbhd",
......@@ -2533,29 +3057,62 @@ def attn_forward_func_with_cp(
assert (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
out = AttnFuncWithCP.apply(
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
attn_bias_type,
attn_bias,
deterministic,
use_fused_attention,
sliding_window_attn = (
window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
)
if sliding_window_attn or cp_comm_type == "all_gather":
out = AttnFuncWithCPAndKVAllGather.apply(
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
attn_bias_type,
attn_bias,
deterministic,
use_fused_attention,
window_size,
)
elif cp_comm_type == "p2p":
out = AttnFuncWithCPAndKVP2P.apply(
is_training,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
cp_group,
cp_global_ranks,
cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
attn_bias_type,
attn_bias,
deterministic,
use_fused_attention,
)
else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
return out
......@@ -3316,6 +3873,7 @@ class FlashAttention(torch.nn.Module):
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
) -> torch.Tensor:
"""flash-attn fprop"""
......@@ -3424,10 +3982,6 @@ class FlashAttention(torch.nn.Module):
max_seqlen_kv = seqlens_kv.max().item()
if context_parallel:
assert window_size in (
(-1, -1),
(-1, 0),
), "Sliding window attention is not supported with context parallelism."
assert (
alibi_slopes is None
), "Alibi slope bias addition is not supported with context parallelism."
......@@ -3447,10 +4001,12 @@ class FlashAttention(torch.nn.Module):
cp_group,
cp_global_ranks,
cp_stream,
cp_comm_type,
softmax_scale=self.softmax_scale,
qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
attn_mask_type=attn_mask_type,
deterministic=self.deterministic,
window_size=window_size,
)
else:
......@@ -4995,6 +5551,7 @@ class FusedAttention(torch.nn.Module):
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
......@@ -5107,12 +5664,14 @@ class FusedAttention(torch.nn.Module):
cp_group,
cp_global_ranks,
cp_stream,
cp_comm_type,
softmax_scale=self.softmax_scale,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
attn_bias_type=core_attention_bias_type,
attn_bias=core_attention_bias,
use_fused_attention=True,
window_size=window_size,
)
else:
with self.attention_dropout_ctx():
......@@ -5260,6 +5819,9 @@ class DotProductAttention(TransformerEngineBaseModule):
compute and communication overlapping. To address the wave quantization
issue of each split step, we add an additional CUDA stream so that we
can overlap two flash attention kernels.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
"""
def __init__(
......@@ -5280,6 +5842,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None,
) -> None:
super().__init__()
......@@ -5307,6 +5870,7 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
self.cp_comm_type = cp_comm_type
self.hidden_size_per_attention_head_k = (
kv_channels if isinstance(kv_channels, int) else kv_channels[0]
......@@ -5430,6 +5994,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
) -> None:
"""
Set the context parallel attributes for the given
......@@ -5443,10 +6008,14 @@ class DotProductAttention(TransformerEngineBaseModule):
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
"""
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
self.cp_comm_type = cp_comm_type
@no_torch_dynamo(recursive=False)
def forward(
......@@ -5943,6 +6512,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
cp_comm_type=self.cp_comm_type,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
......@@ -5985,6 +6555,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
cp_comm_type=self.cp_comm_type,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
)
......@@ -6009,6 +6580,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
cp_comm_type=self.cp_comm_type,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
)
......@@ -6437,6 +7009,7 @@ class MultiheadAttention(torch.nn.Module):
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
) -> None:
"""
Set the context parallel attributes for the given
......@@ -6450,13 +7023,16 @@ class MultiheadAttention(torch.nn.Module):
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
def forward(
self,
......
......@@ -487,6 +487,7 @@ class TransformerLayer(torch.nn.Module):
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p",
) -> None:
"""
Set the context parallel attributes for the given
......@@ -500,13 +501,16 @@ class TransformerLayer(torch.nn.Module):
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather".
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment