Unverified Commit 7f5c784e authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Fused attention unit tests fixes and refinements (#1352)



* Add util functions to attn_mask_type
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add util functions to qkv_layout
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix THD cross reference code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove explicit segment_pad, encoding it to segment_ids
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add jax.jit, replace _token with segment_ids, rename bias shape enum
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add comment for make_mask
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Clean code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add doc strings for the added functions
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove cache for fa deterministic which causes UT failed
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename fixture to avoid conflict
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent f4f35c2f
...@@ -20,7 +20,7 @@ def clear_live_arrays(): ...@@ -20,7 +20,7 @@ def clear_live_arrays():
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def enable_fused_attn(): def enable_fused_attn_after_hopper():
""" """
Enable fused attn for hopper+ arch. Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic. Fused attn kernels on pre-hopper arch are not deterministic.
......
...@@ -20,7 +20,6 @@ from distributed_test_base import ( ...@@ -20,7 +20,6 @@ from distributed_test_base import (
from utils import ( from utils import (
make_causal_mask, make_causal_mask,
make_self_mask, make_self_mask,
assert_tree_like_allclose,
assert_allclose, assert_allclose,
print_debug_tensor_stats, print_debug_tensor_stats,
) )
...@@ -32,7 +31,6 @@ from transformer_engine.jax.attention import ( ...@@ -32,7 +31,6 @@ from transformer_engine.jax.attention import (
AttnMaskType, AttnMaskType,
QKVLayout, QKVLayout,
QKVFormat, QKVFormat,
get_qkv_format,
reorder_causal_load_balancing, reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing, inverse_reorder_causal_load_balancing,
CPStrategy, CPStrategy,
...@@ -421,7 +419,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -421,7 +419,7 @@ class TestDistributedContextParallelSelfAttn:
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
dp_size, cp_size, tp_size = mesh_shape dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout) qkv_format = qkv_layout.get_qkv_format()
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
...@@ -503,7 +501,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -503,7 +501,7 @@ class TestDistributedContextParallelSelfAttn:
# Gradient is small, use a gradient multiplier to amplify the gradient # Gradient is small, use a gradient multiplier to amplify the gradient
_, max_seq_len, num_heads, _ = data_shape _, max_seq_len, num_heads, _ = data_shape
gradient_multiplier = max_seq_len * num_heads gradient_multiplier = max_seq_len * num_heads
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]: if attn_mask_type.is_causal():
gradient_multiplier /= 10 gradient_multiplier /= 10
ret_valid = func(*args, **kwargs) ret_valid = func(*args, **kwargs)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
......
...@@ -28,7 +28,6 @@ from transformer_engine.jax.attention import ( ...@@ -28,7 +28,6 @@ from transformer_engine.jax.attention import (
QKVFormat, QKVFormat,
fused_attn, fused_attn,
fused_attn_thd, fused_attn_thd,
get_qkv_format,
make_swa_mask, make_swa_mask,
) )
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
...@@ -50,6 +49,7 @@ def init(): ...@@ -50,6 +49,7 @@ def init():
yield yield
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
def general_dot_product_attention( def general_dot_product_attention(
query: ArrayLike, query: ArrayLike,
key: ArrayLike, key: ArrayLike,
...@@ -102,29 +102,36 @@ def general_dot_product_attention( ...@@ -102,29 +102,36 @@ def general_dot_product_attention(
return context return context
def is_causal_mask(mask: AttnMaskType): @jax.jit
""" def make_causal_mask(
Check if the mask is a causal mask segment_ids_q: ArrayLike,
""" segment_ids_kv: ArrayLike,
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] segment_pos_q: ArrayLike = None,
segment_pos_kv: ArrayLike = None,
) -> Array:
def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
""" """
Create inverse padded causal mask where `True` means allowing the corresponding Create inverse padded causal mask where `True` means allowing the corresponding
position to participate in attention and `False` means masking out that position. position to participate in attention and `False` means masking out that position.
If segment_pos is not provided, aragne of the segment_ids will be applied.
""" """
q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape) if segment_pos_q is None:
kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape) segment_pos_q = jnp.broadcast_to(
inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal) jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
return inv_causal_mask return inv_causal_mask
@partial(jax.jit, static_argnums=(4, 5))
def make_mask( def make_mask(
q_token: ArrayLike, segment_ids_q: ArrayLike,
kv_token: ArrayLike, segment_ids_kv: ArrayLike,
segment_pad_q: ArrayLike, segment_pos_q: ArrayLike,
segment_pad_kv: ArrayLike, segment_pos_kv: ArrayLike,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
) -> Array: ) -> Array:
...@@ -132,18 +139,31 @@ def make_mask( ...@@ -132,18 +139,31 @@ def make_mask(
Create attention mask based on mask type. A `True` value in the mask means Create attention mask based on mask type. A `True` value in the mask means
masking out the corresponding position and a `False` value means allowing masking out the corresponding position and a `False` value means allowing
that position to participate in attention. that position to participate in attention.
- segment_ids should start with 1, and using 0s for the paddings.
Expected that each segment starts without paddings.
- segment_pos marks the token position in the segments.
A example pair of segments_ids and segment_pos:
segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
""" """
inv_mask = make_attention_mask( inv_mask = make_attention_mask(
q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
) )
if is_causal_mask(attn_mask_type): if attn_mask_type.is_causal():
inv_causal_mask = make_causal_mask(q_token, kv_token) if segment_pos_q is None:
inv_mask = combine_masks(inv_causal_mask, inv_mask) segment_pos_q = jnp.broadcast_to(
if segment_pad_q is not None and segment_pad_kv is not None: jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
inv_pad_mask = make_attention_mask( )
segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1) if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
inv_causal_mask = make_attention_mask(
segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
) )
inv_mask = combine_masks(inv_pad_mask, inv_mask) inv_mask = combine_masks(inv_causal_mask, inv_mask)
if window_size is not None: if window_size is not None:
max_seqlen_q = inv_mask.shape[-2] max_seqlen_q = inv_mask.shape[-2]
...@@ -157,7 +177,8 @@ def make_mask( ...@@ -157,7 +177,8 @@ def make_mask(
return mask return mask
def get_seqlens_and_offsets(segment_ids, segment_pad): @jax.jit
def get_seqlens_and_offsets(segment_ids):
batch, max_seqlen = segment_ids.shape batch, max_seqlen = segment_ids.shape
bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen)) bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
...@@ -165,7 +186,7 @@ def get_seqlens_and_offsets(segment_ids, segment_pad): ...@@ -165,7 +186,7 @@ def get_seqlens_and_offsets(segment_ids, segment_pad):
def _find_offsets(x): def _find_offsets(x):
same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
first_column = jnp.ones((x.shape[0], 1), dtype=bool) first_column = x[..., :1] != 0
same_as_previous = jnp.hstack((first_column, same_as_previous)) same_as_previous = jnp.hstack((first_column, same_as_previous))
return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))( return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
same_as_previous same_as_previous
...@@ -173,13 +194,9 @@ def get_seqlens_and_offsets(segment_ids, segment_pad): ...@@ -173,13 +194,9 @@ def get_seqlens_and_offsets(segment_ids, segment_pad):
offsets = _find_offsets(segment_ids) offsets = _find_offsets(segment_ids)
offsets = jnp.insert(offsets, -1, values=-1, axis=-1) offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
if segment_pad is not None: seqlens = jnp.insert(seqlens, -1, values=0, axis=-1)
segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids) seqlens = jnp.where(seqlens, seqlens, -1)
padding_aware_seqlen = bincount_vmap(segment_id_with_paddings) return seqlens, offsets
output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1)
else:
output = jnp.insert(seqlens, -1, values=0, axis=-1)
return output, offsets
@jax.jit @jax.jit
...@@ -200,8 +217,8 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs): ...@@ -200,8 +217,8 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
query, query,
key, key,
value, value,
bias=bias, bias,
mask=mask, mask,
deterministic=not kwargs["is_training"], deterministic=not kwargs["is_training"],
scale_factor=kwargs["scaling_factor"], scale_factor=kwargs["scaling_factor"],
dropout_rate=kwargs["dropout_probability"], dropout_rate=kwargs["dropout_probability"],
...@@ -228,7 +245,6 @@ def customcall_fused_dpa( ...@@ -228,7 +245,6 @@ def customcall_fused_dpa(
TE customcall dot product attention implementation TE customcall dot product attention implementation
""" """
qkv_layout = kwargs["qkv_layout"] qkv_layout = kwargs["qkv_layout"]
is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD
match qkv_layout: match qkv_layout:
case QKVLayout.BS3HD | QKVLayout.T3HD: case QKVLayout.BS3HD | QKVLayout.T3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
...@@ -242,7 +258,7 @@ def customcall_fused_dpa( ...@@ -242,7 +258,7 @@ def customcall_fused_dpa(
qkv_args = (query, key, value) qkv_args = (query, key, value)
case _: case _:
raise ValueError(f"Unsupported {qkv_layout=}") raise ValueError(f"Unsupported {qkv_layout=}")
if not is_thd: if not qkv_layout.is_thd():
kwargs.pop("max_segments_per_seq") kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd( return fused_attn_thd(
...@@ -262,10 +278,10 @@ class BiasShape(Enum): ...@@ -262,10 +278,10 @@ class BiasShape(Enum):
Enum class to represent the different bias shapes used in the fused attention. Enum class to represent the different bias shapes used in the fused attention.
""" """
BIAS_1HSS = "1HSS" _1HSS = "1HSS"
BIAS_B1SS = "B1SS" _B1SS = "B1SS"
BIAS_BHSS = "BHSS" _BHSS = "BHSS"
BIAS_11SS = "11SS" _11SS = "11SS"
@dataclass @dataclass
...@@ -300,18 +316,12 @@ class FusedAttnRunner: ...@@ -300,18 +316,12 @@ class FusedAttnRunner:
def _check_configs(self): def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available # TODO(rewang): probably adds this in is_fused_attn_available
if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [ if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
AttnMaskType.PADDING_MASK,
AttnMaskType.PADDING_CAUSAL_MASK,
]:
pytest.skip("THD format requires padding masks.") pytest.skip("THD format requires padding masks.")
qkv_format = get_qkv_format(self.qkv_layout) if self.qkv_layout.is_qkvpacked():
if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD:
if self.max_seqlen_q != self.max_seqlen_kv: if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD:
if self.num_heads_q != self.num_heads_kv: if self.num_heads_q != self.num_heads_kv:
pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv") pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
...@@ -339,15 +349,11 @@ class FusedAttnRunner: ...@@ -339,15 +349,11 @@ class FusedAttnRunner:
if ( if (
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
and self.bias_shape != BiasShape.BIAS_1HSS and self.bias_shape != BiasShape._1HSS
): ):
if self.attn_mask_type not in [ if self.attn_mask_type.is_padding():
AttnMaskType.NO_MASK,
AttnMaskType.CAUSAL_MASK,
]:
pytest.skip( pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for " "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
) )
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip( pytest.skip(
...@@ -370,18 +376,18 @@ class FusedAttnRunner: ...@@ -370,18 +376,18 @@ class FusedAttnRunner:
if self.attn_bias_type == AttnBiasType.NO_BIAS: if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None bias_shape = None
elif self.bias_shape == BiasShape.BIAS_1HSS: elif self.bias_shape == BiasShape._1HSS:
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_B1SS: elif self.bias_shape == BiasShape._B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_BHSS: elif self.bias_shape == BiasShape._BHSS:
bias_shape = ( bias_shape = (
self.batch_size, self.batch_size,
self.num_heads_q, self.num_heads_q,
self.max_seqlen_q, self.max_seqlen_q,
self.max_seqlen_kv, self.max_seqlen_kv,
) )
elif self.bias_shape == BiasShape.BIAS_11SS: elif self.bias_shape == BiasShape._11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else: else:
pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
...@@ -391,7 +397,7 @@ class FusedAttnRunner: ...@@ -391,7 +397,7 @@ class FusedAttnRunner:
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
if self.attn_bias_type != AttnBiasType.NO_BIAS: if self.attn_bias_type != AttnBiasType.NO_BIAS:
if self.bias_shape == BiasShape.BIAS_1HSS: if self.bias_shape == BiasShape._1HSS:
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0) self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
else: else:
# [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
...@@ -408,10 +414,10 @@ class FusedAttnRunner: ...@@ -408,10 +414,10 @@ class FusedAttnRunner:
else: else:
self.bias = None self.bias = None
if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: if self.attn_mask_type.is_padding():
pad_ratio = 0.0
else:
pad_ratio = 0.3 pad_ratio = 0.3
else:
pad_ratio = 0.0
def gen_valid(bs, max_seqlen, pad_ratio): def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio) pad_len = int(max_seqlen * pad_ratio)
...@@ -425,6 +431,8 @@ class FusedAttnRunner: ...@@ -425,6 +431,8 @@ class FusedAttnRunner:
rng = np.random.default_rng(seed=seed) rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
segment_ids = np.zeros((batch_size, sequence_length), dtype=int) segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
segment_pos = np.zeros((batch_size, sequence_length), dtype=int)
# [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
# [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
segment_pad = np.zeros((batch_size, sequence_length), dtype=int) segment_pad = np.zeros((batch_size, sequence_length), dtype=int)
...@@ -440,58 +448,62 @@ class FusedAttnRunner: ...@@ -440,58 +448,62 @@ class FusedAttnRunner:
break break
segment_end = current_pos + segment_size segment_end = current_pos + segment_size
segment_ids[i, current_pos:segment_end] = segment_id segment_ids[i, current_pos:segment_end] = segment_id
segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
if with_segment_pad: if with_segment_pad:
num_valid = rng.integers(1, segment_size + 1) num_valid = rng.integers(1, segment_size + 1)
segment_pad[i, current_pos + num_valid : segment_end] = 1 segment_pad[i, current_pos + num_valid : segment_end] = 1
current_pos = segment_end current_pos = segment_end
segment_id += 1 segment_id += 1
segment_pad[i, current_pos:sequence_length] = 1 segment_pad[i, current_pos:sequence_length] = 1
return segment_ids, segment_pad
if get_qkv_format(self.qkv_layout) == QKVFormat.THD: segment_ids, segment_pos, segment_pad = map(
jnp.asarray, [segment_ids, segment_pos, segment_pad]
)
segment_ids = jnp.where(segment_pad, 0, segment_ids)
return segment_ids, segment_pos, segment_pad
if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2 self.num_segments_per_seq = 2
self.token_q, self.segment_pad_q = generate_random_segment_ids( self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
) )
# TODO(rewang): Check if qkvpacked supported different q/kv if self.qkv_layout == QKVLayout.T3HD:
# TODO(rewang): Causal with different q/kv segment_id fails self.segment_ids_kv = self.segment_ids_q
if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type): self.segment_pos_kv = self.segment_pos_q
self.token_kv = self.token_q self.pad_kv = self.pad_q
self.segment_pad_kv = self.segment_pad_q
else: else:
self.token_kv, self.segment_pad_kv = generate_random_segment_ids( self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size, self.batch_size,
self.max_seqlen_kv, self.max_seqlen_kv,
self.num_segments_per_seq, self.num_segments_per_seq,
seed=2024, seed=2024,
) )
self.pad_q = self.segment_pad_q self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
self.pad_kv = self.segment_pad_kv self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else: else:
self.num_segments_per_seq = 1 self.num_segments_per_seq = 1
self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio) self.segment_ids_q, self.pad_q = gen_valid(
self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) self.batch_size, self.max_seqlen_q, pad_ratio
self.segment_pad_q = self.segment_pad_kv = None )
self.segment_ids_kv, self.pad_kv = gen_valid(
self.batch_size, self.max_seqlen_kv, pad_ratio
)
self.segment_pos_q = self.segment_pos_kv = None
self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
# For reference code
self.mask = make_mask( self.mask = make_mask(
self.token_q, self.segment_ids_q,
self.token_kv, self.segment_ids_kv,
self.segment_pad_q, self.segment_pos_q,
self.segment_pad_kv, self.segment_pos_kv,
self.attn_mask_type, self.attn_mask_type,
self.window_size, self.window_size,
) )
if get_qkv_format(self.qkv_layout) == QKVFormat.THD: if self.qkv_layout.is_thd():
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(
self.token_q, self.segment_pad_q
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(
self.token_kv, self.segment_pad_kv
)
self.mask_for_customcall = None # THD format doesn't support mask self.mask_for_customcall = None # THD format doesn't support mask
else: else:
self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
self.mask_for_customcall = self.mask self.mask_for_customcall = self.mask
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
...@@ -547,13 +559,11 @@ class FusedAttnRunner: ...@@ -547,13 +559,11 @@ class FusedAttnRunner:
""" """
self._setup_inputs() self._setup_inputs()
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS:
pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.")
def grad_func(func, *args, **kwargs): def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient # Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.max_seqlen_q * self.num_heads_q gradient_multiplier = self.max_seqlen_q * self.num_heads_q
if is_causal_mask(self.attn_mask_type): if self.attn_mask_type.is_causal():
gradient_multiplier /= 10 gradient_multiplier /= 10
# Keep only valid result for the gradient # Keep only valid result for the gradient
ret_valid = jnp.where( ret_valid = jnp.where(
...@@ -586,7 +596,7 @@ class FusedAttnRunner: ...@@ -586,7 +596,7 @@ class FusedAttnRunner:
} }
# We can compute dBias only for the [1, h, s, s] layout # We can compute dBias only for the [1, h, s, s] layout
arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2) arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2)
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit( jitted_primitive = jit(
...@@ -629,7 +639,7 @@ class FusedAttnRunner: ...@@ -629,7 +639,7 @@ class FusedAttnRunner:
check_dqkv(primitive_dk, reference_dk, self.pad_kv) check_dqkv(primitive_dk, reference_dk, self.pad_kv)
check_dqkv(primitive_dv, reference_dv, self.pad_kv) check_dqkv(primitive_dv, reference_dv, self.pad_kv)
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS: if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
primitive_dbias = primitive_dgrad[3] primitive_dbias = primitive_dgrad[3]
reference_dbias = reference_dgrad[3] reference_dbias = reference_dgrad[3]
...@@ -658,16 +668,6 @@ class FusedAttnRunner: ...@@ -658,16 +668,6 @@ class FusedAttnRunner:
) )
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attn_mask_type", "attn_mask_type",
[ [
...@@ -736,6 +736,16 @@ class TestFusedAttn: ...@@ -736,6 +736,16 @@ class TestFusedAttn:
pytest.param(False, id="INFERENCE"), pytest.param(False, id="INFERENCE"),
], ],
) )
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
],
)
def _test_forward( def _test_forward(
b, b,
s_q, s_q,
...@@ -779,6 +789,13 @@ class TestFusedAttn: ...@@ -779,6 +789,13 @@ class TestFusedAttn:
runner.test_forward() runner.test_forward()
@staticmethod @staticmethod
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def test_backward( def test_backward(
b, b,
s_q, s_q,
......
...@@ -19,7 +19,11 @@ from jax import lax, vmap ...@@ -19,7 +19,11 @@ from jax import lax, vmap
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from transformer_engine.jax.attention import AttnMaskType, make_swa_mask from transformer_engine.jax.attention import (
AttnMaskType,
canonicalize_attn_mask_type,
make_swa_mask,
)
from transformer_engine.jax.fp8 import DType as TEDType from transformer_engine.jax.fp8 import DType as TEDType
PRNGKey = Any PRNGKey = Any
...@@ -913,15 +917,7 @@ def apply_swa_mask( ...@@ -913,15 +917,7 @@ def apply_swa_mask(
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
) -> Array: ) -> Array:
"""Apply the sliding window mask to a given mask""" """Apply the sliding window mask to a given mask"""
mask_map = { _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type)
"no_mask": AttnMaskType.NO_MASK,
"padding": AttnMaskType.PADDING_MASK,
"causal": AttnMaskType.CAUSAL_MASK,
"padding_causal": AttnMaskType.PADDING_CAUSAL_MASK,
"causal_bottom_right": AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
"padding_causal_bottom_right": AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
}
_attn_mask_type = mask_map.get(attn_mask_type, None)
assert _attn_mask_type is not None assert _attn_mask_type is not None
max_seqlen_q = original_mask.shape[-2] max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1] max_seqlen_kv = original_mask.shape[-1]
......
...@@ -46,6 +46,42 @@ class AttnMaskType(Enum): ...@@ -46,6 +46,42 @@ class AttnMaskType(Enum):
CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
def is_causal(self):
"""Returns True if the mask is a causal mask"""
return self in [
AttnMaskType.CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_MASK,
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
def is_padding(self):
"""Returns True if the mask includes padding"""
return self in [
AttnMaskType.PADDING_MASK,
AttnMaskType.PADDING_CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
def is_bottom_right(self):
"""Returns True if the causal mask is calculated from the bottom-right section"""
return self in [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
class QKVFormat(Enum):
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
BSHD: q,k,v memory layout with [b, s, ..., h, d]
THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
"""
SBHD = NVTE_QKV_Format.NVTE_SBHD
BSHD = NVTE_QKV_Format.NVTE_BSHD
THD = NVTE_QKV_Format.NVTE_THD
class QKVLayout(Enum): class QKVLayout(Enum):
""" """
...@@ -66,17 +102,35 @@ class QKVLayout(Enum): ...@@ -66,17 +102,35 @@ class QKVLayout(Enum):
THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD
THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD
def get_qkv_format(self):
class QKVFormat(Enum): """
""" Return the corresponding qkv_format (BSHD, SBHD, THD)
SBHD: q,k,v memory layout with [s, b, ..., h, d] """
BSHD: q,k,v memory layout with [b, s, ..., h, d] return QKVFormat(nvte_get_qkv_format(self.value))
THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
""" def is_qkvpacked(self):
"""
SBHD = NVTE_QKV_Format.NVTE_SBHD Return True if the query, key, value is packed
BSHD = NVTE_QKV_Format.NVTE_BSHD """
THD = NVTE_QKV_Format.NVTE_THD return self in [QKVLayout.BS3HD, QKVLayout.T3HD]
def is_kvpacked(self):
"""
Return True if the key, value is packed
"""
return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD]
def is_separate(self):
"""
Return True if the query, key, value are three separate tensors
"""
return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD]
def is_thd(self):
"""
Return True if the layout belongs to THD
"""
return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
class CPStrategy(Enum): class CPStrategy(Enum):
...@@ -92,13 +146,6 @@ class CPStrategy(Enum): ...@@ -92,13 +146,6 @@ class CPStrategy(Enum):
RING = 2 RING = 2
def get_qkv_format(qkv_layout):
"""
Get qkv_format from qkv_layout
"""
return QKVFormat(nvte_get_qkv_format(qkv_layout.value))
def make_swa_mask( def make_swa_mask(
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_kv: int, max_seqlen_kv: int,
...@@ -136,12 +183,8 @@ def make_swa_mask( ...@@ -136,12 +183,8 @@ def make_swa_mask(
swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype) swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
if window_size is None: if window_size is None:
return swa_mask return swa_mask
bottom_right_masks = [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
left_window, right_window = window_size left_window, right_window = window_size
if attn_mask_type in bottom_right_masks: if attn_mask_type.is_bottom_right():
if left_window < 0: if left_window < 0:
left_window = max_seqlen_kv left_window = max_seqlen_kv
if right_window < 0: if right_window < 0:
...@@ -310,7 +353,7 @@ def fused_attn( ...@@ -310,7 +353,7 @@ def fused_attn(
(jnp.ndarray): The output tensor from the fused attention. (jnp.ndarray): The output tensor from the fused attention.
""" """
assert ( assert (
get_qkv_format(qkv_layout) != QKVFormat.THD not qkv_layout.is_thd()
), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format." ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format."
# Check inputs qkv # Check inputs qkv
...@@ -327,11 +370,7 @@ def fused_attn( ...@@ -327,11 +370,7 @@ def fused_attn(
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
# convert the mask to seqlens, mask doesn't support ragged offsets # convert the mask to seqlens, mask doesn't support ragged offsets
if attn_mask_type in [ if not attn_mask_type.is_padding():
AttnMaskType.NO_MASK,
AttnMaskType.CAUSAL_MASK,
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
]:
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32) q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
...@@ -448,7 +487,7 @@ def fused_attn_thd( ...@@ -448,7 +487,7 @@ def fused_attn_thd(
QKVLayout.T3HD, 0.125, 0, True, 3) QKVLayout.T3HD, 0.125, 0, True, 3)
""" """
assert ( assert (
get_qkv_format(qkv_layout) == QKVFormat.THD qkv_layout.is_thd()
), "Please use transformer_engine.jax.attention.fused_attn for non-THD format." ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."
# Check inputs qkv # Check inputs qkv
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for attention""" """JAX/TE custom ops for attention"""
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, reduce, cache from functools import partial, reduce
import operator import operator
import os import os
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -133,7 +133,6 @@ class FusedAttnHelper: ...@@ -133,7 +133,6 @@ class FusedAttnHelper:
) )
@staticmethod @staticmethod
@cache
def is_non_deterministic_allowed(): def is_non_deterministic_allowed():
"""Check if non-deterministic kernels are allowed""" """Check if non-deterministic kernels are allowed"""
return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment