Unverified Commit 27aa609c authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[PyTorch] Add sliding window support to FlashAttention (#551)



* add sliding window to FA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix forward logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change bert test to causal as unfused does not support padding
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FlashAttention for v2-2.3 versions
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* verify FA swa works
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix mask related restrictions and duplicate code after merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix swa test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for get_swa func
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move repeated code into a function
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert mask change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add determinism filter and fix FA warning message
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add message for determinism filter
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify check_set_window_size()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix check_set_window_size in transformer layers
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix indent
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent 4a147e0f
......@@ -152,10 +152,16 @@ def _is_flash_attention_2_available() -> bool:
@functools.cache
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.0+ is available"""
"""Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")
@functools.cache
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3")
def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0):
......@@ -192,6 +198,17 @@ if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
def get_swa(seq_q, seq_kv, w=None):
"""Generate a random sliding window size (left, right) if w is None,
and create its equivalent attention mask in [seq_q, seq_kv] shape"""
if w is None:
w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda")
m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
mu = torch.triu(m, diagonal=seq_kv-seq_q-w[0])
ml = torch.tril(mu, diagonal=seq_kv-seq_q+w[1])
ml = ~ ml
return w, ml
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
......@@ -199,7 +216,8 @@ param_types_lean = [torch.bfloat16]
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout):
@pytest.mark.parametrize("swa", [False])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa):
"""Test DotProductAttention module"""
# Get configs
......@@ -224,36 +242,43 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
)
if swa:
fused_attn_supported = False
flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
# UnfusedDotProductAttention backend
if unfused_attn_supported:
if swa:
attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt,
dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
)
if swa:
config.attn_mask_type = attn_mask_type
# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backend) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
)
if len(fused_attn_backend) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt,
dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
)
if unfused_attn_supported and fused_attn_supported:
......@@ -279,7 +304,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None)
test_dot_product_attention(dtype, model_configs, model, True, True, None, False)
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
......@@ -303,7 +328,7 @@ model_configs_mask = {
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
......@@ -339,7 +364,22 @@ model_configs_bias = {
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True)
qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
......@@ -367,7 +407,7 @@ model_configs_layout = {
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout)
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False)
def _run_dot_product_attention(
dtype: torch.dtype,
......@@ -376,6 +416,7 @@ def _run_dot_product_attention(
ckpt_attn: bool,
qkv_layout: str,
workspace_opt: bool,
swa: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
......@@ -433,6 +474,10 @@ def _run_dot_product_attention(
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = (
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
else:
window_size, attention_mask = None, None
# Create input tensors
dim_to_num = {
......@@ -515,6 +560,7 @@ def _run_dot_product_attention(
# Run a forward and backward pass
out = block(inp[0], inp[1], inp[2],
window_size=window_size,
attention_mask=attention_mask,
qkv_format=qkv_format,
cu_seqlens_q=cu_seqlens_q,
......
......@@ -57,6 +57,7 @@ _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6")
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
if _flash_attn_2_available:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
......@@ -1248,6 +1249,24 @@ def _get_qkv_layout(
return qkv_layout, q, k, v
def check_set_window_size(
attn_mask_type: str,
window_size: Tuple[int, int] = None,
):
"""Check if sliding window size is compliant with mask type and if not,
assert or set it to the appropriate size
"""
if "causal" in attn_mask_type:
if window_size is None:
window_size = (-1, 0)
else:
assert (
window_size[1] == 0
), "window_size[1] should be 0 when self_attn_mask_type includes 'causal'!"
else:
if window_size is None:
window_size = (-1, -1)
return window_size
class FlashAttention(torch.nn.Module):
"""Dot product attention, using HazyResearch flash-attn package:
......@@ -1286,12 +1305,15 @@ class FlashAttention(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
) -> torch.Tensor:
"""flash-attn fprop"""
window_size = check_set_window_size(attn_mask_type, window_size)
assert (
query_layer.dtype in [torch.float16, torch.bfloat16]
and key_layer.dtype in [torch.float16, torch.bfloat16]
......@@ -1402,6 +1424,9 @@ class FlashAttention(torch.nn.Module):
max_seqlen_kv = seqlens_kv.max().item()
if context_parallel:
assert (
window_size in ((-1, -1), (-1, 0))
), "Sliding window attention is not supported with context parallelism."
with self.attention_dropout_ctx():
output = flash_attn_forward_func_with_cp(
query_layer, key_layer, value_layer,
......@@ -1417,6 +1442,8 @@ class FlashAttention(torch.nn.Module):
fa_optional_forward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size
output = flash_attn_forward_func(
query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
......@@ -1875,6 +1902,12 @@ class DotProductAttention(torch.nn.Module):
:attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask`
in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users
need to provide a mask that is broadcastable to the shape of softmax input.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well.
attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`".
layer_number: int, default = `None`
......@@ -1918,6 +1951,7 @@ class DotProductAttention(torch.nn.Module):
attention_dropout: float = 0.0,
qkv_format: str = "sbhd",
attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
sequence_parallel: bool = False,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
......@@ -1935,6 +1969,8 @@ class DotProductAttention(torch.nn.Module):
if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal"
self.attn_mask_type = attn_mask_type
self.window_size = window_size
self.window_size = check_set_window_size(attn_mask_type, self.window_size)
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker
......@@ -1969,8 +2005,8 @@ class DotProductAttention(torch.nn.Module):
if _flash_attn_2_available and self.deterministic:
self.use_flash_attention = False
warnings.warn(
"Disabling usage of FlashAttention since version 2 does not support deterministic"
"execution. In order to use FA with deterministic behavior, please install"
"Disabling usage of FlashAttention since version 2 does not support deterministic "
"execution. In order to use FA with deterministic behavior, please install "
"FlashAttention version 1."
)
......@@ -2065,6 +2101,7 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......@@ -2138,6 +2175,8 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
`arbitrary`}, default = `None`. Type of attention mask passed into
softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
......@@ -2159,6 +2198,8 @@ class DotProductAttention(torch.nn.Module):
assert (key_layer.shape == value_layer.shape
), "Keys and values must have the same shape!"
if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
else:
......@@ -2169,6 +2210,9 @@ class DotProductAttention(torch.nn.Module):
assert (attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!"
if window_size is None:
window_size = self.window_size
if qkv_format is None:
qkv_format = self.qkv_format
......@@ -2220,6 +2264,7 @@ class DotProductAttention(torch.nn.Module):
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention
use_unfused_attention = True
# The following section filters out some backends based on
# certain asserts before executing the forward pass.
......@@ -2249,9 +2294,11 @@ class DotProductAttention(torch.nn.Module):
and self.device_compute_capability not in ((8, 0), (9, 0)))):
use_flash_attention = False
# Filter: MQA/GQA.
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False
# Filter: cross attention + causal mask.
if (_flash_attn_2_1_plus
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv):
......@@ -2262,9 +2309,19 @@ class DotProductAttention(torch.nn.Module):
)
use_flash_attention = False
# Filter: bias.
if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False
# Filter: sliding window attention.
# UnfusedDotProductAttention can support SWA via arbitrary attention mask.
if window_size not in ((-1, -1), (-1, 0)):
use_fused_attention = False
context_parallel = (self.cp_group is not None
and get_distributed_world_size(self.cp_group) != 1)
if (not _flash_attn_2_3_plus) or context_parallel:
use_flash_attention = False
# Filter: ONNX export.
if is_in_onnx_export_mode():
use_flash_attention = False
......@@ -2282,6 +2339,8 @@ class DotProductAttention(torch.nn.Module):
if attn_mask_type == "arbitrary":
use_flash_attention = False
use_fused_attention = False
if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
use_unfused_attention = False
if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend(
......@@ -2303,6 +2362,24 @@ class DotProductAttention(torch.nn.Module):
use_fused_attention = (use_fused_attention
and is_backend_avail)
# Filter: determinism.
# backend | deterministic
# ---------------------------------------------------------
# flash-attn v1 | yes
# flash-attn v2 | no
# FusedAttnBackend["F16_max512_seqlen"] | yes
# FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
# UnfusedDotProductAttention | yes
#
# Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
# on sm90 architectures.
#
if (use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and self.deterministic
and self.device_compute_capability != (9, 0)):
use_fused_attention = False
# Select FusedAttention on sm90 and FlashAttention on others for performance
if (use_flash_attention
and use_fused_attention
......@@ -2321,6 +2398,7 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
window_size=window_size,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream)
......@@ -2360,6 +2438,7 @@ class DotProductAttention(torch.nn.Module):
if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA")
if use_unfused_attention:
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
......@@ -2384,6 +2463,8 @@ class DotProductAttention(torch.nn.Module):
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias)
raise Exception("No dot product attention support for the provided inputs!")
class MultiheadAttention(torch.nn.Module):
r"""
......@@ -2427,6 +2508,12 @@ class MultiheadAttention(torch.nn.Module):
arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well.
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
......@@ -2518,6 +2605,7 @@ class MultiheadAttention(torch.nn.Module):
output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None,
attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
num_gqa_groups: Optional[int] = None,
......@@ -2546,6 +2634,8 @@ class MultiheadAttention(torch.nn.Module):
super().__init__()
self.attn_mask_type = attn_mask_type
self.window_size = window_size
self.window_size = check_set_window_size(attn_mask_type, self.window_size)
self.layer_number = layer_number
self.input_layernorm = input_layernorm
self.attention_type = attention_type
......@@ -2759,6 +2849,7 @@ class MultiheadAttention(torch.nn.Module):
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
encoder_output: Optional[torch.Tensor] = None,
attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None,
......@@ -2789,6 +2880,8 @@ class MultiheadAttention(torch.nn.Module):
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
default = `None`
type of attention mask passed into softmax operation.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
......@@ -2823,8 +2916,12 @@ class MultiheadAttention(torch.nn.Module):
"""
# hidden_states: [sq, b, h]
if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
if window_size is None:
window_size = self.window_size
if "padding" in attn_mask_type and attention_mask is not None:
for i,_ in enumerate(attention_mask):
......@@ -3037,6 +3134,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens_kv=None,
attention_mask=attention_mask,
attn_mask_type=attn_mask_type,
window_size=window_size,
checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
......
......@@ -12,7 +12,11 @@ import torch
import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.pytorch.attention import InferenceParams, MultiheadAttention
from transformer_engine.pytorch.attention import (
InferenceParams,
MultiheadAttention,
check_set_window_size,
)
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes,
......@@ -134,6 +138,12 @@ class TransformerLayer(torch.nn.Module):
arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Similar to :attr:`self_attn_mask_type`, it can
be overridden by :attr:`window_size` in `forward` as well.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -220,6 +230,7 @@ class TransformerLayer(torch.nn.Module):
layer_number: Optional[int] = None,
kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
params_dtype: Optional[torch.dtype] = None,
......@@ -251,6 +262,8 @@ class TransformerLayer(torch.nn.Module):
), "Userbuffer communication backend not available."
self.self_attn_mask_type = self_attn_mask_type
self.window_size = window_size
self.window_size = check_set_window_size(self_attn_mask_type, self.window_size)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
......@@ -491,6 +504,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
is_first_microbatch: Optional[bool] = None,
......@@ -521,6 +535,8 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `causal`
Type of attention mask passed into softmax operation.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
......@@ -562,8 +578,12 @@ class TransformerLayer(torch.nn.Module):
to efficienly calculate and store the context during inference.
"""
if self_attn_mask_type is not None:
window_size = check_set_window_size(self_attn_mask_type, window_size)
if self_attn_mask_type is None:
self_attn_mask_type = self.self_attn_mask_type
if window_size is None:
window_size = self.window_size
assert (
self_attn_mask_type in AttnMaskTypes
......@@ -594,6 +614,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states,
attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type,
window_size=window_size,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......@@ -619,6 +640,7 @@ class TransformerLayer(torch.nn.Module):
inter_attention_outputs = self.inter_attention(
hidden_states,
attention_mask=enc_dec_attn_mask,
window_size=window_size,
encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment