permutation.py 28.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
"""MoE Permutaion API"""
6
import warnings
7
from typing import Optional, Tuple
8
9
10
import torch

import transformer_engine_torch as tex
11
12
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
13
14
15
16
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
17
18
19
20

__all__ = [
    "moe_permute",
    "moe_unpermute",
21
    "moe_sort_chunks_by_index",
22
23
24
]


25
26
class _moe_permute_index_map(torch.autograd.Function):
    """functional Permute with index router map"""
27
28
29
30
31
32
33
34

    workspace = None
    max_expanded_token_num = 0

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
35
        index: torch.Tensor,
36
37
38
        num_out_tokens: int,
        max_token_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
39
        # pylint: disable=missing-function-docstring
40
41
        # Empty input check
        if not inp.numel():
42
            return inp, torch.tensor([], device=inp.device)
43
44
45

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
46
        assert index.is_cuda, "TransformerEngine needs CUDA."
47
        # Shape check
48
        assert inp.size(0) == index.size(0), "Permute not possible"
49
50

        # Data type check
51
        dtype = TE_DType[inp.dtype]
52
        if index.dtype != torch.int32:
53
            warnings.warn(
54
                f"The data type of the input `index` of Permute is {index.dtype}! "
55
56
                "The recommended type is torch.int32."
            )
57
            index = index.to(torch.int32)
58

59
        topK = index.size(1)
60
61

        input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
62
63
64
        if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num:
            _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num
            _moe_permute_index_map.workspace = []
65

66
        permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd(
67
68
            inp,
            dtype,
69
            index,
70
            num_out_tokens,
71
72
            _moe_permute_index_map.workspace,
            _moe_permute_index_map.max_expanded_token_num,
73
74
75
        )

        ctx.row_id_map = row_id_map
76
77
        ctx.num_tokens = index.size(0)
        ctx.topK = index.size(1)
78
79
80
81
82
83
84
85
        return permuted_act, row_id_map

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
    ) -> Tuple[torch.Tensor, ...]:
86
        # pylint: disable=missing-function-docstring
87
88
89
90
91
92
93
        # Empty input check
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None, None

        if not permuted_act_grad.is_contiguous():
            permuted_act_grad = permuted_act_grad.contiguous()

94
        dtype = TE_DType[permuted_act_grad.dtype]
95
96
97
        act_grad = None
        if ctx.needs_input_grad[0]:
            act_grad = tex.moe_permute_bwd(
98
                permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
99
100
            )

101
        return act_grad, None, None, None
102
103


104
105
class _moe_unpermute_index_map(torch.autograd.Function):
    """functional Unpermute with index router map"""
106
107
108
109
110
111
112
113

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
        probs: torch.Tensor,
    ) -> torch.Tensor:
114
        # pylint: disable=missing-function-docstring
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        # Empty input check
        if not inp.numel():
            ctx.probs = probs
            return inp

        # None probs check
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."

            if probs.dtype != torch.float32:
                warnings.warn(
                    f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
                    "The recommended type is torch.float32."
                )
                probs = probs.to(torch.float32)

            num_tokens = probs.size(0)
            topK = probs.size(1)
        else:
            num_tokens = row_id_map.size(0)
            topK = 1
            probs = torch.empty(0)

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        # Data type check
143
        dtype = TE_DType[inp.dtype]
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        if row_id_map.dtype != torch.int32:
            warnings.warn(
                f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
                "The recommended type is torch.int32."
            )
            row_id_map = row_id_map.to(torch.int32)

        unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)

        ctx.save_for_backward(inp, row_id_map, probs)
        return unpermuted_output

    @staticmethod
    def backward(
        ctx,
        unpermuted_act_grad: torch.Tensor,
    ) -> Tuple[torch.Tensor, None, torch.Tensor]:
161
        # pylint: disable=missing-function-docstring
162
163
164
165
166
167
168
        # Empty input check
        if not unpermuted_act_grad.numel():
            return unpermuted_act_grad, None, ctx.probs

        if not unpermuted_act_grad.is_contiguous():
            unpermuted_act_grad = unpermuted_act_grad.contiguous()

169
        dtype = TE_DType[unpermuted_act_grad.dtype]
170
171
172
        inp, row_id_map, probs = ctx.saved_tensors

        act_grad = None
173
        prob_grad = None
174
175
        if ctx.needs_input_grad[0]:
            act_grad, prob_grad = tex.moe_unpermute_bwd(
176
                unpermuted_act_grad, inp, dtype, row_id_map, probs
177
            )
178
        if not ctx.needs_input_grad[2]:
179
180
            prob_grad = None

181
        return act_grad, None, prob_grad
182
183


184
185
186
187
188
189
190
191
192
class _moe_permute_mask_map(torch.autograd.Function):
    """functional Permute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        routing_map: torch.Tensor,
        num_out_tokens: int,
193
        probs: torch.Tensor,
194
195
196
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
197
198
            ctx.probs = probs
            return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device)
199
200
201

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert routing_map.is_cuda, "TransformerEngine needs CUDA."
202
203
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."
204
205
206
207
208
209
210
211
212
213

        assert inp.size(0) == routing_map.size(0), "Permute not possible"
        num_tokens, hidden_size = inp.size()
        num_experts = routing_map.size(1)
        assert (
            num_out_tokens is not None
        ), "num_out_tokens must be provided to the fused permute function."

        row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)

214
215
216
217
218
        fp8 = isinstance(inp, QuantizedTensor)
        per_tensor_recipe = isinstance(inp, Float8Tensor)
        blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor)
        mxfp8_recipe = isinstance(inp, MXFP8Tensor)

219
220
        if fp8:
            fp8_dtype = inp._fp8_dtype
221
            fake_dtype = inp.dtype
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            # blockwise scaling
            if blockwise_recipe:
                fp8_scale = inp._rowwise_scale_inv.T.contiguous()
                scale_hidden_dim = fp8_scale.shape[1]
                assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                inp = inp._rowwise_data
            # mxfp8 scaling
            elif mxfp8_recipe:
                fp8_scale = inp._rowwise_scale_inv.contiguous()
                scale_hidden_dim = fp8_scale.shape[1]
                assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                inp = inp._rowwise_data
            # per-tensor scaling
            elif per_tensor_recipe:
                # Kernel does not need scale in per-tensor scaling
                fp8_scale = None
                scale_hidden_dim = None
                fp8_scale_inv = inp._scale_inv
                inp = inp._data
            else:
                raise ValueError("Unsupported FP8 recipe")
        else:
            fp8_scale = None
            fp8_dtype = None
            scale_hidden_dim = None

        output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map(
249
250
            inp,
            row_id_map,
251
            probs,
252
            fp8_scale,
253
254
255
256
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
257
            scale_hidden_dim,
258
        )
259

260
        if fp8:
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
            if per_tensor_recipe:
                output = Float8Tensor(
                    data=output,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv,
                    shape=output.shape,
                    dtype=fake_dtype,
                )
            elif blockwise_recipe:
                output = Float8BlockwiseQTensor(
                    shape=output.shape,
                    dtype=fake_dtype,
                    rowwise_data=output,
                    rowwise_scale_inv=permuted_scale.T.contiguous(),
                    columnwise_data=None,
                    columnwise_scale_inv=None,
                    fp8_dtype=fp8_dtype,
                    quantizer=None,
                    is_2D_scaled=False,
                    requires_grad=output.requires_grad,
                )
            elif mxfp8_recipe:
                output = MXFP8Tensor(
                    shape=output.shape,
                    dtype=fake_dtype,
                    fp8_dtype=fp8_dtype,
                    rowwise_data=output,
                    rowwise_scale_inv=permuted_scale.contiguous(),
                    columnwise_data=None,
                    columnwise_scale_inv=None,
                    quantizer=None,
                    requires_grad=output.requires_grad,
                )
294
295
296
297
298

        ctx.save_for_backward(row_id_map)
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
299
        return output, row_id_map, permuted_probs
300
301
302
303
304
305

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
306
        permuted_probs_grad: torch.Tensor,
307
308
309
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
310
            return permuted_act_grad, None, None, ctx.probs
311
312

        act_grad = None
313
        probs_grad = None
314
315
        if ctx.needs_input_grad[0]:
            (row_id_map,) = ctx.saved_tensors
316
317
318
            assert not isinstance(
                permuted_act_grad, QuantizedTensor
            ), "The backward of moe_permute does not support FP8."
319
            act_grad, probs_grad = triton_permutation.unpermute_with_mask_map(
320
321
322
                permuted_act_grad,
                row_id_map,
                None,
323
                permuted_probs_grad,
324
325
326
327
                ctx.num_tokens,
                ctx.num_experts,
                ctx.hidden_size,
            )
328
329
330
        if not ctx.needs_input_grad[3]:
            probs_grad = None
        return act_grad, None, None, probs_grad
331
332
333
334
335
336
337
338
339
340


class _moe_unpermute_mask_map(torch.autograd.Function):
    """functional Unpermute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
341
342
        merging_probs: Optional[torch.Tensor],
        restore_shape: Optional[torch.Size],
343
344
345
    ) -> torch.Tensor:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
346
            ctx.merging_probs = merging_probs
347
348
349
350
351
            return inp

        if restore_shape is None:
            restore_shape = inp.shape
        num_tokens, hidden_size = restore_shape
352
        num_experts = (row_id_map.size(1) - 1) // 2
353

354
        with_probs = merging_probs is not None
355
        if with_probs:
356
            assert merging_probs.is_cuda, "TransformerEngine needs CUDA."
357
358
359
360
361

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

362
363
364
        assert not isinstance(
            inp, QuantizedTensor
        ), "The forward of moe_unpermute does not support FP8."
365
        unpermuted_output, _ = triton_permutation.unpermute_with_mask_map(
366
367
            inp,
            row_id_map,
368
369
            merging_probs,
            None,
370
371
372
373
374
375
            num_tokens,
            num_experts,
            hidden_size,
        )

        if with_probs:
376
            ctx.save_for_backward(inp, row_id_map, merging_probs)
377
378
379
380
381
382
383
384
385
386
387
388
389
        else:
            ctx.save_for_backward(row_id_map)
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.num_permuted_tokens = inp.size(0)
        ctx.hidden_size = hidden_size
        ctx.with_probs = with_probs
        return unpermuted_output

    @staticmethod
    def backward(ctx, unpermuted_act_grad):
        # pylint: disable=missing-function-docstring
        if not unpermuted_act_grad.numel():
390
            return unpermuted_act_grad, None, ctx.merging_probs, None
391
392
393
394
395

        act_grad = None
        probs_grad = None
        if ctx.needs_input_grad[0]:
            if ctx.with_probs:
396
                fwd_input, row_id_map, merging_probs = ctx.saved_tensors
397
398
399
            else:
                (row_id_map,) = ctx.saved_tensors

400
401
402
403
404
            fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
            per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
            blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor)
            mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor)

405
406
            if fp8:
                fp8_dtype = unpermuted_act_grad._fp8_dtype
407
                fake_dtype = unpermuted_act_grad.dtype
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
                # per-tensor scaling
                if per_tensor_recipe:
                    # Kernel does not need scale in per-tensor scaling
                    fp8_scale = None
                    scale_hidden_dim = None
                    fp8_scale_inv = unpermuted_act_grad._scale_inv
                    unpermuted_act_grad = unpermuted_act_grad._data
                # blockwise scaling
                elif blockwise_recipe:
                    fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous()
                    unpermuted_act_grad = unpermuted_act_grad._rowwise_data
                    scale_hidden_dim = fp8_scale.shape[1]
                    assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                # mxfp8 scaling
                elif mxfp8_recipe:
                    fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous()
                    unpermuted_act_grad = unpermuted_act_grad._rowwise_data
                    scale_hidden_dim = fp8_scale.shape[1]
                    assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                else:
                    raise ValueError("Unsupported FP8 recipe")
429
            else:
430
                scale_hidden_dim = None
431
                fp8_dtype = None
432
                fp8_scale = None
433
434

            if ctx.with_probs:
435
436
437
                assert (
                    not fp8
                ), "The backward of moe_unpermute with merging probs does not support FP8."
438
439
440
441
442
443
444
445
446
447
448
                act_grad, probs_grad = (
                    triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
                        unpermuted_act_grad,
                        row_id_map,
                        fwd_input,
                        merging_probs,
                        ctx.num_tokens,
                        ctx.num_experts,
                        ctx.num_permuted_tokens,
                        ctx.hidden_size,
                    )
449
450
                )
            else:
451
                act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map(
452
453
                    unpermuted_act_grad,
                    row_id_map,
454
                    None,
455
                    fp8_scale,
456
457
458
459
                    ctx.num_tokens,
                    ctx.num_experts,
                    ctx.num_permuted_tokens,
                    ctx.hidden_size,
460
                    scale_hidden_dim,
461
462
463
                )

            if fp8:
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
                if per_tensor_recipe:
                    act_grad = Float8Tensor(
                        data=act_grad,
                        fp8_dtype=fp8_dtype,
                        fp8_scale_inv=fp8_scale_inv,
                        shape=act_grad.shape,
                        dtype=fake_dtype,
                    )
                elif blockwise_recipe:
                    act_grad = Float8BlockwiseQTensor(
                        shape=act_grad.shape,
                        dtype=fake_dtype,
                        rowwise_data=act_grad,
                        rowwise_scale_inv=permuted_scale.T.contiguous(),
                        columnwise_data=None,
                        columnwise_scale_inv=None,
                        fp8_dtype=fp8_dtype,
                        quantizer=None,
                        is_2D_scaled=False,
                        requires_grad=act_grad.requires_grad,
                    )
                elif mxfp8_recipe:
                    act_grad = MXFP8Tensor(
                        shape=act_grad.shape,
                        dtype=fake_dtype,
                        fp8_dtype=fp8_dtype,
                        rowwise_data=act_grad,
                        rowwise_scale_inv=permuted_scale.contiguous(),
                        columnwise_data=None,
                        columnwise_scale_inv=None,
                        quantizer=None,
                        requires_grad=act_grad.requires_grad,
                    )
497
498
499
500
501
502

        if not ctx.needs_input_grad[2]:
            probs_grad = None
        return act_grad, None, probs_grad, None


503
504
def moe_permute(
    inp: torch.Tensor,
505
    routing_map: torch.Tensor,
506
507
    num_out_tokens: int = -1,
    max_token_num: int = -1,
508
    map_type: str = "mask",
509
510
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
511
512
513
    Permute the tokens based on the routing_map. Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.
514
515
516
517
518

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
519
520
521
522
523
524
    routing_map: torch.Tensor
        The token to expert mapping tensor.
        If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
        If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
        The values in it are the routed expert indices.
525
526
527
528
529
530
531
    num_out_tokens: int, default = -1
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
    max_token_num: int, default = -1
        The maximum number of tokens, used for workspace allocation.
        By default, set to '-1', meaning the calculation of the size of workspace is
        automatically taken over by the operator.
532
533
534
    map_type: str, default = 'mask'
        Type of the routing map tensor.
        Options are: 'mask', 'index'.
535
        Refer to `routing_map` for more details.
536
    """
537
538
539
    if map_type == "index":
        return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
    if map_type == "mask":
540
541
        output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None)
        return output, row_id_map
542
    raise ValueError("map_type should be one of 'mask' or 'index'")
543
544


545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
def moe_permute_with_probs(
    inp: torch.Tensor,
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    num_out_tokens: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Permute the tokens and probs based on the routing_map.
    Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    probs: torch.Tensor
        The tensor of probabilities corresponding to the permuted tokens and is
        of shape [num_tokens, num_experts]. It will be permuted with the tokens
        according to the routing_map.
    routing_map: torch.Tensor
        The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
    num_out_tokens: int, default = -1
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
    """
    output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
        inp, routing_map, num_out_tokens, probs
    )
    return output, permuted_probs, row_id_map


578
579
580
def moe_unpermute(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
581
582
    merging_probs: Optional[torch.Tensor] = None,
    restore_shape: Optional[torch.Size] = None,
583
    map_type: str = "mask",
584
    probs: Optional[torch.Tensor] = None,
585
586
587
588
589
590
591
592
593
594
595
596
) -> torch.Tensor:
    """
    Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
    corresponding probabilities.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
    row_id_map: torch.Tensor
        The tensor of a mapping table for sorted indices used to unpermute the tokens,
        which is the second output tensor of `Permute`.
597
    merging_probs: torch.Tensor, default = None
598
599
600
        The tensor of probabilities corresponding to the permuted tokens. If provided,
        the unpermuted tokens will be merged with their respective probabilities.
        By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
601
    restore_shape: torch.Size, default = None
602
603
604
605
        The output shape after the unpermute operation.
    map_type: str, default = 'mask'
        Type of the routing map tensor. Should be the same as the value passed to moe_permute.
        Options are: 'mask', 'index'.
606
607
    probs: torch.Tensor, default = None
        Renamed to merging_probs. Keep for backward compatibility.
608
    """
609
610
611
612
613
614
615
    if probs is not None:
        if merging_probs is not None:
            raise ValueError(
                "Both merging_probs and probs kwarg are provided. probs is deprecated."
            )
        warnings.warn("probs kwarg is deprecated. Use merging_probs kwarg instead.")
        merging_probs = probs
616
    if map_type == "index":
617
        return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs)
618
    if map_type == "mask":
619
        return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape)
620
621
622
623
624
625
626
627
628
629
630
631
    raise ValueError("map_type should be one of 'mask' or 'index'")


class _moe_chunk_sort(torch.autograd.Function):
    """functional MoE chunk permute"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        split_sizes: torch.Tensor,
        sorted_idxs: torch.Tensor,
632
        probs: torch.Tensor,
633
634
635
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
636
            return inp, probs
637
638
639
640

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert split_sizes.is_cuda, "TransformerEngine needs CUDA."
        assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA."
641
642
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."
643
644
645
646
647
648
649
650
651

        num_tokens, hidden_size = inp.shape
        num_splits = split_sizes.size(0)
        assert num_splits == sorted_idxs.size(0)

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
652
            fake_dtype = inp.dtype
653
            inp = inp._data
654
655

        row_id_map = triton_permutation.make_chunk_sort_map(
656
657
            split_sizes,
            sorted_idxs,
658
659
660
661
662
663
            num_tokens,
            num_splits,
        )
        output, permuted_probs = triton_permutation.sort_chunks_by_map(
            inp,
            row_id_map,
664
            probs,
665
666
            num_tokens,
            hidden_size,
667
            is_forward=True,
668
669
        )
        if fp8:
670
671
672
673
674
675
676
            output = Float8Tensor(
                data=output,
                fp8_dtype=fp8_dtype,
                fp8_scale_inv=fp8_scale_inv,
                shape=output.shape,
                dtype=fake_dtype,
            )
677
678
679
680

        ctx.save_for_backward(row_id_map)
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
681
        return output, permuted_probs
682
683
684
685
686

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
687
        permuted_probs_grad: torch.Tensor,
688
689
690
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
691
            return permuted_act_grad, None, None, permuted_probs_grad
692
693

        act_grad = None
694
        probs_grad = None
695
696
697
698
699
700
        if ctx.needs_input_grad[0]:
            (row_id_map,) = ctx.saved_tensors
            fp8 = isinstance(permuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = permuted_act_grad._fp8_dtype
                fp8_scale_inv = permuted_act_grad._scale_inv
701
                fake_dtype = permuted_act_grad.dtype
702
                permuted_act_grad = permuted_act_grad._data
703
            act_grad, probs_grad = triton_permutation.sort_chunks_by_map(
704
705
                permuted_act_grad,
                row_id_map,
706
                permuted_probs_grad,
707
708
                ctx.num_tokens,
                ctx.hidden_size,
709
                is_forward=False,
710
711
712
            )
            if fp8:
                act_grad = Float8Tensor(
713
714
715
716
717
                    data=act_grad,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv,
                    shape=act_grad.shape,
                    dtype=fake_dtype,
718
                )
719
720
721
        if not ctx.needs_input_grad[3]:
            probs_grad = None
        return act_grad, None, None, probs_grad
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741


def moe_sort_chunks_by_index(
    inp: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split and sort the input tensor based on the split_sizes and sorted indices.
    The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
    according to the sorted_indices.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    split_sizes: torch.Tensor
        Chunk sizes of the inp tensor along the 0-th dimension.
    sorted_indices: torch.Tensor
        Chunk indices used to permute the chunks.
742
    """
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
    output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None)
    return output


def moe_sort_chunks_by_index_with_probs(
    inp: torch.Tensor,
    probs: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split and sort the input tensor and probs based on the split_sizes and sorted indices.
    The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
    according to the sorted_indices.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    probs: torch.Tensor
        The tensor of probabilities corresponding to the permuted tokens and is
        of shape [num_tokens]. It will be permuted with the tokens according to
        the split_sizes and sorted_indices.
    split_sizes: torch.Tensor
        Chunk sizes of the inp tensor along the 0-th dimension.
    sorted_indices: torch.Tensor
        Chunk indices used to permute the chunks.
    """
    output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs)
    return output, permuted_probs