Unverified Commit cbfb8c6b authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Miscellaneous fixes for core attention (#344)



* miscellenous fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back pytorch csrc extensions.h
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add unit tests for dpa checkpointing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove seqlen%32/64 checks for now
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix tests for core attn bias
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add tests for changes regarding rng_state in aux_ctx_tensor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reuse rng tracker from numerics in fused attn; skip checkpointing if FAv2 in numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* uncomment comments used for testing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix pre/post scale bias
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* remove skipifs for FAv2 check after PR366
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove checkpointing tests for transformer layer; dpa tests still provide coverage
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adjust random number range for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Add upper bound to FA version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Check backend only when using FusedAttention
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove imports/variables related to FAv2 checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further fix random number ranges for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix variable referenced before assignment error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a0f44354
......@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6"])
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.4"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
......
......@@ -17,6 +17,7 @@ import os
from pkg_resources import packaging
from importlib.metadata import version
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
......@@ -58,29 +59,32 @@ batch_sizes = [1, 2, 32]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_dot_product_attention(dtype, bs, model):
@pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
"""Test DotProductAttention module with three backends,
FlashAttention, FusedAttention and UnfusedDotProductAttention"""
config = model_configs[model]
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FlashAttention")
if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FusedAttention")
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type)
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "UnfusedDotProductAttention")
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type)
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3)
if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_dot_product_attention(dtype, bs, config, backend):
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
......@@ -88,7 +92,7 @@ def _run_dot_product_attention(dtype, bs, config, backend):
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.1 * torch.randn(
inp = torch.randn(
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
......@@ -96,9 +100,14 @@ def _run_dot_product_attention(dtype, bs, config, backend):
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype = dtype).cuda()
else:
bias = None
block = (
DotProductAttention(
......@@ -108,7 +117,7 @@ def _run_dot_product_attention(dtype, bs, config, backend):
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
get_rng_state_tracker = get_dummy_cuda_rng_tracker,
tp_group = None,
layer_number = 1,
attention_type = "self"
......@@ -118,7 +127,10 @@ def _run_dot_product_attention(dtype, bs, config, backend):
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v)
op = block(q, k, v,
checkpoint_core_attention = ckpt_attn,
core_attention_bias_type = bias_type,
core_attention_bias = bias)
op.backward(op_grad)
return op, inp.grad
......@@ -128,29 +140,32 @@ def _run_dot_product_attention(dtype, bs, config, backend):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer(dtype, bs, model):
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
config = model_configs[model]
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention")
if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FusedAttention")
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type)
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention")
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type)
atol, rtol = (5e-1, 5e-1) if dtype == torch.bfloat16 else (5e-1, 5e-1)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_transformer_layer(dtype, bs, config, backend):
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
......@@ -158,7 +173,7 @@ def _run_transformer_layer(dtype, bs, config, backend):
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.1 * torch.randn(
inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
......@@ -166,9 +181,9 @@ def _run_transformer_layer(dtype, bs, config, backend):
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
sigma = 0.02
init_method = init_method_normal(sigma)
......@@ -178,6 +193,11 @@ def _run_transformer_layer(dtype, bs, config, backend):
drop_path_rate = 0.0
drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype = dtype).cuda()
else:
bias = None
block = (
TransformerLayer(
......@@ -215,8 +235,13 @@ def _run_transformer_layer(dtype, bs, config, backend):
.cuda()
)
op = block(inp)
op.backward(op_grad)
num_iters = 10
for i in range(num_iters):
op = block(inp,
checkpoint_core_attention = ckpt_attn,
core_attention_bias_type = bias_type,
core_attention_bias = bias)
op.backward(op_grad)
return op, inp.grad
......@@ -246,19 +271,18 @@ def test_transformer_layer_gqa(dtype, bs, model):
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa(
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)
atol, rtol = 5e-1, 5e-1
atol, rtol = 5e-1, 5e-2
assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
inp = 0.1 * torch.randn(
inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
......@@ -266,9 +290,9 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
sigma = 0.02
init_method = init_method_normal(sigma)
......@@ -342,14 +366,13 @@ def test_dpa_fp8(dtype, bs, model):
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (5e-2, 1e-1)
atol, rtol = (2.5e-2, 2.5e-2)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_dpa_fp8(dtype, bs, config, backend):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
......@@ -361,9 +384,9 @@ def _run_dpa_fp8(dtype, bs, config, backend):
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
bs * config.seq_len, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
op_grad = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
torch.save(op_grad, 'op_grad.pt')
fp8_recipe = recipe.DelayedScaling(
......
......@@ -25,7 +25,6 @@ from transformer_engine.pytorch import (
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
seed = 1234
rng_str = "rng_state"
torch.manual_seed(seed)
......
......@@ -32,9 +32,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
#if (CUDNN_VERSION >= 8900)
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
#else
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible." << std::endl;
#endif
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
......@@ -76,6 +82,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
#if (CUDNN_VERSION < 8901)
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+."
" Please upgrade your cuDNN version if possible." << std::endl;
}
#endif
#if (CUDNN_VERSION < 8900)
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible." << std::endl;
}
#endif
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
}
......
......@@ -136,10 +136,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
......@@ -181,10 +181,10 @@ void nvte_fused_attn_fwd_qkvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
......@@ -235,8 +235,8 @@ void nvte_fused_attn_bwd_qkvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
......@@ -283,8 +283,8 @@ void nvte_fused_attn_fwd_kvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
......
......@@ -8,7 +8,7 @@ import warnings
import math
from importlib.metadata import version
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union, Dict
from pkg_resources import packaging
import torch
......@@ -34,6 +34,7 @@ from transformer_engine.pytorch.utils import (
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
AttnTypes,
AttnBiasTypes,
dist_group_type,
TE_DType,
)
......@@ -227,6 +228,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""core attention fprop"""
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
......@@ -275,13 +278,42 @@ class UnfusedDotProductAttention(torch.nn.Module):
scale *= self.layer_number
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / scale),
)
if core_attention_bias_type == "no_bias":
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / scale),
)
elif core_attention_bias_type == "pre_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
assert (core_attention_bias.shape == torch.Size(1, *output_size[1:])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
matmul_result = torch.bmm(
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
)
matmul_result = (matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3])
+ core_attention_bias).view(-1, output_size[2], output_size[3])
matmul_result /= scale
elif core_attention_bias_type == "post_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / scale),
)
matmul_result = (matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3])
+ core_attention_bias).view(-1, output_size[2], output_size[3])
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
......@@ -689,13 +721,17 @@ class FusedAttention(torch.nn.Module):
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
fused_attention_backend:
tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor:
"""fused attention fprop"""
assert (fused_attention_backend
!= tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
), 'No fused attention backend supports this input combination!'
assert (
(query_layer.dtype in [torch.float16, torch.bfloat16])
and (key_layer.dtype in [torch.float16, torch.bfloat16])
......@@ -865,7 +901,7 @@ class DotProductAttention(torch.nn.Module):
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding'}, default = `causal`
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
......@@ -964,11 +1000,12 @@ class DotProductAttention(torch.nn.Module):
self,
attention_func: Callable,
*forward_args: Tuple[torch.Tensor, ...],
**forward_kwargs: Dict[str, Any],
) -> torch.Tensor:
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
return attention_func(*inputs)
def custom_forward(*input_args, **input_kwargs):
return attention_func(*input_args, **input_kwargs)
hidden_states = checkpoint(
custom_forward,
......@@ -976,6 +1013,7 @@ class DotProductAttention(torch.nn.Module):
self.get_rng_state_tracker,
self.tp_group,
*forward_args,
**forward_kwargs,
)
return hidden_states
......@@ -1067,33 +1105,38 @@ class DotProductAttention(torch.nn.Module):
use_flash_attention = False
use_fused_attention = False
if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False
if is_in_onnx_export_mode():
use_flash_attention = False
use_fused_attention = False
qkv_layout = "qkv_interleaved" if self.attention_type == "self" else "kv_interleaved"
fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype],
TE_DType[key_layer.dtype],
QKVLayout[qkv_layout],
AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type],
self.attention_dropout,
query_layer.shape[0], key_layer.shape[0],
query_layer.shape[-1])
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
use_fused_attention = (use_fused_attention
and is_backend_avail
and self.num_gqa_groups == self.num_attention_heads)
if (self.deterministic
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
use_fused_attention = False
warnings.warn(
"Disabling usage of FusedAttention since this FusedAttention"
"backend does not support deterministic execution."
)
if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype],
TE_DType[key_layer.dtype],
QKVLayout[qkv_layout],
AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type],
self.attention_dropout,
query_layer.shape[0], key_layer.shape[0],
query_layer.shape[-1])
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
use_fused_attention = (use_fused_attention
and is_backend_avail
and self.num_gqa_groups == self.num_attention_heads)
if (self.deterministic
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
use_fused_attention = False
warnings.warn(
"Disabling usage of FusedAttention since the FusedAttention"
"backend does not support deterministic exection."
)
if use_flash_attention:
if checkpoint_core_attention:
......@@ -1106,18 +1149,18 @@ class DotProductAttention(torch.nn.Module):
if use_fused_attention:
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.fused_attention,
query_layer,
key_layer,
value_layer,
fused_attention_backend,
core_attention_bias_type,
core_attention_bias,
fast_zero_fill)
query_layer,
key_layer,
value_layer,
fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill)
return self.fused_attention(query_layer, key_layer, value_layer,
fused_attention_backend,
core_attention_bias_type,
core_attention_bias,
fast_zero_fill)
fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
......@@ -1125,9 +1168,17 @@ class DotProductAttention(torch.nn.Module):
query_layer,
key_layer,
value_layer,
attention_mask,
attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
)
return self.unfused_attention(query_layer, key_layer, value_layer, attention_mask)
return self.unfused_attention(query_layer,
key_layer,
value_layer,
attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
)
class MultiHeadAttention(torch.nn.Module):
......@@ -1350,6 +1401,8 @@ class MultiHeadAttention(torch.nn.Module):
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
assert (core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
......
......@@ -26,6 +26,8 @@ AttnMaskTypes = ("causal", "padding", "no_mask")
AttnTypes = ("self", "cross")
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias")
LayerTypes = ("encoder", "decoder")
GemmParallelModes = ("row", "column", None)
......
......@@ -254,7 +254,7 @@ def fused_attn_fwd_qkvpacked(
if attn_bias_type != "no_bias":
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
assert (attn_bias.shape == [1, h, max_seqlen, max_seqlen]
assert (attn_bias.shape == torch.Size([1, h, max_seqlen, max_seqlen])
), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (attn_bias.dtype == qkv.dtype
), "attn_bias tensor must be in the same dtype as qkv."
......@@ -599,7 +599,7 @@ def fused_attn_fwd_kvpacked(
if attn_bias_type != "no_bias":
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
assert (attn_bias.shape == [1, h, max_seqlen_q, max_seqlen_kv]
assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype
), "attn_bias tensor must be in the same dtype as q and kv."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment