permutation.py 15.1 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

Teddy Do's avatar
Teddy Do committed
5
"""PyTorch wrapper functions for Permutation Triton kernels."""
6
7
8
9
10
11

from typing import Union

import torch
import triton

Teddy Do's avatar
Teddy Do committed
12
13
14
15
16
17
18
19
20
21
from transformer_engine.common.triton.permutation import (
    _row_id_map_pass_1_kernel,
    _row_id_map_pass_2_kernel,
    _row_id_map_pass_3_kernel,
    _permute_kernel,
    _unpermute_kernel,
    _unpermute_bwd_with_merging_probs_kernel,
    _make_chunk_sort_map_kernel,
    _sort_chunks_by_map_kernel,
)
22
23


24
25
26
27
28
def make_row_id_map(
    routing_map: torch.Tensor,
    num_tokens: int,
    num_experts: int,
):
29
30
31
32
33
    """
    Prepare the row_id_map for the permutation.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
34
    routing_map : torch.Tensor
35
36
37
        Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
        which experts are routed to which tokens. The values in it: 1 means the token is routed to
        this expert and 0 means not.
Paweł Gadziński's avatar
Paweł Gadziński committed
38
    num_tokens : int
39
        Number of tokens in the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
40
    num_experts : int
41
42
43
44
        Number of experts in the input tensor.

    Returns
    -------
Paweł Gadziński's avatar
Paweł Gadziński committed
45
    row_id_map : torch.Tensor
46
47
48
49
50
51
52
53
        The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
        For each token, the last item is the number of experts that are routed (n_routed).
        The first n_routed items are the destination row indices in the permuted tokens.
        The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
        to the first n_routed row indices above.
    """
    row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda")
    block_size = 1024
54
    grid = (num_experts, triton.cdiv(num_tokens, block_size))
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda")

    # supposing num_tokens == 5, num_experts == 3, block_size == 3
    # and we have a routing_map like this:
    # [[1, 1, 0],
    #  [1, 0, 1],
    #  [0, 0, 1],
    #  [1, 1, 0],
    #  [0, 0, 0]]

    # pass 1: block cumsum
    # for each expert, compute the cumsum of every block_size tokens
    # the row_id_map will be like this after pass 1 (r means useless values):
    # [[1, 1, 0, r, r, r, r],
    #  [2, 0, 1, r, r, r, r],
    #  [0, 0, 2, r, r, r, r],
    #  [1, 1, 0, r, r, r, r],
    #  [0, 0, 0, r, r, r, r]]
73
74
75
76
77
    _row_id_map_pass_1_kernel[grid](
        routing_map,
        num_tokens,
        routing_map.stride(0),
        routing_map.stride(1),
78
79
        row_id_map.stride(0),
        row_id_map.stride(1),
80
81
        row_id_map,
        workspace_tensor,
82
83
        block_size,
    )
84
85
86
87
88
89
90
91
92

    # pass 2: cumsum all and process the mask
    # process the block cumsum into the global cumsum and then into the dst row indices
    # the row_id_map will be like this after pass 2 (r means useless value):
    # [[ 0,  3, -1, r, r, r, r],
    #  [ 1, -1,  5, r, r, r, r],
    #  [-1, -1,  6, r, r, r, r],
    #  [ 2,  4, -1, r, r, r, r],
    #  [-1, -1, -1, r, r, r, r]]
93
94
95
96
    _row_id_map_pass_2_kernel[grid](
        row_id_map,
        workspace_tensor,
        num_tokens,
97
98
        row_id_map.stride(0),
        row_id_map.stride(1),
99
100
101
        triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
        block_size,
    )
102
103
104
105
106
107
108
109
110
111
112
113
114

    # pass 3: make the row_id_map from the sparse structure to the dense structure
    # the row_id_map will be like this after pass 3 (r means useless value):
    # [[3, 0, r, 1, 0, r, 2],
    #  [5, 1, r, 2, 0, r, 2],
    #  [6, r, r, 2, r, r, 1],
    #  [4, 2, r, 1, 0, r, 2],
    #  [r, r, r, r, r, r, 0]]
    grid = (num_tokens,)
    _row_id_map_pass_3_kernel[grid](
        row_id_map,
        row_id_map.stride(0),
        row_id_map.stride(1),
115
        num_experts,
116
117
        triton.next_power_of_2(num_experts),
    )
118
119
120
121
122
123
    return row_id_map


def permute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
124
    probs: torch.Tensor,
125
    scale: torch.Tensor,
126
    pad_offsets: torch.Tensor,
127
128
129
130
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
131
    scale_hidden_dim: int,
132
):
133
134
135
136
137
    """
    Permute the input tensor based on the row_id_map.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
138
    inp : torch.Tensor
139
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
140
    row_id_map : torch.Tensor
141
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
142
    probs : torch.Tensor
143
        The probabilities of the input tensor. If it is not None, it will be permuted.
Paweł Gadziński's avatar
Paweł Gadziński committed
144
    scale : torch.Tensor
145
        The scale of the input tensor. If it is not None, it will be permuted.
146
147
148
    pad_offsets : torch.Tensor
        Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding.
        If it is not None, it will be allocated output buffers with aligned sizes.
Paweł Gadziński's avatar
Paweł Gadziński committed
149
    num_tokens : int
150
        Number of tokens in the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
151
    num_experts : int
152
        Number of experts in the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
153
    num_out_tokens : int
154
        Number of tokens in the permuted tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
155
    hidden_size : int
156
        Hidden size of the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
157
    scale_hidden_dim : int
158
159
        Hidden size of the scale tensor.
    """
160
161
    # Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed.
    # The kernel writes only to valid positions, leaving padding positions at zero.
162
163
164
165
166
167
    alloc = torch.zeros if pad_offsets is not None else torch.empty
    output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
    permuted_probs = (
        alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None
    )
    permuted_scale = (
168
        alloc((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda")
169
170
171
        if scale is not None
        else None
    )
172
173
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
174
175
176
    _permute_kernel[grid](
        inp,
        row_id_map,
177
        probs,
178
179
        scale,
        permuted_scale,
180
        pad_offsets,
181
182
183
184
        # Pass output buffers as input parameters (for JAX input_output_aliases compatibility).
        # In PyTorch, these point to the same memory as the output pointers below.
        output,
        permuted_probs,
185
        scale_hidden_dim,
186
187
        num_tokens,
        num_out_tokens,
188
189
        row_id_map.stride(0),
        row_id_map.stride(1),
190
191
192
193
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
194
195
        probs.stride(0) if probs is not None else None,
        probs.stride(1) if probs is not None else None,
196
197
        scale.stride(0) if scale is not None else None,
        scale.stride(1) if scale is not None else None,
198
        permuted_probs.stride(0) if permuted_probs is not None else None,
199
200
        permuted_scale.stride(0) if permuted_scale is not None else None,
        permuted_scale.stride(1) if permuted_scale is not None else None,
201
202
203
204
        output,
        permuted_probs,
        num_experts,
        hidden_size,
205
        PERMUTE_PROBS=probs is not None,
206
        PERMUTE_SCALE=scale is not None,
207
        FUSION_PAD=pad_offsets is not None,
208
    )
209
    return output, permuted_scale, permuted_probs
210
211
212
213
214


def unpermute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
215
216
    merging_probs: Union[torch.Tensor, None],
    permuted_probs: Union[torch.Tensor, None],
217
    pad_offsets: Union[torch.Tensor, None],
218
219
220
221
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
):
222
223
224
225
226
    """
    Unpermute the input tensor based on the row_id_map.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
227
    inp : torch.Tensor
228
        Input tensor of shape `[num_out_tokens, hidden_size]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
229
    row_id_map : torch.Tensor
230
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
231
    merging_probs : torch.Tensor
232
233
        The merging probabilities of the input tensor. If it is not None, it will be used as weights
        to reduce the unpermuted tokens.
Paweł Gadziński's avatar
Paweł Gadziński committed
234
    permuted_probs : torch.Tensor
235
        The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
236
237
238
    pad_offsets : torch.Tensor
        Per-expert padding offsets of shape `[num_experts]` for FP8 fused unpadding.
        If it is not None, it will remove the previously fused padding.
Paweł Gadziński's avatar
Paweł Gadziński committed
239
    num_tokens : int
240
        Number of tokens in the permuted tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
241
    num_experts : int
242
        Number of experts in the permuted tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
243
    hidden_size : int
244
245
        Hidden size of the permuted tensor.
    """
246
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
247
248
249
250
251
252
    if permuted_probs is not None:
        unpermuted_probs = torch.empty(
            (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda"
        )
    else:
        unpermuted_probs = None
253
254
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
255
256
257
    _unpermute_kernel[grid](
        inp,
        row_id_map,
258
259
        merging_probs,
        permuted_probs,
260
        pad_offsets,
261
262
263
264
        # Dummy buffer parameters for kernel signature consistency with _permute_kernel.
        # These are unused in unpermute but maintain consistent interface.
        output,  # output_buf_ptr (unused, passed for signature consistency)
        unpermuted_probs,  # unpermuted_probs_buf_ptr (unused, passed for signature consistency)
265
266
        row_id_map.stride(0),
        row_id_map.stride(1),
267
268
269
270
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
271
272
273
274
275
        merging_probs.stride(0) if merging_probs is not None else None,
        merging_probs.stride(1) if merging_probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
        unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
276
277
278
279
        output,
        unpermuted_probs,
        num_experts,
        hidden_size,
280
        PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
281
282
        WITH_MERGING_PROBS=merging_probs is not None,
        PERMUTE_PROBS=permuted_probs is not None,
283
        FUSION_UNPAD=pad_offsets is not None,
284
    )
285
    return output, unpermuted_probs
286
287


288
def unpermute_with_mask_map_bwd_with_merging_probs(
289
290
291
    fwd_output_grad: torch.Tensor,
    row_id_map: torch.Tensor,
    fwd_input: torch.Tensor,
292
    merging_probs: torch.Tensor,
293
    pad_offsets: Union[torch.Tensor, None],
294
295
296
297
298
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
):
299
300
301
302
303
    """
    Unpermute backward pass kernel with merging probs.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
304
    fwd_output_grad : torch.Tensor
305
        The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
306
    row_id_map : torch.Tensor
307
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
308
    fwd_input : torch.Tensor
309
        The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
310
    merging_probs : torch.Tensor
311
        The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
312
313
314
    pad_offsets : torch.Tensor
        Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding.
        If it is not None, it will be allocated output buffers with aligned sizes.
Paweł Gadziński's avatar
Paweł Gadziński committed
315
    num_tokens : int
316
        Number of tokens in the permuted tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
317
    num_experts : int
318
        Number of experts in the permuted tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
319
    num_out_tokens : int
320
        Number of tokens in the output tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
321
    hidden_size : int
322
323
        Hidden size of the output tensor.
    """
324
325
326
327
328
    # Use zeros when pad_offsets is used because padding slots won't be written to
    # by the kernel. This matches the behavior of Fp8Unpadding.backward which zeros
    # out the padding slots.
    alloc = torch.zeros if pad_offsets is not None else torch.empty
    act_grad = alloc((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda")
329
330
331
    merging_probs_grad = torch.empty(
        (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
    )
332
    grid = (num_tokens,)
333
    _unpermute_bwd_with_merging_probs_kernel[grid](
334
335
        fwd_output_grad,
        fwd_input,
336
        merging_probs,
337
        row_id_map,
338
        pad_offsets,
339
340
        row_id_map.stride(0),
        row_id_map.stride(1),
341
342
343
344
345
346
        fwd_output_grad.stride(0),
        fwd_output_grad.stride(1),
        act_grad.stride(0),
        act_grad.stride(1),
        fwd_input.stride(0),
        fwd_input.stride(1),
347
348
349
350
        merging_probs.stride(0),
        merging_probs.stride(1),
        merging_probs_grad.stride(0),
        merging_probs_grad.stride(1),
351
352
353
354
        act_grad,
        merging_probs_grad,
        num_experts,
        hidden_size,
355
        PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
356
        FUSION_UNPAD=pad_offsets is not None,
357
    )
358
    return act_grad, merging_probs_grad
359
360


361
def make_chunk_sort_map(
362
363
364
365
366
    split_sizes: torch.Tensor,
    sorted_indices: torch.Tensor,
    num_tokens: int,
    num_splits: int,
):
367
368
369
370
371
    """
    Make a row_id_map for chunk sort.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
372
    split_sizes : torch.Tensor
373
        The sizes of the chunks of shape `[num_splits,]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
374
    sorted_indices : torch.Tensor
375
        The indices of the sorted chunks of shape `[num_splits,]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
376
    num_tokens : int
377
        Number of tokens in the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
378
    num_splits : int
379
380
381
        Number of splits of split_sizes and sorted_indices.
    """
    row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda")
382
    grid = (num_tokens,)
383
    _make_chunk_sort_map_kernel[grid](
384
385
386
387
        split_sizes,
        sorted_indices,
        row_id_map,
        num_splits,
388
        IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits),
389
    )
390
    return row_id_map
391
392
393
394
395


def sort_chunks_by_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
396
    probs: torch.Tensor,
397
398
    num_tokens: int,
    hidden_size: int,
399
    is_forward: bool,
400
):
401
402
403
404
405
    """
    Sort chunks with row_id_map.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
406
    inp : torch.Tensor
407
        Input tensor of shape `[num_tokens, hidden_size]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
408
    row_id_map : torch.Tensor
409
        The token to expert mapping tensor of shape `[num_tokens,]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
410
    probs : torch.Tensor
411
        The probabilities of the input tensor. If it is not None, it will be permuted.
Paweł Gadziński's avatar
Paweł Gadziński committed
412
    num_tokens : int
413
        Number of tokens in the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
414
    hidden_size : int
415
        Hidden size of the input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
416
    is_forward : bool
417
418
        Whether the sort is for forward or backward.
    """
419
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
420
421
422
423
    if probs is not None:
        permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
424
425
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
426
    _sort_chunks_by_map_kernel[grid](
427
428
        inp,
        row_id_map,
429
        probs,
430
431
432
433
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
434
435
        probs.stride(0) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
436
437
438
        output,
        permuted_probs,
        hidden_size,
439
        PERMUTE_PROBS=probs is not None,
440
        FORWARD=is_forward,
441
    )
442
    return output, permuted_probs