permutation.py 21.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
"""MoE Permutaion API"""
6
7
8
9
10
import warnings
from typing import Tuple
import torch

import transformer_engine_torch as tex
11
12
13
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.float8_tensor import Float8Tensor
14
15
16
17
18


__all__ = [
    "moe_permute",
    "moe_unpermute",
19
    "moe_sort_chunks_by_index",
20
21
22
]


23
24
class _moe_permute_index_map(torch.autograd.Function):
    """functional Permute with index router map"""
25
26
27
28
29
30
31
32

    workspace = None
    max_expanded_token_num = 0

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
33
        index: torch.Tensor,
34
35
36
        num_out_tokens: int,
        max_token_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
37
        # pylint: disable=missing-function-docstring
38
39
        # Empty input check
        if not inp.numel():
40
            return inp, torch.tensor([], device=inp.device)
41
42
43

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
44
        assert index.is_cuda, "TransformerEngine needs CUDA."
45
        # Shape check
46
        assert inp.size(0) == index.size(0), "Permute not possible"
47
48

        # Data type check
49
        fp8 = isinstance(inp, Float8Tensor)
50
        if fp8:
51
52
53
            assert (
                inp._quantizer.scale.ndim == 0
            ), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute."
54
            dtype = inp._fp8_dtype
55
            fp8_scale_inv = inp._scale_inv
56
            fake_dtype = inp.dtype
57
            inp = inp._data
58
59
        else:
            dtype = TE_DType[inp.dtype]
60
        if index.dtype != torch.int32:
61
            warnings.warn(
62
                f"The data type of the input `index` of Permute is {index.dtype}! "
63
64
                "The recommended type is torch.int32."
            )
65
            index = index.to(torch.int32)
66

67
        topK = index.size(1)
68
69

        input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
70
71
72
        if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num:
            _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num
            _moe_permute_index_map.workspace = []
73

74
        permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd(
75
76
            inp,
            dtype,
77
            index,
78
            num_out_tokens,
79
80
            _moe_permute_index_map.workspace,
            _moe_permute_index_map.max_expanded_token_num,
81
82
83
84
        )

        if fp8:
            permuted_act = Float8Tensor(
85
86
87
88
89
                data=permuted_act,
                fp8_dtype=dtype,
                fp8_scale_inv=fp8_scale_inv,
                shape=permuted_act.shape,
                dtype=fake_dtype,
90
91
92
            )

        ctx.row_id_map = row_id_map
93
94
        ctx.num_tokens = index.size(0)
        ctx.topK = index.size(1)
95
96
97
98
99
100
101
102
103
        ctx.fp8 = fp8
        return permuted_act, row_id_map

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
    ) -> Tuple[torch.Tensor, ...]:
104
        # pylint: disable=missing-function-docstring
105
106
107
108
109
110
111
        # Empty input check
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None, None

        if not permuted_act_grad.is_contiguous():
            permuted_act_grad = permuted_act_grad.contiguous()

112
        if ctx.fp8:
113
114
115
            assert isinstance(
                permuted_act_grad, Float8Tensor
            ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
116
            dtype = permuted_act_grad._fp8_dtype
117
            fp8_scale_inv = permuted_act_grad._scale_inv
118
            fake_dtype = permuted_act_grad.dtype
119
            permuted_act_grad = permuted_act_grad._data
120
121
        else:
            dtype = TE_DType[permuted_act_grad.dtype]
122
123
124
125

        act_grad = None
        if ctx.needs_input_grad[0]:
            act_grad = tex.moe_permute_bwd(
126
                permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
127
            )
128
            if ctx.fp8:
129
                act_grad = Float8Tensor(
130
131
132
133
134
                    data=act_grad,
                    fp8_dtype=dtype,
                    fp8_scale_inv=fp8_scale_inv * ctx.topK,
                    shape=act_grad.shape,
                    dtype=fake_dtype,
135
136
                )

137
        return act_grad, None, None, None
138
139


140
141
class _moe_unpermute_index_map(torch.autograd.Function):
    """functional Unpermute with index router map"""
142
143
144
145
146
147
148
149

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
        probs: torch.Tensor,
    ) -> torch.Tensor:
150
        # pylint: disable=missing-function-docstring
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        # Empty input check
        if not inp.numel():
            ctx.probs = probs
            return inp

        # None probs check
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."

            if probs.dtype != torch.float32:
                warnings.warn(
                    f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
                    "The recommended type is torch.float32."
                )
                probs = probs.to(torch.float32)

            num_tokens = probs.size(0)
            topK = probs.size(1)
        else:
            num_tokens = row_id_map.size(0)
            topK = 1
            probs = torch.empty(0)

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        # Data type check
179
        fp8 = isinstance(inp, Float8Tensor)
180
        if fp8:
181
            dtype = inp._fp8_dtype
182
            fp8_scale_inv = inp._scale_inv
183
            fake_dtype = inp.dtype
184
            inp = inp._data
185
186
        else:
            dtype = TE_DType[inp.dtype]
187
188
189
190
191
192
193
194
195
196
197
        if row_id_map.dtype != torch.int32:
            warnings.warn(
                f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
                "The recommended type is torch.int32."
            )
            row_id_map = row_id_map.to(torch.int32)

        unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)

        if fp8:
            unpermuted_output = Float8Tensor(
198
199
200
201
202
                data=unpermuted_output,
                fp8_dtype=dtype,
                fp8_scale_inv=fp8_scale_inv,
                shape=unpermuted_output.shape,
                dtype=fake_dtype,
203
204
205
206
207
208
209
210
211
212
213
            )

        ctx.save_for_backward(inp, row_id_map, probs)
        ctx.fp8 = fp8
        return unpermuted_output

    @staticmethod
    def backward(
        ctx,
        unpermuted_act_grad: torch.Tensor,
    ) -> Tuple[torch.Tensor, None, torch.Tensor]:
214
        # pylint: disable=missing-function-docstring
215
216
217
218
219
220
221
        # Empty input check
        if not unpermuted_act_grad.numel():
            return unpermuted_act_grad, None, ctx.probs

        if not unpermuted_act_grad.is_contiguous():
            unpermuted_act_grad = unpermuted_act_grad.contiguous()

222
        if ctx.fp8:
223
224
225
            assert isinstance(
                unpermuted_act_grad, Float8Tensor
            ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
226
            dtype = unpermuted_act_grad._fp8_dtype
227
            fp8_scale_inv = unpermuted_act_grad._scale_inv
228
            fake_dtype = unpermuted_act_grad.dtype
229
            unpermuted_act_grad = unpermuted_act_grad._data
230
231
        else:
            dtype = TE_DType[unpermuted_act_grad.dtype]
232
233
234
235

        inp, row_id_map, probs = ctx.saved_tensors

        act_grad = None
236
        prob_grad = None
237
238
        if ctx.needs_input_grad[0]:
            act_grad, prob_grad = tex.moe_unpermute_bwd(
239
                unpermuted_act_grad, inp, dtype, row_id_map, probs
240
            )
241
            if ctx.fp8:
242
243
244
245
246
247
248
                act_grad = Float8Tensor(
                    data=act_grad,
                    fp8_dtype=dtype,
                    fp8_scale_inv=fp8_scale_inv,
                    shape=act_grad.shape,
                    dtype=fake_dtype,
                )
249
        if not ctx.needs_input_grad[2]:
250
251
            prob_grad = None

252
        return act_grad, None, prob_grad
253
254


255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
class _moe_permute_mask_map(torch.autograd.Function):
    """functional Permute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        routing_map: torch.Tensor,
        num_out_tokens: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
            return inp, torch.tensor([], device=inp.device)

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert routing_map.is_cuda, "TransformerEngine needs CUDA."

        assert inp.size(0) == routing_map.size(0), "Permute not possible"
        num_tokens, hidden_size = inp.size()
        num_experts = routing_map.size(1)
        assert (
            num_out_tokens is not None
        ), "num_out_tokens must be provided to the fused permute function."

        row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
            inp = inp._data
        output = triton_permutation.permute_with_mask_map(
            inp,
            row_id_map,
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
        )
        if fp8:
            output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv)

        ctx.save_for_backward(row_id_map)
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
        return output, row_id_map

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None

        act_grad = None
        if ctx.needs_input_grad[0]:
            (row_id_map,) = ctx.saved_tensors
            fp8 = isinstance(permuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = permuted_act_grad._fp8_dtype
                fp8_scale_inv = permuted_act_grad._scale_inv
                permuted_act_grad = permuted_act_grad._data
            else:
                fp8_dtype = None
            act_grad = triton_permutation.unpermute_with_mask_map(
                permuted_act_grad,
                row_id_map,
                None,
                ctx.num_tokens,
                ctx.num_experts,
                ctx.hidden_size,
                fp8_dtype,
            )
            if fp8:
                act_grad = Float8Tensor(
                    data=act_grad,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv * ctx.num_experts,
                )
        return act_grad, None, None


class _moe_unpermute_mask_map(torch.autograd.Function):
    """functional Unpermute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
        probs: torch.Tensor,
        restore_shape: torch.Size,
    ) -> torch.Tensor:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
            ctx.probs = probs
            return inp

        if restore_shape is None:
            restore_shape = inp.shape
        num_tokens, hidden_size = restore_shape
        num_experts = row_id_map.size(0)

        with_probs = probs is not None
        if with_probs:
            assert probs.is_cuda, "TransformerEngine needs CUDA."
            if probs.dtype != torch.float32:
                warnings.warn(
                    f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
                    "The recommended type is torch.float32."
                )
                probs = probs.to(torch.float32)

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            if not with_probs:
                fp8_scale_inv = inp._scale_inv * num_experts
            else:
                fp8_scale_inv = inp._scale_inv
            inp = inp._data
        else:
            fp8_dtype = None
        unpermuted_output = triton_permutation.unpermute_with_mask_map(
            inp,
            row_id_map,
            probs,
            num_tokens,
            num_experts,
            hidden_size,
            fp8_dtype=fp8_dtype,
        )
        if fp8:
            unpermuted_output = Float8Tensor(
                data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
            )

        if with_probs:
            ctx.save_for_backward(inp, row_id_map, probs)
        else:
            ctx.save_for_backward(row_id_map)
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.num_permuted_tokens = inp.size(0)
        ctx.hidden_size = hidden_size
        ctx.with_probs = with_probs
        return unpermuted_output

    @staticmethod
    def backward(ctx, unpermuted_act_grad):
        # pylint: disable=missing-function-docstring
        if not unpermuted_act_grad.numel():
            return unpermuted_act_grad, None, ctx.probs, None

        act_grad = None
        probs_grad = None
        if ctx.needs_input_grad[0]:
            if ctx.with_probs:
                fwd_input, row_id_map, probs = ctx.saved_tensors
            else:
                (row_id_map,) = ctx.saved_tensors

            fp8 = isinstance(unpermuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = unpermuted_act_grad._fp8_dtype
                fp8_scale_inv = unpermuted_act_grad._scale_inv
                unpermuted_act_grad = unpermuted_act_grad._data
            else:
                fp8_dtype = None

            if ctx.with_probs:
                act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_probs(
                    unpermuted_act_grad,
                    row_id_map,
                    fwd_input,
                    probs,
                    ctx.num_tokens,
                    ctx.num_experts,
                    ctx.num_permuted_tokens,
                    ctx.hidden_size,
                    fp8_dtype,
                )
            else:
                act_grad = triton_permutation.permute_with_mask_map(
                    unpermuted_act_grad,
                    row_id_map,
                    ctx.num_tokens,
                    ctx.num_experts,
                    ctx.num_permuted_tokens,
                    ctx.hidden_size,
                )

            if fp8:
                act_grad = Float8Tensor(
                    data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
                )

        if not ctx.needs_input_grad[2]:
            probs_grad = None
        return act_grad, None, probs_grad, None


465
466
def moe_permute(
    inp: torch.Tensor,
467
    routing_map: torch.Tensor,
468
469
    num_out_tokens: int = -1,
    max_token_num: int = -1,
470
    map_type: str = "mask",
471
472
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
473
474
475
    Permute the tokens based on the routing_map. Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.
476
477
478
479
480

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
481
482
483
484
485
486
    routing_map: torch.Tensor
        The token to expert mapping tensor.
        If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
        If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
        The values in it are the routed expert indices.
487
488
489
490
491
492
493
    num_out_tokens: int, default = -1
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
    max_token_num: int, default = -1
        The maximum number of tokens, used for workspace allocation.
        By default, set to '-1', meaning the calculation of the size of workspace is
        automatically taken over by the operator.
494
495
496
    map_type: str, default = 'mask'
        Type of the routing map tensor.
        Options are: 'mask', 'index'.
497
    """
498
499
500
501
502
    if map_type == "index":
        return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
    if map_type == "mask":
        return _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens)
    raise ValueError("map_type should be one of 'mask' or 'index'")
503
504
505
506
507
508


def moe_unpermute(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
    probs: torch.Tensor = None,
509
510
    restore_shape: torch.Tensor = None,
    map_type: str = "mask",
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
) -> torch.Tensor:
    """
    Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
    corresponding probabilities.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
    row_id_map: torch.Tensor
        The tensor of a mapping table for sorted indices used to unpermute the tokens,
        which is the second output tensor of `Permute`.
    probs: torch.Tensor
        The tensor of probabilities corresponding to the permuted tokens. If provided,
        the unpermuted tokens will be merged with their respective probabilities.
        By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
    restore_shape: torch.Tensor
        The output shape after the unpermute operation.
    map_type: str, default = 'mask'
        Type of the routing map tensor. Should be the same as the value passed to moe_permute.
        Options are: 'mask', 'index'.
    """
    if map_type == "index":
        return _moe_unpermute_index_map.apply(inp, row_id_map, probs)
    if map_type == "mask":
        return _moe_unpermute_mask_map.apply(inp, row_id_map, probs, restore_shape)
    raise ValueError("map_type should be one of 'mask' or 'index'")


class _moe_chunk_sort(torch.autograd.Function):
    """functional MoE chunk permute"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        split_sizes: torch.Tensor,
        sorted_idxs: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
            return inp, torch.tensor([], device=inp.device)

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert split_sizes.is_cuda, "TransformerEngine needs CUDA."
        assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA."

        num_tokens, hidden_size = inp.shape
        num_splits = split_sizes.size(0)
        assert num_splits == sorted_idxs.size(0)

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
            inp = inp._data
        output, row_id_map = triton_permutation.sort_chunks_by_idx(
            inp,
            split_sizes,
            sorted_idxs,
            num_tokens,
            hidden_size,
            num_splits,
        )
        if fp8:
            output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv)

        ctx.save_for_backward(row_id_map)
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
        return output

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None

        act_grad = None
        if ctx.needs_input_grad[0]:
            (row_id_map,) = ctx.saved_tensors
            fp8 = isinstance(permuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = permuted_act_grad._fp8_dtype
                fp8_scale_inv = permuted_act_grad._scale_inv
                permuted_act_grad = permuted_act_grad._data
            act_grad = triton_permutation.sort_chunks_by_map(
                permuted_act_grad,
                row_id_map,
                ctx.num_tokens,
                ctx.hidden_size,
            )
            if fp8:
                act_grad = Float8Tensor(
                    data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
                )
        return act_grad, None, None


def moe_sort_chunks_by_index(
    inp: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split and sort the input tensor based on the split_sizes and sorted indices.
    The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
    according to the sorted_indices.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    split_sizes: torch.Tensor
        Chunk sizes of the inp tensor along the 0-th dimension.
    sorted_indices: torch.Tensor
        Chunk indices used to permute the chunks.
631
    """
632
    return _moe_chunk_sort.apply(inp, split_sizes, sorted_index)