permutation.py 26.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
"""MoE Permutaion API"""
6
7
8
9
10
import warnings
from typing import Tuple
import torch

import transformer_engine_torch as tex
11
12
13
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.float8_tensor import Float8Tensor
14
15
16
17
18


__all__ = [
    "moe_permute",
    "moe_unpermute",
19
    "moe_sort_chunks_by_index",
20
21
22
]


23
24
class _moe_permute_index_map(torch.autograd.Function):
    """functional Permute with index router map"""
25
26
27
28
29
30
31
32

    workspace = None
    max_expanded_token_num = 0

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
33
        index: torch.Tensor,
34
35
36
        num_out_tokens: int,
        max_token_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
37
        # pylint: disable=missing-function-docstring
38
39
        # Empty input check
        if not inp.numel():
40
            return inp, torch.tensor([], device=inp.device)
41
42
43

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
44
        assert index.is_cuda, "TransformerEngine needs CUDA."
45
        # Shape check
46
        assert inp.size(0) == index.size(0), "Permute not possible"
47
48

        # Data type check
49
        fp8 = isinstance(inp, Float8Tensor)
50
        if fp8:
51
52
53
            assert (
                inp._quantizer.scale.ndim == 0
            ), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute."
54
            dtype = inp._fp8_dtype
55
            fp8_scale_inv = inp._scale_inv
56
            fake_dtype = inp.dtype
57
            inp = inp._data
58
59
        else:
            dtype = TE_DType[inp.dtype]
60
        if index.dtype != torch.int32:
61
            warnings.warn(
62
                f"The data type of the input `index` of Permute is {index.dtype}! "
63
64
                "The recommended type is torch.int32."
            )
65
            index = index.to(torch.int32)
66

67
        topK = index.size(1)
68
69

        input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
70
71
72
        if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num:
            _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num
            _moe_permute_index_map.workspace = []
73

74
        permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd(
75
76
            inp,
            dtype,
77
            index,
78
            num_out_tokens,
79
80
            _moe_permute_index_map.workspace,
            _moe_permute_index_map.max_expanded_token_num,
81
82
83
84
        )

        if fp8:
            permuted_act = Float8Tensor(
85
86
87
88
89
                data=permuted_act,
                fp8_dtype=dtype,
                fp8_scale_inv=fp8_scale_inv,
                shape=permuted_act.shape,
                dtype=fake_dtype,
90
91
92
            )

        ctx.row_id_map = row_id_map
93
94
        ctx.num_tokens = index.size(0)
        ctx.topK = index.size(1)
95
96
97
98
99
100
101
102
103
        ctx.fp8 = fp8
        return permuted_act, row_id_map

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
    ) -> Tuple[torch.Tensor, ...]:
104
        # pylint: disable=missing-function-docstring
105
106
107
108
109
110
111
        # Empty input check
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None, None

        if not permuted_act_grad.is_contiguous():
            permuted_act_grad = permuted_act_grad.contiguous()

112
        if ctx.fp8:
113
114
115
            assert isinstance(
                permuted_act_grad, Float8Tensor
            ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
116
            dtype = permuted_act_grad._fp8_dtype
117
            fp8_scale_inv = permuted_act_grad._scale_inv
118
            fake_dtype = permuted_act_grad.dtype
119
            permuted_act_grad = permuted_act_grad._data
120
121
        else:
            dtype = TE_DType[permuted_act_grad.dtype]
122
123
124
125

        act_grad = None
        if ctx.needs_input_grad[0]:
            act_grad = tex.moe_permute_bwd(
126
                permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
127
            )
128
            if ctx.fp8:
129
                act_grad = Float8Tensor(
130
131
132
133
134
                    data=act_grad,
                    fp8_dtype=dtype,
                    fp8_scale_inv=fp8_scale_inv * ctx.topK,
                    shape=act_grad.shape,
                    dtype=fake_dtype,
135
136
                )

137
        return act_grad, None, None, None
138
139


140
141
class _moe_unpermute_index_map(torch.autograd.Function):
    """functional Unpermute with index router map"""
142
143
144
145
146
147
148
149

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
        probs: torch.Tensor,
    ) -> torch.Tensor:
150
        # pylint: disable=missing-function-docstring
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        # Empty input check
        if not inp.numel():
            ctx.probs = probs
            return inp

        # None probs check
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."

            if probs.dtype != torch.float32:
                warnings.warn(
                    f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
                    "The recommended type is torch.float32."
                )
                probs = probs.to(torch.float32)

            num_tokens = probs.size(0)
            topK = probs.size(1)
        else:
            num_tokens = row_id_map.size(0)
            topK = 1
            probs = torch.empty(0)

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        # Data type check
179
        fp8 = isinstance(inp, Float8Tensor)
180
        if fp8:
181
            dtype = inp._fp8_dtype
182
            fp8_scale_inv = inp._scale_inv
183
            fake_dtype = inp.dtype
184
            inp = inp._data
185
186
        else:
            dtype = TE_DType[inp.dtype]
187
188
189
190
191
192
193
194
195
196
197
        if row_id_map.dtype != torch.int32:
            warnings.warn(
                f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
                "The recommended type is torch.int32."
            )
            row_id_map = row_id_map.to(torch.int32)

        unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)

        if fp8:
            unpermuted_output = Float8Tensor(
198
199
200
201
202
                data=unpermuted_output,
                fp8_dtype=dtype,
                fp8_scale_inv=fp8_scale_inv,
                shape=unpermuted_output.shape,
                dtype=fake_dtype,
203
204
205
206
207
208
209
210
211
212
213
            )

        ctx.save_for_backward(inp, row_id_map, probs)
        ctx.fp8 = fp8
        return unpermuted_output

    @staticmethod
    def backward(
        ctx,
        unpermuted_act_grad: torch.Tensor,
    ) -> Tuple[torch.Tensor, None, torch.Tensor]:
214
        # pylint: disable=missing-function-docstring
215
216
217
218
219
220
221
        # Empty input check
        if not unpermuted_act_grad.numel():
            return unpermuted_act_grad, None, ctx.probs

        if not unpermuted_act_grad.is_contiguous():
            unpermuted_act_grad = unpermuted_act_grad.contiguous()

222
        if ctx.fp8:
223
224
225
            assert isinstance(
                unpermuted_act_grad, Float8Tensor
            ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
226
            dtype = unpermuted_act_grad._fp8_dtype
227
            fp8_scale_inv = unpermuted_act_grad._scale_inv
228
            fake_dtype = unpermuted_act_grad.dtype
229
            unpermuted_act_grad = unpermuted_act_grad._data
230
231
        else:
            dtype = TE_DType[unpermuted_act_grad.dtype]
232
233
234
235

        inp, row_id_map, probs = ctx.saved_tensors

        act_grad = None
236
        prob_grad = None
237
238
        if ctx.needs_input_grad[0]:
            act_grad, prob_grad = tex.moe_unpermute_bwd(
239
                unpermuted_act_grad, inp, dtype, row_id_map, probs
240
            )
241
            if ctx.fp8:
242
243
244
245
246
247
248
                act_grad = Float8Tensor(
                    data=act_grad,
                    fp8_dtype=dtype,
                    fp8_scale_inv=fp8_scale_inv,
                    shape=act_grad.shape,
                    dtype=fake_dtype,
                )
249
        if not ctx.needs_input_grad[2]:
250
251
            prob_grad = None

252
        return act_grad, None, prob_grad
253
254


255
256
257
258
259
260
261
262
263
class _moe_permute_mask_map(torch.autograd.Function):
    """functional Permute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        routing_map: torch.Tensor,
        num_out_tokens: int,
264
        probs: torch.Tensor,
265
266
267
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
268
269
            ctx.probs = probs
            return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device)
270
271
272

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert routing_map.is_cuda, "TransformerEngine needs CUDA."
273
274
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."
275
276
277
278
279
280
281
282
283
284
285
286
287
288

        assert inp.size(0) == routing_map.size(0), "Permute not possible"
        num_tokens, hidden_size = inp.size()
        num_experts = routing_map.size(1)
        assert (
            num_out_tokens is not None
        ), "num_out_tokens must be provided to the fused permute function."

        row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
289
            fake_dtype = inp.dtype
290
            inp = inp._data
291
        output, permuted_probs = triton_permutation.permute_with_mask_map(
292
293
            inp,
            row_id_map,
294
            probs,
295
296
297
298
299
300
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
        )
        if fp8:
301
302
303
304
305
306
307
            output = Float8Tensor(
                data=output,
                fp8_dtype=fp8_dtype,
                fp8_scale_inv=fp8_scale_inv,
                shape=output.shape,
                dtype=fake_dtype,
            )
308
309
310
311
312

        ctx.save_for_backward(row_id_map)
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
313
        return output, row_id_map, permuted_probs
314
315
316
317
318
319

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
320
        permuted_probs_grad: torch.Tensor,
321
322
323
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
324
            return permuted_act_grad, None, None, ctx.probs
325
326

        act_grad = None
327
        probs_grad = None
328
329
330
331
332
333
        if ctx.needs_input_grad[0]:
            (row_id_map,) = ctx.saved_tensors
            fp8 = isinstance(permuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = permuted_act_grad._fp8_dtype
                fp8_scale_inv = permuted_act_grad._scale_inv
334
                fake_dtype = permuted_act_grad.dtype
335
336
337
                permuted_act_grad = permuted_act_grad._data
            else:
                fp8_dtype = None
338
            act_grad, probs_grad = triton_permutation.unpermute_with_mask_map(
339
340
341
                permuted_act_grad,
                row_id_map,
                None,
342
                permuted_probs_grad,
343
344
345
346
347
348
349
350
351
352
                ctx.num_tokens,
                ctx.num_experts,
                ctx.hidden_size,
                fp8_dtype,
            )
            if fp8:
                act_grad = Float8Tensor(
                    data=act_grad,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv * ctx.num_experts,
353
354
                    shape=act_grad.shape,
                    dtype=fake_dtype,
355
                )
356
357
358
        if not ctx.needs_input_grad[3]:
            probs_grad = None
        return act_grad, None, None, probs_grad
359
360
361
362
363
364
365
366
367
368


class _moe_unpermute_mask_map(torch.autograd.Function):
    """functional Unpermute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
369
        merging_probs: torch.Tensor,
370
371
372
373
        restore_shape: torch.Size,
    ) -> torch.Tensor:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
374
            ctx.merging_probs = merging_probs
375
376
377
378
379
380
381
            return inp

        if restore_shape is None:
            restore_shape = inp.shape
        num_tokens, hidden_size = restore_shape
        num_experts = row_id_map.size(0)

382
        with_probs = merging_probs is not None
383
        if with_probs:
384
            assert merging_probs.is_cuda, "TransformerEngine needs CUDA."
385
386
387
388
389
390
391
392
393
394
395
396

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            if not with_probs:
                fp8_scale_inv = inp._scale_inv * num_experts
            else:
                fp8_scale_inv = inp._scale_inv
397
            fake_dtype = inp.dtype
398
399
400
            inp = inp._data
        else:
            fp8_dtype = None
401
        unpermuted_output, _ = triton_permutation.unpermute_with_mask_map(
402
403
            inp,
            row_id_map,
404
405
            merging_probs,
            None,
406
407
408
409
410
411
412
            num_tokens,
            num_experts,
            hidden_size,
            fp8_dtype=fp8_dtype,
        )
        if fp8:
            unpermuted_output = Float8Tensor(
413
414
415
416
417
                data=unpermuted_output,
                fp8_dtype=fp8_dtype,
                fp8_scale_inv=fp8_scale_inv,
                shape=unpermuted_output.shape,
                dtype=fake_dtype,
418
419
420
            )

        if with_probs:
421
            ctx.save_for_backward(inp, row_id_map, merging_probs)
422
423
424
425
426
427
428
429
430
431
432
433
434
        else:
            ctx.save_for_backward(row_id_map)
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.num_permuted_tokens = inp.size(0)
        ctx.hidden_size = hidden_size
        ctx.with_probs = with_probs
        return unpermuted_output

    @staticmethod
    def backward(ctx, unpermuted_act_grad):
        # pylint: disable=missing-function-docstring
        if not unpermuted_act_grad.numel():
435
            return unpermuted_act_grad, None, ctx.merging_probs, None
436
437
438
439
440

        act_grad = None
        probs_grad = None
        if ctx.needs_input_grad[0]:
            if ctx.with_probs:
441
                fwd_input, row_id_map, merging_probs = ctx.saved_tensors
442
443
444
445
446
447
448
            else:
                (row_id_map,) = ctx.saved_tensors

            fp8 = isinstance(unpermuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = unpermuted_act_grad._fp8_dtype
                fp8_scale_inv = unpermuted_act_grad._scale_inv
449
                fake_dtype = unpermuted_act_grad.dtype
450
451
452
453
454
                unpermuted_act_grad = unpermuted_act_grad._data
            else:
                fp8_dtype = None

            if ctx.with_probs:
455
456
457
458
459
460
461
462
463
464
465
466
                act_grad, probs_grad = (
                    triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
                        unpermuted_act_grad,
                        row_id_map,
                        fwd_input,
                        merging_probs,
                        ctx.num_tokens,
                        ctx.num_experts,
                        ctx.num_permuted_tokens,
                        ctx.hidden_size,
                        fp8_dtype,
                    )
467
468
                )
            else:
469
                act_grad, _ = triton_permutation.permute_with_mask_map(
470
471
                    unpermuted_act_grad,
                    row_id_map,
472
                    None,
473
474
475
476
477
478
479
480
                    ctx.num_tokens,
                    ctx.num_experts,
                    ctx.num_permuted_tokens,
                    ctx.hidden_size,
                )

            if fp8:
                act_grad = Float8Tensor(
481
482
483
484
485
                    data=act_grad,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv,
                    shape=act_grad.shape,
                    dtype=fake_dtype,
486
487
488
489
490
491
492
                )

        if not ctx.needs_input_grad[2]:
            probs_grad = None
        return act_grad, None, probs_grad, None


493
494
def moe_permute(
    inp: torch.Tensor,
495
    routing_map: torch.Tensor,
496
497
    num_out_tokens: int = -1,
    max_token_num: int = -1,
498
    map_type: str = "mask",
499
500
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
501
502
503
    Permute the tokens based on the routing_map. Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.
504
505
506
507
508

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
509
510
511
512
513
514
    routing_map: torch.Tensor
        The token to expert mapping tensor.
        If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
        If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
        The values in it are the routed expert indices.
515
516
517
518
519
520
521
    num_out_tokens: int, default = -1
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
    max_token_num: int, default = -1
        The maximum number of tokens, used for workspace allocation.
        By default, set to '-1', meaning the calculation of the size of workspace is
        automatically taken over by the operator.
522
523
524
    map_type: str, default = 'mask'
        Type of the routing map tensor.
        Options are: 'mask', 'index'.
525
        Refer to `routing_map` for more details.
526
    """
527
528
529
    if map_type == "index":
        return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
    if map_type == "mask":
530
531
        output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None)
        return output, row_id_map
532
    raise ValueError("map_type should be one of 'mask' or 'index'")
533
534


535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def moe_permute_with_probs(
    inp: torch.Tensor,
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    num_out_tokens: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Permute the tokens and probs based on the routing_map.
    Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    probs: torch.Tensor
        The tensor of probabilities corresponding to the permuted tokens and is
        of shape [num_tokens, num_experts]. It will be permuted with the tokens
        according to the routing_map.
    routing_map: torch.Tensor
        The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
    num_out_tokens: int, default = -1
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
    """
    output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
        inp, routing_map, num_out_tokens, probs
    )
    return output, permuted_probs, row_id_map


568
569
570
def moe_unpermute(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
571
    merging_probs: torch.Tensor = None,
572
573
    restore_shape: torch.Tensor = None,
    map_type: str = "mask",
574
    probs: torch.Tensor = None,
575
576
577
578
579
580
581
582
583
584
585
586
) -> torch.Tensor:
    """
    Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
    corresponding probabilities.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
    row_id_map: torch.Tensor
        The tensor of a mapping table for sorted indices used to unpermute the tokens,
        which is the second output tensor of `Permute`.
587
    merging_probs: torch.Tensor, default = None
588
589
590
        The tensor of probabilities corresponding to the permuted tokens. If provided,
        the unpermuted tokens will be merged with their respective probabilities.
        By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
591
592
593
594
595
    restore_shape: torch.Tensor
        The output shape after the unpermute operation.
    map_type: str, default = 'mask'
        Type of the routing map tensor. Should be the same as the value passed to moe_permute.
        Options are: 'mask', 'index'.
596
597
    probs: torch.Tensor, default = None
        Renamed to merging_probs. Keep for backward compatibility.
598
    """
599
600
601
602
603
604
605
    if probs is not None:
        if merging_probs is not None:
            raise ValueError(
                "Both merging_probs and probs kwarg are provided. probs is deprecated."
            )
        warnings.warn("probs kwarg is deprecated. Use merging_probs kwarg instead.")
        merging_probs = probs
606
    if map_type == "index":
607
        return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs)
608
    if map_type == "mask":
609
        return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape)
610
611
612
613
614
615
616
617
618
619
620
621
    raise ValueError("map_type should be one of 'mask' or 'index'")


class _moe_chunk_sort(torch.autograd.Function):
    """functional MoE chunk permute"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        split_sizes: torch.Tensor,
        sorted_idxs: torch.Tensor,
622
        probs: torch.Tensor,
623
624
625
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
626
            return inp, probs
627
628
629
630

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert split_sizes.is_cuda, "TransformerEngine needs CUDA."
        assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA."
631
632
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."
633
634
635
636
637
638
639
640
641

        num_tokens, hidden_size = inp.shape
        num_splits = split_sizes.size(0)
        assert num_splits == sorted_idxs.size(0)

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
642
            fake_dtype = inp.dtype
643
            inp = inp._data
644
        output, row_id_map, permuted_probs = triton_permutation.sort_chunks_by_idx(
645
646
647
            inp,
            split_sizes,
            sorted_idxs,
648
            probs,
649
650
651
652
653
            num_tokens,
            hidden_size,
            num_splits,
        )
        if fp8:
654
655
656
657
658
659
660
            output = Float8Tensor(
                data=output,
                fp8_dtype=fp8_dtype,
                fp8_scale_inv=fp8_scale_inv,
                shape=output.shape,
                dtype=fake_dtype,
            )
661
662
663
664

        ctx.save_for_backward(row_id_map)
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
665
        return output, permuted_probs
666
667
668
669
670

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
671
        permuted_probs_grad: torch.Tensor,
672
673
674
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
675
            return permuted_act_grad, None, None, permuted_probs_grad
676
677

        act_grad = None
678
        probs_grad = None
679
680
681
682
683
684
        if ctx.needs_input_grad[0]:
            (row_id_map,) = ctx.saved_tensors
            fp8 = isinstance(permuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = permuted_act_grad._fp8_dtype
                fp8_scale_inv = permuted_act_grad._scale_inv
685
                fake_dtype = permuted_act_grad.dtype
686
                permuted_act_grad = permuted_act_grad._data
687
            act_grad, probs_grad = triton_permutation.sort_chunks_by_map(
688
689
                permuted_act_grad,
                row_id_map,
690
                permuted_probs_grad,
691
692
693
694
695
                ctx.num_tokens,
                ctx.hidden_size,
            )
            if fp8:
                act_grad = Float8Tensor(
696
697
698
699
700
                    data=act_grad,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv,
                    shape=act_grad.shape,
                    dtype=fake_dtype,
701
                )
702
703
704
        if not ctx.needs_input_grad[3]:
            probs_grad = None
        return act_grad, None, None, probs_grad
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724


def moe_sort_chunks_by_index(
    inp: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split and sort the input tensor based on the split_sizes and sorted indices.
    The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
    according to the sorted_indices.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    split_sizes: torch.Tensor
        Chunk sizes of the inp tensor along the 0-th dimension.
    sorted_indices: torch.Tensor
        Chunk indices used to permute the chunks.
725
    """
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
    output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None)
    return output


def moe_sort_chunks_by_index_with_probs(
    inp: torch.Tensor,
    probs: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split and sort the input tensor and probs based on the split_sizes and sorted indices.
    The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
    according to the sorted_indices.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    probs: torch.Tensor
        The tensor of probabilities corresponding to the permuted tokens and is
        of shape [num_tokens]. It will be permuted with the tokens according to
        the split_sizes and sorted_indices.
    split_sizes: torch.Tensor
        Chunk sizes of the inp tensor along the 0-th dimension.
    sorted_indices: torch.Tensor
        Chunk indices used to permute the chunks.
    """
    output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs)
    return output, permuted_probs