Unverified Commit 9709147e authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add attention bias and qkv format to context parallelism (#726)



* make FusedAttn with CP support bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert Alibi cannot work with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* syntax fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix variable name
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix tensor shapes
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* a typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bias indexing for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add attn bias tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change dbias update location
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP test model configs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change CP test sequence length
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make AttnFuncWithCP support qkv format of sbhd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make sure qkv are contiguous for CP in cuDNN fused attn
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change assert message
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix code format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent f85553ea
......@@ -6,7 +6,7 @@ import os, sys
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from test_fused_attn_with_cp import model_configs
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
......@@ -17,8 +17,10 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
......@@ -40,8 +42,6 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
assert(rank in cp_comm_ranks)
cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl')
config = model_configs[model]
assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"
# instantiate core attn module
......@@ -69,18 +69,30 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
# create flash attention bias
if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
else:
bias = None
# make sure all GPU ranks have same inputs
for x in [q, k, v, dout]:
for x in [q, k, v, dout] + ([] if bias is None else [bias]):
dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(q, k, v)
out = core_attn(
q, k, v,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
)
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
......@@ -88,8 +100,16 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(*bias_.shape[:-2], 2*world_size, bias_.shape[-2]//(2*world_size), bias_.shape[-1])
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
out_ = core_attn(q_, k_, v_)
out_ = core_attn(
q_, k_, v_,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias_,
)
out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]:
......
......@@ -11,12 +11,12 @@ from test_fused_attn import (
_cudnn_version,
)
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # GQA
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
}
def get_bash_arguments(**kwargs):
......@@ -30,7 +30,7 @@ def get_bash_arguments(**kwargs):
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
......@@ -43,9 +43,21 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):
check=True
)
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
}
@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment