permutation.py 13.2 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

Teddy Do's avatar
Teddy Do committed
5
"""PyTorch wrapper functions for Permutation Triton kernels."""
6
7
8
9
10
11

from typing import Union

import torch
import triton

Teddy Do's avatar
Teddy Do committed
12
13
14
15
16
17
18
19
20
21
from transformer_engine.common.triton.permutation import (
    _row_id_map_pass_1_kernel,
    _row_id_map_pass_2_kernel,
    _row_id_map_pass_3_kernel,
    _permute_kernel,
    _unpermute_kernel,
    _unpermute_bwd_with_merging_probs_kernel,
    _make_chunk_sort_map_kernel,
    _sort_chunks_by_map_kernel,
)
22
23


24
25
26
27
28
def make_row_id_map(
    routing_map: torch.Tensor,
    num_tokens: int,
    num_experts: int,
):
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    """
    Prepare the row_id_map for the permutation.

    Parameters
    ----------
    routing_map: torch.Tensor
        Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
        which experts are routed to which tokens. The values in it: 1 means the token is routed to
        this expert and 0 means not.
    num_tokens: int
        Number of tokens in the input tensor.
    num_experts: int
        Number of experts in the input tensor.

    Returns
    -------
    row_id_map: torch.Tensor
        The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
        For each token, the last item is the number of experts that are routed (n_routed).
        The first n_routed items are the destination row indices in the permuted tokens.
        The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
        to the first n_routed row indices above.
    """
    row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda")
    block_size = 1024
54
    grid = (num_experts, triton.cdiv(num_tokens, block_size))
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda")

    # supposing num_tokens == 5, num_experts == 3, block_size == 3
    # and we have a routing_map like this:
    # [[1, 1, 0],
    #  [1, 0, 1],
    #  [0, 0, 1],
    #  [1, 1, 0],
    #  [0, 0, 0]]

    # pass 1: block cumsum
    # for each expert, compute the cumsum of every block_size tokens
    # the row_id_map will be like this after pass 1 (r means useless values):
    # [[1, 1, 0, r, r, r, r],
    #  [2, 0, 1, r, r, r, r],
    #  [0, 0, 2, r, r, r, r],
    #  [1, 1, 0, r, r, r, r],
    #  [0, 0, 0, r, r, r, r]]
73
74
75
76
77
    _row_id_map_pass_1_kernel[grid](
        routing_map,
        num_tokens,
        routing_map.stride(0),
        routing_map.stride(1),
78
79
        row_id_map.stride(0),
        row_id_map.stride(1),
80
81
        row_id_map,
        workspace_tensor,
82
83
        block_size,
    )
84
85
86
87
88
89
90
91
92

    # pass 2: cumsum all and process the mask
    # process the block cumsum into the global cumsum and then into the dst row indices
    # the row_id_map will be like this after pass 2 (r means useless value):
    # [[ 0,  3, -1, r, r, r, r],
    #  [ 1, -1,  5, r, r, r, r],
    #  [-1, -1,  6, r, r, r, r],
    #  [ 2,  4, -1, r, r, r, r],
    #  [-1, -1, -1, r, r, r, r]]
93
94
95
96
    _row_id_map_pass_2_kernel[grid](
        row_id_map,
        workspace_tensor,
        num_tokens,
97
98
        row_id_map.stride(0),
        row_id_map.stride(1),
99
100
101
        triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
        block_size,
    )
102
103
104
105
106
107
108
109
110
111
112
113
114

    # pass 3: make the row_id_map from the sparse structure to the dense structure
    # the row_id_map will be like this after pass 3 (r means useless value):
    # [[3, 0, r, 1, 0, r, 2],
    #  [5, 1, r, 2, 0, r, 2],
    #  [6, r, r, 2, r, r, 1],
    #  [4, 2, r, 1, 0, r, 2],
    #  [r, r, r, r, r, r, 0]]
    grid = (num_tokens,)
    _row_id_map_pass_3_kernel[grid](
        row_id_map,
        row_id_map.stride(0),
        row_id_map.stride(1),
115
        num_experts,
116
117
        triton.next_power_of_2(num_experts),
    )
118
119
120
121
122
123
    return row_id_map


def permute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
124
    probs: torch.Tensor,
125
    scale: torch.Tensor,
126
127
128
129
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
130
    scale_hidden_dim: int,
131
):
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    """
    Permute the input tensor based on the row_id_map.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    row_id_map: torch.Tensor
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    probs: torch.Tensor
        The probabilities of the input tensor. If it is not None, it will be permuted.
    scale: torch.Tensor
        The scale of the input tensor. If it is not None, it will be permuted.
    num_tokens: int
        Number of tokens in the input tensor.
    num_experts: int
        Number of experts in the input tensor.
    num_out_tokens: int
        Number of tokens in the permuted tensor.
    hidden_size: int
        Hidden size of the input tensor.
    scale_hidden_dim: int
        Hidden size of the scale tensor.
    """
156
    output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
157
158
159
160
    if probs is not None:
        permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
161
162
163
164
165
166
167

    if scale is not None:
        permuted_scale = torch.empty(
            (num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
        )
    else:
        permuted_scale = None
168
169
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
170
171
172
    _permute_kernel[grid](
        inp,
        row_id_map,
173
        probs,
174
175
176
        scale,
        permuted_scale,
        scale_hidden_dim,
177
178
        row_id_map.stride(0),
        row_id_map.stride(1),
179
180
181
182
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
183
184
        probs.stride(0) if probs is not None else None,
        probs.stride(1) if probs is not None else None,
185
186
        scale.stride(0) if scale is not None else None,
        scale.stride(1) if scale is not None else None,
187
        permuted_probs.stride(0) if permuted_probs is not None else None,
188
189
        permuted_scale.stride(0) if permuted_scale is not None else None,
        permuted_scale.stride(1) if permuted_scale is not None else None,
190
191
192
193
        output,
        permuted_probs,
        num_experts,
        hidden_size,
194
        PERMUTE_PROBS=probs is not None,
195
        PERMUTE_SCALE=scale is not None,
196
    )
197
    return output, permuted_scale, permuted_probs
198
199
200
201
202


def unpermute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
203
204
    merging_probs: Union[torch.Tensor, None],
    permuted_probs: Union[torch.Tensor, None],
205
206
207
208
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
):
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    """
    Unpermute the input tensor based on the row_id_map.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_out_tokens, hidden_size]`.
    row_id_map: torch.Tensor
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    merging_probs: torch.Tensor
        The merging probabilities of the input tensor. If it is not None, it will be used as weights
        to reduce the unpermuted tokens.
    permuted_probs: torch.Tensor
        The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
    num_tokens: int
        Number of tokens in the permuted tensor.
    num_experts: int
        Number of experts in the permuted tensor.
    hidden_size: int
        Hidden size of the permuted tensor.
    """
230
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
231
232
233
234
235
236
    if permuted_probs is not None:
        unpermuted_probs = torch.empty(
            (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda"
        )
    else:
        unpermuted_probs = None
237
238
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
239
240
241
    _unpermute_kernel[grid](
        inp,
        row_id_map,
242
243
        merging_probs,
        permuted_probs,
244
245
        row_id_map.stride(0),
        row_id_map.stride(1),
246
247
248
249
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
250
251
252
253
254
        merging_probs.stride(0) if merging_probs is not None else None,
        merging_probs.stride(1) if merging_probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
        unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
255
256
257
258
        output,
        unpermuted_probs,
        num_experts,
        hidden_size,
259
        PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
260
261
        WITH_MERGING_PROBS=merging_probs is not None,
        PERMUTE_PROBS=permuted_probs is not None,
262
    )
263
    return output, unpermuted_probs
264
265


266
def unpermute_with_mask_map_bwd_with_merging_probs(
267
268
269
    fwd_output_grad: torch.Tensor,
    row_id_map: torch.Tensor,
    fwd_input: torch.Tensor,
270
    merging_probs: torch.Tensor,
271
272
273
274
275
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
):
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    """
    Unpermute backward pass kernel with merging probs.

    Parameters
    ----------
    fwd_output_grad: torch.Tensor
        The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
    row_id_map: torch.Tensor
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    fwd_input: torch.Tensor
        The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
    merging_probs: torch.Tensor
        The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
    num_tokens: int
        Number of tokens in the permuted tensor.
    num_experts: int
        Number of experts in the permuted tensor.
    num_out_tokens: int
        Number of tokens in the output tensor.
    hidden_size: int
        Hidden size of the output tensor.
    """
298
299
300
    act_grad = torch.empty(
        (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
    )
301
302
303
    merging_probs_grad = torch.empty(
        (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
    )
304
    grid = (num_tokens,)
305
    _unpermute_bwd_with_merging_probs_kernel[grid](
306
307
308
        fwd_output_grad,
        act_grad,
        fwd_input,
309
310
        merging_probs,
        merging_probs_grad,
311
312
313
        row_id_map,
        num_experts,
        hidden_size,
314
315
        row_id_map.stride(0),
        row_id_map.stride(1),
316
317
318
319
320
321
        fwd_output_grad.stride(0),
        fwd_output_grad.stride(1),
        act_grad.stride(0),
        act_grad.stride(1),
        fwd_input.stride(0),
        fwd_input.stride(1),
322
323
324
325
        merging_probs.stride(0),
        merging_probs.stride(1),
        merging_probs_grad.stride(0),
        merging_probs_grad.stride(1),
326
        PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
327
    )
328
    return act_grad, merging_probs_grad
329
330


331
def make_chunk_sort_map(
332
333
334
335
336
    split_sizes: torch.Tensor,
    sorted_indices: torch.Tensor,
    num_tokens: int,
    num_splits: int,
):
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    """
    Make a row_id_map for chunk sort.

    Parameters
    ----------
    split_sizes: torch.Tensor
        The sizes of the chunks of shape `[num_splits,]`.
    sorted_indices: torch.Tensor
        The indices of the sorted chunks of shape `[num_splits,]`.
    num_tokens: int
        Number of tokens in the input tensor.
    num_splits: int
        Number of splits of split_sizes and sorted_indices.
    """
    row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda")
352
    grid = (num_tokens,)
353
    _make_chunk_sort_map_kernel[grid](
354
355
356
357
        split_sizes,
        sorted_indices,
        row_id_map,
        num_splits,
358
        IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits),
359
    )
360
    return row_id_map
361
362
363
364
365


def sort_chunks_by_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
366
    probs: torch.Tensor,
367
368
    num_tokens: int,
    hidden_size: int,
369
    is_forward: bool,
370
):
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    """
    Sort chunks with row_id_map.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`.
    row_id_map: torch.Tensor
        The token to expert mapping tensor of shape `[num_tokens,]`.
    probs: torch.Tensor
        The probabilities of the input tensor. If it is not None, it will be permuted.
    num_tokens: int
        Number of tokens in the input tensor.
    hidden_size: int
        Hidden size of the input tensor.
    is_forward: bool
        Whether the sort is for forward or backward.
    """
389
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
390
391
392
393
    if probs is not None:
        permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
394
395
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
396
    _sort_chunks_by_map_kernel[grid](
397
398
        inp,
        row_id_map,
399
        probs,
400
401
402
403
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
404
405
        probs.stride(0) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
406
407
408
        output,
        permuted_probs,
        hidden_size,
409
        PERMUTE_PROBS=probs is not None,
410
        FORWARD=is_forward,
411
    )
412
    return output, permuted_probs