Unverified Commit ca6fedcf authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Add BRCM support for THD (#2242)



* Add BRCM support when creating a test mask for fused attn
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add support for BRCM to correctly generate the mask needed for calculating the seqlens and offsets for THD
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Skip drop=0 and no_bias case for BRCM as cuDNN does not suport this
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Skip BRCM test cases where max_seqlen_q > max_seqlen_kv
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Refactor the segment id run length code for BRCM seqoffset and seqlens calculations
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix the drop inequality skip condition in fused attn
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* nit: Adjust the BRCM id name in the test to make it consistent
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix the brcm mask condition.
Fix the condition for cross atnn type pattern to only apply for brcm
Change the num segments per sequence to 3 instead of 2
Reduce one test pattern data size and make it such that it triggers brcm
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix lint errors
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix incorrectly changed dtype to numpy bool_ rather than native python bool
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Restore the numsegments to earlier value
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add example for THD BRCM
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent dfacd9f7
...@@ -32,6 +32,7 @@ from transformer_engine.jax.attention import ( ...@@ -32,6 +32,7 @@ from transformer_engine.jax.attention import (
reorder_causal_load_balancing, reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing, inverse_reorder_causal_load_balancing,
fused_attn, fused_attn,
run_length_fill,
make_swa_mask, make_swa_mask,
SequenceDescriptor, SequenceDescriptor,
CPStrategy, CPStrategy,
...@@ -172,15 +173,34 @@ def make_mask( ...@@ -172,15 +173,34 @@ def make_mask(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
) )
# causal mask if attn_mask_type.is_bottom_right():
if attn_mask_type.is_causal(): run_length_out_q = run_length_fill(segment_ids_q)
run_length_out_kv = run_length_fill(segment_ids_kv)
bottom_right_causal_mask = make_attention_mask(
run_length_out_q - segment_pos_q,
run_length_out_kv - segment_pos_kv,
jnp.less_equal,
)
inv_mask = combine_masks(bottom_right_causal_mask, inv_mask)
elif attn_mask_type.is_causal():
inv_causal_mask = make_attention_mask( inv_causal_mask = make_attention_mask(
segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y) segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
) )
inv_mask = combine_masks(inv_causal_mask, inv_mask) inv_mask = combine_masks(inv_causal_mask, inv_mask)
# sliding window mask # sliding window mask
inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_) inv_swa_mask = (
make_swa_mask(
segment_pos_q,
segment_pos_kv,
window_size,
dtype=jnp.bool,
segment_ids_q=segment_ids_q,
segment_ids_kv=segment_ids_kv,
)
if attn_mask_type.is_bottom_right()
else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool_)
)
inv_mask = combine_masks(inv_mask, inv_swa_mask) inv_mask = combine_masks(inv_mask, inv_swa_mask)
mask = jnp.logical_not(inv_mask) mask = jnp.logical_not(inv_mask)
return mask return mask
...@@ -338,6 +358,16 @@ class FusedAttnRunner: ...@@ -338,6 +358,16 @@ class FusedAttnRunner:
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.") pytest.skip("THD format requires padding masks.")
if self.attn_mask_type.is_bottom_right():
if self.max_seqlen_q > self.max_seqlen_kv:
pytest.skip(
f"BRCM requires cross attn type pattern, i.e.max_seqlen_kv >= max_seqlen_q"
)
if self.attn_bias_type is not AttnBiasType.NO_BIAS:
pytest.skip(f"cuDNN does not support pre or post scale bias for BRCM")
if self.dropout_prob != 0.0:
pytest.skip(f"cuDNN does not support non-zero dropoouts for BRCM")
if self.qkv_layout.is_qkvpacked(): if self.qkv_layout.is_qkvpacked():
if self.max_seqlen_q != self.max_seqlen_kv: if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
...@@ -526,7 +556,11 @@ class FusedAttnRunner: ...@@ -526,7 +556,11 @@ class FusedAttnRunner:
self.pad_kv = self.pad_q self.pad_kv = self.pad_q
else: else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q min_segment_len = None
if (
self.window_size is not None or self.attn_mask_type.is_bottom_right()
): # SWA or BRCM requires kv_len >= q_len
min_segment_len = self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size, self.batch_size,
self.max_seqlen_kv, self.max_seqlen_kv,
...@@ -937,6 +971,9 @@ class FusedAttnRunner: ...@@ -937,6 +971,9 @@ class FusedAttnRunner:
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"), pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"), pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
pytest.param(
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -958,14 +995,14 @@ class FusedAttnRunner: ...@@ -958,14 +995,14 @@ class FusedAttnRunner:
), ),
pytest.param( pytest.param(
2, 2,
2048, 512,
1024, 1024,
12, 12,
12, 12,
64, 64,
64, 64,
jnp.bfloat16, jnp.bfloat16,
id="2-2048-1024-12-12-64-64-BF16-CROSS", id="2-512-1024-12-12-64-64-BF16-CROSS",
), ),
pytest.param( pytest.param(
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA" 2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
......
...@@ -209,6 +209,8 @@ def make_swa_mask( ...@@ -209,6 +209,8 @@ def make_swa_mask(
segment_pos_kv: jnp.ndarray, segment_pos_kv: jnp.ndarray,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
dtype: jax.typing.DTypeLike = jnp.float32, dtype: jax.typing.DTypeLike = jnp.float32,
segment_ids_q: jnp.ndarray = None,
segment_ids_kv: jnp.ndarray = None,
): ):
""" """
Generate a sliding window mask (1 = attend, 0 = masked). Generate a sliding window mask (1 = attend, 0 = masked).
...@@ -227,6 +229,10 @@ def make_swa_mask( ...@@ -227,6 +229,10 @@ def make_swa_mask(
Defaults to None. Defaults to None.
dtype (jax.typing.DTypeLike, optional): dtype (jax.typing.DTypeLike, optional):
Mask data type. Defaults to jnp.float32. Mask data type. Defaults to jnp.float32.
segment_ids_q (jnp.ndarray):
Query segment id that each token belongs to
segment_ids_kv (jnp.ndarray):
Key/value segment id that each token belongs to
Returns: Returns:
jnp.ndarray: jnp.ndarray:
...@@ -240,6 +246,18 @@ def make_swa_mask( ...@@ -240,6 +246,18 @@ def make_swa_mask(
right_window = jnp.inf if right_window < 0 else right_window right_window = jnp.inf if right_window < 0 else right_window
pos_q = jnp.expand_dims(segment_pos_q, axis=-1) pos_q = jnp.expand_dims(segment_pos_q, axis=-1)
pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2) pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2)
# For Bottom Right Causal Mask (BRCM)
if segment_ids_q is not None and segment_ids_kv is not None:
run_length_q = run_length_fill(segment_ids_q)
run_length_kv = run_length_fill(segment_ids_kv)
run_length_q_exp = jnp.expand_dims(run_length_q, axis=-1)
run_length_kv_exp = jnp.expand_dims(run_length_kv, axis=-2)
bottom_right_inv_swa_mask = (
run_length_q_exp - pos_q + left_window >= run_length_kv_exp - pos_kv
)
bottom_right_inv_swa_mask = jnp.expand_dims(bottom_right_inv_swa_mask, axis=-3)
return bottom_right_inv_swa_mask.astype(dtype)
# All other cases other than BRCM
inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window) inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window)
inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3) inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3)
return inv_swa_mask.astype(dtype) return inv_swa_mask.astype(dtype)
...@@ -420,6 +438,42 @@ def _segment_ids_pos_to_seqlens_offsets_fast_causal_path( ...@@ -420,6 +438,42 @@ def _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
) )
def run_length_fill_flattened(segment_ids_flattened) -> jnp.ndarray:
"""
Returns an array of run-lengths of the flattened segment ids
"""
# Example for run_length_fill_flattened:
# Input segment_ids_flattened: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]]
# run_ids: [[0 0 1 1 1 2 3 4 5 5 5 5 5 6 6 6], [0 1 1 2 2 2 3 3 4 4 5 5 5 5 6 6]]
# counts: [[2 3 1 1 1 5 3 0 0 0 0 0 0 0 0 0], [1 2 3 2 2 4 2 0 0 0 0 0 0 0 0 0]]
# Returns segment_ids_run_length_1d: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]]
boundary = jnp.concatenate(
[jnp.broadcast_to(True, (1,)), segment_ids_flattened[1:] != segment_ids_flattened[:-1]]
)
run_ids = jnp.cumsum(boundary) - 1
# Each element could, in worst case, start a run
max_runs = segment_ids_flattened.shape[-1]
counts = jnp.bincount(run_ids, length=max_runs)
# Fill in the missing values
segment_ids_run_length_1d = counts[run_ids]
segment_ids_run_length_1d = jnp.where(segment_ids_flattened == 0, 0, segment_ids_run_length_1d)
return segment_ids_run_length_1d
def run_length_fill(segment_ids) -> jnp.ndarray:
"""
Returns an array of run-lengths of the segment ids, with shape preserved
"""
# Example for run_length_fill:
# Input segment_ids: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]]
# Returns run length: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]]
# Flatten all dimension except the last one prior to executing vmap run length
orig_shape = segment_ids.shape
segment_ids_flat = segment_ids.reshape(-1, orig_shape[-1])
run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat)
return run_length_segment_id_shape.reshape(orig_shape)
def _segment_ids_pos_to_seqlens_offsets( def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q, segment_ids_q,
segment_ids_kv, segment_ids_kv,
...@@ -443,7 +497,10 @@ def _segment_ids_pos_to_seqlens_offsets( ...@@ -443,7 +497,10 @@ def _segment_ids_pos_to_seqlens_offsets(
# #
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements. # examine only O(Q+KV) elements.
if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1): # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if (attn_mask_type.is_causal() and window_size is None) or (
window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
):
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
) )
...@@ -459,8 +516,41 @@ def _segment_ids_pos_to_seqlens_offsets( ...@@ -459,8 +516,41 @@ def _segment_ids_pos_to_seqlens_offsets(
segment_ids_kv, segment_ids_kv,
lambda x, y: jnp.equal(x, y) * x, lambda x, y: jnp.equal(x, y) * x,
) )
# TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied
attn_mask = segment_mask attn_mask = segment_mask
if attn_mask_type.is_causal(): if attn_mask_type.is_bottom_right():
run_length_out_q = run_length_fill(segment_ids_q)
run_length_out_kv = run_length_fill(segment_ids_kv)
# Example for brcm:
# run_length_out_q: [3 3 3 0 4 4 4 4]
# segment_pos_q: [0 1 2 3 0 1 2 3]
# segment_ids_q: [1 1 1 0 2 2 2 2]
# run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10]
# segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9]
# segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2]
# brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
# [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
# [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]]
# attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]]
bottom_right_causal_mask = make_attention_mask(
run_length_out_q - segment_pos_q,
run_length_out_kv - segment_pos_kv,
jnp.less_equal,
)
attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask)
elif attn_mask_type.is_causal():
causal_mask = make_attention_mask( causal_mask = make_attention_mask(
segment_pos_q, segment_pos_q,
segment_pos_kv, segment_pos_kv,
...@@ -468,7 +558,19 @@ def _segment_ids_pos_to_seqlens_offsets( ...@@ -468,7 +558,19 @@ def _segment_ids_pos_to_seqlens_offsets(
) )
attn_mask = jnp.logical_and(segment_mask, causal_mask) attn_mask = jnp.logical_and(segment_mask, causal_mask)
swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) # TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask = (
make_swa_mask(
segment_pos_q,
segment_pos_kv,
window_size,
dtype=jnp.bool,
segment_ids_q=segment_ids_q,
segment_ids_kv=segment_ids_kv,
)
if attn_mask_type.is_bottom_right()
else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
)
attn_mask = jnp.logical_and(attn_mask, swa_mask) attn_mask = jnp.logical_and(attn_mask, swa_mask)
attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
...@@ -1125,5 +1227,4 @@ def fused_attn( ...@@ -1125,5 +1227,4 @@ def fused_attn(
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
context_checkpoint_name=context_checkpoint_name, context_checkpoint_name=context_checkpoint_name,
) )
return output return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment