Unverified Commit 4654b70a authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Remove unnecessary SWA calculation in _segment_ids_pos_to_seqlens_offsets() (#2201)



* Remove unnecessary SWA calculation from _segment_ids_pos_to_seqlens_offsets
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent beed55b9
......@@ -530,6 +530,11 @@ def _segment_ids_pos_to_seqlens_offsets(
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if (attn_mask_type.is_causal() and window_size is None) or (
window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
......@@ -591,21 +596,6 @@ def _segment_ids_pos_to_seqlens_offsets(
)
attn_mask = jnp.logical_and(segment_mask, causal_mask)
# TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask = (
make_swa_mask(
segment_pos_q,
segment_pos_kv,
window_size,
dtype=jnp.bool,
segment_ids_q=segment_ids_q,
segment_ids_kv=segment_ids_kv,
)
if attn_mask_type.is_bottom_right()
else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
)
attn_mask = jnp.logical_and(attn_mask, swa_mask)
attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
attn_mask_with_id, max_segments_per_seq
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment