permutation.py 14.5 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

Teddy Do's avatar
Teddy Do committed
5
"""PyTorch wrapper functions for Permutation Triton kernels."""
6
7
8
9
10
11

from typing import Union

import torch
import triton

Teddy Do's avatar
Teddy Do committed
12
13
14
15
16
17
18
19
20
21
from transformer_engine.common.triton.permutation import (
    _row_id_map_pass_1_kernel,
    _row_id_map_pass_2_kernel,
    _row_id_map_pass_3_kernel,
    _permute_kernel,
    _unpermute_kernel,
    _unpermute_bwd_with_merging_probs_kernel,
    _make_chunk_sort_map_kernel,
    _sort_chunks_by_map_kernel,
)
22
23


24
25
26
27
28
def make_row_id_map(
    routing_map: torch.Tensor,
    num_tokens: int,
    num_experts: int,
):
29
30
31
32
33
    """
    Prepare the row_id_map for the permutation.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
34
    routing_map : torch.Tensor
35
36
37
        Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
        which experts are routed to which tokens. The values in it: 1 means the token is routed to
        this expert and 0 means not.
Paweł Gadziński's avatar
Paweł Gadziński committed
38
    num_tokens : int
39
        Number of tokens in the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
40
    num_experts : int
41
42
43
44
        Number of experts in the input tensor.

    Returns
    -------
Paweł Gadziński's avatar
Paweł Gadziński committed
45
    row_id_map : torch.Tensor
46
47
48
49
50
51
52
53
        The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
        For each token, the last item is the number of experts that are routed (n_routed).
        The first n_routed items are the destination row indices in the permuted tokens.
        The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
        to the first n_routed row indices above.
    """
    row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda")
    block_size = 1024
54
    grid = (num_experts, triton.cdiv(num_tokens, block_size))
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda")

    # supposing num_tokens == 5, num_experts == 3, block_size == 3
    # and we have a routing_map like this:
    # [[1, 1, 0],
    #  [1, 0, 1],
    #  [0, 0, 1],
    #  [1, 1, 0],
    #  [0, 0, 0]]

    # pass 1: block cumsum
    # for each expert, compute the cumsum of every block_size tokens
    # the row_id_map will be like this after pass 1 (r means useless values):
    # [[1, 1, 0, r, r, r, r],
    #  [2, 0, 1, r, r, r, r],
    #  [0, 0, 2, r, r, r, r],
    #  [1, 1, 0, r, r, r, r],
    #  [0, 0, 0, r, r, r, r]]
73
74
75
76
77
    _row_id_map_pass_1_kernel[grid](
        routing_map,
        num_tokens,
        routing_map.stride(0),
        routing_map.stride(1),
78
79
        row_id_map.stride(0),
        row_id_map.stride(1),
80
81
        row_id_map,
        workspace_tensor,
82
83
        block_size,
    )
84
85
86
87
88
89
90
91
92

    # pass 2: cumsum all and process the mask
    # process the block cumsum into the global cumsum and then into the dst row indices
    # the row_id_map will be like this after pass 2 (r means useless value):
    # [[ 0,  3, -1, r, r, r, r],
    #  [ 1, -1,  5, r, r, r, r],
    #  [-1, -1,  6, r, r, r, r],
    #  [ 2,  4, -1, r, r, r, r],
    #  [-1, -1, -1, r, r, r, r]]
93
94
95
96
    _row_id_map_pass_2_kernel[grid](
        row_id_map,
        workspace_tensor,
        num_tokens,
97
98
        row_id_map.stride(0),
        row_id_map.stride(1),
99
100
101
        triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
        block_size,
    )
102
103
104
105
106
107
108
109
110
111
112
113
114

    # pass 3: make the row_id_map from the sparse structure to the dense structure
    # the row_id_map will be like this after pass 3 (r means useless value):
    # [[3, 0, r, 1, 0, r, 2],
    #  [5, 1, r, 2, 0, r, 2],
    #  [6, r, r, 2, r, r, 1],
    #  [4, 2, r, 1, 0, r, 2],
    #  [r, r, r, r, r, r, 0]]
    grid = (num_tokens,)
    _row_id_map_pass_3_kernel[grid](
        row_id_map,
        row_id_map.stride(0),
        row_id_map.stride(1),
115
        num_experts,
116
117
        triton.next_power_of_2(num_experts),
    )
118
119
120
121
122
123
    return row_id_map


def permute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
124
    probs: torch.Tensor,
125
    scale: torch.Tensor,
126
    pad_offsets: torch.Tensor,
127
128
129
130
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
131
    scale_hidden_dim: int,
132
):
133
134
135
136
137
    """
    Permute the input tensor based on the row_id_map.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
138
    inp : torch.Tensor
139
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
140
    row_id_map : torch.Tensor
141
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
142
    probs : torch.Tensor
143
        The probabilities of the input tensor. If it is not None, it will be permuted.
Paweł Gadziński's avatar
Paweł Gadziński committed
144
    scale : torch.Tensor
145
        The scale of the input tensor. If it is not None, it will be permuted.
146
147
148
    pad_offsets : torch.Tensor
        Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding.
        If it is not None, it will be allocated output buffers with aligned sizes.
Paweł Gadziński's avatar
Paweł Gadziński committed
149
    num_tokens : int
150
        Number of tokens in the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
151
    num_experts : int
152
        Number of experts in the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
153
    num_out_tokens : int
154
        Number of tokens in the permuted tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
155
    hidden_size : int
156
        Hidden size of the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
157
    scale_hidden_dim : int
158
159
        Hidden size of the scale tensor.
    """
160
161
162
163
164
165
166
167
168
169
170
171
    # Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed,
    # since the kernel doesn't write to padding positions.
    alloc = torch.zeros if pad_offsets is not None else torch.empty
    output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
    permuted_probs = (
        alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None
    )
    permuted_scale = (
        torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda")
        if scale is not None
        else None
    )
172
173
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
174
175
176
    _permute_kernel[grid](
        inp,
        row_id_map,
177
        probs,
178
179
        scale,
        permuted_scale,
180
        pad_offsets,
181
        scale_hidden_dim,
182
183
        row_id_map.stride(0),
        row_id_map.stride(1),
184
185
186
187
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
188
189
        probs.stride(0) if probs is not None else None,
        probs.stride(1) if probs is not None else None,
190
191
        scale.stride(0) if scale is not None else None,
        scale.stride(1) if scale is not None else None,
192
        permuted_probs.stride(0) if permuted_probs is not None else None,
193
194
        permuted_scale.stride(0) if permuted_scale is not None else None,
        permuted_scale.stride(1) if permuted_scale is not None else None,
195
196
197
198
        output,
        permuted_probs,
        num_experts,
        hidden_size,
199
        PERMUTE_PROBS=probs is not None,
200
        PERMUTE_SCALE=scale is not None,
201
        FUSION_PAD=pad_offsets is not None,
202
    )
203
    return output, permuted_scale, permuted_probs
204
205
206
207
208


def unpermute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
209
210
    merging_probs: Union[torch.Tensor, None],
    permuted_probs: Union[torch.Tensor, None],
211
    pad_offsets: Union[torch.Tensor, None],
212
213
214
215
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
):
216
217
218
219
220
    """
    Unpermute the input tensor based on the row_id_map.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
221
    inp : torch.Tensor
222
        Input tensor of shape `[num_out_tokens, hidden_size]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
223
    row_id_map : torch.Tensor
224
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
225
    merging_probs : torch.Tensor
226
227
        The merging probabilities of the input tensor. If it is not None, it will be used as weights
        to reduce the unpermuted tokens.
Paweł Gadziński's avatar
Paweł Gadziński committed
228
    permuted_probs : torch.Tensor
229
        The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
230
231
232
    pad_offsets : torch.Tensor
        Per-expert padding offsets of shape `[num_experts]` for FP8 fused unpadding.
        If it is not None, it will remove the previously fused padding.
Paweł Gadziński's avatar
Paweł Gadziński committed
233
    num_tokens : int
234
        Number of tokens in the permuted tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
235
    num_experts : int
236
        Number of experts in the permuted tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
237
    hidden_size : int
238
239
        Hidden size of the permuted tensor.
    """
240
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
241
242
243
244
245
246
    if permuted_probs is not None:
        unpermuted_probs = torch.empty(
            (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda"
        )
    else:
        unpermuted_probs = None
247
248
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
249
250
251
    _unpermute_kernel[grid](
        inp,
        row_id_map,
252
253
        merging_probs,
        permuted_probs,
254
        pad_offsets,
255
256
        row_id_map.stride(0),
        row_id_map.stride(1),
257
258
259
260
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
261
262
263
264
265
        merging_probs.stride(0) if merging_probs is not None else None,
        merging_probs.stride(1) if merging_probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
        unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
266
267
268
269
        output,
        unpermuted_probs,
        num_experts,
        hidden_size,
270
        PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
271
272
        WITH_MERGING_PROBS=merging_probs is not None,
        PERMUTE_PROBS=permuted_probs is not None,
273
        FUSION_UNPAD=pad_offsets is not None,
274
    )
275
    return output, unpermuted_probs
276
277


278
def unpermute_with_mask_map_bwd_with_merging_probs(
279
280
281
    fwd_output_grad: torch.Tensor,
    row_id_map: torch.Tensor,
    fwd_input: torch.Tensor,
282
    merging_probs: torch.Tensor,
283
    pad_offsets: Union[torch.Tensor, None],
284
285
286
287
288
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
):
289
290
291
292
293
    """
    Unpermute backward pass kernel with merging probs.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
294
    fwd_output_grad : torch.Tensor
295
        The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
296
    row_id_map : torch.Tensor
297
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
298
    fwd_input : torch.Tensor
299
        The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
300
    merging_probs : torch.Tensor
301
        The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
302
303
304
    pad_offsets : torch.Tensor
        Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding.
        If it is not None, it will be allocated output buffers with aligned sizes.
Paweł Gadziński's avatar
Paweł Gadziński committed
305
    num_tokens : int
306
        Number of tokens in the permuted tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
307
    num_experts : int
308
        Number of experts in the permuted tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
309
    num_out_tokens : int
310
        Number of tokens in the output tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
311
    hidden_size : int
312
313
        Hidden size of the output tensor.
    """
314
315
316
317
318
    # Use zeros when pad_offsets is used because padding slots won't be written to
    # by the kernel. This matches the behavior of Fp8Unpadding.backward which zeros
    # out the padding slots.
    alloc = torch.zeros if pad_offsets is not None else torch.empty
    act_grad = alloc((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda")
319
320
321
    merging_probs_grad = torch.empty(
        (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
    )
322
    grid = (num_tokens,)
323
    _unpermute_bwd_with_merging_probs_kernel[grid](
324
325
        fwd_output_grad,
        fwd_input,
326
        merging_probs,
327
        row_id_map,
328
        pad_offsets,
329
330
        row_id_map.stride(0),
        row_id_map.stride(1),
331
332
333
334
335
336
        fwd_output_grad.stride(0),
        fwd_output_grad.stride(1),
        act_grad.stride(0),
        act_grad.stride(1),
        fwd_input.stride(0),
        fwd_input.stride(1),
337
338
339
340
        merging_probs.stride(0),
        merging_probs.stride(1),
        merging_probs_grad.stride(0),
        merging_probs_grad.stride(1),
341
342
343
344
        act_grad,
        merging_probs_grad,
        num_experts,
        hidden_size,
345
        PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
346
        FUSION_UNPAD=pad_offsets is not None,
347
    )
348
    return act_grad, merging_probs_grad
349
350


351
def make_chunk_sort_map(
352
353
354
355
356
    split_sizes: torch.Tensor,
    sorted_indices: torch.Tensor,
    num_tokens: int,
    num_splits: int,
):
357
358
359
360
361
    """
    Make a row_id_map for chunk sort.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
362
    split_sizes : torch.Tensor
363
        The sizes of the chunks of shape `[num_splits,]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
364
    sorted_indices : torch.Tensor
365
        The indices of the sorted chunks of shape `[num_splits,]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
366
    num_tokens : int
367
        Number of tokens in the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
368
    num_splits : int
369
370
371
        Number of splits of split_sizes and sorted_indices.
    """
    row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda")
372
    grid = (num_tokens,)
373
    _make_chunk_sort_map_kernel[grid](
374
375
376
377
        split_sizes,
        sorted_indices,
        row_id_map,
        num_splits,
378
        IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits),
379
    )
380
    return row_id_map
381
382
383
384
385


def sort_chunks_by_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
386
    probs: torch.Tensor,
387
388
    num_tokens: int,
    hidden_size: int,
389
    is_forward: bool,
390
):
391
392
393
394
395
    """
    Sort chunks with row_id_map.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
396
    inp : torch.Tensor
397
        Input tensor of shape `[num_tokens, hidden_size]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
398
    row_id_map : torch.Tensor
399
        The token to expert mapping tensor of shape `[num_tokens,]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
400
    probs : torch.Tensor
401
        The probabilities of the input tensor. If it is not None, it will be permuted.
Paweł Gadziński's avatar
Paweł Gadziński committed
402
    num_tokens : int
403
        Number of tokens in the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
404
    hidden_size : int
405
        Hidden size of the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
406
    is_forward : bool
407
408
        Whether the sort is for forward or backward.
    """
409
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
410
411
412
413
    if probs is not None:
        permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
414
415
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
416
    _sort_chunks_by_map_kernel[grid](
417
418
        inp,
        row_id_map,
419
        probs,
420
421
422
423
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
424
425
        probs.stride(0) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
426
427
428
        output,
        permuted_probs,
        hidden_size,
429
        PERMUTE_PROBS=probs is not None,
430
        FORWARD=is_forward,
431
    )
432
    return output, permuted_probs