permutation.py 31.6 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
"""MoE Permutation API"""
6
import warnings
7
from typing import Optional, Tuple
8
9
10
import torch

import transformer_engine_torch as tex
11
12
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
13
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
14
15
16
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
17
18
19
20

__all__ = [
    "moe_permute",
    "moe_unpermute",
21
    "moe_sort_chunks_by_index",
22
23
24
]


25
26
class _moe_permute_index_map(torch.autograd.Function):
    """functional Permute with index router map"""
27
28
29
30
31
32
33
34

    workspace = None
    max_expanded_token_num = 0

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
35
        index: torch.Tensor,
36
37
38
        num_out_tokens: int,
        max_token_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
39
        # pylint: disable=missing-function-docstring
40
41
        # Empty input check
        if not inp.numel():
42
            return inp, torch.tensor([], device=inp.device)
43
44
45

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
46
        assert index.is_cuda, "TransformerEngine needs CUDA."
47
        # Shape check
48
        assert inp.size(0) == index.size(0), "Permute not possible"
49
50

        # Data type check
51
        dtype = TE_DType[inp.dtype]
52
        if index.dtype != torch.int32:
53
            warnings.warn(
54
                f"The data type of the input `index` of Permute is {index.dtype}! "
55
56
                "The recommended type is torch.int32."
            )
57
            index = index.to(torch.int32)
58

59
        topK = index.size(1)
60
61

        input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
62
63
64
        if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num:
            _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num
            _moe_permute_index_map.workspace = []
65

66
        permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd(
67
68
            inp,
            dtype,
69
            index,
70
            num_out_tokens,
71
72
            _moe_permute_index_map.workspace,
            _moe_permute_index_map.max_expanded_token_num,
73
74
75
        )

        ctx.row_id_map = row_id_map
76
77
        ctx.num_tokens = index.size(0)
        ctx.topK = index.size(1)
78
79
80
81
82
83
84
85
        return permuted_act, row_id_map

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
    ) -> Tuple[torch.Tensor, ...]:
86
        # pylint: disable=missing-function-docstring
87
88
89
90
91
92
93
        # Empty input check
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None, None

        if not permuted_act_grad.is_contiguous():
            permuted_act_grad = permuted_act_grad.contiguous()

94
        dtype = TE_DType[permuted_act_grad.dtype]
95
96
97
        act_grad = None
        if ctx.needs_input_grad[0]:
            act_grad = tex.moe_permute_bwd(
98
                permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
99
100
            )

101
        return act_grad, None, None, None
102
103


104
105
class _moe_unpermute_index_map(torch.autograd.Function):
    """functional Unpermute with index router map"""
106
107
108
109
110
111
112
113

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
        probs: torch.Tensor,
    ) -> torch.Tensor:
114
        # pylint: disable=missing-function-docstring
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        # Empty input check
        if not inp.numel():
            ctx.probs = probs
            return inp

        # None probs check
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."

            if probs.dtype != torch.float32:
                warnings.warn(
                    f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
                    "The recommended type is torch.float32."
                )
                probs = probs.to(torch.float32)

            num_tokens = probs.size(0)
            topK = probs.size(1)
        else:
            num_tokens = row_id_map.size(0)
            topK = 1
            probs = torch.empty(0)

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        # Data type check
143
        dtype = TE_DType[inp.dtype]
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        if row_id_map.dtype != torch.int32:
            warnings.warn(
                f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
                "The recommended type is torch.int32."
            )
            row_id_map = row_id_map.to(torch.int32)

        unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)

        ctx.save_for_backward(inp, row_id_map, probs)
        return unpermuted_output

    @staticmethod
    def backward(
        ctx,
        unpermuted_act_grad: torch.Tensor,
    ) -> Tuple[torch.Tensor, None, torch.Tensor]:
161
        # pylint: disable=missing-function-docstring
162
163
164
165
166
167
168
        # Empty input check
        if not unpermuted_act_grad.numel():
            return unpermuted_act_grad, None, ctx.probs

        if not unpermuted_act_grad.is_contiguous():
            unpermuted_act_grad = unpermuted_act_grad.contiguous()

169
        dtype = TE_DType[unpermuted_act_grad.dtype]
170
171
172
        inp, row_id_map, probs = ctx.saved_tensors

        act_grad = None
173
        prob_grad = None
174
175
        if ctx.needs_input_grad[0]:
            act_grad, prob_grad = tex.moe_unpermute_bwd(
176
                unpermuted_act_grad, inp, dtype, row_id_map, probs
177
            )
178
        if not ctx.needs_input_grad[2]:
179
180
            prob_grad = None

181
        return act_grad, None, prob_grad
182
183


184
185
186
187
188
189
190
191
192
class _moe_permute_mask_map(torch.autograd.Function):
    """functional Permute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        routing_map: torch.Tensor,
        num_out_tokens: int,
193
        probs: torch.Tensor,
194
        pad_offsets: Optional[torch.Tensor],
195
196
197
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
198
199
            ctx.probs = probs
            return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device)
200
201
202

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert routing_map.is_cuda, "TransformerEngine needs CUDA."
203
204
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."
205
206
        if pad_offsets is not None:
            assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
207
208
209
210
211
212
213
214
215
216

        assert inp.size(0) == routing_map.size(0), "Permute not possible"
        num_tokens, hidden_size = inp.size()
        num_experts = routing_map.size(1)
        assert (
            num_out_tokens is not None
        ), "num_out_tokens must be provided to the fused permute function."

        row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)

217
218
219
220
221
        fp8 = isinstance(inp, QuantizedTensor)
        per_tensor_recipe = isinstance(inp, Float8Tensor)
        blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor)
        mxfp8_recipe = isinstance(inp, MXFP8Tensor)

222
223
        if fp8:
            fp8_dtype = inp._fp8_dtype
224
            fake_dtype = inp.dtype
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            # blockwise scaling
            if blockwise_recipe:
                fp8_scale = inp._rowwise_scale_inv.T.contiguous()
                scale_hidden_dim = fp8_scale.shape[1]
                assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                inp = inp._rowwise_data
            # mxfp8 scaling
            elif mxfp8_recipe:
                fp8_scale = inp._rowwise_scale_inv.contiguous()
                scale_hidden_dim = fp8_scale.shape[1]
                assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                inp = inp._rowwise_data
            # per-tensor scaling
            elif per_tensor_recipe:
                # Kernel does not need scale in per-tensor scaling
                fp8_scale = None
                scale_hidden_dim = None
                fp8_scale_inv = inp._scale_inv
                inp = inp._data
            else:
                raise ValueError("Unsupported FP8 recipe")
        else:
            fp8_scale = None
            fp8_dtype = None
            scale_hidden_dim = None

        output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map(
252
253
            inp,
            row_id_map,
254
            probs,
255
            fp8_scale,
256
            pad_offsets,
257
258
259
260
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
261
            scale_hidden_dim,
262
        )
263

264
        if fp8:
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
            if per_tensor_recipe:
                output = Float8Tensor(
                    data=output,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv,
                    shape=output.shape,
                    dtype=fake_dtype,
                )
            elif blockwise_recipe:
                output = Float8BlockwiseQTensor(
                    shape=output.shape,
                    dtype=fake_dtype,
                    rowwise_data=output,
                    rowwise_scale_inv=permuted_scale.T.contiguous(),
                    columnwise_data=None,
                    columnwise_scale_inv=None,
                    fp8_dtype=fp8_dtype,
                    quantizer=None,
                    is_2D_scaled=False,
                    requires_grad=output.requires_grad,
                )
            elif mxfp8_recipe:
                output = MXFP8Tensor(
                    shape=output.shape,
                    dtype=fake_dtype,
                    fp8_dtype=fp8_dtype,
                    rowwise_data=output,
                    rowwise_scale_inv=permuted_scale.contiguous(),
                    columnwise_data=None,
                    columnwise_scale_inv=None,
                    quantizer=None,
                    requires_grad=output.requires_grad,
297
                    with_gemm_swizzled_scales=False,
298
                )
299

300
        ctx.save_for_backward(row_id_map, pad_offsets)
301
302
303
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
304
        return output, row_id_map, permuted_probs
305
306
307
308
309
310

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
311
        permuted_probs_grad: torch.Tensor,
312
313
314
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
315
            return permuted_act_grad, None, None, ctx.probs, None
316
317

        act_grad = None
318
        probs_grad = None
319
        if ctx.needs_input_grad[0]:
320
            row_id_map, pad_offsets = ctx.saved_tensors
321
322
323
            assert not isinstance(
                permuted_act_grad, QuantizedTensor
            ), "The backward of moe_permute does not support FP8."
324
            act_grad, probs_grad = triton_permutation.unpermute_with_mask_map(
325
326
327
                permuted_act_grad,
                row_id_map,
                None,
328
                permuted_probs_grad,
329
                pad_offsets,
330
331
332
333
                ctx.num_tokens,
                ctx.num_experts,
                ctx.hidden_size,
            )
334
335
        if not ctx.needs_input_grad[3]:
            probs_grad = None
336
        return act_grad, None, None, probs_grad, None
337
338
339
340
341
342
343
344
345
346


class _moe_unpermute_mask_map(torch.autograd.Function):
    """functional Unpermute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
347
348
        merging_probs: Optional[torch.Tensor],
        restore_shape: Optional[torch.Size],
349
        pad_offsets: Optional[torch.Tensor],
350
351
352
    ) -> torch.Tensor:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
353
            ctx.merging_probs = merging_probs
354
355
356
357
358
            return inp

        if restore_shape is None:
            restore_shape = inp.shape
        num_tokens, hidden_size = restore_shape
359
        num_experts = (row_id_map.size(1) - 1) // 2
360

361
        with_probs = merging_probs is not None
362
        if with_probs:
363
            assert merging_probs.is_cuda, "TransformerEngine needs CUDA."
364
365
366
367

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
368
369
        if pad_offsets is not None:
            assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
370

371
372
373
        assert not isinstance(
            inp, QuantizedTensor
        ), "The forward of moe_unpermute does not support FP8."
374
        unpermuted_output, _ = triton_permutation.unpermute_with_mask_map(
375
376
            inp,
            row_id_map,
377
378
            merging_probs,
            None,
379
            pad_offsets,
380
381
382
383
384
385
            num_tokens,
            num_experts,
            hidden_size,
        )

        if with_probs:
386
            ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets)
387
        else:
388
            ctx.save_for_backward(row_id_map, pad_offsets)
389
390
391
392
393
394
395
396
397
398
399
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.num_permuted_tokens = inp.size(0)
        ctx.hidden_size = hidden_size
        ctx.with_probs = with_probs
        return unpermuted_output

    @staticmethod
    def backward(ctx, unpermuted_act_grad):
        # pylint: disable=missing-function-docstring
        if not unpermuted_act_grad.numel():
400
            return unpermuted_act_grad, None, ctx.merging_probs, None, None
401
402
403
404
405

        act_grad = None
        probs_grad = None
        if ctx.needs_input_grad[0]:
            if ctx.with_probs:
406
                fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors
407
            else:
408
                row_id_map, pad_offsets = ctx.saved_tensors
409

410
411
412
413
414
            fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
            per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
            blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor)
            mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor)

415
416
            if fp8:
                fp8_dtype = unpermuted_act_grad._fp8_dtype
417
                fake_dtype = unpermuted_act_grad.dtype
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
                # per-tensor scaling
                if per_tensor_recipe:
                    # Kernel does not need scale in per-tensor scaling
                    fp8_scale = None
                    scale_hidden_dim = None
                    fp8_scale_inv = unpermuted_act_grad._scale_inv
                    unpermuted_act_grad = unpermuted_act_grad._data
                # blockwise scaling
                elif blockwise_recipe:
                    fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous()
                    unpermuted_act_grad = unpermuted_act_grad._rowwise_data
                    scale_hidden_dim = fp8_scale.shape[1]
                    assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                # mxfp8 scaling
                elif mxfp8_recipe:
                    fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous()
                    unpermuted_act_grad = unpermuted_act_grad._rowwise_data
                    scale_hidden_dim = fp8_scale.shape[1]
                    assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                else:
                    raise ValueError("Unsupported FP8 recipe")
439
            else:
440
                scale_hidden_dim = None
441
                fp8_dtype = None
442
                fp8_scale = None
443
444

            if ctx.with_probs:
445
446
447
                assert (
                    not fp8
                ), "The backward of moe_unpermute with merging probs does not support FP8."
448
449
450
451
452
453
                act_grad, probs_grad = (
                    triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
                        unpermuted_act_grad,
                        row_id_map,
                        fwd_input,
                        merging_probs,
454
                        pad_offsets,
455
456
457
458
459
                        ctx.num_tokens,
                        ctx.num_experts,
                        ctx.num_permuted_tokens,
                        ctx.hidden_size,
                    )
460
461
                )
            else:
462
                act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map(
463
464
                    unpermuted_act_grad,
                    row_id_map,
465
                    None,
466
                    fp8_scale,
467
                    pad_offsets,
468
469
470
471
                    ctx.num_tokens,
                    ctx.num_experts,
                    ctx.num_permuted_tokens,
                    ctx.hidden_size,
472
                    scale_hidden_dim,
473
474
475
                )

            if fp8:
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
                if per_tensor_recipe:
                    act_grad = Float8Tensor(
                        data=act_grad,
                        fp8_dtype=fp8_dtype,
                        fp8_scale_inv=fp8_scale_inv,
                        shape=act_grad.shape,
                        dtype=fake_dtype,
                    )
                elif blockwise_recipe:
                    act_grad = Float8BlockwiseQTensor(
                        shape=act_grad.shape,
                        dtype=fake_dtype,
                        rowwise_data=act_grad,
                        rowwise_scale_inv=permuted_scale.T.contiguous(),
                        columnwise_data=None,
                        columnwise_scale_inv=None,
                        fp8_dtype=fp8_dtype,
                        quantizer=None,
                        is_2D_scaled=False,
                        requires_grad=act_grad.requires_grad,
                    )
                elif mxfp8_recipe:
                    act_grad = MXFP8Tensor(
                        shape=act_grad.shape,
                        dtype=fake_dtype,
                        fp8_dtype=fp8_dtype,
                        rowwise_data=act_grad,
                        rowwise_scale_inv=permuted_scale.contiguous(),
                        columnwise_data=None,
                        columnwise_scale_inv=None,
                        quantizer=None,
                        requires_grad=act_grad.requires_grad,
508
                        with_gemm_swizzled_scales=False,
509
                    )
510
511
512

        if not ctx.needs_input_grad[2]:
            probs_grad = None
513
        return act_grad, None, probs_grad, None, None
514
515


516
517
def moe_permute(
    inp: torch.Tensor,
518
    routing_map: torch.Tensor,
519
520
    num_out_tokens: int = -1,
    max_token_num: int = -1,
521
    map_type: str = "mask",
522
523
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
524
525
526
    Permute the tokens based on the routing_map. Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.
527
528
529

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
530
    inp : torch.Tensor
531
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
532
    routing_map : torch.Tensor
533
534
535
536
537
        The token to expert mapping tensor.
        If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
        If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
        The values in it are the routed expert indices.
Paweł Gadziński's avatar
Paweł Gadziński committed
538
    num_out_tokens : int, default = -1
539
540
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
Paweł Gadziński's avatar
Paweł Gadziński committed
541
    max_token_num : int, default = -1
542
543
544
        The maximum number of tokens, used for workspace allocation.
        By default, set to '-1', meaning the calculation of the size of workspace is
        automatically taken over by the operator.
Paweł Gadziński's avatar
Paweł Gadziński committed
545
    map_type : str, default = 'mask'
546
547
        Type of the routing map tensor.
        Options are: 'mask', 'index'.
548
        Refer to `routing_map` for more details.
549
    """
550
551
552
    if map_type == "index":
        return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
    if map_type == "mask":
553
554
555
        output, row_id_map, _ = _moe_permute_mask_map.apply(
            inp, routing_map, num_out_tokens, None, None
        )
556
        return output, row_id_map
557
    raise ValueError("map_type should be one of 'mask' or 'index'")
558
559


560
561
562
563
564
565
566
567
568
569
570
571
572
573
def moe_permute_with_probs(
    inp: torch.Tensor,
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    num_out_tokens: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Permute the tokens and probs based on the routing_map.
    Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
574
    inp : torch.Tensor
575
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
576
    probs : torch.Tensor
577
578
579
        The tensor of probabilities corresponding to the permuted tokens and is
        of shape [num_tokens, num_experts]. It will be permuted with the tokens
        according to the routing_map.
Paweł Gadziński's avatar
Paweł Gadziński committed
580
    routing_map : torch.Tensor
581
582
        The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
Paweł Gadziński's avatar
Paweł Gadziński committed
583
    num_out_tokens : int, default = -1
584
585
586
587
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
    """
    output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
588
        inp, routing_map, num_out_tokens, probs, None
589
590
591
592
    )
    return output, permuted_probs, row_id_map


593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
def moe_permute_and_pad_with_probs(
    inp: torch.Tensor,
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    tokens_per_expert: torch.Tensor,
    align_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
    """
    Permute the tokens and probs based on the routing_map.
    Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    probs: torch.Tensor
        The tensor of probabilities corresponding to the permuted tokens and is
        of shape [num_tokens, num_experts]. It will be permuted with the tokens
        according to the routing_map.
    routing_map: torch.Tensor
        The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
    tokens_per_expert : torch.Tensor
        Tensor of shape `[num_experts]` containing actual token counts per expert.
    align_size : int
        the alignment size for the input tensor.
    """
    assert (
        tokens_per_expert is not None
    ), "tokens_per_expert must be provided to the fused permute padding function."
    assert align_size > 0, f"align_size must be positive, got {align_size}"

    # Ensure tokens_per_expert is on the same device as input to avoid device transfers
    if tokens_per_expert.device != inp.device:
        tokens_per_expert = tokens_per_expert.to(inp.device)

    # Calculate aligned token counts per expert
    target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()

    if torch.equal(tokens_per_expert, target_tokens_per_expert):
        pad_offsets = None
    else:
        pad_lengths = target_tokens_per_expert - tokens_per_expert
        cum_pad = torch.cumsum(pad_lengths, dim=0)
        pad_offsets = torch.cat(
            [torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]]
        )

    output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
        inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets
    )
    return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert


649
650
651
def moe_unpermute(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
652
653
    merging_probs: Optional[torch.Tensor] = None,
    restore_shape: Optional[torch.Size] = None,
654
    map_type: str = "mask",
655
    probs: Optional[torch.Tensor] = None,
656
    pad_offsets: Optional[torch.Tensor] = None,
657
658
659
660
661
662
663
) -> torch.Tensor:
    """
    Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
    corresponding probabilities.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
664
    inp : torch.Tensor
665
        Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
Paweł Gadziński's avatar
Paweł Gadziński committed
666
    row_id_map : torch.Tensor
667
668
        The tensor of a mapping table for sorted indices used to unpermute the tokens,
        which is the second output tensor of `Permute`.
Paweł Gadziński's avatar
Paweł Gadziński committed
669
    merging_probs : torch.Tensor, default = None
670
671
672
        The tensor of probabilities corresponding to the permuted tokens. If provided,
        the unpermuted tokens will be merged with their respective probabilities.
        By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
Paweł Gadziński's avatar
Paweł Gadziński committed
673
    restore_shape : torch.Size, default = None
674
        The output shape after the unpermute operation.
Paweł Gadziński's avatar
Paweł Gadziński committed
675
    map_type : str, default = 'mask'
676
677
        Type of the routing map tensor. Should be the same as the value passed to moe_permute.
        Options are: 'mask', 'index'.
Paweł Gadziński's avatar
Paweł Gadziński committed
678
    probs : torch.Tensor, default = None
679
        Renamed to merging_probs. Keep for backward compatibility.
680
681
682
683
    pad_offsets : torch.Tensor, default = None
        Tensor of per-expert cumulative padding offsets used to remove padding added
        during permutation. This is the fourth output of `moe_permute_and_pad_with_probs`
        and is required when unpermuting padded outputs.
684
    """
685
686
687
688
689
690
691
    if probs is not None:
        if merging_probs is not None:
            raise ValueError(
                "Both merging_probs and probs kwarg are provided. probs is deprecated."
            )
        warnings.warn("probs kwarg is deprecated. Use merging_probs kwarg instead.")
        merging_probs = probs
692
    if map_type == "index":
693
        return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs)
694
    if map_type == "mask":
695
696
697
        return _moe_unpermute_mask_map.apply(
            inp, row_id_map, merging_probs, restore_shape, pad_offsets
        )
698
699
700
701
702
703
704
705
706
707
708
709
    raise ValueError("map_type should be one of 'mask' or 'index'")


class _moe_chunk_sort(torch.autograd.Function):
    """functional MoE chunk permute"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        split_sizes: torch.Tensor,
        sorted_idxs: torch.Tensor,
710
        probs: torch.Tensor,
711
712
713
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
714
            return inp, probs
715
716
717
718

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert split_sizes.is_cuda, "TransformerEngine needs CUDA."
        assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA."
719
720
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."
721
722
723
724
725
726
727
728
729

        num_tokens, hidden_size = inp.shape
        num_splits = split_sizes.size(0)
        assert num_splits == sorted_idxs.size(0)

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
730
            fake_dtype = inp.dtype
731
            inp = inp._data
732
733

        row_id_map = triton_permutation.make_chunk_sort_map(
734
735
            split_sizes,
            sorted_idxs,
736
737
738
739
740
741
            num_tokens,
            num_splits,
        )
        output, permuted_probs = triton_permutation.sort_chunks_by_map(
            inp,
            row_id_map,
742
            probs,
743
744
            num_tokens,
            hidden_size,
745
            is_forward=True,
746
747
        )
        if fp8:
748
749
750
751
752
753
754
            output = Float8Tensor(
                data=output,
                fp8_dtype=fp8_dtype,
                fp8_scale_inv=fp8_scale_inv,
                shape=output.shape,
                dtype=fake_dtype,
            )
755
756
757
758

        ctx.save_for_backward(row_id_map)
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
759
        return output, permuted_probs
760
761
762
763
764

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
765
        permuted_probs_grad: torch.Tensor,
766
767
768
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
769
            return permuted_act_grad, None, None, permuted_probs_grad
770
771

        act_grad = None
772
        probs_grad = None
773
774
775
776
777
778
        if ctx.needs_input_grad[0]:
            (row_id_map,) = ctx.saved_tensors
            fp8 = isinstance(permuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = permuted_act_grad._fp8_dtype
                fp8_scale_inv = permuted_act_grad._scale_inv
779
                fake_dtype = permuted_act_grad.dtype
780
                permuted_act_grad = permuted_act_grad._data
781
            act_grad, probs_grad = triton_permutation.sort_chunks_by_map(
782
783
                permuted_act_grad,
                row_id_map,
784
                permuted_probs_grad,
785
786
                ctx.num_tokens,
                ctx.hidden_size,
787
                is_forward=False,
788
789
790
            )
            if fp8:
                act_grad = Float8Tensor(
791
792
793
794
795
                    data=act_grad,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv,
                    shape=act_grad.shape,
                    dtype=fake_dtype,
796
                )
797
798
799
        if not ctx.needs_input_grad[3]:
            probs_grad = None
        return act_grad, None, None, probs_grad
800
801
802
803
804
805
806
807
808
809
810
811
812
813


def moe_sort_chunks_by_index(
    inp: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split and sort the input tensor based on the split_sizes and sorted indices.
    The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
    according to the sorted_indices.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
814
    inp : torch.Tensor
815
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
816
    split_sizes : torch.Tensor
817
        Chunk sizes of the inp tensor along the 0-th dimension.
Paweł Gadziński's avatar
Paweł Gadziński committed
818
    sorted_indices : torch.Tensor
819
        Chunk indices used to permute the chunks.
820
    """
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
    output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None)
    return output


def moe_sort_chunks_by_index_with_probs(
    inp: torch.Tensor,
    probs: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split and sort the input tensor and probs based on the split_sizes and sorted indices.
    The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
    according to the sorted_indices.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
838
    inp : torch.Tensor
839
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
840
    probs : torch.Tensor
841
842
843
        The tensor of probabilities corresponding to the permuted tokens and is
        of shape [num_tokens]. It will be permuted with the tokens according to
        the split_sizes and sorted_indices.
Paweł Gadziński's avatar
Paweł Gadziński committed
844
    split_sizes : torch.Tensor
845
        Chunk sizes of the inp tensor along the 0-th dimension.
Paweł Gadziński's avatar
Paweł Gadziński committed
846
    sorted_indices : torch.Tensor
847
848
849
850
        Chunk indices used to permute the chunks.
    """
    output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs)
    return output, permuted_probs