permutation.py 21 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#
# See LICENSE for license information.

"""MoE Permutation API for JAX.

This module provides high-level token dispatch and combine operations for
Mixture of Experts (MoE) models with proper automatic differentiation support.

Token Dispatch (Permute):
    - Forward: Permute tokens according to routing map (scatter to experts)
    - Backward: Unpermute gradients (gather from experts)

Token Combine (Unpermute):
    - Forward: Unpermute tokens and merge with weights (gather from experts)
    - Backward: Permute gradients (scatter to experts)
"""

from functools import partial
from typing import Optional, Tuple

import jax
import jax.numpy as jnp

from transformer_engine.jax.triton_extensions.permutation import (
    make_row_id_map,
    permute_with_mask_map,
28
    permute_with_mask_map_and_pad,
29
    unpermute_with_mask_map,
30
    unpermute_with_mask_map_and_unpad,
31
    unpermute_bwd_with_merging_probs,
32
    unpermute_bwd_with_merging_probs_and_unpad,
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    make_chunk_sort_map,
    sort_chunks_by_map,
)

__all__ = [
    "token_dispatch",
    "token_combine",
    "sort_chunks_by_index",
]


def token_dispatch(
    inp: jnp.ndarray,
    routing_map: jnp.ndarray,
    num_out_tokens: int,
    probs: Optional[jnp.ndarray] = None,
49
50
51
52
53
54
55
56
    align_size: Optional[int] = None,
) -> Tuple[
    jnp.ndarray,
    Optional[jnp.ndarray],
    jnp.ndarray,
    Optional[jnp.ndarray],
    Optional[jnp.ndarray],
]:
57
58
59
60
61
62
63
    """
    Dispatch tokens to experts based on routing map.

    This is the forward pass of the MoE permutation. Tokens are scattered
    to their designated experts according to the routing map. The row_id_map
    is computed internally from the routing_map.

64
65
66
67
    Optionally supports fused padding for alignment when `align_size` is provided.
    This is useful for efficient matrix multiplications that require aligned tensor
    dimensions. The padding is computed internally from the routing_map.

68
69
70
71
72
73
74
75
    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [batch, sequence, hidden_size] or [num_tokens, hidden_size].
    routing_map : jnp.ndarray
        Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
        Values: 1 = routed, 0 = not routed.
    num_out_tokens : int
76
77
78
        The number of output tokens after permutation (before padding). For the dropless
        case, this should be equal to the sum of routing_map. Must be provided explicitly
        for JIT compatibility since output shape must be known at compile time.
79
80
81
    probs : Optional[jnp.ndarray]
        Optional routing probabilities of shape [batch, sequence, num_experts] or
        [num_tokens, num_experts]. If provided, permuted_probs will be returned.
82
83
84
85
86
87
    align_size : Optional[int]
        Optional alignment size for padding. If provided, outputs will be padded to
        align each expert's tokens to a multiple of this size. The output buffer is
        allocated with worst-case size, rounded down to align_size:
        ((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
        This enables full JIT compatibility.
88
89
90
91

    Returns
    -------
    output : jnp.ndarray
92
93
94
95
        Permuted output tensor of shape [num_out_tokens, hidden_size] without padding,
        or [worst_case_padded_size, hidden_size] when using padding fusion.
        With padding, the actual used portion may be smaller than the buffer; check
        actual_num_out_tokens (sum of target_tokens_per_expert) for the actual size.
96
    permuted_probs : Optional[jnp.ndarray]
97
98
        Permuted probabilities of shape [num_out_tokens] or [worst_case_padded_size],
        or None if probs was not provided.
99
100
    row_id_map : jnp.ndarray
        Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]).
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    pad_offsets : Optional[jnp.ndarray]
        Per-expert cumulative padding offsets of shape [num_experts] when using padding,
        None otherwise. Pass this to token_combine when unpadding is needed.
    target_tokens_per_expert : Optional[jnp.ndarray]
        Aligned token counts per expert of shape [num_experts] when using padding,
        None otherwise.

    Note
    ----
    **JIT Compatibility:**

    This function is fully JIT-compatible. When using padding (align_size provided),
    the output buffer is allocated with a fixed worst-case size that depends only on
    compile-time constants (num_out_tokens, num_experts, align_size). The actual
    padding offsets (pad_offsets) and aligned token counts (target_tokens_per_expert)
    are computed internally from the routing_map and can be traced values.

    The worst-case output size is:
    ((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
    This accounts for the maximum possible padding when each expert needs (align_size - 1)
    extra tokens to align, rounded down to align_size for buffer alignment.
122
    """
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    use_padding = align_size is not None
    num_experts = routing_map.shape[-1]

    if use_padding:
        # Compute worst-case output size (compile-time constant)
        # This is the maximum possible size when each expert needs max padding
        worst_case_out_tokens = (
            (num_out_tokens + num_experts * (align_size - 1)) // align_size
        ) * align_size
    else:
        worst_case_out_tokens = num_out_tokens

    return _token_dispatch(
        inp, routing_map, probs, num_out_tokens, worst_case_out_tokens, align_size, use_padding
    )
138
139


140
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
141
142
143
144
145
def _token_dispatch(
    inp: jnp.ndarray,
    routing_map: jnp.ndarray,
    probs: Optional[jnp.ndarray],
    num_out_tokens: int,
146
147
148
149
150
151
152
153
154
155
    worst_case_out_tokens: int,
    align_size: Optional[int],
    use_padding: bool,
) -> Tuple[
    jnp.ndarray,
    Optional[jnp.ndarray],
    jnp.ndarray,
    Optional[jnp.ndarray],
    Optional[jnp.ndarray],
]:
156
    """Internal token_dispatch with custom VJP."""
157
158
159
160
161
162
163
164
165
166
    (output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = (
        _token_dispatch_fwd_rule(
            inp,
            routing_map,
            probs,
            num_out_tokens,
            worst_case_out_tokens,
            align_size,
            use_padding,
        )
167
    )
168
    return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
169
170
171
172
173
174
175


def _token_dispatch_fwd_rule(
    inp: jnp.ndarray,
    routing_map: jnp.ndarray,
    probs: Optional[jnp.ndarray],
    num_out_tokens: int,
176
177
178
    worst_case_out_tokens: int,
    align_size: Optional[int],
    use_padding: bool,
179
) -> Tuple[
180
181
182
183
184
185
186
187
    Tuple[
        jnp.ndarray,
        Optional[jnp.ndarray],
        jnp.ndarray,
        Optional[jnp.ndarray],
        Optional[jnp.ndarray],
    ],
    Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
]:
    """Forward pass rule for token_dispatch."""
    # Validate input dimensions
    assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
    assert routing_map.ndim in [2, 3], f"routing_map must be 2D or 3D, got {routing_map.ndim}D"

    # Infer dimensions from input shapes
    num_tokens = inp.shape[0] * inp.shape[1] if inp.ndim == 3 else inp.shape[0]
    hidden_size = inp.shape[-1]
    num_experts = routing_map.shape[-1]

    # Verify consistency between inp and routing_map
    routing_num_tokens = (
        routing_map.shape[0] * routing_map.shape[1]
        if routing_map.ndim == 3
        else routing_map.shape[0]
    )
    assert num_tokens == routing_num_tokens, (
        f"Token count mismatch: inp has {num_tokens} tokens, "
        f"routing_map has {routing_num_tokens} tokens"
    )

    # Always compute row_id_map internally from routing_map
    row_id_map = make_row_id_map(routing_map, num_tokens, num_experts)

    with_probs = probs is not None

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    if use_padding:
        # Compute tokens_per_expert internally from routing_map
        # This can be a traced value since output shape uses worst_case_out_tokens
        tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)

        # Calculate aligned token counts per expert
        target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype(
            jnp.int32
        )

        # Compute pad_offsets: cumulative padding for each expert
        # pad_offsets[i] = sum of (target - actual) for experts 0..i-1
        pad_lengths = target_tokens_per_expert - tokens_per_expert
        cum_pad = jnp.cumsum(pad_lengths)
        pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]])

        # Use worst_case_out_tokens as the output buffer size (compile-time constant)
        # The actual used size is sum(target_tokens_per_expert), which may be smaller.
        # Unused positions will be zero-initialized by the kernel.
        output, permuted_probs = permute_with_mask_map_and_pad(
            inp,
            row_id_map,
            probs,
            pad_offsets,
            num_tokens,
            num_experts,
            worst_case_out_tokens,
            hidden_size,
243
            align_size=align_size,
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        )
    else:
        # No padding
        pad_offsets = None
        target_tokens_per_expert = None

        output, permuted_probs = permute_with_mask_map(
            inp,
            row_id_map,
            probs,
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
        )
259
260

    # Return (primals, residuals)
261
262
263
264
265
266
267
268
    residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs)
    return (
        output,
        permuted_probs,
        row_id_map,
        pad_offsets,
        target_tokens_per_expert,
    ), residuals
269
270
271
272


def _token_dispatch_bwd_rule(
    _num_out_tokens: int,
273
274
275
276
277
278
279
280
281
282
283
    _worst_case_out_tokens: int,
    _align_size: Optional[int],
    _use_padding: bool,
    residuals: Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
    g: Tuple[
        jnp.ndarray,
        Optional[jnp.ndarray],
        jnp.ndarray,
        Optional[jnp.ndarray],
        Optional[jnp.ndarray],
    ],
284
285
286
287
288
289
) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray]]:
    """Backward pass rule for token_dispatch.

    Returns gradients for (inp, routing_map, probs).
    routing_map gradient is None since it's a discrete routing decision.
    """
290
291
    row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals
    output_grad, permuted_probs_grad, _, _, _ = g  # Ignore row_id_map, pad_offsets, target grads
292
293

    # Backward: unpermute gradients (gather from experts back to tokens)
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    if pad_offsets is not None:
        inp_grad, probs_grad = unpermute_with_mask_map_and_unpad(
            output_grad,
            row_id_map,
            None,  # No merging probs
            permuted_probs_grad if with_probs else None,
            pad_offsets,
            num_tokens,
            num_experts,
            hidden_size,
        )
    else:
        inp_grad, probs_grad = unpermute_with_mask_map(
            output_grad,
            row_id_map,
            None,  # No merging probs
            permuted_probs_grad if with_probs else None,
            num_tokens,
            num_experts,
            hidden_size,
        )
315

316
317
318
    # Return gradients for (inp, routing_map, probs)
    # routing_map is non-differentiable (discrete routing), so return None
    return inp_grad, None, probs_grad if with_probs else None
319
320
321
322
323
324
325
326
327
328
329
330
331
332


_token_dispatch.defvjp(_token_dispatch_fwd_rule, _token_dispatch_bwd_rule)


# =============================================================================
# Token Combine (Unpermute) with VJP
# =============================================================================


def token_combine(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: Optional[jnp.ndarray] = None,
333
    pad_offsets: Optional[jnp.ndarray] = None,
334
335
336
337
338
339
340
) -> jnp.ndarray:
    """
    Combine tokens from experts back to original token positions.

    This is the forward pass of MoE unpermutation. Tokens are gathered from
    experts and merged (optionally weighted by merging_probs).

341
342
343
    Optionally supports fused unpadding when `pad_offsets` is provided (from
    token_dispatch with padding enabled).

344
345
346
    Parameters
    ----------
    inp : jnp.ndarray
347
348
        Input tensor from experts of shape [num_out_tokens, hidden_size]
        (or [num_out_tokens_padded, hidden_size] when using unpadding).
349
350
351
352
353
354
    row_id_map : jnp.ndarray
        Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1].
    merging_probs : Optional[jnp.ndarray]
        Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
        If provided, tokens from different experts are weighted-summed.
        If None, tokens are summed directly.
355
356
357
358
    pad_offsets : Optional[jnp.ndarray]
        Per-expert cumulative padding offsets of shape [num_experts] from token_dispatch.
        If provided, fused unpadding will be performed. This should be the pad_offsets
        returned by token_dispatch when using padding.
359
360
361
362
363
364

    Returns
    -------
    output : jnp.ndarray
        Combined output tensor of shape [num_tokens, hidden_size].
    """
365
    return _token_combine(inp, row_id_map, merging_probs, pad_offsets)
366
367


368
@jax.custom_vjp
369
370
371
372
def _token_combine(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: Optional[jnp.ndarray],
373
    pad_offsets: Optional[jnp.ndarray],
374
375
) -> jnp.ndarray:
    """Internal token_combine with custom VJP."""
376
    output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs, pad_offsets)
377
378
379
380
381
382
383
    return output


def _token_combine_fwd_rule(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: Optional[jnp.ndarray],
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    pad_offsets: Optional[jnp.ndarray],
) -> Tuple[
    jnp.ndarray,
    Tuple[
        jnp.ndarray,
        Optional[jnp.ndarray],
        jnp.ndarray,
        Optional[jnp.ndarray],
        int,
        int,
        int,
        int,
    ],
]:
398
399
400
401
402
403
404
    """Forward pass rule for token_combine."""
    # Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1]
    num_tokens = row_id_map.shape[0]
    num_experts = (row_id_map.shape[1] - 1) // 2
    hidden_size = inp.shape[-1]
    num_out_tokens = inp.shape[0]

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    # Call triton extension with or without unpadding
    if pad_offsets is not None:
        output, _ = unpermute_with_mask_map_and_unpad(
            inp,
            row_id_map,
            merging_probs,
            None,  # No permuted probs to unpermute
            pad_offsets,
            num_tokens,
            num_experts,
            hidden_size,
        )
    else:
        output, _ = unpermute_with_mask_map(
            inp,
            row_id_map,
            merging_probs,
            None,  # No permuted probs to unpermute
            num_tokens,
            num_experts,
            hidden_size,
        )
427
428
429
430
431

    # Return (primal, residuals)
    # Include inp in residuals for backward with merging_probs
    residuals = (
        row_id_map,
432
        pad_offsets,
433
434
435
436
437
438
439
440
441
442
443
        inp,
        merging_probs,
        num_tokens,
        num_experts,
        hidden_size,
        num_out_tokens,
    )
    return output, residuals


def _token_combine_bwd_rule(
444
445
446
447
448
449
450
451
452
453
    residuals: Tuple[
        jnp.ndarray,
        Optional[jnp.ndarray],
        jnp.ndarray,
        Optional[jnp.ndarray],
        int,
        int,
        int,
        int,
    ],
454
    g: jnp.ndarray,
455
456
457
458
459
460
) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray], None]:
    """Backward pass rule for token_combine.

    Returns gradients for: (inp, row_id_map, merging_probs, pad_offsets)
    row_id_map and pad_offsets are integer arrays, so their gradients are None.
    """
461
462
    (
        row_id_map,
463
        pad_offsets,
464
465
466
467
468
469
470
471
472
473
474
475
476
        fwd_input,
        merging_probs,
        num_tokens,
        num_experts,
        hidden_size,
        num_out_tokens,
    ) = residuals
    output_grad = g

    with_merging_probs = merging_probs is not None

    if with_merging_probs:
        # Use specialized backward kernel that properly scales by merging_probs
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        if pad_offsets is not None:
            inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs_and_unpad(
                output_grad,
                row_id_map,
                fwd_input,
                merging_probs,
                pad_offsets,
                num_tokens,
                num_experts,
                num_out_tokens,
                hidden_size,
            )
            # The backward kernel only writes to positions that tokens map to.
            # Padded positions may contain uninitialized (NaN) values - replace with zeros.
            inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
        else:
            inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs(
                output_grad,
                row_id_map,
                fwd_input,
                merging_probs,
                num_tokens,
                num_experts,
                num_out_tokens,
                hidden_size,
            )
503
504
    else:
        # Simple case: just permute gradients back
505
        if pad_offsets is not None:
506
507
            # Note: align_size uses default (128) since buffer sizes are already
            # determined from forward pass (stored in residuals as num_out_tokens)
508
509
510
511
512
513
514
515
516
            inp_grad, _ = permute_with_mask_map_and_pad(
                output_grad,
                row_id_map,
                None,
                pad_offsets,
                num_tokens,
                num_experts,
                num_out_tokens,
                hidden_size,
517
                align_size=128,  # Default, sizes already computed in forward
518
519
520
521
522
523
524
525
526
527
528
529
530
531
            )
            # The permute kernel only writes to positions that tokens map to.
            # Padded positions may contain uninitialized (NaN) values - replace with zeros.
            inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
        else:
            inp_grad, _ = permute_with_mask_map(
                output_grad,
                row_id_map,
                None,
                num_tokens,
                num_experts,
                num_out_tokens,
                hidden_size,
            )
532
533
        merging_probs_grad = None

534
535
536
    # Return gradients for: inp, row_id_map, merging_probs, pad_offsets
    # row_id_map and pad_offsets are integer arrays, so their gradients are None
    return inp_grad, None, merging_probs_grad, None
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638


_token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule)


# =============================================================================
# Chunk Sort with VJP
# =============================================================================


def sort_chunks_by_index(
    inp: jnp.ndarray,
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Sort chunks of tokens according to sorted indices.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape [batch, sequence, hidden_size] or [num_tokens, hidden_size].
    split_sizes : jnp.ndarray
        Sizes of each chunk of shape [num_splits].
    sorted_indices : jnp.ndarray
        Permutation indices for chunks of shape [num_splits].

    Returns
    -------
    output : jnp.ndarray
        Sorted output tensor of shape [num_tokens, hidden_size].
    row_id_map : jnp.ndarray
        Row ID map for reversing the sort.
    """
    return _sort_chunks_by_index(inp, split_sizes, sorted_indices)


@partial(jax.custom_vjp, nondiff_argnums=(1, 2))
def _sort_chunks_by_index(
    inp: jnp.ndarray,
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Internal sort_chunks_by_index with custom VJP."""
    (output, row_id_map), _ = _sort_chunks_by_index_fwd_rule(inp, split_sizes, sorted_indices)
    return output, row_id_map


def _sort_chunks_by_index_fwd_rule(
    inp: jnp.ndarray,
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]:
    """Forward pass rule for sort_chunks_by_index."""
    # Validate input dimensions
    assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"

    # Infer dimensions from input shape
    num_tokens = inp.shape[0] * inp.shape[1] if inp.ndim == 3 else inp.shape[0]
    hidden_size = inp.shape[-1]
    num_splits = split_sizes.shape[0]

    row_id_map = make_chunk_sort_map(split_sizes, sorted_indices, num_tokens, num_splits)

    output, _ = sort_chunks_by_map(
        inp,
        row_id_map,
        None,  # No probs
        num_tokens,
        hidden_size,
        is_forward=True,
    )

    # Return (primals, residuals)
    residuals = (row_id_map, num_tokens, hidden_size)
    return (output, row_id_map), residuals


def _sort_chunks_by_index_bwd_rule(
    _split_sizes: jnp.ndarray,
    _sorted_indices: jnp.ndarray,
    residuals: Tuple[jnp.ndarray, int, int],
    g: Tuple[jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray]:
    """Backward pass rule for sort_chunks_by_index."""
    row_id_map, num_tokens, hidden_size = residuals
    output_grad, _ = g

    # Backward: reverse the sort
    inp_grad, _ = sort_chunks_by_map(
        output_grad,
        row_id_map,
        None,
        num_tokens,
        hidden_size,
        is_forward=False,
    )

    return (inp_grad,)


_sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)