permutation.py 20.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
"""MoE Permutaion API"""
6
7
8
9
10
import warnings
from typing import Tuple
import torch

import transformer_engine_torch as tex
11
12
13
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.float8_tensor import Float8Tensor
14
15
16
17
18


__all__ = [
    "moe_permute",
    "moe_unpermute",
19
    "moe_sort_chunks_by_index",
20
21
22
]


23
24
class _moe_permute_index_map(torch.autograd.Function):
    """functional Permute with index router map"""
25
26
27
28
29
30
31
32

    workspace = None
    max_expanded_token_num = 0

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
33
        index: torch.Tensor,
34
35
36
        num_out_tokens: int,
        max_token_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
37
        # pylint: disable=missing-function-docstring
38
39
        # Empty input check
        if not inp.numel():
40
            return inp, torch.tensor([], device=inp.device)
41
42
43

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
44
        assert index.is_cuda, "TransformerEngine needs CUDA."
45
        # Shape check
46
        assert inp.size(0) == index.size(0), "Permute not possible"
47
48

        # Data type check
49
        fp8 = isinstance(inp, Float8Tensor)
50
        if fp8:
51
            dtype = inp._fp8_dtype
52
53
            fp8_scale_inv = inp._scale_inv
            inp = inp._data
54
55
        else:
            dtype = TE_DType[inp.dtype]
56
        if index.dtype != torch.int32:
57
            warnings.warn(
58
                f"The data type of the input `index` of Permute is {index.dtype}! "
59
60
                "The recommended type is torch.int32."
            )
61
            index = index.to(torch.int32)
62

63
        topK = index.size(1)
64
65

        input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
66
67
68
        if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num:
            _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num
            _moe_permute_index_map.workspace = []
69

70
        permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd(
71
72
            inp,
            dtype,
73
            index,
74
            num_out_tokens,
75
76
            _moe_permute_index_map.workspace,
            _moe_permute_index_map.max_expanded_token_num,
77
78
79
80
        )

        if fp8:
            permuted_act = Float8Tensor(
81
                data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv
82
83
84
            )

        ctx.row_id_map = row_id_map
85
86
        ctx.num_tokens = index.size(0)
        ctx.topK = index.size(1)
87
88
89
90
91
92
93
94
95
        ctx.fp8 = fp8
        return permuted_act, row_id_map

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
    ) -> Tuple[torch.Tensor, ...]:
96
        # pylint: disable=missing-function-docstring
97
98
99
100
101
102
103
        # Empty input check
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None, None

        if not permuted_act_grad.is_contiguous():
            permuted_act_grad = permuted_act_grad.contiguous()

104
        if ctx.fp8:
105
106
107
            assert isinstance(
                permuted_act_grad, Float8Tensor
            ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
108
            dtype = permuted_act_grad._fp8_dtype
109
110
            fp8_scale_inv = permuted_act_grad._scale_inv
            permuted_act_grad = permuted_act_grad._data
111
112
        else:
            dtype = TE_DType[permuted_act_grad.dtype]
113
114
115
116

        act_grad = None
        if ctx.needs_input_grad[0]:
            act_grad = tex.moe_permute_bwd(
117
                permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
118
            )
119
            if ctx.fp8:
120
                act_grad = Float8Tensor(
121
                    data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK
122
123
                )

124
        return act_grad, None, None, None
125
126


127
128
class _moe_unpermute_index_map(torch.autograd.Function):
    """functional Unpermute with index router map"""
129
130
131
132
133
134
135
136

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
        probs: torch.Tensor,
    ) -> torch.Tensor:
137
        # pylint: disable=missing-function-docstring
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        # Empty input check
        if not inp.numel():
            ctx.probs = probs
            return inp

        # None probs check
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."

            if probs.dtype != torch.float32:
                warnings.warn(
                    f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
                    "The recommended type is torch.float32."
                )
                probs = probs.to(torch.float32)

            num_tokens = probs.size(0)
            topK = probs.size(1)
        else:
            num_tokens = row_id_map.size(0)
            topK = 1
            probs = torch.empty(0)

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        # Data type check
166
        fp8 = isinstance(inp, Float8Tensor)
167
        if fp8:
168
            dtype = inp._fp8_dtype
169
170
            fp8_scale_inv = inp._scale_inv
            inp = inp._data
171
172
        else:
            dtype = TE_DType[inp.dtype]
173
174
175
176
177
178
179
180
181
182
183
        if row_id_map.dtype != torch.int32:
            warnings.warn(
                f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
                "The recommended type is torch.int32."
            )
            row_id_map = row_id_map.to(torch.int32)

        unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)

        if fp8:
            unpermuted_output = Float8Tensor(
184
                data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv
185
186
187
188
189
190
191
192
193
194
195
            )

        ctx.save_for_backward(inp, row_id_map, probs)
        ctx.fp8 = fp8
        return unpermuted_output

    @staticmethod
    def backward(
        ctx,
        unpermuted_act_grad: torch.Tensor,
    ) -> Tuple[torch.Tensor, None, torch.Tensor]:
196
        # pylint: disable=missing-function-docstring
197
198
199
200
201
202
203
        # Empty input check
        if not unpermuted_act_grad.numel():
            return unpermuted_act_grad, None, ctx.probs

        if not unpermuted_act_grad.is_contiguous():
            unpermuted_act_grad = unpermuted_act_grad.contiguous()

204
        if ctx.fp8:
205
206
207
            assert isinstance(
                unpermuted_act_grad, Float8Tensor
            ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
208
            dtype = unpermuted_act_grad._fp8_dtype
209
210
            fp8_scale_inv = unpermuted_act_grad._scale_inv
            unpermuted_act_grad = unpermuted_act_grad._data
211
212
        else:
            dtype = TE_DType[unpermuted_act_grad.dtype]
213
214
215
216

        inp, row_id_map, probs = ctx.saved_tensors

        act_grad = None
217
        prob_grad = None
218
219
        if ctx.needs_input_grad[0]:
            act_grad, prob_grad = tex.moe_unpermute_bwd(
220
                unpermuted_act_grad, inp, dtype, row_id_map, probs
221
            )
222
223
224
            if ctx.fp8:
                act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv)
        if not ctx.needs_input_grad[2]:
225
226
            prob_grad = None

227
        return act_grad, None, prob_grad
228
229


230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
class _moe_permute_mask_map(torch.autograd.Function):
    """functional Permute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        routing_map: torch.Tensor,
        num_out_tokens: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
            return inp, torch.tensor([], device=inp.device)

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert routing_map.is_cuda, "TransformerEngine needs CUDA."

        assert inp.size(0) == routing_map.size(0), "Permute not possible"
        num_tokens, hidden_size = inp.size()
        num_experts = routing_map.size(1)
        assert (
            num_out_tokens is not None
        ), "num_out_tokens must be provided to the fused permute function."

        row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
            inp = inp._data
        output = triton_permutation.permute_with_mask_map(
            inp,
            row_id_map,
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
        )
        if fp8:
            output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv)

        ctx.save_for_backward(row_id_map)
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
        return output, row_id_map

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None

        act_grad = None
        if ctx.needs_input_grad[0]:
            (row_id_map,) = ctx.saved_tensors
            fp8 = isinstance(permuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = permuted_act_grad._fp8_dtype
                fp8_scale_inv = permuted_act_grad._scale_inv
                permuted_act_grad = permuted_act_grad._data
            else:
                fp8_dtype = None
            act_grad = triton_permutation.unpermute_with_mask_map(
                permuted_act_grad,
                row_id_map,
                None,
                ctx.num_tokens,
                ctx.num_experts,
                ctx.hidden_size,
                fp8_dtype,
            )
            if fp8:
                act_grad = Float8Tensor(
                    data=act_grad,
                    fp8_dtype=fp8_dtype,
                    fp8_scale_inv=fp8_scale_inv * ctx.num_experts,
                )
        return act_grad, None, None


class _moe_unpermute_mask_map(torch.autograd.Function):
    """functional Unpermute with mask router map"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
        probs: torch.Tensor,
        restore_shape: torch.Size,
    ) -> torch.Tensor:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
            ctx.probs = probs
            return inp

        if restore_shape is None:
            restore_shape = inp.shape
        num_tokens, hidden_size = restore_shape
        num_experts = row_id_map.size(0)

        with_probs = probs is not None
        if with_probs:
            assert probs.is_cuda, "TransformerEngine needs CUDA."
            if probs.dtype != torch.float32:
                warnings.warn(
                    f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
                    "The recommended type is torch.float32."
                )
                probs = probs.to(torch.float32)

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            if not with_probs:
                fp8_scale_inv = inp._scale_inv * num_experts
            else:
                fp8_scale_inv = inp._scale_inv
            inp = inp._data
        else:
            fp8_dtype = None
        unpermuted_output = triton_permutation.unpermute_with_mask_map(
            inp,
            row_id_map,
            probs,
            num_tokens,
            num_experts,
            hidden_size,
            fp8_dtype=fp8_dtype,
        )
        if fp8:
            unpermuted_output = Float8Tensor(
                data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
            )

        if with_probs:
            ctx.save_for_backward(inp, row_id_map, probs)
        else:
            ctx.save_for_backward(row_id_map)
        ctx.num_experts = num_experts
        ctx.num_tokens = num_tokens
        ctx.num_permuted_tokens = inp.size(0)
        ctx.hidden_size = hidden_size
        ctx.with_probs = with_probs
        return unpermuted_output

    @staticmethod
    def backward(ctx, unpermuted_act_grad):
        # pylint: disable=missing-function-docstring
        if not unpermuted_act_grad.numel():
            return unpermuted_act_grad, None, ctx.probs, None

        act_grad = None
        probs_grad = None
        if ctx.needs_input_grad[0]:
            if ctx.with_probs:
                fwd_input, row_id_map, probs = ctx.saved_tensors
            else:
                (row_id_map,) = ctx.saved_tensors

            fp8 = isinstance(unpermuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = unpermuted_act_grad._fp8_dtype
                fp8_scale_inv = unpermuted_act_grad._scale_inv
                unpermuted_act_grad = unpermuted_act_grad._data
            else:
                fp8_dtype = None

            if ctx.with_probs:
                act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_probs(
                    unpermuted_act_grad,
                    row_id_map,
                    fwd_input,
                    probs,
                    ctx.num_tokens,
                    ctx.num_experts,
                    ctx.num_permuted_tokens,
                    ctx.hidden_size,
                    fp8_dtype,
                )
            else:
                act_grad = triton_permutation.permute_with_mask_map(
                    unpermuted_act_grad,
                    row_id_map,
                    ctx.num_tokens,
                    ctx.num_experts,
                    ctx.num_permuted_tokens,
                    ctx.hidden_size,
                )

            if fp8:
                act_grad = Float8Tensor(
                    data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
                )

        if not ctx.needs_input_grad[2]:
            probs_grad = None
        return act_grad, None, probs_grad, None


440
441
def moe_permute(
    inp: torch.Tensor,
442
    routing_map: torch.Tensor,
443
444
    num_out_tokens: int = -1,
    max_token_num: int = -1,
445
    map_type: str = "mask",
446
447
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
448
449
450
    Permute the tokens based on the routing_map. Token with the same index will be grouped together.
    Tokens with the same designated expert will be grouped together.
    The routing_map indicates which experts were selected by each token.
451
452
453
454
455

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
456
457
458
459
460
461
    routing_map: torch.Tensor
        The token to expert mapping tensor.
        If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
        The values in it: 1 means the token is routed to this expert and 0 means not.
        If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
        The values in it are the routed expert indices.
462
463
464
465
466
467
468
    num_out_tokens: int, default = -1
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
    max_token_num: int, default = -1
        The maximum number of tokens, used for workspace allocation.
        By default, set to '-1', meaning the calculation of the size of workspace is
        automatically taken over by the operator.
469
470
471
    map_type: str, default = 'mask'
        Type of the routing map tensor.
        Options are: 'mask', 'index'.
472
    """
473
474
475
476
477
    if map_type == "index":
        return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
    if map_type == "mask":
        return _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens)
    raise ValueError("map_type should be one of 'mask' or 'index'")
478
479
480
481
482
483


def moe_unpermute(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
    probs: torch.Tensor = None,
484
485
    restore_shape: torch.Tensor = None,
    map_type: str = "mask",
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
) -> torch.Tensor:
    """
    Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
    corresponding probabilities.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
    row_id_map: torch.Tensor
        The tensor of a mapping table for sorted indices used to unpermute the tokens,
        which is the second output tensor of `Permute`.
    probs: torch.Tensor
        The tensor of probabilities corresponding to the permuted tokens. If provided,
        the unpermuted tokens will be merged with their respective probabilities.
        By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    restore_shape: torch.Tensor
        The output shape after the unpermute operation.
    map_type: str, default = 'mask'
        Type of the routing map tensor. Should be the same as the value passed to moe_permute.
        Options are: 'mask', 'index'.
    """
    if map_type == "index":
        return _moe_unpermute_index_map.apply(inp, row_id_map, probs)
    if map_type == "mask":
        return _moe_unpermute_mask_map.apply(inp, row_id_map, probs, restore_shape)
    raise ValueError("map_type should be one of 'mask' or 'index'")


class _moe_chunk_sort(torch.autograd.Function):
    """functional MoE chunk permute"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        split_sizes: torch.Tensor,
        sorted_idxs: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # pylint: disable=missing-function-docstring
        if not inp.numel():
            return inp, torch.tensor([], device=inp.device)

        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert split_sizes.is_cuda, "TransformerEngine needs CUDA."
        assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA."

        num_tokens, hidden_size = inp.shape
        num_splits = split_sizes.size(0)
        assert num_splits == sorted_idxs.size(0)

        fp8 = isinstance(inp, Float8Tensor)
        if fp8:
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
            inp = inp._data
        output, row_id_map = triton_permutation.sort_chunks_by_idx(
            inp,
            split_sizes,
            sorted_idxs,
            num_tokens,
            hidden_size,
            num_splits,
        )
        if fp8:
            output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv)

        ctx.save_for_backward(row_id_map)
        ctx.num_tokens = num_tokens
        ctx.hidden_size = hidden_size
        return output

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None

        act_grad = None
        if ctx.needs_input_grad[0]:
            (row_id_map,) = ctx.saved_tensors
            fp8 = isinstance(permuted_act_grad, Float8Tensor)
            if fp8:
                fp8_dtype = permuted_act_grad._fp8_dtype
                fp8_scale_inv = permuted_act_grad._scale_inv
                permuted_act_grad = permuted_act_grad._data
            act_grad = triton_permutation.sort_chunks_by_map(
                permuted_act_grad,
                row_id_map,
                ctx.num_tokens,
                ctx.hidden_size,
            )
            if fp8:
                act_grad = Float8Tensor(
                    data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
                )
        return act_grad, None, None


def moe_sort_chunks_by_index(
    inp: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split and sort the input tensor based on the split_sizes and sorted indices.
    The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
    according to the sorted_indices.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    split_sizes: torch.Tensor
        Chunk sizes of the inp tensor along the 0-th dimension.
    sorted_indices: torch.Tensor
        Chunk indices used to permute the chunks.
606
    """
607
    return _moe_chunk_sort.apply(inp, split_sizes, sorted_index)