Unverified Commit f68df153 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

[PyTorch] Add support for cuDNN FusedAttention + THD + CP (#885)



* add seq_offsets_qkvo for cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets calculation of cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove a thd assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bias for thd test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add thd test for cudnn FA with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* skip GQA/MQA test for cuDNN THD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets inputs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove two comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn mask type for cudnn thd with cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type for cudnn fa with thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a typo
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix out dout in bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert cudnn+thd does not support attn bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if attn_mask_type has padding
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change cp test batch size to 2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix code format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix two assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 90f3c9ad
......@@ -22,6 +22,8 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
if qkv_format == 'thd' and (config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"):
return
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
......@@ -45,6 +47,12 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == 'FusedAttention' and qkv_format == 'thd':
if 'causal' in config.attn_mask_type:
config.attn_mask_type = 'padding_causal'
else:
config.attn_mask_type = 'padding'
# instantiate core attn module
core_attn = DotProductAttention(config.num_heads,
config.head_dim,
......@@ -112,9 +120,9 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
if qkv_format == "bshd" or qkv_format == "sbhd":
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
......@@ -122,14 +130,12 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
elif qkv_format == "thd":
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
cu_seqlens_q = cu_seqlens_q // world_size
cu_seqlens_kv = cu_seqlens_kv // world_size
bias_ = None
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
......@@ -158,7 +164,10 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
# compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
tols = dict(atol=2.5e-2, rtol=2.5e-2)
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
......
......@@ -14,10 +14,10 @@ from transformer_engine.pytorch.utils import get_device_compute_capability
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
}
def get_bash_arguments(**kwargs):
......@@ -47,21 +47,21 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
}
@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd'])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment