"git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "c312f1756760a7cf66f24833dc2bf27be2e40433"
Unverified Commit 26c82db6 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Fix incorrect calculation of segment pos from segment ids in user-facing API (#2523)



* Fix incorrect calculation of segment pos from segment ids for thd cases and load balanced cases in from_segment_ids_and_pos. Enforce passing of segment_pos for THD cases and lod balanced cases
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Correct the assert condition
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Modify fused attn tests to pass new args to from_segment_ids_and_pos()
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Calculate seg ids before pos
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* 1. Change the signature for from_segment_ids_and_pos()
2. Add support for THD in from_segment_ids_and_pos()
3. Assert if load balanced segment_ids is passed to generate a segment_pos
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Pass keyword-only args by name
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* nit: Fix typo to use seg_ids instead of segment_ids
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* nit: Fix comments
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Modify the function call to differentiate between load balancing and actually reordered segment_ids and segment_pos
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the is_segment_ids_reordered to be set only when CP and load balancing
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix comments for from_segment_ids_and_pos()
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Code clean up

for more information, see https://pre-commit.ci



Fix lint errors
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5ba01faa
......@@ -668,14 +668,24 @@ class FusedAttnRunner:
(self.offsets_q, self.offsets_kv),
)
case SeqDescFormat.SegmentIDs:
# Exercise the path to generate the segment_pos in from_segment_ids_and_pos()
# if no CP and load balancing, else explicitly pass the segment_pos
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(
self.cp_reorder_fn(self.segment_ids_q),
self.cp_reorder_fn(self.segment_ids_kv),
),
(
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
)
if self.cp_size > 1 and self.cp_load_balanced
else None
),
is_thd=self.qkv_layout.is_thd(),
is_segment_ids_reordered=(
True if self.cp_size > 1 and self.cp_load_balanced else False
),
)
case _:
......@@ -704,6 +714,8 @@ class FusedAttnRunner:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
None,
is_thd=self.qkv_layout.is_thd(),
is_segment_ids_reordered=False,
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
......
......@@ -658,7 +658,7 @@ class SequenceDescriptor:
- SequenceDescriptor.from_seqlens_and_offsets
For THD (packed) cases, where each batch may have not only 1 sequence.
- SequenceDescriptor.from_segment_ids_and_pos
Experimental feature for THD (packed) cases with context parallelism.
Experimental feature for BSHD (with and without reordering) and THD (packed) cases without reordering
"""
seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
......@@ -796,9 +796,14 @@ class SequenceDescriptor:
cls,
segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None,
*,
is_thd: bool,
is_segment_ids_reordered: bool,
) -> SequenceDescriptor:
"""
Experimental factory method for inputs with segment IDs and optional positions. (THD)
Experimental factory method for inputs with segment IDs and optional positions.
segment_pos = None to be used only for: BSHD with or without load balancing and,
THD without load balancing
Args:
segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
- q_segment_ids (jnp.ndarray):
......@@ -812,22 +817,84 @@ class SequenceDescriptor:
The position inside each segment for query, with shape [batch, max_seqlen].
- kv_segment_pos (jnp.ndarray):
The position inside each segment for key, value, with shape [batch, max_seqlen].
is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD
is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing.
Only THD with load balancing is expected to have this flag set to True
Return:
A SequenceDescriptor with segment_ids/segment_pos initialized.
"""
q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids)
if segment_pos is not None:
segment_pos = cls._expand_to_pair(segment_pos)
else:
# Using defaults : segment pos has to be generated.
if segment_pos is None:
# THD + load balanced segment_ids are not supported in this function
# BSHD + load balanced segment_ids are incorrect as BSHD handles reordering within the primitive itself
if is_segment_ids_reordered:
assert not is_thd, (
f"{segment_pos=} default arg is not supported for load balanced reordered"
" (Striped) THD inputs. Please pass the load balanced reordered segment_pos"
" and segment_ids explicitly to {from_segment_ids_and_pos.__qualname__}"
" using convenience function reorder_causal_load_balancing()"
)
assert is_thd, (
f"{segment_pos=} default arg is not supported for load balanced reordered (Dual"
" Chunk) BSHD inputs. BSHD segment_pos and segment_ids do not need to be load"
" balanced reordered. The reordering for these is performed within the"
" primitive"
)
def generate_default_pos(segment_ids):
seqlen = segment_ids.shape[-1]
return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape)
# Generate the default pos for THD and BSHD non-reordered segment_ids
def generate_default_pos(seg_ids):
if is_thd:
batch_size, seq_size = seg_ids.shape
# Assume that the first token belongs to a segment and is not a padded token
first_is_segment = jnp.full((batch_size, 1), True, dtype=bool)
# Get segment start positions
segment_start = jnp.concatenate(
[
first_is_segment,
(seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0),
],
axis=-1,
)
# Get offset for location where new segment starts
segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)(
segment_start
)
segment_start_offsets = jax.vmap(jnp.maximum.accumulate)(segment_start_idx)
# Get the last non-zero index - after this everything is padding
# (B,)
last_nonzero_idx = jax.vmap(
lambda segids_row: jnp.max(
jnp.where(segids_row != 0, jnp.arange(seq_size), -1)
)
)(seg_ids)
seg_pos_no_thd = jnp.arange(seq_size)
# Get a mask which can be used to zero out all the padding at the end (after the non-zero index)
mask = seg_pos_no_thd <= last_nonzero_idx[:, None]
# Get the unmasked seg_pos for the THD sequence
seg_pos = (
jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape)
- segment_start_offsets
)
# Use the mask to zero out the padding at the end (after the non-zero index)
segment_pos = jax.vmap(
lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0)
)(seg_pos, mask)
return segment_pos
seqlen = seg_ids.shape[-1]
return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape)
q_seg_pos = generate_default_pos(q_seg_ids)
kv_seg_pos = generate_default_pos(kv_seg_ids)
segment_pos = (q_seg_pos, kv_seg_pos)
# Explicitly passed segment_pos
else:
segment_pos = cls._expand_to_pair(segment_pos)
return cls(
segment_ids=(q_seg_ids, kv_seg_ids),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment