Unverified Commit 3b1f5a11 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Add fast path for causal masking with segment IDs. (#1601)



Add fast path for causal masking with segment IDs.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 76187a5e
......@@ -378,6 +378,44 @@ def _mask_to_seqlens_offset(mask, max_segments_per_seq):
return q_seqlen, q_offset, kv_seqlen, kv_offset
def _fast_causal_adjust_seqlen_and_offsets(
segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
):
# The assumption is that for any segment tokens respect causal ordering except at the ends
# of the segment. This allows us to tweak the length and offset by only looking at the start
# and end tokens between segments.
is_active_segment = jnp.logical_and(q_len > 0, kv_len > 0)
q_seq_id_start = jnp.take(segment_pos_q, q_offset[..., :-1], fill_value=-1)
kv_seq_id_start = jnp.take(segment_pos_kv, kv_offset[..., :-1], fill_value=-1)
skip_start_token = jnp.logical_and(kv_seq_id_start > q_seq_id_start, is_active_segment).astype(
jnp.int32
)
q_len -= skip_start_token
q_offset += jnp.insert(skip_start_token, skip_start_token.shape[-1], 0, axis=-1)
q_seq_id_end = jnp.take(segment_pos_q, q_offset[..., 1:] - 1, fill_value=-1)
kv_seq_id_end = jnp.take(segment_pos_kv, kv_offset[..., 1:] - 1, fill_value=-1)
skip_end_token = jnp.logical_and(kv_seq_id_end > q_seq_id_end, is_active_segment).astype(
jnp.int32
)
kv_len -= skip_end_token
return q_len, kv_len, q_offset, kv_offset
def _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
):
q_len, q_offset = _get_seqlens_and_offsets(segment_ids_q, max_segments_per_seq)
kv_len, kv_offset = _get_seqlens_and_offsets(segment_ids_kv, max_segments_per_seq)
return _fast_causal_adjust_seqlen_and_offsets(
segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
)
def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q,
segment_ids_kv,
......@@ -387,6 +425,25 @@ def _segment_ids_pos_to_seqlens_offsets(
window_size,
max_segments_per_seq,
):
# TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here.
# Computing the full mask is expensive due to quadratic expansion of Q * KV masking.
# Assumptions for cudnn causal mask correctness.
# 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0]
# 2. No intra-segment padding, only inter-segment paddding allowed
# 3. Only start or end token within a segment may violate the causal order relationship
# 1 5 9 0 4 8 10 0 4 8
# 0 x x
# 4 x x x x x
# 8 x x x x x x x x
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1):
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)
# (1 = attend, 0 = masked)
segment_mask = make_attention_mask(
segment_ids_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment