permutation.py 21.4 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#
# See LICENSE for license information.

"""MoE Permutation API for JAX.

This module provides high-level token dispatch and combine operations for
Mixture of Experts (MoE) models with proper automatic differentiation support.

Token Dispatch (Permute):
    - Forward: Permute tokens according to routing map (scatter to experts)
    - Backward: Unpermute gradients (gather from experts)

Token Combine (Unpermute):
    - Forward: Unpermute tokens and merge with weights (gather from experts)
    - Backward: Permute gradients (scatter to experts)
"""

from functools import partial
from typing import Optional, Tuple

import jax
import jax.numpy as jnp

from transformer_engine.jax.triton_extensions.permutation import (
    make_row_id_map,
    permute_with_mask_map,
28
    permute_with_mask_map_and_pad,
29
    unpermute_with_mask_map,
30
    unpermute_with_mask_map_and_unpad,
31
    unpermute_bwd_with_merging_probs,
32
    unpermute_bwd_with_merging_probs_and_unpad,
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    make_chunk_sort_map,
    sort_chunks_by_map,
)

__all__ = [
    "token_dispatch",
    "token_combine",
    "sort_chunks_by_index",
]


def token_dispatch(
    inp: jnp.ndarray,
    routing_map: jnp.ndarray,
    num_out_tokens: int,
    probs: Optional[jnp.ndarray] = None,
49
50
51
52
53
54
    align_size: Optional[int] = None,
) -> Tuple[
    jnp.ndarray,
    Optional[jnp.ndarray],
    jnp.ndarray,
    Optional[jnp.ndarray],
55
    jnp.ndarray,
56
]:
57
58
59
60
61
62
63
    """
    Dispatch tokens to experts based on routing map.

    This is the forward pass of the MoE permutation. Tokens are scattered
    to their designated experts according to the routing map. The row_id_map
    is computed internally from the routing_map.

64
65
66
67
    Optionally supports fused padding for alignment when `align_size` is provided.
    This is useful for efficient matrix multiplications that require aligned tensor
    dimensions. The padding is computed internally from the routing_map.

68
69
70
71
72
73
74
75
    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [batch, sequence, hidden_size] or [num_tokens, hidden_size].
    routing_map : jnp.ndarray
        Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
        Values: 1 = routed, 0 = not routed.
    num_out_tokens : int
76
77
78
        The number of output tokens after permutation (before padding). For the dropless
        case, this should be equal to the sum of routing_map. Must be provided explicitly
        for JIT compatibility since output shape must be known at compile time.
79
80
81
    probs : Optional[jnp.ndarray]
        Optional routing probabilities of shape [batch, sequence, num_experts] or
        [num_tokens, num_experts]. If provided, permuted_probs will be returned.
82
83
84
85
86
87
    align_size : Optional[int]
        Optional alignment size for padding. If provided, outputs will be padded to
        align each expert's tokens to a multiple of this size. The output buffer is
        allocated with worst-case size, rounded down to align_size:
        ((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
        This enables full JIT compatibility.
88
89
90
91

    Returns
    -------
    output : jnp.ndarray
92
93
94
95
        Permuted output tensor of shape [num_out_tokens, hidden_size] without padding,
        or [worst_case_padded_size, hidden_size] when using padding fusion.
        With padding, the actual used portion may be smaller than the buffer; check
        actual_num_out_tokens (sum of target_tokens_per_expert) for the actual size.
96
    permuted_probs : Optional[jnp.ndarray]
97
98
        Permuted probabilities of shape [num_out_tokens] or [worst_case_padded_size],
        or None if probs was not provided.
99
100
    row_id_map : jnp.ndarray
        Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]).
101
102
103
    pad_offsets : Optional[jnp.ndarray]
        Per-expert cumulative padding offsets of shape [num_experts] when using padding,
        None otherwise. Pass this to token_combine when unpadding is needed.
104
105
106
107
108
    tokens_per_expert : jnp.ndarray
        Token counts per expert of shape [num_experts]:
        - Without padding: actual token counts (sum of routing_map columns)
        - With padding: aligned token counts (ceil(actual / align_size) * align_size)
        This gives the effective number of tokens per expert in the output buffer.
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    Note
    ----
    **JIT Compatibility:**

    This function is fully JIT-compatible. When using padding (align_size provided),
    the output buffer is allocated with a fixed worst-case size that depends only on
    compile-time constants (num_out_tokens, num_experts, align_size). The actual
    padding offsets (pad_offsets) and aligned token counts (target_tokens_per_expert)
    are computed internally from the routing_map and can be traced values.

    The worst-case output size is:
    ((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
    This accounts for the maximum possible padding when each expert needs (align_size - 1)
    extra tokens to align, rounded down to align_size for buffer alignment.
124
    """
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    use_padding = align_size is not None
    num_experts = routing_map.shape[-1]

    if use_padding:
        # Compute worst-case output size (compile-time constant)
        # This is the maximum possible size when each expert needs max padding
        worst_case_out_tokens = (
            (num_out_tokens + num_experts * (align_size - 1)) // align_size
        ) * align_size
    else:
        worst_case_out_tokens = num_out_tokens

    return _token_dispatch(
        inp, routing_map, probs, num_out_tokens, worst_case_out_tokens, align_size, use_padding
    )
140
141


142
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
143
144
145
146
147
def _token_dispatch(
    inp: jnp.ndarray,
    routing_map: jnp.ndarray,
    probs: Optional[jnp.ndarray],
    num_out_tokens: int,
148
149
150
151
152
153
154
155
    worst_case_out_tokens: int,
    align_size: Optional[int],
    use_padding: bool,
) -> Tuple[
    jnp.ndarray,
    Optional[jnp.ndarray],
    jnp.ndarray,
    Optional[jnp.ndarray],
156
    jnp.ndarray,
157
]:
158
    """Internal token_dispatch with custom VJP."""
159
    (output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert), _ = (
160
161
162
163
164
165
166
167
168
        _token_dispatch_fwd_rule(
            inp,
            routing_map,
            probs,
            num_out_tokens,
            worst_case_out_tokens,
            align_size,
            use_padding,
        )
169
    )
170
    return output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert
171
172
173
174
175
176
177


def _token_dispatch_fwd_rule(
    inp: jnp.ndarray,
    routing_map: jnp.ndarray,
    probs: Optional[jnp.ndarray],
    num_out_tokens: int,
178
179
180
    worst_case_out_tokens: int,
    align_size: Optional[int],
    use_padding: bool,
181
) -> Tuple[
182
183
184
185
186
    Tuple[
        jnp.ndarray,
        Optional[jnp.ndarray],
        jnp.ndarray,
        Optional[jnp.ndarray],
187
        jnp.ndarray,
188
189
    ],
    Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
]:
    """Forward pass rule for token_dispatch."""
    # Validate input dimensions
    assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
    assert routing_map.ndim in [2, 3], f"routing_map must be 2D or 3D, got {routing_map.ndim}D"

    # Infer dimensions from input shapes
    num_tokens = inp.shape[0] * inp.shape[1] if inp.ndim == 3 else inp.shape[0]
    hidden_size = inp.shape[-1]
    num_experts = routing_map.shape[-1]

    # Verify consistency between inp and routing_map
    routing_num_tokens = (
        routing_map.shape[0] * routing_map.shape[1]
        if routing_map.ndim == 3
        else routing_map.shape[0]
    )
    assert num_tokens == routing_num_tokens, (
        f"Token count mismatch: inp has {num_tokens} tokens, "
        f"routing_map has {routing_num_tokens} tokens"
    )

    # Always compute row_id_map internally from routing_map
    row_id_map = make_row_id_map(routing_map, num_tokens, num_experts)

    with_probs = probs is not None

217
218
219
    # Compute tokens_per_expert from routing_map (actual counts)
    # This is well-optimized by XLA as a simple column-wise reduction
    tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)
220

221
    if use_padding:
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        # Calculate aligned token counts per expert
        target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype(
            jnp.int32
        )

        # Compute pad_offsets: cumulative padding for each expert
        # pad_offsets[i] = sum of (target - actual) for experts 0..i-1
        pad_lengths = target_tokens_per_expert - tokens_per_expert
        cum_pad = jnp.cumsum(pad_lengths)
        pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]])

        # Use worst_case_out_tokens as the output buffer size (compile-time constant)
        # The actual used size is sum(target_tokens_per_expert), which may be smaller.
        # Unused positions will be zero-initialized by the kernel.
        output, permuted_probs = permute_with_mask_map_and_pad(
            inp,
            row_id_map,
            probs,
            pad_offsets,
            num_tokens,
            num_experts,
            worst_case_out_tokens,
            hidden_size,
245
            align_size=align_size,
246
        )
247
248
249

        # Return aligned counts when using padding
        out_tokens_per_expert = target_tokens_per_expert
250
251
252
253
254
255
256
257
258
259
260
261
262
    else:
        # No padding
        pad_offsets = None

        output, permuted_probs = permute_with_mask_map(
            inp,
            row_id_map,
            probs,
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
        )
263

264
265
266
        # Return actual counts when not using padding
        out_tokens_per_expert = tokens_per_expert

267
    # Return (primals, residuals)
268
269
270
    # out_tokens_per_expert is:
    #   - target_tokens_per_expert (aligned) when using padding
    #   - tokens_per_expert (actual) when not using padding
271
272
273
274
275
276
    residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs)
    return (
        output,
        permuted_probs,
        row_id_map,
        pad_offsets,
277
        out_tokens_per_expert,
278
    ), residuals
279
280
281
282


def _token_dispatch_bwd_rule(
    _num_out_tokens: int,
283
284
285
286
287
288
289
290
291
292
293
    _worst_case_out_tokens: int,
    _align_size: Optional[int],
    _use_padding: bool,
    residuals: Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
    g: Tuple[
        jnp.ndarray,
        Optional[jnp.ndarray],
        jnp.ndarray,
        Optional[jnp.ndarray],
        Optional[jnp.ndarray],
    ],
294
295
296
297
298
299
) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray]]:
    """Backward pass rule for token_dispatch.

    Returns gradients for (inp, routing_map, probs).
    routing_map gradient is None since it's a discrete routing decision.
    """
300
301
    row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals
    output_grad, permuted_probs_grad, _, _, _ = g  # Ignore row_id_map, pad_offsets, target grads
302
303

    # Backward: unpermute gradients (gather from experts back to tokens)
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    if pad_offsets is not None:
        inp_grad, probs_grad = unpermute_with_mask_map_and_unpad(
            output_grad,
            row_id_map,
            None,  # No merging probs
            permuted_probs_grad if with_probs else None,
            pad_offsets,
            num_tokens,
            num_experts,
            hidden_size,
        )
    else:
        inp_grad, probs_grad = unpermute_with_mask_map(
            output_grad,
            row_id_map,
            None,  # No merging probs
            permuted_probs_grad if with_probs else None,
            num_tokens,
            num_experts,
            hidden_size,
        )
325

326
327
328
    # Return gradients for (inp, routing_map, probs)
    # routing_map is non-differentiable (discrete routing), so return None
    return inp_grad, None, probs_grad if with_probs else None
329
330
331
332
333
334
335
336
337
338
339
340
341
342


_token_dispatch.defvjp(_token_dispatch_fwd_rule, _token_dispatch_bwd_rule)


# =============================================================================
# Token Combine (Unpermute) with VJP
# =============================================================================


def token_combine(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: Optional[jnp.ndarray] = None,
343
    pad_offsets: Optional[jnp.ndarray] = None,
344
345
346
347
348
349
350
) -> jnp.ndarray:
    """
    Combine tokens from experts back to original token positions.

    This is the forward pass of MoE unpermutation. Tokens are gathered from
    experts and merged (optionally weighted by merging_probs).

351
352
353
    Optionally supports fused unpadding when `pad_offsets` is provided (from
    token_dispatch with padding enabled).

354
355
356
    Parameters
    ----------
    inp : jnp.ndarray
357
358
        Input tensor from experts of shape [num_out_tokens, hidden_size]
        (or [num_out_tokens_padded, hidden_size] when using unpadding).
359
360
361
362
363
364
    row_id_map : jnp.ndarray
        Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1].
    merging_probs : Optional[jnp.ndarray]
        Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
        If provided, tokens from different experts are weighted-summed.
        If None, tokens are summed directly.
365
366
367
368
    pad_offsets : Optional[jnp.ndarray]
        Per-expert cumulative padding offsets of shape [num_experts] from token_dispatch.
        If provided, fused unpadding will be performed. This should be the pad_offsets
        returned by token_dispatch when using padding.
369
370
371
372
373
374

    Returns
    -------
    output : jnp.ndarray
        Combined output tensor of shape [num_tokens, hidden_size].
    """
375
    return _token_combine(inp, row_id_map, merging_probs, pad_offsets)
376
377


378
@jax.custom_vjp
379
380
381
382
def _token_combine(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: Optional[jnp.ndarray],
383
    pad_offsets: Optional[jnp.ndarray],
384
385
) -> jnp.ndarray:
    """Internal token_combine with custom VJP."""
386
    output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs, pad_offsets)
387
388
389
390
391
392
393
    return output


def _token_combine_fwd_rule(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: Optional[jnp.ndarray],
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    pad_offsets: Optional[jnp.ndarray],
) -> Tuple[
    jnp.ndarray,
    Tuple[
        jnp.ndarray,
        Optional[jnp.ndarray],
        jnp.ndarray,
        Optional[jnp.ndarray],
        int,
        int,
        int,
        int,
    ],
]:
408
409
410
411
412
413
414
    """Forward pass rule for token_combine."""
    # Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1]
    num_tokens = row_id_map.shape[0]
    num_experts = (row_id_map.shape[1] - 1) // 2
    hidden_size = inp.shape[-1]
    num_out_tokens = inp.shape[0]

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    # Call triton extension with or without unpadding
    if pad_offsets is not None:
        output, _ = unpermute_with_mask_map_and_unpad(
            inp,
            row_id_map,
            merging_probs,
            None,  # No permuted probs to unpermute
            pad_offsets,
            num_tokens,
            num_experts,
            hidden_size,
        )
    else:
        output, _ = unpermute_with_mask_map(
            inp,
            row_id_map,
            merging_probs,
            None,  # No permuted probs to unpermute
            num_tokens,
            num_experts,
            hidden_size,
        )
437
438
439
440
441

    # Return (primal, residuals)
    # Include inp in residuals for backward with merging_probs
    residuals = (
        row_id_map,
442
        pad_offsets,
443
444
445
446
447
448
449
450
451
452
453
        inp,
        merging_probs,
        num_tokens,
        num_experts,
        hidden_size,
        num_out_tokens,
    )
    return output, residuals


def _token_combine_bwd_rule(
454
455
456
457
458
459
460
461
462
463
    residuals: Tuple[
        jnp.ndarray,
        Optional[jnp.ndarray],
        jnp.ndarray,
        Optional[jnp.ndarray],
        int,
        int,
        int,
        int,
    ],
464
    g: jnp.ndarray,
465
466
467
468
469
470
) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray], None]:
    """Backward pass rule for token_combine.

    Returns gradients for: (inp, row_id_map, merging_probs, pad_offsets)
    row_id_map and pad_offsets are integer arrays, so their gradients are None.
    """
471
472
    (
        row_id_map,
473
        pad_offsets,
474
475
476
477
478
479
480
481
482
483
484
485
486
        fwd_input,
        merging_probs,
        num_tokens,
        num_experts,
        hidden_size,
        num_out_tokens,
    ) = residuals
    output_grad = g

    with_merging_probs = merging_probs is not None

    if with_merging_probs:
        # Use specialized backward kernel that properly scales by merging_probs
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        if pad_offsets is not None:
            inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs_and_unpad(
                output_grad,
                row_id_map,
                fwd_input,
                merging_probs,
                pad_offsets,
                num_tokens,
                num_experts,
                num_out_tokens,
                hidden_size,
            )
            # The backward kernel only writes to positions that tokens map to.
            # Padded positions may contain uninitialized (NaN) values - replace with zeros.
            inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
        else:
            inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs(
                output_grad,
                row_id_map,
                fwd_input,
                merging_probs,
                num_tokens,
                num_experts,
                num_out_tokens,
                hidden_size,
            )
513
514
    else:
        # Simple case: just permute gradients back
515
        if pad_offsets is not None:
516
517
            # Note: align_size uses default (128) since buffer sizes are already
            # determined from forward pass (stored in residuals as num_out_tokens)
518
519
520
521
522
523
524
525
526
            inp_grad, _ = permute_with_mask_map_and_pad(
                output_grad,
                row_id_map,
                None,
                pad_offsets,
                num_tokens,
                num_experts,
                num_out_tokens,
                hidden_size,
527
                align_size=128,  # Default, sizes already computed in forward
528
529
530
531
532
533
534
535
536
537
538
539
540
541
            )
            # The permute kernel only writes to positions that tokens map to.
            # Padded positions may contain uninitialized (NaN) values - replace with zeros.
            inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
        else:
            inp_grad, _ = permute_with_mask_map(
                output_grad,
                row_id_map,
                None,
                num_tokens,
                num_experts,
                num_out_tokens,
                hidden_size,
            )
542
543
        merging_probs_grad = None

544
545
546
    # Return gradients for: inp, row_id_map, merging_probs, pad_offsets
    # row_id_map and pad_offsets are integer arrays, so their gradients are None
    return inp_grad, None, merging_probs_grad, None
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648


_token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule)


# =============================================================================
# Chunk Sort with VJP
# =============================================================================


def sort_chunks_by_index(
    inp: jnp.ndarray,
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Sort chunks of tokens according to sorted indices.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [batch, sequence, hidden_size] or [num_tokens, hidden_size].
    split_sizes : jnp.ndarray
        Sizes of each chunk of shape [num_splits].
    sorted_indices : jnp.ndarray
        Permutation indices for chunks of shape [num_splits].

    Returns
    -------
    output : jnp.ndarray
        Sorted output tensor of shape [num_tokens, hidden_size].
    row_id_map : jnp.ndarray
        Row ID map for reversing the sort.
    """
    return _sort_chunks_by_index(inp, split_sizes, sorted_indices)


@partial(jax.custom_vjp, nondiff_argnums=(1, 2))
def _sort_chunks_by_index(
    inp: jnp.ndarray,
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Internal sort_chunks_by_index with custom VJP."""
    (output, row_id_map), _ = _sort_chunks_by_index_fwd_rule(inp, split_sizes, sorted_indices)
    return output, row_id_map


def _sort_chunks_by_index_fwd_rule(
    inp: jnp.ndarray,
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]:
    """Forward pass rule for sort_chunks_by_index."""
    # Validate input dimensions
    assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"

    # Infer dimensions from input shape
    num_tokens = inp.shape[0] * inp.shape[1] if inp.ndim == 3 else inp.shape[0]
    hidden_size = inp.shape[-1]
    num_splits = split_sizes.shape[0]

    row_id_map = make_chunk_sort_map(split_sizes, sorted_indices, num_tokens, num_splits)

    output, _ = sort_chunks_by_map(
        inp,
        row_id_map,
        None,  # No probs
        num_tokens,
        hidden_size,
        is_forward=True,
    )

    # Return (primals, residuals)
    residuals = (row_id_map, num_tokens, hidden_size)
    return (output, row_id_map), residuals


def _sort_chunks_by_index_bwd_rule(
    _split_sizes: jnp.ndarray,
    _sorted_indices: jnp.ndarray,
    residuals: Tuple[jnp.ndarray, int, int],
    g: Tuple[jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray]:
    """Backward pass rule for sort_chunks_by_index."""
    row_id_map, num_tokens, hidden_size = residuals
    output_grad, _ = g

    # Backward: reverse the sort
    inp_grad, _ = sort_chunks_by_map(
        output_grad,
        row_id_map,
        None,
        num_tokens,
        hidden_size,
        is_forward=False,
    )

    return (inp_grad,)


_sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)