Unverified Commit ca6fedcf authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Add BRCM support for THD (#2242)



* Add BRCM support when creating a test mask for fused attn
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add support for BRCM to correctly generate the mask needed for calculating the seqlens and offsets for THD
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Skip drop=0 and no_bias case for BRCM as cuDNN does not suport this
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Skip BRCM test cases where max_seqlen_q > max_seqlen_kv
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Refactor the segment id run length code for BRCM seqoffset and seqlens calculations
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix the drop inequality skip condition in fused attn
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* nit: Adjust the BRCM id name in the test to make it consistent
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix the brcm mask condition.
Fix the condition for cross atnn type pattern to only apply for brcm
Change the num segments per sequence to 3 instead of 2
Reduce one test pattern data size and make it such that it triggers brcm
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix lint errors
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix incorrectly changed dtype to numpy bool_ rather than native python bool
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Restore the numsegments to earlier value
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add example for THD BRCM
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent dfacd9f7
......@@ -32,6 +32,7 @@ from transformer_engine.jax.attention import (
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
run_length_fill,
make_swa_mask,
SequenceDescriptor,
CPStrategy,
......@@ -172,15 +173,34 @@ def make_mask(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
# causal mask
if attn_mask_type.is_causal():
if attn_mask_type.is_bottom_right():
run_length_out_q = run_length_fill(segment_ids_q)
run_length_out_kv = run_length_fill(segment_ids_kv)
bottom_right_causal_mask = make_attention_mask(
run_length_out_q - segment_pos_q,
run_length_out_kv - segment_pos_kv,
jnp.less_equal,
)
inv_mask = combine_masks(bottom_right_causal_mask, inv_mask)
elif attn_mask_type.is_causal():
inv_causal_mask = make_attention_mask(
segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
)
inv_mask = combine_masks(inv_causal_mask, inv_mask)
# sliding window mask
inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
inv_swa_mask = (
make_swa_mask(
segment_pos_q,
segment_pos_kv,
window_size,
dtype=jnp.bool,
segment_ids_q=segment_ids_q,
segment_ids_kv=segment_ids_kv,
)
if attn_mask_type.is_bottom_right()
else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool_)
)
inv_mask = combine_masks(inv_mask, inv_swa_mask)
mask = jnp.logical_not(inv_mask)
return mask
......@@ -338,6 +358,16 @@ class FusedAttnRunner:
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")
if self.attn_mask_type.is_bottom_right():
if self.max_seqlen_q > self.max_seqlen_kv:
pytest.skip(
f"BRCM requires cross attn type pattern, i.e.max_seqlen_kv >= max_seqlen_q"
)
if self.attn_bias_type is not AttnBiasType.NO_BIAS:
pytest.skip(f"cuDNN does not support pre or post scale bias for BRCM")
if self.dropout_prob != 0.0:
pytest.skip(f"cuDNN does not support non-zero dropoouts for BRCM")
if self.qkv_layout.is_qkvpacked():
if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
......@@ -526,7 +556,11 @@ class FusedAttnRunner:
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
min_segment_len = None
if (
self.window_size is not None or self.attn_mask_type.is_bottom_right()
): # SWA or BRCM requires kv_len >= q_len
min_segment_len = self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
......@@ -937,6 +971,9 @@ class FusedAttnRunner:
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
pytest.param(
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
),
],
)
@pytest.mark.parametrize(
......@@ -958,14 +995,14 @@ class FusedAttnRunner:
),
pytest.param(
2,
2048,
512,
1024,
12,
12,
64,
64,
jnp.bfloat16,
id="2-2048-1024-12-12-64-64-BF16-CROSS",
id="2-512-1024-12-12-64-64-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
......
......@@ -209,6 +209,8 @@ def make_swa_mask(
segment_pos_kv: jnp.ndarray,
window_size: Optional[Tuple[int, int]] = None,
dtype: jax.typing.DTypeLike = jnp.float32,
segment_ids_q: jnp.ndarray = None,
segment_ids_kv: jnp.ndarray = None,
):
"""
Generate a sliding window mask (1 = attend, 0 = masked).
......@@ -227,6 +229,10 @@ def make_swa_mask(
Defaults to None.
dtype (jax.typing.DTypeLike, optional):
Mask data type. Defaults to jnp.float32.
segment_ids_q (jnp.ndarray):
Query segment id that each token belongs to
segment_ids_kv (jnp.ndarray):
Key/value segment id that each token belongs to
Returns:
jnp.ndarray:
......@@ -240,6 +246,18 @@ def make_swa_mask(
right_window = jnp.inf if right_window < 0 else right_window
pos_q = jnp.expand_dims(segment_pos_q, axis=-1)
pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2)
# For Bottom Right Causal Mask (BRCM)
if segment_ids_q is not None and segment_ids_kv is not None:
run_length_q = run_length_fill(segment_ids_q)
run_length_kv = run_length_fill(segment_ids_kv)
run_length_q_exp = jnp.expand_dims(run_length_q, axis=-1)
run_length_kv_exp = jnp.expand_dims(run_length_kv, axis=-2)
bottom_right_inv_swa_mask = (
run_length_q_exp - pos_q + left_window >= run_length_kv_exp - pos_kv
)
bottom_right_inv_swa_mask = jnp.expand_dims(bottom_right_inv_swa_mask, axis=-3)
return bottom_right_inv_swa_mask.astype(dtype)
# All other cases other than BRCM
inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window)
inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3)
return inv_swa_mask.astype(dtype)
......@@ -420,6 +438,42 @@ def _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
)
def run_length_fill_flattened(segment_ids_flattened) -> jnp.ndarray:
"""
Returns an array of run-lengths of the flattened segment ids
"""
# Example for run_length_fill_flattened:
# Input segment_ids_flattened: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]]
# run_ids: [[0 0 1 1 1 2 3 4 5 5 5 5 5 6 6 6], [0 1 1 2 2 2 3 3 4 4 5 5 5 5 6 6]]
# counts: [[2 3 1 1 1 5 3 0 0 0 0 0 0 0 0 0], [1 2 3 2 2 4 2 0 0 0 0 0 0 0 0 0]]
# Returns segment_ids_run_length_1d: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]]
boundary = jnp.concatenate(
[jnp.broadcast_to(True, (1,)), segment_ids_flattened[1:] != segment_ids_flattened[:-1]]
)
run_ids = jnp.cumsum(boundary) - 1
# Each element could, in worst case, start a run
max_runs = segment_ids_flattened.shape[-1]
counts = jnp.bincount(run_ids, length=max_runs)
# Fill in the missing values
segment_ids_run_length_1d = counts[run_ids]
segment_ids_run_length_1d = jnp.where(segment_ids_flattened == 0, 0, segment_ids_run_length_1d)
return segment_ids_run_length_1d
def run_length_fill(segment_ids) -> jnp.ndarray:
"""
Returns an array of run-lengths of the segment ids, with shape preserved
"""
# Example for run_length_fill:
# Input segment_ids: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]]
# Returns run length: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]]
# Flatten all dimension except the last one prior to executing vmap run length
orig_shape = segment_ids.shape
segment_ids_flat = segment_ids.reshape(-1, orig_shape[-1])
run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat)
return run_length_segment_id_shape.reshape(orig_shape)
def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q,
segment_ids_kv,
......@@ -443,7 +497,10 @@ def _segment_ids_pos_to_seqlens_offsets(
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1):
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if (attn_mask_type.is_causal() and window_size is None) or (
window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
):
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)
......@@ -459,8 +516,41 @@ def _segment_ids_pos_to_seqlens_offsets(
segment_ids_kv,
lambda x, y: jnp.equal(x, y) * x,
)
# TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied
attn_mask = segment_mask
if attn_mask_type.is_causal():
if attn_mask_type.is_bottom_right():
run_length_out_q = run_length_fill(segment_ids_q)
run_length_out_kv = run_length_fill(segment_ids_kv)
# Example for brcm:
# run_length_out_q: [3 3 3 0 4 4 4 4]
# segment_pos_q: [0 1 2 3 0 1 2 3]
# segment_ids_q: [1 1 1 0 2 2 2 2]
# run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10]
# segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9]
# segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2]
# brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
# [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
# [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]]
# attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]]
bottom_right_causal_mask = make_attention_mask(
run_length_out_q - segment_pos_q,
run_length_out_kv - segment_pos_kv,
jnp.less_equal,
)
attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask)
elif attn_mask_type.is_causal():
causal_mask = make_attention_mask(
segment_pos_q,
segment_pos_kv,
......@@ -468,7 +558,19 @@ def _segment_ids_pos_to_seqlens_offsets(
)
attn_mask = jnp.logical_and(segment_mask, causal_mask)
swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
# TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask = (
make_swa_mask(
segment_pos_q,
segment_pos_kv,
window_size,
dtype=jnp.bool,
segment_ids_q=segment_ids_q,
segment_ids_kv=segment_ids_kv,
)
if attn_mask_type.is_bottom_right()
else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
)
attn_mask = jnp.logical_and(attn_mask, swa_mask)
attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
......@@ -1125,5 +1227,4 @@ def fused_attn(
context_parallel_axis=context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment