Unverified Commit cbfb8c6b authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Miscellaneous fixes for core attention (#344)



* miscellenous fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back pytorch csrc extensions.h
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add unit tests for dpa checkpointing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove seqlen%32/64 checks for now
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix tests for core attn bias
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add tests for changes regarding rng_state in aux_ctx_tensor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reuse rng tracker from numerics in fused attn; skip checkpointing if FAv2 in numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* uncomment comments used for testing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix pre/post scale bias
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* remove skipifs for FAv2 check after PR366
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove checkpointing tests for transformer layer; dpa tests still provide coverage
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adjust random number range for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Add upper bound to FA version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Check backend only when using FusedAttention
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove imports/variables related to FAv2 checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further fix random number ranges for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix variable referenced before assignment error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a0f44354
...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if "pytorch" in frameworks(): if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6"]) add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.4"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
...@@ -17,6 +17,7 @@ import os ...@@ -17,6 +17,7 @@ import os
from pkg_resources import packaging from pkg_resources import packaging
from importlib.metadata import version from importlib.metadata import version
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
...@@ -58,29 +59,32 @@ batch_sizes = [1, 2, 32] ...@@ -58,29 +59,32 @@ batch_sizes = [1, 2, 32]
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_dot_product_attention(dtype, bs, model): @pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
"""Test DotProductAttention module with three backends, """Test DotProductAttention module with three backends,
FlashAttention, FusedAttention and UnfusedDotProductAttention""" FlashAttention, FusedAttention and UnfusedDotProductAttention"""
config = model_configs[model] config = model_configs[model]
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( if bias_type == "no_bias":
dtype, bs, config, "FlashAttention") flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FusedAttention") dtype, bs, config, "FusedAttention", ckpt_attn, bias_type)
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "UnfusedDotProductAttention") dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type)
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3) atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) if bias_type == "no_bias":
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_dot_product_attention(dtype, bs, config, backend): def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
torch.manual_seed(1234) reset_rng_states()
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
...@@ -88,7 +92,7 @@ def _run_dot_product_attention(dtype, bs, config, backend): ...@@ -88,7 +92,7 @@ def _run_dot_product_attention(dtype, bs, config, backend):
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.1 * torch.randn( inp = torch.randn(
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim, config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
dtype = dtype).cuda() dtype = dtype).cuda()
inp.requires_grad=True inp.requires_grad=True
...@@ -96,9 +100,14 @@ def _run_dot_product_attention(dtype, bs, config, backend): ...@@ -96,9 +100,14 @@ def _run_dot_product_attention(dtype, bs, config, backend):
seqlens.fill_(config.seq_len) seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, ( op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim config.seq_len, bs, config.num_attention_heads * config.head_dim,
), dtype = dtype).cuda() dtype = dtype).cuda()
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype = dtype).cuda()
else:
bias = None
block = ( block = (
DotProductAttention( DotProductAttention(
...@@ -108,7 +117,7 @@ def _run_dot_product_attention(dtype, bs, config, backend): ...@@ -108,7 +117,7 @@ def _run_dot_product_attention(dtype, bs, config, backend):
attn_mask_type = config.attn_mask_type, attn_mask_type = config.attn_mask_type,
sequence_parallel = False, sequence_parallel = False,
tp_size = 1, tp_size = 1,
get_rng_state_tracker = None, get_rng_state_tracker = get_dummy_cuda_rng_tracker,
tp_group = None, tp_group = None,
layer_number = 1, layer_number = 1,
attention_type = "self" attention_type = "self"
...@@ -118,7 +127,10 @@ def _run_dot_product_attention(dtype, bs, config, backend): ...@@ -118,7 +127,10 @@ def _run_dot_product_attention(dtype, bs, config, backend):
q = inp[:, :,0,:,:] q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:] k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:] v = inp[:, :,2,:,:]
op = block(q, k, v) op = block(q, k, v,
checkpoint_core_attention = ckpt_attn,
core_attention_bias_type = bias_type,
core_attention_bias = bias)
op.backward(op_grad) op.backward(op_grad)
return op, inp.grad return op, inp.grad
...@@ -128,29 +140,32 @@ def _run_dot_product_attention(dtype, bs, config, backend): ...@@ -128,29 +140,32 @@ def _run_dot_product_attention(dtype, bs, config, backend):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer(dtype, bs, model): @pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type):
"""Test TransformerLayer module when its DotProductAttention is enabled with """Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
config = model_configs[model] config = model_configs[model]
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer( if bias_type == "no_bias":
dtype, bs, config, "FlashAttention") flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FusedAttention") dtype, bs, config, "FusedAttention", ckpt_attn, bias_type)
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer( unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention") dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type)
atol, rtol = (5e-1, 5e-1) if dtype == torch.bfloat16 else (5e-1, 5e-1) atol, rtol = (5e-1, 5e-2)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) if bias_type == "no_bias":
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_transformer_layer(dtype, bs, config, backend): def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
torch.manual_seed(1234) reset_rng_states()
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
...@@ -158,7 +173,7 @@ def _run_transformer_layer(dtype, bs, config, backend): ...@@ -158,7 +173,7 @@ def _run_transformer_layer(dtype, bs, config, backend):
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.1 * torch.randn( inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim, config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda() dtype = dtype).cuda()
inp.requires_grad=True inp.requires_grad=True
...@@ -166,9 +181,9 @@ def _run_transformer_layer(dtype, bs, config, backend): ...@@ -166,9 +181,9 @@ def _run_transformer_layer(dtype, bs, config, backend):
seqlens.fill_(config.seq_len) seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, ( op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim config.seq_len, bs, config.num_attention_heads * config.head_dim,
), dtype = dtype).cuda() dtype = dtype).cuda()
sigma = 0.02 sigma = 0.02
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -178,6 +193,11 @@ def _run_transformer_layer(dtype, bs, config, backend): ...@@ -178,6 +193,11 @@ def _run_transformer_layer(dtype, bs, config, backend):
drop_path_rate = 0.0 drop_path_rate = 0.0
drop_path_rates = [ drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)] rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype = dtype).cuda()
else:
bias = None
block = ( block = (
TransformerLayer( TransformerLayer(
...@@ -215,8 +235,13 @@ def _run_transformer_layer(dtype, bs, config, backend): ...@@ -215,8 +235,13 @@ def _run_transformer_layer(dtype, bs, config, backend):
.cuda() .cuda()
) )
op = block(inp) num_iters = 10
op.backward(op_grad) for i in range(num_iters):
op = block(inp,
checkpoint_core_attention = ckpt_attn,
core_attention_bias_type = bias_type,
core_attention_bias = bias)
op.backward(op_grad)
return op, inp.grad return op, inp.grad
...@@ -246,19 +271,18 @@ def test_transformer_layer_gqa(dtype, bs, model): ...@@ -246,19 +271,18 @@ def test_transformer_layer_gqa(dtype, bs, model):
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa( unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa(
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group) dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)
atol, rtol = 5e-1, 5e-1 atol, rtol = 5e-1, 5e-2
assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group): def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):
torch.manual_seed(1234) reset_rng_states()
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
inp = 0.1 * torch.randn( inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim, config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda() dtype = dtype).cuda()
inp.requires_grad=True inp.requires_grad=True
...@@ -266,9 +290,9 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr ...@@ -266,9 +290,9 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
seqlens.fill_(config.seq_len) seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, ( op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim config.seq_len, bs, config.num_attention_heads * config.head_dim,
), dtype = dtype).cuda() dtype = dtype).cuda()
sigma = 0.02 sigma = 0.02
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -342,14 +366,13 @@ def test_dpa_fp8(dtype, bs, model): ...@@ -342,14 +366,13 @@ def test_dpa_fp8(dtype, bs, model):
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref( unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
dtype, bs, config, "UnfusedDotProductAttention") dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (5e-2, 1e-1) atol, rtol = (2.5e-2, 2.5e-2)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_dpa_fp8(dtype, bs, config, backend): def _run_dpa_fp8(dtype, bs, config, backend):
torch.manual_seed(1234) reset_rng_states()
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
...@@ -361,9 +384,9 @@ def _run_dpa_fp8(dtype, bs, config, backend): ...@@ -361,9 +384,9 @@ def _run_dpa_fp8(dtype, bs, config, backend):
seqlens.fill_(config.seq_len) seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, ( op_grad = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim bs * config.seq_len, config.num_attention_heads * config.head_dim,
), dtype = dtype).cuda() dtype = dtype).cuda()
torch.save(op_grad, 'op_grad.pt') torch.save(op_grad, 'op_grad.pt')
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
......
...@@ -25,7 +25,6 @@ from transformer_engine.pytorch import ( ...@@ -25,7 +25,6 @@ from transformer_engine.pytorch import (
) )
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
seed = 1234 seed = 1234
rng_str = "rng_state" rng_str = "rng_state"
torch.manual_seed(seed) torch.manual_seed(seed)
......
...@@ -32,9 +32,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -32,9 +32,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (max_seqlen_q <= 512) && (max_seqlen_q <= 512)
&& (head_dim == 64) && (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) { && (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
#if (CUDNN_VERSION >= 8900)
backend = NVTE_Fused_Attn_Backend::NVTE_FP8; backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
#else
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible." << std::endl;
#endif
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false; bool flag_m512 = false;
bool flag_arb = false; bool flag_arb = false;
...@@ -76,6 +82,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -76,6 +82,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) { NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
} }
#if (CUDNN_VERSION < 8901)
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+."
" Please upgrade your cuDNN version if possible." << std::endl;
}
#endif
#if (CUDNN_VERSION < 8900)
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible." << std::endl;
}
#endif
} else { } else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
} }
......
...@@ -136,10 +136,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -136,10 +136,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 | | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -181,10 +181,10 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -181,10 +181,10 @@ void nvte_fused_attn_fwd_qkvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 | | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -235,8 +235,8 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -235,8 +235,8 @@ void nvte_fused_attn_bwd_qkvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
...@@ -283,8 +283,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -283,8 +283,8 @@ void nvte_fused_attn_fwd_kvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
......
...@@ -8,7 +8,7 @@ import warnings ...@@ -8,7 +8,7 @@ import warnings
import math import math
from importlib.metadata import version from importlib.metadata import version
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union from typing import Any, Callable, Optional, Tuple, Union, Dict
from pkg_resources import packaging from pkg_resources import packaging
import torch import torch
...@@ -34,6 +34,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -34,6 +34,7 @@ from transformer_engine.pytorch.utils import (
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
AttnMaskTypes, AttnMaskTypes,
AttnTypes, AttnTypes,
AttnBiasTypes,
dist_group_type, dist_group_type,
TE_DType, TE_DType,
) )
...@@ -227,6 +228,8 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -227,6 +228,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""core attention fprop""" """core attention fprop"""
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
...@@ -275,13 +278,42 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -275,13 +278,42 @@ class UnfusedDotProductAttention(torch.nn.Module):
scale *= self.layer_number scale *= self.layer_number
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm( if core_attention_bias_type == "no_bias":
matmul_result, matmul_result = torch.baddbmm(
query_layer.transpose(0, 1), # [b * np, sq, hn] matmul_result,
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] query_layer.transpose(0, 1), # [b * np, sq, hn]
beta=0.0, key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
alpha=(1.0 / scale), beta=0.0,
) alpha=(1.0 / scale),
)
elif core_attention_bias_type == "pre_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
assert (core_attention_bias.shape == torch.Size(1, *output_size[1:])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
matmul_result = torch.bmm(
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
)
matmul_result = (matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3])
+ core_attention_bias).view(-1, output_size[2], output_size[3])
matmul_result /= scale
elif core_attention_bias_type == "post_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / scale),
)
matmul_result = (matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3])
+ core_attention_bias).view(-1, output_size[2], output_size[3])
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
...@@ -689,13 +721,17 @@ class FusedAttention(torch.nn.Module): ...@@ -689,13 +721,17 @@ class FusedAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend:
tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
assert (fused_attention_backend
!= tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
), 'No fused attention backend supports this input combination!'
assert ( assert (
(query_layer.dtype in [torch.float16, torch.bfloat16]) (query_layer.dtype in [torch.float16, torch.bfloat16])
and (key_layer.dtype in [torch.float16, torch.bfloat16]) and (key_layer.dtype in [torch.float16, torch.bfloat16])
...@@ -865,7 +901,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -865,7 +901,7 @@ class DotProductAttention(torch.nn.Module):
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0 attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding'}, default = `causal` attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
layer_number: int, default = `None` layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules layer number of the current `DotProductAttention` when multiple such modules
...@@ -964,11 +1000,12 @@ class DotProductAttention(torch.nn.Module): ...@@ -964,11 +1000,12 @@ class DotProductAttention(torch.nn.Module):
self, self,
attention_func: Callable, attention_func: Callable,
*forward_args: Tuple[torch.Tensor, ...], *forward_args: Tuple[torch.Tensor, ...],
**forward_kwargs: Dict[str, Any],
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom_forward(*inputs): def custom_forward(*input_args, **input_kwargs):
return attention_func(*inputs) return attention_func(*input_args, **input_kwargs)
hidden_states = checkpoint( hidden_states = checkpoint(
custom_forward, custom_forward,
...@@ -976,6 +1013,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -976,6 +1013,7 @@ class DotProductAttention(torch.nn.Module):
self.get_rng_state_tracker, self.get_rng_state_tracker,
self.tp_group, self.tp_group,
*forward_args, *forward_args,
**forward_kwargs,
) )
return hidden_states return hidden_states
...@@ -1067,33 +1105,38 @@ class DotProductAttention(torch.nn.Module): ...@@ -1067,33 +1105,38 @@ class DotProductAttention(torch.nn.Module):
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
qkv_layout = "qkv_interleaved" if self.attention_type == "self" else "kv_interleaved" qkv_layout = "qkv_interleaved" if self.attention_type == "self" else "kv_interleaved"
fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], if use_fused_attention:
TE_DType[key_layer.dtype], fused_attention_backend = tex.get_fused_attn_backend(
QKVLayout[qkv_layout], TE_DType[query_layer.dtype],
AttnBiasType[core_attention_bias_type], TE_DType[key_layer.dtype],
AttnMaskType[self.attn_mask_type], QKVLayout[qkv_layout],
self.attention_dropout, AttnBiasType[core_attention_bias_type],
query_layer.shape[0], key_layer.shape[0], AttnMaskType[self.attn_mask_type],
query_layer.shape[-1]) self.attention_dropout,
# DPA does not support FP8; for FP8, use cpp_extensions modules directly query_layer.shape[0], key_layer.shape[0],
is_backend_avail = (fused_attention_backend in query_layer.shape[-1])
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]]) # DPA does not support FP8; for FP8, use cpp_extensions modules directly
use_fused_attention = (use_fused_attention is_backend_avail = (fused_attention_backend in
and is_backend_avail [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
and self.num_gqa_groups == self.num_attention_heads) use_fused_attention = (use_fused_attention
if (self.deterministic and is_backend_avail
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]): and self.num_gqa_groups == self.num_attention_heads)
use_fused_attention = False if (self.deterministic
warnings.warn( and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
"Disabling usage of FusedAttention since this FusedAttention" use_fused_attention = False
"backend does not support deterministic execution." warnings.warn(
) "Disabling usage of FusedAttention since the FusedAttention"
"backend does not support deterministic exection."
)
if use_flash_attention: if use_flash_attention:
if checkpoint_core_attention: if checkpoint_core_attention:
...@@ -1106,18 +1149,18 @@ class DotProductAttention(torch.nn.Module): ...@@ -1106,18 +1149,18 @@ class DotProductAttention(torch.nn.Module):
if use_fused_attention: if use_fused_attention:
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.fused_attention, return self._checkpointed_attention_forward(self.fused_attention,
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
fused_attention_backend, fused_attention_backend = fused_attention_backend,
core_attention_bias_type, core_attention_bias_type = core_attention_bias_type,
core_attention_bias, core_attention_bias = core_attention_bias,
fast_zero_fill) fast_zero_fill = fast_zero_fill)
return self.fused_attention(query_layer, key_layer, value_layer, return self.fused_attention(query_layer, key_layer, value_layer,
fused_attention_backend, fused_attention_backend = fused_attention_backend,
core_attention_bias_type, core_attention_bias_type = core_attention_bias_type,
core_attention_bias, core_attention_bias = core_attention_bias,
fast_zero_fill) fast_zero_fill = fast_zero_fill)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
...@@ -1125,9 +1168,17 @@ class DotProductAttention(torch.nn.Module): ...@@ -1125,9 +1168,17 @@ class DotProductAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
attention_mask, attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
) )
return self.unfused_attention(query_layer, key_layer, value_layer, attention_mask) return self.unfused_attention(query_layer,
key_layer,
value_layer,
attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
)
class MultiHeadAttention(torch.nn.Module): class MultiHeadAttention(torch.nn.Module):
...@@ -1350,6 +1401,8 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1350,6 +1401,8 @@ class MultiHeadAttention(torch.nn.Module):
attention_mask.dtype == torch.bool attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor" ), "Attention mask must be a boolean tensor"
assert (core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
......
...@@ -26,6 +26,8 @@ AttnMaskTypes = ("causal", "padding", "no_mask") ...@@ -26,6 +26,8 @@ AttnMaskTypes = ("causal", "padding", "no_mask")
AttnTypes = ("self", "cross") AttnTypes = ("self", "cross")
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias")
LayerTypes = ("encoder", "decoder") LayerTypes = ("encoder", "decoder")
GemmParallelModes = ("row", "column", None) GemmParallelModes = ("row", "column", None)
......
...@@ -254,7 +254,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -254,7 +254,7 @@ def fused_attn_fwd_qkvpacked(
if attn_bias_type != "no_bias": if attn_bias_type != "no_bias":
assert (attn_bias is not None assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias." ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
assert (attn_bias.shape == [1, h, max_seqlen, max_seqlen] assert (attn_bias.shape == torch.Size([1, h, max_seqlen, max_seqlen])
), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." ), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (attn_bias.dtype == qkv.dtype assert (attn_bias.dtype == qkv.dtype
), "attn_bias tensor must be in the same dtype as qkv." ), "attn_bias tensor must be in the same dtype as qkv."
...@@ -599,7 +599,7 @@ def fused_attn_fwd_kvpacked( ...@@ -599,7 +599,7 @@ def fused_attn_fwd_kvpacked(
if attn_bias_type != "no_bias": if attn_bias_type != "no_bias":
assert (attn_bias is not None assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias." ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias."
assert (attn_bias.shape == [1, h, max_seqlen_q, max_seqlen_kv] assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." ), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype assert (attn_bias.dtype == q.dtype
), "attn_bias tensor must be in the same dtype as q and kv." ), "attn_bias tensor must be in the same dtype as q and kv."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment