permutation.py 77.6 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

"""JAX/TE custom ops for permutation in MOE using Triton kernels."""

from typing import Optional, Tuple

import jax
import jax.numpy as jnp
11
12
from jax.sharding import PartitionSpec
from jax.experimental.custom_partitioning import SdyShardingRule
13
14
15
import triton

from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive
16
17
from transformer_engine.jax.cpp_extensions.misc import get_padded_spec, NamedSharding
from transformer_engine.jax.sharding import get_mesh_axis_size
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformer_engine.common.triton.permutation import (
    _row_id_map_pass_1_kernel,
    _row_id_map_pass_2_kernel,
    _row_id_map_pass_3_kernel,
    _permute_kernel,
    _unpermute_kernel,
    _unpermute_bwd_with_merging_probs_kernel,
    _make_chunk_sort_map_kernel,
    _sort_chunks_by_map_kernel,
)
from .utils import triton_call_lowering


__all__ = [
    "make_row_id_map",
    "permute_with_mask_map",
34
    "permute_with_mask_map_and_pad",
35
    "unpermute_with_mask_map",
36
    "unpermute_with_mask_map_and_unpad",
37
    "unpermute_bwd_with_merging_probs",
38
    "unpermute_bwd_with_merging_probs_and_unpad",
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    "make_chunk_sort_map",
    "sort_chunks_by_map",
]

DEFAULT_BLOCK_SIZE = 1024


def _get_min_block_size(kernel, default=128):
    if hasattr(kernel, "configs"):
        return min(config.kwargs.get("BLOCK_SIZE", default) for config in kernel.configs)
    return default


class RowIdMapPass1Primitive(BasePrimitive):
    """
    Pass 1 of row_id_map generation: block cumsum.

    For each expert, compute the cumsum of every block_size tokens.
    """

    name = "te_row_id_map_pass1_triton"
    multiple_results = True
    impl_static_args = (1, 2, 3)  # num_tokens, num_experts, block_size
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(routing_map_aval, *, num_tokens, num_experts, block_size):
        """Shape/dtype inference for pass 1."""
        del block_size  # Only affects grid, not output shape

        assert routing_map_aval.shape == (
            num_tokens,
            num_experts,
        ), f"routing_map shape mismatch: expected ({num_tokens}, {num_experts})"

        row_id_map_shape = (num_tokens, num_experts * 2 + 1)
        workspace_shape = (
            num_experts,
            triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE),
        )

        return (
            jax.core.ShapedArray(row_id_map_shape, jnp.int32),
            jax.core.ShapedArray(workspace_shape, jnp.int32),
        )

    @staticmethod
    def impl(routing_map, num_tokens, num_experts, block_size):
        """Forward to inner primitive."""
        assert RowIdMapPass1Primitive.inner_primitive is not None
        return RowIdMapPass1Primitive.inner_primitive.bind(
            routing_map,
            num_tokens=num_tokens,
            num_experts=num_experts,
            block_size=block_size,
        )

    @staticmethod
    def lowering(ctx, routing_map, *, num_tokens, num_experts, block_size):
        """MLIR lowering using triton_call_lowering."""
        routing_stride_token = num_experts
        routing_stride_expert = 1
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1

        grid = (num_experts, triton.cdiv(num_tokens, block_size))

        return triton_call_lowering(
            ctx,
            _row_id_map_pass_1_kernel,
110
            routing_map,
111
112
113
114
115
116
117
118
119
120
121
            grid=grid,
            constexprs={
                "num_tokens": num_tokens,
                "stride_routing_map_token": routing_stride_token,
                "stride_routing_map_expert": routing_stride_expert,
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "BLOCK_SIZE": block_size,
            },
        )

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    @staticmethod
    def infer_sharding_from_operands(
        num_tokens, num_experts, block_size, mesh, arg_infos, result_infos
    ):
        """Infer output sharding from input sharding."""
        del num_tokens, num_experts, block_size, result_infos
        routing_map_spec = get_padded_spec(arg_infos[0])
        # row_id_map has same token dimension sharding as routing_map
        # Shape: (num_tokens, num_experts * 2 + 1)
        row_id_map_sharding = NamedSharding(
            mesh,
            PartitionSpec(routing_map_spec[0], None),
            desc="RowIdMapPass1.row_id_map_sharding",
        )
        # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
        workspace_sharding = NamedSharding(
            mesh,
            PartitionSpec(None, None),
            desc="RowIdMapPass1.workspace_sharding",
        )
        return [row_id_map_sharding, workspace_sharding]

    @staticmethod
    def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos):
        """Row id map 1st pass partition."""
        del num_tokens, result_infos
        routing_map_spec = get_padded_spec(arg_infos[0])

        # Input sharding
        arg_shardings = (arg_infos[0].sharding,)

        # Output shardings
        row_id_map_sharding = NamedSharding(
            mesh,
            PartitionSpec(routing_map_spec[0], None),
            desc="RowIdMapPass1.row_id_map_sharding",
        )
        workspace_sharding = NamedSharding(
            mesh,
            PartitionSpec(None, None),
            desc="RowIdMapPass1.workspace_sharding",
        )
        out_shardings = [row_id_map_sharding, workspace_sharding]

        def sharded_impl(routing_map):
            # Each shard processes its local tokens
            local_num_tokens = routing_map.shape[0]
            return RowIdMapPass1Primitive.impl(
                routing_map,
                num_tokens=local_num_tokens,
                num_experts=num_experts,
                block_size=block_size,
            )

        return mesh, sharded_impl, out_shardings, arg_shardings

    @staticmethod
    def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, result_types):
        """Shardy sharding rule for this primitive."""
        del num_tokens, num_experts, block_size, mesh, value_types, result_types
        prefix = "RowIdMapPass1"
        # routing_map shape: (num_tokens, num_experts)
        input_spec = (f"{prefix}_tokens", f"{prefix}_experts")
        # row_id_map shape: (num_tokens, num_experts * 2 + 1)
        # Note: row_id_cols != experts since it's num_experts * 2 + 1
        row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols")
        # workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
        workspace_spec = (f"{prefix}_experts", f"{prefix}_ws_blocks")
        return SdyShardingRule((input_spec,), (row_id_map_spec, workspace_spec))

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

register_primitive(RowIdMapPass1Primitive)


class RowIdMapPass2Primitive(BasePrimitive):
    """
    Pass 2 of row_id_map generation: cumsum all and process the mask.
    """

    name = "te_row_id_map_pass2_triton"
    multiple_results = True
    impl_static_args = (2, 3, 4)  # num_tokens, num_experts, block_size
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(row_id_map_aval, workspace_aval, *, num_tokens, num_experts, block_size):
        """Shape/dtype inference for pass 2 (in-place operation)."""
        del row_id_map_aval, workspace_aval
        del block_size

        row_id_map_shape = (num_tokens, num_experts * 2 + 1)
        workspace_shape = (num_experts, triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE))

        return (
            jax.core.ShapedArray(row_id_map_shape, jnp.int32),
            jax.core.ShapedArray(workspace_shape, jnp.int32),
        )

    @staticmethod
    def impl(row_id_map, workspace, num_tokens, num_experts, block_size):
        """Forward to inner primitive."""
        assert RowIdMapPass2Primitive.inner_primitive is not None
        return RowIdMapPass2Primitive.inner_primitive.bind(
            row_id_map,
            workspace,
            num_tokens=num_tokens,
            num_experts=num_experts,
            block_size=block_size,
        )

    @staticmethod
    def lowering(ctx, row_id_map, workspace, *, num_tokens, num_experts, block_size):
        """MLIR lowering using triton_call_lowering."""
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1

        grid = (num_experts, triton.cdiv(num_tokens, block_size))
        workspace_load_width = triton.next_power_of_2(
            num_experts * triton.cdiv(num_tokens, block_size)
        )

        return triton_call_lowering(
            ctx,
            _row_id_map_pass_2_kernel,
            row_id_map,
            workspace,
            grid=grid,
            input_output_aliases={0: 0, 1: 1},
            constexprs={
                "num_tokens": num_tokens,
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "WORKSPACE_LOAD_WIDTH": workspace_load_width,
                "BLOCK_SIZE": block_size,
            },
        )

260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    @staticmethod
    def infer_sharding_from_operands(
        num_tokens, num_experts, block_size, mesh, arg_infos, result_infos
    ):
        """Infer output sharding from input sharding."""
        del num_tokens, num_experts, block_size, result_infos
        row_id_map_spec = get_padded_spec(arg_infos[0])
        # Output has same sharding as input (in-place operation)
        row_id_map_sharding = NamedSharding(
            mesh,
            PartitionSpec(*row_id_map_spec),
            desc="RowIdMapPass2.row_id_map_sharding",
        )
        workspace_sharding = NamedSharding(
            mesh,
            PartitionSpec(None, None),
            desc="RowIdMapPass2.workspace_sharding",
        )
        return [row_id_map_sharding, workspace_sharding]

    @staticmethod
    def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos):
        """Partition the primitive for distributed execution."""
        del num_tokens, result_infos
        row_id_map_spec = get_padded_spec(arg_infos[0])

        # Input shardings
        arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding)

        # Output shardings (same as inputs for in-place operation)
        row_id_map_sharding = NamedSharding(
            mesh,
            PartitionSpec(*row_id_map_spec),
            desc="RowIdMapPass2.row_id_map_sharding",
        )
        workspace_sharding = NamedSharding(
            mesh,
            PartitionSpec(None, None),
            desc="RowIdMapPass2.workspace_sharding",
        )
        out_shardings = [row_id_map_sharding, workspace_sharding]

        def sharded_impl(row_id_map, workspace):
            local_num_tokens = row_id_map.shape[0]
            return RowIdMapPass2Primitive.impl(
                row_id_map,
                workspace,
                num_tokens=local_num_tokens,
                num_experts=num_experts,
                block_size=block_size,
            )

        return mesh, sharded_impl, out_shardings, arg_shardings

    @staticmethod
    def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, result_types):
        """Shardy sharding rule for this primitive."""
        del num_tokens, num_experts, block_size, mesh, value_types, result_types
        prefix = "RowIdMapPass2"
        row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols")
        workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_ws_blocks")
        return SdyShardingRule((row_id_map_spec, workspace_spec), (row_id_map_spec, workspace_spec))

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

register_primitive(RowIdMapPass2Primitive)


class RowIdMapPass3Primitive(BasePrimitive):
    """
    Pass 3 of row_id_map generation: make the row_id_map from sparse to dense structure.
    """

    name = "te_row_id_map_pass3_triton"
    multiple_results = False
    impl_static_args = (1, 2)  # num_tokens, num_experts
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(row_id_map_aval, *, num_tokens, num_experts):
        """Shape/dtype inference for pass 3 (in-place operation)."""
        del row_id_map_aval
        row_id_map_shape = (num_tokens, num_experts * 2 + 1)
        return jax.core.ShapedArray(row_id_map_shape, jnp.int32)

    @staticmethod
    def impl(row_id_map, num_tokens, num_experts):
        """Forward to inner primitive."""
        assert RowIdMapPass3Primitive.inner_primitive is not None
        return RowIdMapPass3Primitive.inner_primitive.bind(
            row_id_map,
            num_tokens=num_tokens,
            num_experts=num_experts,
        )

    @staticmethod
    def lowering(ctx, row_id_map, *, num_tokens, num_experts):
        """MLIR lowering using triton_call_lowering."""
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1

        grid = (num_tokens,)
        load_size = triton.next_power_of_2(num_experts)

        return triton_call_lowering(
            ctx,
            _row_id_map_pass_3_kernel,
            row_id_map,
            grid=grid,
            input_output_aliases={0: 0},
            constexprs={
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "num_experts": num_experts,
                "LOAD_SIZE": load_size,
            },
        )

378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    @staticmethod
    def infer_sharding_from_operands(num_tokens, num_experts, mesh, arg_infos, result_infos):
        """Infer output sharding from input sharding."""
        del num_tokens, num_experts, result_infos
        row_id_map_spec = get_padded_spec(arg_infos[0])
        # Output has same sharding as input (in-place operation)
        return NamedSharding(
            mesh,
            PartitionSpec(*row_id_map_spec),
            desc="RowIdMapPass3.row_id_map_sharding",
        )

    @staticmethod
    def partition(num_tokens, num_experts, mesh, arg_infos, result_infos):
        """Partition the primitive for distributed execution."""
        del num_tokens, result_infos
        row_id_map_spec = get_padded_spec(arg_infos[0])

        # Input sharding
        arg_shardings = (arg_infos[0].sharding,)

        # Output sharding (same as input for in-place operation)
        out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*row_id_map_spec),
            desc="RowIdMapPass3.row_id_map_sharding",
        )

        def sharded_impl(row_id_map):
            local_num_tokens = row_id_map.shape[0]
            return RowIdMapPass3Primitive.impl(
                row_id_map,
                num_tokens=local_num_tokens,
                num_experts=num_experts,
            )

        return mesh, sharded_impl, out_sharding, arg_shardings

    @staticmethod
    def shardy_sharding_rule(num_tokens, num_experts, mesh, value_types, result_types):
        """Shardy sharding rule for this primitive."""
        del num_tokens, num_experts, mesh, value_types, result_types
        prefix = "RowIdMapPass3"
        row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols")
        return SdyShardingRule((row_id_map_spec,), (row_id_map_spec,))

424
425
426
427
428
429

register_primitive(RowIdMapPass3Primitive)


class PermuteWithMaskMapPrimitive(BasePrimitive):
    """
430
    Permute the input tensor based on the row_id_map, optionally with fused padding.
431
432
433
434
    """

    name = "te_permute_with_mask_map_triton"
    multiple_results = True
435
436
437
438
439
440
    # Outer primitive has 6 tensor inputs: inp, row_id_map, probs, scale, permuted_scale, pad_offsets
    # Static args for outer primitive: num_tokens, num_experts, num_out_tokens, hidden_size,
    #                                  with_probs, with_pad, align_size
    # Inner primitive adds output_buf, permuted_probs_buf)

    # impl_static_args is for the outer primitive's impl() which has 6 tensor inputs.
441
442
443
444
445
    impl_static_args = (
        6,
        7,
        8,
        9,
446
447
        10,
        11,
448
449
        12,
    )
450
451
452
453
454
455
456
457
458
459
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        inp_aval,
        row_id_map_aval,
        probs_aval,
        scale_aval,  # dummy, same shape as inp
        permuted_scale_aval,  # dummy, same shape as inp
460
        pad_offsets_aval,
461
462
        output_buf_aval=None,  # Pre-zeroed output buffer (inner primitive only)
        permuted_probs_buf_aval=None,  # Pre-zeroed permuted_probs buffer (inner primitive only)
463
464
465
466
467
468
        *,
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
        with_probs,
469
        with_pad,
470
        align_size,
471
472
    ):
        """Shape/dtype inference for permute."""
473
        del row_id_map_aval, scale_aval, permuted_scale_aval, pad_offsets_aval
474
475
        del num_tokens, num_experts, with_pad, align_size
        del output_buf_aval, permuted_probs_buf_aval  # Used for input_output_aliases only
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493

        output_shape = (num_out_tokens, hidden_size)
        output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)

        if with_probs:
            permuted_probs_aval = jax.core.ShapedArray((num_out_tokens,), probs_aval.dtype)
        else:
            permuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)

        return output_aval, permuted_probs_aval

    @staticmethod
    def impl(
        inp,
        row_id_map,
        probs,
        scale,
        permuted_scale,
494
        pad_offsets,
495
496
497
498
499
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
        with_probs,
500
        with_pad,
501
        align_size,  # align_size is only used for sharding, but must be passed since abstract() requires it
502
503
    ):
        """Forward to inner primitive."""
504

505
        assert PermuteWithMaskMapPrimitive.inner_primitive is not None
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523

        # Create pre-zeroed output buffers for the inner primitive.
        # When with_pad=True, this ensures padding positions contain zeros.
        # These buffers are aliased to the outputs via input_output_aliases in the lowering.
        if with_pad:
            output_buf = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype)
            if with_probs:
                permuted_probs_buf = jnp.zeros((num_out_tokens,), dtype=probs.dtype)
            else:
                permuted_probs_buf = jnp.zeros((0,), dtype=inp.dtype)
        else:
            # When not padding, use empty buffers (kernel ignores them, lowering skips aliasing)
            output_buf = jnp.empty((num_out_tokens, hidden_size), dtype=inp.dtype)
            if with_probs:
                permuted_probs_buf = jnp.empty((num_out_tokens,), dtype=probs.dtype)
            else:
                permuted_probs_buf = jnp.empty((0,), dtype=inp.dtype)

524
525
526
527
528
529
        return PermuteWithMaskMapPrimitive.inner_primitive.bind(
            inp,
            row_id_map,
            probs,
            scale,
            permuted_scale,
530
            pad_offsets,
531
532
            output_buf,
            permuted_probs_buf,
533
534
535
536
537
            num_tokens=num_tokens,
            num_experts=num_experts,
            num_out_tokens=num_out_tokens,
            hidden_size=hidden_size,
            with_probs=with_probs,
538
            with_pad=with_pad,
539
            align_size=align_size,
540
541
542
543
544
545
546
547
548
549
        )

    @staticmethod
    def lowering(
        ctx,
        inp,
        row_id_map,
        probs,
        scale,
        permuted_scale,
550
        pad_offsets,
551
552
        output_buf,  # Pre-zeroed output buffer (for input_output_aliases)
        permuted_probs_buf,  # Pre-zeroed permuted_probs buffer (for input_output_aliases)
553
554
555
556
557
558
        *,
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
        with_probs,
559
        with_pad,
560
        align_size,
561
562
    ):
        """MLIR lowering using triton_call_lowering."""
563
        del align_size
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        inp_stride_token = hidden_size
        inp_stride_hidden = 1
        output_stride_token = hidden_size
        output_stride_hidden = 1
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1
        permuted_probs_stride_token = 1

        if with_probs:
            # Check if probs is 2D [num_tokens, num_experts] or 1D [num_tokens]
            probs_aval = ctx.avals_in[2]
            if len(probs_aval.shape) > 1:
                probs_stride_token = num_experts
                probs_stride_expert = 1
            else:
                probs_stride_token = 1
                probs_stride_expert = 1
        else:
            probs_stride_token = 0
            probs_stride_expert = 0

        # Grid function equivalent: (num_tokens, cdiv(hidden_size, BLOCK_SIZE))
        # Use minimum BLOCK_SIZE from autotune configs to ensure grid covers all elements
        block_size = _get_min_block_size(_permute_kernel)
        grid = (num_tokens, triton.cdiv(hidden_size, block_size))

590
591
592
593
594
595
596
597
598
599
600
601
        # Use input_output_aliases to alias pre-zeroed buffers to outputs.
        # This ensures padding positions contain zeros since the kernel only writes valid positions.
        # Input indices: 0=inp, 1=row_id_map, 2=probs, 3=scale, 4=permuted_scale,
        #                5=pad_offsets, 6=output_buf, 7=permuted_probs_buf
        # Output indices: 0=output, 1=permuted_probs
        if with_pad:
            input_output_aliases = {6: 0}
            if with_probs:
                input_output_aliases[7] = 1
        else:
            input_output_aliases = None

602
603
604
605
606
607
608
609
        return triton_call_lowering(
            ctx,
            _permute_kernel,
            inp,
            row_id_map,
            probs,
            scale,
            permuted_scale,
610
            pad_offsets,
611
612
            output_buf,
            permuted_probs_buf,
613
            grid=grid,
614
            input_output_aliases=input_output_aliases,
615
616
            constexprs={
                "scale_hidden_dim": 0,
617
618
                "num_tokens": num_tokens,
                "num_out_tokens": num_out_tokens,
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "stride_input_token": inp_stride_token,
                "stride_input_hidden": inp_stride_hidden,
                "stride_output_token": output_stride_token,
                "stride_output_hidden": output_stride_hidden,
                "stride_probs_token": probs_stride_token,
                "stride_probs_expert": probs_stride_expert,
                "stride_scale_token": hidden_size,
                "stride_scale_hidden": 1,
                "stride_permuted_probs_token": permuted_probs_stride_token,
                "stride_permuted_scale_token": hidden_size,
                "stride_permuted_scale_hidden": 1,
                "num_experts": num_experts,
                "hidden_size": hidden_size,
                "PERMUTE_PROBS": with_probs,
                "PERMUTE_SCALE": False,
636
                "FUSION_PAD": with_pad,
637
638
639
640
                "BLOCK_SIZE": block_size,
            },
        )

641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
    @staticmethod
    def infer_sharding_from_operands(
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
        with_probs,
        with_pad,
        align_size,
        mesh,
        arg_infos,
        result_infos,
    ):
        """Infer output sharding from input sharding.

        For batch-dimension partitioning:
        - Input (num_tokens, hidden_size) is sharded on token dim
        - Output (num_out_tokens, hidden_size) gets same token dim sharding
        - Permuted probs (num_out_tokens,) gets same token dim sharding
        """
        del align_size  # Used only in partition
        del num_tokens, num_experts, num_out_tokens, hidden_size, with_pad, result_infos
        inp_spec = get_padded_spec(arg_infos[0])
        # Output has same sharding pattern: (token_shard, None)
        output_sharding = NamedSharding(
            mesh,
            PartitionSpec(inp_spec[0], None),
            desc="PermuteWithMaskMap.output_sharding",
        )
        if with_probs:
            permuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(inp_spec[0]),
                desc="PermuteWithMaskMap.permuted_probs_sharding",
            )
        else:
            permuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(None),
                desc="PermuteWithMaskMap.permuted_probs_sharding_empty",
            )
        return [output_sharding, permuted_probs_sharding]

    @staticmethod
    def partition(
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
        with_probs,
        with_pad,
        align_size,
        mesh,
        arg_infos,
        result_infos,
    ):
        """Partition the primitive for distributed execution.

        For batch-dimension partitioning, each GPU processes its local tokens
        independently. The row_id_map contains local destination indices,
        so no inter-GPU communication is needed.
        """
        del num_tokens, result_infos
        inp_spec = get_padded_spec(arg_infos[0])

        # Input shardings - preserve original shardings
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)

        # Output shardings
        output_sharding = NamedSharding(
            mesh,
            PartitionSpec(inp_spec[0], None),
            desc="PermuteWithMaskMap.output_sharding",
        )
        if with_probs:
            permuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(inp_spec[0]),
                desc="PermuteWithMaskMap.permuted_probs_sharding",
            )
        else:
            permuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(None),
                desc="PermuteWithMaskMap.permuted_probs_sharding_empty",
            )
        out_shardings = [output_sharding, permuted_probs_sharding]

        # Get number of data parallel devices from the batch sharding axis
        batch_axis = inp_spec[0]
        if batch_axis is not None:
            num_dp_devices = get_mesh_axis_size(batch_axis, mesh)
        else:
            num_dp_devices = 1

        def sharded_impl(inp, row_id_map, probs, scale, permuted_scale, pad_offsets):
            # Each shard processes its local tokens independently (data parallelism)
            local_num_tokens = inp.shape[0]

            # =========================================================================
            # MoE Permutation Sharding (data parallelism, no expert parallelism)
            # =========================================================================
            # Each GPU has ALL experts and processes its local batch of tokens.
            #
            # TopK bounds output: each token goes to at most topK experts, so:
            #   global_num_out_tokens = global_num_in_tokens * topK
            #   local_num_out_tokens = local_num_in_tokens * topK
            #                        = global_num_out_tokens / num_dp_devices
            #
            #   E = num_experts
            #   A = align_size for padding to group gemm size in cuBLAS
            # With padding (align_size != 128, which is the default/no-op value):
            #   The global num_out_tokens passed here is already worst_case_out_tokens.
            #   We need to recalculate local worst-case from local raw tokens.
            #   local_raw_out_tokens = global_raw_out_tokens / num_dp_devices
            #   local_worst_case = ((local_raw_out + E*(A-1)) // A) * A
            #
            # Local permute produces output ordered by expert: [E0 | E1 | ... | EN]
            # where each expert section contains tokens routed to that expert.
            #
            # Global assembly (if needed) should be done outside this primitive.

            # =========================================================================
            # Output size calculation
            # =========================================================================
            # For both padding and non-padding cases, use simple division.
            # The global num_out_tokens is already the worst-case buffer size.
            #
            # IMPORTANT for padding + sharding:
            # Padding overhead is per-shard (each shard needs E*(A-1) extra space).
            # The caller must account for this by passing a sufficiently large
            # global num_out_tokens such that: global_worst / num_dp >= local_worst
            # where local_worst = ((local_raw + E*(A-1)) // A) * A

            local_num_out_tokens = num_out_tokens // num_dp_devices

            # Local permute - output stays sharded on this GPU
            local_output, local_permuted_probs = PermuteWithMaskMapPrimitive.impl(
                inp,
                row_id_map,
                probs,
                scale,
                permuted_scale,
                pad_offsets,
                num_tokens=local_num_tokens,
                num_experts=num_experts,
                num_out_tokens=local_num_out_tokens,
                hidden_size=hidden_size,
                with_probs=with_probs,
                with_pad=with_pad,
                align_size=align_size,
            )

            return local_output, local_permuted_probs

        return mesh, sharded_impl, out_shardings, arg_shardings

    @staticmethod
    def shardy_sharding_rule(
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
        with_probs,
        with_pad,
        align_size,
        mesh,
        value_types,
        result_types,
    ):
        """Shardy sharding rule for this primitive."""
        del (
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
            align_size,
            mesh,
            value_types,
            result_types,
        )
        prefix = "PermuteWithMaskMap"
        # inp: (num_tokens, hidden_size)
        inp_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
        # row_id_map: (num_tokens, num_experts * 2 + 1)
        row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols")
        # probs: (num_tokens, num_experts) or (0,)
        probs_spec = (
            (f"{prefix}_tokens", f"{prefix}_experts") if with_probs else (f"{prefix}_empty",)
        )
        # scale: (num_tokens, hidden_size) - same shape as inp, permuted together
        scale_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
        # permuted_scale: (num_out_tokens, hidden_size) - same shape as output
        permuted_scale_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden")
        # pad_offsets: (num_experts,) or (0,) - uses same experts factor as probs
        pad_offsets_spec = (f"{prefix}_experts",) if with_pad else (f"{prefix}_pad_empty",)
        # output: (num_out_tokens, hidden_size)
        output_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden")
        # permuted_probs: (num_out_tokens,) or (0,)
        permuted_probs_spec = (f"{prefix}_out_tokens",) if with_probs else (f"{prefix}_empty2",)

        return SdyShardingRule(
            (
                inp_spec,
                row_id_map_spec,
                probs_spec,
                scale_spec,
                permuted_scale_spec,
                pad_offsets_spec,
            ),
            (output_spec, permuted_probs_spec),
        )

854
855
856
857
858
859

register_primitive(PermuteWithMaskMapPrimitive)


class UnpermuteWithMaskMapPrimitive(BasePrimitive):
    """
860
    Unpermute the input tensor based on the row_id_map, optionally with fused unpadding.
861
862
863
864
    """

    name = "te_unpermute_with_mask_map_triton"
    multiple_results = True
865
866
867
868
    # Outer primitive has 5 tensor inputs: inp, row_id_map, merging_probs, permuted_probs, pad_offsets
    # Static args for outer primitive: num_tokens, num_experts, hidden_size,
    #                                  with_merging_probs, with_probs, with_unpad
    # Inner primitive has adds output_buf, unpermuted_probs_buf
869
870
871
872
873
    impl_static_args = (
        5,
        6,
        7,
        8,
874
        9,
875
876
        10,
    )
877
878
879
880
881
882
883
884
885
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        inp_aval,
        row_id_map_aval,
        merging_probs_aval,
        permuted_probs_aval,
886
887
888
        pad_offsets_aval,
        output_buf_aval=None,  # Dummy (inner primitive only)
        unpermuted_probs_buf_aval=None,  # Dummy (inner primitive only)
889
890
891
892
893
894
        *,
        num_tokens,
        num_experts,
        hidden_size,
        with_merging_probs,
        with_probs,
895
        with_unpad,
896
897
    ):
        """Shape/dtype inference for unpermute."""
898
899
        del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval, with_unpad
        del output_buf_aval, unpermuted_probs_buf_aval
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919

        output_shape = (num_tokens, hidden_size)
        output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)

        if with_probs:
            unpermuted_probs_shape = (num_tokens, num_experts)
            unpermuted_probs_aval = jax.core.ShapedArray(
                unpermuted_probs_shape, permuted_probs_aval.dtype
            )
        else:
            unpermuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)

        return output_aval, unpermuted_probs_aval

    @staticmethod
    def impl(
        inp,
        row_id_map,
        merging_probs,
        permuted_probs,
920
        pad_offsets,
921
922
923
924
925
        num_tokens,
        num_experts,
        hidden_size,
        with_merging_probs,
        with_probs,
926
        with_unpad,
927
928
929
    ):
        """Forward to inner primitive."""
        assert UnpermuteWithMaskMapPrimitive.inner_primitive is not None
930
931
932
933
934
935
936
937
938

        # Create dummy buffers for kernel signature consistency with _permute_kernel.
        # These are not used for pre-zeroing since unpermute writes to all output positions.
        output_buf = jnp.empty((num_tokens, hidden_size), dtype=inp.dtype)
        if with_probs:
            unpermuted_probs_buf = jnp.empty((num_tokens, num_experts), dtype=permuted_probs.dtype)
        else:
            unpermuted_probs_buf = jnp.empty((0,), dtype=inp.dtype)

939
940
941
942
943
        return UnpermuteWithMaskMapPrimitive.inner_primitive.bind(
            inp,
            row_id_map,
            merging_probs,
            permuted_probs,
944
            pad_offsets,
945
946
            output_buf,
            unpermuted_probs_buf,
947
948
949
950
951
            num_tokens=num_tokens,
            num_experts=num_experts,
            hidden_size=hidden_size,
            with_merging_probs=with_merging_probs,
            with_probs=with_probs,
952
            with_unpad=with_unpad,
953
954
955
956
957
958
959
960
961
        )

    @staticmethod
    def lowering(
        ctx,
        inp,
        row_id_map,
        merging_probs,
        permuted_probs,
962
        pad_offsets,
963
964
        output_buf,  # Dummy for kernel signature consistency
        unpermuted_probs_buf,  # Dummy for kernel signature consistency
965
966
967
968
969
970
        *,
        num_tokens,
        num_experts,
        hidden_size,
        with_merging_probs,
        with_probs,
971
        with_unpad,
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
    ):
        """MLIR lowering using triton_call_lowering."""
        # Compute strides
        inp_stride_token = hidden_size
        inp_stride_hidden = 1
        output_stride_token = hidden_size
        output_stride_hidden = 1
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1

        if with_merging_probs:
            merging_probs_stride_token = num_experts
            merging_probs_stride_expert = 1
        else:
            merging_probs_stride_token = 0
            merging_probs_stride_expert = 0

        permuted_probs_stride_token = 1
        unpermuted_probs_stride_token = num_experts
        unpermuted_probs_stride_expert = 1

        # Grid - use minimum BLOCK_SIZE from autotune configs
        block_size = _get_min_block_size(_unpermute_kernel)
        grid = (num_tokens, triton.cdiv(hidden_size, block_size))

        return triton_call_lowering(
            ctx,
            _unpermute_kernel,
            inp,
            row_id_map,
            merging_probs,
            permuted_probs,
1004
            pad_offsets,
1005
1006
            output_buf,
            unpermuted_probs_buf,
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
            grid=grid,
            constexprs={
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "stride_input_token": inp_stride_token,
                "stride_input_hidden": inp_stride_hidden,
                "stride_output_token": output_stride_token,
                "stride_output_hidden": output_stride_hidden,
                "stride_merging_probs_token": merging_probs_stride_token,
                "stride_merging_probs_expert": merging_probs_stride_expert,
                "stride_permuted_probs_token": permuted_probs_stride_token,
                "stride_unpermuted_probs_token": unpermuted_probs_stride_token,
                "stride_unpermuted_probs_expert": unpermuted_probs_stride_expert,
                "num_experts": num_experts,
                "hidden_size": hidden_size,
                "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
                "WITH_MERGING_PROBS": with_merging_probs,
                "PERMUTE_PROBS": with_probs,
1025
                "FUSION_UNPAD": with_unpad,
1026
1027
1028
1029
                "BLOCK_SIZE": block_size,
            },
        )

1030
    @staticmethod
1031
    def infer_sharding_from_operands(
1032
1033
1034
1035
1036
        num_tokens,
        num_experts,
        hidden_size,
        with_merging_probs,
        with_probs,
1037
1038
1039
1040
        with_unpad,
        mesh,
        arg_infos,
        result_infos,
1041
    ):
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
        """Infer output sharding from input sharding.

        For batch-dimension partitioning:
        - row_id_map (num_tokens, num_experts*2+1) is sharded on token dim
        - Output (num_tokens, hidden_size) gets same token dim sharding
        """
        del num_tokens, num_experts, hidden_size, with_merging_probs, with_unpad, result_infos
        row_id_map_spec = get_padded_spec(arg_infos[1])
        # Output has same token dimension sharding as row_id_map
        output_sharding = NamedSharding(
            mesh,
            PartitionSpec(row_id_map_spec[0], None),
            desc="UnpermuteWithMaskMap.output_sharding",
        )
1056
        if with_probs:
1057
1058
1059
1060
            unpermuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(row_id_map_spec[0], None),
                desc="UnpermuteWithMaskMap.unpermuted_probs_sharding",
1061
1062
            )
        else:
1063
1064
1065
1066
1067
1068
            unpermuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(None),
                desc="UnpermuteWithMaskMap.unpermuted_probs_sharding_empty",
            )
        return [output_sharding, unpermuted_probs_sharding]
1069
1070

    @staticmethod
1071
    def partition(
1072
1073
1074
1075
1076
        num_tokens,
        num_experts,
        hidden_size,
        with_merging_probs,
        with_probs,
1077
1078
1079
1080
        with_unpad,
        mesh,
        arg_infos,
        result_infos,
1081
    ):
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
        """Partition the primitive for distributed execution."""
        del num_tokens, result_infos
        row_id_map_spec = get_padded_spec(arg_infos[1])

        # Input shardings - preserve original shardings
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)

        # Output shardings
        output_sharding = NamedSharding(
            mesh,
            PartitionSpec(row_id_map_spec[0], None),
            desc="UnpermuteWithMaskMap.output_sharding",
1094
        )
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
        if with_probs:
            unpermuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(row_id_map_spec[0], None),
                desc="UnpermuteWithMaskMap.unpermuted_probs_sharding",
            )
        else:
            unpermuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(None),
                desc="UnpermuteWithMaskMap.unpermuted_probs_sharding_empty",
            )
        out_shardings = [output_sharding, unpermuted_probs_sharding]

        def sharded_impl(inp, row_id_map, merging_probs, permuted_probs, pad_offsets):
            # Each shard processes its local tokens
            local_num_tokens = row_id_map.shape[0]
            return UnpermuteWithMaskMapPrimitive.impl(
                inp,
                row_id_map,
                merging_probs,
                permuted_probs,
                pad_offsets,
                num_tokens=local_num_tokens,
                num_experts=num_experts,
                hidden_size=hidden_size,  # hidden_size is not sharded
                with_merging_probs=with_merging_probs,
                with_probs=with_probs,
                with_unpad=with_unpad,
            )

        return mesh, sharded_impl, out_shardings, arg_shardings
1127
1128

    @staticmethod
1129
    def shardy_sharding_rule(
1130
1131
1132
1133
1134
        num_tokens,
        num_experts,
        hidden_size,
        with_merging_probs,
        with_probs,
1135
1136
1137
1138
        with_unpad,
        mesh,
        value_types,
        result_types,
1139
    ):
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
        """Shardy sharding rule for this primitive."""
        del num_tokens, num_experts, hidden_size, mesh, value_types, result_types
        prefix = "UnpermuteWithMaskMap"
        # inp: (num_out_tokens, hidden_size)
        inp_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden")
        # row_id_map: (num_tokens, num_experts * 2 + 1)
        row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols")
        # merging_probs: (num_tokens, num_experts) or (0,)
        merging_probs_spec = (
            (f"{prefix}_tokens", f"{prefix}_experts")
            if with_merging_probs
            else (f"{prefix}_empty",)
        )
        # permuted_probs: (num_out_tokens,) or (0,)
        permuted_probs_spec = (f"{prefix}_out_tokens",) if with_probs else (f"{prefix}_empty2",)
        # pad_offsets: (num_experts,) when with_unpad=True, or dummy (0,) otherwise
        pad_offsets_spec = (f"{prefix}_experts",) if with_unpad else (f"{prefix}_pad_empty",)
        # output: (num_tokens, hidden_size)
        output_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
        # unpermuted_probs: (num_tokens, num_experts) or (0,)
        unpermuted_probs_spec = (
            (f"{prefix}_tokens", f"{prefix}_experts") if with_probs else (f"{prefix}_empty3",)
        )
1163

1164
1165
1166
        return SdyShardingRule(
            (inp_spec, row_id_map_spec, merging_probs_spec, permuted_probs_spec, pad_offsets_spec),
            (output_spec, unpermuted_probs_spec),
1167
1168
1169
        )


1170
register_primitive(UnpermuteWithMaskMapPrimitive)
1171
1172


1173
1174
class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
    """
1175
    Backward pass for unpermute with merging probabilities, optionally with fused unpadding.
1176
1177
1178
1179
1180
1181

    This kernel computes gradients for both the input and merging_probs.
    """

    name = "te_unpermute_bwd_with_merging_probs_triton"
    multiple_results = True
1182
1183
1184
1185
1186
1187
1188
    impl_static_args = (
        5,
        6,
        7,
        8,
        9,
    )  # num_tokens, num_experts, num_out_tokens, hidden_size, with_unpad
1189
1190
1191
1192
1193
1194
1195
1196
1197
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        fwd_output_grad_aval,
        fwd_input_aval,
        merging_probs_aval,
        row_id_map_aval,
1198
        pad_offsets_aval,
1199
1200
1201
1202
1203
        *,
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
1204
        with_unpad,
1205
1206
    ):
        """Shape/dtype inference for unpermute backward with merging probs."""
1207
        del fwd_input_aval, row_id_map_aval, pad_offsets_aval, with_unpad
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226

        # fwd_input_grad has same shape as fwd_input
        fwd_input_grad_shape = (num_out_tokens, hidden_size)
        fwd_input_grad_aval = jax.core.ShapedArray(fwd_input_grad_shape, fwd_output_grad_aval.dtype)

        # merging_probs_grad has same shape as merging_probs
        merging_probs_grad_shape = (num_tokens, num_experts)
        merging_probs_grad_aval = jax.core.ShapedArray(
            merging_probs_grad_shape, merging_probs_aval.dtype
        )

        return fwd_input_grad_aval, merging_probs_grad_aval

    @staticmethod
    def impl(
        fwd_output_grad,
        fwd_input,
        merging_probs,
        row_id_map,
1227
        pad_offsets,
1228
1229
1230
1231
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
1232
        with_unpad,
1233
1234
1235
1236
1237
1238
1239
1240
    ):
        """Forward to inner primitive."""
        assert UnpermuteBwdWithMergingProbsPrimitive.inner_primitive is not None
        return UnpermuteBwdWithMergingProbsPrimitive.inner_primitive.bind(
            fwd_output_grad,
            fwd_input,
            merging_probs,
            row_id_map,
1241
            pad_offsets,
1242
1243
1244
1245
            num_tokens=num_tokens,
            num_experts=num_experts,
            num_out_tokens=num_out_tokens,
            hidden_size=hidden_size,
1246
            with_unpad=with_unpad,
1247
1248
1249
1250
1251
1252
1253
1254
1255
        )

    @staticmethod
    def lowering(
        ctx,
        fwd_output_grad,
        fwd_input,
        merging_probs,
        row_id_map,
1256
        pad_offsets,
1257
1258
1259
1260
1261
        *,
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
1262
        with_unpad,
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
    ):
        """MLIR lowering using triton_call_lowering."""
        del num_out_tokens

        # Compute strides
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1
        fwd_output_grad_stride_token = hidden_size
        fwd_output_grad_stride_hidden = 1
        fwd_input_grad_stride_token = hidden_size
        fwd_input_grad_stride_hidden = 1
        fwd_input_stride_token = hidden_size
        fwd_input_stride_hidden = 1
        merging_probs_stride_token = num_experts
        merging_probs_stride_expert = 1
        merging_probs_grad_stride_token = num_experts
        merging_probs_grad_stride_expert = 1

        # Grid - one program per token
        grid = (num_tokens,)

        # Get min block size from autotune configs for consistency
        block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel)

        return triton_call_lowering(
            ctx,
            _unpermute_bwd_with_merging_probs_kernel,
            fwd_output_grad,
            fwd_input,
            merging_probs,
            row_id_map,
1294
            pad_offsets,
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
            grid=grid,
            constexprs={
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "stride_fwd_output_grad_token": fwd_output_grad_stride_token,
                "stride_fwd_output_grad_hidden": fwd_output_grad_stride_hidden,
                "stride_fwd_input_grad_token": fwd_input_grad_stride_token,
                "stride_fwd_input_grad_hidden": fwd_input_grad_stride_hidden,
                "stride_fwd_input_token": fwd_input_stride_token,
                "stride_fwd_input_hidden": fwd_input_stride_hidden,
                "stride_merging_probs_token": merging_probs_stride_token,
                "stride_merging_probs_expert": merging_probs_stride_expert,
                "stride_merging_probs_grad_token": merging_probs_grad_stride_token,
                "stride_merging_probs_grad_expert": merging_probs_grad_stride_expert,
                "num_experts": num_experts,
                "hidden_size": hidden_size,
                "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
1312
                "FUSION_UNPAD": with_unpad,
1313
1314
1315
1316
                "BLOCK_SIZE": block_size,
            },
        )

1317
    @staticmethod
1318
    def infer_sharding_from_operands(
1319
1320
1321
1322
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
1323
1324
1325
1326
        with_unpad,
        mesh,
        arg_infos,
        result_infos,
1327
    ):
1328
1329
1330
1331
1332
1333
1334
1335
1336
        """Infer output sharding from input sharding."""
        del num_tokens, num_experts, num_out_tokens, hidden_size, with_unpad, result_infos
        fwd_output_grad_spec = get_padded_spec(arg_infos[0])
        merging_probs_spec = get_padded_spec(arg_infos[2])
        # fwd_input_grad has same token sharding as fwd_output_grad
        fwd_input_grad_sharding = NamedSharding(
            mesh,
            PartitionSpec(fwd_output_grad_spec[0], None),
            desc="UnpermuteBwdWithMergingProbs.fwd_input_grad_sharding",
1337
        )
1338
1339
1340
1341
1342
1343
1344
        # merging_probs_grad has same sharding as merging_probs
        merging_probs_grad_sharding = NamedSharding(
            mesh,
            PartitionSpec(merging_probs_spec[0], None),
            desc="UnpermuteBwdWithMergingProbs.merging_probs_grad_sharding",
        )
        return [fwd_input_grad_sharding, merging_probs_grad_sharding]
1345
1346

    @staticmethod
1347
    def partition(
1348
1349
1350
1351
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
1352
1353
1354
1355
        with_unpad,
        mesh,
        arg_infos,
        result_infos,
1356
    ):
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
        """Partition the primitive for distributed execution."""
        del num_tokens, num_out_tokens, result_infos
        fwd_output_grad_spec = get_padded_spec(arg_infos[0])
        merging_probs_spec = get_padded_spec(arg_infos[2])

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)

        fwd_input_grad_sharding = NamedSharding(
            mesh,
            PartitionSpec(fwd_output_grad_spec[0], None),
            desc="UnpermuteBwdWithMergingProbs.fwd_input_grad_sharding",
        )
        merging_probs_grad_sharding = NamedSharding(
            mesh,
            PartitionSpec(merging_probs_spec[0], None),
            desc="UnpermuteBwdWithMergingProbs.merging_probs_grad_sharding",
1373
        )
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
        out_shardings = [fwd_input_grad_sharding, merging_probs_grad_sharding]

        def sharded_impl(fwd_output_grad, fwd_input, merging_probs, row_id_map, pad_offsets):
            local_num_tokens = row_id_map.shape[0]
            # NOTE: local_num_out_tokens is obtained from the actual tensor shape,
            # which reflects the data-dependent output size from the forward pass.
            local_num_out_tokens = fwd_input.shape[0]
            return UnpermuteBwdWithMergingProbsPrimitive.impl(
                fwd_output_grad,
                fwd_input,
                merging_probs,
                row_id_map,
                pad_offsets,
                num_tokens=local_num_tokens,
                num_experts=num_experts,
                num_out_tokens=local_num_out_tokens,
                hidden_size=hidden_size,  # hidden_size is not sharded
                with_unpad=with_unpad,
            )

        return mesh, sharded_impl, out_shardings, arg_shardings
1395
1396

    @staticmethod
1397
    def shardy_sharding_rule(
1398
1399
1400
1401
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
1402
1403
1404
1405
        with_unpad,
        mesh,
        value_types,
        result_types,
1406
    ):
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
        """Shardy sharding rule for this primitive."""
        del num_tokens, num_experts, num_out_tokens, hidden_size, mesh, value_types, result_types
        prefix = "UnpermuteBwdWithMergingProbs"
        fwd_output_grad_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
        fwd_input_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden")
        merging_probs_spec = (f"{prefix}_tokens", f"{prefix}_experts")
        row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols")
        # pad_offsets: (num_experts,) when with_unpad=True, or dummy (0,) otherwise
        pad_offsets_spec = (f"{prefix}_experts",) if with_unpad else (f"{prefix}_pad_empty",)
        fwd_input_grad_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden")
        merging_probs_grad_spec = (f"{prefix}_tokens", f"{prefix}_experts")

        return SdyShardingRule(
            (
                fwd_output_grad_spec,
                fwd_input_spec,
                merging_probs_spec,
                row_id_map_spec,
                pad_offsets_spec,
            ),
            (fwd_input_grad_spec, merging_probs_grad_spec),
1428
1429
1430
        )


1431
register_primitive(UnpermuteBwdWithMergingProbsPrimitive)
1432
1433


1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
def unpermute_bwd_with_merging_probs(
    fwd_output_grad: jnp.ndarray,
    row_id_map: jnp.ndarray,
    fwd_input: jnp.ndarray,
    merging_probs: jnp.ndarray,
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Backward pass for unpermute with merging probabilities.

    This computes gradients for both the input tensor and merging_probs.

    Parameters
    ----------
    fwd_output_grad : jnp.ndarray
        Gradient of the forward output of shape `[num_tokens, hidden_size]`.
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    fwd_input : jnp.ndarray
        The input tensor from the forward pass of shape `[num_out_tokens, hidden_size]`.
    merging_probs : jnp.ndarray
        The merging probabilities of shape `[num_tokens, num_experts]`.
    num_tokens : int
        Number of tokens in the unpermuted tensor.
    num_experts : int
        Number of experts.
    num_out_tokens : int
        Number of tokens in the permuted tensor.
    hidden_size : int
        Hidden size.

    Returns
    -------
    fwd_input_grad : jnp.ndarray
        Gradient w.r.t. the input tensor of shape `[num_out_tokens, hidden_size]`.
    merging_probs_grad : jnp.ndarray
        Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`.
    """
1475
    # Create dummy pad_offsets (not used when with_unpad=False, but required by kernel signature)
1476
1477
    dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32)
    # Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map, pad_offsets
1478
1479
1480
1481
1482
    return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind(
        fwd_output_grad,
        fwd_input,
        merging_probs,
        row_id_map,
1483
1484
1485
1486
1487
        dummy_pad_offsets,
        num_tokens=num_tokens,
        num_experts=num_experts,
        num_out_tokens=num_out_tokens,
        hidden_size=hidden_size,
1488
        with_unpad=False,
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
    )


def unpermute_bwd_with_merging_probs_and_unpad(
    fwd_output_grad: jnp.ndarray,
    row_id_map: jnp.ndarray,
    fwd_input: jnp.ndarray,
    merging_probs: jnp.ndarray,
    pad_offsets: jnp.ndarray,
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Backward pass for unpermute with merging probabilities and fused unpadding.

    This computes gradients for both the input tensor and merging_probs,
    while handling padded outputs.

    Parameters
    ----------
    fwd_output_grad : jnp.ndarray
        Gradient of the forward output of shape `[num_tokens, hidden_size]`.
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    fwd_input : jnp.ndarray
        The input tensor from the forward pass of shape `[num_out_tokens, hidden_size]`.
    merging_probs : jnp.ndarray
        The merging probabilities of shape `[num_tokens, num_experts]`.
    pad_offsets : jnp.ndarray
        Per-expert cumulative padding offsets of shape `[num_experts]`.
    num_tokens : int
        Number of tokens in the unpermuted tensor.
    num_experts : int
        Number of experts.
    num_out_tokens : int
        Number of tokens in the permuted tensor (including padding).
    hidden_size : int
        Hidden size.

    Returns
    -------
    fwd_input_grad : jnp.ndarray
        Gradient w.r.t. the input tensor of shape `[num_out_tokens, hidden_size]`.
    merging_probs_grad : jnp.ndarray
        Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`.
    """
1537
    return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind(
1538
1539
1540
1541
1542
        fwd_output_grad,
        fwd_input,
        merging_probs,
        row_id_map,
        pad_offsets,
1543
1544
1545
1546
        num_tokens=num_tokens,
        num_experts=num_experts,
        num_out_tokens=num_out_tokens,
        hidden_size=hidden_size,
1547
        with_unpad=True,
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
    )


class MakeChunkSortMapPrimitive(BasePrimitive):
    """
    Make a row_id_map for chunk sort.
    """

    name = "te_make_chunk_sort_map_triton"
    multiple_results = False
    impl_static_args = (2, 3)  # num_tokens, num_splits
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(split_sizes_aval, sorted_indices_aval, *, num_tokens, num_splits):
        """Shape/dtype inference."""
        del sorted_indices_aval
        assert split_sizes_aval.shape == (num_splits,)
        return jax.core.ShapedArray((num_tokens,), jnp.int32)

    @staticmethod
    def impl(split_sizes, sorted_indices, num_tokens, num_splits):
        """Forward to inner primitive."""
        assert MakeChunkSortMapPrimitive.inner_primitive is not None
        return MakeChunkSortMapPrimitive.inner_primitive.bind(
            split_sizes,
            sorted_indices,
            num_tokens=num_tokens,
            num_splits=num_splits,
        )

    @staticmethod
    def lowering(ctx, split_sizes, sorted_indices, *, num_tokens, num_splits):
        """MLIR lowering using triton_call_lowering."""
        grid = (num_tokens,)

        return triton_call_lowering(
            ctx,
            _make_chunk_sort_map_kernel,
            split_sizes,
            sorted_indices,
            grid=grid,
            constexprs={
                "num_splits": num_splits,
                "IDX_LOAD_WIDTH": triton.next_power_of_2(num_splits),
            },
        )

1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
    @staticmethod
    def infer_sharding_from_operands(num_tokens, num_splits, mesh, arg_infos, result_infos):
        """Infer output sharding from input sharding."""
        del num_tokens, num_splits, result_infos, arg_infos
        # row_id_map is replicated since split_sizes and sorted_indices are typically small
        return NamedSharding(
            mesh,
            PartitionSpec(None),
            desc="MakeChunkSortMap.row_id_map_sharding",
        )

    @staticmethod
    def partition(num_tokens, num_splits, mesh, arg_infos, result_infos):
        """Partition the primitive for distributed execution."""
        del result_infos

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)

        out_sharding = NamedSharding(
            mesh,
            PartitionSpec(None),
            desc="MakeChunkSortMap.row_id_map_sharding",
        )

        def sharded_impl(split_sizes, sorted_indices):
            return MakeChunkSortMapPrimitive.impl(
                split_sizes,
                sorted_indices,
                num_tokens=num_tokens,
                num_splits=num_splits,
            )

        return mesh, sharded_impl, out_sharding, arg_shardings

    @staticmethod
    def shardy_sharding_rule(num_tokens, num_splits, mesh, value_types, result_types):
        """Shardy sharding rule for this primitive."""
        del num_tokens, num_splits, mesh, value_types, result_types
        prefix = "MakeChunkSortMap"
        split_sizes_spec = (f"{prefix}_splits",)
        sorted_indices_spec = (f"{prefix}_splits",)
        row_id_map_spec = (f"{prefix}_tokens",)

        return SdyShardingRule(
            (split_sizes_spec, sorted_indices_spec),
            (row_id_map_spec,),
        )

1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725

register_primitive(MakeChunkSortMapPrimitive)


class SortChunksByMapPrimitive(BasePrimitive):
    """
    Sort chunks with row_id_map.
    """

    name = "te_sort_chunks_by_map_triton"
    multiple_results = True
    impl_static_args = (3, 4, 5, 6)  # num_tokens, hidden_size, is_forward, with_probs
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        inp_aval, row_id_map_aval, probs_aval, *, num_tokens, hidden_size, is_forward, with_probs
    ):
        """Shape/dtype inference."""
        del row_id_map_aval, is_forward

        output_aval = jax.core.ShapedArray((num_tokens, hidden_size), inp_aval.dtype)

        if with_probs:
            permuted_probs_aval = jax.core.ShapedArray((num_tokens,), probs_aval.dtype)
        else:
            permuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)

        return output_aval, permuted_probs_aval

    @staticmethod
    def impl(inp, row_id_map, probs, num_tokens, hidden_size, is_forward, with_probs):
        """Forward to inner primitive."""
        assert SortChunksByMapPrimitive.inner_primitive is not None
        return SortChunksByMapPrimitive.inner_primitive.bind(
            inp,
            row_id_map,
            probs,
            num_tokens=num_tokens,
            hidden_size=hidden_size,
            is_forward=is_forward,
            with_probs=with_probs,
        )

    @staticmethod
    def lowering(ctx, inp, row_id_map, probs, *, num_tokens, hidden_size, is_forward, with_probs):
        """MLIR lowering using triton_call_lowering."""
        # Compute strides
        inp_stride_token = hidden_size
        inp_stride_hidden = 1
        output_stride_token = hidden_size
        output_stride_hidden = 1
        probs_stride_token = 1
        permuted_probs_stride_token = 1

        # Grid - use minimum BLOCK_SIZE from autotune configs
        block_size = _get_min_block_size(_sort_chunks_by_map_kernel)
        grid = (num_tokens, triton.cdiv(hidden_size, block_size))

        return triton_call_lowering(
            ctx,
            _sort_chunks_by_map_kernel,
            inp,
            row_id_map,
            probs,
            grid=grid,
            constexprs={
                "stride_input_token": inp_stride_token,
                "stride_input_hidden": inp_stride_hidden,
                "stride_output_token": output_stride_token,
                "stride_output_hidden": output_stride_hidden,
                "stride_probs_token": probs_stride_token,
                "stride_permuted_probs_token": permuted_probs_stride_token,
                "hidden_size": hidden_size,
                "PERMUTE_PROBS": with_probs,
                "FORWARD": is_forward,
                "BLOCK_SIZE": block_size,
            },
        )

1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
    @staticmethod
    def infer_sharding_from_operands(
        num_tokens, hidden_size, is_forward, with_probs, mesh, arg_infos, result_infos
    ):
        """Infer output sharding from input sharding."""
        del num_tokens, hidden_size, is_forward, result_infos
        inp_spec = get_padded_spec(arg_infos[0])
        output_sharding = NamedSharding(
            mesh,
            PartitionSpec(inp_spec[0], None),
            desc="SortChunksByMap.output_sharding",
        )
        if with_probs:
            permuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(inp_spec[0]),
                desc="SortChunksByMap.permuted_probs_sharding",
            )
        else:
            permuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(None),
                desc="SortChunksByMap.permuted_probs_sharding_empty",
            )
        return [output_sharding, permuted_probs_sharding]

    @staticmethod
    def partition(num_tokens, hidden_size, is_forward, with_probs, mesh, arg_infos, result_infos):
        """Partition the primitive for distributed execution."""
        del num_tokens, result_infos
        inp_spec = get_padded_spec(arg_infos[0])

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)

        output_sharding = NamedSharding(
            mesh,
            PartitionSpec(inp_spec[0], None),
            desc="SortChunksByMap.output_sharding",
        )
        if with_probs:
            permuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(inp_spec[0]),
                desc="SortChunksByMap.permuted_probs_sharding",
            )
        else:
            permuted_probs_sharding = NamedSharding(
                mesh,
                PartitionSpec(None),
                desc="SortChunksByMap.permuted_probs_sharding_empty",
            )
        out_shardings = [output_sharding, permuted_probs_sharding]

        def sharded_impl(inp, row_id_map, probs):
            local_num_tokens = inp.shape[0]
            return SortChunksByMapPrimitive.impl(
                inp,
                row_id_map,
                probs,
                num_tokens=local_num_tokens,
                hidden_size=hidden_size,  # hidden_size is not sharded
                is_forward=is_forward,
                with_probs=with_probs,
            )

        return mesh, sharded_impl, out_shardings, arg_shardings

    @staticmethod
    def shardy_sharding_rule(
        num_tokens, hidden_size, is_forward, with_probs, mesh, value_types, result_types
    ):
        """Shardy sharding rule for this primitive."""
        del num_tokens, hidden_size, is_forward, mesh, value_types, result_types
        prefix = "SortChunksByMap"
        inp_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
        row_id_map_spec = (f"{prefix}_tokens",)
        probs_spec = (f"{prefix}_tokens",) if with_probs else (f"{prefix}_empty",)
        output_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
        permuted_probs_spec = (f"{prefix}_tokens",) if with_probs else (f"{prefix}_empty2",)

        return SdyShardingRule(
            (inp_spec, row_id_map_spec, probs_spec),
            (output_spec, permuted_probs_spec),
        )

1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915

register_primitive(SortChunksByMapPrimitive)


def make_row_id_map(
    routing_map: jnp.ndarray,
    num_tokens: int,
    num_experts: int,
) -> jnp.ndarray:
    """
    Prepare the row_id_map for the permutation.

    This function chains 3 Triton kernel passes together.

    Parameters
    ----------
    routing_map : jnp.ndarray
        Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
        which experts are routed to which tokens. The values in it: 1 means the token is routed to
        this expert and 0 means not.
    num_tokens : int
        Number of tokens in the input tensor.
    num_experts : int
        Number of experts in the input tensor.

    Returns
    -------
    row_id_map : jnp.ndarray
        The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
        For each token, the last item is the number of experts that are routed (n_routed).
        The first n_routed items are the destination row indices in the permuted tokens.
        The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
        to the first n_routed row indices above.
    """
    block_size = DEFAULT_BLOCK_SIZE

    # Pass 1: Block cumsum
    row_id_map_pass1, workspace_tensor = RowIdMapPass1Primitive.outer_primitive.bind(
        routing_map,
        num_tokens=num_tokens,
        num_experts=num_experts,
        block_size=block_size,
    )

    # Pass 2: Cumsum all and process the mask
    row_id_map_pass2, _ = RowIdMapPass2Primitive.outer_primitive.bind(
        row_id_map_pass1,
        workspace_tensor,
        num_tokens=num_tokens,
        num_experts=num_experts,
        block_size=block_size,
    )

    # Initialize columns [num_experts:] to -1 since Pass 1/2 only wrote to [0:num_experts]
    # Reference implementation expects -1 for invalid entries
    row_id_map = row_id_map_pass2.at[:, num_experts:].set(-1)

    # Pass 3: Make the row_id_map from sparse to dense structure
    row_id_map = RowIdMapPass3Primitive.outer_primitive.bind(
        row_id_map,
        num_tokens=num_tokens,
        num_experts=num_experts,
    )

    return row_id_map


def permute_with_mask_map(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    probs: Optional[jnp.ndarray],
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    """
    Permute the input tensor based on the row_id_map.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    probs : Optional[jnp.ndarray]
        The probabilities of the input tensor. If it is not None, it will be permuted.
    num_tokens : int
        Number of tokens in the input tensor.
    num_experts : int
        Number of experts in the input tensor.
    num_out_tokens : int
        Number of tokens in the permuted tensor.
    hidden_size : int
        Hidden size of the input tensor.

    Returns
    -------
    output : jnp.ndarray
        Permuted output tensor of shape `[num_out_tokens, hidden_size]`.
    permuted_probs : Optional[jnp.ndarray]
        Permuted probabilities if probs was provided, None otherwise.
    """
    with_probs = probs is not None

1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
    # Handle None probs by creating dummy tensor
    if not with_probs:
        probs = jnp.zeros((0,), dtype=inp.dtype)

    # Create dummy scale tensors (not used when PERMUTE_SCALE=False, but required by kernel signature)
    dummy_scale = inp
    dummy_permuted_scale = inp
    # Create dummy pad_offsets (not used when FUSION_PAD=False, but required by kernel signature)
    dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32)

    output, permuted_probs = PermuteWithMaskMapPrimitive.outer_primitive.bind(
        inp,
        row_id_map,
        probs,
        dummy_scale,
        dummy_permuted_scale,
        dummy_pad_offsets,
        num_tokens=num_tokens,
        num_experts=num_experts,
        num_out_tokens=num_out_tokens,
        hidden_size=hidden_size,
        with_probs=with_probs,
        with_pad=False,
1939
        align_size=128,  # Default value, no-op for non-padding case
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
    )

    if not with_probs:
        permuted_probs = None

    return output, permuted_probs


def permute_with_mask_map_and_pad(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    probs: Optional[jnp.ndarray],
    pad_offsets: jnp.ndarray,
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
1957
    align_size: int = 128,
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    """
    Permute the input tensor based on the row_id_map with fused padding.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    probs : Optional[jnp.ndarray]
        The probabilities of the input tensor. If it is not None, it will be permuted.
    pad_offsets : jnp.ndarray
        Per-expert cumulative padding offsets of shape `[num_experts]`.
    num_tokens : int
        Number of tokens in the input tensor.
    num_experts : int
        Number of experts in the input tensor.
    num_out_tokens : int
        Number of tokens in the permuted tensor (including padding).
    hidden_size : int
        Hidden size of the input tensor.
1980
1981
1982
    align_size : int
        Alignment size for padding (default: 128). Used for distributed sharding
        to correctly compute local buffer sizes.
1983
1984
1985
1986
1987

    Returns
    -------
    output : jnp.ndarray
        Permuted and padded output tensor of shape `[num_out_tokens, hidden_size]`.
1988
        Padding positions are zero-filled.
1989
1990
    permuted_probs : Optional[jnp.ndarray]
        Permuted probabilities if probs was provided, None otherwise.
1991
        Padding positions are zero-filled.
1992
1993
1994
    """
    with_probs = probs is not None

1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
    # Handle None probs by creating dummy tensor
    if not with_probs:
        probs = jnp.zeros((0,), dtype=inp.dtype)

    # Create dummy scale tensors (not used when PERMUTE_SCALE=False, but required by kernel signature)
    dummy_scale = inp
    dummy_permuted_scale = inp

    output, permuted_probs = PermuteWithMaskMapPrimitive.outer_primitive.bind(
        inp,
        row_id_map,
        probs,
        dummy_scale,
        dummy_permuted_scale,
2009
        pad_offsets,
2010
2011
2012
2013
2014
        num_tokens=num_tokens,
        num_experts=num_experts,
        num_out_tokens=num_out_tokens,
        hidden_size=hidden_size,
        with_probs=with_probs,
2015
        with_pad=True,
2016
        align_size=align_size,
2017
2018
    )

2019
2020
2021
2022
2023
    # Note: Zero-filling of padding positions is handled by pre-zeroing the output
    # buffers in impl() using jnp.zeros(), then aliasing them to the kernel's outputs
    # via input_output_aliases. The kernel only writes to valid positions, leaving
    # padding positions at zero.

2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
    if not with_probs:
        permuted_probs = None

    return output, permuted_probs


def unpermute_with_mask_map(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: Optional[jnp.ndarray],
    permuted_probs: Optional[jnp.ndarray],
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    """
    Unpermute the input tensor based on the row_id_map.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape `[num_out_tokens, hidden_size]`.
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    merging_probs : Optional[jnp.ndarray]
        The merging probabilities of the input tensor. If it is not None, it will be used as weights
        to reduce the unpermuted tokens.
    permuted_probs : Optional[jnp.ndarray]
        The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
    num_tokens : int
        Number of tokens in the permuted tensor.
    num_experts : int
        Number of experts in the permuted tensor.
    hidden_size : int
        Hidden size of the permuted tensor.

    Returns
    -------
    output : jnp.ndarray
        Unpermuted output tensor of shape `[num_tokens, hidden_size]`.
    unpermuted_probs : Optional[jnp.ndarray]
        Unpermuted probabilities if permuted_probs was provided, None otherwise.
    """
    with_merging_probs = merging_probs is not None
    with_probs = permuted_probs is not None

    # Handle None inputs by creating dummy tensors
    if not with_merging_probs:
        merging_probs = jnp.zeros((0,), dtype=inp.dtype)
    if not with_probs:
        permuted_probs = jnp.zeros((0,), dtype=inp.dtype)
2075
    # Create dummy pad_offsets (not used when with_unpad=False, but required by kernel signature)
2076
    dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32)
2077
2078
2079
2080
2081
2082

    output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind(
        inp,
        row_id_map,
        merging_probs,
        permuted_probs,
2083
2084
2085
2086
2087
2088
        dummy_pad_offsets,
        num_tokens=num_tokens,
        num_experts=num_experts,
        hidden_size=hidden_size,
        with_merging_probs=with_merging_probs,
        with_probs=with_probs,
2089
        with_unpad=False,
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
    )

    if not with_probs:
        unpermuted_probs = None

    return output, unpermuted_probs


def unpermute_with_mask_map_and_unpad(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: Optional[jnp.ndarray],
    permuted_probs: Optional[jnp.ndarray],
    pad_offsets: jnp.ndarray,
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    """
    Unpermute the input tensor based on the row_id_map with fused unpadding.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape `[num_out_tokens, hidden_size]` (including padding).
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    merging_probs : Optional[jnp.ndarray]
        The merging probabilities of the input tensor. If it is not None, it will be used as weights
        to reduce the unpermuted tokens.
    permuted_probs : Optional[jnp.ndarray]
        The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
    pad_offsets : jnp.ndarray
        Per-expert cumulative padding offsets of shape `[num_experts]`.
    num_tokens : int
        Number of tokens in the unpermuted tensor.
    num_experts : int
        Number of experts.
    hidden_size : int
        Hidden size of the tensor.

    Returns
    -------
    output : jnp.ndarray
        Unpermuted output tensor of shape `[num_tokens, hidden_size]`.
    unpermuted_probs : Optional[jnp.ndarray]
        Unpermuted probabilities if permuted_probs was provided, None otherwise.
    """
    with_merging_probs = merging_probs is not None
    with_probs = permuted_probs is not None

    # Handle None inputs by creating dummy tensors
    if not with_merging_probs:
        merging_probs = jnp.zeros((0,), dtype=inp.dtype)
    if not with_probs:
        permuted_probs = jnp.zeros((0,), dtype=inp.dtype)

2147
    output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind(
2148
2149
2150
2151
2152
        inp,
        row_id_map,
        merging_probs,
        permuted_probs,
        pad_offsets,
2153
2154
2155
2156
2157
        num_tokens=num_tokens,
        num_experts=num_experts,
        hidden_size=hidden_size,
        with_merging_probs=with_merging_probs,
        with_probs=with_probs,
2158
        with_unpad=True,
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
    )

    if not with_probs:
        unpermuted_probs = None

    return output, unpermuted_probs


def make_chunk_sort_map(
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
    num_tokens: int,
    num_splits: int,
) -> jnp.ndarray:
    """
    Make a row_id_map for chunk sort.

    Parameters
    ----------
    split_sizes : jnp.ndarray
        The sizes of the chunks of shape `[num_splits,]`.
    sorted_indices : jnp.ndarray
        The indices of the sorted chunks of shape `[num_splits,]`.
    num_tokens : int
        Number of tokens in the input tensor.
    num_splits : int
        Number of splits of split_sizes and sorted_indices.

    Returns
    -------
    row_id_map : jnp.ndarray
        Row ID map for chunk sorting of shape `[num_tokens,]`.
    """
    return MakeChunkSortMapPrimitive.outer_primitive.bind(
        split_sizes,
        sorted_indices,
        num_tokens=num_tokens,
        num_splits=num_splits,
    )


def sort_chunks_by_map(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    probs: Optional[jnp.ndarray],
    num_tokens: int,
    hidden_size: int,
    is_forward: bool,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    """
    Sort chunks with row_id_map.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape `[num_tokens, hidden_size]`.
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens,]`.
    probs : Optional[jnp.ndarray]
        The probabilities of the input tensor. If it is not None, it will be permuted.
    num_tokens : int
        Number of tokens in the input tensor.
    hidden_size : int
        Hidden size of the input tensor.
    is_forward : bool
        Whether the sort is for forward or backward.

    Returns
    -------
    output : jnp.ndarray
        Sorted output tensor of shape `[num_tokens, hidden_size]`.
    permuted_probs : Optional[jnp.ndarray]
        Sorted probabilities if probs was provided, None otherwise.
    """
    with_probs = probs is not None

    # Handle None probs by creating dummy tensor
    if not with_probs:
        probs = jnp.zeros((0,), dtype=inp.dtype)

    output, permuted_probs = SortChunksByMapPrimitive.outer_primitive.bind(
        inp,
        row_id_map,
        probs,
        num_tokens=num_tokens,
        hidden_size=hidden_size,
        is_forward=is_forward,
        with_probs=with_probs,
    )

    if not with_probs:
        permuted_probs = None

    return output, permuted_probs