test_permutation.py 35.4 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

"""Tests for permutation Triton kernels and high-level APIs"""

7
8
import functools

9
10
11
12
13
14
15
16
17
18
import jax
import jax.numpy as jnp
import pytest

# High-level API with VJP support
from transformer_engine.jax.permutation import (
    token_dispatch,
    token_combine,
    sort_chunks_by_index,
)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from utils import assert_allclose, pytest_parametrize_wrapper


# =============================================================================
# Test parameter definitions with L0 (fast) and L2 (comprehensive) levels
# =============================================================================

# All dispatch/combine test cases
ALL_DISPATCH_COMBINE_CASES = [
    (128, 5, 128, 3),
    (1024, 8, 128, 8),
    (4096, 32, 1280, 2),
    (4096, 256, 4096, 6),
]
DISPATCH_COMBINE_CASES = {
    "L0": ALL_DISPATCH_COMBINE_CASES[0:2],
    "L2": ALL_DISPATCH_COMBINE_CASES,
}

# All sort chunks test cases
ALL_SORT_CHUNKS_CASES = [
    (8, 4096, 1280),
    (64, 4096, 4096),
    (256, 4096, 9216),
]
SORT_CHUNKS_CASES = {
    "L0": ALL_SORT_CHUNKS_CASES[0:2],
    "L2": ALL_SORT_CHUNKS_CASES,
}

# All dispatch/combine with padding test cases
ALL_DISPATCH_COMBINE_PADDING_CASES = [
    (128, 5, 128, 3, 8),
    (1024, 8, 128, 8, 16),
    (4096, 32, 1280, 2, 128),
    (4096, 256, 4096, 6, 16),
]
DISPATCH_COMBINE_PADDING_CASES = {
    "L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2],
    "L2": ALL_DISPATCH_COMBINE_PADDING_CASES,
}

# Dtypes for testing
ALL_DTYPES = [jnp.float32, jnp.bfloat16]
DTYPES = {
    "L0": ALL_DTYPES,
    "L2": ALL_DTYPES,
}

# With probs options
ALL_WITH_PROBS = [True, False]
WITH_PROBS = {
    "L0": [True],
    "L2": ALL_WITH_PROBS,
}
74
75
76
77
78
79


def reference_make_row_id_map(
    routing_map: jnp.ndarray,
) -> jnp.ndarray:
    """
80
    Vectorized reference implementation of make_row_id_map using JAX primitives.
81
82
83
84
85
86
87
88
89
90
91
92

    Parameters
    ----------
    routing_map : jnp.ndarray
        Input tensor of shape [num_tokens, num_experts]. Mask indicating which experts
        are routed to which tokens (1 = routed, 0 = not routed).

    Returns
    -------
    row_id_map : jnp.ndarray
        The row_id_map for the permutation of shape [num_tokens, num_experts * 2 + 1].
    """
93
    num_tokens, num_experts = routing_map.shape
94
95
96
97

    # For each expert, compute cumulative sum to get destination indices
    cumsum_per_expert = jnp.cumsum(routing_map, axis=0)

98
    # Compute total tokens per expert and expert offsets
99
    tokens_per_expert = jnp.sum(routing_map, axis=0)
100
101
102
    expert_offsets = jnp.concatenate(
        [jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)]
    )
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    # Compute destination rows for all (token, expert) pairs
    # dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1
    dest_rows_all = (expert_offsets[None, :] + cumsum_per_expert - 1) * routing_map + (-1) * (
        1 - routing_map
    )

    # Count routed experts per token
    n_routed_per_token = jnp.sum(routing_map, axis=1)

    # For each token, we need to sort by descending dest_row and pack into row_id_map
    # Use a large negative value for non-routed experts so they sort to the end
    sort_keys = jnp.where(routing_map == 1, -dest_rows_all, jnp.iinfo(jnp.int32).max)
    sorted_expert_indices = jnp.argsort(sort_keys, axis=1)

    # Gather the sorted destination rows and expert indices using advanced indexing
    # Create indices for gathering
120
121
122
    token_idx = jnp.broadcast_to(
        jnp.arange(num_tokens, dtype=jnp.int32)[:, None], (num_tokens, num_experts)
    )
123
124
125
126
127
128
129
130
131
132
133
    sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices]

    # Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed]
    row_id_map = jnp.concatenate(
        [
            sorted_dest_rows.astype(jnp.int32),
            sorted_expert_indices.astype(jnp.int32),
            n_routed_per_token.astype(jnp.int32)[:, None],
        ],
        axis=1,
    )
134
135
136
137
138
139
140
141
142
143
144

    return row_id_map


def _reference_permute_impl(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    probs: jnp.ndarray,
    num_out_tokens: int,
) -> tuple:
    """
145
    Vectorized internal helper for reference permutation implementation.
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [num_tokens, hidden_size].
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
    probs : jnp.ndarray
        The probabilities of the input tensor.
    num_out_tokens : int
        Number of tokens in the permuted tensor.

    Returns
    -------
    output : jnp.ndarray
        Permuted output tensor of shape [num_out_tokens, hidden_size].
    permuted_probs : jnp.ndarray
        Permuted probabilities if probs was provided, None otherwise.
    """
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    num_tokens, hidden_size = inp.shape
    num_experts = (row_id_map.shape[1] - 1) // 2

    # Extract destination rows, expert indices, and n_routed from row_id_map
    dest_rows = row_id_map[:, :num_experts]  # [num_tokens, num_experts]
    expert_indices = row_id_map[:, num_experts : 2 * num_experts]  # [num_tokens, num_experts]
    n_routed = row_id_map[:, 2 * num_experts]  # [num_tokens]

    # Create mask for valid entries: slot_idx < n_routed[token]
    # The kernel's row_id_map only guarantees valid data in the first n_routed slots
    # (slots beyond n_routed may contain garbage, not -1)
    slot_indices = jnp.arange(num_experts)[None, :]  # [1, num_experts]
    valid_mask = slot_indices < n_routed[:, None]  # [num_tokens, num_experts]

    # Flatten for scatter operations
    flat_dest_rows = dest_rows.flatten()  # [num_tokens * num_experts]
    flat_valid_mask = valid_mask.flatten()
    flat_token_indices = jnp.repeat(jnp.arange(num_tokens), num_experts)
    flat_expert_indices = expert_indices.flatten()

    # Set invalid dest_rows to num_out_tokens (out of bounds, will be dropped)
    # This avoids overwriting valid entries at index 0 with zeros
    flat_dest_rows_clamped = jnp.where(flat_valid_mask, flat_dest_rows, num_out_tokens)

    # Gather input tokens and scatter to output
190
    output = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype)
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    gathered_inp = inp[flat_token_indices]  # [num_tokens * num_experts, hidden_size]

    # Use segment_sum-like operation via scatter
    # For each valid (token, expert) pair, write inp[token] to output[dest_row]
    # Invalid entries target num_out_tokens and get dropped by mode="drop"
    output = output.at[flat_dest_rows_clamped].set(
        gathered_inp,
        mode="drop",
    )

    permuted_probs = None
    if probs is not None:
        permuted_probs = jnp.zeros((num_out_tokens,), dtype=probs.dtype)

        # Vectorized approach: gather probs and scatter to permuted_probs
        if probs.ndim == 1:
            flat_probs = probs[flat_token_indices]
        else:
            # Clamp invalid expert indices to 0 to avoid wraparound indexing with -1
            # The result for invalid entries will be ignored anyway since they target num_out_tokens
            # Cast to int32 explicitly for consistent indexing behavior
            flat_expert_indices_clamped = jnp.where(flat_valid_mask, flat_expert_indices, 0).astype(
                jnp.int32
            )
            flat_probs = probs[flat_token_indices.astype(jnp.int32), flat_expert_indices_clamped]

        # Invalid entries target num_out_tokens and get dropped by mode="drop"
        permuted_probs = permuted_probs.at[flat_dest_rows_clamped.astype(jnp.int32)].set(
            flat_probs,
            mode="drop",
        )
222
223
224
225
226
227
228
229
230
231
232

    return output, permuted_probs


def _reference_unpermute_impl(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: jnp.ndarray,
    permuted_probs: jnp.ndarray,
) -> tuple:
    """
233
    Vectorized internal helper for reference unpermutation implementation.
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [num_out_tokens, hidden_size].
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
    merging_probs : jnp.ndarray
        The merging probabilities for weighted reduction.
    permuted_probs : jnp.ndarray
        The permuted probabilities.

    Returns
    -------
    output : jnp.ndarray
        Unpermuted output tensor of shape [num_tokens, hidden_size].
    unpermuted_probs : jnp.ndarray
        Unpermuted probabilities if permuted_probs was provided, None otherwise.
    """
253
254
    num_tokens = row_id_map.shape[0]
    num_experts = (row_id_map.shape[1] - 1) // 2
255

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    # Extract source rows, expert indices, and n_routed from row_id_map
    src_rows = row_id_map[:, :num_experts]  # [num_tokens, num_experts]
    expert_indices = row_id_map[:, num_experts : 2 * num_experts]  # [num_tokens, num_experts]
    n_routed = row_id_map[:, 2 * num_experts]  # [num_tokens]

    # Create mask for valid entries: slot_idx < n_routed[token]
    # The kernel's row_id_map only guarantees valid data in the first n_routed slots
    slot_indices = jnp.arange(num_experts)[None, :]  # [1, num_experts]
    valid_mask = slot_indices < n_routed[:, None]  # [num_tokens, num_experts]

    # Clamp invalid src_rows to 0 (they won't be used due to masking)
    src_rows_clamped = jnp.where(valid_mask, src_rows, 0)

    # Gather input from permuted positions
    gathered_inp = inp[src_rows_clamped]  # [num_tokens, num_experts, hidden_size]

    # Apply merging probs if provided
    if merging_probs is not None:
        # Gather the merging weights for each (token, expert) pair using advanced indexing
        token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
        weights = merging_probs[token_idx, expert_indices]  # [num_tokens, num_experts]
        gathered_inp = gathered_inp * weights[:, :, None]

    # Mask out invalid entries and sum across experts
    gathered_inp = jnp.where(valid_mask[:, :, None], gathered_inp, 0.0)
    output = jnp.sum(gathered_inp, axis=1)  # [num_tokens, hidden_size]

    unpermuted_probs = None
    if permuted_probs is not None:
        gathered_probs = permuted_probs[src_rows_clamped]  # [num_tokens, num_experts]
        unpermuted_probs = jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype)
        token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
        unpermuted_probs = unpermuted_probs.at[token_idx, expert_indices].set(
            jnp.where(valid_mask, gathered_probs, 0.0)
        )
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

    return output, unpermuted_probs


def reference_token_dispatch(
    inp: jnp.ndarray,
    routing_map: jnp.ndarray,
    num_out_tokens: int,
    probs: jnp.ndarray = None,
) -> tuple:
    """
    Reference implementation of token_dispatch using JAX primitives.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [num_tokens, hidden_size].
    routing_map : jnp.ndarray
        Routing mask of shape [num_tokens, num_experts].
    num_out_tokens : int
        Number of tokens in the permuted tensor.
    probs : jnp.ndarray, optional
        The probabilities of shape [num_tokens, num_experts].

    Returns
    -------
    output : jnp.ndarray
        Permuted output tensor of shape [num_out_tokens, hidden_size].
    permuted_probs : jnp.ndarray or None
        Permuted probabilities of shape [num_out_tokens], or None if probs not provided.
    row_id_map : jnp.ndarray
        The row_id_map for the permutation.
    """
324
325
    row_id_map = reference_make_row_id_map(routing_map)
    output, permuted_probs = _reference_permute_impl(inp, row_id_map, probs, num_out_tokens)
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

    return output, permuted_probs, row_id_map


def reference_token_combine(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: jnp.ndarray,
) -> jnp.ndarray:
    """
    Reference implementation of token_combine using JAX primitives.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [num_out_tokens, hidden_size].
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
    merging_probs : jnp.ndarray
        The merging probabilities for weighted reduction.

    Returns
    -------
    output : jnp.ndarray
        Unpermuted output tensor of shape [num_tokens, hidden_size].
    """
352
    output, _ = _reference_unpermute_impl(inp, row_id_map, merging_probs, None)
353
354
355
356
357
358
359
360
361
362

    return output


def reference_make_chunk_sort_map(
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
    num_tokens: int,
) -> jnp.ndarray:
    """
363
    Vectorized reference implementation of make_chunk_sort_map using JAX primitives.
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

    Parameters
    ----------
    split_sizes : jnp.ndarray
        The sizes of the chunks of shape [num_splits,].
    sorted_indices : jnp.ndarray
        The indices of the sorted chunks of shape [num_splits,].
    num_tokens : int
        Number of tokens.

    Returns
    -------
    row_id_map : jnp.ndarray
        Row ID map for chunk sorting of shape [num_tokens,].
    """
379
    # Compute source chunk boundaries (cumulative sum of original split_sizes)
380
381
382
    src_cumsum = jnp.concatenate(
        [jnp.array([0], dtype=jnp.int32), jnp.cumsum(split_sizes).astype(jnp.int32)]
    )
383

384
385
    # Compute destination chunk boundaries based on sorted order
    sorted_sizes = split_sizes[sorted_indices]
386
387
388
    dest_cumsum = jnp.concatenate(
        [jnp.array([0], dtype=jnp.int32), jnp.cumsum(sorted_sizes).astype(jnp.int32)]
    )
389

390
391
392
393
    # For each source chunk, compute its destination offset
    # inverse_indices[i] = position of chunk i in sorted order
    inverse_indices = jnp.argsort(sorted_indices)
    dest_offsets = dest_cumsum[inverse_indices]
394

395
396
    # Create row_id_map: for each token position, compute its destination
    # First, figure out which chunk each position belongs to
397
    position_indices = jnp.arange(num_tokens, dtype=jnp.int32)
398

399
400
    # chunk_ids[i] = which chunk position i belongs to
    chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right")
401

402
403
404
405
406
407
408
    # within_chunk_offset[i] = position i's offset within its chunk
    within_chunk_offset = position_indices - src_cumsum[chunk_ids]

    # destination[i] = dest_offsets[chunk_ids[i]] + within_chunk_offset[i]
    row_id_map = dest_offsets[chunk_ids] + within_chunk_offset

    return row_id_map.astype(jnp.int32)
409
410
411
412
413
414
415
416
417


def reference_sort_chunks_by_map(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    probs: jnp.ndarray,
    is_forward: bool,
) -> tuple:
    """
418
    Vectorized reference implementation of sort_chunks_by_map using JAX primitives.
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [num_tokens, hidden_size].
    row_id_map : jnp.ndarray
        The token to destination mapping of shape [num_tokens,].
    probs : jnp.ndarray
        The probabilities.
    is_forward : bool
        Whether this is forward or backward.

    Returns
    -------
    output : jnp.ndarray
        Sorted output tensor of shape [num_tokens, hidden_size].
    permuted_probs : jnp.ndarray
        Sorted probabilities if probs was provided, None otherwise.
    """
438
439
    num_tokens = inp.shape[0]
    hidden_size = inp.shape[1]
440
441

    if is_forward:
442
443
444
445
446
447
448
449
        # Forward: scatter inp[src] to output[dest] where dest = row_id_map[src]
        output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype)
        output = output.at[row_id_map].set(inp)
        if probs is not None:
            permuted_probs = jnp.zeros((num_tokens,), dtype=probs.dtype)
            permuted_probs = permuted_probs.at[row_id_map].set(probs)
        else:
            permuted_probs = None
450
    else:
451
452
453
454
455
456
        # Backward: gather output[dest] = inp[src] where src = row_id_map[dest]
        output = inp[row_id_map]
        if probs is not None:
            permuted_probs = probs[row_id_map]
        else:
            permuted_probs = None
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488

    return output, permuted_probs


class TestHighLevelPermutationAPI:
    """Test high-level permutation APIs (token_dispatch, token_combine, etc.)

    These tests compare the high-level APIs against reference implementations
    to verify correctness of both forward and backward passes.
    """

    @staticmethod
    def generate_routing_map(
        num_tokens: int,
        num_experts: int,
        tokens_per_expert: int = 2,
        key: jax.Array = None,
    ):
        """Generate random routing map for testing"""
        if key is None:
            key = jax.random.PRNGKey(0)

        routing_map = jnp.zeros((num_tokens, num_experts), dtype=jnp.int32)
        for token_idx in range(num_tokens):
            key, subkey = jax.random.split(key)
            expert_indices = jax.random.choice(
                subkey, num_experts, shape=(tokens_per_expert,), replace=False
            )
            routing_map = routing_map.at[token_idx, expert_indices].set(1)

        return routing_map

489
    @pytest_parametrize_wrapper(
490
        "num_tokens,num_experts,hidden_size,tokens_per_expert",
491
        DISPATCH_COMBINE_CASES,
492
    )
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("with_probs", WITH_PROBS)
    def test_token_dispatch(
        self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs
    ):
        """
        Individual test for token_dispatch forward and backward passes.

        This test validates dispatch in isolation to catch errors that might be
        masked when combined with token_combine in the roundtrip test.

        Uses value_and_grad to validate both forward (via loss comparison) and
        backward (via gradient comparison) passes against reference implementation.
        """
507
508
509
510
511
512
513
        key = jax.random.PRNGKey(42)

        # Generate routing map
        routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
        num_out_tokens = int(jnp.sum(routing_map))

        # Generate input data
514
        key, inp_key, prob_key = jax.random.split(key, 3)
515
516
517
518
        inp = jax.random.uniform(
            inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
        )

519
520
521
522
523
524
        # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
        probs = None
        if with_probs:
            probs = jax.random.uniform(
                prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
            )
525

526
527
        # Generate reference row_id_map for comparison
        ref_row_id_map = reference_make_row_id_map(routing_map)
528

529
530
531
532
533
        # =====================================================================
        # Test forward and backward pass using value_and_grad
        # (value validates forward, grad validates backward)
        # =====================================================================
        if with_probs:
534

535
536
537
538
            @jax.jit
            def dispatch_loss(x, p):
                out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p)
                return jnp.sum(out**2) + jnp.sum(perm_probs**2)
539

540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
            @jax.jit
            def ref_dispatch_loss(x, p):
                out, perm_probs = _reference_permute_impl(x, ref_row_id_map, p, num_out_tokens)
                return jnp.sum(out**2) + jnp.sum(perm_probs**2)

            loss_val, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))(
                inp, probs
            )
            ref_loss_val, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad(
                ref_dispatch_loss, argnums=(0, 1)
            )(inp, probs)

            # Validate forward loss matches
            assert_allclose(loss_val, ref_loss_val, dtype=dtype)
            # Validate gradients
            assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
            assert_allclose(probs_grad, ref_probs_grad, dtype=dtype)
        else:

            @jax.jit
            def dispatch_loss_no_probs(x):
                out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens)
                return jnp.sum(out**2)

            @jax.jit
            def ref_dispatch_loss_no_probs(x):
                out, _ = _reference_permute_impl(x, ref_row_id_map, None, num_out_tokens)
                return jnp.sum(out**2)

            loss_val, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp)
            ref_loss_val, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp)

            # Validate forward loss matches
            assert_allclose(loss_val, ref_loss_val, dtype=dtype)
            # Validate gradients
            assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
576
577

    # =========================================================================
578
    # Consolidated dispatch + combine tests
579
580
    # =========================================================================

581
    @pytest_parametrize_wrapper(
582
        "num_tokens,num_experts,hidden_size,tokens_per_expert",
583
        DISPATCH_COMBINE_CASES,
584
    )
585
586
587
588
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("with_probs", WITH_PROBS)
    def test_dispatch_and_combine(
        self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs
589
    ):
590
591
592
593
594
595
596
597
598
599
600
601
        """
        Comprehensive test for token_dispatch and token_combine.

        Tests:
        1. Dispatch forward pass against reference (element-by-element)
        2. Dispatch backward pass against reference
        3. Combine forward pass against reference (element-by-element)
        4. Combine backward pass against reference
        5. Roundtrip: dispatch + combine recovers original input
        6. row_id_map n_routed column validation
        7. Probs permutation (when with_probs=True)
        """
602
603
604
605
606
607
        key = jax.random.PRNGKey(42)

        # Generate routing map
        routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
        num_out_tokens = int(jnp.sum(routing_map))

608
609
        # Generate input data
        key, inp_key, prob_key, merge_key = jax.random.split(key, 4)
610
611
612
613
        inp = jax.random.uniform(
            inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
        )

614
615
616
617
618
        # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
        probs = None
        if with_probs:
            probs = jax.random.uniform(
                prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
619
620
            )

621
622
623
        # Generate merging probs (normalized per token)
        merging_probs = jax.random.uniform(
            merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
624
        )
625
626
627
        merging_probs = merging_probs * routing_map.astype(dtype)  # Zero out non-routed
        merging_probs = merging_probs / jnp.maximum(
            jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8
628
629
        )

630
631
632
633
634
635
636
637
        # =====================================================================
        # Test 1: Dispatch forward pass
        # =====================================================================
        output, permuted_probs, row_id_map, _, _ = token_dispatch(
            inp, routing_map, num_out_tokens, probs=probs
        )
        ref_output, ref_permuted_probs = _reference_permute_impl(
            inp, row_id_map, probs, num_out_tokens
638
639
        )

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
        # Validate row_id_map structure: n_routed column should match routing_map sum
        n_routed_actual = row_id_map[:, -1]
        n_routed_expected = jnp.sum(routing_map, axis=1)
        assert jnp.array_equal(
            n_routed_actual, n_routed_expected
        ), "make_row_id_map n_routed column mismatch"

        # Compare dispatch output
        assert_allclose(output, ref_output, dtype=dtype)
        if with_probs:
            assert_allclose(permuted_probs, ref_permuted_probs, dtype=dtype)

        # =====================================================================
        # Test 2: Dispatch backward pass
        # =====================================================================
        if with_probs:

            @jax.jit
            def dispatch_loss(x, p):
                out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p)
                return jnp.sum(out**2) + jnp.sum(perm_probs**2)

            @jax.jit
            def ref_dispatch_loss(x, p):
                out, perm_probs = _reference_permute_impl(x, row_id_map, p, num_out_tokens)
                return jnp.sum(out**2) + jnp.sum(perm_probs**2)

            _, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))(
                inp, probs
669
            )
670
671
672
673
674
            _, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad(
                ref_dispatch_loss, argnums=(0, 1)
            )(inp, probs)
            assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
            assert_allclose(probs_grad, ref_probs_grad, dtype=dtype)
675
676
        else:

677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
            @jax.jit
            def dispatch_loss_no_probs(x):
                out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens)
                return jnp.sum(out**2)

            @jax.jit
            def ref_dispatch_loss_no_probs(x):
                out, _ = _reference_permute_impl(x, row_id_map, None, num_out_tokens)
                return jnp.sum(out**2)

            _, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp)
            _, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp)
            assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)

        # =====================================================================
        # Test 3: Combine forward pass
        # =====================================================================
        combined = token_combine(output, row_id_map, merging_probs)
        ref_combined = _reference_unpermute_impl(output, row_id_map, merging_probs, None)[0]
        assert_allclose(combined, ref_combined, dtype=dtype)

        # =====================================================================
        # Test 4: Combine backward pass
        # =====================================================================

        @jax.jit
        def combine_loss(x):
            return jnp.sum(token_combine(x, row_id_map, merging_probs) ** 2)

        @jax.jit
        def ref_combine_loss(x):
            return jnp.sum(_reference_unpermute_impl(x, row_id_map, merging_probs, None)[0] ** 2)

        _, combine_grad = jax.value_and_grad(combine_loss)(output)
        _, ref_combine_grad = jax.value_and_grad(ref_combine_loss)(output)
        assert_allclose(combine_grad, ref_combine_grad, dtype=dtype)

        # =====================================================================
        # Test 5: Roundtrip (dispatch + combine = original)
        # =====================================================================
        # Use uniform merging probs for perfect roundtrip
        uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
            jnp.sum(routing_map, axis=1, keepdims=True), 1.0
        )
721

722
723
724
725
        @jax.jit
        def roundtrip(x):
            dispatched, _, rid_map, _, _ = token_dispatch(x, routing_map, num_out_tokens)
            return token_combine(dispatched, rid_map, uniform_merging_probs)
726

727
728
        roundtrip_output = roundtrip(inp)
        assert_allclose(roundtrip_output, inp, dtype=dtype)
729
730
731
732
733

    # =========================================================================
    # sort_chunks_by_index tests
    # =========================================================================

734
    @pytest_parametrize_wrapper(
735
        "num_splits,total_tokens,hidden_size",
736
        SORT_CHUNKS_CASES,
737
    )
738
    @pytest_parametrize_wrapper("dtype", DTYPES)
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
    def test_sort_chunks_by_index(self, num_splits, total_tokens, hidden_size, dtype):
        """Test sort_chunks_by_index forward and backward pass against reference"""
        key = jax.random.PRNGKey(42)

        # Generate random split sizes
        key, size_key = jax.random.split(key)
        split_sizes = jax.random.randint(size_key, (num_splits,), 10, total_tokens // num_splits)
        split_sizes = split_sizes.at[-1].set(total_tokens - jnp.sum(split_sizes[:-1]))

        # Generate sorted indices
        key, sort_key = jax.random.split(key)
        sorted_indices = jax.random.permutation(sort_key, num_splits)

        # Generate input data
        key, inp_key = jax.random.split(key)
        inp = jax.random.uniform(
            inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
        )

758
759
        # Get reference row_id_map
        row_id_map = reference_make_chunk_sort_map(split_sizes, sorted_indices, total_tokens)
760

761
762
        # Define loss functions (JIT compiled for performance)
        @jax.jit
763
764
765
766
        def loss_fn(x):
            output, _ = sort_chunks_by_index(x, split_sizes, sorted_indices)
            return jnp.sum(output**2)

767
        @jax.jit
768
        def ref_loss_fn(x):
769
            output, _ = reference_sort_chunks_by_map(x, row_id_map, None, is_forward=True)
770
771
            return jnp.sum(output**2)

772
773
774
775
776
        # Test forward pass
        output, _ = sort_chunks_by_index(inp, split_sizes, sorted_indices)
        ref_output, _ = reference_sort_chunks_by_map(inp, row_id_map, None, is_forward=True)

        # Test backward pass with JIT
777
778
779
        loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp)
        ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp)

780
        # Compare forward and backward
781
782
783
784
785
        assert_allclose(output, ref_output)
        assert_allclose(loss_val, ref_loss_val)
        assert_allclose(computed_grad, ref_grad)

    # =========================================================================
786
    # Consolidated dispatch + combine with padding tests
787
788
    # =========================================================================

789
790
791
    @pytest_parametrize_wrapper(
        "num_tokens,num_experts,hidden_size,topk,align_size",
        DISPATCH_COMBINE_PADDING_CASES,
792
    )
793
794
795
796
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("with_probs", WITH_PROBS)
    def test_dispatch_and_combine_with_padding(
        self, num_tokens, num_experts, hidden_size, topk, align_size, dtype, with_probs
797
    ):
798
799
800
801
802
803
804
805
806
807
808
        """
        Comprehensive test for token_dispatch and token_combine with padding/unpadding.

        Tests:
        1. Dispatch with padding: output shape and alignment
        2. Dispatch backward pass with padding
        3. Combine with unpad: output shape
        4. Combine backward pass with unpad
        5. Roundtrip with padding: dispatch + combine recovers original
        6. Probs permutation with padding (when with_probs=True)
        """
809
810
811
        key = jax.random.PRNGKey(42)

        # Generate routing map
812
        routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
813
814
        num_out_tokens = int(jnp.sum(routing_map))

815
816
817
818
819
        # Compute worst-case padded size
        worst_case_size = (
            (num_out_tokens + num_experts * (align_size - 1)) // align_size
        ) * align_size

820
        # Generate input data
821
        key, inp_key, prob_key, merge_key = jax.random.split(key, 4)
822
823
824
825
        inp = jax.random.uniform(
            inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
        )

826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
        # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
        probs = None
        if with_probs:
            probs = jax.random.uniform(
                prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
            )

        # Generate merging probs (normalized per token)
        merging_probs = jax.random.uniform(
            merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
        )
        merging_probs = merging_probs * routing_map.astype(dtype)  # Zero out non-routed
        merging_probs = merging_probs / jnp.maximum(
            jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8
        )

        # =====================================================================
        # Test 1: Dispatch with padding - forward pass
        # =====================================================================
        output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch(
            inp, routing_map, num_out_tokens, probs=probs, align_size=align_size
        )

        # Check output shape
        assert output.shape == (worst_case_size, hidden_size)
        if with_probs:
            assert permuted_probs is not None
            assert permuted_probs.shape == (worst_case_size,)
        else:
            assert permuted_probs is None

        # Check alignment: each expert's tokens should be aligned
        for expert_idx in range(num_experts):
            expert_tokens = int(target_tokens_per_expert[expert_idx])
            assert expert_tokens % align_size == 0 or expert_tokens == 0

        # =====================================================================
        # Test 2: Dispatch with padding - backward pass
        # =====================================================================
        if with_probs:

            @jax.jit
            def dispatch_loss(x, p):
                out, perm_probs, _, _, _ = token_dispatch(
                    x, routing_map, num_out_tokens, probs=p, align_size=align_size
                )
                return jnp.sum(out**2) + jnp.sum(perm_probs**2)

            inp_grad, probs_grad = jax.grad(dispatch_loss, argnums=(0, 1))(inp, probs)
            assert inp_grad.shape == inp.shape
            assert probs_grad.shape == probs.shape
            assert not jnp.any(jnp.isnan(inp_grad))
            assert not jnp.any(jnp.isnan(probs_grad))
        else:

            @jax.jit
            def dispatch_loss_no_probs(x):
                out, _, _, _, _ = token_dispatch(
                    x, routing_map, num_out_tokens, align_size=align_size
                )
                return jnp.sum(out**2)

            inp_grad = jax.grad(dispatch_loss_no_probs)(inp)
            assert inp_grad.shape == inp.shape
            assert not jnp.any(jnp.isnan(inp_grad))

        # =====================================================================
        # Test 3: Combine with unpad - forward pass
        # =====================================================================
        combined = token_combine(output, row_id_map, merging_probs, pad_offsets)
        assert combined.shape == (num_tokens, hidden_size)

        # =====================================================================
        # Test 4: Combine with unpad - backward pass
        # =====================================================================

        @jax.jit
        def combine_loss(x):
            return jnp.sum(token_combine(x, row_id_map, merging_probs, pad_offsets) ** 2)

        combine_grad = jax.grad(combine_loss)(output)
        assert combine_grad.shape == output.shape
        assert not jnp.any(jnp.isnan(combine_grad))

        # =====================================================================
        # Test 5: Roundtrip with padding (dispatch + combine = original)
        # =====================================================================
        # Use uniform merging probs for perfect roundtrip
        uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
915
916
917
            jnp.sum(routing_map, axis=1, keepdims=True), 1.0
        )

918
919
920
921
922
923
924
925
926
        @jax.jit
        def roundtrip(x):
            dispatched, _, rid_map, p_offsets, _ = token_dispatch(
                x, routing_map, num_out_tokens, align_size=align_size
            )
            return token_combine(dispatched, rid_map, uniform_merging_probs, p_offsets)

        roundtrip_output = roundtrip(inp)
        assert_allclose(roundtrip_output, inp, dtype=dtype)
927

928
929
930
931
        # Test roundtrip gradient
        @jax.jit
        def roundtrip_loss(x):
            return jnp.sum(roundtrip(x) ** 2)
932

933
934
935
        roundtrip_grad = jax.grad(roundtrip_loss)(inp)
        assert roundtrip_grad.shape == inp.shape
        assert not jnp.any(jnp.isnan(roundtrip_grad))