test_permutation.py 35.1 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

"""Tests for permutation Triton kernels and high-level APIs"""

7
8
import functools

9
10
11
12
13
14
15
16
17
18
import jax
import jax.numpy as jnp
import pytest

# High-level API with VJP support
from transformer_engine.jax.permutation import (
    token_dispatch,
    token_combine,
    sort_chunks_by_index,
)
19
20
21
22
23
24
25
from utils import assert_allclose, pytest_parametrize_wrapper


ALL_DISPATCH_COMBINE_CASES = [
    (128, 5, 128, 3),
    (1024, 8, 128, 8),
    (4096, 32, 1280, 2),
26
    (4096, 64, 4096, 6),
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
]
DISPATCH_COMBINE_CASES = {
    "L0": ALL_DISPATCH_COMBINE_CASES[0:2],
    "L2": ALL_DISPATCH_COMBINE_CASES,
}

ALL_SORT_CHUNKS_CASES = [
    (8, 4096, 1280),
    (64, 4096, 4096),
    (256, 4096, 9216),
]
SORT_CHUNKS_CASES = {
    "L0": ALL_SORT_CHUNKS_CASES[0:2],
    "L2": ALL_SORT_CHUNKS_CASES,
}

ALL_DISPATCH_COMBINE_PADDING_CASES = [
    (128, 5, 128, 3, 8),
    (1024, 8, 128, 8, 16),
    (4096, 32, 1280, 2, 128),
47
    (4096, 64, 4096, 6, 16),
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
]
DISPATCH_COMBINE_PADDING_CASES = {
    "L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2],
    "L2": ALL_DISPATCH_COMBINE_PADDING_CASES,
}

ALL_DTYPES = [jnp.float32, jnp.bfloat16]
DTYPES = {
    "L0": ALL_DTYPES,
    "L2": ALL_DTYPES,
}

ALL_WITH_PROBS = [True, False]
WITH_PROBS = {
    "L0": [True],
    "L2": ALL_WITH_PROBS,
}
65
66
67
68
69
70


def reference_make_row_id_map(
    routing_map: jnp.ndarray,
) -> jnp.ndarray:
    """
71
    Vectorized reference implementation of make_row_id_map using JAX primitives.
72
73
74
75
76
77
78
79
80
81
82
83

    Parameters
    ----------
    routing_map : jnp.ndarray
        Input tensor of shape [num_tokens, num_experts]. Mask indicating which experts
        are routed to which tokens (1 = routed, 0 = not routed).

    Returns
    -------
    row_id_map : jnp.ndarray
        The row_id_map for the permutation of shape [num_tokens, num_experts * 2 + 1].
    """
84
    num_tokens, num_experts = routing_map.shape
85
86
87
88

    # For each expert, compute cumulative sum to get destination indices
    cumsum_per_expert = jnp.cumsum(routing_map, axis=0)

89
    # Compute total tokens per expert and expert offsets
90
    tokens_per_expert = jnp.sum(routing_map, axis=0)
91
92
93
    expert_offsets = jnp.concatenate(
        [jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)]
    )
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    # Compute destination rows for all (token, expert) pairs
    # dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1
    dest_rows_all = (expert_offsets[None, :] + cumsum_per_expert - 1) * routing_map + (-1) * (
        1 - routing_map
    )

    # Count routed experts per token
    n_routed_per_token = jnp.sum(routing_map, axis=1)

    # For each token, we need to sort by descending dest_row and pack into row_id_map
    # Use a large negative value for non-routed experts so they sort to the end
    sort_keys = jnp.where(routing_map == 1, -dest_rows_all, jnp.iinfo(jnp.int32).max)
    sorted_expert_indices = jnp.argsort(sort_keys, axis=1)

    # Gather the sorted destination rows and expert indices using advanced indexing
    # Create indices for gathering
111
112
113
    token_idx = jnp.broadcast_to(
        jnp.arange(num_tokens, dtype=jnp.int32)[:, None], (num_tokens, num_experts)
    )
114
115
116
117
118
119
120
121
122
123
124
    sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices]

    # Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed]
    row_id_map = jnp.concatenate(
        [
            sorted_dest_rows.astype(jnp.int32),
            sorted_expert_indices.astype(jnp.int32),
            n_routed_per_token.astype(jnp.int32)[:, None],
        ],
        axis=1,
    )
125
126
127
128
129
130
131
132
133
134
135

    return row_id_map


def _reference_permute_impl(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    probs: jnp.ndarray,
    num_out_tokens: int,
) -> tuple:
    """
136
    Vectorized internal helper for reference permutation implementation.
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [num_tokens, hidden_size].
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
    probs : jnp.ndarray
        The probabilities of the input tensor.
    num_out_tokens : int
        Number of tokens in the permuted tensor.

    Returns
    -------
    output : jnp.ndarray
        Permuted output tensor of shape [num_out_tokens, hidden_size].
    permuted_probs : jnp.ndarray
        Permuted probabilities if probs was provided, None otherwise.
    """
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    num_tokens, hidden_size = inp.shape
    num_experts = (row_id_map.shape[1] - 1) // 2

    # Extract destination rows, expert indices, and n_routed from row_id_map
    dest_rows = row_id_map[:, :num_experts]  # [num_tokens, num_experts]
    expert_indices = row_id_map[:, num_experts : 2 * num_experts]  # [num_tokens, num_experts]
    n_routed = row_id_map[:, 2 * num_experts]  # [num_tokens]

    # Create mask for valid entries: slot_idx < n_routed[token]
    # The kernel's row_id_map only guarantees valid data in the first n_routed slots
    # (slots beyond n_routed may contain garbage, not -1)
    slot_indices = jnp.arange(num_experts)[None, :]  # [1, num_experts]
    valid_mask = slot_indices < n_routed[:, None]  # [num_tokens, num_experts]

    # Flatten for scatter operations
    flat_dest_rows = dest_rows.flatten()  # [num_tokens * num_experts]
    flat_valid_mask = valid_mask.flatten()
    flat_token_indices = jnp.repeat(jnp.arange(num_tokens), num_experts)
    flat_expert_indices = expert_indices.flatten()

    # Set invalid dest_rows to num_out_tokens (out of bounds, will be dropped)
    # This avoids overwriting valid entries at index 0 with zeros
    flat_dest_rows_clamped = jnp.where(flat_valid_mask, flat_dest_rows, num_out_tokens)

    # Gather input tokens and scatter to output
181
    output = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    gathered_inp = inp[flat_token_indices]  # [num_tokens * num_experts, hidden_size]

    # Use segment_sum-like operation via scatter
    # For each valid (token, expert) pair, write inp[token] to output[dest_row]
    # Invalid entries target num_out_tokens and get dropped by mode="drop"
    output = output.at[flat_dest_rows_clamped].set(
        gathered_inp,
        mode="drop",
    )

    permuted_probs = None
    if probs is not None:
        permuted_probs = jnp.zeros((num_out_tokens,), dtype=probs.dtype)

        # Vectorized approach: gather probs and scatter to permuted_probs
        if probs.ndim == 1:
            flat_probs = probs[flat_token_indices]
        else:
            # Clamp invalid expert indices to 0 to avoid wraparound indexing with -1
            # The result for invalid entries will be ignored anyway since they target num_out_tokens
            # Cast to int32 explicitly for consistent indexing behavior
            flat_expert_indices_clamped = jnp.where(flat_valid_mask, flat_expert_indices, 0).astype(
                jnp.int32
            )
            flat_probs = probs[flat_token_indices.astype(jnp.int32), flat_expert_indices_clamped]

        # Invalid entries target num_out_tokens and get dropped by mode="drop"
        permuted_probs = permuted_probs.at[flat_dest_rows_clamped.astype(jnp.int32)].set(
            flat_probs,
            mode="drop",
        )
213
214
215
216
217
218
219
220
221
222
223

    return output, permuted_probs


def _reference_unpermute_impl(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: jnp.ndarray,
    permuted_probs: jnp.ndarray,
) -> tuple:
    """
224
    Vectorized internal helper for reference unpermutation implementation.
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [num_out_tokens, hidden_size].
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
    merging_probs : jnp.ndarray
        The merging probabilities for weighted reduction.
    permuted_probs : jnp.ndarray
        The permuted probabilities.

    Returns
    -------
    output : jnp.ndarray
        Unpermuted output tensor of shape [num_tokens, hidden_size].
    unpermuted_probs : jnp.ndarray
        Unpermuted probabilities if permuted_probs was provided, None otherwise.
    """
244
245
    num_tokens = row_id_map.shape[0]
    num_experts = (row_id_map.shape[1] - 1) // 2
246

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    # Extract source rows, expert indices, and n_routed from row_id_map
    src_rows = row_id_map[:, :num_experts]  # [num_tokens, num_experts]
    expert_indices = row_id_map[:, num_experts : 2 * num_experts]  # [num_tokens, num_experts]
    n_routed = row_id_map[:, 2 * num_experts]  # [num_tokens]

    # Create mask for valid entries: slot_idx < n_routed[token]
    # The kernel's row_id_map only guarantees valid data in the first n_routed slots
    slot_indices = jnp.arange(num_experts)[None, :]  # [1, num_experts]
    valid_mask = slot_indices < n_routed[:, None]  # [num_tokens, num_experts]

    # Clamp invalid src_rows to 0 (they won't be used due to masking)
    src_rows_clamped = jnp.where(valid_mask, src_rows, 0)

    # Gather input from permuted positions
    gathered_inp = inp[src_rows_clamped]  # [num_tokens, num_experts, hidden_size]

    # Apply merging probs if provided
    if merging_probs is not None:
        # Gather the merging weights for each (token, expert) pair using advanced indexing
        token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
        weights = merging_probs[token_idx, expert_indices]  # [num_tokens, num_experts]
        gathered_inp = gathered_inp * weights[:, :, None]

    # Mask out invalid entries and sum across experts
    gathered_inp = jnp.where(valid_mask[:, :, None], gathered_inp, 0.0)
    output = jnp.sum(gathered_inp, axis=1)  # [num_tokens, hidden_size]

    unpermuted_probs = None
    if permuted_probs is not None:
        gathered_probs = permuted_probs[src_rows_clamped]  # [num_tokens, num_experts]
        unpermuted_probs = jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype)
        token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
        unpermuted_probs = unpermuted_probs.at[token_idx, expert_indices].set(
            jnp.where(valid_mask, gathered_probs, 0.0)
        )
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

    return output, unpermuted_probs


def reference_token_dispatch(
    inp: jnp.ndarray,
    routing_map: jnp.ndarray,
    num_out_tokens: int,
    probs: jnp.ndarray = None,
) -> tuple:
    """
    Reference implementation of token_dispatch using JAX primitives.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [num_tokens, hidden_size].
    routing_map : jnp.ndarray
        Routing mask of shape [num_tokens, num_experts].
    num_out_tokens : int
        Number of tokens in the permuted tensor.
    probs : jnp.ndarray, optional
        The probabilities of shape [num_tokens, num_experts].

    Returns
    -------
    output : jnp.ndarray
        Permuted output tensor of shape [num_out_tokens, hidden_size].
    permuted_probs : jnp.ndarray or None
        Permuted probabilities of shape [num_out_tokens], or None if probs not provided.
    row_id_map : jnp.ndarray
        The row_id_map for the permutation.
    """
315
316
    row_id_map = reference_make_row_id_map(routing_map)
    output, permuted_probs = _reference_permute_impl(inp, row_id_map, probs, num_out_tokens)
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342

    return output, permuted_probs, row_id_map


def reference_token_combine(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: jnp.ndarray,
) -> jnp.ndarray:
    """
    Reference implementation of token_combine using JAX primitives.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [num_out_tokens, hidden_size].
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
    merging_probs : jnp.ndarray
        The merging probabilities for weighted reduction.

    Returns
    -------
    output : jnp.ndarray
        Unpermuted output tensor of shape [num_tokens, hidden_size].
    """
343
    output, _ = _reference_unpermute_impl(inp, row_id_map, merging_probs, None)
344
345
346
347
348
349
350
351
352
353

    return output


def reference_make_chunk_sort_map(
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
    num_tokens: int,
) -> jnp.ndarray:
    """
354
    Vectorized reference implementation of make_chunk_sort_map using JAX primitives.
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369

    Parameters
    ----------
    split_sizes : jnp.ndarray
        The sizes of the chunks of shape [num_splits,].
    sorted_indices : jnp.ndarray
        The indices of the sorted chunks of shape [num_splits,].
    num_tokens : int
        Number of tokens.

    Returns
    -------
    row_id_map : jnp.ndarray
        Row ID map for chunk sorting of shape [num_tokens,].
    """
370
    # Compute source chunk boundaries (cumulative sum of original split_sizes)
371
372
373
    src_cumsum = jnp.concatenate(
        [jnp.array([0], dtype=jnp.int32), jnp.cumsum(split_sizes).astype(jnp.int32)]
    )
374

375
376
    # Compute destination chunk boundaries based on sorted order
    sorted_sizes = split_sizes[sorted_indices]
377
378
379
    dest_cumsum = jnp.concatenate(
        [jnp.array([0], dtype=jnp.int32), jnp.cumsum(sorted_sizes).astype(jnp.int32)]
    )
380

381
382
    # For each source chunk, compute its destination offset
    # inverse_indices[i] = position of chunk i in sorted order
383
    inverse_indices = jnp.argsort(sorted_indices).astype(jnp.int32)
384
    dest_offsets = dest_cumsum[inverse_indices]
385

386
387
    # Create row_id_map: for each token position, compute its destination
    # First, figure out which chunk each position belongs to
388
    position_indices = jnp.arange(num_tokens, dtype=jnp.int32)
389

390
    # chunk_ids[i] = which chunk position i belongs to
391
    chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right").astype(jnp.int32)
392

393
394
395
396
397
398
399
    # within_chunk_offset[i] = position i's offset within its chunk
    within_chunk_offset = position_indices - src_cumsum[chunk_ids]

    # destination[i] = dest_offsets[chunk_ids[i]] + within_chunk_offset[i]
    row_id_map = dest_offsets[chunk_ids] + within_chunk_offset

    return row_id_map.astype(jnp.int32)
400
401
402
403
404
405
406
407
408


def reference_sort_chunks_by_map(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    probs: jnp.ndarray,
    is_forward: bool,
) -> tuple:
    """
409
    Vectorized reference implementation of sort_chunks_by_map using JAX primitives.
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [num_tokens, hidden_size].
    row_id_map : jnp.ndarray
        The token to destination mapping of shape [num_tokens,].
    probs : jnp.ndarray
        The probabilities.
    is_forward : bool
        Whether this is forward or backward.

    Returns
    -------
    output : jnp.ndarray
        Sorted output tensor of shape [num_tokens, hidden_size].
    permuted_probs : jnp.ndarray
        Sorted probabilities if probs was provided, None otherwise.
    """
429
430
    num_tokens = inp.shape[0]
    hidden_size = inp.shape[1]
431
432

    if is_forward:
433
434
435
436
437
438
439
440
        # Forward: scatter inp[src] to output[dest] where dest = row_id_map[src]
        output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype)
        output = output.at[row_id_map].set(inp)
        if probs is not None:
            permuted_probs = jnp.zeros((num_tokens,), dtype=probs.dtype)
            permuted_probs = permuted_probs.at[row_id_map].set(probs)
        else:
            permuted_probs = None
441
    else:
442
443
444
445
446
447
        # Backward: gather output[dest] = inp[src] where src = row_id_map[dest]
        output = inp[row_id_map]
        if probs is not None:
            permuted_probs = probs[row_id_map]
        else:
            permuted_probs = None
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479

    return output, permuted_probs


class TestHighLevelPermutationAPI:
    """Test high-level permutation APIs (token_dispatch, token_combine, etc.)

    These tests compare the high-level APIs against reference implementations
    to verify correctness of both forward and backward passes.
    """

    @staticmethod
    def generate_routing_map(
        num_tokens: int,
        num_experts: int,
        tokens_per_expert: int = 2,
        key: jax.Array = None,
    ):
        """Generate random routing map for testing"""
        if key is None:
            key = jax.random.PRNGKey(0)

        routing_map = jnp.zeros((num_tokens, num_experts), dtype=jnp.int32)
        for token_idx in range(num_tokens):
            key, subkey = jax.random.split(key)
            expert_indices = jax.random.choice(
                subkey, num_experts, shape=(tokens_per_expert,), replace=False
            )
            routing_map = routing_map.at[token_idx, expert_indices].set(1)

        return routing_map

480
    @pytest_parametrize_wrapper(
481
        "num_tokens,num_experts,hidden_size,tokens_per_expert",
482
        DISPATCH_COMBINE_CASES,
483
    )
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("with_probs", WITH_PROBS)
    def test_token_dispatch(
        self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs
    ):
        """
        Individual test for token_dispatch forward and backward passes.

        This test validates dispatch in isolation to catch errors that might be
        masked when combined with token_combine in the roundtrip test.

        Uses value_and_grad to validate both forward (via loss comparison) and
        backward (via gradient comparison) passes against reference implementation.
        """
498
499
500
501
502
503
504
        key = jax.random.PRNGKey(42)

        # Generate routing map
        routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
        num_out_tokens = int(jnp.sum(routing_map))

        # Generate input data
505
        key, inp_key, prob_key = jax.random.split(key, 3)
506
507
508
509
        inp = jax.random.uniform(
            inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
        )

510
511
512
513
514
515
        # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
        probs = None
        if with_probs:
            probs = jax.random.uniform(
                prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
            )
516

517
518
        # Generate reference row_id_map for comparison
        ref_row_id_map = reference_make_row_id_map(routing_map)
519

520
521
522
523
524
        # =====================================================================
        # Test forward and backward pass using value_and_grad
        # (value validates forward, grad validates backward)
        # =====================================================================
        if with_probs:
525

526
527
528
529
            @jax.jit
            def dispatch_loss(x, p):
                out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p)
                return jnp.sum(out**2) + jnp.sum(perm_probs**2)
530

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
            @jax.jit
            def ref_dispatch_loss(x, p):
                out, perm_probs = _reference_permute_impl(x, ref_row_id_map, p, num_out_tokens)
                return jnp.sum(out**2) + jnp.sum(perm_probs**2)

            loss_val, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))(
                inp, probs
            )
            ref_loss_val, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad(
                ref_dispatch_loss, argnums=(0, 1)
            )(inp, probs)

            # Validate forward loss matches
            assert_allclose(loss_val, ref_loss_val, dtype=dtype)
            # Validate gradients
            assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
            assert_allclose(probs_grad, ref_probs_grad, dtype=dtype)
        else:

            @jax.jit
            def dispatch_loss_no_probs(x):
                out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens)
                return jnp.sum(out**2)

            @jax.jit
            def ref_dispatch_loss_no_probs(x):
                out, _ = _reference_permute_impl(x, ref_row_id_map, None, num_out_tokens)
                return jnp.sum(out**2)

            loss_val, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp)
            ref_loss_val, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp)

            # Validate forward loss matches
            assert_allclose(loss_val, ref_loss_val, dtype=dtype)
            # Validate gradients
            assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
567
568

    # =========================================================================
569
    # Consolidated dispatch + combine tests
570
571
    # =========================================================================

572
    @pytest_parametrize_wrapper(
573
        "num_tokens,num_experts,hidden_size,tokens_per_expert",
574
        DISPATCH_COMBINE_CASES,
575
    )
576
577
578
579
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("with_probs", WITH_PROBS)
    def test_dispatch_and_combine(
        self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs
580
    ):
581
582
583
584
585
586
587
588
589
590
591
592
        """
        Comprehensive test for token_dispatch and token_combine.

        Tests:
        1. Dispatch forward pass against reference (element-by-element)
        2. Dispatch backward pass against reference
        3. Combine forward pass against reference (element-by-element)
        4. Combine backward pass against reference
        5. Roundtrip: dispatch + combine recovers original input
        6. row_id_map n_routed column validation
        7. Probs permutation (when with_probs=True)
        """
593
594
595
596
597
598
        key = jax.random.PRNGKey(42)

        # Generate routing map
        routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
        num_out_tokens = int(jnp.sum(routing_map))

599
600
        # Generate input data
        key, inp_key, prob_key, merge_key = jax.random.split(key, 4)
601
602
603
604
        inp = jax.random.uniform(
            inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
        )

605
606
607
608
609
        # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
        probs = None
        if with_probs:
            probs = jax.random.uniform(
                prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
610
611
            )

612
613
614
        # Generate merging probs (normalized per token)
        merging_probs = jax.random.uniform(
            merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
615
        )
616
617
618
        merging_probs = merging_probs * routing_map.astype(dtype)  # Zero out non-routed
        merging_probs = merging_probs / jnp.maximum(
            jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8
619
620
        )

621
622
623
624
625
626
627
628
        # =====================================================================
        # Test 1: Dispatch forward pass
        # =====================================================================
        output, permuted_probs, row_id_map, _, _ = token_dispatch(
            inp, routing_map, num_out_tokens, probs=probs
        )
        ref_output, ref_permuted_probs = _reference_permute_impl(
            inp, row_id_map, probs, num_out_tokens
629
630
        )

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
        # Validate row_id_map structure: n_routed column should match routing_map sum
        n_routed_actual = row_id_map[:, -1]
        n_routed_expected = jnp.sum(routing_map, axis=1)
        assert jnp.array_equal(
            n_routed_actual, n_routed_expected
        ), "make_row_id_map n_routed column mismatch"

        # Compare dispatch output
        assert_allclose(output, ref_output, dtype=dtype)
        if with_probs:
            assert_allclose(permuted_probs, ref_permuted_probs, dtype=dtype)

        # =====================================================================
        # Test 2: Dispatch backward pass
        # =====================================================================
        if with_probs:

            @jax.jit
            def dispatch_loss(x, p):
                out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p)
                return jnp.sum(out**2) + jnp.sum(perm_probs**2)

            @jax.jit
            def ref_dispatch_loss(x, p):
                out, perm_probs = _reference_permute_impl(x, row_id_map, p, num_out_tokens)
                return jnp.sum(out**2) + jnp.sum(perm_probs**2)

            _, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))(
                inp, probs
660
            )
661
662
663
664
665
            _, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad(
                ref_dispatch_loss, argnums=(0, 1)
            )(inp, probs)
            assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
            assert_allclose(probs_grad, ref_probs_grad, dtype=dtype)
666
667
        else:

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
            @jax.jit
            def dispatch_loss_no_probs(x):
                out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens)
                return jnp.sum(out**2)

            @jax.jit
            def ref_dispatch_loss_no_probs(x):
                out, _ = _reference_permute_impl(x, row_id_map, None, num_out_tokens)
                return jnp.sum(out**2)

            _, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp)
            _, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp)
            assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)

        # =====================================================================
        # Test 3: Combine forward pass
        # =====================================================================
        combined = token_combine(output, row_id_map, merging_probs)
        ref_combined = _reference_unpermute_impl(output, row_id_map, merging_probs, None)[0]
        assert_allclose(combined, ref_combined, dtype=dtype)

        # =====================================================================
        # Test 4: Combine backward pass
        # =====================================================================

        @jax.jit
        def combine_loss(x):
            return jnp.sum(token_combine(x, row_id_map, merging_probs) ** 2)

        @jax.jit
        def ref_combine_loss(x):
            return jnp.sum(_reference_unpermute_impl(x, row_id_map, merging_probs, None)[0] ** 2)

        _, combine_grad = jax.value_and_grad(combine_loss)(output)
        _, ref_combine_grad = jax.value_and_grad(ref_combine_loss)(output)
        assert_allclose(combine_grad, ref_combine_grad, dtype=dtype)

        # =====================================================================
        # Test 5: Roundtrip (dispatch + combine = original)
        # =====================================================================
        # Use uniform merging probs for perfect roundtrip
        uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
            jnp.sum(routing_map, axis=1, keepdims=True), 1.0
        )
712

713
714
715
716
        @jax.jit
        def roundtrip(x):
            dispatched, _, rid_map, _, _ = token_dispatch(x, routing_map, num_out_tokens)
            return token_combine(dispatched, rid_map, uniform_merging_probs)
717

718
719
        roundtrip_output = roundtrip(inp)
        assert_allclose(roundtrip_output, inp, dtype=dtype)
720
721
722
723
724

    # =========================================================================
    # sort_chunks_by_index tests
    # =========================================================================

725
    @pytest_parametrize_wrapper(
726
        "num_splits,total_tokens,hidden_size",
727
        SORT_CHUNKS_CASES,
728
    )
729
    @pytest_parametrize_wrapper("dtype", DTYPES)
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
    def test_sort_chunks_by_index(self, num_splits, total_tokens, hidden_size, dtype):
        """Test sort_chunks_by_index forward and backward pass against reference"""
        key = jax.random.PRNGKey(42)

        # Generate random split sizes
        key, size_key = jax.random.split(key)
        split_sizes = jax.random.randint(size_key, (num_splits,), 10, total_tokens // num_splits)
        split_sizes = split_sizes.at[-1].set(total_tokens - jnp.sum(split_sizes[:-1]))

        # Generate sorted indices
        key, sort_key = jax.random.split(key)
        sorted_indices = jax.random.permutation(sort_key, num_splits)

        # Generate input data
        key, inp_key = jax.random.split(key)
        inp = jax.random.uniform(
            inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
        )

749
750
        # Get reference row_id_map
        row_id_map = reference_make_chunk_sort_map(split_sizes, sorted_indices, total_tokens)
751

752
753
        # Define loss functions (JIT compiled for performance)
        @jax.jit
754
755
756
757
        def loss_fn(x):
            output, _ = sort_chunks_by_index(x, split_sizes, sorted_indices)
            return jnp.sum(output**2)

758
        @jax.jit
759
        def ref_loss_fn(x):
760
            output, _ = reference_sort_chunks_by_map(x, row_id_map, None, is_forward=True)
761
762
            return jnp.sum(output**2)

763
764
765
766
767
        # Test forward pass
        output, _ = sort_chunks_by_index(inp, split_sizes, sorted_indices)
        ref_output, _ = reference_sort_chunks_by_map(inp, row_id_map, None, is_forward=True)

        # Test backward pass with JIT
768
769
770
        loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp)
        ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp)

771
        # Compare forward and backward
772
773
774
775
776
        assert_allclose(output, ref_output)
        assert_allclose(loss_val, ref_loss_val)
        assert_allclose(computed_grad, ref_grad)

    # =========================================================================
777
    # Consolidated dispatch + combine with padding tests
778
779
    # =========================================================================

780
781
782
    @pytest_parametrize_wrapper(
        "num_tokens,num_experts,hidden_size,topk,align_size",
        DISPATCH_COMBINE_PADDING_CASES,
783
    )
784
785
786
787
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("with_probs", WITH_PROBS)
    def test_dispatch_and_combine_with_padding(
        self, num_tokens, num_experts, hidden_size, topk, align_size, dtype, with_probs
788
    ):
789
790
791
792
793
794
795
796
797
798
799
        """
        Comprehensive test for token_dispatch and token_combine with padding/unpadding.

        Tests:
        1. Dispatch with padding: output shape and alignment
        2. Dispatch backward pass with padding
        3. Combine with unpad: output shape
        4. Combine backward pass with unpad
        5. Roundtrip with padding: dispatch + combine recovers original
        6. Probs permutation with padding (when with_probs=True)
        """
800
801
802
        key = jax.random.PRNGKey(42)

        # Generate routing map
803
        routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
804
805
        num_out_tokens = int(jnp.sum(routing_map))

806
807
808
809
810
        # Compute worst-case padded size
        worst_case_size = (
            (num_out_tokens + num_experts * (align_size - 1)) // align_size
        ) * align_size

811
        # Generate input data
812
        key, inp_key, prob_key, merge_key = jax.random.split(key, 4)
813
814
815
816
        inp = jax.random.uniform(
            inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
        )

817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
        # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
        probs = None
        if with_probs:
            probs = jax.random.uniform(
                prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
            )

        # Generate merging probs (normalized per token)
        merging_probs = jax.random.uniform(
            merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
        )
        merging_probs = merging_probs * routing_map.astype(dtype)  # Zero out non-routed
        merging_probs = merging_probs / jnp.maximum(
            jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8
        )

        # =====================================================================
        # Test 1: Dispatch with padding - forward pass
        # =====================================================================
        output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch(
            inp, routing_map, num_out_tokens, probs=probs, align_size=align_size
        )

        # Check output shape
        assert output.shape == (worst_case_size, hidden_size)
        if with_probs:
            assert permuted_probs is not None
            assert permuted_probs.shape == (worst_case_size,)
        else:
            assert permuted_probs is None

        # Check alignment: each expert's tokens should be aligned
        for expert_idx in range(num_experts):
            expert_tokens = int(target_tokens_per_expert[expert_idx])
            assert expert_tokens % align_size == 0 or expert_tokens == 0

        # =====================================================================
        # Test 2: Dispatch with padding - backward pass
        # =====================================================================
        if with_probs:

            @jax.jit
            def dispatch_loss(x, p):
                out, perm_probs, _, _, _ = token_dispatch(
                    x, routing_map, num_out_tokens, probs=p, align_size=align_size
                )
                return jnp.sum(out**2) + jnp.sum(perm_probs**2)

            inp_grad, probs_grad = jax.grad(dispatch_loss, argnums=(0, 1))(inp, probs)
            assert inp_grad.shape == inp.shape
            assert probs_grad.shape == probs.shape
            assert not jnp.any(jnp.isnan(inp_grad))
            assert not jnp.any(jnp.isnan(probs_grad))
        else:

            @jax.jit
            def dispatch_loss_no_probs(x):
                out, _, _, _, _ = token_dispatch(
                    x, routing_map, num_out_tokens, align_size=align_size
                )
                return jnp.sum(out**2)

            inp_grad = jax.grad(dispatch_loss_no_probs)(inp)
            assert inp_grad.shape == inp.shape
            assert not jnp.any(jnp.isnan(inp_grad))

        # =====================================================================
        # Test 3: Combine with unpad - forward pass
        # =====================================================================
        combined = token_combine(output, row_id_map, merging_probs, pad_offsets)
        assert combined.shape == (num_tokens, hidden_size)

        # =====================================================================
        # Test 4: Combine with unpad - backward pass
        # =====================================================================

        @jax.jit
        def combine_loss(x):
            return jnp.sum(token_combine(x, row_id_map, merging_probs, pad_offsets) ** 2)

        combine_grad = jax.grad(combine_loss)(output)
        assert combine_grad.shape == output.shape
        assert not jnp.any(jnp.isnan(combine_grad))

        # =====================================================================
        # Test 5: Roundtrip with padding (dispatch + combine = original)
        # =====================================================================
        # Use uniform merging probs for perfect roundtrip
        uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
906
907
908
            jnp.sum(routing_map, axis=1, keepdims=True), 1.0
        )

909
910
911
912
913
914
915
916
917
        @jax.jit
        def roundtrip(x):
            dispatched, _, rid_map, p_offsets, _ = token_dispatch(
                x, routing_map, num_out_tokens, align_size=align_size
            )
            return token_combine(dispatched, rid_map, uniform_merging_probs, p_offsets)

        roundtrip_output = roundtrip(inp)
        assert_allclose(roundtrip_output, inp, dtype=dtype)
918

919
920
921
922
        # Test roundtrip gradient
        @jax.jit
        def roundtrip_loss(x):
            return jnp.sum(roundtrip(x) ** 2)
923

924
925
926
        roundtrip_grad = jax.grad(roundtrip_loss)(inp)
        assert roundtrip_grad.shape == inp.shape
        assert not jnp.any(jnp.isnan(roundtrip_grad))