Unverified Commit 85e60e64 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] Expose sliding window attn to TE-JAX API (#1205)



* Expose JAX sliding window attn API
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* No SWA in context parallel; fix RNG seed in test
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Handle SAW API discrepancy in cuDNN and Python
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add SAW API for flax, all tests passed

Will update tests/jax/test_praxis_layers.py next
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update test_praxis_layers.py for SWA, test passed
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Use tuple window_size; update for PR #1212
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add and adjust some pytest.skip
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revised following Reese Wang's comments

Still need further debugging:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-KV_PACKED-NO_MASK-NO_BIAS] - AssertionError:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-KV_PACKED-NO_MASK-POST_SCALE_BIAS-1HSS] - AssertionError:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-SEPARATE-NO_MASK-NO_BIAS] - AssertionError:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-SEPARATE-NO_MASK-POST_SCALE_BIAS-1HSS] - AssertionError:

These errors does not exist in the previous commit
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix no-SWA test case errors in previous commit
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add Padding mask w/ sliding windows sanity tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use float32 for the reference code softmax calculation
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarReese Wang <rewang@nvidia.com>
parent 5b6546c8
...@@ -6,6 +6,7 @@ from enum import Enum ...@@ -6,6 +6,7 @@ from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from math import sqrt from math import sqrt
from typing import Tuple, Optional
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -27,6 +28,7 @@ from transformer_engine.jax.attention import ( ...@@ -27,6 +28,7 @@ from transformer_engine.jax.attention import (
fused_attn, fused_attn,
fused_attn_thd, fused_attn_thd,
get_qkv_format, get_qkv_format,
make_swa_mask,
) )
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import ( from transformer_engine.transformer_engine_jax import (
...@@ -123,6 +125,7 @@ def make_mask( ...@@ -123,6 +125,7 @@ def make_mask(
segment_pad_q: ArrayLike, segment_pad_q: ArrayLike,
segment_pad_kv: ArrayLike, segment_pad_kv: ArrayLike,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
window_size: Optional[Tuple[int, int]] = None,
) -> Array: ) -> Array:
""" """
Create attention mask based on mask type. A `True` value in the mask means Create attention mask based on mask type. A `True` value in the mask means
...@@ -140,6 +143,15 @@ def make_mask( ...@@ -140,6 +143,15 @@ def make_mask(
segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1) segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
) )
inv_mask = combine_masks(inv_pad_mask, inv_mask) inv_mask = combine_masks(inv_pad_mask, inv_mask)
if window_size is not None:
max_seqlen_q = inv_mask.shape[-2]
max_seqlen_kv = inv_mask.shape[-1]
inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
# In inv_swa_mask and inv_mask 0 is masked out
inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask)
mask = jnp.logical_not(inv_mask) mask = jnp.logical_not(inv_mask)
return mask return mask
...@@ -274,6 +286,7 @@ class FusedAttnRunner: ...@@ -274,6 +286,7 @@ class FusedAttnRunner:
is_training: bool is_training: bool
qkv_layout: QKVLayout qkv_layout: QKVLayout
bias_shape: BiasShape bias_shape: BiasShape
window_size: Optional[Tuple[int, int]] = None
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
...@@ -298,6 +311,11 @@ class FusedAttnRunner: ...@@ -298,6 +311,11 @@ class FusedAttnRunner:
if self.max_seqlen_q != self.max_seqlen_kv: if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.") pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.")
if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
self.backend = FusedAttnHelper( self.backend = FusedAttnHelper(
self.dtype, self.dtype,
self.dtype, self.dtype,
...@@ -310,6 +328,7 @@ class FusedAttnRunner: ...@@ -310,6 +328,7 @@ class FusedAttnRunner:
self.max_seqlen_q, self.max_seqlen_q,
self.max_seqlen_kv, self.max_seqlen_kv,
self.head_dim, self.head_dim,
(-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend() ).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
...@@ -456,6 +475,7 @@ class FusedAttnRunner: ...@@ -456,6 +475,7 @@ class FusedAttnRunner:
self.segment_pad_q, self.segment_pad_q,
self.segment_pad_kv, self.segment_pad_kv,
self.attn_mask_type, self.attn_mask_type,
self.window_size,
) )
if get_qkv_format(self.qkv_layout) == QKVFormat.THD: if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
...@@ -500,6 +520,7 @@ class FusedAttnRunner: ...@@ -500,6 +520,7 @@ class FusedAttnRunner:
"is_training": self.is_training, "is_training": self.is_training,
"qkv_layout": self.qkv_layout, "qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(), "max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
} }
# Convert the outputs to float32 for the elementwise comparison # Convert the outputs to float32 for the elementwise comparison
...@@ -557,6 +578,7 @@ class FusedAttnRunner: ...@@ -557,6 +578,7 @@ class FusedAttnRunner:
"is_training": self.is_training, "is_training": self.is_training,
"qkv_layout": self.qkv_layout, "qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(), "max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
} }
# We can compute dBias only for the [1, h, s, s] layout # We can compute dBias only for the [1, h, s, s] layout
...@@ -668,7 +690,7 @@ class FusedAttnRunner: ...@@ -668,7 +690,7 @@ class FusedAttnRunner:
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"), pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"), pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
pytest.param( pytest.param(
2, 2,
2048, 2048,
...@@ -677,7 +699,7 @@ class FusedAttnRunner: ...@@ -677,7 +699,7 @@ class FusedAttnRunner:
12, 12,
64, 64,
jnp.bfloat16, jnp.bfloat16,
id="2-2048-1048-12-12-64-BF16-CROSS", id="2-2048-1024-12-12-64-BF16-CROSS",
), ),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"), pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"), pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
...@@ -690,6 +712,13 @@ class FusedAttnRunner: ...@@ -690,6 +712,13 @@ class FusedAttnRunner:
pytest.param(0.1, id="DROP_0.1"), pytest.param(0.1, id="DROP_0.1"),
], ],
) )
@pytest.mark.parametrize(
"swa",
[
pytest.param(False, id="NO_SWA"),
pytest.param(True, id="SWA"),
],
)
class TestFusedAttn: class TestFusedAttn:
""" """
Fused attention tester Fused attention tester
...@@ -717,12 +746,16 @@ class TestFusedAttn: ...@@ -717,12 +746,16 @@ class TestFusedAttn:
is_training, is_training,
qkv_layout, qkv_layout,
bias_shape, bias_shape,
swa,
): ):
""" """
Test forward with parameterized configs Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging It is kept for development and debugging
""" """
window_size = None
if swa:
window_size = (s_kv // 10, 0)
runner = FusedAttnRunner( runner = FusedAttnRunner(
b, b,
s_q, s_q,
...@@ -737,6 +770,7 @@ class TestFusedAttn: ...@@ -737,6 +770,7 @@ class TestFusedAttn:
is_training, is_training,
qkv_layout, qkv_layout,
bias_shape, bias_shape,
window_size,
) )
runner.test_forward() runner.test_forward()
...@@ -754,10 +788,14 @@ class TestFusedAttn: ...@@ -754,10 +788,14 @@ class TestFusedAttn:
dtype, dtype,
qkv_layout, qkv_layout,
bias_shape, bias_shape,
swa,
): ):
""" """
Test backward with parameterized configs Test backward with parameterized configs
""" """
window_size = None
if swa:
window_size = (s_kv // 10, 0)
runner = FusedAttnRunner( runner = FusedAttnRunner(
b, b,
s_q, s_q,
...@@ -772,5 +810,6 @@ class TestFusedAttn: ...@@ -772,5 +810,6 @@ class TestFusedAttn:
True, True,
qkv_layout, qkv_layout,
bias_shape, bias_shape,
window_size,
) )
runner.test_backward() runner.test_backward()
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Test transformer_engine.jax.flax.TransformerLayer""" """Test transformer_engine.jax.flax.TransformerLayer"""
import os import os
from functools import partial from functools import partial
from typing import Dict from typing import Dict, Tuple
import flax import flax
import jax import jax
...@@ -61,6 +61,7 @@ _KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type" ...@@ -61,6 +61,7 @@ _KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits" _KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias" _KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding" _KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
_KEY_OF_WINDOW_SIZE = "window_size"
BASE_ATTRS = { BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_TRANSPOSE_BS: True,
...@@ -70,6 +71,7 @@ BASE_ATTRS = { ...@@ -70,6 +71,7 @@ BASE_ATTRS = {
_KEY_OF_INTERMEDIATE_DROPOUT: 0, _KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal", _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: "layernorm", _KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_WINDOW_SIZE: (-1, -1),
} }
ATTRS = [ ATTRS = [
...@@ -193,6 +195,19 @@ ATTRS = [ ...@@ -193,6 +195,19 @@ ATTRS = [
{ {
_KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")), _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
}, },
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_WINDOW_SIZE: (64, 0), # Left size must < DATA_SHAPE seqlen
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: (2, 2),
},
] ]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
...@@ -326,7 +341,7 @@ class EncoderRunner(BaseRunner): ...@@ -326,7 +341,7 @@ class EncoderRunner(BaseRunner):
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]: if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
mask = causal_mask mask = causal_mask
else: else:
mask = padded_mask mask = padded_mask
...@@ -379,7 +394,7 @@ class DecoderRunner(BaseRunner): ...@@ -379,7 +394,7 @@ class DecoderRunner(BaseRunner):
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]: if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
self_mask = causal_mask self_mask = causal_mask
else: else:
self_mask = padded_mask self_mask = padded_mask
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import os import os
from functools import partial from functools import partial
from typing import Dict from typing import Dict, Tuple
import flax import flax
import jax import jax
...@@ -645,6 +645,7 @@ class DotProductAttnAttr: ...@@ -645,6 +645,7 @@ class DotProductAttnAttr:
NUM_GQA_GROUPS = "num_gqa_groups" NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = "transpose_batch_sequence" TRANSPOSE_BS = "transpose_batch_sequence"
SCALE_FACTOR = "scale_factor" SCALE_FACTOR = "scale_factor"
WINDOW_SIZE = "window_size"
ATTRS = [ ATTRS = [
{ {
ATTN_MASK_TYPE: "padding", ATTN_MASK_TYPE: "padding",
...@@ -681,6 +682,12 @@ class DotProductAttnAttr: ...@@ -681,6 +682,12 @@ class DotProductAttnAttr:
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0, SCALE_FACTOR: 1.0,
}, },
{
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
] ]
...@@ -707,6 +714,7 @@ class TestDotProductAttn(TestLayer): ...@@ -707,6 +714,7 @@ class TestDotProductAttn(TestLayer):
num_gqa_groups = num_attention_heads num_gqa_groups = num_attention_heads
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE] attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS] transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
window_size = attrs.get(DotProductAttnAttr.WINDOW_SIZE, None)
praxis_p = pax_fiddle.Config( praxis_p = pax_fiddle.Config(
DotProductAttention, DotProductAttention,
...@@ -717,6 +725,7 @@ class TestDotProductAttn(TestLayer): ...@@ -717,6 +725,7 @@ class TestDotProductAttn(TestLayer):
num_gqa_groups=num_gqa_groups, num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
) )
flax_cls = partial( flax_cls = partial(
flax_DotProductAttention, flax_DotProductAttention,
...@@ -726,6 +735,7 @@ class TestDotProductAttn(TestLayer): ...@@ -726,6 +735,7 @@ class TestDotProductAttn(TestLayer):
num_gqa_groups=num_gqa_groups, num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
) )
return praxis_p, flax_cls return praxis_p, flax_cls
...@@ -750,6 +760,7 @@ class MultiHeadAttnAttr: ...@@ -750,6 +760,7 @@ class MultiHeadAttnAttr:
ENABLE_ROPE = "enable_rotary_pos_emb" ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = "rotary_pos_emb_group_method" ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = "low_rank_adaptation_scope" LORA_SCOPE = "low_rank_adaptation_scope"
WINDOW_SIZE = "window_size"
ATTRS = [ ATTRS = [
{ {
USE_BIAS: True, USE_BIAS: True,
...@@ -858,6 +869,17 @@ class MultiHeadAttnAttr: ...@@ -858,6 +869,17 @@ class MultiHeadAttnAttr:
LORA_SCOPE: "all", LORA_SCOPE: "all",
TRANSPOSE_BS: True, TRANSPOSE_BS: True,
}, },
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
LORA_SCOPE: "all",
TRANSPOSE_BS: True,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
] ]
...@@ -899,6 +921,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -899,6 +921,7 @@ class TestMultiHeadAttn(TestLayer):
scale_attn_logits = False scale_attn_logits = False
scaled_query_init = True scaled_query_init = True
float32_logits = False float32_logits = False
window_size = attrs.get(MultiHeadAttnAttr.WINDOW_SIZE, None)
praxis_p = pax_fiddle.Config( praxis_p = pax_fiddle.Config(
MultiHeadAttention, MultiHeadAttention,
...@@ -923,6 +946,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -923,6 +946,7 @@ class TestMultiHeadAttn(TestLayer):
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init, scaled_query_init=scaled_query_init,
float32_logits=float32_logits, float32_logits=float32_logits,
window_size=window_size,
) )
flax_cls = partial( flax_cls = partial(
flax_MultiHeadAttention, flax_MultiHeadAttention,
...@@ -946,6 +970,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -946,6 +970,7 @@ class TestMultiHeadAttn(TestLayer):
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init, scaled_query_init=scaled_query_init,
float32_logits=float32_logits, float32_logits=float32_logits,
window_size=window_size,
) )
return praxis_p, flax_cls return praxis_p, flax_cls
...@@ -983,6 +1008,7 @@ class TransformerLayerAttr: ...@@ -983,6 +1008,7 @@ class TransformerLayerAttr:
ENABLE_ROPE = "enable_rotary_pos_emb" ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = "rotary_pos_emb_group_method" ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = "low_rank_adaptation_scope" LORA_SCOPE = "low_rank_adaptation_scope"
WINDOW_SIZE = "window_size"
ATTRS = [ ATTRS = [
{ {
USE_BIAS: True, USE_BIAS: True,
...@@ -1246,6 +1272,28 @@ class TransformerLayerAttr: ...@@ -1246,6 +1272,28 @@ class TransformerLayerAttr:
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
LORA_SCOPE: "all", LORA_SCOPE: "all",
}, },
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
] ]
...@@ -1289,6 +1337,7 @@ class TestTransformer(TestLayer): ...@@ -1289,6 +1337,7 @@ class TestTransformer(TestLayer):
) )
drop_path = 0.0 drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS] transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
window_size = attrs.get(TransformerLayerAttr.WINDOW_SIZE, None)
rel_embedding_init = RelativePositionBiases.generate_embedding_init( rel_embedding_init = RelativePositionBiases.generate_embedding_init(
relative_embedding.embedding_init, relative_embedding.embedding_init,
...@@ -1330,6 +1379,7 @@ class TestTransformer(TestLayer): ...@@ -1330,6 +1379,7 @@ class TestTransformer(TestLayer):
relative_embedding=relative_embedding, relative_embedding=relative_embedding,
drop_path=drop_path, drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
) )
flax_cls = partial( flax_cls = partial(
flax_TransformerLayer, flax_TransformerLayer,
...@@ -1358,6 +1408,7 @@ class TestTransformer(TestLayer): ...@@ -1358,6 +1408,7 @@ class TestTransformer(TestLayer):
low_rank_adaptation_scope=low_rank_adaptation_scope, low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path, drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
) )
return praxis_p, flax_cls return praxis_p, flax_cls
......
...@@ -18,6 +18,7 @@ from jax import lax, vmap ...@@ -18,6 +18,7 @@ from jax import lax, vmap
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from transformer_engine.jax.attention import AttnMaskType, make_swa_mask
from transformer_engine.jax.fp8 import DType as TEDType from transformer_engine.jax.fp8 import DType as TEDType
PRNGKey = Any PRNGKey = Any
...@@ -902,6 +903,33 @@ class RelativePositionBiases(nn.Module): ...@@ -902,6 +903,33 @@ class RelativePositionBiases(nn.Module):
return values[jnp.newaxis, ...] return values[jnp.newaxis, ...]
def apply_swa_mask(
attn_mask_type: str,
original_mask: Array,
window_size: Tuple[int, int] = (-1, -1),
) -> Array:
"""Apply the sliding window mask to a given mask"""
mask_map = {
"no_mask": AttnMaskType.NO_MASK,
"padding": AttnMaskType.PADDING_MASK,
"causal": AttnMaskType.CAUSAL_MASK,
"padding_causal": AttnMaskType.PADDING_CAUSAL_MASK,
"causal_bottom_right": AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
"padding_causal_bottom_right": AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
}
_attn_mask_type = mask_map.get(attn_mask_type, None)
assert _attn_mask_type is not None
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
swa_mask = make_swa_mask(
max_seqlen_q, max_seqlen_kv, window_size, _attn_mask_type, dtype=original_mask.dtype
)
# In swa_mask and original_mask 0 is masked out
swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape)
new_mask = jnp.where(original_mask == 1, swa_mask_bcast, original_mask)
return new_mask
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
"""Transformer encoder layer.""" """Transformer encoder layer."""
...@@ -934,7 +962,8 @@ class EncoderLayer(nn.Module): ...@@ -934,7 +962,8 @@ class EncoderLayer(nn.Module):
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None self_attn_bias_type: Any = None
self_attn_mask_type: Any = None self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
def __post_init__(self): def __post_init__(self):
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
...@@ -943,7 +972,13 @@ class EncoderLayer(nn.Module): ...@@ -943,7 +972,13 @@ class EncoderLayer(nn.Module):
@nn.compact @nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False): def __call__(self, inputs, encoder_mask=None, deterministic=False):
del self.self_attn_mask_type # dummy, just align to TE's impl # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
encoder_mask = apply_swa_mask(
self.self_attn_mask_type,
encoder_mask,
self.window_size,
)
# Relative position embedding as attention biases. # Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim batch_dim = 1 - sequence_dim
...@@ -1087,7 +1122,8 @@ class DecoderLayer(nn.Module): ...@@ -1087,7 +1122,8 @@ class DecoderLayer(nn.Module):
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None self_attn_bias_type: Any = None
self_attn_mask_type: Any = None self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
def __post_init__(self): def __post_init__(self):
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
...@@ -1105,7 +1141,18 @@ class DecoderLayer(nn.Module): ...@@ -1105,7 +1141,18 @@ class DecoderLayer(nn.Module):
decode=False, decode=False,
max_decode_length=None, max_decode_length=None,
): ):
del self.self_attn_mask_type # dummy, just align to TE's impl decoder_mask = apply_swa_mask(
self.self_attn_mask_type,
decoder_mask,
self.window_size,
)
encoder_decoder_mask = apply_swa_mask(
"padding",
encoder_decoder_mask,
self.window_size,
)
# Relative position embedding as attention biases. # Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim batch_dim = 1 - sequence_dim
......
...@@ -86,6 +86,66 @@ def get_qkv_format(qkv_layout): ...@@ -86,6 +86,66 @@ def get_qkv_format(qkv_layout):
return QKVFormat(nvte_get_qkv_format(qkv_layout.value)) return QKVFormat(nvte_get_qkv_format(qkv_layout.value))
def make_swa_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
window_size: Optional[Tuple[int, int]] = None,
attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK,
dtype: jax.typing.DTypeLike = jnp.float32,
):
"""
Generate sliding window mask. `True` or `1` means keep the element.
For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type,
the sliding window diagonal is aligned to the bottom right corner, and for other
mask types, the top left corner.
Parameters
----------
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
window_size: Optional[Tuple[int, int]] = None
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Negative number in window size means infinity window.
`None` means no sliding window.
attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK
dtype: jax.typing.DTypeLike, default=jnp.float32
The mask data type.
Returns
----------
swa_mask: jax.numpy.tensor
Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions
that will get attention, value 0 are the masked out positions.
"""
swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
if window_size is None:
return swa_mask
bottom_right_masks = [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
left_window, right_window = window_size
if attn_mask_type in bottom_right_masks:
if left_window < 0:
left_window = max_seqlen_kv
if right_window < 0:
right_window = max_seqlen_kv
bottom_right_shift = max_seqlen_kv - max_seqlen_q
swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift)
swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift)
else:
if left_window < 0:
left_window = max_seqlen_q
if right_window < 0:
right_window = max_seqlen_q
swa_mask = jnp.triu(swa_mask, k=-left_window)
swa_mask = jnp.tril(swa_mask, k=right_window)
return swa_mask
def canonicalize_attn_mask_type(attn_mask_type: str): def canonicalize_attn_mask_type(attn_mask_type: str):
"""Convert string attn_mask_type to AttnMaskType """Convert string attn_mask_type to AttnMaskType
TE-JAX currently fall back to the padding version kernels for the libraries integration. TE-JAX currently fall back to the padding version kernels for the libraries integration.
...@@ -129,6 +189,7 @@ def is_fused_attn_kernel_available( ...@@ -129,6 +189,7 @@ def is_fused_attn_kernel_available(
q_max_seqlen, q_max_seqlen,
kv_max_seqlen, kv_max_seqlen,
head_dim, head_dim,
window_size: Optional[Tuple[int, int]] = None,
): ):
""" """
To check whether the fused attention kernel is supported To check whether the fused attention kernel is supported
...@@ -145,6 +206,7 @@ def is_fused_attn_kernel_available( ...@@ -145,6 +206,7 @@ def is_fused_attn_kernel_available(
q_max_seqlen, q_max_seqlen,
kv_max_seqlen, kv_max_seqlen,
head_dim, head_dim,
(-1, -1) if window_size is None else window_size,
).is_fused_attn_kernel_available() ).is_fused_attn_kernel_available()
...@@ -247,6 +309,7 @@ def fused_attn( ...@@ -247,6 +309,7 @@ def fused_attn(
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
): ):
...@@ -275,6 +338,7 @@ def fused_attn( ...@@ -275,6 +338,7 @@ def fused_attn(
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode. is_training (bool): Flag indicating whether the model is in training mode.
window_size (Optional[Tuple[int, int]]): Sliding window size.
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
...@@ -332,6 +396,7 @@ def fused_attn( ...@@ -332,6 +396,7 @@ def fused_attn(
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=1, max_segments_per_seq=1,
window_size=window_size,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
) )
...@@ -354,6 +419,7 @@ def fused_attn_thd( ...@@ -354,6 +419,7 @@ def fused_attn_thd(
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
max_segments_per_seq: int = 1, max_segments_per_seq: int = 1,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
): ):
...@@ -394,6 +460,8 @@ def fused_attn_thd( ...@@ -394,6 +460,8 @@ def fused_attn_thd(
Indicating the maximum number of segments inside a sequence. This parameter is to Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`. time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]):
Sliding window size.
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
...@@ -451,6 +519,7 @@ def fused_attn_thd( ...@@ -451,6 +519,7 @@ def fused_attn_thd(
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
) )
...@@ -458,7 +527,7 @@ def fused_attn_thd( ...@@ -458,7 +527,7 @@ def fused_attn_thd(
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
def _fused_attn( def _fused_attn(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
...@@ -474,6 +543,7 @@ def _fused_attn( ...@@ -474,6 +543,7 @@ def _fused_attn(
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
max_segments_per_seq: int, max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]],
context_parallel_causal_load_balanced: bool, context_parallel_causal_load_balanced: bool,
context_parallel_axis: str, context_parallel_axis: str,
): ):
...@@ -492,6 +562,7 @@ def _fused_attn( ...@@ -492,6 +562,7 @@ def _fused_attn(
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq, max_segments_per_seq,
window_size,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
) )
...@@ -513,6 +584,7 @@ def _fused_attn_fwd_rule( ...@@ -513,6 +584,7 @@ def _fused_attn_fwd_rule(
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq, max_segments_per_seq,
window_size,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
): ):
...@@ -531,6 +603,7 @@ def _fused_attn_fwd_rule( ...@@ -531,6 +603,7 @@ def _fused_attn_fwd_rule(
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
) )
...@@ -558,6 +631,7 @@ def _fused_attn_bwd_rule( ...@@ -558,6 +631,7 @@ def _fused_attn_bwd_rule(
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq, max_segments_per_seq,
window_size,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
ctx, ctx,
...@@ -592,6 +666,7 @@ def _fused_attn_bwd_rule( ...@@ -592,6 +666,7 @@ def _fused_attn_bwd_rule(
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
) )
......
...@@ -63,6 +63,7 @@ __all__ = [ ...@@ -63,6 +63,7 @@ __all__ = [
"dropout_probability", "dropout_probability",
"is_training", "is_training",
"max_segments_per_seq", "max_segments_per_seq",
"window_size",
"context_parallel_load_balanced", "context_parallel_load_balanced",
"cp_axis", "cp_axis",
], ],
...@@ -80,6 +81,7 @@ class _FusedAttnConfig: ...@@ -80,6 +81,7 @@ class _FusedAttnConfig:
dropout_probability: float dropout_probability: float
is_training: bool is_training: bool
max_segments_per_seq: int max_segments_per_seq: int
window_size: Tuple[int, int]
context_parallel_load_balanced: bool context_parallel_load_balanced: bool
cp_axis: str cp_axis: str
...@@ -101,6 +103,7 @@ class FusedAttnHelper: ...@@ -101,6 +103,7 @@ class FusedAttnHelper:
q_max_seqlen: int q_max_seqlen: int
kv_max_seqlen: int kv_max_seqlen: int
head_dim: int head_dim: int
window_size: Tuple[int, int]
def is_fused_attn_kernel_available(self): def is_fused_attn_kernel_available(self):
"""Check if there is available fused attention kernel""" """Check if there is available fused attention kernel"""
...@@ -120,6 +123,8 @@ class FusedAttnHelper: ...@@ -120,6 +123,8 @@ class FusedAttnHelper:
self.q_max_seqlen, self.q_max_seqlen,
self.kv_max_seqlen, self.kv_max_seqlen,
self.head_dim, self.head_dim,
self.window_size[0],
self.window_size[1],
) )
@staticmethod @staticmethod
...@@ -263,6 +268,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -263,6 +268,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
q_max_seqlen, q_max_seqlen,
kv_max_seqlen, kv_max_seqlen,
head_dim, head_dim,
config.window_size,
).get_fused_attn_backend() ).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
...@@ -309,6 +315,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -309,6 +315,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training, config.is_training,
config.max_segments_per_seq, config.max_segments_per_seq,
config.window_size[0],
config.window_size[1],
) )
wkspace_aval = q_aval.update( wkspace_aval = q_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -388,6 +396,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -388,6 +396,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training, config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(), not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
) )
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
...@@ -615,6 +625,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -615,6 +625,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config.is_training, config.is_training,
deterministic, deterministic,
config.max_segments_per_seq, config.max_segments_per_seq,
config.window_size[0],
config.window_size[1],
) )
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
...@@ -714,6 +726,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -714,6 +726,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training, config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(), not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
) )
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
...@@ -1042,6 +1056,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1042,6 +1056,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
def partition(config, mesh, arg_infos, result_infos): def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work. # Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
is_context_parallel and config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel: if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
...@@ -1136,6 +1153,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1136,6 +1153,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
def partition(config, mesh, arg_infos, result_infos): def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work. # Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
is_context_parallel and config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel: if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
...@@ -1284,6 +1304,7 @@ def fused_attn_fwd( ...@@ -1284,6 +1304,7 @@ def fused_attn_fwd(
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
max_segments_per_seq: int, max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
) -> jnp.ndarray: ) -> jnp.ndarray:
...@@ -1314,6 +1335,11 @@ def fused_attn_fwd( ...@@ -1314,6 +1335,11 @@ def fused_attn_fwd(
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode. is_training (bool): Flag indicating whether the model is in training mode.
max_segments_per_seq (int):
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]): Sliding window size.
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
...@@ -1356,6 +1382,7 @@ def fused_attn_fwd( ...@@ -1356,6 +1382,7 @@ def fused_attn_fwd(
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
context_parallel_load_balanced=context_parallel_causal_load_balanced, context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
) )
...@@ -1390,6 +1417,7 @@ def fused_attn_bwd( ...@@ -1390,6 +1417,7 @@ def fused_attn_bwd(
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
max_segments_per_seq: int, max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
): ):
...@@ -1421,6 +1449,11 @@ def fused_attn_bwd( ...@@ -1421,6 +1449,11 @@ def fused_attn_bwd(
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode. is_training (bool): Flag indicating whether the model is in training mode.
max_segments_per_seq (int):
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]): Sliding window size .
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
...@@ -1466,6 +1499,7 @@ def fused_attn_bwd( ...@@ -1466,6 +1499,7 @@ def fused_attn_bwd(
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
context_parallel_load_balanced=context_parallel_causal_load_balanced, context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
) )
......
...@@ -135,6 +135,8 @@ struct CustomCallFusedAttnDescriptor { ...@@ -135,6 +135,8 @@ struct CustomCallFusedAttnDescriptor {
DType wkspace_dtype; DType wkspace_dtype;
bool is_training; bool is_training;
bool deterministic; bool deterministic;
int64_t window_size_left;
int64_t window_size_right;
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
...@@ -143,7 +145,7 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -143,7 +145,7 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic); bool deterministic, int64_t window_size_left, int64_t window_size_right);
// Transpose // Transpose
...@@ -239,14 +241,15 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, ...@@ -239,14 +241,15 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads, size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim); size_t head_dim, int64_t window_size_left,
int64_t window_size_right);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes( pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq); size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -255,7 +258,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -255,7 +258,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq); bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -15,11 +15,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, ...@@ -15,11 +15,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_attn_heads, size_t kv_attn_heads, size_t q_attn_heads, size_t kv_attn_heads,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim) { size_t head_dim, int64_t window_size_left,
int64_t window_size_right) {
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, -1, -1); head_dim, head_dim, window_size_left, window_size_right);
return backend; return backend;
} }
...@@ -105,7 +106,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -105,7 +106,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq) { size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
// For qkv_packed // For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
...@@ -155,27 +156,28 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -155,27 +156,28 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(
o_tensor.data(), &aux_output_tensors, qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, is_training, dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor,
scaling_factor, dropout_probability, qkv_layout, bias_type, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
mask_type, -1, -1, query_workspace_tensor.data(), nullptr); window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, query_workspace_tensor.data(), nullptr); bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(),
nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
-1, query_workspace_tensor.data(), nullptr); window_size_right, query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported QKVLayout."); NVTE_ERROR("Unsupported QKVLayout.");
} }
...@@ -223,6 +225,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -223,6 +225,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto is_training = descriptor.is_training; auto is_training = descriptor.is_training;
auto max_segments_per_seq = descriptor.max_segments_per_seq; auto max_segments_per_seq = descriptor.max_segments_per_seq;
auto window_size_left = descriptor.window_size_left;
auto window_size_right = descriptor.window_size_right;
/* Input tensors */ /* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -269,7 +273,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -269,7 +273,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, -1, -1); head_dim, head_dim, window_size_left, window_size_right);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
...@@ -288,12 +292,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -288,12 +292,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto qkv = buffers[0]; auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
q_max_seqlen, is_training, descriptor.scaling_factor, rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -306,7 +310,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -306,7 +310,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, workspace_tensor.data(), stream); bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -322,8 +326,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -322,8 +326,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
-1, workspace_tensor.data(), stream); window_size_left, window_size_right, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -336,7 +340,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -336,7 +340,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq) { bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) {
// For qkv_packed // For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
...@@ -398,8 +403,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -398,8 +403,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, deterministic, bias_type, mask_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr); deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
...@@ -408,8 +413,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -408,8 +413,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
-1, deterministic, query_workspace_tensor.data(), nullptr); window_size_left, window_size_right, deterministic, query_workspace_tensor.data(),
nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(), doutput_tensor.data(),
...@@ -419,8 +425,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -419,8 +425,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
-1, deterministic, query_workspace_tensor.data(), nullptr); window_size_left, window_size_right, deterministic,
query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -470,6 +477,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -470,6 +477,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto deterministic = descriptor.deterministic; auto deterministic = descriptor.deterministic;
auto max_segments_per_seq = descriptor.max_segments_per_seq; auto max_segments_per_seq = descriptor.max_segments_per_seq;
auto window_size_left = descriptor.window_size_left;
auto window_size_right = descriptor.window_size_right;
/* Input tensors */ /* Input tensors */
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -513,7 +522,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -513,7 +522,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, -1, -1); head_dim, head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias); rng_state, bias);
...@@ -535,13 +544,14 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -535,13 +544,14 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>()); std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream); cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream);
} }
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream); bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -568,8 +578,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -568,8 +578,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, deterministic, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
workspace_tensor.data(), stream); deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -604,8 +614,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -604,8 +614,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
deterministic, workspace_tensor.data(), stream); window_size_right, deterministic, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
......
...@@ -69,11 +69,15 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -69,11 +69,15 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic) { bool deterministic, int64_t window_size_left, int64_t window_size_right) {
return PackOpaque(CustomCallFusedAttnDescriptor{ return PackOpaque(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, CustomCallFusedAttnDescriptor{input_batch, bias_batch, q_max_seqlen,
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, kv_max_seqlen, attn_heads, num_gqa_groups,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic}); bias_heads, head_dim, max_segments_per_seq,
wkspace_size, scaling_factor, dropout_probability,
bias_type, mask_type, qkv_layout,
dtype, wkspace_dtype, is_training,
deterministic, window_size_left, window_size_right});
} }
} // namespace jax } // namespace jax
......
...@@ -25,7 +25,7 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -25,7 +25,7 @@ from jax.ad_checkpoint import checkpoint_name
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout from ..attention import AttnBiasType, AttnMaskType, QKVLayout
from ..attention import is_fused_attn_kernel_available, canonicalize_attn_mask_type from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn from ..attention import fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import num_of_devices from ..sharding import num_of_devices
...@@ -118,6 +118,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -118,6 +118,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
float32_logits: bool = False float32_logits: bool = False
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -193,11 +194,27 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -193,11 +194,27 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights += bias attn_weights += bias
def apply_swa_mask(attn_mask_type: AttnMaskType, original_mask: Array) -> Array:
"""Apply the sliding window mask to a given mask"""
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, self.window_size, attn_mask_type)
# In swa_mask 0 is masked out, in original_mask 1 is masked out
swa_mask = 1 - swa_mask.astype(original_mask.dtype)
swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape)
new_mask = jnp.where(original_mask == 0, swa_mask_bcast, original_mask)
return new_mask
def convert_to_softmax_type(attn_mask_type, mask): def convert_to_softmax_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType""" """Convert the attn_mask_type to SoftmaxType"""
# mask is ignored for no_mask and causal_mask # mask is ignored for no_mask and causal_mask without sliding window
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: if attn_mask_type == AttnMaskType.NO_MASK:
mask = None
if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
mask = None mask = None
if mask is not None:
mask = apply_swa_mask(attn_mask_type, mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]: if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
...@@ -244,6 +261,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -244,6 +261,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -289,6 +307,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -289,6 +307,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
window_size=self.window_size,
) )
elif self.qkv_layout == QKVLayout.BSHD_BS2HD: elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat """kvpacked format, treat
...@@ -311,6 +330,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -311,6 +330,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
window_size=self.window_size,
) )
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
...@@ -328,6 +348,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -328,6 +348,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
window_size=self.window_size,
) )
else: else:
raise ValueError(f"Unsupported {self.qkv_layout=}.") raise ValueError(f"Unsupported {self.qkv_layout=}.")
...@@ -440,6 +461,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -440,6 +461,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -459,6 +482,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -459,6 +482,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
qkv_layout: str = "bshd_bshd_bshd" qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -532,6 +556,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -532,6 +556,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
seqlen_q, seqlen_q,
seqlen_kv, seqlen_kv,
self.head_dim, self.head_dim,
self.window_size,
) )
use_fused_attn = enable_fused_attn and has_fused_attn_kernel use_fused_attn = enable_fused_attn and has_fused_attn_kernel
...@@ -577,6 +602,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -577,6 +602,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
else: else:
x = _FusedDotProductAttention( x = _FusedDotProductAttention(
...@@ -587,6 +613,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -587,6 +613,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=self.window_size,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
return x return x
...@@ -856,6 +883,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -856,6 +883,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
For fused attention backend, the accumulation is always float32 without the perf overhead. For fused attention backend, the accumulation is always float32 without the perf overhead.
fuse_qkv: bool, default = None fuse_qkv: bool, default = None
Deprecated. Please refer `fuse_qkv_params` Deprecated. Please refer `fuse_qkv_params`
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window.
""" """
head_dim: int head_dim: int
...@@ -886,6 +915,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -886,6 +915,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False float32_logits: bool = False
window_size: Optional[Tuple[int, int]] = None
# Deprecated parameters # Deprecated parameters
num_heads: Optional[int] = None num_heads: Optional[int] = None
...@@ -1280,6 +1310,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1280,6 +1310,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
qkv_layout=qkv_layout.name, qkv_layout=qkv_layout.name,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
)(*dpa_args, mask, bias, deterministic=deterministic) )(*dpa_args, mask, bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
...@@ -1555,6 +1586,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1555,6 +1586,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
:math:`\frac{alpha}{rank} * lora\_output`. None means no scaling. :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
enable_sequence_parallel: bool, default = False enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot. Whether to enable sequence parallelism to operations except dot.
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -1618,6 +1651,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1618,6 +1651,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_sequence_parallel: bool = False enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
window_size: Optional[Tuple[int, int]] = None
def __post_init__(self): def __post_init__(self):
if self.mha_kernel_init is None: if self.mha_kernel_init is None:
...@@ -1771,6 +1805,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1771,6 +1805,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
name=mha_name, name=mha_name,
window_size=self.window_size,
)(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
def hidden_dropout(x, deterministic): def hidden_dropout(x, deterministic):
...@@ -1848,6 +1883,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1848,6 +1883,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
name="encoder_decoder_attention", name="encoder_decoder_attention",
window_size=self.window_size,
)(x, encoded, encoder_decoder_mask, deterministic=deterministic) )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
y = with_sharding_constraint_by_logical_axes( y = with_sharding_constraint_by_logical_axes(
......
...@@ -80,6 +80,7 @@ class DotProductAttention(TransformerEngineBaseLayer): ...@@ -80,6 +80,7 @@ class DotProductAttention(TransformerEngineBaseLayer):
qkv_layout: str = "bshd_bshd_bshd" qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
def setup(self) -> None: def setup(self) -> None:
"""setup""" """setup"""
...@@ -102,6 +103,7 @@ class DotProductAttention(TransformerEngineBaseLayer): ...@@ -102,6 +103,7 @@ class DotProductAttention(TransformerEngineBaseLayer):
qkv_layout=self.qkv_layout, qkv_layout=self.qkv_layout,
scale_factor=self.scale_factor, scale_factor=self.scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
) )
self.create_layer("dot_product_attention", dpa_cls) self.create_layer("dot_product_attention", dpa_cls)
...@@ -151,6 +153,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -151,6 +153,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False float32_logits: bool = False
window_size: Optional[Tuple[int, int]] = None
# Deprecated parameters # Deprecated parameters
num_heads: Optional[int] = None num_heads: Optional[int] = None
...@@ -233,6 +236,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -233,6 +236,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
window_size=self.window_size,
) )
self.create_layer("multi_head_attn", mha_cls) self.create_layer("multi_head_attn", mha_cls)
...@@ -292,6 +296,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -292,6 +296,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_sequence_parallel: bool = False enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
window_size: Optional[Tuple[int, int]] = None
def __post_init__(self): def __post_init__(self):
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
...@@ -371,6 +376,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -371,6 +376,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_sequence_parallel=self.enable_sequence_parallel, enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
window_size=self.window_size,
) )
self.create_layer("transformerlayer", transformerlayer_cls) self.create_layer("transformerlayer", transformerlayer_cls)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment