permutation.py 9.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Linear API"""
import warnings
from typing import Tuple
import torch

import transformer_engine_torch as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor


__all__ = [
    "moe_permute",
    "moe_unpermute",
]


class _moe_permute(torch.autograd.Function):
    """functional Permute"""

    workspace = None
    max_expanded_token_num = 0

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        dtype: tex.DType,
        indices: torch.Tensor,
        num_out_tokens: int,
        max_token_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Empty input check
        if not inp.numel():
            return inp, None

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert indices.is_cuda, "TransformerEngine needs CUDA."
        # Shape check
        assert inp.size(0) == indices.size(0), "Permute not possible"

        # Data type check
        fp8 = False
        if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
            fp8 = True
        if fp8:
            assert isinstance(
                inp, Float8Tensor
            ), "Input must be in Float8Tensor type for FP8 moe_permute."
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
            inp = inp._data
        if indices.dtype != torch.int32:
            warnings.warn(
                f"The data type of the input `indices` of Permute is {indices.dtype}! "
                "The recommended type is torch.int32."
            )
            indices = indices.to(torch.int32)

        topK = indices.size(1)

        input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
        if _moe_permute.max_expanded_token_num < input_max_expanded_token_num:
            _moe_permute.max_expanded_token_num = input_max_expanded_token_num
            _moe_permute.workspace = []

        permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd(
            inp,
            dtype,
            indices,
            num_out_tokens,
            _moe_permute.workspace,
            _moe_permute.max_expanded_token_num,
        )

        if fp8:
            permuted_act = Float8Tensor(
                data=permuted_act, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
            )

        ctx.row_id_map = row_id_map
        ctx.num_tokens = indices.size(0)
        ctx.topK = indices.size(1)
        ctx.dtype = dtype
        ctx.fp8 = fp8
        return permuted_act, row_id_map

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
    ) -> Tuple[torch.Tensor, ...]:
        # Empty input check
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None, None

        if not permuted_act_grad.is_contiguous():
            permuted_act_grad = permuted_act_grad.contiguous()

        fp8 = ctx.fp8
        if fp8:
            assert isinstance(
                permuted_act_grad, Float8Tensor
            ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
            fp8_dtype = permuted_act_grad._fp8_dtype
            fp8_scale_inv = permuted_act_grad._scale_inv
            permuted_act_grad = permuted_act_grad._data

        row_id_map = ctx.row_id_map
        num_tokens = ctx.num_tokens
        topK = ctx.topK

        act_grad = None
        if ctx.needs_input_grad[0]:
            act_grad = tex.moe_permute_bwd(
                permuted_act_grad, ctx.dtype, row_id_map, torch.empty(0), num_tokens, topK
            )
            if fp8:
                act_grad = Float8Tensor(
                    data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * topK
                )

        return act_grad, None, None, None, None


class _moe_unpermute(torch.autograd.Function):
    """functional Unpermute"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        dtype: tex.DType,
        row_id_map: torch.Tensor,
        probs: torch.Tensor,
    ) -> torch.Tensor:
        # Empty input check
        if not inp.numel():
            ctx.probs = probs
            return inp

        # None probs check
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."

            if probs.dtype != torch.float32:
                warnings.warn(
                    f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
                    "The recommended type is torch.float32."
                )
                probs = probs.to(torch.float32)

            num_tokens = probs.size(0)
            topK = probs.size(1)
        else:
            num_tokens = row_id_map.size(0)
            topK = 1
            probs = torch.empty(0)

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        # Data type check
        fp8 = False
        if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
            fp8 = True
        if fp8:
            assert isinstance(
                inp, Float8Tensor
            ), "Input must be in Float8Tensor type for FP8 moe_unpermute."
            fp8_dtype = inp._fp8_dtype
            fp8_scale_inv = inp._scale_inv
            inp = inp._data
        if row_id_map.dtype != torch.int32:
            warnings.warn(
                f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
                "The recommended type is torch.int32."
            )
            row_id_map = row_id_map.to(torch.int32)

        unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)

        if fp8:
            unpermuted_output = Float8Tensor(
                data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
            )

        ctx.dtype = dtype
        ctx.save_for_backward(inp, row_id_map, probs)
        ctx.fp8 = fp8
        return unpermuted_output

    @staticmethod
    def backward(
        ctx,
        unpermuted_act_grad: torch.Tensor,
    ) -> Tuple[torch.Tensor, None, torch.Tensor]:
        # Empty input check
        if not unpermuted_act_grad.numel():
            return unpermuted_act_grad, None, ctx.probs

        if not unpermuted_act_grad.is_contiguous():
            unpermuted_act_grad = unpermuted_act_grad.contiguous()

        fp8 = ctx.fp8
        if fp8:
            assert isinstance(
                unpermuted_act_grad, Float8Tensor
            ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
            fp8_dtype = unpermuted_act_grad._fp8_dtype
            fp8_scale_inv = unpermuted_act_grad._scale_inv
            unpermuted_act_grad = unpermuted_act_grad._data

        inp, row_id_map, probs = ctx.saved_tensors

        act_grad = None
        if ctx.needs_input_grad[0]:
            act_grad, prob_grad = tex.moe_unpermute_bwd(
                unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs
            )
            if fp8:
                act_grad = Float8Tensor(
                    data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
                )
        if not ctx.needs_input_grad[3]:
            prob_grad = None

        return act_grad, None, None, prob_grad


def moe_permute(
    inp: torch.Tensor,
    dtype: tex.DType,
    indices: torch.Tensor,
    num_out_tokens: int = -1,
    max_token_num: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Permute the tokens based on the indices. Token with the same index will be grouped together.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    dtype: tex.DType
        Data type of the input tensor.
    indices: torch.Tensor
        The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'.
    num_out_tokens: int, default = -1
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
    max_token_num: int, default = -1
        The maximum number of tokens, used for workspace allocation.
        By default, set to '-1', meaning the calculation of the size of workspace is
        automatically taken over by the operator.
    """
    return _moe_permute.apply(inp, dtype, indices, num_out_tokens, max_token_num)


def moe_unpermute(
    inp: torch.Tensor,
    dtype: tex.DType,
    row_id_map: torch.Tensor,
    probs: torch.Tensor = None,
) -> torch.Tensor:
    """
    Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
    corresponding probabilities.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
    dtype: tex.DType
        Data type of the input tensor.
    row_id_map: torch.Tensor
        The tensor of a mapping table for sorted indices used to unpermute the tokens,
        which is the second output tensor of `Permute`.
    probs: torch.Tensor
        The tensor of probabilities corresponding to the permuted tokens. If provided,
        the unpermuted tokens will be merged with their respective probabilities.
        By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
    """
    return _moe_unpermute.apply(inp, dtype, row_id_map, probs)