Unverified Commit 7f5c784e authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Fused attention unit tests fixes and refinements (#1352)



* Add util functions to attn_mask_type
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add util functions to qkv_layout
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix THD cross reference code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove explicit segment_pad, encoding it to segment_ids
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add jax.jit, replace _token with segment_ids, rename bias shape enum
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add comment for make_mask
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Clean code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add doc strings for the added functions
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove cache for fa deterministic which causes UT failed
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename fixture to avoid conflict
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent f4f35c2f
......@@ -20,7 +20,7 @@ def clear_live_arrays():
@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
def enable_fused_attn_after_hopper():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
......
......@@ -20,7 +20,6 @@ from distributed_test_base import (
from utils import (
make_causal_mask,
make_self_mask,
assert_tree_like_allclose,
assert_allclose,
print_debug_tensor_stats,
)
......@@ -32,7 +31,6 @@ from transformer_engine.jax.attention import (
AttnMaskType,
QKVLayout,
QKVFormat,
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
......@@ -421,7 +419,7 @@ class TestDistributedContextParallelSelfAttn:
dropout_prob = 0.0
is_training = True
dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout)
qkv_format = qkv_layout.get_qkv_format()
batch, seqlen, num_head, hidden = data_shape
......@@ -503,7 +501,7 @@ class TestDistributedContextParallelSelfAttn:
# Gradient is small, use a gradient multiplier to amplify the gradient
_, max_seq_len, num_heads, _ = data_shape
gradient_multiplier = max_seq_len * num_heads
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]:
if attn_mask_type.is_causal():
gradient_multiplier /= 10
ret_valid = func(*args, **kwargs)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
......
......@@ -28,7 +28,6 @@ from transformer_engine.jax.attention import (
QKVFormat,
fused_attn,
fused_attn_thd,
get_qkv_format,
make_swa_mask,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
......@@ -50,6 +49,7 @@ def init():
yield
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
def general_dot_product_attention(
query: ArrayLike,
key: ArrayLike,
......@@ -102,29 +102,36 @@ def general_dot_product_attention(
return context
def is_causal_mask(mask: AttnMaskType):
"""
Check if the mask is a causal mask
"""
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]
def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
@jax.jit
def make_causal_mask(
segment_ids_q: ArrayLike,
segment_ids_kv: ArrayLike,
segment_pos_q: ArrayLike = None,
segment_pos_kv: ArrayLike = None,
) -> Array:
"""
Create inverse padded causal mask where `True` means allowing the corresponding
position to participate in attention and `False` means masking out that position.
If segment_pos is not provided, aragne of the segment_ids will be applied.
"""
q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
if segment_pos_q is None:
segment_pos_q = jnp.broadcast_to(
jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
return inv_causal_mask
@partial(jax.jit, static_argnums=(4, 5))
def make_mask(
q_token: ArrayLike,
kv_token: ArrayLike,
segment_pad_q: ArrayLike,
segment_pad_kv: ArrayLike,
segment_ids_q: ArrayLike,
segment_ids_kv: ArrayLike,
segment_pos_q: ArrayLike,
segment_pos_kv: ArrayLike,
attn_mask_type: AttnMaskType,
window_size: Optional[Tuple[int, int]] = None,
) -> Array:
......@@ -132,18 +139,31 @@ def make_mask(
Create attention mask based on mask type. A `True` value in the mask means
masking out the corresponding position and a `False` value means allowing
that position to participate in attention.
- segment_ids should start with 1, and using 0s for the paddings.
Expected that each segment starts without paddings.
- segment_pos marks the token position in the segments.
A example pair of segments_ids and segment_pos:
segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
"""
inv_mask = make_attention_mask(
q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
)
if is_causal_mask(attn_mask_type):
inv_causal_mask = make_causal_mask(q_token, kv_token)
inv_mask = combine_masks(inv_causal_mask, inv_mask)
if segment_pad_q is not None and segment_pad_kv is not None:
inv_pad_mask = make_attention_mask(
segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
if attn_mask_type.is_causal():
if segment_pos_q is None:
segment_pos_q = jnp.broadcast_to(
jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
inv_causal_mask = make_attention_mask(
segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
)
inv_mask = combine_masks(inv_pad_mask, inv_mask)
inv_mask = combine_masks(inv_causal_mask, inv_mask)
if window_size is not None:
max_seqlen_q = inv_mask.shape[-2]
......@@ -157,7 +177,8 @@ def make_mask(
return mask
def get_seqlens_and_offsets(segment_ids, segment_pad):
@jax.jit
def get_seqlens_and_offsets(segment_ids):
batch, max_seqlen = segment_ids.shape
bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
......@@ -165,7 +186,7 @@ def get_seqlens_and_offsets(segment_ids, segment_pad):
def _find_offsets(x):
same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
first_column = jnp.ones((x.shape[0], 1), dtype=bool)
first_column = x[..., :1] != 0
same_as_previous = jnp.hstack((first_column, same_as_previous))
return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
same_as_previous
......@@ -173,13 +194,9 @@ def get_seqlens_and_offsets(segment_ids, segment_pad):
offsets = _find_offsets(segment_ids)
offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
if segment_pad is not None:
segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids)
padding_aware_seqlen = bincount_vmap(segment_id_with_paddings)
output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1)
else:
output = jnp.insert(seqlens, -1, values=0, axis=-1)
return output, offsets
seqlens = jnp.insert(seqlens, -1, values=0, axis=-1)
seqlens = jnp.where(seqlens, seqlens, -1)
return seqlens, offsets
@jax.jit
......@@ -200,8 +217,8 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
query,
key,
value,
bias=bias,
mask=mask,
bias,
mask,
deterministic=not kwargs["is_training"],
scale_factor=kwargs["scaling_factor"],
dropout_rate=kwargs["dropout_probability"],
......@@ -228,7 +245,6 @@ def customcall_fused_dpa(
TE customcall dot product attention implementation
"""
qkv_layout = kwargs["qkv_layout"]
is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD
match qkv_layout:
case QKVLayout.BS3HD | QKVLayout.T3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
......@@ -242,7 +258,7 @@ def customcall_fused_dpa(
qkv_args = (query, key, value)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
if not is_thd:
if not qkv_layout.is_thd():
kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd(
......@@ -262,10 +278,10 @@ class BiasShape(Enum):
Enum class to represent the different bias shapes used in the fused attention.
"""
BIAS_1HSS = "1HSS"
BIAS_B1SS = "B1SS"
BIAS_BHSS = "BHSS"
BIAS_11SS = "11SS"
_1HSS = "1HSS"
_B1SS = "B1SS"
_BHSS = "BHSS"
_11SS = "11SS"
@dataclass
......@@ -300,18 +316,12 @@ class FusedAttnRunner:
def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available
if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
AttnMaskType.PADDING_MASK,
AttnMaskType.PADDING_CAUSAL_MASK,
]:
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")
qkv_format = get_qkv_format(self.qkv_layout)
if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD:
if self.qkv_layout.is_qkvpacked():
if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD:
if self.num_heads_q != self.num_heads_kv:
pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
......@@ -339,15 +349,11 @@ class FusedAttnRunner:
if (
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
and self.bias_shape != BiasShape.BIAS_1HSS
and self.bias_shape != BiasShape._1HSS
):
if self.attn_mask_type not in [
AttnMaskType.NO_MASK,
AttnMaskType.CAUSAL_MASK,
]:
if self.attn_mask_type.is_padding():
pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
"B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
)
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip(
......@@ -370,18 +376,18 @@ class FusedAttnRunner:
if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
elif self.bias_shape == BiasShape.BIAS_1HSS:
elif self.bias_shape == BiasShape._1HSS:
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_B1SS:
elif self.bias_shape == BiasShape._B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_BHSS:
elif self.bias_shape == BiasShape._BHSS:
bias_shape = (
self.batch_size,
self.num_heads_q,
self.max_seqlen_q,
self.max_seqlen_kv,
)
elif self.bias_shape == BiasShape.BIAS_11SS:
elif self.bias_shape == BiasShape._11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else:
pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
......@@ -391,7 +397,7 @@ class FusedAttnRunner:
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
if self.attn_bias_type != AttnBiasType.NO_BIAS:
if self.bias_shape == BiasShape.BIAS_1HSS:
if self.bias_shape == BiasShape._1HSS:
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
else:
# [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
......@@ -408,10 +414,10 @@ class FusedAttnRunner:
else:
self.bias = None
if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pad_ratio = 0.0
else:
if self.attn_mask_type.is_padding():
pad_ratio = 0.3
else:
pad_ratio = 0.0
def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio)
......@@ -425,6 +431,8 @@ class FusedAttnRunner:
rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
segment_pos = np.zeros((batch_size, sequence_length), dtype=int)
# [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
# [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
segment_pad = np.zeros((batch_size, sequence_length), dtype=int)
......@@ -440,58 +448,62 @@ class FusedAttnRunner:
break
segment_end = current_pos + segment_size
segment_ids[i, current_pos:segment_end] = segment_id
segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
if with_segment_pad:
num_valid = rng.integers(1, segment_size + 1)
segment_pad[i, current_pos + num_valid : segment_end] = 1
current_pos = segment_end
segment_id += 1
segment_pad[i, current_pos:sequence_length] = 1
return segment_ids, segment_pad
if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
segment_ids, segment_pos, segment_pad = map(
jnp.asarray, [segment_ids, segment_pos, segment_pad]
)
segment_ids = jnp.where(segment_pad, 0, segment_ids)
return segment_ids, segment_pos, segment_pad
if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
self.token_q, self.segment_pad_q = generate_random_segment_ids(
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
# TODO(rewang): Check if qkvpacked supported different q/kv
# TODO(rewang): Causal with different q/kv segment_id fails
if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type):
self.token_kv = self.token_q
self.segment_pad_kv = self.segment_pad_q
if self.qkv_layout == QKVLayout.T3HD:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
)
self.pad_q = self.segment_pad_q
self.pad_kv = self.segment_pad_kv
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
self.segment_pad_q = self.segment_pad_kv = None
self.segment_ids_q, self.pad_q = gen_valid(
self.batch_size, self.max_seqlen_q, pad_ratio
)
self.segment_ids_kv, self.pad_kv = gen_valid(
self.batch_size, self.max_seqlen_kv, pad_ratio
)
self.segment_pos_q = self.segment_pos_kv = None
self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
# For reference code
self.mask = make_mask(
self.token_q,
self.token_kv,
self.segment_pad_q,
self.segment_pad_kv,
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
self.window_size,
)
if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(
self.token_q, self.segment_pad_q
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(
self.token_kv, self.segment_pad_kv
)
if self.qkv_layout.is_thd():
self.mask_for_customcall = None # THD format doesn't support mask
else:
self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
self.mask_for_customcall = self.mask
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
......@@ -547,13 +559,11 @@ class FusedAttnRunner:
"""
self._setup_inputs()
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS:
pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.")
def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.max_seqlen_q * self.num_heads_q
if is_causal_mask(self.attn_mask_type):
if self.attn_mask_type.is_causal():
gradient_multiplier /= 10
# Keep only valid result for the gradient
ret_valid = jnp.where(
......@@ -586,7 +596,7 @@ class FusedAttnRunner:
}
# We can compute dBias only for the [1, h, s, s] layout
arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)
arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2)
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
......@@ -629,7 +639,7 @@ class FusedAttnRunner:
check_dqkv(primitive_dk, reference_dk, self.pad_kv)
check_dqkv(primitive_dv, reference_dv, self.pad_kv)
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
primitive_dbias = primitive_dgrad[3]
reference_dbias = reference_dgrad[3]
......@@ -658,16 +668,6 @@ class FusedAttnRunner:
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
......@@ -736,6 +736,16 @@ class TestFusedAttn:
pytest.param(False, id="INFERENCE"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
],
)
def _test_forward(
b,
s_q,
......@@ -779,6 +789,13 @@ class TestFusedAttn:
runner.test_forward()
@staticmethod
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def test_backward(
b,
s_q,
......
......@@ -19,7 +19,11 @@ from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
from transformer_engine.jax.attention import AttnMaskType, make_swa_mask
from transformer_engine.jax.attention import (
AttnMaskType,
canonicalize_attn_mask_type,
make_swa_mask,
)
from transformer_engine.jax.fp8 import DType as TEDType
PRNGKey = Any
......@@ -913,15 +917,7 @@ def apply_swa_mask(
window_size: Tuple[int, int] = (-1, -1),
) -> Array:
"""Apply the sliding window mask to a given mask"""
mask_map = {
"no_mask": AttnMaskType.NO_MASK,
"padding": AttnMaskType.PADDING_MASK,
"causal": AttnMaskType.CAUSAL_MASK,
"padding_causal": AttnMaskType.PADDING_CAUSAL_MASK,
"causal_bottom_right": AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
"padding_causal_bottom_right": AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
}
_attn_mask_type = mask_map.get(attn_mask_type, None)
_attn_mask_type = canonicalize_attn_mask_type(attn_mask_type)
assert _attn_mask_type is not None
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
......
......@@ -46,6 +46,42 @@ class AttnMaskType(Enum):
CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
def is_causal(self):
"""Returns True if the mask is a causal mask"""
return self in [
AttnMaskType.CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_MASK,
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
def is_padding(self):
"""Returns True if the mask includes padding"""
return self in [
AttnMaskType.PADDING_MASK,
AttnMaskType.PADDING_CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
def is_bottom_right(self):
"""Returns True if the causal mask is calculated from the bottom-right section"""
return self in [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
class QKVFormat(Enum):
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
BSHD: q,k,v memory layout with [b, s, ..., h, d]
THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
"""
SBHD = NVTE_QKV_Format.NVTE_SBHD
BSHD = NVTE_QKV_Format.NVTE_BSHD
THD = NVTE_QKV_Format.NVTE_THD
class QKVLayout(Enum):
"""
......@@ -66,17 +102,35 @@ class QKVLayout(Enum):
THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD
THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD
def get_qkv_format(self):
"""
Return the corresponding qkv_format (BSHD, SBHD, THD)
"""
return QKVFormat(nvte_get_qkv_format(self.value))
def is_qkvpacked(self):
"""
Return True if the query, key, value is packed
"""
return self in [QKVLayout.BS3HD, QKVLayout.T3HD]
class QKVFormat(Enum):
def is_kvpacked(self):
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
BSHD: q,k,v memory layout with [b, s, ..., h, d]
THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
Return True if the key, value is packed
"""
return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD]
SBHD = NVTE_QKV_Format.NVTE_SBHD
BSHD = NVTE_QKV_Format.NVTE_BSHD
THD = NVTE_QKV_Format.NVTE_THD
def is_separate(self):
"""
Return True if the query, key, value are three separate tensors
"""
return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD]
def is_thd(self):
"""
Return True if the layout belongs to THD
"""
return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
class CPStrategy(Enum):
......@@ -92,13 +146,6 @@ class CPStrategy(Enum):
RING = 2
def get_qkv_format(qkv_layout):
"""
Get qkv_format from qkv_layout
"""
return QKVFormat(nvte_get_qkv_format(qkv_layout.value))
def make_swa_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
......@@ -136,12 +183,8 @@ def make_swa_mask(
swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
if window_size is None:
return swa_mask
bottom_right_masks = [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
left_window, right_window = window_size
if attn_mask_type in bottom_right_masks:
if attn_mask_type.is_bottom_right():
if left_window < 0:
left_window = max_seqlen_kv
if right_window < 0:
......@@ -310,7 +353,7 @@ def fused_attn(
(jnp.ndarray): The output tensor from the fused attention.
"""
assert (
get_qkv_format(qkv_layout) != QKVFormat.THD
not qkv_layout.is_thd()
), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format."
# Check inputs qkv
......@@ -327,11 +370,7 @@ def fused_attn(
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
# convert the mask to seqlens, mask doesn't support ragged offsets
if attn_mask_type in [
AttnMaskType.NO_MASK,
AttnMaskType.CAUSAL_MASK,
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
]:
if not attn_mask_type.is_padding():
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
......@@ -448,7 +487,7 @@ def fused_attn_thd(
QKVLayout.T3HD, 0.125, 0, True, 3)
"""
assert (
get_qkv_format(qkv_layout) == QKVFormat.THD
qkv_layout.is_thd()
), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."
# Check inputs qkv
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
from dataclasses import dataclass
from functools import partial, reduce, cache
from functools import partial, reduce
import operator
import os
from typing import Optional, Tuple
......@@ -133,7 +133,6 @@ class FusedAttnHelper:
)
@staticmethod
@cache
def is_non_deterministic_allowed():
"""Check if non-deterministic kernels are allowed"""
return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment