permutation.py 35.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""JAX/TE custom ops for permutation in MOE using Triton kernels."""

from typing import Optional, Tuple

import jax
import jax.numpy as jnp
import triton

from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive
from transformer_engine.common.triton.permutation import (
    _row_id_map_pass_1_kernel,
    _row_id_map_pass_2_kernel,
    _row_id_map_pass_3_kernel,
    _permute_kernel,
    _unpermute_kernel,
    _unpermute_bwd_with_merging_probs_kernel,
    _make_chunk_sort_map_kernel,
    _sort_chunks_by_map_kernel,
)
from .utils import triton_call_lowering


__all__ = [
    "make_row_id_map",
    "permute_with_mask_map",
    "unpermute_with_mask_map",
    "unpermute_bwd_with_merging_probs",
    "make_chunk_sort_map",
    "sort_chunks_by_map",
]

DEFAULT_BLOCK_SIZE = 1024


def _get_min_block_size(kernel, default=128):
    if hasattr(kernel, "configs"):
        return min(config.kwargs.get("BLOCK_SIZE", default) for config in kernel.configs)
    return default


class RowIdMapPass1Primitive(BasePrimitive):
    """
    Pass 1 of row_id_map generation: block cumsum.

    For each expert, compute the cumsum of every block_size tokens.
    """

    name = "te_row_id_map_pass1_triton"
    multiple_results = True
    impl_static_args = (1, 2, 3)  # num_tokens, num_experts, block_size
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(routing_map_aval, *, num_tokens, num_experts, block_size):
        """Shape/dtype inference for pass 1."""
        del block_size  # Only affects grid, not output shape

        assert routing_map_aval.shape == (
            num_tokens,
            num_experts,
        ), f"routing_map shape mismatch: expected ({num_tokens}, {num_experts})"

        row_id_map_shape = (num_tokens, num_experts * 2 + 1)
        workspace_shape = (
            num_experts,
            triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE),
        )

        return (
            jax.core.ShapedArray(row_id_map_shape, jnp.int32),
            jax.core.ShapedArray(workspace_shape, jnp.int32),
        )

    @staticmethod
    def impl(routing_map, num_tokens, num_experts, block_size):
        """Forward to inner primitive."""
        assert RowIdMapPass1Primitive.inner_primitive is not None
        return RowIdMapPass1Primitive.inner_primitive.bind(
            routing_map,
            num_tokens=num_tokens,
            num_experts=num_experts,
            block_size=block_size,
        )

    @staticmethod
    def lowering(ctx, routing_map, *, num_tokens, num_experts, block_size):
        """MLIR lowering using triton_call_lowering."""
        # Compute strides
        routing_stride_token = num_experts
        routing_stride_expert = 1
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1

        grid = (num_experts, triton.cdiv(num_tokens, block_size))

        # All scalar arguments must be passed as constexprs
        return triton_call_lowering(
            ctx,
            _row_id_map_pass_1_kernel,
            routing_map,  # Only tensor arguments here
            grid=grid,
            constexprs={
                "num_tokens": num_tokens,
                "stride_routing_map_token": routing_stride_token,
                "stride_routing_map_expert": routing_stride_expert,
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "BLOCK_SIZE": block_size,
            },
        )


register_primitive(RowIdMapPass1Primitive)


class RowIdMapPass2Primitive(BasePrimitive):
    """
    Pass 2 of row_id_map generation: cumsum all and process the mask.
    """

    name = "te_row_id_map_pass2_triton"
    multiple_results = True
    impl_static_args = (2, 3, 4)  # num_tokens, num_experts, block_size
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(row_id_map_aval, workspace_aval, *, num_tokens, num_experts, block_size):
        """Shape/dtype inference for pass 2 (in-place operation)."""
        del row_id_map_aval, workspace_aval
        del block_size

        row_id_map_shape = (num_tokens, num_experts * 2 + 1)
        workspace_shape = (num_experts, triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE))

        return (
            jax.core.ShapedArray(row_id_map_shape, jnp.int32),
            jax.core.ShapedArray(workspace_shape, jnp.int32),
        )

    @staticmethod
    def impl(row_id_map, workspace, num_tokens, num_experts, block_size):
        """Forward to inner primitive."""
        assert RowIdMapPass2Primitive.inner_primitive is not None
        return RowIdMapPass2Primitive.inner_primitive.bind(
            row_id_map,
            workspace,
            num_tokens=num_tokens,
            num_experts=num_experts,
            block_size=block_size,
        )

    @staticmethod
    def lowering(ctx, row_id_map, workspace, *, num_tokens, num_experts, block_size):
        """MLIR lowering using triton_call_lowering."""
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1

        grid = (num_experts, triton.cdiv(num_tokens, block_size))
        workspace_load_width = triton.next_power_of_2(
            num_experts * triton.cdiv(num_tokens, block_size)
        )

        return triton_call_lowering(
            ctx,
            _row_id_map_pass_2_kernel,
            row_id_map,
            workspace,
            grid=grid,
            input_output_aliases={0: 0, 1: 1},
            constexprs={
                "num_tokens": num_tokens,
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "WORKSPACE_LOAD_WIDTH": workspace_load_width,
                "BLOCK_SIZE": block_size,
            },
        )


register_primitive(RowIdMapPass2Primitive)


class RowIdMapPass3Primitive(BasePrimitive):
    """
    Pass 3 of row_id_map generation: make the row_id_map from sparse to dense structure.
    """

    name = "te_row_id_map_pass3_triton"
    multiple_results = False
    impl_static_args = (1, 2)  # num_tokens, num_experts
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(row_id_map_aval, *, num_tokens, num_experts):
        """Shape/dtype inference for pass 3 (in-place operation)."""
        del row_id_map_aval
        row_id_map_shape = (num_tokens, num_experts * 2 + 1)
        return jax.core.ShapedArray(row_id_map_shape, jnp.int32)

    @staticmethod
    def impl(row_id_map, num_tokens, num_experts):
        """Forward to inner primitive."""
        assert RowIdMapPass3Primitive.inner_primitive is not None
        return RowIdMapPass3Primitive.inner_primitive.bind(
            row_id_map,
            num_tokens=num_tokens,
            num_experts=num_experts,
        )

    @staticmethod
    def lowering(ctx, row_id_map, *, num_tokens, num_experts):
        """MLIR lowering using triton_call_lowering."""
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1

        grid = (num_tokens,)
        load_size = triton.next_power_of_2(num_experts)

        return triton_call_lowering(
            ctx,
            _row_id_map_pass_3_kernel,
            row_id_map,
            grid=grid,
            input_output_aliases={0: 0},
            constexprs={
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "num_experts": num_experts,
                "LOAD_SIZE": load_size,
            },
        )


register_primitive(RowIdMapPass3Primitive)


class PermuteWithMaskMapPrimitive(BasePrimitive):
    """
    Permute the input tensor based on the row_id_map.
    """

    name = "te_permute_with_mask_map_triton"
    multiple_results = True
    # scale and permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False)
    # but they need to be in the signature for the kernel call
    impl_static_args = (
        5,
        6,
        7,
        8,
        9,
    )  # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        inp_aval,
        row_id_map_aval,
        probs_aval,
        scale_aval,  # dummy, same shape as inp
        permuted_scale_aval,  # dummy, same shape as inp
        *,
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
        with_probs,
    ):
        """Shape/dtype inference for permute."""
        del row_id_map_aval, scale_aval, permuted_scale_aval
        del num_tokens, num_experts

        output_shape = (num_out_tokens, hidden_size)
        output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)

        if with_probs:
            permuted_probs_aval = jax.core.ShapedArray((num_out_tokens,), probs_aval.dtype)
        else:
            permuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)

        return output_aval, permuted_probs_aval

    @staticmethod
    def impl(
        inp,
        row_id_map,
        probs,
        scale,
        permuted_scale,
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
        with_probs,
    ):
        """Forward to inner primitive."""
        assert PermuteWithMaskMapPrimitive.inner_primitive is not None
        return PermuteWithMaskMapPrimitive.inner_primitive.bind(
            inp,
            row_id_map,
            probs,
            scale,
            permuted_scale,
            num_tokens=num_tokens,
            num_experts=num_experts,
            num_out_tokens=num_out_tokens,
            hidden_size=hidden_size,
            with_probs=with_probs,
        )

    @staticmethod
    def lowering(
        ctx,
        inp,
        row_id_map,
        probs,
        scale,
        permuted_scale,
        *,
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
        with_probs,
    ):
        """MLIR lowering using triton_call_lowering."""
        del num_out_tokens
        inp_stride_token = hidden_size
        inp_stride_hidden = 1
        output_stride_token = hidden_size
        output_stride_hidden = 1
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1
        permuted_probs_stride_token = 1

        if with_probs:
            # Check if probs is 2D [num_tokens, num_experts] or 1D [num_tokens]
            probs_aval = ctx.avals_in[2]
            if len(probs_aval.shape) > 1:
                probs_stride_token = num_experts
                probs_stride_expert = 1
            else:
                probs_stride_token = 1
                probs_stride_expert = 1
        else:
            probs_stride_token = 0
            probs_stride_expert = 0

        # Grid function equivalent: (num_tokens, cdiv(hidden_size, BLOCK_SIZE))
        # Use minimum BLOCK_SIZE from autotune configs to ensure grid covers all elements
        block_size = _get_min_block_size(_permute_kernel)
        grid = (num_tokens, triton.cdiv(hidden_size, block_size))

        return triton_call_lowering(
            ctx,
            _permute_kernel,
            inp,
            row_id_map,
            probs,
            scale,
            permuted_scale,
            grid=grid,
            constexprs={
                "scale_hidden_dim": 0,
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "stride_input_token": inp_stride_token,
                "stride_input_hidden": inp_stride_hidden,
                "stride_output_token": output_stride_token,
                "stride_output_hidden": output_stride_hidden,
                "stride_probs_token": probs_stride_token,
                "stride_probs_expert": probs_stride_expert,
                "stride_scale_token": hidden_size,
                "stride_scale_hidden": 1,
                "stride_permuted_probs_token": permuted_probs_stride_token,
                "stride_permuted_scale_token": hidden_size,
                "stride_permuted_scale_hidden": 1,
                "num_experts": num_experts,
                "hidden_size": hidden_size,
                "PERMUTE_PROBS": with_probs,
                "PERMUTE_SCALE": False,
                "BLOCK_SIZE": block_size,
            },
        )


register_primitive(PermuteWithMaskMapPrimitive)


class UnpermuteWithMaskMapPrimitive(BasePrimitive):
    """
    Unpermute the input tensor based on the row_id_map.
    """

    name = "te_unpermute_with_mask_map_triton"
    multiple_results = True
    impl_static_args = (
        4,
        5,
        6,
        7,
        8,
    )  # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        inp_aval,
        row_id_map_aval,
        merging_probs_aval,
        permuted_probs_aval,
        *,
        num_tokens,
        num_experts,
        hidden_size,
        with_merging_probs,
        with_probs,
    ):
        """Shape/dtype inference for unpermute."""
        del row_id_map_aval, merging_probs_aval, with_merging_probs

        output_shape = (num_tokens, hidden_size)
        output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)

        if with_probs:
            unpermuted_probs_shape = (num_tokens, num_experts)
            unpermuted_probs_aval = jax.core.ShapedArray(
                unpermuted_probs_shape, permuted_probs_aval.dtype
            )
        else:
            unpermuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)

        return output_aval, unpermuted_probs_aval

    @staticmethod
    def impl(
        inp,
        row_id_map,
        merging_probs,
        permuted_probs,
        num_tokens,
        num_experts,
        hidden_size,
        with_merging_probs,
        with_probs,
    ):
        """Forward to inner primitive."""
        assert UnpermuteWithMaskMapPrimitive.inner_primitive is not None
        return UnpermuteWithMaskMapPrimitive.inner_primitive.bind(
            inp,
            row_id_map,
            merging_probs,
            permuted_probs,
            num_tokens=num_tokens,
            num_experts=num_experts,
            hidden_size=hidden_size,
            with_merging_probs=with_merging_probs,
            with_probs=with_probs,
        )

    @staticmethod
    def lowering(
        ctx,
        inp,
        row_id_map,
        merging_probs,
        permuted_probs,
        *,
        num_tokens,
        num_experts,
        hidden_size,
        with_merging_probs,
        with_probs,
    ):
        """MLIR lowering using triton_call_lowering."""
        # Compute strides
        inp_stride_token = hidden_size
        inp_stride_hidden = 1
        output_stride_token = hidden_size
        output_stride_hidden = 1
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1

        if with_merging_probs:
            merging_probs_stride_token = num_experts
            merging_probs_stride_expert = 1
        else:
            merging_probs_stride_token = 0
            merging_probs_stride_expert = 0

        permuted_probs_stride_token = 1
        unpermuted_probs_stride_token = num_experts
        unpermuted_probs_stride_expert = 1

        # Grid - use minimum BLOCK_SIZE from autotune configs
        block_size = _get_min_block_size(_unpermute_kernel)
        grid = (num_tokens, triton.cdiv(hidden_size, block_size))

        return triton_call_lowering(
            ctx,
            _unpermute_kernel,
            inp,
            row_id_map,
            merging_probs,
            permuted_probs,
            grid=grid,
            constexprs={
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "stride_input_token": inp_stride_token,
                "stride_input_hidden": inp_stride_hidden,
                "stride_output_token": output_stride_token,
                "stride_output_hidden": output_stride_hidden,
                "stride_merging_probs_token": merging_probs_stride_token,
                "stride_merging_probs_expert": merging_probs_stride_expert,
                "stride_permuted_probs_token": permuted_probs_stride_token,
                "stride_unpermuted_probs_token": unpermuted_probs_stride_token,
                "stride_unpermuted_probs_expert": unpermuted_probs_stride_expert,
                "num_experts": num_experts,
                "hidden_size": hidden_size,
                "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
                "WITH_MERGING_PROBS": with_merging_probs,
                "PERMUTE_PROBS": with_probs,
                "BLOCK_SIZE": block_size,
            },
        )


register_primitive(UnpermuteWithMaskMapPrimitive)


class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
    """
    Backward pass for unpermute with merging probabilities.

    This kernel computes gradients for both the input and merging_probs.
    """

    name = "te_unpermute_bwd_with_merging_probs_triton"
    multiple_results = True
    impl_static_args = (4, 5, 6, 7)  # num_tokens, num_experts, num_out_tokens, hidden_size
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        fwd_output_grad_aval,
        fwd_input_aval,
        merging_probs_aval,
        row_id_map_aval,
        *,
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
    ):
        """Shape/dtype inference for unpermute backward with merging probs."""
        del fwd_input_aval, row_id_map_aval

        # fwd_input_grad has same shape as fwd_input
        fwd_input_grad_shape = (num_out_tokens, hidden_size)
        fwd_input_grad_aval = jax.core.ShapedArray(fwd_input_grad_shape, fwd_output_grad_aval.dtype)

        # merging_probs_grad has same shape as merging_probs
        merging_probs_grad_shape = (num_tokens, num_experts)
        merging_probs_grad_aval = jax.core.ShapedArray(
            merging_probs_grad_shape, merging_probs_aval.dtype
        )

        return fwd_input_grad_aval, merging_probs_grad_aval

    @staticmethod
    def impl(
        fwd_output_grad,
        fwd_input,
        merging_probs,
        row_id_map,
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
    ):
        """Forward to inner primitive."""
        assert UnpermuteBwdWithMergingProbsPrimitive.inner_primitive is not None
        return UnpermuteBwdWithMergingProbsPrimitive.inner_primitive.bind(
            fwd_output_grad,
            fwd_input,
            merging_probs,
            row_id_map,
            num_tokens=num_tokens,
            num_experts=num_experts,
            num_out_tokens=num_out_tokens,
            hidden_size=hidden_size,
        )

    @staticmethod
    def lowering(
        ctx,
        fwd_output_grad,
        fwd_input,
        merging_probs,
        row_id_map,
        *,
        num_tokens,
        num_experts,
        num_out_tokens,
        hidden_size,
    ):
        """MLIR lowering using triton_call_lowering."""
        del num_out_tokens

        # Compute strides
        row_id_stride_token = num_experts * 2 + 1
        row_id_stride_expert = 1
        fwd_output_grad_stride_token = hidden_size
        fwd_output_grad_stride_hidden = 1
        fwd_input_grad_stride_token = hidden_size
        fwd_input_grad_stride_hidden = 1
        fwd_input_stride_token = hidden_size
        fwd_input_stride_hidden = 1
        merging_probs_stride_token = num_experts
        merging_probs_stride_expert = 1
        merging_probs_grad_stride_token = num_experts
        merging_probs_grad_stride_expert = 1

        # Grid - one program per token
        grid = (num_tokens,)

        # Get min block size from autotune configs for consistency
        block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel)

        # Pass inputs in kernel argument order: fwd_output_grad, fwd_input, merging_probs, row_id_map
        return triton_call_lowering(
            ctx,
            _unpermute_bwd_with_merging_probs_kernel,
            fwd_output_grad,
            fwd_input,
            merging_probs,
            row_id_map,
            grid=grid,
            constexprs={
                "stride_row_id_map_token": row_id_stride_token,
                "stride_row_id_map_expert": row_id_stride_expert,
                "stride_fwd_output_grad_token": fwd_output_grad_stride_token,
                "stride_fwd_output_grad_hidden": fwd_output_grad_stride_hidden,
                "stride_fwd_input_grad_token": fwd_input_grad_stride_token,
                "stride_fwd_input_grad_hidden": fwd_input_grad_stride_hidden,
                "stride_fwd_input_token": fwd_input_stride_token,
                "stride_fwd_input_hidden": fwd_input_stride_hidden,
                "stride_merging_probs_token": merging_probs_stride_token,
                "stride_merging_probs_expert": merging_probs_stride_expert,
                "stride_merging_probs_grad_token": merging_probs_grad_stride_token,
                "stride_merging_probs_grad_expert": merging_probs_grad_stride_expert,
                "num_experts": num_experts,
                "hidden_size": hidden_size,
                "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
                "BLOCK_SIZE": block_size,
            },
        )


register_primitive(UnpermuteBwdWithMergingProbsPrimitive)


def unpermute_bwd_with_merging_probs(
    fwd_output_grad: jnp.ndarray,
    row_id_map: jnp.ndarray,
    fwd_input: jnp.ndarray,
    merging_probs: jnp.ndarray,
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Backward pass for unpermute with merging probabilities.

    This computes gradients for both the input tensor and merging_probs.

    Parameters
    ----------
    fwd_output_grad : jnp.ndarray
        Gradient of the forward output of shape `[num_tokens, hidden_size]`.
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    fwd_input : jnp.ndarray
        The input tensor from the forward pass of shape `[num_out_tokens, hidden_size]`.
    merging_probs : jnp.ndarray
        The merging probabilities of shape `[num_tokens, num_experts]`.
    num_tokens : int
        Number of tokens in the unpermuted tensor.
    num_experts : int
        Number of experts.
    num_out_tokens : int
        Number of tokens in the permuted tensor.
    hidden_size : int
        Hidden size.

    Returns
    -------
    fwd_input_grad : jnp.ndarray
        Gradient w.r.t. the input tensor of shape `[num_out_tokens, hidden_size]`.
    merging_probs_grad : jnp.ndarray
        Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`.
    """
    # Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map
    return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind(
        fwd_output_grad,
        fwd_input,
        merging_probs,
        row_id_map,
        num_tokens=num_tokens,
        num_experts=num_experts,
        num_out_tokens=num_out_tokens,
        hidden_size=hidden_size,
    )


class MakeChunkSortMapPrimitive(BasePrimitive):
    """
    Make a row_id_map for chunk sort.
    """

    name = "te_make_chunk_sort_map_triton"
    multiple_results = False
    impl_static_args = (2, 3)  # num_tokens, num_splits
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(split_sizes_aval, sorted_indices_aval, *, num_tokens, num_splits):
        """Shape/dtype inference."""
        del sorted_indices_aval
        assert split_sizes_aval.shape == (num_splits,)
        return jax.core.ShapedArray((num_tokens,), jnp.int32)

    @staticmethod
    def impl(split_sizes, sorted_indices, num_tokens, num_splits):
        """Forward to inner primitive."""
        assert MakeChunkSortMapPrimitive.inner_primitive is not None
        return MakeChunkSortMapPrimitive.inner_primitive.bind(
            split_sizes,
            sorted_indices,
            num_tokens=num_tokens,
            num_splits=num_splits,
        )

    @staticmethod
    def lowering(ctx, split_sizes, sorted_indices, *, num_tokens, num_splits):
        """MLIR lowering using triton_call_lowering."""
        grid = (num_tokens,)

        return triton_call_lowering(
            ctx,
            _make_chunk_sort_map_kernel,
            split_sizes,
            sorted_indices,
            grid=grid,
            constexprs={
                "num_splits": num_splits,
                "IDX_LOAD_WIDTH": triton.next_power_of_2(num_splits),
            },
        )


register_primitive(MakeChunkSortMapPrimitive)


class SortChunksByMapPrimitive(BasePrimitive):
    """
    Sort chunks with row_id_map.
    """

    name = "te_sort_chunks_by_map_triton"
    multiple_results = True
    impl_static_args = (3, 4, 5, 6)  # num_tokens, hidden_size, is_forward, with_probs
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        inp_aval, row_id_map_aval, probs_aval, *, num_tokens, hidden_size, is_forward, with_probs
    ):
        """Shape/dtype inference."""
        del row_id_map_aval, is_forward

        output_aval = jax.core.ShapedArray((num_tokens, hidden_size), inp_aval.dtype)

        if with_probs:
            permuted_probs_aval = jax.core.ShapedArray((num_tokens,), probs_aval.dtype)
        else:
            permuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)

        return output_aval, permuted_probs_aval

    @staticmethod
    def impl(inp, row_id_map, probs, num_tokens, hidden_size, is_forward, with_probs):
        """Forward to inner primitive."""
        assert SortChunksByMapPrimitive.inner_primitive is not None
        return SortChunksByMapPrimitive.inner_primitive.bind(
            inp,
            row_id_map,
            probs,
            num_tokens=num_tokens,
            hidden_size=hidden_size,
            is_forward=is_forward,
            with_probs=with_probs,
        )

    @staticmethod
    def lowering(ctx, inp, row_id_map, probs, *, num_tokens, hidden_size, is_forward, with_probs):
        """MLIR lowering using triton_call_lowering."""
        # Compute strides
        inp_stride_token = hidden_size
        inp_stride_hidden = 1
        output_stride_token = hidden_size
        output_stride_hidden = 1
        probs_stride_token = 1
        permuted_probs_stride_token = 1

        # Grid - use minimum BLOCK_SIZE from autotune configs
        block_size = _get_min_block_size(_sort_chunks_by_map_kernel)
        grid = (num_tokens, triton.cdiv(hidden_size, block_size))

        return triton_call_lowering(
            ctx,
            _sort_chunks_by_map_kernel,
            inp,
            row_id_map,
            probs,
            grid=grid,
            constexprs={
                "stride_input_token": inp_stride_token,
                "stride_input_hidden": inp_stride_hidden,
                "stride_output_token": output_stride_token,
                "stride_output_hidden": output_stride_hidden,
                "stride_probs_token": probs_stride_token,
                "stride_permuted_probs_token": permuted_probs_stride_token,
                "hidden_size": hidden_size,
                "PERMUTE_PROBS": with_probs,
                "FORWARD": is_forward,
                "BLOCK_SIZE": block_size,
            },
        )


register_primitive(SortChunksByMapPrimitive)


def make_row_id_map(
    routing_map: jnp.ndarray,
    num_tokens: int,
    num_experts: int,
) -> jnp.ndarray:
    """
    Prepare the row_id_map for the permutation.

    This function chains 3 Triton kernel passes together.

    Parameters
    ----------
    routing_map : jnp.ndarray
        Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
        which experts are routed to which tokens. The values in it: 1 means the token is routed to
        this expert and 0 means not.
    num_tokens : int
        Number of tokens in the input tensor.
    num_experts : int
        Number of experts in the input tensor.

    Returns
    -------
    row_id_map : jnp.ndarray
        The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
        For each token, the last item is the number of experts that are routed (n_routed).
        The first n_routed items are the destination row indices in the permuted tokens.
        The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
        to the first n_routed row indices above.
    """
    block_size = DEFAULT_BLOCK_SIZE

    # Pass 1: Block cumsum
    row_id_map_pass1, workspace_tensor = RowIdMapPass1Primitive.outer_primitive.bind(
        routing_map,
        num_tokens=num_tokens,
        num_experts=num_experts,
        block_size=block_size,
    )

    # Pass 2: Cumsum all and process the mask
    row_id_map_pass2, _ = RowIdMapPass2Primitive.outer_primitive.bind(
        row_id_map_pass1,
        workspace_tensor,
        num_tokens=num_tokens,
        num_experts=num_experts,
        block_size=block_size,
    )

    # Initialize columns [num_experts:] to -1 since Pass 1/2 only wrote to [0:num_experts]
    # Reference implementation expects -1 for invalid entries
    row_id_map = row_id_map_pass2.at[:, num_experts:].set(-1)

    # Pass 3: Make the row_id_map from sparse to dense structure
    row_id_map = RowIdMapPass3Primitive.outer_primitive.bind(
        row_id_map,
        num_tokens=num_tokens,
        num_experts=num_experts,
    )

    return row_id_map


def permute_with_mask_map(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    probs: Optional[jnp.ndarray],
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    """
    Permute the input tensor based on the row_id_map.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    probs : Optional[jnp.ndarray]
        The probabilities of the input tensor. If it is not None, it will be permuted.
    num_tokens : int
        Number of tokens in the input tensor.
    num_experts : int
        Number of experts in the input tensor.
    num_out_tokens : int
        Number of tokens in the permuted tensor.
    hidden_size : int
        Hidden size of the input tensor.

    Returns
    -------
    output : jnp.ndarray
        Permuted output tensor of shape `[num_out_tokens, hidden_size]`.
    permuted_probs : Optional[jnp.ndarray]
        Permuted probabilities if probs was provided, None otherwise.
    """
    with_probs = probs is not None

    # Handle None probs by creating dummy tensor
    if not with_probs:
        probs = jnp.zeros((0,), dtype=inp.dtype)

    # Create dummy scale tensors (not used when PERMUTE_SCALE=False, but required by kernel signature)
    dummy_scale = inp
    dummy_permuted_scale = inp

    output, permuted_probs = PermuteWithMaskMapPrimitive.outer_primitive.bind(
        inp,
        row_id_map,
        probs,
        dummy_scale,
        dummy_permuted_scale,
        num_tokens=num_tokens,
        num_experts=num_experts,
        num_out_tokens=num_out_tokens,
        hidden_size=hidden_size,
        with_probs=with_probs,
    )

    if not with_probs:
        permuted_probs = None

    return output, permuted_probs


def unpermute_with_mask_map(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: Optional[jnp.ndarray],
    permuted_probs: Optional[jnp.ndarray],
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    """
    Unpermute the input tensor based on the row_id_map.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape `[num_out_tokens, hidden_size]`.
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    merging_probs : Optional[jnp.ndarray]
        The merging probabilities of the input tensor. If it is not None, it will be used as weights
        to reduce the unpermuted tokens.
    permuted_probs : Optional[jnp.ndarray]
        The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
    num_tokens : int
        Number of tokens in the permuted tensor.
    num_experts : int
        Number of experts in the permuted tensor.
    hidden_size : int
        Hidden size of the permuted tensor.

    Returns
    -------
    output : jnp.ndarray
        Unpermuted output tensor of shape `[num_tokens, hidden_size]`.
    unpermuted_probs : Optional[jnp.ndarray]
        Unpermuted probabilities if permuted_probs was provided, None otherwise.
    """
    with_merging_probs = merging_probs is not None
    with_probs = permuted_probs is not None

    # Handle None inputs by creating dummy tensors
    if not with_merging_probs:
        merging_probs = jnp.zeros((0,), dtype=inp.dtype)
    if not with_probs:
        permuted_probs = jnp.zeros((0,), dtype=inp.dtype)

    output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind(
        inp,
        row_id_map,
        merging_probs,
        permuted_probs,
        num_tokens=num_tokens,
        num_experts=num_experts,
        hidden_size=hidden_size,
        with_merging_probs=with_merging_probs,
        with_probs=with_probs,
    )

    if not with_probs:
        unpermuted_probs = None

    return output, unpermuted_probs


def make_chunk_sort_map(
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
    num_tokens: int,
    num_splits: int,
) -> jnp.ndarray:
    """
    Make a row_id_map for chunk sort.

    Parameters
    ----------
    split_sizes : jnp.ndarray
        The sizes of the chunks of shape `[num_splits,]`.
    sorted_indices : jnp.ndarray
        The indices of the sorted chunks of shape `[num_splits,]`.
    num_tokens : int
        Number of tokens in the input tensor.
    num_splits : int
        Number of splits of split_sizes and sorted_indices.

    Returns
    -------
    row_id_map : jnp.ndarray
        Row ID map for chunk sorting of shape `[num_tokens,]`.
    """
    return MakeChunkSortMapPrimitive.outer_primitive.bind(
        split_sizes,
        sorted_indices,
        num_tokens=num_tokens,
        num_splits=num_splits,
    )


def sort_chunks_by_map(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    probs: Optional[jnp.ndarray],
    num_tokens: int,
    hidden_size: int,
    is_forward: bool,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    """
    Sort chunks with row_id_map.

    Parameters
    ----------
    inp : jnp.ndarray
        Input tensor of shape `[num_tokens, hidden_size]`.
    row_id_map : jnp.ndarray
        The token to expert mapping tensor of shape `[num_tokens,]`.
    probs : Optional[jnp.ndarray]
        The probabilities of the input tensor. If it is not None, it will be permuted.
    num_tokens : int
        Number of tokens in the input tensor.
    hidden_size : int
        Hidden size of the input tensor.
    is_forward : bool
        Whether the sort is for forward or backward.

    Returns
    -------
    output : jnp.ndarray
        Sorted output tensor of shape `[num_tokens, hidden_size]`.
    permuted_probs : Optional[jnp.ndarray]
        Sorted probabilities if probs was provided, None otherwise.
    """
    with_probs = probs is not None

    # Handle None probs by creating dummy tensor
    if not with_probs:
        probs = jnp.zeros((0,), dtype=inp.dtype)

    output, permuted_probs = SortChunksByMapPrimitive.outer_primitive.bind(
        inp,
        row_id_map,
        probs,
        num_tokens=num_tokens,
        hidden_size=hidden_size,
        is_forward=is_forward,
        with_probs=with_probs,
    )

    if not with_probs:
        permuted_probs = None

    return output, permuted_probs