permutation.py 8.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Linear API"""
import warnings
from typing import Tuple
import torch

import transformer_engine_torch as tex
11
12
from .constants import TE_DType
from .float8_tensor import Float8Tensor
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36


__all__ = [
    "moe_permute",
    "moe_unpermute",
]


class _moe_permute(torch.autograd.Function):
    """functional Permute"""

    workspace = None
    max_expanded_token_num = 0

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        indices: torch.Tensor,
        num_out_tokens: int,
        max_token_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Empty input check
        if not inp.numel():
37
            return inp, torch.tensor([], device=inp.device)
38
39
40
41
42
43
44
45

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert indices.is_cuda, "TransformerEngine needs CUDA."
        # Shape check
        assert inp.size(0) == indices.size(0), "Permute not possible"

        # Data type check
46
        fp8 = isinstance(inp, Float8Tensor)
47
        if fp8:
48
            dtype = inp._fp8_dtype
49
50
            fp8_scale_inv = inp._scale_inv
            inp = inp._data
51
52
        else:
            dtype = TE_DType[inp.dtype]
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        if indices.dtype != torch.int32:
            warnings.warn(
                f"The data type of the input `indices` of Permute is {indices.dtype}! "
                "The recommended type is torch.int32."
            )
            indices = indices.to(torch.int32)

        topK = indices.size(1)

        input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
        if _moe_permute.max_expanded_token_num < input_max_expanded_token_num:
            _moe_permute.max_expanded_token_num = input_max_expanded_token_num
            _moe_permute.workspace = []

        permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd(
            inp,
            dtype,
            indices,
            num_out_tokens,
            _moe_permute.workspace,
            _moe_permute.max_expanded_token_num,
        )

        if fp8:
            permuted_act = Float8Tensor(
78
                data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            )

        ctx.row_id_map = row_id_map
        ctx.num_tokens = indices.size(0)
        ctx.topK = indices.size(1)
        ctx.fp8 = fp8
        return permuted_act, row_id_map

    @staticmethod
    def backward(
        ctx,
        permuted_act_grad: torch.Tensor,
        _,
    ) -> Tuple[torch.Tensor, ...]:
        # Empty input check
        if not permuted_act_grad.numel():
            return permuted_act_grad, None, None, None

        if not permuted_act_grad.is_contiguous():
            permuted_act_grad = permuted_act_grad.contiguous()

100
        if ctx.fp8:
101
102
103
            assert isinstance(
                permuted_act_grad, Float8Tensor
            ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
104
            dtype = permuted_act_grad._fp8_dtype
105
106
            fp8_scale_inv = permuted_act_grad._scale_inv
            permuted_act_grad = permuted_act_grad._data
107
108
        else:
            dtype = TE_DType[permuted_act_grad.dtype]
109
110
111
112

        act_grad = None
        if ctx.needs_input_grad[0]:
            act_grad = tex.moe_permute_bwd(
113
                permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
114
            )
115
            if ctx.fp8:
116
                act_grad = Float8Tensor(
117
                    data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK
118
119
                )

120
        return act_grad, None, None, None
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160


class _moe_unpermute(torch.autograd.Function):
    """functional Unpermute"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        row_id_map: torch.Tensor,
        probs: torch.Tensor,
    ) -> torch.Tensor:
        # Empty input check
        if not inp.numel():
            ctx.probs = probs
            return inp

        # None probs check
        if probs is not None:
            assert probs.is_cuda, "TransformerEngine needs CUDA."

            if probs.dtype != torch.float32:
                warnings.warn(
                    f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
                    "The recommended type is torch.float32."
                )
                probs = probs.to(torch.float32)

            num_tokens = probs.size(0)
            topK = probs.size(1)
        else:
            num_tokens = row_id_map.size(0)
            topK = 1
            probs = torch.empty(0)

        # Device check
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

        # Data type check
161
        fp8 = isinstance(inp, Float8Tensor)
162
        if fp8:
163
            dtype = inp._fp8_dtype
164
165
            fp8_scale_inv = inp._scale_inv
            inp = inp._data
166
167
        else:
            dtype = TE_DType[inp.dtype]
168
169
170
171
172
173
174
175
176
177
178
        if row_id_map.dtype != torch.int32:
            warnings.warn(
                f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
                "The recommended type is torch.int32."
            )
            row_id_map = row_id_map.to(torch.int32)

        unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)

        if fp8:
            unpermuted_output = Float8Tensor(
179
                data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            )

        ctx.save_for_backward(inp, row_id_map, probs)
        ctx.fp8 = fp8
        return unpermuted_output

    @staticmethod
    def backward(
        ctx,
        unpermuted_act_grad: torch.Tensor,
    ) -> Tuple[torch.Tensor, None, torch.Tensor]:
        # Empty input check
        if not unpermuted_act_grad.numel():
            return unpermuted_act_grad, None, ctx.probs

        if not unpermuted_act_grad.is_contiguous():
            unpermuted_act_grad = unpermuted_act_grad.contiguous()

198
        if ctx.fp8:
199
200
201
            assert isinstance(
                unpermuted_act_grad, Float8Tensor
            ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
202
            dtype = unpermuted_act_grad._fp8_dtype
203
204
            fp8_scale_inv = unpermuted_act_grad._scale_inv
            unpermuted_act_grad = unpermuted_act_grad._data
205
206
        else:
            dtype = TE_DType[unpermuted_act_grad.dtype]
207
208
209
210
211
212

        inp, row_id_map, probs = ctx.saved_tensors

        act_grad = None
        if ctx.needs_input_grad[0]:
            act_grad, prob_grad = tex.moe_unpermute_bwd(
213
                unpermuted_act_grad, inp, dtype, row_id_map, probs
214
            )
215
216
217
            if ctx.fp8:
                act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv)
        if not ctx.needs_input_grad[2]:
218
219
            prob_grad = None

220
        return act_grad, None, prob_grad
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245


def moe_permute(
    inp: torch.Tensor,
    indices: torch.Tensor,
    num_out_tokens: int = -1,
    max_token_num: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Permute the tokens based on the indices. Token with the same index will be grouped together.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    indices: torch.Tensor
        The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'.
    num_out_tokens: int, default = -1
        The effective output token count, representing the number of tokens not dropped.
        By default, set to '-1', meaning no tokens are dropped.
    max_token_num: int, default = -1
        The maximum number of tokens, used for workspace allocation.
        By default, set to '-1', meaning the calculation of the size of workspace is
        automatically taken over by the operator.
    """
246
    return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num)
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269


def moe_unpermute(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
    probs: torch.Tensor = None,
) -> torch.Tensor:
    """
    Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
    corresponding probabilities.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
    row_id_map: torch.Tensor
        The tensor of a mapping table for sorted indices used to unpermute the tokens,
        which is the second output tensor of `Permute`.
    probs: torch.Tensor
        The tensor of probabilities corresponding to the permuted tokens. If provided,
        the unpermuted tokens will be merged with their respective probabilities.
        By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
    """
270
    return _moe_unpermute.apply(inp, row_id_map, probs)