Unverified Commit 85e60e64 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] Expose sliding window attn to TE-JAX API (#1205)



* Expose JAX sliding window attn API
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* No SWA in context parallel; fix RNG seed in test
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Handle SAW API discrepancy in cuDNN and Python
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add SAW API for flax, all tests passed

Will update tests/jax/test_praxis_layers.py next
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update test_praxis_layers.py for SWA, test passed
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Use tuple window_size; update for PR #1212
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add and adjust some pytest.skip
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revised following Reese Wang's comments

Still need further debugging:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-KV_PACKED-NO_MASK-NO_BIAS] - AssertionError:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-KV_PACKED-NO_MASK-POST_SCALE_BIAS-1HSS] - AssertionError:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-SEPARATE-NO_MASK-NO_BIAS] - AssertionError:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-SEPARATE-NO_MASK-POST_SCALE_BIAS-1HSS] - AssertionError:

These errors does not exist in the previous commit
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix no-SWA test case errors in previous commit
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add Padding mask w/ sliding windows sanity tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use float32 for the reference code softmax calculation
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarReese Wang <rewang@nvidia.com>
parent 5b6546c8
......@@ -6,6 +6,7 @@ from enum import Enum
from dataclasses import dataclass
from functools import partial
from math import sqrt
from typing import Tuple, Optional
import jax
import jax.numpy as jnp
......@@ -27,6 +28,7 @@ from transformer_engine.jax.attention import (
fused_attn,
fused_attn_thd,
get_qkv_format,
make_swa_mask,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import (
......@@ -123,6 +125,7 @@ def make_mask(
segment_pad_q: ArrayLike,
segment_pad_kv: ArrayLike,
attn_mask_type: AttnMaskType,
window_size: Optional[Tuple[int, int]] = None,
) -> Array:
"""
Create attention mask based on mask type. A `True` value in the mask means
......@@ -140,6 +143,15 @@ def make_mask(
segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
)
inv_mask = combine_masks(inv_pad_mask, inv_mask)
if window_size is not None:
max_seqlen_q = inv_mask.shape[-2]
max_seqlen_kv = inv_mask.shape[-1]
inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
# In inv_swa_mask and inv_mask 0 is masked out
inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask)
mask = jnp.logical_not(inv_mask)
return mask
......@@ -274,6 +286,7 @@ class FusedAttnRunner:
is_training: bool
qkv_layout: QKVLayout
bias_shape: BiasShape
window_size: Optional[Tuple[int, int]] = None
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
......@@ -298,6 +311,11 @@ class FusedAttnRunner:
if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.")
if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
self.backend = FusedAttnHelper(
self.dtype,
self.dtype,
......@@ -310,6 +328,7 @@ class FusedAttnRunner:
self.max_seqlen_q,
self.max_seqlen_kv,
self.head_dim,
(-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")
......@@ -456,6 +475,7 @@ class FusedAttnRunner:
self.segment_pad_q,
self.segment_pad_kv,
self.attn_mask_type,
self.window_size,
)
if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
......@@ -500,6 +520,7 @@ class FusedAttnRunner:
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
}
# Convert the outputs to float32 for the elementwise comparison
......@@ -557,6 +578,7 @@ class FusedAttnRunner:
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
}
# We can compute dBias only for the [1, h, s, s] layout
......@@ -668,7 +690,7 @@ class FusedAttnRunner:
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
pytest.param(
2,
2048,
......@@ -677,7 +699,7 @@ class FusedAttnRunner:
12,
64,
jnp.bfloat16,
id="2-2048-1048-12-12-64-BF16-CROSS",
id="2-2048-1024-12-12-64-BF16-CROSS",
),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
......@@ -690,6 +712,13 @@ class FusedAttnRunner:
pytest.param(0.1, id="DROP_0.1"),
],
)
@pytest.mark.parametrize(
"swa",
[
pytest.param(False, id="NO_SWA"),
pytest.param(True, id="SWA"),
],
)
class TestFusedAttn:
"""
Fused attention tester
......@@ -717,12 +746,16 @@ class TestFusedAttn:
is_training,
qkv_layout,
bias_shape,
swa,
):
"""
Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
"""
window_size = None
if swa:
window_size = (s_kv // 10, 0)
runner = FusedAttnRunner(
b,
s_q,
......@@ -737,6 +770,7 @@ class TestFusedAttn:
is_training,
qkv_layout,
bias_shape,
window_size,
)
runner.test_forward()
......@@ -754,10 +788,14 @@ class TestFusedAttn:
dtype,
qkv_layout,
bias_shape,
swa,
):
"""
Test backward with parameterized configs
"""
window_size = None
if swa:
window_size = (s_kv // 10, 0)
runner = FusedAttnRunner(
b,
s_q,
......@@ -772,5 +810,6 @@ class TestFusedAttn:
True,
qkv_layout,
bias_shape,
window_size,
)
runner.test_backward()
......@@ -4,7 +4,7 @@
"""Test transformer_engine.jax.flax.TransformerLayer"""
import os
from functools import partial
from typing import Dict
from typing import Dict, Tuple
import flax
import jax
......@@ -61,6 +61,7 @@ _KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
_KEY_OF_WINDOW_SIZE = "window_size"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
......@@ -70,6 +71,7 @@ BASE_ATTRS = {
_KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_WINDOW_SIZE: (-1, -1),
}
ATTRS = [
......@@ -193,6 +195,19 @@ ATTRS = [
{
_KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_WINDOW_SIZE: (64, 0), # Left size must < DATA_SHAPE seqlen
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: (2, 2),
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......@@ -326,7 +341,7 @@ class EncoderRunner(BaseRunner):
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
mask = causal_mask
else:
mask = padded_mask
......@@ -379,7 +394,7 @@ class DecoderRunner(BaseRunner):
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
self_mask = causal_mask
else:
self_mask = padded_mask
......
......@@ -4,7 +4,7 @@
import os
from functools import partial
from typing import Dict
from typing import Dict, Tuple
import flax
import jax
......@@ -645,6 +645,7 @@ class DotProductAttnAttr:
NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = "transpose_batch_sequence"
SCALE_FACTOR = "scale_factor"
WINDOW_SIZE = "window_size"
ATTRS = [
{
ATTN_MASK_TYPE: "padding",
......@@ -681,6 +682,12 @@ class DotProductAttnAttr:
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
},
{
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
]
......@@ -707,6 +714,7 @@ class TestDotProductAttn(TestLayer):
num_gqa_groups = num_attention_heads
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
window_size = attrs.get(DotProductAttnAttr.WINDOW_SIZE, None)
praxis_p = pax_fiddle.Config(
DotProductAttention,
......@@ -717,6 +725,7 @@ class TestDotProductAttn(TestLayer):
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)
flax_cls = partial(
flax_DotProductAttention,
......@@ -726,6 +735,7 @@ class TestDotProductAttn(TestLayer):
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)
return praxis_p, flax_cls
......@@ -750,6 +760,7 @@ class MultiHeadAttnAttr:
ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = "low_rank_adaptation_scope"
WINDOW_SIZE = "window_size"
ATTRS = [
{
USE_BIAS: True,
......@@ -858,6 +869,17 @@ class MultiHeadAttnAttr:
LORA_SCOPE: "all",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
LORA_SCOPE: "all",
TRANSPOSE_BS: True,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
]
......@@ -899,6 +921,7 @@ class TestMultiHeadAttn(TestLayer):
scale_attn_logits = False
scaled_query_init = True
float32_logits = False
window_size = attrs.get(MultiHeadAttnAttr.WINDOW_SIZE, None)
praxis_p = pax_fiddle.Config(
MultiHeadAttention,
......@@ -923,6 +946,7 @@ class TestMultiHeadAttn(TestLayer):
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits,
window_size=window_size,
)
flax_cls = partial(
flax_MultiHeadAttention,
......@@ -946,6 +970,7 @@ class TestMultiHeadAttn(TestLayer):
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits,
window_size=window_size,
)
return praxis_p, flax_cls
......@@ -983,6 +1008,7 @@ class TransformerLayerAttr:
ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = "low_rank_adaptation_scope"
WINDOW_SIZE = "window_size"
ATTRS = [
{
USE_BIAS: True,
......@@ -1246,6 +1272,28 @@ class TransformerLayerAttr:
TRANSPOSE_BS: False,
LORA_SCOPE: "all",
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
]
......@@ -1289,6 +1337,7 @@ class TestTransformer(TestLayer):
)
drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
window_size = attrs.get(TransformerLayerAttr.WINDOW_SIZE, None)
rel_embedding_init = RelativePositionBiases.generate_embedding_init(
relative_embedding.embedding_init,
......@@ -1330,6 +1379,7 @@ class TestTransformer(TestLayer):
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)
flax_cls = partial(
flax_TransformerLayer,
......@@ -1358,6 +1408,7 @@ class TestTransformer(TestLayer):
low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)
return praxis_p, flax_cls
......
......@@ -18,6 +18,7 @@ from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
from transformer_engine.jax.attention import AttnMaskType, make_swa_mask
from transformer_engine.jax.fp8 import DType as TEDType
PRNGKey = Any
......@@ -902,6 +903,33 @@ class RelativePositionBiases(nn.Module):
return values[jnp.newaxis, ...]
def apply_swa_mask(
attn_mask_type: str,
original_mask: Array,
window_size: Tuple[int, int] = (-1, -1),
) -> Array:
"""Apply the sliding window mask to a given mask"""
mask_map = {
"no_mask": AttnMaskType.NO_MASK,
"padding": AttnMaskType.PADDING_MASK,
"causal": AttnMaskType.CAUSAL_MASK,
"padding_causal": AttnMaskType.PADDING_CAUSAL_MASK,
"causal_bottom_right": AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
"padding_causal_bottom_right": AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
}
_attn_mask_type = mask_map.get(attn_mask_type, None)
assert _attn_mask_type is not None
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
swa_mask = make_swa_mask(
max_seqlen_q, max_seqlen_kv, window_size, _attn_mask_type, dtype=original_mask.dtype
)
# In swa_mask and original_mask 0 is masked out
swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape)
new_mask = jnp.where(original_mask == 1, swa_mask_bcast, original_mask)
return new_mask
class EncoderLayer(nn.Module):
"""Transformer encoder layer."""
......@@ -934,7 +962,8 @@ class EncoderLayer(nn.Module):
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None
self_attn_mask_type: Any = None
self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -943,7 +972,13 @@ class EncoderLayer(nn.Module):
@nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False):
del self.self_attn_mask_type # dummy, just align to TE's impl
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
encoder_mask = apply_swa_mask(
self.self_attn_mask_type,
encoder_mask,
self.window_size,
)
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
......@@ -1087,7 +1122,8 @@ class DecoderLayer(nn.Module):
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None
self_attn_mask_type: Any = None
self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -1105,7 +1141,18 @@ class DecoderLayer(nn.Module):
decode=False,
max_decode_length=None,
):
del self.self_attn_mask_type # dummy, just align to TE's impl
decoder_mask = apply_swa_mask(
self.self_attn_mask_type,
decoder_mask,
self.window_size,
)
encoder_decoder_mask = apply_swa_mask(
"padding",
encoder_decoder_mask,
self.window_size,
)
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
......
......@@ -86,6 +86,66 @@ def get_qkv_format(qkv_layout):
return QKVFormat(nvte_get_qkv_format(qkv_layout.value))
def make_swa_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
window_size: Optional[Tuple[int, int]] = None,
attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK,
dtype: jax.typing.DTypeLike = jnp.float32,
):
"""
Generate sliding window mask. `True` or `1` means keep the element.
For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type,
the sliding window diagonal is aligned to the bottom right corner, and for other
mask types, the top left corner.
Parameters
----------
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
window_size: Optional[Tuple[int, int]] = None
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Negative number in window size means infinity window.
`None` means no sliding window.
attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK
dtype: jax.typing.DTypeLike, default=jnp.float32
The mask data type.
Returns
----------
swa_mask: jax.numpy.tensor
Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions
that will get attention, value 0 are the masked out positions.
"""
swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
if window_size is None:
return swa_mask
bottom_right_masks = [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
left_window, right_window = window_size
if attn_mask_type in bottom_right_masks:
if left_window < 0:
left_window = max_seqlen_kv
if right_window < 0:
right_window = max_seqlen_kv
bottom_right_shift = max_seqlen_kv - max_seqlen_q
swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift)
swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift)
else:
if left_window < 0:
left_window = max_seqlen_q
if right_window < 0:
right_window = max_seqlen_q
swa_mask = jnp.triu(swa_mask, k=-left_window)
swa_mask = jnp.tril(swa_mask, k=right_window)
return swa_mask
def canonicalize_attn_mask_type(attn_mask_type: str):
"""Convert string attn_mask_type to AttnMaskType
TE-JAX currently fall back to the padding version kernels for the libraries integration.
......@@ -129,6 +189,7 @@ def is_fused_attn_kernel_available(
q_max_seqlen,
kv_max_seqlen,
head_dim,
window_size: Optional[Tuple[int, int]] = None,
):
"""
To check whether the fused attention kernel is supported
......@@ -145,6 +206,7 @@ def is_fused_attn_kernel_available(
q_max_seqlen,
kv_max_seqlen,
head_dim,
(-1, -1) if window_size is None else window_size,
).is_fused_attn_kernel_available()
......@@ -247,6 +309,7 @@ def fused_attn(
scaling_factor: float,
dropout_probability: float,
is_training: bool,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
......@@ -275,6 +338,7 @@ def fused_attn(
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
window_size (Optional[Tuple[int, int]]): Sliding window size.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
......@@ -332,6 +396,7 @@ def fused_attn(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=1,
window_size=window_size,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......@@ -354,6 +419,7 @@ def fused_attn_thd(
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int = 1,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
......@@ -394,6 +460,8 @@ def fused_attn_thd(
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]):
Sliding window size.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
......@@ -451,6 +519,7 @@ def fused_attn_thd(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......@@ -458,7 +527,7 @@ def fused_attn_thd(
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -474,6 +543,7 @@ def _fused_attn(
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]],
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
):
......@@ -492,6 +562,7 @@ def _fused_attn(
dropout_probability,
is_training,
max_segments_per_seq,
window_size,
context_parallel_causal_load_balanced,
context_parallel_axis,
)
......@@ -513,6 +584,7 @@ def _fused_attn_fwd_rule(
dropout_probability,
is_training,
max_segments_per_seq,
window_size,
context_parallel_causal_load_balanced,
context_parallel_axis,
):
......@@ -531,6 +603,7 @@ def _fused_attn_fwd_rule(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......@@ -558,6 +631,7 @@ def _fused_attn_bwd_rule(
dropout_probability,
is_training,
max_segments_per_seq,
window_size,
context_parallel_causal_load_balanced,
context_parallel_axis,
ctx,
......@@ -592,6 +666,7 @@ def _fused_attn_bwd_rule(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......
......@@ -63,6 +63,7 @@ __all__ = [
"dropout_probability",
"is_training",
"max_segments_per_seq",
"window_size",
"context_parallel_load_balanced",
"cp_axis",
],
......@@ -80,6 +81,7 @@ class _FusedAttnConfig:
dropout_probability: float
is_training: bool
max_segments_per_seq: int
window_size: Tuple[int, int]
context_parallel_load_balanced: bool
cp_axis: str
......@@ -101,6 +103,7 @@ class FusedAttnHelper:
q_max_seqlen: int
kv_max_seqlen: int
head_dim: int
window_size: Tuple[int, int]
def is_fused_attn_kernel_available(self):
"""Check if there is available fused attention kernel"""
......@@ -120,6 +123,8 @@ class FusedAttnHelper:
self.q_max_seqlen,
self.kv_max_seqlen,
self.head_dim,
self.window_size[0],
self.window_size[1],
)
@staticmethod
......@@ -263,6 +268,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
q_max_seqlen,
kv_max_seqlen,
head_dim,
config.window_size,
).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
......@@ -309,6 +315,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training,
config.max_segments_per_seq,
config.window_size[0],
config.window_size[1],
)
wkspace_aval = q_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
......@@ -388,6 +396,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
......@@ -615,6 +625,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config.is_training,
deterministic,
config.max_segments_per_seq,
config.window_size[0],
config.window_size[1],
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
......@@ -714,6 +726,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
......@@ -1042,6 +1056,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
is_context_parallel and config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
......@@ -1136,6 +1153,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
is_context_parallel and config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
......@@ -1284,6 +1304,7 @@ def fused_attn_fwd(
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
) -> jnp.ndarray:
......@@ -1314,6 +1335,11 @@ def fused_attn_fwd(
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
max_segments_per_seq (int):
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]): Sliding window size.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
......@@ -1356,6 +1382,7 @@ def fused_attn_fwd(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
)
......@@ -1390,6 +1417,7 @@ def fused_attn_bwd(
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
......@@ -1421,6 +1449,11 @@ def fused_attn_bwd(
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
max_segments_per_seq (int):
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]): Sliding window size .
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
......@@ -1466,6 +1499,7 @@ def fused_attn_bwd(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
)
......
......@@ -135,6 +135,8 @@ struct CustomCallFusedAttnDescriptor {
DType wkspace_dtype;
bool is_training;
bool deterministic;
int64_t window_size_left;
int64_t window_size_right;
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
......@@ -143,7 +145,7 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic);
bool deterministic, int64_t window_size_left, int64_t window_size_right);
// Transpose
......@@ -239,14 +241,15 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim);
size_t head_dim, int64_t window_size_left,
int64_t window_size_right);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -255,7 +258,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq);
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -15,11 +15,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_attn_heads, size_t kv_attn_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim) {
size_t head_dim, int64_t window_size_left,
int64_t window_size_right) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, -1, -1);
head_dim, head_dim, window_size_left, window_size_right);
return backend;
}
......@@ -105,7 +106,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq) {
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
......@@ -155,27 +156,28 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, -1, -1, query_workspace_tensor.data(), nullptr);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, query_workspace_tensor.data(), nullptr);
bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(),
nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, query_workspace_tensor.data(), nullptr);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
......@@ -223,6 +225,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto dtype = descriptor.dtype;
auto is_training = descriptor.is_training;
auto max_segments_per_seq = descriptor.max_segments_per_seq;
auto window_size_left = descriptor.window_size_left;
auto window_size_right = descriptor.window_size_right;
/* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -269,7 +273,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, -1, -1);
head_dim, head_dim, window_size_left, window_size_right);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -288,12 +292,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1,
workspace_tensor.data(), stream);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -306,7 +310,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, workspace_tensor.data(), stream);
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -322,8 +326,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, workspace_tensor.data(), stream);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -336,7 +340,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq) {
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
......@@ -398,8 +403,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, deterministic,
query_workspace_tensor.data(), nullptr);
bias_type, mask_type, window_size_left, window_size_right,
deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
......@@ -408,8 +413,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, deterministic, query_workspace_tensor.data(), nullptr);
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, deterministic, query_workspace_tensor.data(),
nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
......@@ -419,8 +425,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, deterministic, query_workspace_tensor.data(), nullptr);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, deterministic,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -470,6 +477,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dtype = descriptor.dtype;
auto deterministic = descriptor.deterministic;
auto max_segments_per_seq = descriptor.max_segments_per_seq;
auto window_size_left = descriptor.window_size_left;
auto window_size_right = descriptor.window_size_right;
/* Input tensors */
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -513,7 +522,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, -1, -1);
head_dim, head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
......@@ -535,13 +544,14 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -568,8 +578,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, deterministic,
workspace_tensor.data(), stream);
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -604,8 +614,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1,
deterministic, workspace_tensor.data(), stream);
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, deterministic, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......
......@@ -69,11 +69,15 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic) {
return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic});
bool deterministic, int64_t window_size_left, int64_t window_size_right) {
return PackOpaque(
CustomCallFusedAttnDescriptor{input_batch, bias_batch, q_max_seqlen,
kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, max_segments_per_seq,
wkspace_size, scaling_factor, dropout_probability,
bias_type, mask_type, qkv_layout,
dtype, wkspace_dtype, is_training,
deterministic, window_size_left, window_size_right});
}
} // namespace jax
......
......@@ -25,7 +25,7 @@ from jax.ad_checkpoint import checkpoint_name
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout
from ..attention import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn
from ..softmax import SoftmaxType
from ..sharding import num_of_devices
......@@ -118,6 +118,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
float32_logits: bool = False
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
@nn.compact
def __call__(
......@@ -193,11 +194,27 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights += bias
def apply_swa_mask(attn_mask_type: AttnMaskType, original_mask: Array) -> Array:
"""Apply the sliding window mask to a given mask"""
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, self.window_size, attn_mask_type)
# In swa_mask 0 is masked out, in original_mask 1 is masked out
swa_mask = 1 - swa_mask.astype(original_mask.dtype)
swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape)
new_mask = jnp.where(original_mask == 0, swa_mask_bcast, original_mask)
return new_mask
def convert_to_softmax_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType"""
# mask is ignored for no_mask and causal_mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
# mask is ignored for no_mask and causal_mask without sliding window
if attn_mask_type == AttnMaskType.NO_MASK:
mask = None
if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
mask = None
if mask is not None:
mask = apply_swa_mask(attn_mask_type, mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
......@@ -244,6 +261,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None
@nn.compact
def __call__(
......@@ -289,6 +307,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
)
elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat
......@@ -311,6 +330,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
)
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence:
......@@ -328,6 +348,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
......@@ -440,6 +461,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window.
Optimization parameters
-----------------------
......@@ -459,6 +482,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
@nn.compact
def __call__(
......@@ -532,6 +556,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
seqlen_q,
seqlen_kv,
self.head_dim,
self.window_size,
)
use_fused_attn = enable_fused_attn and has_fused_attn_kernel
......@@ -577,6 +602,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
float32_logits=self.float32_logits,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
else:
x = _FusedDotProductAttention(
......@@ -587,6 +613,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
window_size=self.window_size,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
return x
......@@ -856,6 +883,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
For fused attention backend, the accumulation is always float32 without the perf overhead.
fuse_qkv: bool, default = None
Deprecated. Please refer `fuse_qkv_params`
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window.
"""
head_dim: int
......@@ -886,6 +915,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
window_size: Optional[Tuple[int, int]] = None
# Deprecated parameters
num_heads: Optional[int] = None
......@@ -1280,6 +1310,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
qkv_layout=qkv_layout.name,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
)(*dpa_args, mask, bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
......@@ -1555,6 +1586,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
:math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window.
Optimization parameters
-----------------------
......@@ -1618,6 +1651,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
window_size: Optional[Tuple[int, int]] = None
def __post_init__(self):
if self.mha_kernel_init is None:
......@@ -1771,6 +1805,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias,
bias_init=self.bias_init,
name=mha_name,
window_size=self.window_size,
)(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
def hidden_dropout(x, deterministic):
......@@ -1848,6 +1883,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias,
bias_init=self.bias_init,
name="encoder_decoder_attention",
window_size=self.window_size,
)(x, encoded, encoder_decoder_mask, deterministic=deterministic)
y = with_sharding_constraint_by_logical_axes(
......
......@@ -80,6 +80,7 @@ class DotProductAttention(TransformerEngineBaseLayer):
qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
def setup(self) -> None:
"""setup"""
......@@ -102,6 +103,7 @@ class DotProductAttention(TransformerEngineBaseLayer):
qkv_layout=self.qkv_layout,
scale_factor=self.scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
)
self.create_layer("dot_product_attention", dpa_cls)
......@@ -151,6 +153,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
window_size: Optional[Tuple[int, int]] = None
# Deprecated parameters
num_heads: Optional[int] = None
......@@ -233,6 +236,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits,
window_size=self.window_size,
)
self.create_layer("multi_head_attn", mha_cls)
......@@ -292,6 +296,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
window_size: Optional[Tuple[int, int]] = None
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -371,6 +376,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
window_size=self.window_size,
)
self.create_layer("transformerlayer", transformerlayer_cls)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment