Unverified Commit 76669cdd authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/Pytorch] Expand layout support for fused attention (#403)



* add flexible layout support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add support for flexible qkv layout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes for compiling
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redudant file
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix options device error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix typos
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more changes; WIP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more changes; WIP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes and tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes and wrong results
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* sb3hd/bs3hd working on top of 3xsbhd/bshd/thd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dQ, dK, dV
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add nvtx
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove qkvso_strides on torch side; cover it in generateQKVStrides
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* all 15 layouts pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add workspace optimization
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes and test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* removed most debug info/clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add note to deprecate some qkv layouts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix code for unit tests in test_fused_attn.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further remove debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove a couple more comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix numerics tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes for lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix onnx for core attn; not fixed
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove nvtx and add env var for workspace opt
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove testing for env var
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace zeros/zeros_like with empty/empty_like
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix nvtx marker name for _q_k_v API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm80 when compiling for h100
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add mapping from qkv layout to layout group and qkv format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up enums mapping and remove trailing spaces
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify workspace opt control logic; only need env var
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 test, and minor modifications for other tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* avoid overwriting model configs in unit test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* random fixes/improvements: get_qkv_format/etc, default values, docstrings, comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix minor issues: invalid syntax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change workspace opt logic back to FORCE_WORKSPACE_OPT
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FP8 tests and generateStrides function
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_backend logic for max512/arbitrary
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unit tests; need cleanup
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up unit tests for layouts, and fix minor lint issue
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks for CI testing: onnx string issue and test fused attn first
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove one unsupported layout from max512 and add a check to qkvpacked API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix te layer test; reduce test time
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert compiler option changes; add back sm80 for even h100
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove some unit tests or make them optional to reduce CI time
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove more unit tests temporarily
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove _q_k_v in naming and add NVTE_ERROR for FP8 Aux_CTX_Tensors size checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more deprecation notes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove temp tests from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace with te::getenv
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redundant contiguous()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove thd->bs3hd user warning to avoid GPU sync
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adjust fused attn bs in tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* temporary fix for onnx issue; more fixes in PR 437
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unused variables
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent db589510
......@@ -9,6 +9,6 @@ set -e
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
......@@ -39,20 +39,23 @@ class ModelConfig:
model_configs = {
"test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"),
"test2": ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal"),
"test3": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test4": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test5": ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal"),
"test6": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
"test7": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
"test8": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
"test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test4": ModelConfig(1, 3072, 24, 128, 2048, 0.0, "causal"),
"test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
}
if os.getenv('NVTE_ADDITIONAL_TESTS', '0') == '1':
model_configs["test6"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal")
model_configs["test7"] = ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal")
model_configs["test8"] = ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal")
model_configs["test9"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask")
param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
batch_sizes = [1, 2, 32]
batch_sizes = [1, 2] # add more if needed, e.g. 32
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
......@@ -77,10 +80,10 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3)
if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
......@@ -126,7 +129,11 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v, attn_mask_type=config.attn_mask_type,
op = block(q, k, v,
qkv_format='sbhd',
cu_seqlens_q = cu_seqlens,
cu_seqlens_kv = cu_seqlens,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=bias_type,
core_attention_bias=bias)
......@@ -134,6 +141,130 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
return op, inp.grad
qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
# will add tests for thd layouts later when the support is available in fused attention
#'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
]
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
config = model_configs[model]
flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt)
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt)
atol, rtol = (5e-2, 5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
for i in range(len(flash_attn_bwd)):
torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol)
def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
dim_to_num = {'b': bs,
's': config.seq_len,
'h': config.num_attention_heads,
'd': config.head_dim,
't': bs * config.seq_len,
'3': 3,
'2': 2}
inp = []
for i,layout in enumerate(qkv_layout.split('_')):
tensor_shape = [dim_to_num[j] for j in layout]
tensor = 0.1 * torch.randn(tensor_shape, dtype = dtype).cuda()
tensor_count = 1
split_dim = 0
for dim,l in enumerate(layout):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
break
tensors = torch.split(tensor, 1, dim = split_dim) if split_dim != 0 else [tensor]
for j in range(tensor_count):
if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim))
else:
inp.append(tensors[j])
for i in range(3):
inp[i].requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp[0].device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
qkv_format_no_thd = qkv_format if qkv_format != 'thd' else 'bshd'
op_grad_shape = [dim_to_num[i] for i in qkv_format_no_thd]
op_grad_shape_new = [*op_grad_shape[:-2], op_grad_shape[-2] * op_grad_shape[-1]]
op_grad = 0.001 * torch.randint(0, 200, op_grad_shape_new, dtype = dtype).cuda()
block = (
DotProductAttention(
config.num_attention_heads,
config.head_dim,
attention_dropout = config.dropout_p,
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
tp_group = None,
layer_number = 1,
attention_type = "self"
).to(dtype = dtype).cuda()
)
if qkv_format != 'thd':
op = block(inp[0], inp[1], inp[2], qkv_format=qkv_format)
else:
cu_seqlens_q = torch.arange(
0,
(bs + 1) * config.seq_len,
step=config.seq_len,
dtype=torch.int32,
device=inp[0].device)
cu_seqlens_kv = torch.arange(
0,
(bs + 1) * config.seq_len,
step=config.seq_len,
dtype=torch.int32,
device=inp[1].device)
op = block(inp[0], inp[1], inp[2],
qkv_format=qkv_format,
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv)
op.backward(op_grad)
return op, (inp[0].grad, inp[1].grad, inp[2].grad)
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
......@@ -158,10 +289,10 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_par
atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params):
......@@ -231,7 +362,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus
.cuda()
)
num_iters = 10
num_iters = 5
for i in range(num_iters):
op = block(inp, self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
......@@ -269,8 +400,8 @@ def test_transformer_layer_gqa(dtype, bs, model):
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)
atol, rtol = 5e-1, 5e-2
assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):
......@@ -363,8 +494,8 @@ def test_dpa_fp8(dtype, bs, model):
dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (2.5e-2, 2.5e-2)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_dpa_fp8(dtype, bs, config, backend):
......@@ -427,7 +558,7 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=None,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self"
......@@ -439,8 +570,6 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
v = inp[:, :,2,:,:]
op = block(q, k, v, attn_mask_type=config.attn_mask_type)
op.backward(op_grad)
torch.save(op,'ctx_ref.pt')
torch.save(inp.grad,'dqkv_ref.pt')
return op, inp.grad
......@@ -455,6 +584,8 @@ from typing import Union, Dict, Any, Tuple, List
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd,
fused_attn_bwd,
FusedAttnBackend)
_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB
......@@ -542,11 +673,15 @@ class _dpa_fp8(torch.autograd.Function):
torch.save(qkv_out_fp16, 'qkv.pt')
# FMHA
context_, aux_ctx_tensors, *rest = fused_attn_fwd_qkvpacked(
context_, aux_ctx_tensors, *rest = fused_attn_fwd(
is_training,
max_s,
max_s,
cu_seqlens,
qkv_out,
cu_seqlens,
qkv_out[:,0,:,:],
qkv_out[:,1,:,:],
qkv_out[:,2,:,:],
fp8_dtype_forward,
FusedAttnBackend["FP8"],
None,
......@@ -558,7 +693,7 @@ class _dpa_fp8(torch.autograd.Function):
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
qkv_layout="qkv_interleaved",
qkv_layout="t3hd",
attn_bias_type="no_bias",
attn_mask_type="padding",
rng_gen=None,
......@@ -617,10 +752,14 @@ class _dpa_fp8(torch.autograd.Function):
grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
)
dqkv, *rest = fused_attn_bwd_qkvpacked(
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_s,
ctx.max_s,
ctx.cu_seqlens,
qkv_out,
ctx.cu_seqlens,
qkv_out[:,0,:,:],
qkv_out[:,1,:,:],
qkv_out[:,2,:,:],
context,
proj_dgrad.view_as(context),
fp8_dtype_forward,
......@@ -638,10 +777,11 @@ class _dpa_fp8(torch.autograd.Function):
None,
ctx.p_dropout,
ctx.fast_zero_fill,
"qkv_interleaved",
"t3hd",
"no_bias",
"padding",
)
dqkv = torch.cat([dq.unsqueeze(1), dk.unsqueeze(1), dv.unsqueeze(1)], dim=1)
dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size)
dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c,
......
......@@ -871,7 +871,7 @@ def _test_dpa_accuracy(block, bs, dtype, config):
key.retain_grad()
value.retain_grad()
out = block(query, key, value, mask)
out = block(query, key, value, attention_mask=mask)
loss = out.sum()
loss.backward()
......
......@@ -1005,6 +1005,7 @@ def test_export_core_attention(
# Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
qkv_format = "sbhd"
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
......@@ -1025,6 +1026,7 @@ def test_export_core_attention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
).to(device='cuda')
do_export(model,
......
......@@ -12,6 +12,66 @@
#include "fused_attn_fp8.h"
#include "../util/cuda_runtime.h"
// map NVTE_QKV_Layout to NVTE_QKV_Layout_Group
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
switch (qkv_layout) {
case NVTE_QKV_Layout::NVTE_SB3HD:
case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_T3HD:
case NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_3HD;
case NVTE_QKV_Layout::NVTE_SBH3D:
case NVTE_QKV_Layout::NVTE_BSH3D:
case NVTE_QKV_Layout::NVTE_TH3D:
return NVTE_QKV_Layout_Group::NVTE_H3D;
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_THD_T2HD:
case NVTE_QKV_Layout::NVTE_KV_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_HD_2HD;
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
case NVTE_QKV_Layout::NVTE_THD_TH2D:
return NVTE_QKV_Layout_Group::NVTE_HD_H2D;
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}
// map NVTE_QKV_Layout to NVTE_QKV_Format
NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
switch (qkv_layout) {
case NVTE_QKV_Layout::NVTE_SB3HD:
case NVTE_QKV_Layout::NVTE_SBH3D:
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_BSH3D:
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Layout::NVTE_T3HD:
case NVTE_QKV_Layout::NVTE_TH3D:
case NVTE_QKV_Layout::NVTE_THD_T2HD:
case NVTE_QKV_Layout::NVTE_THD_TH2D:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED:
case NVTE_QKV_Layout::NVTE_KV_INTERLEAVED:
case NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype,
......@@ -26,6 +86,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv)
......@@ -33,7 +94,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD))) {
#if (CUDNN_VERSION >= 8900)
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
#else
......@@ -52,7 +114,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) {
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
flag_m512 = true;
}
if (
......@@ -65,7 +132,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512))
......@@ -438,3 +507,201 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
// NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(
const NVTETensor Q,
const NVTETensor K,
const NVTETensor V,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor*>(K);
const Tensor *input_V = reinterpret_cast<const Tensor*>(V);
const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor*>(S);
Tensor *output_O = reinterpret_cast<Tensor*>(O);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
is_training, attn_scale, dropout, qkv_layout,
input_Q, input_K, input_V, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
// NVTE fused attention BWD with separate Q, K and V
void nvte_fused_attn_bwd(
const NVTETensor Q,
const NVTETensor K,
const NVTETensor V,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ,
NVTETensor dK,
NVTETensor dV,
NVTETensor dBias,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor*>(K);
const Tensor *input_V = reinterpret_cast<const Tensor*>(V);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor*>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP);
Tensor *output_dQ = reinterpret_cast<Tensor*>(dQ);
Tensor *output_dK = reinterpret_cast<Tensor*>(dK);
Tensor *output_dV = reinterpret_cast<Tensor*>(dV);
Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_dO,
output_S,
output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
fused_attn_arbitrary_seqlen_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO,
output_S,
output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
"with arbitrary sequence length. \n";
NVTE_ERROR(err_msg);
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
attn_scale, dropout, qkv_layout,
input_Q, input_K, input_V, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
output_dQ, output_dK, output_dV,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
......@@ -15,6 +15,7 @@
#include "../common.h"
#include "utils.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h"
#if (CUDNN_VERSION >= 8900)
#define Q_ID 1
......@@ -1059,6 +1060,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
if (!use_workspace_opt) {
auto matmul_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
......@@ -1221,9 +1223,6 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = 2 * num_head * head_dim;
......@@ -1295,9 +1294,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
......@@ -1337,31 +1333,173 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
(batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576;
// default upper limit for dp workspace 256MB
size_t max_allowed_dp_workspace = 256;
const char* env_workspace_limit_char = std::getenv("NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT");
if (env_workspace_limit_char != nullptr) {
try {
std::string env_dp_workspace_limit(env_workspace_limit_char);
int dp_workspace_limit = std::stoi(env_dp_workspace_limit);
if (dp_workspace_limit > max_allowed_dp_workspace) {
max_allowed_dp_workspace = dp_workspace_limit;
if (required_dp_workspace <= max_allowed_dp_workspace) {
use_workspace_opt = true;
}
use_workspace_opt = transformer_engine::getenv<bool>(
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
// will not be needed in cuDNN 8.9.6
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false;
}
}
#endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(qkv_type), workspace->data.dptr,
&workspace_size, stream, handle, use_workspace_opt);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
} catch (...) {
NVTE_ERROR(
"Invalid argument for NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT (integer; in MBytes)! \n");
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV,
Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dK->data.dptr;
void *devPtrdV = output_dV->data.dptr;
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
bool use_workspace_opt = false;
#if (CUDNN_VERSION >= 8905)
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
if (sm_arch_ >= 90) {
// quick estimate of dp workspace size
size_t max_seqlen_div_up_q = ((max_seqlen_q + 64 - 1) / 64) * 64;
size_t max_seqlen_div_up_kv = ((max_seqlen_kv + 64 - 1) / 64) * 64;
size_t required_dp_workspace =
(batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576;
// default upper limit for dp workspace 256MB
size_t max_allowed_dp_workspace = 256;
if (required_dp_workspace <= max_allowed_dp_workspace) {
use_workspace_opt = true;
}
use_workspace_opt = transformer_engine::getenv<bool>(
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
// will not be needed in cuDNN 8.9.6
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false;
}
}
#endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(qkv_type), workspace->data.dptr,
get_cudnn_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle, use_workspace_opt);
if (workspace_size > 0) {
......
......@@ -38,6 +38,30 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_size, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
......
......@@ -1250,9 +1250,6 @@ void fused_attn_max_512_fwd_qkvpacked(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = 2 * num_head * head_dim;
......@@ -1323,8 +1320,6 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_KV_INTERLEAVED.");
NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS,
"NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512.");
......@@ -1391,6 +1386,76 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
void *devPtrBias = input_Bias->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
const DType q_type = input_Q->data.dtype;
const DType kv_type = input_K->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devQCuSeqlen = q_cu_seqlens->data.dptr;
void *devKVCuSeqlen = kv_cu_seqlens->data.dptr;
const DType rng_state_type = rng_state->data.dtype;
NVTE_CHECK(rng_state_type == DType::kInt64);
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_max_512_fwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
......@@ -1402,9 +1467,6 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
......@@ -1465,9 +1527,6 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED,
"qkv_layout must be NVTE_KV_INTERLEAVED.");
// Q shape is [b, s, h, d]
// KV shape is [b, s, 2, h, d]
auto stride = 2 * num_head * head_dim;
......@@ -1518,5 +1577,63 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV,
Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dK->data.dptr;
void *devPtrdV = output_dV->data.dptr;
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr;
void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr;
const auto q_type = input_Q->data.dtype;
const auto kv_type = input_K->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
size_t workspace_size = 0;
fused_attn_max_512_bwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
} // namespace transformer_engine
#endif // CUDNN_VERSION >= 8901
......@@ -38,6 +38,17 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -56,6 +67,18 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV,
Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8901
} // namespace transformer_engine
......
......@@ -173,6 +173,7 @@ static cudnn_frontend::Tensor createScale(
static cudnn_frontend::Tensor createScaleWithOffset(
const cudnn_frontend::Tensor& prevBlockOutputTensor,
const std::string& scale_tensor_name,
NVTE_QKV_Layout layout,
cudnnDataType_t tensorType,
bool isOutputVirtual,
bool isScaleByValue,
......@@ -192,7 +193,7 @@ static cudnn_frontend::Tensor createScaleWithOffset(
generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2],
0 /*s_kv = 0 for placeholder*/,
output_dim[3], output_stride,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, NVTE_QKV_Matrix::NVTE_Q_Matrix);
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
} else {
// Otherwise output dim and stride should be the same as prev block dim and stride
for (int i = 0; i < 4; i++) {
......@@ -1163,6 +1164,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
auto OTensor = createScaleWithOffset(
OTensor_before_quan_O_tensor, // input tensor
"scaleO", // scale tensor
layout, // qkv layout
tensorType, // output tensor type
false, // output not virtual
false, // scale is by value
......@@ -1515,6 +1517,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
auto dVTensor = createScaleWithOffset(
dVTensor_before_quan_dV, // input tensor
"scaledV", // scale tensor
layout, // qkv layout
CUDNN_DATA_FP8_E5M2, // output tensor type
false, // output not virtual
false, // scale is by value
......@@ -1653,6 +1656,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
auto dQ = createScaleWithOffset(
After_dS_K_before_quan_dQ, // input tensor
"scaledQ", // scale tensor
layout, // qkv layout
CUDNN_DATA_FP8_E5M2, // output tensor type
false, // output not virtual
false, // scale is by value
......@@ -1693,6 +1697,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
auto dK = createScaleWithOffset(
After_dSTranspose_Q_before_quan_dK, // input tensor
"scaledK", // scale tensor
layout, // qkv layout
CUDNN_DATA_FP8_E5M2, // output tensor type
false, // output not virtual
false, // scale is by value
......@@ -1911,6 +1916,8 @@ void fused_attn_fp8_fwd_qkvpacked(
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void* devPtrAmaxS = input_output_S->amax.dptr;
......@@ -2048,5 +2055,204 @@ void fused_attn_fp8_bwd_qkvpacked(
return;
}
}
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q,
const Tensor *input_K,
const Tensor *input_V,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
void* devPtrQ = input_Q->data.dptr;
void* devPtrK = input_K->data.dptr;
void* devPtrV = input_V->data.dptr;
void* devPtrDescaleQ = input_Q->scale_inv.dptr;
void* devPtrDescaleK = input_Q->scale_inv.dptr;
void* devPtrDescaleV = input_Q->scale_inv.dptr;
void* devPtrO = output_O->data.dptr;
void* devPtrAmaxO = output_O->amax.dptr;
void* devPtrScaleO = output_O->scale.dptr;
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
if (is_training) {
Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen_q, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void* devPtrAmaxS = input_output_S->amax.dptr;
void* devPtrScaleS = input_output_S->scale.dptr;
void* devPtrDescaleS = input_output_S->scale_inv.dptr;
void* devPtrcuSeqlensQ = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
void* devPtrcuSeqlensKV = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_kv->data.dptr));
void* devPtrDropoutSeed = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype;
size_t workspace_size = 0;
fused_attn::fused_attn_fp8_fwd_impl(
b, max_seqlen_q, max_seqlen_kv, h, d,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = { workspace_size };
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = { 1 };
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q,
const Tensor *input_K,
const Tensor *input_V,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQ,
const Tensor *output_dK,
const Tensor *output_dV,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
void* devPtrQ = input_Q->data.dptr;
void* devPtrK = input_K->data.dptr;
void* devPtrV = input_V->data.dptr;
void* devPtrDescaleQ = input_Q->scale_inv.dptr;
void* devPtrDescaleK = input_Q->scale_inv.dptr;
void* devPtrDescaleV = input_Q->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr;
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
void* devPtrM = input_M->data.dptr;
void* devPtrZInv = input_ZInv->data.dptr;
void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdS = input_output_dP->amax.dptr;
void* devPtrScaledS = input_output_dP->scale.dptr;
void* devPtrDescaledS = input_output_dP->scale_inv.dptr;
void* devPtrdQ = output_dQ->data.dptr;
void* devPtrdK = output_dK->data.dptr;
void* devPtrdV = output_dV->data.dptr;
void* devPtrAmaxdQ = output_dQ->amax.dptr;
void* devPtrAmaxdK = output_dQ->amax.dptr;
void* devPtrAmaxdV = output_dQ->amax.dptr;
void* devPtrScaledQ = output_dQ->scale.dptr;
void* devPtrScaledK = output_dQ->scale.dptr;
void* devPtrScaledV = output_dQ->scale.dptr;
void* devPtrcuSeqlensQ = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
void* devPtrcuSeqlensKV = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_kv->data.dptr));
void* devPtrDropoutSeed = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype;
size_t workspace_size = 0;
fused_attn::fused_attn_fp8_bwd_impl(
b, max_seqlen_q, max_seqlen_kv, h, d,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledS,
devPtrScaleS, devPtrScaledS,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdS,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = { workspace_size };
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = { 1 };
workspace->data.dtype = DType::kByte;
return;
}
}
#endif // end of CUDNN>=8900
} // namespace transformer_engine
......@@ -46,5 +46,44 @@ void fused_attn_fp8_bwd_qkvpacked(
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQ,
const Tensor *output_dK,
const Tensor *output_dV,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
#endif // end of CUDNN>=8900
} // namespace transformer_engine
......@@ -30,6 +30,7 @@ void generateMatrixStrides(
constexpr int seqlen_q_dim_idx = 2;
constexpr int seqlen_kv_dim_idx = 3;
// to be deprecated in the future
switch (matrix) {
case NVTE_QKV_Matrix::NVTE_Q_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
......@@ -37,7 +38,8 @@ void generateMatrixStrides(
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * 3 * h * d;
} else {
} else if ((layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
|| (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d;
......@@ -55,7 +57,7 @@ void generateMatrixStrides(
strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d;
......@@ -73,7 +75,7 @@ void generateMatrixStrides(
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
......@@ -91,7 +93,7 @@ void generateMatrixStrides(
strideA[seqlen_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d;
......@@ -109,7 +111,7 @@ void generateMatrixStrides(
strideA[seqlen_transpose_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
} else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[head_dim_idx] = d;
......@@ -129,6 +131,228 @@ void generateMatrixStrides(
strideA[batch_dim_idx] = s_q * h * d;
break;
}
// new way of getting strides
switch (layout) {
case NVTE_QKV_Layout::NVTE_SB3HD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * 3 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = b * 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_SBH3D:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = 3 * d;
strideA[seqlen_dim_idx] = b * 3 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = 3 * d;
strideA[seqlen_transpose_dim_idx] = b * 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = 2 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * 2 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = 2 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = b * 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = 2 * h * d;
strideA[head_dim_idx] = 2 * d;
strideA[seqlen_dim_idx] = b * 2 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = 2 * h * d;
strideA[head_dim_idx] = 2 * d;
strideA[seqlen_transpose_dim_idx] = b * 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = b * h * d;
strideA[hidden_transpose_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_T3HD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_q * 3 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_q * 3 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSH3D:
case NVTE_QKV_Layout::NVTE_TH3D:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_q * 3 * h * d;
strideA[head_dim_idx] = 3 * d;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_q * 3 * h * d;
strideA[head_dim_idx] = 3 * d;
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_THD_T2HD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_kv * 2 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = 2 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_kv * 2 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
case NVTE_QKV_Layout::NVTE_THD_TH2D:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_kv * 2 * h * d;
strideA[head_dim_idx] = 2 * d;
strideA[seqlen_dim_idx] = 2 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_kv * 2 * h * d;
strideA[head_dim_idx] = 2 * d;
strideA[seqlen_transpose_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_kv * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_kv * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
}
break;
}
if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) {
strideA[seqlen_kv_dim_idx] = 1;
strideA[seqlen_q_dim_idx] = s_kv;
strideA[head_dim_idx] = s_q * s_kv;
strideA[batch_dim_idx] = h * s_q * s_kv;
}
}
bool allowAllConfig(cudnnBackendDescriptor_t engine_config) {
......
......@@ -18,7 +18,17 @@ extern "C" {
#endif
/*! \enum NVTE_QKV_Layout
* \brief QKV matrix layouts
* \brief Memory layouts of QKV tensors
* `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, the number of heads,
head size, and the total number of sequences in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
`SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
or padded to the same length, and `THD`-based layouts are used when sequences have
different lengths in a batch.
* \note {`NVTE_QKV_INTERLEAVED`, `NVTE_KV_INTERLEAVED` and `NVTE_NOT_INTERLEAVED`
will be deprecated in the next release. Please use their equivalent enums instead, i.e. `NVTE_T3HD`,
`NVTE_THD_T2HD` and `NVTE_THD_THD_THD` when sequences are of variable lengths, and `NVTE_BS3HD`,
`NVTE_BSHD_BS2HD` and `NVTE_BSHD_BSHD_BSHD` when sequences are of equal length or padded
to equal length.}
*/
enum NVTE_QKV_Layout {
/*! Separate Q, K, V tensors.
......@@ -67,7 +77,51 @@ enum NVTE_QKV_Layout {
| num_heads * head_dim
\endverbatim
*/
NVTE_KV_INTERLEAVED = 2
NVTE_KV_INTERLEAVED = 2,
NVTE_SB3HD = 3,
NVTE_SBH3D = 4,
NVTE_SBHD_SB2HD = 5,
NVTE_SBHD_SBH2D = 6,
NVTE_SBHD_SBHD_SBHD = 7,
NVTE_BS3HD = 8,
NVTE_BSH3D = 9,
NVTE_BSHD_BS2HD = 10,
NVTE_BSHD_BSH2D = 11,
NVTE_BSHD_BSHD_BSHD = 12,
NVTE_T3HD = 13,
NVTE_TH3D = 14,
NVTE_THD_T2HD = 15,
NVTE_THD_TH2D = 16,
NVTE_THD_THD_THD = 17,
};
/*! \enum NVTE_QKV_Layout_Group
* \brief Grouping of QKV layouts
*/
enum NVTE_QKV_Layout_Group {
/*! 3HD QKV layouts, e.g. BS3HD */
NVTE_3HD = 0,
/*! H3D QKV layouts, e.g. BSH3D */
NVTE_H3D = 1,
/*! HD_2HD QKV layouts, e.g. BSHD_BS2HD */
NVTE_HD_2HD = 2,
/*! HD_H2D QKV layouts, e.g. BSHD_BSH2D */
NVTE_HD_H2D = 3,
/*! HD_HD_HD QKV layouts, e.g. BSHD_BSHD_BSHD */
NVTE_HD_HD_HD = 4,
};
/*! \enum NVTE_QKV_Format
* \brief Dimension formats for QKV tensors
*/
enum NVTE_QKV_Format {
/*! SBHD QKV format */
NVTE_SBHD = 0,
/*! BSHD QKV format */
NVTE_BSHD = 1,
/*! THD QKV format */
NVTE_THD = 2,
};
/*! \enum NVTE_Bias_Type
......@@ -94,6 +148,9 @@ enum NVTE_Mask_Type {
NVTE_CAUSAL_MASK = 2,
};
/*! \enum NVTE_Fused_Attn_Backend
* \brief Fused attention backends
*/
enum NVTE_Fused_Attn_Backend {
/*! No supported backend */
NVTE_No_Backend = -1,
......@@ -105,6 +162,22 @@ enum NVTE_Fused_Attn_Backend {
NVTE_FP8 = 2,
};
/*! \brief Get layout group for a given QKV layout
*
* \param[in] qkv_layout QKV layout, e.g. sbh3d.
*
* \return qkv layout group, e.g. h3d.
*/
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout);
/*! \brief Get QKV format for a given QKV layout
*
* \param[in] qkv_layout QKV layout, e.g. sbh3d.
*
* \return qkv format, e.g. sbhd.
*/
NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] q_dtype The data type of Tensor Q.
......@@ -152,7 +225,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
......@@ -199,7 +272,7 @@ void nvte_fused_attn_fwd_qkvpacked(
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
......@@ -250,9 +323,9 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q).
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv).
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
......@@ -301,9 +374,9 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q).
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv).
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
......@@ -332,6 +405,122 @@ void nvte_fused_attn_bwd_kvpacked(
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
* Computes:
* - P = Q * Transpose(K) + Bias
* - S = ScaleMaskSoftmax(P)
* - D = Dropout(S)
* - O = D * Transpose(V)
*
* Support Matrix:
\verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor.
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(
const NVTETensor Q,
const NVTETensor K,
const NVTETensor V,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
* Support Matrix:
\verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor.
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd(
const NVTETensor Q,
const NVTETensor K,
const NVTETensor V,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ,
NVTETensor dK,
NVTETensor dV,
NVTETensor dBias,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
This diff is collapsed.
......@@ -28,6 +28,11 @@ AttnTypes = ("self", "cross")
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias")
QKVLayouts = (
"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd")
LayerTypes = ("encoder", "decoder")
GemmParallelModes = ("row", "column", None)
......
......@@ -106,6 +106,52 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV);
std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor K,
const at::Tensor V,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor K,
const at::Tensor V,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV);
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
......
......@@ -56,6 +56,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Attention FP8/BF16/FP16 FWD with packed KV");
m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed KV");
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("gelu", &gelu, "GeLU with FP8 output");
m.def("relu", &relu, "ReLU with FP8 output");
......@@ -148,7 +152,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED);
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D)
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
.value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D)
.value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD)
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D)
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D)
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend")
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
......
......@@ -74,6 +74,7 @@ class TransformerLayer(torch.nn.Module):
are deprecated and will be fully removed in future releases.
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`self_attn_mask_type` is set to `"causal"`.
......@@ -638,5 +639,5 @@ class TransformerLayer(torch.nn.Module):
if self.output_layernorm:
output = self.layernorm(output)
# output: [b, s, h]
# output: [s, b, h]
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment