Unverified Commit 76669cdd authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/Pytorch] Expand layout support for fused attention (#403)



* add flexible layout support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add support for flexible qkv layout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes for compiling
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redudant file
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix options device error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix typos
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more changes; WIP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more changes; WIP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes and tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes and wrong results
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* sb3hd/bs3hd working on top of 3xsbhd/bshd/thd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dQ, dK, dV
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add nvtx
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove qkvso_strides on torch side; cover it in generateQKVStrides
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* all 15 layouts pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add workspace optimization
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes and test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* removed most debug info/clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add note to deprecate some qkv layouts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix code for unit tests in test_fused_attn.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further remove debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove a couple more comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix numerics tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes for lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix onnx for core attn; not fixed
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove nvtx and add env var for workspace opt
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove testing for env var
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace zeros/zeros_like with empty/empty_like
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix nvtx marker name for _q_k_v API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm80 when compiling for h100
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add mapping from qkv layout to layout group and qkv format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up enums mapping and remove trailing spaces
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify workspace opt control logic; only need env var
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 test, and minor modifications for other tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* avoid overwriting model configs in unit test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* random fixes/improvements: get_qkv_format/etc, default values, docstrings, comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix minor issues: invalid syntax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change workspace opt logic back to FORCE_WORKSPACE_OPT
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FP8 tests and generateStrides function
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_backend logic for max512/arbitrary
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unit tests; need cleanup
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up unit tests for layouts, and fix minor lint issue
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks for CI testing: onnx string issue and test fused attn first
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove one unsupported layout from max512 and add a check to qkvpacked API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix te layer test; reduce test time
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert compiler option changes; add back sm80 for even h100
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove some unit tests or make them optional to reduce CI time
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove more unit tests temporarily
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove _q_k_v in naming and add NVTE_ERROR for FP8 Aux_CTX_Tensors size checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more deprecation notes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove temp tests from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace with te::getenv
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redundant contiguous()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove thd->bs3hd user warning to avoid GPU sync
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adjust fused attn bs in tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* temporary fix for onnx issue; more fixes in PR 437
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unused variables
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent db589510
...@@ -9,6 +9,6 @@ set -e ...@@ -9,6 +9,6 @@ set -e
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
...@@ -39,20 +39,23 @@ class ModelConfig: ...@@ -39,20 +39,23 @@ class ModelConfig:
model_configs = { model_configs = {
"test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"), "test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"),
"test2": ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal"), "test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test3": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"), "test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test4": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"), "test4": ModelConfig(1, 3072, 24, 128, 2048, 0.0, "causal"),
"test5": ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal"), "test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
"test6": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
"test7": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
"test8": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
} }
if os.getenv('NVTE_ADDITIONAL_TESTS', '0') == '1':
model_configs["test6"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal")
model_configs["test7"] = ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal")
model_configs["test8"] = ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal")
model_configs["test9"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask")
param_types = [torch.float16] param_types = [torch.float16]
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
batch_sizes = [1, 2, 32] batch_sizes = [1, 2] # add more if needed, e.g. 32
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
...@@ -77,10 +80,10 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type): ...@@ -77,10 +80,10 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3) atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3)
if bias_type == "no_bias": if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type): def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
...@@ -126,7 +129,11 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) ...@@ -126,7 +129,11 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
q = inp[:, :,0,:,:] q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:] k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:] v = inp[:, :,2,:,:]
op = block(q, k, v, attn_mask_type=config.attn_mask_type, op = block(q, k, v,
qkv_format='sbhd',
cu_seqlens_q = cu_seqlens,
cu_seqlens_kv = cu_seqlens,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn, checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=bias_type, core_attention_bias_type=bias_type,
core_attention_bias=bias) core_attention_bias=bias)
...@@ -134,6 +141,130 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) ...@@ -134,6 +141,130 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
return op, inp.grad return op, inp.grad
qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
# will add tests for thd layouts later when the support is available in fused attention
#'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
]
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
config = model_configs[model]
flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt)
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt)
atol, rtol = (5e-2, 5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3)
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
for i in range(len(flash_attn_bwd)):
torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], atol = atol, rtol = rtol)
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol)
def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
dim_to_num = {'b': bs,
's': config.seq_len,
'h': config.num_attention_heads,
'd': config.head_dim,
't': bs * config.seq_len,
'3': 3,
'2': 2}
inp = []
for i,layout in enumerate(qkv_layout.split('_')):
tensor_shape = [dim_to_num[j] for j in layout]
tensor = 0.1 * torch.randn(tensor_shape, dtype = dtype).cuda()
tensor_count = 1
split_dim = 0
for dim,l in enumerate(layout):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
break
tensors = torch.split(tensor, 1, dim = split_dim) if split_dim != 0 else [tensor]
for j in range(tensor_count):
if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim))
else:
inp.append(tensors[j])
for i in range(3):
inp[i].requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp[0].device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
qkv_format_no_thd = qkv_format if qkv_format != 'thd' else 'bshd'
op_grad_shape = [dim_to_num[i] for i in qkv_format_no_thd]
op_grad_shape_new = [*op_grad_shape[:-2], op_grad_shape[-2] * op_grad_shape[-1]]
op_grad = 0.001 * torch.randint(0, 200, op_grad_shape_new, dtype = dtype).cuda()
block = (
DotProductAttention(
config.num_attention_heads,
config.head_dim,
attention_dropout = config.dropout_p,
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
tp_group = None,
layer_number = 1,
attention_type = "self"
).to(dtype = dtype).cuda()
)
if qkv_format != 'thd':
op = block(inp[0], inp[1], inp[2], qkv_format=qkv_format)
else:
cu_seqlens_q = torch.arange(
0,
(bs + 1) * config.seq_len,
step=config.seq_len,
dtype=torch.int32,
device=inp[0].device)
cu_seqlens_kv = torch.arange(
0,
(bs + 1) * config.seq_len,
step=config.seq_len,
dtype=torch.int32,
device=inp[1].device)
op = block(inp[0], inp[1], inp[2],
qkv_format=qkv_format,
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv)
op.backward(op_grad)
return op, (inp[0].grad, inp[1].grad, inp[2].grad)
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -158,10 +289,10 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_par ...@@ -158,10 +289,10 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_par
atol, rtol = (5e-1, 5e-2) atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias": if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params): def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params):
...@@ -231,7 +362,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus ...@@ -231,7 +362,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus
.cuda() .cuda()
) )
num_iters = 10 num_iters = 5
for i in range(num_iters): for i in range(num_iters):
op = block(inp, self_attn_mask_type=config.attn_mask_type, op = block(inp, self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn, checkpoint_core_attention=ckpt_attn,
...@@ -269,8 +400,8 @@ def test_transformer_layer_gqa(dtype, bs, model): ...@@ -269,8 +400,8 @@ def test_transformer_layer_gqa(dtype, bs, model):
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group) dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)
atol, rtol = 5e-1, 5e-2 atol, rtol = 5e-1, 5e-2
assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group): def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):
...@@ -363,8 +494,8 @@ def test_dpa_fp8(dtype, bs, model): ...@@ -363,8 +494,8 @@ def test_dpa_fp8(dtype, bs, model):
dtype, bs, config, "UnfusedDotProductAttention") dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (2.5e-2, 2.5e-2) atol, rtol = (2.5e-2, 2.5e-2)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_dpa_fp8(dtype, bs, config, backend): def _run_dpa_fp8(dtype, bs, config, backend):
...@@ -427,7 +558,7 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): ...@@ -427,7 +558,7 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
sequence_parallel=False, sequence_parallel=False,
tp_size=1, tp_size=1,
get_rng_state_tracker=None, get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None, tp_group=None,
layer_number=1, layer_number=1,
attention_type="self" attention_type="self"
...@@ -439,8 +570,6 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): ...@@ -439,8 +570,6 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
v = inp[:, :,2,:,:] v = inp[:, :,2,:,:]
op = block(q, k, v, attn_mask_type=config.attn_mask_type) op = block(q, k, v, attn_mask_type=config.attn_mask_type)
op.backward(op_grad) op.backward(op_grad)
torch.save(op,'ctx_ref.pt')
torch.save(inp.grad,'dqkv_ref.pt')
return op, inp.grad return op, inp.grad
...@@ -455,6 +584,8 @@ from typing import Union, Dict, Any, Tuple, List ...@@ -455,6 +584,8 @@ from typing import Union, Dict, Any, Tuple, List
from transformer_engine.pytorch.cpp_extensions.fused_attn import ( from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked, fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked, fused_attn_bwd_qkvpacked,
fused_attn_fwd,
fused_attn_bwd,
FusedAttnBackend) FusedAttnBackend)
_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB _CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB
...@@ -542,11 +673,15 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -542,11 +673,15 @@ class _dpa_fp8(torch.autograd.Function):
torch.save(qkv_out_fp16, 'qkv.pt') torch.save(qkv_out_fp16, 'qkv.pt')
# FMHA # FMHA
context_, aux_ctx_tensors, *rest = fused_attn_fwd_qkvpacked( context_, aux_ctx_tensors, *rest = fused_attn_fwd(
is_training, is_training,
max_s, max_s,
max_s,
cu_seqlens, cu_seqlens,
qkv_out, cu_seqlens,
qkv_out[:,0,:,:],
qkv_out[:,1,:,:],
qkv_out[:,2,:,:],
fp8_dtype_forward, fp8_dtype_forward,
FusedAttnBackend["FP8"], FusedAttnBackend["FP8"],
None, None,
...@@ -558,7 +693,7 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -558,7 +693,7 @@ class _dpa_fp8(torch.autograd.Function):
attn_scale=None, attn_scale=None,
dropout=p_dropout, dropout=p_dropout,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
qkv_layout="qkv_interleaved", qkv_layout="t3hd",
attn_bias_type="no_bias", attn_bias_type="no_bias",
attn_mask_type="padding", attn_mask_type="padding",
rng_gen=None, rng_gen=None,
...@@ -617,10 +752,14 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -617,10 +752,14 @@ class _dpa_fp8(torch.autograd.Function):
grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
) )
dqkv, *rest = fused_attn_bwd_qkvpacked( dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_s,
ctx.max_s, ctx.max_s,
ctx.cu_seqlens, ctx.cu_seqlens,
qkv_out, ctx.cu_seqlens,
qkv_out[:,0,:,:],
qkv_out[:,1,:,:],
qkv_out[:,2,:,:],
context, context,
proj_dgrad.view_as(context), proj_dgrad.view_as(context),
fp8_dtype_forward, fp8_dtype_forward,
...@@ -638,10 +777,11 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -638,10 +777,11 @@ class _dpa_fp8(torch.autograd.Function):
None, None,
ctx.p_dropout, ctx.p_dropout,
ctx.fast_zero_fill, ctx.fast_zero_fill,
"qkv_interleaved", "t3hd",
"no_bias", "no_bias",
"padding", "padding",
) )
dqkv = torch.cat([dq.unsqueeze(1), dk.unsqueeze(1), dv.unsqueeze(1)], dim=1)
dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size) dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size)
dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c, dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c,
......
...@@ -871,7 +871,7 @@ def _test_dpa_accuracy(block, bs, dtype, config): ...@@ -871,7 +871,7 @@ def _test_dpa_accuracy(block, bs, dtype, config):
key.retain_grad() key.retain_grad()
value.retain_grad() value.retain_grad()
out = block(query, key, value, mask) out = block(query, key, value, attention_mask=mask)
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
......
...@@ -1005,6 +1005,7 @@ def test_export_core_attention( ...@@ -1005,6 +1005,7 @@ def test_export_core_attention(
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels) qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
qkv_format = "sbhd"
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
...@@ -1025,6 +1026,7 @@ def test_export_core_attention( ...@@ -1025,6 +1026,7 @@ def test_export_core_attention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
kv_channels=kv_channels, kv_channels=kv_channels,
attention_dropout=0.5, attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
).to(device='cuda') ).to(device='cuda')
do_export(model, do_export(model,
......
...@@ -12,6 +12,66 @@ ...@@ -12,6 +12,66 @@
#include "fused_attn_fp8.h" #include "fused_attn_fp8.h"
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
// map NVTE_QKV_Layout to NVTE_QKV_Layout_Group
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
switch (qkv_layout) {
case NVTE_QKV_Layout::NVTE_SB3HD:
case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_T3HD:
case NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_3HD;
case NVTE_QKV_Layout::NVTE_SBH3D:
case NVTE_QKV_Layout::NVTE_BSH3D:
case NVTE_QKV_Layout::NVTE_TH3D:
return NVTE_QKV_Layout_Group::NVTE_H3D;
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_THD_T2HD:
case NVTE_QKV_Layout::NVTE_KV_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_HD_2HD;
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
case NVTE_QKV_Layout::NVTE_THD_TH2D:
return NVTE_QKV_Layout_Group::NVTE_HD_H2D;
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED:
return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}
// map NVTE_QKV_Layout to NVTE_QKV_Format
NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
switch (qkv_layout) {
case NVTE_QKV_Layout::NVTE_SB3HD:
case NVTE_QKV_Layout::NVTE_SBH3D:
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_BSH3D:
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Layout::NVTE_T3HD:
case NVTE_QKV_Layout::NVTE_TH3D:
case NVTE_QKV_Layout::NVTE_THD_T2HD:
case NVTE_QKV_Layout::NVTE_THD_TH2D:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED:
case NVTE_QKV_Layout::NVTE_KV_INTERLEAVED:
case NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}
// select a backend for fused attention // select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType q_dtype,
...@@ -26,6 +86,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -26,6 +86,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id); const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2) if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90) && (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q == max_seqlen_kv)
...@@ -33,7 +94,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -33,7 +94,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (head_dim == 64) && (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) { && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD))) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
backend = NVTE_Fused_Attn_Backend::NVTE_FP8; backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
#else #else
...@@ -52,7 +114,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -52,7 +114,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) { || (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
flag_m512 = true; flag_m512 = true;
} }
if ( if (
...@@ -65,7 +132,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -65,7 +132,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& ((head_dim == 64) || (head_dim == 128)) && ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) { && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
flag_arb = true; flag_arb = true;
} }
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) if (((max_seqlen_q > 512) || (max_seqlen_kv > 512))
...@@ -438,3 +507,201 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -438,3 +507,201 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
// NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(
const NVTETensor Q,
const NVTETensor K,
const NVTETensor V,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor*>(K);
const Tensor *input_V = reinterpret_cast<const Tensor*>(V);
const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor*>(S);
Tensor *output_O = reinterpret_cast<Tensor*>(O);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
is_training, attn_scale, dropout, qkv_layout,
input_Q, input_K, input_V, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
// NVTE fused attention BWD with separate Q, K and V
void nvte_fused_attn_bwd(
const NVTETensor Q,
const NVTETensor K,
const NVTETensor V,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ,
NVTETensor dK,
NVTETensor dV,
NVTETensor dBias,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor*>(K);
const Tensor *input_V = reinterpret_cast<const Tensor*>(V);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor*>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP);
Tensor *output_dQ = reinterpret_cast<Tensor*>(dQ);
Tensor *output_dK = reinterpret_cast<Tensor*>(dK);
Tensor *output_dV = reinterpret_cast<Tensor*>(dV);
Tensor *output_dBias = reinterpret_cast<Tensor*>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_dO,
output_S,
output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
fused_attn_arbitrary_seqlen_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO,
output_S,
output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
"with arbitrary sequence length. \n";
NVTE_ERROR(err_msg);
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(
b, max_seqlen_q, max_seqlen_kv, h, d,
attn_scale, dropout, qkv_layout,
input_Q, input_K, input_V, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
output_dQ, output_dK, output_dV,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "../common.h" #include "../common.h"
#include "utils.h" #include "utils.h"
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
#include "../util/system.h"
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
#define Q_ID 1 #define Q_ID 1
...@@ -1059,6 +1060,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1059,6 +1060,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder() auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT) .setComputeType(CUDNN_DATA_FLOAT)
.build(); .build();
if (!use_workspace_opt) { if (!use_workspace_opt) {
auto matmul_op3 = cudnn_frontend::OperationBuilder( auto matmul_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
...@@ -1221,9 +1223,6 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1221,9 +1223,6 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d] // QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr; void *devPtrQKV = input_QKV->data.dptr;
const auto stride = 2 * num_head * head_dim; const auto stride = 2 * num_head * head_dim;
...@@ -1295,9 +1294,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -1295,9 +1294,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d] // QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr; void *devPtrQKV = input_QKV->data.dptr;
...@@ -1337,31 +1333,173 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -1337,31 +1333,173 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
(batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576; (batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576;
// default upper limit for dp workspace 256MB // default upper limit for dp workspace 256MB
size_t max_allowed_dp_workspace = 256; size_t max_allowed_dp_workspace = 256;
const char* env_workspace_limit_char = std::getenv("NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT"); if (required_dp_workspace <= max_allowed_dp_workspace) {
if (env_workspace_limit_char != nullptr) { use_workspace_opt = true;
try { }
std::string env_dp_workspace_limit(env_workspace_limit_char); use_workspace_opt = transformer_engine::getenv<bool>(
int dp_workspace_limit = std::stoi(env_dp_workspace_limit); "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
if (dp_workspace_limit > max_allowed_dp_workspace) { // will not be needed in cuDNN 8.9.6
max_allowed_dp_workspace = dp_workspace_limit; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false;
}
}
#endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(qkv_type), workspace->data.dptr,
&workspace_size, stream, handle, use_workspace_opt);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen_q, 1};
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
} catch (...) {
NVTE_ERROR( void* devPtrDropoutSeed = rng_state->data.dptr;
"Invalid argument for NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT (integer; in MBytes)! \n"); void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
} }
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
}
void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV,
Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dK->data.dptr;
void *devPtrdV = output_dV->data.dptr;
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
bool use_workspace_opt = false;
#if (CUDNN_VERSION >= 8905)
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
if (sm_arch_ >= 90) {
// quick estimate of dp workspace size
size_t max_seqlen_div_up_q = ((max_seqlen_q + 64 - 1) / 64) * 64;
size_t max_seqlen_div_up_kv = ((max_seqlen_kv + 64 - 1) / 64) * 64;
size_t required_dp_workspace =
(batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576;
// default upper limit for dp workspace 256MB
size_t max_allowed_dp_workspace = 256;
if (required_dp_workspace <= max_allowed_dp_workspace) { if (required_dp_workspace <= max_allowed_dp_workspace) {
use_workspace_opt = true; use_workspace_opt = true;
} }
use_workspace_opt = transformer_engine::getenv<bool>(
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
// will not be needed in cuDNN 8.9.6
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false;
}
} }
#endif #endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(qkv_type), workspace->data.dptr, get_cudnn_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle, use_workspace_opt); &workspace_size, stream, handle, use_workspace_opt);
if (workspace_size > 0) { if (workspace_size > 0) {
......
...@@ -38,6 +38,30 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -38,6 +38,30 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_size, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900 #endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -1250,9 +1250,6 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1250,9 +1250,6 @@ void fused_attn_max_512_fwd_qkvpacked(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d] // QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr; void *devPtrQKV = input_QKV->data.dptr;
const auto stride = 2 * num_head * head_dim; const auto stride = 2 * num_head * head_dim;
...@@ -1323,8 +1320,6 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1323,8 +1320,6 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_KV_INTERLEAVED.");
NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS, bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS,
"NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512."); "NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512.");
...@@ -1391,6 +1386,76 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1391,6 +1386,76 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_ERROR("Unexpected workspace_size."); NVTE_ERROR("Unexpected workspace_size.");
} }
} }
void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
void *devPtrBias = input_Bias->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
const DType q_type = input_Q->data.dtype;
const DType kv_type = input_K->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devQCuSeqlen = q_cu_seqlens->data.dptr;
void *devKVCuSeqlen = kv_cu_seqlens->data.dptr;
const DType rng_state_type = rng_state->data.dtype;
NVTE_CHECK(rng_state_type == DType::kInt64);
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
fused_attn_max_512_fwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout, size_t head_dim, float attn_scale, float p_dropout,
...@@ -1402,9 +1467,6 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -1402,9 +1467,6 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d] // QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr; void *devPtrQKV = input_QKV->data.dptr;
...@@ -1465,9 +1527,6 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1465,9 +1527,6 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED,
"qkv_layout must be NVTE_KV_INTERLEAVED.");
// Q shape is [b, s, h, d] // Q shape is [b, s, h, d]
// KV shape is [b, s, 2, h, d] // KV shape is [b, s, 2, h, d]
auto stride = 2 * num_head * head_dim; auto stride = 2 * num_head * head_dim;
...@@ -1518,5 +1577,63 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1518,5 +1577,63 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_ERROR("Unexpected workspace_size."); NVTE_ERROR("Unexpected workspace_size.");
} }
} }
void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV,
Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dK->data.dptr;
void *devPtrdV = output_dV->data.dptr;
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr;
void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr;
const auto q_type = input_Q->data.dtype;
const auto kv_type = input_K->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
size_t workspace_size = 0;
fused_attn_max_512_bwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
} // namespace transformer_engine } // namespace transformer_engine
#endif // CUDNN_VERSION >= 8901 #endif // CUDNN_VERSION >= 8901
...@@ -38,6 +38,17 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -38,6 +38,17 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const Tensor *kv_cu_seqlens, const Tensor *rng_state, const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -56,6 +67,18 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -56,6 +67,18 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV,
Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8901 #endif // CUDNN_VERSION >= 8901
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -173,6 +173,7 @@ static cudnn_frontend::Tensor createScale( ...@@ -173,6 +173,7 @@ static cudnn_frontend::Tensor createScale(
static cudnn_frontend::Tensor createScaleWithOffset( static cudnn_frontend::Tensor createScaleWithOffset(
const cudnn_frontend::Tensor& prevBlockOutputTensor, const cudnn_frontend::Tensor& prevBlockOutputTensor,
const std::string& scale_tensor_name, const std::string& scale_tensor_name,
NVTE_QKV_Layout layout,
cudnnDataType_t tensorType, cudnnDataType_t tensorType,
bool isOutputVirtual, bool isOutputVirtual,
bool isScaleByValue, bool isScaleByValue,
...@@ -192,7 +193,7 @@ static cudnn_frontend::Tensor createScaleWithOffset( ...@@ -192,7 +193,7 @@ static cudnn_frontend::Tensor createScaleWithOffset(
generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2], generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2],
0 /*s_kv = 0 for placeholder*/, 0 /*s_kv = 0 for placeholder*/,
output_dim[3], output_stride, output_dim[3], output_stride,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, NVTE_QKV_Matrix::NVTE_Q_Matrix); layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
} else { } else {
// Otherwise output dim and stride should be the same as prev block dim and stride // Otherwise output dim and stride should be the same as prev block dim and stride
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
...@@ -1163,6 +1164,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in ...@@ -1163,6 +1164,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
auto OTensor = createScaleWithOffset( auto OTensor = createScaleWithOffset(
OTensor_before_quan_O_tensor, // input tensor OTensor_before_quan_O_tensor, // input tensor
"scaleO", // scale tensor "scaleO", // scale tensor
layout, // qkv layout
tensorType, // output tensor type tensorType, // output tensor type
false, // output not virtual false, // output not virtual
false, // scale is by value false, // scale is by value
...@@ -1515,6 +1517,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in ...@@ -1515,6 +1517,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
auto dVTensor = createScaleWithOffset( auto dVTensor = createScaleWithOffset(
dVTensor_before_quan_dV, // input tensor dVTensor_before_quan_dV, // input tensor
"scaledV", // scale tensor "scaledV", // scale tensor
layout, // qkv layout
CUDNN_DATA_FP8_E5M2, // output tensor type CUDNN_DATA_FP8_E5M2, // output tensor type
false, // output not virtual false, // output not virtual
false, // scale is by value false, // scale is by value
...@@ -1653,6 +1656,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in ...@@ -1653,6 +1656,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
auto dQ = createScaleWithOffset( auto dQ = createScaleWithOffset(
After_dS_K_before_quan_dQ, // input tensor After_dS_K_before_quan_dQ, // input tensor
"scaledQ", // scale tensor "scaledQ", // scale tensor
layout, // qkv layout
CUDNN_DATA_FP8_E5M2, // output tensor type CUDNN_DATA_FP8_E5M2, // output tensor type
false, // output not virtual false, // output not virtual
false, // scale is by value false, // scale is by value
...@@ -1693,6 +1697,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in ...@@ -1693,6 +1697,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
auto dK = createScaleWithOffset( auto dK = createScaleWithOffset(
After_dSTranspose_Q_before_quan_dK, // input tensor After_dSTranspose_Q_before_quan_dK, // input tensor
"scaledK", // scale tensor "scaledK", // scale tensor
layout, // qkv layout
CUDNN_DATA_FP8_E5M2, // output tensor type CUDNN_DATA_FP8_E5M2, // output tensor type
false, // output not virtual false, // output not virtual
false, // scale is by value false, // scale is by value
...@@ -1911,6 +1916,8 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -1911,6 +1916,8 @@ void fused_attn_fp8_fwd_qkvpacked(
devPtrM = output_M->data.dptr; devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr; devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
void* devPtrAmaxS = input_output_S->amax.dptr; void* devPtrAmaxS = input_output_S->amax.dptr;
...@@ -2048,5 +2055,204 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2048,5 +2055,204 @@ void fused_attn_fp8_bwd_qkvpacked(
return; return;
} }
} }
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q,
const Tensor *input_K,
const Tensor *input_V,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
void* devPtrQ = input_Q->data.dptr;
void* devPtrK = input_K->data.dptr;
void* devPtrV = input_V->data.dptr;
void* devPtrDescaleQ = input_Q->scale_inv.dptr;
void* devPtrDescaleK = input_Q->scale_inv.dptr;
void* devPtrDescaleV = input_Q->scale_inv.dptr;
void* devPtrO = output_O->data.dptr;
void* devPtrAmaxO = output_O->amax.dptr;
void* devPtrScaleO = output_O->scale.dptr;
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
if (is_training) {
Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen_q, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void* devPtrAmaxS = input_output_S->amax.dptr;
void* devPtrScaleS = input_output_S->scale.dptr;
void* devPtrDescaleS = input_output_S->scale_inv.dptr;
void* devPtrcuSeqlensQ = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
void* devPtrcuSeqlensKV = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_kv->data.dptr));
void* devPtrDropoutSeed = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype;
size_t workspace_size = 0;
fused_attn::fused_attn_fp8_fwd_impl(
b, max_seqlen_q, max_seqlen_kv, h, d,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = { workspace_size };
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = { 1 };
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q,
const Tensor *input_K,
const Tensor *input_V,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQ,
const Tensor *output_dK,
const Tensor *output_dV,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
void* devPtrQ = input_Q->data.dptr;
void* devPtrK = input_K->data.dptr;
void* devPtrV = input_V->data.dptr;
void* devPtrDescaleQ = input_Q->scale_inv.dptr;
void* devPtrDescaleK = input_Q->scale_inv.dptr;
void* devPtrDescaleV = input_Q->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr;
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
void* devPtrM = input_M->data.dptr;
void* devPtrZInv = input_ZInv->data.dptr;
void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdS = input_output_dP->amax.dptr;
void* devPtrScaledS = input_output_dP->scale.dptr;
void* devPtrDescaledS = input_output_dP->scale_inv.dptr;
void* devPtrdQ = output_dQ->data.dptr;
void* devPtrdK = output_dK->data.dptr;
void* devPtrdV = output_dV->data.dptr;
void* devPtrAmaxdQ = output_dQ->amax.dptr;
void* devPtrAmaxdK = output_dQ->amax.dptr;
void* devPtrAmaxdV = output_dQ->amax.dptr;
void* devPtrScaledQ = output_dQ->scale.dptr;
void* devPtrScaledK = output_dQ->scale.dptr;
void* devPtrScaledV = output_dQ->scale.dptr;
void* devPtrcuSeqlensQ = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
void* devPtrcuSeqlensKV = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_kv->data.dptr));
void* devPtrDropoutSeed = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype;
size_t workspace_size = 0;
fused_attn::fused_attn_fp8_bwd_impl(
b, max_seqlen_q, max_seqlen_kv, h, d,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledS,
devPtrScaleS, devPtrScaledS,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdS,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = { workspace_size };
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = { 1 };
workspace->data.dtype = DType::kByte;
return;
}
}
#endif // end of CUDNN>=8900 #endif // end of CUDNN>=8900
} // namespace transformer_engine } // namespace transformer_engine
...@@ -46,5 +46,44 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -46,5 +46,44 @@ void fused_attn_fp8_bwd_qkvpacked(
Tensor *workspace, Tensor *workspace,
cudaStream_t stream, cudaStream_t stream,
cudnnHandle_t handle); cudnnHandle_t handle);
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQ,
const Tensor *output_dK,
const Tensor *output_dV,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
#endif // end of CUDNN>=8900 #endif // end of CUDNN>=8900
} // namespace transformer_engine } // namespace transformer_engine
...@@ -30,6 +30,7 @@ void generateMatrixStrides( ...@@ -30,6 +30,7 @@ void generateMatrixStrides(
constexpr int seqlen_q_dim_idx = 2; constexpr int seqlen_q_dim_idx = 2;
constexpr int seqlen_kv_dim_idx = 3; constexpr int seqlen_kv_dim_idx = 3;
// to be deprecated in the future
switch (matrix) { switch (matrix) {
case NVTE_QKV_Matrix::NVTE_Q_Matrix: case NVTE_QKV_Matrix::NVTE_Q_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
...@@ -37,7 +38,8 @@ void generateMatrixStrides( ...@@ -37,7 +38,8 @@ void generateMatrixStrides(
strideA[seqlen_dim_idx] = 3 * h * d; strideA[seqlen_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * 3 * h * d; strideA[batch_dim_idx] = s_q * 3 * h * d;
} else { } else if ((layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
|| (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)) {
strideA[hidden_dim_idx] = 1; strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d; strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
...@@ -55,7 +57,7 @@ void generateMatrixStrides( ...@@ -55,7 +57,7 @@ void generateMatrixStrides(
strideA[hidden_dim_idx] = 1; strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d; strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else { } else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[seqlen_dim_idx] = h * d; strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1; strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
...@@ -73,7 +75,7 @@ void generateMatrixStrides( ...@@ -73,7 +75,7 @@ void generateMatrixStrides(
strideA[hidden_transpose_dim_idx] = 1; strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d; strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else { } else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = h * d; strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1; strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
...@@ -91,7 +93,7 @@ void generateMatrixStrides( ...@@ -91,7 +93,7 @@ void generateMatrixStrides(
strideA[seqlen_dim_idx] = 2* h * d; strideA[seqlen_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d; strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else { } else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[hidden_dim_idx] = 1; strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d; strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
...@@ -109,7 +111,7 @@ void generateMatrixStrides( ...@@ -109,7 +111,7 @@ void generateMatrixStrides(
strideA[seqlen_transpose_dim_idx] = 2* h * d; strideA[seqlen_transpose_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d; strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else { } else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1; strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = h * d; strideA[seqlen_transpose_dim_idx] = h * d;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
...@@ -129,6 +131,228 @@ void generateMatrixStrides( ...@@ -129,6 +131,228 @@ void generateMatrixStrides(
strideA[batch_dim_idx] = s_q * h * d; strideA[batch_dim_idx] = s_q * h * d;
break; break;
} }
// new way of getting strides
switch (layout) {
case NVTE_QKV_Layout::NVTE_SB3HD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * 3 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = b * 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_SBH3D:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = 3 * d;
strideA[seqlen_dim_idx] = b * 3 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = 3 * d;
strideA[seqlen_transpose_dim_idx] = b * 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = 2 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * 2 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = 2 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = b * 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = 2 * h * d;
strideA[head_dim_idx] = 2 * d;
strideA[seqlen_dim_idx] = b * 2 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = 2 * h * d;
strideA[head_dim_idx] = 2 * d;
strideA[seqlen_transpose_dim_idx] = b * 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = b * h * d;
strideA[hidden_transpose_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_T3HD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_q * 3 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_q * 3 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSH3D:
case NVTE_QKV_Layout::NVTE_TH3D:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_q * 3 * h * d;
strideA[head_dim_idx] = 3 * d;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_q * 3 * h * d;
strideA[head_dim_idx] = 3 * d;
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_THD_T2HD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_kv * 2 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = 2 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_kv * 2 * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
case NVTE_QKV_Layout::NVTE_THD_TH2D:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_kv * 2 * h * d;
strideA[head_dim_idx] = 2 * d;
strideA[seqlen_dim_idx] = 2 * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_kv * 2 * h * d;
strideA[head_dim_idx] = 2 * d;
strideA[seqlen_transpose_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_kv * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose)
|| (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_kv * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
}
break;
}
if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) {
strideA[seqlen_kv_dim_idx] = 1;
strideA[seqlen_q_dim_idx] = s_kv;
strideA[head_dim_idx] = s_q * s_kv;
strideA[batch_dim_idx] = h * s_q * s_kv;
}
} }
bool allowAllConfig(cudnnBackendDescriptor_t engine_config) { bool allowAllConfig(cudnnBackendDescriptor_t engine_config) {
......
...@@ -18,7 +18,17 @@ extern "C" { ...@@ -18,7 +18,17 @@ extern "C" {
#endif #endif
/*! \enum NVTE_QKV_Layout /*! \enum NVTE_QKV_Layout
* \brief QKV matrix layouts * \brief Memory layouts of QKV tensors
* `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, the number of heads,
head size, and the total number of sequences in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
`SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
or padded to the same length, and `THD`-based layouts are used when sequences have
different lengths in a batch.
* \note {`NVTE_QKV_INTERLEAVED`, `NVTE_KV_INTERLEAVED` and `NVTE_NOT_INTERLEAVED`
will be deprecated in the next release. Please use their equivalent enums instead, i.e. `NVTE_T3HD`,
`NVTE_THD_T2HD` and `NVTE_THD_THD_THD` when sequences are of variable lengths, and `NVTE_BS3HD`,
`NVTE_BSHD_BS2HD` and `NVTE_BSHD_BSHD_BSHD` when sequences are of equal length or padded
to equal length.}
*/ */
enum NVTE_QKV_Layout { enum NVTE_QKV_Layout {
/*! Separate Q, K, V tensors. /*! Separate Q, K, V tensors.
...@@ -67,7 +77,51 @@ enum NVTE_QKV_Layout { ...@@ -67,7 +77,51 @@ enum NVTE_QKV_Layout {
| num_heads * head_dim | num_heads * head_dim
\endverbatim \endverbatim
*/ */
NVTE_KV_INTERLEAVED = 2 NVTE_KV_INTERLEAVED = 2,
NVTE_SB3HD = 3,
NVTE_SBH3D = 4,
NVTE_SBHD_SB2HD = 5,
NVTE_SBHD_SBH2D = 6,
NVTE_SBHD_SBHD_SBHD = 7,
NVTE_BS3HD = 8,
NVTE_BSH3D = 9,
NVTE_BSHD_BS2HD = 10,
NVTE_BSHD_BSH2D = 11,
NVTE_BSHD_BSHD_BSHD = 12,
NVTE_T3HD = 13,
NVTE_TH3D = 14,
NVTE_THD_T2HD = 15,
NVTE_THD_TH2D = 16,
NVTE_THD_THD_THD = 17,
};
/*! \enum NVTE_QKV_Layout_Group
* \brief Grouping of QKV layouts
*/
enum NVTE_QKV_Layout_Group {
/*! 3HD QKV layouts, e.g. BS3HD */
NVTE_3HD = 0,
/*! H3D QKV layouts, e.g. BSH3D */
NVTE_H3D = 1,
/*! HD_2HD QKV layouts, e.g. BSHD_BS2HD */
NVTE_HD_2HD = 2,
/*! HD_H2D QKV layouts, e.g. BSHD_BSH2D */
NVTE_HD_H2D = 3,
/*! HD_HD_HD QKV layouts, e.g. BSHD_BSHD_BSHD */
NVTE_HD_HD_HD = 4,
};
/*! \enum NVTE_QKV_Format
* \brief Dimension formats for QKV tensors
*/
enum NVTE_QKV_Format {
/*! SBHD QKV format */
NVTE_SBHD = 0,
/*! BSHD QKV format */
NVTE_BSHD = 1,
/*! THD QKV format */
NVTE_THD = 2,
}; };
/*! \enum NVTE_Bias_Type /*! \enum NVTE_Bias_Type
...@@ -94,6 +148,9 @@ enum NVTE_Mask_Type { ...@@ -94,6 +148,9 @@ enum NVTE_Mask_Type {
NVTE_CAUSAL_MASK = 2, NVTE_CAUSAL_MASK = 2,
}; };
/*! \enum NVTE_Fused_Attn_Backend
* \brief Fused attention backends
*/
enum NVTE_Fused_Attn_Backend { enum NVTE_Fused_Attn_Backend {
/*! No supported backend */ /*! No supported backend */
NVTE_No_Backend = -1, NVTE_No_Backend = -1,
...@@ -105,6 +162,22 @@ enum NVTE_Fused_Attn_Backend { ...@@ -105,6 +162,22 @@ enum NVTE_Fused_Attn_Backend {
NVTE_FP8 = 2, NVTE_FP8 = 2,
}; };
/*! \brief Get layout group for a given QKV layout
*
* \param[in] qkv_layout QKV layout, e.g. sbh3d.
*
* \return qkv layout group, e.g. h3d.
*/
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout);
/*! \brief Get QKV format for a given QKV layout
*
* \param[in] qkv_layout QKV layout, e.g. sbh3d.
*
* \return qkv format, e.g. sbhd.
*/
NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters. /*! \brief Get fused attention backend based on input parameters.
* *
* \param[in] q_dtype The data type of Tensor Q. * \param[in] q_dtype The data type of Tensor Q.
...@@ -152,7 +225,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -152,7 +225,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens). * it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
...@@ -199,7 +272,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -199,7 +272,7 @@ void nvte_fused_attn_fwd_qkvpacked(
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens). * it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
...@@ -250,9 +323,9 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -250,9 +323,9 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q). * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV. * \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv). * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
...@@ -301,9 +374,9 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -301,9 +374,9 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q). * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV. * \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv). * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
...@@ -332,6 +405,122 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -332,6 +405,122 @@ void nvte_fused_attn_bwd_kvpacked(
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
* Computes:
* - P = Q * Transpose(K) + Bias
* - S = ScaleMaskSoftmax(P)
* - D = Dropout(S)
* - O = D * Transpose(V)
*
* Support Matrix:
\verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor.
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(
const NVTETensor Q,
const NVTETensor K,
const NVTETensor V,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
* Support Matrix:
\verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor.
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd(
const NVTETensor Q,
const NVTETensor K,
const NVTETensor V,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ,
NVTETensor dK,
NVTETensor dV,
NVTETensor dBias,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -20,6 +20,8 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -20,6 +20,8 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_bwd_qkvpacked, fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked, fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked, fused_attn_bwd_kvpacked,
fused_attn_fwd,
fused_attn_bwd,
QKVLayout, QKVLayout,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
...@@ -37,6 +39,7 @@ from transformer_engine.pytorch.constants import ( ...@@ -37,6 +39,7 @@ from transformer_engine.pytorch.constants import (
AttnMaskTypes, AttnMaskTypes,
AttnTypes, AttnTypes,
AttnBiasTypes, AttnBiasTypes,
QKVLayouts,
dist_group_type, dist_group_type,
TE_DType, TE_DType,
) )
...@@ -565,64 +568,6 @@ class _SplitAlongDim(torch.autograd.Function): ...@@ -565,64 +568,6 @@ class _SplitAlongDim(torch.autograd.Function):
return torch.cat(grad_outputs, dim = split_dim), None, None return torch.cat(grad_outputs, dim = split_dim), None, None
class _CombineQKV(torch.autograd.Function):
""""""
@staticmethod
def forward(ctx,
query_layer: torch.Tensor,
key_layer: torch.Tensor, # pylint: disable=unused-argument
value_layer: torch.Tensor, # pylint: disable=unused-argument
dim: int,
) -> torch.Tensor:
mixed_layer = torch.Tensor().to(device=query_layer.device,
dtype=query_layer.dtype)
new_shape = list(query_layer.shape)
new_shape[dim] = new_shape[dim] * 3
mixed_layer.set_(query_layer.untyped_storage(),
query_layer.storage_offset(),
new_shape,
query_layer.stride())
ctx.dim = dim
return mixed_layer
@staticmethod
def backward(ctx,
*grad_outputs,
) -> Tuple[torch.Tensor, ...]:
assert len(grad_outputs) > 0, "No gradients received for backprop!"
tensors = split_tensor_along_dim(grad_outputs[0], ctx.dim, 3)
return tensors[0], tensors[1], tensors[2], None
class _CombineKV(torch.autograd.Function):
""""""
@staticmethod
def forward(ctx,
key_layer: torch.Tensor,
value_layer: torch.Tensor, # pylint: disable=unused-argument
dim: int,
) -> torch.Tensor:
mixed_layer = torch.Tensor().to(device=key_layer.device,
dtype=key_layer.dtype)
new_shape = list(key_layer.shape)
new_shape[dim] = new_shape[dim] * 2
mixed_layer.set_(key_layer.untyped_storage(),
key_layer.storage_offset(),
new_shape,
key_layer.stride())
ctx.dim = dim
return mixed_layer
@staticmethod
def backward(ctx,
*grad_outputs,
) -> Tuple[torch.Tensor, ...]:
assert len(grad_outputs) > 0, "No gradients received for backprop!"
tensors = split_tensor_along_dim(grad_outputs[0], ctx.dim, 2)
return tensors[0], tensors[1], None
class UnfusedDotProductAttention(torch.nn.Module): class UnfusedDotProductAttention(torch.nn.Module):
...@@ -659,6 +604,9 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -659,6 +604,9 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
...@@ -666,6 +614,15 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -666,6 +614,15 @@ class UnfusedDotProductAttention(torch.nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
"""core attention fprop""" """core attention fprop"""
assert (qkv_layout in QKVLayouts
), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
assert (qkv_format != 'thd'
), """UnfusedDotProductAttention does not support variable sequence lengths!"""
if qkv_format == 'bshd':
# convert to sbhd and use sbhd implementation for now
query_layer, key_layer, value_layer = [x.transpose(0, 1)
for x in [query_layer, key_layer, value_layer]]
assert ( assert (
attn_mask_type in AttnMaskTypes attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported" ), f"attn_mask_type {attn_mask_type} not supported"
...@@ -681,7 +638,6 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -681,7 +638,6 @@ class UnfusedDotProductAttention(torch.nn.Module):
key_layer.size(0), key_layer.size(0),
) )
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
if key_layer.shape[2] != query_layer.shape[2]: if key_layer.shape[2] != query_layer.shape[2]:
assert (query_layer.shape[2]%key_layer.shape[2]==0 assert (query_layer.shape[2]%key_layer.shape[2]==0
),"The number of attention heads must be divisible by the number of GQA groups!" ),"The number of attention heads must be divisible by the number of GQA groups!"
...@@ -791,12 +747,20 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -791,12 +747,20 @@ class UnfusedDotProductAttention(torch.nn.Module):
# change view [b, np, sq, hn] # change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size) context_layer = context_layer.view(*output_size)
if qkv_format == 'sbhd':
# [b, np, sq, hn] --> [sq, b, np, hn] # [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous() context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp] # [sq, b, np, hn] --> [sq, b, hp]
context_layer = context_layer.view(seqlen, batch_size, -1) context_layer = context_layer.view(seqlen, batch_size, -1)
if qkv_format == 'bshd':
# [b, np, sq, hn] --> [b, sq, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# [b, sq, np, hn] --> [b, sq, hp]
context_layer = context_layer.view(batch_size, seqlen, -1)
return context_layer return context_layer
...@@ -830,66 +794,100 @@ class _PrepareQKVForFA(torch.autograd.Function): ...@@ -830,66 +794,100 @@ class _PrepareQKVForFA(torch.autograd.Function):
dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3) dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
return dq, dk, dv return dq, dk, dv
def _get_qkv_layout(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qkv_format: str = 'sbhd',
) -> str:
"""Get qkv layout.
def _check_qkv_layout(q, k, v): Parameters
data_ptr = q.untyped_storage().data_ptr() ----------
check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) q: torch.Tensor
if not check_ptrs: Query tensor.
return False k: torch.Tensor
Key tensor.
stride = q.stride() v: torch.Tensor
check_strides = all(stride == x.stride() for x in [q, k, v]) Value tensor.
if not check_strides: qkv_format: str, default = `sbhd`
return False Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
the sequence length dimension, `b` batch size, `h` the number of attention heads,
shape = q.shape `d` head size, and `t` the total number of sequences in a batch, i.e.
check_shapes = all(shape == x.shape for x in [q, k, v]) `t = sum(s_i) for i = 0...b-1`.
if not check_shapes:
return False
last_dim_size = shape[-1]
check_offsets = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
if check_offsets:
return "sbh3d"
last_dims_size = shape[-1] * shape[-2] Returns
check_offsets = all(i * last_dims_size == x.storage_offset() ----------
for i, x in enumerate([q, k, v])) qkv_layout: str
if check_offsets: Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
return "sb3hd" memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
`q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
`v = kv[:,:,:,1,:]`.
Mapping:
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
`thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
"""
return "other" check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
def _check_kv_layout(k, v): data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
data_ptr = k.untyped_storage().data_ptr() data_ptr = k.untyped_storage().data_ptr()
check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])
if not check_ptrs:
return False
stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
stride = k.stride() stride = k.stride()
check_strides = all(stride == x.stride() for x in [k, v]) check_strides_kv = all(stride == x.stride() for x in [k, v])
if not check_strides:
return False
shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape shape = k.shape
check_shapes = all(shape == x.shape for x in [k, v]) check_shapes_kv = all(shape == x.shape for x in [k, v])
if not check_shapes:
return False
last_dim_size = shape[-1] last_dim_size = q.shape[-1]
check_offsets = all(i * last_dim_size == x.storage_offset() check_last_dim_offsets_qkv = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
last_dim_size = k.shape[-1]
check_last_dim_offsets_kv = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([k, v])) for i, x in enumerate([k, v]))
if check_offsets:
return "sbh2d"
last_dims_size = shape[-1] * shape[-2] last_two_dims_size = q.shape[-1] * q.shape[-2]
check_offsets = all(i * last_dims_size == x.storage_offset() check_last_two_dims_offsets_qkv = all(i * last_two_dims_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
last_two_dims_size = k.shape[-1] * k.shape[-2]
check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset()
for i, x in enumerate([k, v])) for i, x in enumerate([k, v]))
if check_offsets:
return "sb2hd"
return "other" qkv_layout = None
if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
and check_last_two_dims_offsets_qkv
and not check_last_dim_offsets_qkv):
# sb3hd, bs3hd, t3hd
qkv_layout = qkv_format[:-2] + '3' + qkv_format[-2:]
elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv:
# sbh3d, bsh3d, th3d
qkv_layout = qkv_format[:-1] + '3' + qkv_format[-1:]
elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
and check_last_two_dims_offsets_kv
and not check_last_dim_offsets_kv):
# sbhd_sb2hd, bshd_bs2hd, thd_t2hd
qkv_layout = qkv_format + '_' + qkv_format[:-2] + '2' + qkv_format[-2:]
elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
and check_last_dim_offsets_kv):
# sbhd_sbh2d, bshd_bsh2d, thd_th2d
qkv_layout = qkv_format + '_' + qkv_format[:-1] + '2' + qkv_format[-1:]
elif check_strides_kv and check_shapes_kv:
# sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
qkv_layout = '_'.join(list([qkv_format])*3)
else:
raise Exception("The provided qkv memory layout is not supported!")
return qkv_layout
class FlashAttention(torch.nn.Module): class FlashAttention(torch.nn.Module):
...@@ -920,6 +918,9 @@ class FlashAttention(torch.nn.Module): ...@@ -920,6 +918,9 @@ class FlashAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: Union[int] = None, cp_global_ranks: Union[int] = None,
...@@ -931,16 +932,21 @@ class FlashAttention(torch.nn.Module): ...@@ -931,16 +932,21 @@ class FlashAttention(torch.nn.Module):
query_layer.dtype in [torch.float16, torch.bfloat16] query_layer.dtype in [torch.float16, torch.bfloat16]
and key_layer.dtype in [torch.float16, torch.bfloat16] and key_layer.dtype in [torch.float16, torch.bfloat16]
and value_layer.dtype in [torch.float16, torch.bfloat16] and value_layer.dtype in [torch.float16, torch.bfloat16]
), 'FlashAttention currently only supports FP16 and BF16.' ), "FlashAttention currently only supports FP16 and BF16."
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), 'FlashAttention currently only supports CUDA tensors.' ), "FlashAttention currently only supports CUDA tensors."
assert (
qkv_layout in QKVLayouts
), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
# For now just 128, will make it more general in the future qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
if qkv_format == 'sbhd':
# For now just 128, will make it more general in the future
if (query_layer.shape[-1] == 128 and if (query_layer.shape[-1] == 128 and
query_layer.shape[0] * query_layer.shape[1] >= 512 and query_layer.shape[0] * query_layer.shape[1] >= 512 and
_check_qkv_layout(query_layer, key_layer, value_layer) == "sbh3d"): qkv_layout == "sbh3d"):
query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer, query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
key_layer, key_layer,
value_layer) value_layer)
...@@ -948,18 +954,42 @@ class FlashAttention(torch.nn.Module): ...@@ -948,18 +954,42 @@ class FlashAttention(torch.nn.Module):
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous() query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)] for x in (query_layer, key_layer, value_layer)]
batch_size, seqlen = query_layer.shape[0], query_layer.shape[1] if qkv_format == 'bshd':
query_layer, key_layer, value_layer = [x.contiguous()
for x in (query_layer, key_layer, value_layer)]
max_seqlen = seqlen if qkv_format in ['sbhd', 'bshd']:
cu_seqlens = torch.arange( batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0, 0,
(batch_size + 1) * seqlen, (batch_size + 1) * max_seqlen_q,
step=seqlen, step=max_seqlen_q,
dtype=torch.int32, dtype=torch.int32,
device=query_layer.device) device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
if qkv_format == 'thd':
assert (cp_group is None or get_distributed_world_size(cp_group) == 1
), "thd format is not supported for context parallelism!"
assert (_flash_attn_2_available
), "flash-attn v2 is required for variable sequence length support!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_q = seqlens_q.max().item()
max_seqlen_kv = seqlens_kv.max().item()
if cp_group is None or get_distributed_world_size(cp_group) == 1: if cp_group is None or get_distributed_world_size(cp_group) == 1:
# [b, sq, np, hn] # [b * s, h, d]
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:]) x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer] for x in [query_layer, key_layer, value_layer]
...@@ -971,7 +1001,7 @@ class FlashAttention(torch.nn.Module): ...@@ -971,7 +1001,7 @@ class FlashAttention(torch.nn.Module):
fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_kwargs["deterministic"] = self.deterministic
output = flash_attn_forward_func( output = flash_attn_forward_func(
query_layer, key_layer, value_layer, query_layer, key_layer, value_layer,
cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, softmax_scale=1.0/self.norm_factor,
causal=attn_mask_type=="causal", causal=attn_mask_type=="causal",
...@@ -981,7 +1011,7 @@ class FlashAttention(torch.nn.Module): ...@@ -981,7 +1011,7 @@ class FlashAttention(torch.nn.Module):
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
output = flash_attn_forward_func_with_cp( output = flash_attn_forward_func_with_cp(
query_layer, key_layer, value_layer, query_layer, key_layer, value_layer,
cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream, cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor, softmax_scale=1.0/self.norm_factor,
...@@ -989,8 +1019,14 @@ class FlashAttention(torch.nn.Module): ...@@ -989,8 +1019,14 @@ class FlashAttention(torch.nn.Module):
deterministic=self.deterministic deterministic=self.deterministic
) )
# [(b sq), np, hn] -> [sq, b, (np hn)] if qkv_format == 'sbhd':
return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous() # (bs)hd -> bs(hd) -> sb(hd)
output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
if qkv_format == 'bshd':
# (bs)hd -> bs(hd)
output = output.view(batch_size, max_seqlen_q, -1).contiguous()
return output
class FusedAttnFunc_qkvpacked(torch.autograd.Function): class FusedAttnFunc_qkvpacked(torch.autograd.Function):
...@@ -1126,6 +1162,77 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -1126,6 +1162,77 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
class FusedAttnFunc(torch.autograd.Function):
"""Function for FusedAttention with separate Q, K, V tensors"""
@staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend, use_FAv2_bwd):
out, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv)
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
ctx.use_FAv2_bwd = use_FAv2_bwd
return out
@staticmethod
def backward(ctx, d_out):
q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
d_out, q, k, v, out = [maybe_contiguous(x)
for x in (d_out, q, k, v, out)]
flash_attn_cuda_bwd(
d_out, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
ctx.dropout_p, ctx.attn_scale, False,
ctx.attn_mask_type == "causal", None, rng_state
)
dq = dq[..., :d_out.shape[-1]]
dk = dk[..., :d_out.shape[-1]]
dv = dv[..., :d_out.shape[-1]]
else:
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
return (None, None, None, None, None, dq, dk, dv, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None)
# else, return (dqkv, dbias)
return (None, None, None, None, None, dq, dk, dv, None, rest[0], None,
None, None, None, None, None, None,
None, None, None, None, None, None)
class FusedAttention(torch.nn.Module): class FusedAttention(torch.nn.Module):
"""Dot product attention, with multiple backends: """Dot product attention, with multiple backends:
...@@ -1144,6 +1251,9 @@ class FusedAttention(torch.nn.Module): ...@@ -1144,6 +1251,9 @@ class FusedAttention(torch.nn.Module):
| qkv_layout | | | | qkv_layout | | |
| - qkv | qkv_interleaved | qkv_interleaved | | - qkv | qkv_interleaved | qkv_interleaved |
| - (q,kv) | kv_interleaved | | | - (q,kv) | kv_interleaved | |
| - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd |
| | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd |
| | bshd_bshd_bshd | sbhd_sbhd_sbhd, bshd_bshd_bshd |
| mask_type | causal/no_mask | causal | | mask_type | causal/no_mask | causal |
| bias_type | no_bias/post_scale_bias | no_bias | | bias_type | no_bias/post_scale_bias | no_bias |
| dropout | yes | yes | | dropout | yes | yes |
...@@ -1174,6 +1284,9 @@ class FusedAttention(torch.nn.Module): ...@@ -1174,6 +1284,9 @@ class FusedAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
fused_attention_backend: fused_attention_backend:
tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
...@@ -1194,116 +1307,51 @@ class FusedAttention(torch.nn.Module): ...@@ -1194,116 +1307,51 @@ class FusedAttention(torch.nn.Module):
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), 'FusedAttention only supports CUDA tensors.' ), 'FusedAttention only supports CUDA tensors.'
assert (
qkv_dtype = TE_DType[query_layer.dtype] qkv_layout in QKVLayouts
seqlen_q, batch_size = query_layer.shape[0], query_layer.shape[1] ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
seqlen_kv = key_layer.shape[0]
max_seqlen_q = seqlen_q qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
max_seqlen_kv = seqlen_kv if qkv_format in ['sbhd', 'bshd']:
if qkv_format == 'sbhd':
if self.attention_type == "self": batch_size, max_seqlen_q, max_seqlen_kv = (
qkv_layout = _check_qkv_layout(query_layer, key_layer, value_layer) query_layer.shape[1], query_layer.shape[0], key_layer.shape[0])
if qkv_layout == "sbh3d": if qkv_format == 'bshd':
mixed_layer = _CombineQKV.apply(query_layer, key_layer, value_layer, 3) batch_size, max_seqlen_q, max_seqlen_kv = (
# [s, b, h, 3, d] query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
mixed_layer = mixed_layer.view( if cu_seqlens_q is None:
*mixed_layer.shape[0:3], 3, query_layer.shape[-1])
# [b, s, 3, h, d]
mixed_layer = mixed_layer.transpose(2, 3).transpose(0, 1).contiguous()
elif qkv_layout == "sb3hd":
mixed_layer = _CombineQKV.apply(query_layer, key_layer, value_layer, 2)
# [s, b, 3, h, d]
mixed_layer = mixed_layer.view(
*mixed_layer.shape[0:2], 3, *query_layer.shape[2:])
# [b, s, 3, h, d]
mixed_layer = mixed_layer.transpose(0, 1).contiguous()
else:
raise Exception("FusedAttention only supports qkv layout sbh3d or sb3hd!")
# [total_seqs, 3, h, d]
mixed_layer = mixed_layer.view(
mixed_layer.shape[0] * mixed_layer.shape[1], *mixed_layer.shape[2:])
qkv_layout = "qkv_interleaved"
max_seqlen = seqlen_q
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=query_layer.device)
use_FAv2_bwd = (self.use_FAv2_bwd
and (fused_attention_backend
== tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
and core_attention_bias_type == "no_bias")
with self.attention_dropout_ctx():
output = FusedAttnFunc_qkvpacked.apply(
self.training,
max_seqlen,
cu_seqlens,
mixed_layer,
qkv_dtype,
core_attention_bias,
1.0/self.norm_factor,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd
)
output = output.view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous()
if self.attention_type == "cross":
kv_layout = _check_kv_layout(key_layer, value_layer)
if kv_layout == "sbh2d":
key_value = _CombineKV.apply(key_layer, value_layer, 3)
# [s, b, h, 2, d]
key_value = key_value.view(
*key_value.shape[0:3], 2, key_layer.shape[-1])
# [b, s, 2, h, d]
key_value = key_value.transpose(2, 3).transpose(0, 1).contiguous()
elif qkv_layout == "sb2hd":
key_value = _CombineKV.apply(key_layer, value_layer, 2)
# [s, b, 2, h, d]
key_value = key_value.view(
*key_value.shape[0:2], 2, *key_layer.shape[2:])
# [b, s, 2, h, d]
key_value = key_value.transpose(0, 1).contiguous()
else:
raise Exception("FusedAttention only supports kv layout sbh2d or sb2hd!")
# [total_seqs, h, d]
query_layer = query_layer.transpose(0, 1).contiguous()
query_layer = query_layer.view(
query_layer.shape[0] * query_layer.shape[1], *query_layer.shape[2:])
# [total_seqs, 2, h, d]
key_value = key_value.view([key_value.shape[0] * key_value.shape[1]]
+ key_value.shape[2:])
qkv_layout = "kv_interleaved"
cu_seqlens_q = torch.arange( cu_seqlens_q = torch.arange(
0, 0,
(batch_size + 1) * seqlen_q, (batch_size + 1) * max_seqlen_q,
step=seqlen_q, step=max_seqlen_q,
dtype=torch.int32, dtype=torch.int32,
device=query_layer.device) device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange( cu_seqlens_kv = torch.arange(
0, 0,
(batch_size + 1) * seqlen_kv, (batch_size + 1) * max_seqlen_kv,
step=seqlen_kv, step=max_seqlen_kv,
dtype=torch.int32, dtype=torch.int32,
device=key_layer.device) device=key_layer.device)
if qkv_format == 'thd':
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_q = seqlens_q.max().item()
max_seqlen_kv = seqlens_kv.max().item()
qkv_dtype = TE_DType[query_layer.dtype]
use_FAv2_bwd = (self.use_FAv2_bwd
and (fused_attention_backend
== tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen))
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
outputs = FusedAttnFunc_kvpacked.apply( output = FusedAttnFunc.apply(
self.training, self.training,
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q, cu_seqlens_kv,
query_layer, key_value, query_layer, key_layer, value_layer,
qkv_dtype, qkv_dtype,
core_attention_bias, core_attention_bias,
1.0/self.norm_factor, 1.0/self.norm_factor,
...@@ -1314,12 +1362,11 @@ class FusedAttention(torch.nn.Module): ...@@ -1314,12 +1362,11 @@ class FusedAttention(torch.nn.Module):
attn_mask_type, attn_mask_type,
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd use_FAv2_bwd,
) )
output = (outputs[0].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous(), # ...hd -> ...(hd)
outputs[1].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous()) return output.view(*output.shape[:-2], -1)
return output
class DotProductAttention(torch.nn.Module): class DotProductAttention(torch.nn.Module):
...@@ -1358,6 +1405,16 @@ class DotProductAttention(torch.nn.Module): ...@@ -1358,6 +1405,16 @@ class DotProductAttention(torch.nn.Module):
layer_number: int, default = `None` layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks. are concatenated, for instance in consecutive transformer blocks.
qkv_format: str, default = `sbhd`
dimension format for `query_layer`, `key_layer` and `value_layer`,
{`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
`h` the number of heads, `d` head size, and `t` the total number of sequences
in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
are used for when sequences in a batch are of equal length or padded to
equal length, and the `thd` format is used for when sequences in a batch
have different lengths. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward :attr:`attn_mask_type` in the `forward` method. The forward
...@@ -1390,6 +1447,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1390,6 +1447,7 @@ class DotProductAttention(torch.nn.Module):
kv_channels: int, kv_channels: int,
num_gqa_groups: Optional[int] = None, num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
qkv_format: str = "sbhd",
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_size: int = 1, tp_size: int = 1,
...@@ -1403,6 +1461,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1403,6 +1461,7 @@ class DotProductAttention(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group self.tp_group = tp_group
...@@ -1496,6 +1555,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -1496,6 +1555,9 @@ class DotProductAttention(torch.nn.Module):
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
qkv_format: Optional[str] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
...@@ -1538,9 +1600,11 @@ class DotProductAttention(torch.nn.Module): ...@@ -1538,9 +1600,11 @@ class DotProductAttention(torch.nn.Module):
If FusedAttention is being used, users can also choose to switch to flash-attn's If FusedAttention is being used, users can also choose to switch to flash-attn's
implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1` implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
(default: 0), because of the performance differences between various versions of (default: 0), because of the performance differences between various versions of
flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT` flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
can be used to enable the workspace related optimizations in FusedAttention can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
(default: 256MB; raise the limit to enable these performance optimizations). optimizations in FusedAttention. When unset, TransformerEngine determines the code path
based on its internal logic. These optimizations trade memory for performance
and should be used with care.
Parameters Parameters
---------- ----------
...@@ -1550,6 +1614,14 @@ class DotProductAttention(torch.nn.Module): ...@@ -1550,6 +1614,14 @@ class DotProductAttention(torch.nn.Module):
Key tensor. Key tensor.
value_layer : torch.Tensor value_layer : torch.Tensor
Value tensor. Value tensor.
qkv_format: str, default = `None`
If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
with shape [batch_size + 1] and dtype torch.int32.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn. Boolean tensor used to mask out softmax input when not using flash-attn.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None` attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None`
...@@ -1567,12 +1639,57 @@ class DotProductAttention(torch.nn.Module): ...@@ -1567,12 +1639,57 @@ class DotProductAttention(torch.nn.Module):
Whether to use the fast path to set output tensors to 0 or not. Whether to use the fast path to set output tensors to 0 or not.
""" """
assert (key_layer.shape == value_layer.shape
), "Keys and values must have the same shape!"
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
if qkv_format is None:
qkv_format = self.qkv_format
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have {self.num_gqa_groups} heads!" ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
assert (qkv_format in ['sbhd', 'bshd', 'thd']
), "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"
if qkv_format == 'thd':
assert (all(len(x.shape) == 3 for x in (query_layer, key_layer, value_layer))
), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
and len(cu_seqlens_q.shape) == 1
and len(cu_seqlens_kv.shape) == 1
), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
assert (cu_seqlens_q.dtype == torch.int32
and cu_seqlens_kv.dtype == torch.int32
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_q = seqlens_q.max().item()
max_seqlen_kv = seqlens_kv.max().item()
if qkv_format in ['sbhd', 'bshd']:
assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer))
), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
if qkv_format == 'sbhd':
max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
if qkv_format == 'bshd':
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
assert (all(seqlens_q <= max_seqlen_q)
), """Sequence lengths indicated by cu_seqlens_q must be no greater than
the sequence dimention in 'query_layer'!"""
if cu_seqlens_kv is not None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
assert (all(seqlens_kv <= max_seqlen_kv)
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimention in 'key_layer' and 'value_layer'!"""
qkv_layout = _get_qkv_layout(query_layer, key_layer, value_layer,
qkv_format = qkv_format)
use_flash_attention = self.use_flash_attention use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention use_fused_attention = self.use_fused_attention
...@@ -1603,8 +1720,6 @@ class DotProductAttention(torch.nn.Module): ...@@ -1603,8 +1720,6 @@ class DotProductAttention(torch.nn.Module):
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
qkv_layout = "qkv_interleaved" if self.attention_type == "self" else "kv_interleaved"
if use_fused_attention: if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend( fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
...@@ -1613,7 +1728,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1613,7 +1728,7 @@ class DotProductAttention(torch.nn.Module):
AttnBiasType[core_attention_bias_type], AttnBiasType[core_attention_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
self.attention_dropout, self.attention_dropout,
query_layer.shape[0], key_layer.shape[0], max_seqlen_q, max_seqlen_kv,
query_layer.shape[-1]) query_layer.shape[-1])
# DPA does not support FP8; for FP8, use cpp_extensions modules directly # DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in is_backend_avail = (fused_attention_backend in
...@@ -1635,17 +1750,21 @@ class DotProductAttention(torch.nn.Module): ...@@ -1635,17 +1750,21 @@ class DotProductAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
attn_mask_type=attn_mask_type, qkv_layout = qkv_layout,
cp_group=self.cp_group, cu_seqlens_q = cu_seqlens_q,
cp_global_ranks=self.cp_global_ranks, cu_seqlens_kv = cu_seqlens_kv,
cp_stream=self.cp_stream) attn_mask_type = attn_mask_type,
return self.flash_attention(query_layer, cp_group = self.cp_group,
key_layer, cp_global_ranks = self.cp_global_ranks,
value_layer, cp_stream = self.cp_stream)
attn_mask_type=attn_mask_type, return self.flash_attention(query_layer, key_layer, value_layer,
cp_group=self.cp_group, qkv_layout = qkv_layout,
cp_global_ranks=self.cp_global_ranks, cu_seqlens_q = cu_seqlens_q,
cp_stream=self.cp_stream) cu_seqlens_kv = cu_seqlens_kv,
attn_mask_type = attn_mask_type,
cp_group = self.cp_group,
cp_global_ranks = self.cp_global_ranks,
cp_stream = self.cp_stream)
assert ( assert (
self.cp_group is None or get_distributed_world_size(self.cp_group) == 1 self.cp_group is None or get_distributed_world_size(self.cp_group) == 1
...@@ -1657,17 +1776,23 @@ class DotProductAttention(torch.nn.Module): ...@@ -1657,17 +1776,23 @@ class DotProductAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
attn_mask_type=attn_mask_type, qkv_layout = qkv_layout,
fused_attention_backend=fused_attention_backend, cu_seqlens_q = cu_seqlens_q,
core_attention_bias_type=core_attention_bias_type, cu_seqlens_kv = cu_seqlens_kv,
core_attention_bias=core_attention_bias, attn_mask_type = attn_mask_type,
fast_zero_fill=fast_zero_fill) fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill)
return self.fused_attention(query_layer, key_layer, value_layer, return self.fused_attention(query_layer, key_layer, value_layer,
attn_mask_type=attn_mask_type, qkv_layout = qkv_layout,
fused_attention_backend=fused_attention_backend, cu_seqlens_q = cu_seqlens_q,
core_attention_bias_type=core_attention_bias_type, cu_seqlens_kv = cu_seqlens_kv,
core_attention_bias=core_attention_bias, attn_mask_type = attn_mask_type,
fast_zero_fill=fast_zero_fill) fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
...@@ -1675,19 +1800,23 @@ class DotProductAttention(torch.nn.Module): ...@@ -1675,19 +1800,23 @@ class DotProductAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
attn_mask_type=attn_mask_type, qkv_layout = qkv_layout,
attention_mask=attention_mask, cu_seqlens_q = cu_seqlens_q,
core_attention_bias_type=core_attention_bias_type, cu_seqlens_kv = cu_seqlens_kv,
core_attention_bias=core_attention_bias, attn_mask_type = attn_mask_type,
) attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias)
return self.unfused_attention(query_layer, return self.unfused_attention(query_layer,
key_layer, key_layer,
value_layer, value_layer,
attn_mask_type=attn_mask_type, qkv_layout = qkv_layout,
attention_mask=attention_mask, cu_seqlens_q = cu_seqlens_q,
core_attention_bias_type=core_attention_bias_type, cu_seqlens_kv = cu_seqlens_kv,
core_attention_bias=core_attention_bias, attn_mask_type = attn_mask_type,
) attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias)
class MultiheadAttention(torch.nn.Module): class MultiheadAttention(torch.nn.Module):
...@@ -2313,6 +2442,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2313,6 +2442,9 @@ class MultiheadAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
qkv_format='sbhd',
cu_seqlens_q=None,
cu_seqlens_kv=None,
attention_mask=attention_mask, attention_mask=attention_mask,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
......
...@@ -28,6 +28,11 @@ AttnTypes = ("self", "cross") ...@@ -28,6 +28,11 @@ AttnTypes = ("self", "cross")
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias") AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias")
QKVLayouts = (
"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd")
LayerTypes = ("encoder", "decoder") LayerTypes = ("encoder", "decoder")
GemmParallelModes = ("row", "column", None) GemmParallelModes = ("row", "column", None)
......
...@@ -18,7 +18,9 @@ from transformer_engine_extensions import ( ...@@ -18,7 +18,9 @@ from transformer_engine_extensions import (
__all__ = ['fused_attn_fwd_qkvpacked', __all__ = ['fused_attn_fwd_qkvpacked',
'fused_attn_bwd_qkvpacked', 'fused_attn_bwd_qkvpacked',
'fused_attn_fwd_kvpacked', 'fused_attn_fwd_kvpacked',
'fused_attn_bwd_kvpacked'] 'fused_attn_bwd_kvpacked',
'fused_attn_fwd',
'fused_attn_bwd']
TORCH_DType = { TORCH_DType = {
...@@ -34,6 +36,21 @@ QKVLayout = { ...@@ -34,6 +36,21 @@ QKVLayout = {
"not_interleaved": NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED, "not_interleaved": NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED,
"qkv_interleaved": NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, "qkv_interleaved": NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED,
"kv_interleaved": NVTE_QKV_Layout.NVTE_KV_INTERLEAVED, "kv_interleaved": NVTE_QKV_Layout.NVTE_KV_INTERLEAVED,
"sb3hd": NVTE_QKV_Layout.NVTE_SB3HD,
"sbh3d": NVTE_QKV_Layout.NVTE_SBH3D,
"sbhd_sb2hd": NVTE_QKV_Layout.NVTE_SBHD_SB2HD,
"sbhd_sbh2d": NVTE_QKV_Layout.NVTE_SBHD_SBH2D,
"sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_SBHD_SBHD_SBHD,
"bs3hd": NVTE_QKV_Layout.NVTE_BS3HD,
"bsh3d": NVTE_QKV_Layout.NVTE_BSH3D,
"bshd_bs2hd": NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
"bshd_bsh2d": NVTE_QKV_Layout.NVTE_BSHD_BSH2D,
"bshd_bshd_bshd": NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
"t3hd": NVTE_QKV_Layout.NVTE_T3HD,
"th3d": NVTE_QKV_Layout.NVTE_TH3D,
"thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD,
"thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D,
"thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD,
} }
AttnBiasType = { AttnBiasType = {
...@@ -166,9 +183,10 @@ def fused_attn_fwd_qkvpacked( ...@@ -166,9 +183,10 @@ def fused_attn_fwd_qkvpacked(
if True, runs training and produces auxiliary tensors aux_ctx_tensors if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen: int max_seqlen: int
max sequence length for QKV, used for padding; may be larger than max(cu_seqlens) max sequence length for QKV, used for padding; may be larger than max(seqlens),
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
cu_seqlens: torch.Tensor cu_seqlens: torch.Tensor
accumulative sequence lengths for QKV; shape [batch_size + 1] cumulative sequence lengths for QKV; shape [batch_size + 1]
qkv: torch.Tensor qkv: torch.Tensor
input tensor QKV; input tensor QKV;
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1] shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
...@@ -336,9 +354,10 @@ def fused_attn_bwd_qkvpacked( ...@@ -336,9 +354,10 @@ def fused_attn_bwd_qkvpacked(
Parameters Parameters
---------- ----------
max_seqlen: int max_seqlen: int
max sequence length for QKV, used for padding; may be larger than max(cu_seqlens_q) max sequence length for QKV, used for padding; may be larger than max(seqlens)
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
cu_seqlens: torch.Tensor cu_seqlens: torch.Tensor
accumulative sequence lengths for QKV; shape [batch_size + 1] cumulative sequence lengths for QKV; shape [batch_size + 1]
qkv: torch.Tensor qkv: torch.Tensor
input tensor QKV; input tensor QKV;
shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1] shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
...@@ -482,7 +501,7 @@ def fused_attn_fwd_kvpacked( ...@@ -482,7 +501,7 @@ def fused_attn_fwd_kvpacked(
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "kv_interleaved",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
...@@ -495,13 +514,15 @@ def fused_attn_fwd_kvpacked( ...@@ -495,13 +514,15 @@ def fused_attn_fwd_kvpacked(
if True, runs training and produces auxiliary tensors aux_ctx_tensors if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen_q: int max_seqlen_q: int
max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q) max sequence length for Q, used for padding; may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int max_seqlen_kv: int
max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv) max sequence length for KV, used for padding; may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor cu_seqlens_q: torch.Tensor
accumulative sequence lengths for Q; shape [batch_size + 1] cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor cu_seqlens_kv: torch.Tensor
accumulative sequence lengths for KV; shape [batch_size + 1] cumulative sequence lengths for KV; shape [batch_size + 1]
q: torch.Tensor q: torch.Tensor
input tensor Q; input tensor Q;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
...@@ -535,7 +556,7 @@ def fused_attn_fwd_kvpacked( ...@@ -535,7 +556,7 @@ def fused_attn_fwd_kvpacked(
fast_zero_fill: bool, default = True fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method; if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved" qkv_layout: str, default = "kv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
attn_bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
...@@ -659,7 +680,7 @@ def fused_attn_bwd_kvpacked( ...@@ -659,7 +680,7 @@ def fused_attn_bwd_kvpacked(
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
qkv_layout: str = "qkv_interleaved", qkv_layout: str = "kv_interleaved",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -668,13 +689,15 @@ def fused_attn_bwd_kvpacked( ...@@ -668,13 +689,15 @@ def fused_attn_bwd_kvpacked(
Parameters Parameters
---------- ----------
max_seqlen_q: int max_seqlen_q: int
max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q) max sequence length for Q, used for padding; may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int max_seqlen_kv: int
max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv) max sequence length for KV, used for padding; may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor cu_seqlens_q: torch.Tensor
accumulative sequence lengths for Q; shape [batch_size + 1] cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor cu_seqlens_kv: torch.Tensor
accumulative sequence lengths for KV; shape [batch_size + 1] cumulative sequence lengths for KV; shape [batch_size + 1]
q: torch.Tensor q: torch.Tensor
input tensor Q; input tensor Q;
shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
...@@ -723,7 +746,7 @@ def fused_attn_bwd_kvpacked( ...@@ -723,7 +746,7 @@ def fused_attn_bwd_kvpacked(
fast_zero_fill: bool, default = True fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method; if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "qkv_interleaved" qkv_layout: str, default = "kv_interleaved"
layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
attn_bias_type: str, default = "no_bias" attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
...@@ -812,3 +835,365 @@ def fused_attn_bwd_kvpacked( ...@@ -812,3 +835,365 @@ def fused_attn_bwd_kvpacked(
return output_tensors return output_tensors
# otherwise return (d_q, d_kv), d_bias # otherwise return (d_q, d_kv), d_bias
return output_tensors[:2], output_tensors[2] return output_tensors[:2], output_tensors[2]
def fused_attn_fwd(
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None,
amax_o: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
fast_zero_fill: bool = True,
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input.
Parameters
----------
is_training: bool
if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen_q: int
max sequence length for Q, used for padding;
may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int
max sequence length for K and V, used for padding;
may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor
cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor
input tensor Q;
shape [total_seqs_q, num_heads, head_dim],
where total_seqs_q = cu_seqlens_q[-1],
or [batch_size, seqlen_q, num_heads, head_dim],
or [seqlen_q, batch_size, num_heads, head_dim]
k: torch.Tensor
input tensor K;
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
v: torch.Tensor
input tensor V;
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
qkv_dtype: tex.DType
data type of Q, K and V; in tex.DType, not torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None
input tensor for the quantization of O in FP8 computations
amax_s: torch.Tensor, default = None
output tensor, amax of S, used by the next iteration in FP8 computations
amax_o: torch.Tensor, default = None
output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "sbh3d"
layout of Q, K and V;
{"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
Returns
----------
o: torch.Tensor
output tensor O, of the attention calculation; same data type as Q, K and V;
same shape as Q
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
if is_training is False, aux_ctx_tensors = None
softmax-related tensors:
1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
softmax: torch.Tensor
Softmax(Q*K.T)
shape [batch_size, num_heads, max_seqlen_q, max_seqlen_kv], dtype float32
2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
softmaxStats: torch.Tensor
log(sum(e^(x - max(x)))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32
3. if fused_attention_backend == FusedAttnBackend["FP8"]
M: torch.Tensor
max(Q*K.T)
shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32
ZInv: torch.Tensor
1/sum(e^(x - max(x))), where x=Q*K.T
shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32
rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator;
[seed, offset], dtype uint64
"""
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if attn_bias_type != "no_bias":
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype
), "attn_bias tensor must be in the same dtype as q and kv."
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
# FP8 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["FP8"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_q
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
# execute kernel
output_tensors = tex.fused_attn_fwd(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread,
)
# out, aux_ctx_tensors
return output_tensors[0], output_tensors[1:]
def fused_attn_bwd(
max_seqlen_q: int,
max_seqlen_kv: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
d_o: torch.Tensor,
qkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None,
amax_dp: torch.Tensor = None,
amax_dqkv: torch.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
fast_zero_fill: bool = True,
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input.
Parameters
----------
max_seqlen_q: int
max sequence length for Q, used for padding; may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int
max sequence length for K and V, used for padding;
may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor
cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor
input tensor Q;
shape [total_seqs_q, num_heads, head_dim],
where total_seqs_q = cu_seqlens_q[-1],
or [batch_size, seqlen_q, num_heads, head_dim],
or [seqlen_q, batch_size, num_heads, head_dim]
k: torch.Tensor
input tensor K;
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
v: torch.Tensor
input tensor V;
shape [total_seqs_kv, num_heads, head_dim],
where total_seqs_kv = cu_seqlens_kv[-1],
or [batch_size, seqlen_kv, num_heads, head_dim],
or [seqlen_kv, batch_size, num_heads, head_dim]
o: torch.Tensor
input tensor O (output of forward); same data type as Q, K and V;
same shape as Q
d_o: torch.Tensor
input tensor dO (gradient of O); same data type as Q, K and V;
same shape as Q
qkv_dtype: tex.DType
data type of Q, K and V; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
d_scale_o: torch.Tensor, default = None
input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None
input tensor for the quantization of dP in FP8 computations, P = Q * K.T
q_scale_dqkv: torch.Tensor, default = None
input tensor for the quantization of dQ, dK and dV in FP8 computations
amax_dp: torch.Tensor, default = None
output tensor, amax of dP, used by the next iteration in FP8 computations,
P = Q * K.T
amax_dqkv: torch.Tensor, default = None
output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
fast_zero_fill: bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "sbh3d"
layout of Q, K and V;
{"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "no_mask"}
Returns
----------
d_q: torch.Tensor
gradient tensor of Q; same data type and shape as Q
d_k: torch.Tensor
gradient tensor of K; same data type and shape as K
d_v: torch.Tensor
gradient tensor of V; same data type and shape as V
d_bias: torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias
"""
check_cu_seqlens(cu_seqlens_q)
check_cu_seqlens(cu_seqlens_kv)
assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
), "cu_seqlens_q and cu_seqlens_kv must have the same length."
b = cu_seqlens_q.numel() - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]:
assert (len(aux_ctx_tensors) >= 1
), "aux_ctx_tensors must contain rng_state as its last element."
rng_state = aux_ctx_tensors[-1]
check_rng_state(rng_state)
if fused_attention_backend == FusedAttnBackend["FP8"]:
assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention."
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
assert (amax_dp is not None), "amax_dp is required for FP8 fused attention."
assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention."
assert (len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."
check_scalar(d_scale_qkv)
check_scalar(d_scale_s)
check_scalar(d_scale_o)
check_scalar(d_scale_do)
check_scalar(q_scale_s)
check_scalar(q_scale_dp)
check_scalar(q_scale_dqkv)
check_scalar(amax_dp)
check_scalar(amax_dqkv)
m, z_inv = aux_ctx_tensors[:2]
check_stats(m, b, h, max_seqlen_q)
check_stats(z_inv, b, h, max_seqlen_q)
# execute kernel
output_tensors = tex.fused_attn_bwd(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
return tuple(output_tensors)
...@@ -106,6 +106,52 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -106,6 +106,52 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV); c10::optional<at::Tensor> amax_dQKV);
std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor K,
const at::Tensor V,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor K,
const at::Tensor V,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV);
at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
......
...@@ -717,6 +717,444 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -717,6 +717,444 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
return {dQ, dKV, dBias}; return {dQ, dKV, dBias};
} }
// fused attention FWD with separate Q, K and V tensors
std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor K,
const at::Tensor V,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
using namespace transformer_engine;
auto q_sizes = Q.sizes().vec();
std::vector<size_t> q_shape{q_sizes.begin(), q_sizes.end()};
auto k_sizes = K.sizes().vec();
std::vector<size_t> k_shape{k_sizes.begin(), k_sizes.end()};
auto v_sizes = V.sizes().vec();
std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()};
// create output tensor O
auto O = torch::empty_like(Q);
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias;
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h = Q.size(-2);
auto d = Q.size(-1);
if (set_zero && ((h * d) % block_size == 0)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
}
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape,
qkv_type, nullptr, nullptr, nullptr);
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape,
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) {
auto bias_sizes = Bias.value().sizes().vec();
std::vector<size_t> bias_shape{bias_sizes.begin(), bias_sizes.end()};
te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape,
DType::kFloat32, nullptr, nullptr, nullptr);
}
auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec();
std::vector<size_t> cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()};
auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec();
std::vector<size_t> cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()};
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
// extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options);
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(
te_Q.data(),
te_K.data(),
te_V.data(),
te_Bias.data(),
te_S.data(),
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// output_tensors = [O, nvte_aux_tensor_pack.tensors]
std::vector<at::Tensor> output_tensors;
output_tensors.push_back(O);
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
} else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
}
output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr();
}
// execute the kernel
nvte_fused_attn_fwd(
te_Q.data(),
te_K.data(),
te_V.data(),
te_Bias.data(),
te_S.data(),
te_O.data(),
&nvte_aux_tensor_pack,
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
te_rng_state.data(),
max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
// if training, [O, softmax-related tensors, rng_state]; if inference, [O]
return output_tensors;
}
// fused attention BWD with separate Q, K and V
std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor K,
const at::Tensor V,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine;
auto q_sizes = Q.sizes().vec();
std::vector<size_t> q_shape{q_sizes.begin(), q_sizes.end()};
auto k_sizes = K.sizes().vec();
std::vector<size_t> k_shape{k_sizes.begin(), k_sizes.end()};
auto v_sizes = V.sizes().vec();
std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()};
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dQ;
at::Tensor dK;
at::Tensor dV;
at::Tensor dQKV, dKV;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
std::vector<int64_t> tmp_shape;
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_3HD:
tmp_shape = std::vector<int64_t>{q_sizes.begin(), q_sizes.end()};
tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3));
dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options);
dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1),
torch::indexing::Slice(0, torch::indexing::None, 1),
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3);
dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1),
torch::indexing::Slice(0, torch::indexing::None, 1),
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3);
dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1),
torch::indexing::Slice(0, torch::indexing::None, 1),
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3);
break;
case NVTE_QKV_Layout_Group::NVTE_H3D:
tmp_shape = std::vector<int64_t>{q_sizes.begin(), q_sizes.end()};
tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3));
dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options);
dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1),
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2);
dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1),
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2);
dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1),
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2);
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
dQ = torch::empty_like(Q);
tmp_shape = std::vector<int64_t>{k_sizes.begin(), k_sizes.end()};
tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2));
dKV = torch::empty(c10::IntArrayRef(tmp_shape), options);
dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1),
torch::indexing::Slice(0, torch::indexing::None, 1),
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3);
dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1),
torch::indexing::Slice(0, torch::indexing::None, 1),
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3);
break;
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
dQ = torch::empty_like(Q);
tmp_shape = std::vector<int64_t>{k_sizes.begin(), k_sizes.end()};
tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2));
dKV = torch::empty(c10::IntArrayRef(tmp_shape), options);
dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1),
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2);
dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1),
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2);
break;
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
dQ = torch::empty_like(Q);
dK = torch::empty_like(K);
dV = torch::empty_like(V);
break;
default:
NVTE_ERROR("QKV layout not supported!");
}
at::Tensor dBias;
TensorWrapper te_dBias;
if (bias_type != NVTE_NO_BIAS) {
dBias = torch::empty({1, static_cast<int64_t>(Q.size(-2)),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h_q = Q.size(-2);
auto h_kv = K.size(-2);
auto d = Q.size(-1);
if (set_zero
&& ((h_q * d) % block_size == 0)
&& ((h_kv * d) % block_size == 0)
&& dQ.is_contiguous()
&& dK.is_contiguous()
&& dV.is_contiguous()) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQ.fill_(0);
dK.fill_(0);
dV.fill_(0);
}
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value())
|| (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, qkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape,
qkv_type, nullptr, nullptr, nullptr);
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape,
qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape,
qkv_type, nullptr, nullptr, nullptr);
te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape,
qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
// create cu_seqlens tensorwrappers
auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec();
std::vector<size_t> cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()};
auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec();
std::vector<size_t> cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()};
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv, te_qkvso_strides;
te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr);
// convert auxiliary tensors from forward to NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr();
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec());
tensor->data.shape = std::vector<size_t>(tmp.begin(), tmp.end());
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
}
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd(
te_Q.data(),
te_K.data(),
te_V.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQ.data(),
te_dK.data(),
te_dV.data(),
te_dBias.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(
workspace_data.data_ptr(),
workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd(
te_Q.data(),
te_K.data(),
te_V.data(),
te_O.data(),
te_dO.data(),
te_S.data(),
te_dP.data(),
&nvte_aux_tensor_pack,
te_dQ.data(),
te_dK.data(),
te_dV.data(),
te_dBias.data(),
te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(),
max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type,
workspace.data(),
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {dQ, dK, dV, dBias};
}
namespace flash_attention { namespace flash_attention {
constexpr int warp_size = 32; constexpr int warp_size = 32;
......
...@@ -56,6 +56,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -56,6 +56,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Attention FP8/BF16/FP16 FWD with packed KV"); "Fused Attention FP8/BF16/FP16 FWD with packed KV");
m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed KV"); "Fused Attention FP8/BF16/FP16 BWD with packed KV");
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("gelu", &gelu, "GeLU with FP8 output");
m.def("relu", &relu, "ReLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output");
...@@ -148,7 +152,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -148,7 +152,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout") py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) .value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) .value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED); .value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED)
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D)
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
.value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D)
.value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD)
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D)
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D)
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend") py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend")
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
......
...@@ -74,6 +74,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -74,6 +74,7 @@ class TransformerLayer(torch.nn.Module):
are deprecated and will be fully removed in future releases. are deprecated and will be fully removed in future releases.
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`self_attn_mask_type` is set to `"causal"`. :attr:`self_attn_mask_type` is set to `"causal"`.
...@@ -638,5 +639,5 @@ class TransformerLayer(torch.nn.Module): ...@@ -638,5 +639,5 @@ class TransformerLayer(torch.nn.Module):
if self.output_layernorm: if self.output_layernorm:
output = self.layernorm(output) output = self.layernorm(output)
# output: [b, s, h] # output: [s, b, h]
return output return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment