Unverified Commit 56e0b351 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Add support for bottom-right-diagonal causal mask (#960)



* update to FE 1.5.1 and add bottom right causal
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust logic for backend selection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.5.2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add get_attention_backend function
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tweak get_attention_backend and fix unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes for unfused, get_backend, etc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cpu offload
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for get_attention_backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* explicitly skip FP32 and padding tests because there is no support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for window size check
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update check_set_window_size and add enc_dec_attn_mask_type/enc_dec_window_size
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent f9dd37f7
...@@ -345,10 +345,10 @@ ...@@ -345,10 +345,10 @@
"| cuDNN attention | `bshd`, `sbhd`, `thd` | PyTorch: 3 formats, i.e. 15 layouts<br>JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts |\n", "| cuDNN attention | `bshd`, `sbhd`, `thd` | PyTorch: 3 formats, i.e. 15 layouts<br>JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts |\n",
"| Framework-native attention | `bshd`, `sbhd`<br>(`sbhd` requires transpose operations) | PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts |\n", "| Framework-native attention | `bshd`, `sbhd`<br>(`sbhd` requires transpose operations) | PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts |\n",
"\n", "\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention._get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](../../tests/pytorch/fused_attention/test_fused_attn.py) and [test_dpa_qkv_layout_thd](../../tests/pytorch/fused_attention/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [_get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n", "<b>Note:</b> When RoPE is employed, the <code>qkv_layout</code> may change in Transformer Engine PyTorch through [get_qkv_layout](../../transformer_engine/pytorch/attention.py). This is due to the in-place nature of our RoPE implementations. We convert `q`, `k`, `v` tensors from their initial layout to the corresponding <code>hd_hd_hd</code> layout. For example, from <code>sbh3d</code> in <code>pytorch.MultiHeadAttention</code> before RoPE, to <code>sbhd_sbhd_sbhd</code> in <code>pytorch.DotProductAttention</code> after RoPE.\n",
"</div>\n" "</div>\n"
] ]
}, },
......
...@@ -7,7 +7,7 @@ import logging ...@@ -7,7 +7,7 @@ import logging
import math import math
import os import os
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union, Optional
import pytest import pytest
import torch import torch
...@@ -19,6 +19,9 @@ from transformer_engine.pytorch.attention import ( ...@@ -19,6 +19,9 @@ from transformer_engine.pytorch.attention import (
DotProductAttention, DotProductAttention,
MultiheadAttention, MultiheadAttention,
RotaryPositionEmbedding, RotaryPositionEmbedding,
get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus,
) )
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
...@@ -99,104 +102,96 @@ class ModelConfig: ...@@ -99,104 +102,96 @@ class ModelConfig:
self.bias_shape = bias_shape self.bias_shape = bias_shape
def _is_fused_attention_supported( def _get_attention_backends(
config: ModelConfig, config: ModelConfig,
dtype: torch.dtype, qkv_dtype: torch.dtype,
qkv_layout: str = "sbh3d", qkv_layout: str,
) -> Tuple[bool, NVTE_Fused_Attn_Backend]: window_size: Tuple[int, int] = (-1, -1),
"""Check if FusedAttention supports a model configuration""" pad_between_seqs: bool = False,
backends = [] context_parallel: bool = False,
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" deterministic: bool = False,
backend = tex.get_fused_attn_backend( fp8: bool = False,
TE_DType[dtype], fp8_meta: Optional[Dict[str, Any]] = None,
TE_DType[dtype], ) -> Tuple[List, List]:
QKVLayout[qkv_layout], """Check if what attention backends support a model configuration"""
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
)
if backend == FusedAttnBackend["FP8"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
if backend == FusedAttnBackend["F16_max512_seqlen"]:
backends.append(backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
backend = tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
QKVLayout[qkv_layout],
AttnBiasType[config.attn_bias_type],
AttnMaskType[config.attn_mask_type],
config.dropout_p,
config.num_heads,
config.num_gqa_groups,
config.max_seqlen_q,
config.max_seqlen_kv,
config.head_dim,
)
if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
backends.append(backend)
return True, backends
return False, backends
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2")
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3")
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
def _is_flash_attention_supported(config: ModelConfig) -> bool: alibi_slopes_shape = None
"""Check if FlashAttention supports a model configuration""" if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if get_device_compute_capability() < (8, 0): if config.bias_shape == "1hss":
return False alibi_slopes_shape = [config.num_heads]
if config.attn_bias_type not in ["no_bias", "alibi"]: if config.bias_shape == "bhss":
return False alibi_slopes_shape = [config.batch_size, config.num_heads]
if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
return False
if "causal" in config.attn_mask_type and config.attn_type == "cross":
if _is_flash_attention_2_1():
# FAv2.1 implements causal mask for cross attention differently
# https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag
return False
return True
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128:
core_attention_bias_requires_grad = True
def _is_unfused_attention_supported( fused_attn_backends = []
config: ModelConfig, available_backends = None
qkv_format: str, os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
) -> bool: _, _, _, available_backends, fused_attention_backend = get_attention_backend(
"""Check if UnfusedDotProductAttention supports a model configuration""" qkv_dtype=qkv_dtype,
if "padding" in config.attn_mask_type: qkv_layout=qkv_layout,
return False batch_size=config.batch_size,
if "causal" in config.attn_mask_type and config.attn_type == "cross": num_heads=config.num_heads,
return False num_gqa_groups=config.num_gqa_groups,
if qkv_format == "thd": max_seqlen_q=config.max_seqlen_q,
return False max_seqlen_kv=config.max_seqlen_kv,
return True head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
fused_attn_backends.append(fused_attention_backend)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
_, _, _, available_backends, fused_attention_backend = get_attention_backend(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
fused_attn_backends.append(fused_attention_backend)
elif (
fused_attention_backend != FusedAttnBackend["No_Backend"]
and fused_attention_backend is not None
):
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
model_configs_base = { model_configs_base = {
...@@ -255,25 +250,21 @@ def test_dot_product_attention( ...@@ -255,25 +250,21 @@ def test_dot_product_attention(
if "3" in qkv_layout and config.attn_type == "cross": if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention") pytest.skip("No need to test this layout for cross attention")
# Skip if only unfused backend is supported # Test backend availability
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) window_size = (2, 2) if swa else (-1, -1)
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format) available_backends, fused_attn_backends = _get_attention_backends(
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, config,
dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=window_size,
pad_between_seqs=pad_between_seqs,
) )
if swa: flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
fused_attn_supported = False
flash_attn_supported = _is_flash_attention_supported(config) # Skip if only unfused backend is supported
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD layout requires padding/padding_causal mask type.")
# d=256 is supported by cuDNN 9.0+ for inference but not training
is_training = config.head_dim <= 128 is_training = config.head_dim <= 128
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
...@@ -296,7 +287,7 @@ def test_dot_product_attention( ...@@ -296,7 +287,7 @@ def test_dot_product_attention(
# FusedAttention backend # FusedAttention backend
if fused_attn_supported: if fused_attn_supported:
if len(fused_attn_backend) == 1: if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
...@@ -308,7 +299,7 @@ def test_dot_product_attention( ...@@ -308,7 +299,7 @@ def test_dot_product_attention(
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
if len(fused_attn_backend) == 2: if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, dtype,
...@@ -363,7 +354,7 @@ def test_dot_product_attention( ...@@ -363,7 +354,7 @@ def test_dot_product_attention(
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd): for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2: if fused_attn_supported and len(fused_attn_backends) == 2:
logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1") logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd): for i, _ in enumerate(fused_attn_bwd):
...@@ -393,6 +384,14 @@ model_configs_mask = { ...@@ -393,6 +384,14 @@ model_configs_mask = {
"mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"), "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_7_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
} }
...@@ -508,7 +507,7 @@ model_configs_swa = { ...@@ -508,7 +507,7 @@ model_configs_swa = {
} }
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.") @pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys()) @pytest.mark.parametrize("model", model_configs_swa.keys())
...@@ -530,7 +529,7 @@ model_configs_alibi_slopes = { ...@@ -530,7 +529,7 @@ model_configs_alibi_slopes = {
} }
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.") @pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes]) @pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys()) @pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
...@@ -994,17 +993,16 @@ def test_transformer_layer( ...@@ -994,17 +993,16 @@ def test_transformer_layer(
tols = dict(atol=5e-1, rtol=5e-2) tols = dict(atol=5e-1, rtol=5e-2)
workspace_opt = True workspace_opt = True
# Skip if only unfused backend is supported # Test backend availability
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512: available_backends, fused_attn_backends = _get_attention_backends(
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, config,
dtype, qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
) )
flash_attn_supported = _is_flash_attention_supported(config) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: # Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
import os import os
import pytest import pytest
import subprocess import subprocess
from test_fused_attn import ( from test_fused_attn import ModelConfig
ModelConfig, from transformer_engine.pytorch.attention import (
_is_flash_attention_2_available, _flash_attn_2_plus,
_flash_attn_2_3_plus,
) )
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
...@@ -33,7 +34,7 @@ def get_bash_arguments(**kwargs): ...@@ -33,7 +34,7 @@ def get_bash_arguments(**kwargs):
return args return args
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(not _flash_attn_2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("model", model_configs_flash_attn.keys())
......
...@@ -1593,6 +1593,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1593,6 +1593,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
ffn_hidden_size=4 * D, ffn_hidden_size=4 * D,
num_attention_heads=H, num_attention_heads=H,
attn_input_format=input_format, attn_input_format=input_format,
self_attn_mask_type="causal_bottom_right",
enc_dec_attn_mask_type="causal_bottom_right",
layer_number=layer_number, layer_number=layer_number,
attention_dropout=0.0, attention_dropout=0.0,
params_dtype=dtype, params_dtype=dtype,
...@@ -1606,6 +1608,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1606,6 +1608,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
qkv_format=input_format, qkv_format=input_format,
layer_number=layer_number, layer_number=layer_number,
attention_dropout=0.0, attention_dropout=0.0,
attn_mask_type="causal_bottom_right",
params_dtype=dtype, params_dtype=dtype,
) )
.cuda() .cuda()
......
...@@ -670,7 +670,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -670,7 +670,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding", self_attn_mask_type="causal",
normalization=normalization, normalization=normalization,
device="cuda", device="cuda",
) )
......
...@@ -119,17 +119,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -119,17 +119,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) { (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
flag_m512 = true; flag_m512 = true;
} }
if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || if ( // architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
// sequence length
((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) ||
(cudnn_runtime_version >= 90000)) && (cudnn_runtime_version >= 90000)) &&
// number of heads
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 8907)) && (cudnn_runtime_version >= 8907)) &&
((head_dim <= 128 && head_dim % 8 == 0) // head dimension
((head_dim <= 128 && head_dim % 8 == 0) ||
// TODO (cyang): add is_training to nvte_get_fused_attn_backend // TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward // d=256 only supported for forward
|| (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 && (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 &&
head_dim % 8 == 0)) && head_dim % 8 == 0)) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
((cudnn_runtime_version >= 8906) && ((cudnn_runtime_version >= 8906) &&
(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
...@@ -139,16 +144,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -139,16 +144,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) ||
((cudnn_runtime_version >= 90000) && ((cudnn_runtime_version >= 90000) &&
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) &&
// mask type
((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
((cudnn_runtime_version >= 8906) && ((cudnn_runtime_version >= 8906) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))) && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
((cudnn_runtime_version >= 90300) &&
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) &&
// bias + mask combination
(!(cudnn_runtime_version >= 8906 && (!(cudnn_runtime_version >= 8906 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
// qkv format
((qkv_format == NVTE_QKV_Format::NVTE_SBHD) || ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) ||
(sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups && (sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups &&
qkv_format == NVTE_QKV_Format::NVTE_THD) || qkv_format == NVTE_QKV_Format::NVTE_THD) ||
......
...@@ -61,6 +61,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -61,6 +61,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_dropout = (is_training && dropout_probability != 0.0f);
...@@ -203,6 +205,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -203,6 +205,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_name("flash_attention") .set_name("flash_attention")
.set_is_inference(false) .set_is_inference(false)
.set_causal_mask(is_causal) .set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
sdpa_options.set_alibi_mask(is_alibi); sdpa_options.set_alibi_mask(is_alibi);
...@@ -376,6 +379,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -376,6 +379,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f); bool is_dropout = (dropout_probability != 0.0f);
...@@ -544,6 +549,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -544,6 +549,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options = fe::graph::SDPA_backward_attributes() sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("flash_attention_backward") .set_name("flash_attention_backward")
.set_causal_mask(is_causal) .set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
sdpa_backward_options.set_alibi_mask(is_alibi); sdpa_backward_options.set_alibi_mask(is_alibi);
......
...@@ -93,10 +93,14 @@ enum NVTE_Mask_Type { ...@@ -93,10 +93,14 @@ enum NVTE_Mask_Type {
NVTE_NO_MASK = 0, NVTE_NO_MASK = 0,
/*! Padding attention mask */ /*! Padding attention mask */
NVTE_PADDING_MASK = 1, NVTE_PADDING_MASK = 1,
/*! Causal attention mask */ /*! Causal attention mask (aligned to the top left corner) */
NVTE_CAUSAL_MASK = 2, NVTE_CAUSAL_MASK = 2,
/*! Padding and causal attention mask */ /*! Padding and causal attention mask (aligned to the top left corner) */
NVTE_PADDING_CAUSAL_MASK = 3, NVTE_PADDING_CAUSAL_MASK = 3,
/*! Causal attention mask (aligned to the bottom right corner) */
NVTE_CAUSAL_BOTTOM_RIGHT_MASK = 4,
/*! Padding and causal attention mask (aligned to the bottom right corner) */
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5,
}; };
/*! \enum NVTE_Fused_Attn_Backend /*! \enum NVTE_Fused_Attn_Backend
......
This diff is collapsed.
...@@ -22,7 +22,15 @@ TE_DType = { ...@@ -22,7 +22,15 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16, torch.bfloat16: tex.DType.kBFloat16,
} }
AttnMaskTypes = ("causal", "padding", "padding_causal", "arbitrary", "no_mask") AttnMaskTypes = (
"no_mask",
"padding",
"causal",
"padding_causal",
"causal_bottom_right",
"padding_causal_bottom_right",
"arbitrary",
)
AttnTypes = ("self", "cross") AttnTypes = ("self", "cross")
......
...@@ -64,6 +64,8 @@ AttnMaskType = { ...@@ -64,6 +64,8 @@ AttnMaskType = {
"padding": NVTE_Mask_Type.NVTE_PADDING_MASK, "padding": NVTE_Mask_Type.NVTE_PADDING_MASK,
"causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK, "causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK,
"padding_causal": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK, "padding_causal": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK,
"causal_bottom_right": NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK,
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
} }
FusedAttnBackend = { FusedAttnBackend = {
......
...@@ -311,7 +311,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -311,7 +311,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK); .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout") py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
......
...@@ -332,7 +332,9 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -332,7 +332,9 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "arbitrary": if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported return False # Custom masks not supported
if self.attn_mask_type == "causal": # unfused causal softmax kernel if self.attn_mask_type == "causal_bottom_right" or (
self.attn_mask_type == "causal" and sq == sk
): # fused causal softmax kernel
return True return True
if ( if (
...@@ -361,7 +363,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -361,7 +363,7 @@ class FusedScaleMaskSoftmax(nn.Module):
"""Fused masked softmax kernel""" """Fused masked softmax kernel"""
scale = 1.0 if scale is None else scale scale = 1.0 if scale is None else scale
if self.attn_mask_type == "causal": if "causal" in self.attn_mask_type:
return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)
# input is 4D tensor (b, np, sq, sk) # input is 4D tensor (b, np, sq, sk)
...@@ -379,7 +381,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -379,7 +381,7 @@ class FusedScaleMaskSoftmax(nn.Module):
if scale is not None: if scale is not None:
inp = inp * scale inp = inp * scale
if self.attn_mask_type == "causal": if "causal" in self.attn_mask_type:
seq_len_q, seq_len_k = inp.size(2), inp.size(3) seq_len_q, seq_len_k = inp.size(2), inp.size(3)
if is_in_onnx_export_mode() and self.kvcache_max_seq > 0: if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
assert self.kvcache_max_seq >= seq_len_k assert self.kvcache_max_seq >= seq_len_k
......
...@@ -130,19 +130,26 @@ class TransformerLayer(torch.nn.Module): ...@@ -130,19 +130,26 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None` kv_channels: int, default = `None`
number of query-key-value channels per attention head. defaults to number of query-key-value channels per attention head. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`. :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'}, self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right', 'arbitrary'},
default = `causal` default = `causal`
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation for encoder.
:attr:`self_attn_mask_type` in the `forward` method. The forward Overridden by :attr:`self_attn_mask_type` in the `forward` method.
arg is useful for dynamically changing mask types, e.g. a different The forward arg is useful for dynamically changing mask types, e.g.
mask for training and inference. The init arg is useful for cases a different mask for training and inference. The init arg is useful
involving compilation/tracing, e.g. ONNX export. for cases involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention in encoder, where query at position i
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
window and causal mask specifically. Similar to :attr:`self_attn_mask_type`, it can no sliding window and "`causal`" mask specifically. Similar to
be overridden by :attr:`window_size` in `forward` as well. :attr:`self_attn_mask_type`, it can be overridden by :attr:`window_size`
in `forward` as well.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `no_mask`
type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention in decoder.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
...@@ -238,6 +245,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -238,6 +245,8 @@ class TransformerLayer(torch.nn.Module):
kv_channels: Optional[int] = None, kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal", self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
enc_dec_attn_mask_type: str = "no_mask",
enc_dec_window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
...@@ -270,8 +279,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -270,8 +279,11 @@ class TransformerLayer(torch.nn.Module):
super().__init__() super().__init__()
self.self_attn_mask_type = self_attn_mask_type self.self_attn_mask_type = self_attn_mask_type
self.window_size = window_size self.window_size = check_set_window_size(self_attn_mask_type, window_size)
self.window_size = check_set_window_size(self_attn_mask_type, self.window_size) self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
self.enc_dec_window_size = check_set_window_size(
enc_dec_attn_mask_type, enc_dec_window_size
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
...@@ -368,7 +380,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -368,7 +380,7 @@ class TransformerLayer(torch.nn.Module):
self.inter_attention = MultiheadAttention( self.inter_attention = MultiheadAttention(
*attention_args, *attention_args,
**common_attention_kwargs, **common_attention_kwargs,
attn_mask_type="padding", attn_mask_type=enc_dec_attn_mask_type,
input_layernorm=True, input_layernorm=True,
attention_type="cross", attention_type="cross",
bias=bias, bias=bias,
...@@ -502,6 +514,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -502,6 +514,8 @@ class TransformerLayer(torch.nn.Module):
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
enc_dec_attn_mask_type: Optional[str] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
...@@ -524,28 +538,37 @@ class TransformerLayer(torch.nn.Module): ...@@ -524,28 +538,37 @@ class TransformerLayer(torch.nn.Module):
hidden_states : torch.Tensor hidden_states : torch.Tensor
Input tensor. Input tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input. It should be
It should be in [batch_size, 1, 1, seqlen_q] for 'padding' mask, in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
and broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
for 'arbitrary'. It should be 'None' for 'causal' and 'no_mask'. mask. It should be `None` for causal masks and "`no_mask`" type.
A `True` value means the corresponding position is masked out and A `True` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention. a `False` means that position is allowed to participate in attention.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default = `causal` default = `causal`
Type of attention mask passed into softmax operation. Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of
the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention. Sliding window size for local attention in encoder.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. `layer_type="decoder"`.
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensors used to mask out inter-attention softmax input if default = `None`. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for 'padding' mask. [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary' mask. It should be 'None' for 'causal' and 'no_mask'. A `True` value for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
means the corresponding position is masked out and a `False` means that position is A `True` value means the corresponding position is masked out and a `False`
allowed to participate in attention. means that position is allowed to participate in attention.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `None`
Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention in decoder.
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split pipeline parallelism a minibatch of data is further split
...@@ -582,16 +605,23 @@ class TransformerLayer(torch.nn.Module): ...@@ -582,16 +605,23 @@ class TransformerLayer(torch.nn.Module):
to efficienly calculate and store the context during inference. to efficienly calculate and store the context during inference.
""" """
if self_attn_mask_type is not None:
window_size = check_set_window_size(self_attn_mask_type, window_size)
if self_attn_mask_type is None: if self_attn_mask_type is None:
self_attn_mask_type = self.self_attn_mask_type self_attn_mask_type = self.self_attn_mask_type
if window_size is None: if window_size is None:
window_size = self.window_size window_size = self.window_size
window_size = check_set_window_size(self_attn_mask_type, window_size)
if enc_dec_attn_mask_type is None:
enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
if enc_dec_window_size is None:
enc_dec_window_size = self.enc_dec_window_size
enc_dec_window_size = check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size)
assert ( assert (
self_attn_mask_type in AttnMaskTypes self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported" ), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert (
enc_dec_attn_mask_type in AttnMaskTypes
), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported"
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
...@@ -604,6 +634,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -604,6 +634,12 @@ class TransformerLayer(torch.nn.Module):
"padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
) and attention_mask is not None: ) and attention_mask is not None:
assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor" assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor"
if (
"padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
) and enc_dec_attn_mask is not None:
assert all(
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)"
# For AMP # For AMP
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
...@@ -614,7 +650,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -614,7 +650,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type, attn_mask_type=self_attn_mask_type,
window_size=window_size, window_size=enc_dec_window_size,
inference_params=inference_params, inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment