Unverified Commit 27aa609c authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[PyTorch] Add sliding window support to FlashAttention (#551)



* add sliding window to FA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix forward logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change bert test to causal as unfused does not support padding
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FlashAttention for v2-2.3 versions
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* verify FA swa works
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix mask related restrictions and duplicate code after merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix swa test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for get_swa func
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move repeated code into a function
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert mask change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add determinism filter and fix FA warning message
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add message for determinism filter
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify check_set_window_size()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix check_set_window_size in transformer layers
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix indent
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent 4a147e0f
...@@ -152,10 +152,16 @@ def _is_flash_attention_2_available() -> bool: ...@@ -152,10 +152,16 @@ def _is_flash_attention_2_available() -> bool:
@functools.cache @functools.cache
def _is_flash_attention_2_1() -> bool: def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.0+ is available""" """Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1") return Version(version("flash-attn")) >= Version("2.1")
@functools.cache
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3")
def _is_flash_attention_supported(config: ModelConfig) -> bool: def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration""" """Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0): if get_device_compute_capability() < (8, 0):
...@@ -192,6 +198,17 @@ if torch.cuda.is_bf16_supported(): ...@@ -192,6 +198,17 @@ if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16] param_types_lean = [torch.bfloat16]
def get_swa(seq_q, seq_kv, w=None):
"""Generate a random sliding window size (left, right) if w is None,
and create its equivalent attention mask in [seq_q, seq_kv] shape"""
if w is None:
w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda")
m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
mu = torch.triu(m, diagonal=seq_kv-seq_q-w[0])
ml = torch.tril(mu, diagonal=seq_kv-seq_q+w[1])
ml = ~ ml
return w, ml
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model_configs", [model_configs_base])
...@@ -199,7 +216,8 @@ param_types_lean = [torch.bfloat16] ...@@ -199,7 +216,8 @@ param_types_lean = [torch.bfloat16]
@pytest.mark.parametrize("ckpt_attn", [False]) @pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False]) @pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("qkv_layout", [None])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout): @pytest.mark.parametrize("swa", [False])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa):
"""Test DotProductAttention module""" """Test DotProductAttention module"""
# Get configs # Get configs
...@@ -224,36 +242,43 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace ...@@ -224,36 +242,43 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout, config, dtype, qkv_layout=qkv_layout,
) )
if swa:
fused_attn_supported = False
flash_attn_supported = _is_flash_attention_supported(config) flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
if swa:
attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt, dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
) )
if swa:
config.attn_mask_type = attn_mask_type
# FusedAttention backend # FusedAttention backend
if fused_attn_supported: if fused_attn_supported:
if len(fused_attn_backend) == 1: if len(fused_attn_backend) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
) )
if len(fused_attn_backend) == 2: if len(fused_attn_backend) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
) )
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
) )
# FlashAttention backend # FlashAttention backend
if flash_attn_supported: if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
) )
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
...@@ -279,7 +304,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace ...@@ -279,7 +304,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"]) @pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model): def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing""" """Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None) test_dot_product_attention(dtype, model_configs, model, True, True, None, False)
model_configs_mask = { model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
...@@ -303,7 +328,7 @@ model_configs_mask = { ...@@ -303,7 +328,7 @@ model_configs_mask = {
@pytest.mark.parametrize("model", model_configs_mask.keys()) @pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model): def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types""" """Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None) test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_bias = { model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
...@@ -339,7 +364,22 @@ model_configs_bias = { ...@@ -339,7 +364,22 @@ model_configs_bias = {
@pytest.mark.parametrize("model", model_configs_bias.keys()) @pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model): def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types""" """Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None) test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True)
qkv_layouts = [ qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd', 'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
...@@ -367,7 +407,7 @@ model_configs_layout = { ...@@ -367,7 +407,7 @@ model_configs_layout = {
@pytest.mark.parametrize("qkv_layout", qkv_layouts) @pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts""" """Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout) test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False)
def _run_dot_product_attention( def _run_dot_product_attention(
dtype: torch.dtype, dtype: torch.dtype,
...@@ -376,6 +416,7 @@ def _run_dot_product_attention( ...@@ -376,6 +416,7 @@ def _run_dot_product_attention(
ckpt_attn: bool, ckpt_attn: bool,
qkv_layout: str, qkv_layout: str,
workspace_opt: bool, workspace_opt: bool,
swa: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass""" """Run DotProductAttention module with one forward pass and one backward pass"""
...@@ -433,6 +474,10 @@ def _run_dot_product_attention( ...@@ -433,6 +474,10 @@ def _run_dot_product_attention(
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0) .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = ( attention_mask = (
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda")) attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
else:
window_size, attention_mask = None, None
# Create input tensors # Create input tensors
dim_to_num = { dim_to_num = {
...@@ -515,6 +560,7 @@ def _run_dot_product_attention( ...@@ -515,6 +560,7 @@ def _run_dot_product_attention(
# Run a forward and backward pass # Run a forward and backward pass
out = block(inp[0], inp[1], inp[2], out = block(inp[0], inp[1], inp[2],
window_size=window_size,
attention_mask=attention_mask, attention_mask=attention_mask,
qkv_format=qkv_format, qkv_format=qkv_format,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
......
...@@ -57,6 +57,7 @@ _flash_attn_version = packaging.version.Version(version("flash-attn")) ...@@ -57,6 +57,7 @@ _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6") _flash_attn_version_required = packaging.version.Version("1.0.6")
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") _flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
if _flash_attn_2_available: if _flash_attn_2_available:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
...@@ -1248,6 +1249,24 @@ def _get_qkv_layout( ...@@ -1248,6 +1249,24 @@ def _get_qkv_layout(
return qkv_layout, q, k, v return qkv_layout, q, k, v
def check_set_window_size(
attn_mask_type: str,
window_size: Tuple[int, int] = None,
):
"""Check if sliding window size is compliant with mask type and if not,
assert or set it to the appropriate size
"""
if "causal" in attn_mask_type:
if window_size is None:
window_size = (-1, 0)
else:
assert (
window_size[1] == 0
), "window_size[1] should be 0 when self_attn_mask_type includes 'causal'!"
else:
if window_size is None:
window_size = (-1, -1)
return window_size
class FlashAttention(torch.nn.Module): class FlashAttention(torch.nn.Module):
"""Dot product attention, using HazyResearch flash-attn package: """Dot product attention, using HazyResearch flash-attn package:
...@@ -1286,12 +1305,15 @@ class FlashAttention(torch.nn.Module): ...@@ -1286,12 +1305,15 @@ class FlashAttention(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
window_size = check_set_window_size(attn_mask_type, window_size)
assert ( assert (
query_layer.dtype in [torch.float16, torch.bfloat16] query_layer.dtype in [torch.float16, torch.bfloat16]
and key_layer.dtype in [torch.float16, torch.bfloat16] and key_layer.dtype in [torch.float16, torch.bfloat16]
...@@ -1402,6 +1424,9 @@ class FlashAttention(torch.nn.Module): ...@@ -1402,6 +1424,9 @@ class FlashAttention(torch.nn.Module):
max_seqlen_kv = seqlens_kv.max().item() max_seqlen_kv = seqlens_kv.max().item()
if context_parallel: if context_parallel:
assert (
window_size in ((-1, -1), (-1, 0))
), "Sliding window attention is not supported with context parallelism."
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
output = flash_attn_forward_func_with_cp( output = flash_attn_forward_func_with_cp(
query_layer, key_layer, value_layer, query_layer, key_layer, value_layer,
...@@ -1417,6 +1442,8 @@ class FlashAttention(torch.nn.Module): ...@@ -1417,6 +1442,8 @@ class FlashAttention(torch.nn.Module):
fa_optional_forward_kwargs = {} fa_optional_forward_kwargs = {}
if not _flash_attn_2_available: if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_kwargs["deterministic"] = self.deterministic
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size
output = flash_attn_forward_func( output = flash_attn_forward_func(
query_layer, key_layer, value_layer, query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
...@@ -1875,6 +1902,12 @@ class DotProductAttention(torch.nn.Module): ...@@ -1875,6 +1902,12 @@ class DotProductAttention(torch.nn.Module):
:attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask` :attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask`
in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users
need to provide a mask that is broadcastable to the shape of softmax input. need to provide a mask that is broadcastable to the shape of softmax input.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well.
attention_type: str, default = `self` attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`". type of attention, either "`self`" and "`cross`".
layer_number: int, default = `None` layer_number: int, default = `None`
...@@ -1918,6 +1951,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1918,6 +1951,7 @@ class DotProductAttention(torch.nn.Module):
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
qkv_format: str = "sbhd", qkv_format: str = "sbhd",
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_size: int = 1, tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
...@@ -1935,6 +1969,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -1935,6 +1969,8 @@ class DotProductAttention(torch.nn.Module):
if attn_mask_type == "causal_padding": if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal" attn_mask_type = "padding_causal"
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.window_size = window_size
self.window_size = check_set_window_size(attn_mask_type, self.window_size)
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
...@@ -1969,8 +2005,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -1969,8 +2005,8 @@ class DotProductAttention(torch.nn.Module):
if _flash_attn_2_available and self.deterministic: if _flash_attn_2_available and self.deterministic:
self.use_flash_attention = False self.use_flash_attention = False
warnings.warn( warnings.warn(
"Disabling usage of FlashAttention since version 2 does not support deterministic" "Disabling usage of FlashAttention since version 2 does not support deterministic "
"execution. In order to use FA with deterministic behavior, please install" "execution. In order to use FA with deterministic behavior, please install "
"FlashAttention version 1." "FlashAttention version 1."
) )
...@@ -2065,6 +2101,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2065,6 +2101,7 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
...@@ -2138,6 +2175,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2138,6 +2175,8 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`, attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
`arbitrary`}, default = `None`. Type of attention mask passed into `arbitrary`}, default = `None`. Type of attention mask passed into
softmax operation. 'padding,causal' and 'causal,padding' are equivalent. softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
checkpoint_core_attention : bool, default = `False` checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
...@@ -2159,6 +2198,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2159,6 +2198,8 @@ class DotProductAttention(torch.nn.Module):
assert (key_layer.shape == value_layer.shape assert (key_layer.shape == value_layer.shape
), "Keys and values must have the same shape!" ), "Keys and values must have the same shape!"
if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
else: else:
...@@ -2169,6 +2210,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -2169,6 +2210,9 @@ class DotProductAttention(torch.nn.Module):
assert (attn_mask_type in AttnMaskTypes assert (attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!" ), f"Attention mask type {attn_mask_type} is not supported!"
if window_size is None:
window_size = self.window_size
if qkv_format is None: if qkv_format is None:
qkv_format = self.qkv_format qkv_format = self.qkv_format
...@@ -2220,6 +2264,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2220,6 +2264,7 @@ class DotProductAttention(torch.nn.Module):
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention. # is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
use_flash_attention = self.use_flash_attention use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention use_fused_attention = self.use_fused_attention
use_unfused_attention = True
# The following section filters out some backends based on # The following section filters out some backends based on
# certain asserts before executing the forward pass. # certain asserts before executing the forward pass.
...@@ -2249,9 +2294,11 @@ class DotProductAttention(torch.nn.Module): ...@@ -2249,9 +2294,11 @@ class DotProductAttention(torch.nn.Module):
and self.device_compute_capability not in ((8, 0), (9, 0)))): and self.device_compute_capability not in ((8, 0), (9, 0)))):
use_flash_attention = False use_flash_attention = False
# Filter: MQA/GQA.
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads: if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False use_flash_attention = False
# Filter: cross attention + causal mask.
if (_flash_attn_2_1_plus if (_flash_attn_2_1_plus
and "causal" in attn_mask_type and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv): and max_seqlen_q != max_seqlen_kv):
...@@ -2262,9 +2309,19 @@ class DotProductAttention(torch.nn.Module): ...@@ -2262,9 +2309,19 @@ class DotProductAttention(torch.nn.Module):
) )
use_flash_attention = False use_flash_attention = False
# Filter: bias.
if core_attention_bias_type != "no_bias" or core_attention_bias is not None: if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False use_flash_attention = False
# Filter: sliding window attention.
# UnfusedDotProductAttention can support SWA via arbitrary attention mask.
if window_size not in ((-1, -1), (-1, 0)):
use_fused_attention = False
context_parallel = (self.cp_group is not None
and get_distributed_world_size(self.cp_group) != 1)
if (not _flash_attn_2_3_plus) or context_parallel:
use_flash_attention = False
# Filter: ONNX export. # Filter: ONNX export.
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
use_flash_attention = False use_flash_attention = False
...@@ -2282,6 +2339,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2282,6 +2339,8 @@ class DotProductAttention(torch.nn.Module):
if attn_mask_type == "arbitrary": if attn_mask_type == "arbitrary":
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
use_unfused_attention = False
if use_fused_attention: if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend( fused_attention_backend = tex.get_fused_attn_backend(
...@@ -2303,6 +2362,24 @@ class DotProductAttention(torch.nn.Module): ...@@ -2303,6 +2362,24 @@ class DotProductAttention(torch.nn.Module):
use_fused_attention = (use_fused_attention use_fused_attention = (use_fused_attention
and is_backend_avail) and is_backend_avail)
# Filter: determinism.
# backend | deterministic
# ---------------------------------------------------------
# flash-attn v1 | yes
# flash-attn v2 | no
# FusedAttnBackend["F16_max512_seqlen"] | yes
# FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
# UnfusedDotProductAttention | yes
#
# Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
# on sm90 architectures.
#
if (use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and self.deterministic
and self.device_compute_capability != (9, 0)):
use_fused_attention = False
# Select FusedAttention on sm90 and FlashAttention on others for performance # Select FusedAttention on sm90 and FlashAttention on others for performance
if (use_flash_attention if (use_flash_attention
and use_fused_attention and use_fused_attention
...@@ -2321,6 +2398,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2321,6 +2398,7 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
window_size=window_size,
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream) cp_stream=self.cp_stream)
...@@ -2360,6 +2438,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2360,6 +2438,7 @@ class DotProductAttention(torch.nn.Module):
if _NVTE_DEBUG: if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA") print("[DotProductAttention]: using unfused DPA")
if use_unfused_attention:
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.unfused_attention, self.unfused_attention,
...@@ -2384,6 +2463,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2384,6 +2463,8 @@ class DotProductAttention(torch.nn.Module):
core_attention_bias_type = core_attention_bias_type, core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias) core_attention_bias = core_attention_bias)
raise Exception("No dot product attention support for the provided inputs!")
class MultiheadAttention(torch.nn.Module): class MultiheadAttention(torch.nn.Module):
r""" r"""
...@@ -2427,6 +2508,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2427,6 +2508,12 @@ class MultiheadAttention(torch.nn.Module):
arg is useful for dynamically changing mask types, e.g. a different arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export. involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well.
num_gqa_groups : int, default = `None` num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer. number of GQA groups in the transformer layer.
Grouped Query Attention is described in Grouped Query Attention is described in
...@@ -2518,6 +2605,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2518,6 +2605,7 @@ class MultiheadAttention(torch.nn.Module):
output_layer_init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
num_gqa_groups: Optional[int] = None, num_gqa_groups: Optional[int] = None,
...@@ -2546,6 +2634,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2546,6 +2634,8 @@ class MultiheadAttention(torch.nn.Module):
super().__init__() super().__init__()
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.window_size = window_size
self.window_size = check_set_window_size(attn_mask_type, self.window_size)
self.layer_number = layer_number self.layer_number = layer_number
self.input_layernorm = input_layernorm self.input_layernorm = input_layernorm
self.attention_type = attention_type self.attention_type = attention_type
...@@ -2759,6 +2849,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2759,6 +2849,7 @@ class MultiheadAttention(torch.nn.Module):
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
...@@ -2789,6 +2880,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2789,6 +2880,8 @@ class MultiheadAttention(torch.nn.Module):
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'}, attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
default = `None` default = `None`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. `layer_type="decoder"`.
...@@ -2823,8 +2916,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2823,8 +2916,12 @@ class MultiheadAttention(torch.nn.Module):
""" """
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
if window_size is None:
window_size = self.window_size
if "padding" in attn_mask_type and attention_mask is not None: if "padding" in attn_mask_type and attention_mask is not None:
for i,_ in enumerate(attention_mask): for i,_ in enumerate(attention_mask):
...@@ -3037,6 +3134,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3037,6 +3134,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens_kv=None, cu_seqlens_kv=None,
attention_mask=attention_mask, attention_mask=attention_mask,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
window_size=window_size,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
......
...@@ -12,7 +12,11 @@ import torch ...@@ -12,7 +12,11 @@ import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.pytorch.attention import InferenceParams, MultiheadAttention from transformer_engine.pytorch.attention import (
InferenceParams,
MultiheadAttention,
check_set_window_size,
)
from transformer_engine.pytorch.jit import ( from transformer_engine.pytorch.jit import (
set_jit_fusion_options, set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes, warmup_jit_bias_dropout_add_all_dtypes,
...@@ -134,6 +138,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -134,6 +138,12 @@ class TransformerLayer(torch.nn.Module):
arg is useful for dynamically changing mask types, e.g. a different arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export. involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`self_attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
...@@ -220,6 +230,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -220,6 +230,7 @@ class TransformerLayer(torch.nn.Module):
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
kv_channels: Optional[int] = None, kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal", self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
...@@ -251,6 +262,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -251,6 +262,8 @@ class TransformerLayer(torch.nn.Module):
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
self.self_attn_mask_type = self_attn_mask_type self.self_attn_mask_type = self_attn_mask_type
self.window_size = window_size
self.window_size = check_set_window_size(self_attn_mask_type, self.window_size)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1"))) ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1"))) ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
...@@ -491,6 +504,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -491,6 +504,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: Optional[str] = None, self_attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
...@@ -521,6 +535,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -521,6 +535,8 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `causal` default = `causal`
Type of attention mask passed into softmax operation. Type of attention mask passed into softmax operation.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. `layer_type="decoder"`.
...@@ -562,8 +578,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -562,8 +578,12 @@ class TransformerLayer(torch.nn.Module):
to efficienly calculate and store the context during inference. to efficienly calculate and store the context during inference.
""" """
if self_attn_mask_type is not None:
window_size = check_set_window_size(self_attn_mask_type, window_size)
if self_attn_mask_type is None: if self_attn_mask_type is None:
self_attn_mask_type = self.self_attn_mask_type self_attn_mask_type = self.self_attn_mask_type
if window_size is None:
window_size = self.window_size
assert ( assert (
self_attn_mask_type in AttnMaskTypes self_attn_mask_type in AttnMaskTypes
...@@ -594,6 +614,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -594,6 +614,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type, attn_mask_type=self_attn_mask_type,
window_size=window_size,
inference_params=inference_params, inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
...@@ -619,6 +640,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -619,6 +640,7 @@ class TransformerLayer(torch.nn.Module):
inter_attention_outputs = self.inter_attention( inter_attention_outputs = self.inter_attention(
hidden_states, hidden_states,
attention_mask=enc_dec_attn_mask, attention_mask=enc_dec_attn_mask,
window_size=window_size,
encoder_output=encoder_output, encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment