permutation.py 31.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
"""MoE Permutation API"""
6
import warnings
7
from typing import Optional, Tuple
8
9
10
import torch

import transformer_engine_torch as tex
11
12
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
13
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
14
15
16
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
17
18
19
20

__all__ = [
    "moe_permute",
    "moe_unpermute",
21
    "moe_sort_chunks_by_index",
22
23
24
]


25
26
class _moe_permute_index_map(torch.autograd.Function):
    """functional Permute with index router map"""
27
28
29
30
31
32
33
34

    workspace = None
    max_expanded_token_num = 0

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
35
        index: torch.Tensor,
36
37
38
        num_out_tokens: int,
        max_token_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
39
        # pylint: disable=missing-function-docstring
40
41
        # Empty input check
        if not inp.numel():
42
            return inp, torch.tensor([], device=inp.device)
43
44
45

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
46
        assert index.is_cuda, "TransformerEngine needs CUDA."
47
        # Shape check
48
        assert inp.size(0) == index.size(0), "Permute not possible"
49
50

        # Data type check
51
        dtype = TE_DType[inp.dtype]
52
        if index.dtype != torch.int32:
53
            warnings.warn(
54
                f"The data type of the input `index` of Permute is {index.dtype}! "
55
56
                "The recommended type is torch.int32."
            )
57
            index = index.to(torch.int32)
58

59
        topK = index.size(1)
60
61

        input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
62
63
64
        if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num:
            _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num
            _moe_permute_index_map.workspace = []
65

66
        permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd(
67
68
            inp,
            dtype,
69
            index,
70
            num_out_tokens,
71
72
            _moe_permute_index_map.workspace,
            _moe_permute_index_map.max_expanded_token_num,
73
74
75
        )

        ctx.row_id_map = row_id_map
76
77
        ctx.num_tokens = index.size(0)
        ctx.topK = index.size(1)
78
79
80
81
82
83
84
85
        return permuted_act, row_id_map

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
    ) -> Tuple[torch.Tensor, ...]:
86
        # pylint: disable=missing-function-docstring
87
88
89
90
91
92
93
        # Empty input check
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None, None

        if not permuted_act_grad.is_contiguous():
            permuted_act_grad = permuted_act_grad.contiguous()

94
        dtype = TE_DType[permuted_act_grad.dtype]
95
96
97
        act_grad = None
        if ctx.needs_input_grad[0]:
            act_grad = tex.moe_permute_bwd(
98
                permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
99
100
            )

101
        return act_grad, None, None, None
102
103


104
105
class _moe_unpermute_index_map(torch.autograd.Function):
    """functional Unpermute with index router map"""
106
107
108
109
110
111
112
113

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
        probs: torch.Tensor,
    ) -> torch.Tensor:
114
        # pylint: disable=missing-function-docstring
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        # Empty input check
        if not inp.numel():
            ctx.probs = probs
            return inp

        # None probs check
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."

            if probs.dtype != torch.float32:
                warnings.warn(
                    f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
                    "The recommended type is torch.float32."
                )
                probs = probs.to(torch.float32)

            num_tokens = probs.size(0)
            topK = probs.size(1)
        else:
            num_tokens = row_id_map.size(0)
            topK = 1
            probs = torch.empty(0)

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        # Data type check
143
        dtype = TE_DType[inp.dtype]
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        if row_id_map.dtype != torch.int32:
            warnings.warn(
                f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
                "The recommended type is torch.int32."
            )
            row_id_map = row_id_map.to(torch.int32)

        unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)

        ctx.save_for_backward(inp, row_id_map, probs)
        return unpermuted_output

    @staticmethod
    def backward(
        ctx,
        unpermuted_act_grad: torch.Tensor,
    ) -> Tuple[torch.Tensor, None, torch.Tensor]:
161
        # pylint: disable=missing-function-docstring
162
163
164
165
166
167
168
        # Empty input check
        if not unpermuted_act_grad.numel():
            return unpermuted_act_grad, None, ctx.probs

        if not unpermuted_act_grad.is_contiguous():
            unpermuted_act_grad = unpermuted_act_grad.contiguous()

169
        dtype = TE_DType[unpermuted_act_grad.dtype]
170
171
172
        inp, row_id_map, probs = ctx.saved_tensors

        act_grad = None
173
        prob_grad = None
174
175
        if ctx.needs_input_grad[0]:
            act_grad, prob_grad = tex.moe_unpermute_bwd(
176
                unpermuted_act_grad, inp, dtype, row_id_map, probs
177
            )
178
        if not ctx.needs_input_grad[2]:
179
180
            prob_grad = None

181
        return act_grad, None, prob_grad
182
183


184
185
186
187
188
189
190
191
192
class _moe_permute_mask_map(torch.autograd.Function):
    """functional Permute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        routing_map: torch.Tensor,
        num_out_tokens: int,
193
        probs: torch.Tensor,
194
        pad_offsets: Optional[torch.Tensor],
195
196
197
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
198
199
            ctx.probs = probs
            return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device)
200
201
202

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert routing_map.is_cuda, "TransformerEngine needs CUDA."
203
204
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."
205
206
        if pad_offsets is not None:
            assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
207
208
209
210
211
212
213
214
215
216

        assert inp.size(0) == routing_map.size(0), "Permute not possible"
        num_tokens, hidden_size = inp.size()
        num_experts = routing_map.size(1)
        assert (
            num_out_tokens is not None
        ), "num_out_tokens must be provided to the fused permute function."

        row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)

217
218
219
220
221
        fp8 = isinstance(inp, QuantizedTensor)
        per_tensor_recipe = isinstance(inp, Float8Tensor)
        blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor)
        mxfp8_recipe = isinstance(inp, MXFP8Tensor)

222
223
        if fp8:
            fp8_dtype = inp._fp8_dtype
224
            fake_dtype = inp.dtype
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            # blockwise scaling
            if blockwise_recipe:
                fp8_scale = inp._rowwise_scale_inv.T.contiguous()
                scale_hidden_dim = fp8_scale.shape[1]
                assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                inp = inp._rowwise_data
            # mxfp8 scaling
            elif mxfp8_recipe:
                fp8_scale = inp._rowwise_scale_inv.contiguous()
                scale_hidden_dim = fp8_scale.shape[1]
                assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                inp = inp._rowwise_data
            # per-tensor scaling
            elif per_tensor_recipe:
                # Kernel does not need scale in per-tensor scaling
                fp8_scale = None
                scale_hidden_dim = None
                fp8_scale_inv = inp._scale_inv
                inp = inp._data
            else:
                raise ValueError("Unsupported FP8 recipe")
        else:
            fp8_scale = None
            fp8_dtype = None
            scale_hidden_dim = None

        output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map(
252
253
            inp,
            row_id_map,
254
            probs,
255
            fp8_scale,
256
            pad_offsets,
257
258
259
260
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
261
            scale_hidden_dim,
262
        )
263

264
        if fp8:
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
            if per_tensor_recipe:
                output = Float8Tensor(
                    data=output,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv,
                    shape=output.shape,
                    dtype=fake_dtype,
                )
            elif blockwise_recipe:
                output = Float8BlockwiseQTensor(
                    shape=output.shape,
                    dtype=fake_dtype,
                    rowwise_data=output,
                    rowwise_scale_inv=permuted_scale.T.contiguous(),
                    columnwise_data=None,
                    columnwise_scale_inv=None,
                    fp8_dtype=fp8_dtype,
                    quantizer=None,
                    is_2D_scaled=False,
                    requires_grad=output.requires_grad,
                )
            elif mxfp8_recipe:
                output = MXFP8Tensor(
                    shape=output.shape,
                    dtype=fake_dtype,
                    fp8_dtype=fp8_dtype,
                    rowwise_data=output,
                    rowwise_scale_inv=permuted_scale.contiguous(),
                    columnwise_data=None,
                    columnwise_scale_inv=None,
                    quantizer=None,
                    requires_grad=output.requires_grad,
                )
298

299
        ctx.save_for_backward(row_id_map, pad_offsets)
300
301
302
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
303
        return output, row_id_map, permuted_probs
304
305
306
307
308
309

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
310
        permuted_probs_grad: torch.Tensor,
311
312
313
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
314
            return permuted_act_grad, None, None, ctx.probs, None
315
316

        act_grad = None
317
        probs_grad = None
318
        if ctx.needs_input_grad[0]:
319
            row_id_map, pad_offsets = ctx.saved_tensors
320
321
322
            assert not isinstance(
                permuted_act_grad, QuantizedTensor
            ), "The backward of moe_permute does not support FP8."
323
            act_grad, probs_grad = triton_permutation.unpermute_with_mask_map(
324
325
326
                permuted_act_grad,
                row_id_map,
                None,
327
                permuted_probs_grad,
328
                pad_offsets,
329
330
331
332
                ctx.num_tokens,
                ctx.num_experts,
                ctx.hidden_size,
            )
333
334
        if not ctx.needs_input_grad[3]:
            probs_grad = None
335
        return act_grad, None, None, probs_grad, None
336
337
338
339
340
341
342
343
344
345


class _moe_unpermute_mask_map(torch.autograd.Function):
    """functional Unpermute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
346
347
        merging_probs: Optional[torch.Tensor],
        restore_shape: Optional[torch.Size],
348
        pad_offsets: Optional[torch.Tensor],
349
350
351
    ) -> torch.Tensor:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
352
            ctx.merging_probs = merging_probs
353
354
355
356
357
            return inp

        if restore_shape is None:
            restore_shape = inp.shape
        num_tokens, hidden_size = restore_shape
358
        num_experts = (row_id_map.size(1) - 1) // 2
359

360
        with_probs = merging_probs is not None
361
        if with_probs:
362
            assert merging_probs.is_cuda, "TransformerEngine needs CUDA."
363
364
365
366

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
367
368
        if pad_offsets is not None:
            assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
369

370
371
372
        assert not isinstance(
            inp, QuantizedTensor
        ), "The forward of moe_unpermute does not support FP8."
373
        unpermuted_output, _ = triton_permutation.unpermute_with_mask_map(
374
375
            inp,
            row_id_map,
376
377
            merging_probs,
            None,
378
            pad_offsets,
379
380
381
382
383
384
            num_tokens,
            num_experts,
            hidden_size,
        )

        if with_probs:
385
            ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets)
386
        else:
387
            ctx.save_for_backward(row_id_map, pad_offsets)
388
389
390
391
392
393
394
395
396
397
398
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.num_permuted_tokens = inp.size(0)
        ctx.hidden_size = hidden_size
        ctx.with_probs = with_probs
        return unpermuted_output

    @staticmethod
    def backward(ctx, unpermuted_act_grad):
        # pylint: disable=missing-function-docstring
        if not unpermuted_act_grad.numel():
399
            return unpermuted_act_grad, None, ctx.merging_probs, None, None
400
401
402
403
404

        act_grad = None
        probs_grad = None
        if ctx.needs_input_grad[0]:
            if ctx.with_probs:
405
                fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors
406
            else:
407
                row_id_map, pad_offsets = ctx.saved_tensors
408

409
410
411
412
413
            fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
            per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
            blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor)
            mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor)

414
415
            if fp8:
                fp8_dtype = unpermuted_act_grad._fp8_dtype
416
                fake_dtype = unpermuted_act_grad.dtype
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
                # per-tensor scaling
                if per_tensor_recipe:
                    # Kernel does not need scale in per-tensor scaling
                    fp8_scale = None
                    scale_hidden_dim = None
                    fp8_scale_inv = unpermuted_act_grad._scale_inv
                    unpermuted_act_grad = unpermuted_act_grad._data
                # blockwise scaling
                elif blockwise_recipe:
                    fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous()
                    unpermuted_act_grad = unpermuted_act_grad._rowwise_data
                    scale_hidden_dim = fp8_scale.shape[1]
                    assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                # mxfp8 scaling
                elif mxfp8_recipe:
                    fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous()
                    unpermuted_act_grad = unpermuted_act_grad._rowwise_data
                    scale_hidden_dim = fp8_scale.shape[1]
                    assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
                else:
                    raise ValueError("Unsupported FP8 recipe")
438
            else:
439
                scale_hidden_dim = None
440
                fp8_dtype = None
441
                fp8_scale = None
442
443

            if ctx.with_probs:
444
445
446
                assert (
                    not fp8
                ), "The backward of moe_unpermute with merging probs does not support FP8."
447
448
449
450
451
452
                act_grad, probs_grad = (
                    triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
                        unpermuted_act_grad,
                        row_id_map,
                        fwd_input,
                        merging_probs,
453
                        pad_offsets,
454
455
456
457
458
                        ctx.num_tokens,
                        ctx.num_experts,
                        ctx.num_permuted_tokens,
                        ctx.hidden_size,
                    )
459
460
                )
            else:
461
                act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map(
462
463
                    unpermuted_act_grad,
                    row_id_map,
464
                    None,
465
                    fp8_scale,
466
                    pad_offsets,
467
468
469
470
                    ctx.num_tokens,
                    ctx.num_experts,
                    ctx.num_permuted_tokens,
                    ctx.hidden_size,
471
                    scale_hidden_dim,
472
473
474
                )

            if fp8:
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
                if per_tensor_recipe:
                    act_grad = Float8Tensor(
                        data=act_grad,
                        fp8_dtype=fp8_dtype,
                        fp8_scale_inv=fp8_scale_inv,
                        shape=act_grad.shape,
                        dtype=fake_dtype,
                    )
                elif blockwise_recipe:
                    act_grad = Float8BlockwiseQTensor(
                        shape=act_grad.shape,
                        dtype=fake_dtype,
                        rowwise_data=act_grad,
                        rowwise_scale_inv=permuted_scale.T.contiguous(),
                        columnwise_data=None,
                        columnwise_scale_inv=None,
                        fp8_dtype=fp8_dtype,
                        quantizer=None,
                        is_2D_scaled=False,
                        requires_grad=act_grad.requires_grad,
                    )
                elif mxfp8_recipe:
                    act_grad = MXFP8Tensor(
                        shape=act_grad.shape,
                        dtype=fake_dtype,
                        fp8_dtype=fp8_dtype,
                        rowwise_data=act_grad,
                        rowwise_scale_inv=permuted_scale.contiguous(),
                        columnwise_data=None,
                        columnwise_scale_inv=None,
                        quantizer=None,
                        requires_grad=act_grad.requires_grad,
                    )
508
509
510

        if not ctx.needs_input_grad[2]:
            probs_grad = None
511
        return act_grad, None, probs_grad, None, None
512
513


514
515
def moe_permute(
    inp: torch.Tensor,
516
    routing_map: torch.Tensor,
517
518
    num_out_tokens: int = -1,
    max_token_num: int = -1,
519
    map_type: str = "mask",
520
521
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
522
523
524
    Permute the tokens based on the routing_map. Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.
525
526
527

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
528
    inp : torch.Tensor
529
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
530
    routing_map : torch.Tensor
531
532
533
534
535
        The token to expert mapping tensor.
        If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
        If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
        The values in it are the routed expert indices.
Paweł Gadziński's avatar
Paweł Gadziński committed
536
    num_out_tokens : int, default = -1
537
538
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
Paweł Gadziński's avatar
Paweł Gadziński committed
539
    max_token_num : int, default = -1
540
541
542
        The maximum number of tokens, used for workspace allocation.
        By default, set to '-1', meaning the calculation of the size of workspace is
        automatically taken over by the operator.
Paweł Gadziński's avatar
Paweł Gadziński committed
543
    map_type : str, default = 'mask'
544
545
        Type of the routing map tensor.
        Options are: 'mask', 'index'.
546
        Refer to `routing_map` for more details.
547
    """
548
549
550
    if map_type == "index":
        return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
    if map_type == "mask":
551
552
553
        output, row_id_map, _ = _moe_permute_mask_map.apply(
            inp, routing_map, num_out_tokens, None, None
        )
554
        return output, row_id_map
555
    raise ValueError("map_type should be one of 'mask' or 'index'")
556
557


558
559
560
561
562
563
564
565
566
567
568
569
570
571
def moe_permute_with_probs(
    inp: torch.Tensor,
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    num_out_tokens: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Permute the tokens and probs based on the routing_map.
    Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
572
    inp : torch.Tensor
573
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
574
    probs : torch.Tensor
575
576
577
        The tensor of probabilities corresponding to the permuted tokens and is
        of shape [num_tokens, num_experts]. It will be permuted with the tokens
        according to the routing_map.
Paweł Gadziński's avatar
Paweł Gadziński committed
578
    routing_map : torch.Tensor
579
580
        The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
Paweł Gadziński's avatar
Paweł Gadziński committed
581
    num_out_tokens : int, default = -1
582
583
584
585
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
    """
    output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
586
        inp, routing_map, num_out_tokens, probs, None
587
588
589
590
    )
    return output, permuted_probs, row_id_map


591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
def moe_permute_and_pad_with_probs(
    inp: torch.Tensor,
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    tokens_per_expert: torch.Tensor,
    align_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
    """
    Permute the tokens and probs based on the routing_map.
    Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    probs: torch.Tensor
        The tensor of probabilities corresponding to the permuted tokens and is
        of shape [num_tokens, num_experts]. It will be permuted with the tokens
        according to the routing_map.
    routing_map: torch.Tensor
        The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
    tokens_per_expert : torch.Tensor
        Tensor of shape `[num_experts]` containing actual token counts per expert.
    align_size : int
        the alignment size for the input tensor.
    """
    assert (
        tokens_per_expert is not None
    ), "tokens_per_expert must be provided to the fused permute padding function."
    assert align_size > 0, f"align_size must be positive, got {align_size}"

    # Ensure tokens_per_expert is on the same device as input to avoid device transfers
    if tokens_per_expert.device != inp.device:
        tokens_per_expert = tokens_per_expert.to(inp.device)

    # Calculate aligned token counts per expert
    target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()

    if torch.equal(tokens_per_expert, target_tokens_per_expert):
        pad_offsets = None
    else:
        pad_lengths = target_tokens_per_expert - tokens_per_expert
        cum_pad = torch.cumsum(pad_lengths, dim=0)
        pad_offsets = torch.cat(
            [torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]]
        )

    output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
        inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets
    )
    return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert


647
648
649
def moe_unpermute(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
650
651
    merging_probs: Optional[torch.Tensor] = None,
    restore_shape: Optional[torch.Size] = None,
652
    map_type: str = "mask",
653
    probs: Optional[torch.Tensor] = None,
654
    pad_offsets: Optional[torch.Tensor] = None,
655
656
657
658
659
660
661
) -> torch.Tensor:
    """
    Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
    corresponding probabilities.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
662
    inp : torch.Tensor
663
        Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
Paweł Gadziński's avatar
Paweł Gadziński committed
664
    row_id_map : torch.Tensor
665
666
        The tensor of a mapping table for sorted indices used to unpermute the tokens,
        which is the second output tensor of `Permute`.
Paweł Gadziński's avatar
Paweł Gadziński committed
667
    merging_probs : torch.Tensor, default = None
668
669
670
        The tensor of probabilities corresponding to the permuted tokens. If provided,
        the unpermuted tokens will be merged with their respective probabilities.
        By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
Paweł Gadziński's avatar
Paweł Gadziński committed
671
    restore_shape : torch.Size, default = None
672
        The output shape after the unpermute operation.
Paweł Gadziński's avatar
Paweł Gadziński committed
673
    map_type : str, default = 'mask'
674
675
        Type of the routing map tensor. Should be the same as the value passed to moe_permute.
        Options are: 'mask', 'index'.
Paweł Gadziński's avatar
Paweł Gadziński committed
676
    probs : torch.Tensor, default = None
677
        Renamed to merging_probs. Keep for backward compatibility.
678
679
680
681
    pad_offsets : torch.Tensor, default = None
        Tensor of per-expert cumulative padding offsets used to remove padding added
        during permutation. This is the fourth output of `moe_permute_and_pad_with_probs`
        and is required when unpermuting padded outputs.
682
    """
683
684
685
686
687
688
689
    if probs is not None:
        if merging_probs is not None:
            raise ValueError(
                "Both merging_probs and probs kwarg are provided. probs is deprecated."
            )
        warnings.warn("probs kwarg is deprecated. Use merging_probs kwarg instead.")
        merging_probs = probs
690
    if map_type == "index":
691
        return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs)
692
    if map_type == "mask":
693
694
695
        return _moe_unpermute_mask_map.apply(
            inp, row_id_map, merging_probs, restore_shape, pad_offsets
        )
696
697
698
699
700
701
702
703
704
705
706
707
    raise ValueError("map_type should be one of 'mask' or 'index'")


class _moe_chunk_sort(torch.autograd.Function):
    """functional MoE chunk permute"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        split_sizes: torch.Tensor,
        sorted_idxs: torch.Tensor,
708
        probs: torch.Tensor,
709
710
711
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
712
            return inp, probs
713
714
715
716

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert split_sizes.is_cuda, "TransformerEngine needs CUDA."
        assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA."
717
718
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."
719
720
721
722
723
724
725
726
727

        num_tokens, hidden_size = inp.shape
        num_splits = split_sizes.size(0)
        assert num_splits == sorted_idxs.size(0)

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
728
            fake_dtype = inp.dtype
729
            inp = inp._data
730
731

        row_id_map = triton_permutation.make_chunk_sort_map(
732
733
            split_sizes,
            sorted_idxs,
734
735
736
737
738
739
            num_tokens,
            num_splits,
        )
        output, permuted_probs = triton_permutation.sort_chunks_by_map(
            inp,
            row_id_map,
740
            probs,
741
742
            num_tokens,
            hidden_size,
743
            is_forward=True,
744
745
        )
        if fp8:
746
747
748
749
750
751
752
            output = Float8Tensor(
                data=output,
                fp8_dtype=fp8_dtype,
                fp8_scale_inv=fp8_scale_inv,
                shape=output.shape,
                dtype=fake_dtype,
            )
753
754
755
756

        ctx.save_for_backward(row_id_map)
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
757
        return output, permuted_probs
758
759
760
761
762

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
763
        permuted_probs_grad: torch.Tensor,
764
765
766
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
767
            return permuted_act_grad, None, None, permuted_probs_grad
768
769

        act_grad = None
770
        probs_grad = None
771
772
773
774
775
776
        if ctx.needs_input_grad[0]:
            (row_id_map,) = ctx.saved_tensors
            fp8 = isinstance(permuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = permuted_act_grad._fp8_dtype
                fp8_scale_inv = permuted_act_grad._scale_inv
777
                fake_dtype = permuted_act_grad.dtype
778
                permuted_act_grad = permuted_act_grad._data
779
            act_grad, probs_grad = triton_permutation.sort_chunks_by_map(
780
781
                permuted_act_grad,
                row_id_map,
782
                permuted_probs_grad,
783
784
                ctx.num_tokens,
                ctx.hidden_size,
785
                is_forward=False,
786
787
788
            )
            if fp8:
                act_grad = Float8Tensor(
789
790
791
792
793
                    data=act_grad,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv,
                    shape=act_grad.shape,
                    dtype=fake_dtype,
794
                )
795
796
797
        if not ctx.needs_input_grad[3]:
            probs_grad = None
        return act_grad, None, None, probs_grad
798
799
800
801
802
803
804
805
806
807
808
809
810
811


def moe_sort_chunks_by_index(
    inp: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split and sort the input tensor based on the split_sizes and sorted indices.
    The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
    according to the sorted_indices.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
812
    inp : torch.Tensor
813
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
814
    split_sizes : torch.Tensor
815
        Chunk sizes of the inp tensor along the 0-th dimension.
Paweł Gadziński's avatar
Paweł Gadziński committed
816
    sorted_indices : torch.Tensor
817
        Chunk indices used to permute the chunks.
818
    """
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
    output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None)
    return output


def moe_sort_chunks_by_index_with_probs(
    inp: torch.Tensor,
    probs: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split and sort the input tensor and probs based on the split_sizes and sorted indices.
    The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
    according to the sorted_indices.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
836
    inp : torch.Tensor
837
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
838
    probs : torch.Tensor
839
840
841
        The tensor of probabilities corresponding to the permuted tokens and is
        of shape [num_tokens]. It will be permuted with the tokens according to
        the split_sizes and sorted_indices.
Paweł Gadziński's avatar
Paweł Gadziński committed
842
    split_sizes : torch.Tensor
843
        Chunk sizes of the inp tensor along the 0-th dimension.
Paweł Gadziński's avatar
Paweł Gadziński committed
844
    sorted_indices : torch.Tensor
845
846
847
848
        Chunk indices used to permute the chunks.
    """
    output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs)
    return output, permuted_probs