permutation.py 32.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Permutation kernels written with OpenAI Triton."""

from typing import Union

import torch
import triton
import triton.language as tl

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from triton.language import core
from triton.language.standard import _log2


# The following three argsort related kernels are adapted from
# the issue https://github.com/triton-lang/triton/issues/3698


@triton.jit
def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr):
    n_outer: tl.constexpr = x.numel >> n_dims
    shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)]
    y = tl.reshape(x, shape)
    z = tl.reshape(indices, shape)

    mask = tl.arange(0, 2)[None, :, None]

    l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to(
        x.dtype
    )
    r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to(
        x.dtype
    )

    l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape)
    r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape)

    idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)

    il_value = l_value.to(idtype, bitcast=True)
    ir_value = r_value.to(idtype, bitcast=True)
    ix = x.to(idtype, bitcast=True)

    flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix))
    ret = ix ^ flag1
    flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix))
    ind = indices ^ flag2

    return ret.to(x.dtype, bitcast=True), ind


@triton.jit
def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):
    n_outer: tl.constexpr = x.numel >> n_dims
    tl.static_assert(stage <= n_dims)
    """
    order_type 0 == ascending
    order_type 1 == descending
    order_type 2 == alternating
    """
    if order == 2:
        shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage]
        flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
    else:
        flip = tl.full(x.shape, value=order, dtype=tl.int32)
    for i in tl.static_range(stage):
        x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims)
    return x, indices


@triton.jit
def _argsort(x, indices, n_dims: tl.constexpr):
    for i in tl.static_range(1, n_dims + 1):
        x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims)
    return x, indices

79
80
81
82
83
84
85
86
87
88
89
90

@triton.jit
def _row_id_map_pass_1_kernel(
    # pointers
    routing_map_ptr,
    row_id_map_ptr,
    workspace_ptr,
    # sizes
    num_tokens,
    # strides
    stride_routing_map_token,
    stride_routing_map_expert,
91
92
    stride_row_id_map_token,
    stride_row_id_map_expert,
93
94
95
96
97
98
99
100
101
102
    # metas
    BLOCK_SIZE: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    expert_token_mask = tl.load(
        routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
        mask=(offset < num_tokens),
        other=0,
103
    ).to(tl.int32)
104
105
    row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
    tl.store(
106
        row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        row_id_within_token_block,
        mask=offset < num_tokens,
    )
    n_tokens_per_block = tl.sum(expert_token_mask)
    tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block)


@triton.jit
def _row_id_map_pass_2_kernel(
    # pointers
    row_id_map_ptr,
    workspace_ptr,
    # sizes
    num_tokens,
121
122
123
    # strides
    stride_row_id_map_token,
    stride_row_id_map_expert,
124
125
126
127
128
129
130
131
132
    # metas
    WORKSPACE_LOAD_WIDTH: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
    offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    row_id_within_token_block = tl.load(
133
134
135
        row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
        mask=(offset < num_tokens),
        other=0,
136
137
138
139
140
141
142
143
144
145
    )

    workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
    n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
    row_id = tl.where(
        row_id_within_token_block == 0,
        -1,
        row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
    )
    tl.store(
146
        row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
147
148
149
150
151
        row_id,
        mask=(offset < num_tokens),
    )


152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
@triton.jit
def _row_id_map_pass_3_kernel(
    # pointers
    row_id_map_ptr,
    # sizes
    num_experts: tl.constexpr,
    # strides
    stride_row_id_map_token,
    stride_row_id_map_expert,
    # metas
    LOAD_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    n_dims: tl.constexpr = _log2(LOAD_SIZE)
    off = tl.arange(0, LOAD_SIZE)
    row_id_map = tl.load(
        row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off,
        mask=off < num_experts,
        other=-1,
    )
    n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0))
    indices = off
    sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims)
    tl.store(
        row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert,
        sorted_map,
        mask=off < n_routed,
    )
    tl.store(
        row_id_map_ptr
        + pid * stride_row_id_map_token
        + (num_experts + off) * stride_row_id_map_expert,
        indices,
        mask=off < n_routed,
    )
    tl.store(
        row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert,
        n_routed,
    )


193
194
195
196
197
def make_row_id_map(
    routing_map: torch.Tensor,
    num_tokens: int,
    num_experts: int,
):
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    """
    Prepare the row_id_map for the permutation.

    Parameters
    ----------
    routing_map: torch.Tensor
        Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
        which experts are routed to which tokens. The values in it: 1 means the token is routed to
        this expert and 0 means not.
    num_tokens: int
        Number of tokens in the input tensor.
    num_experts: int
        Number of experts in the input tensor.

    Returns
    -------
    row_id_map: torch.Tensor
        The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
        For each token, the last item is the number of experts that are routed (n_routed).
        The first n_routed items are the destination row indices in the permuted tokens.
        The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
        to the first n_routed row indices above.
    """
    row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda")
    block_size = 1024
223
    grid = (num_experts, triton.cdiv(num_tokens, block_size))
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda")

    # supposing num_tokens == 5, num_experts == 3, block_size == 3
    # and we have a routing_map like this:
    # [[1, 1, 0],
    #  [1, 0, 1],
    #  [0, 0, 1],
    #  [1, 1, 0],
    #  [0, 0, 0]]

    # pass 1: block cumsum
    # for each expert, compute the cumsum of every block_size tokens
    # the row_id_map will be like this after pass 1 (r means useless values):
    # [[1, 1, 0, r, r, r, r],
    #  [2, 0, 1, r, r, r, r],
    #  [0, 0, 2, r, r, r, r],
    #  [1, 1, 0, r, r, r, r],
    #  [0, 0, 0, r, r, r, r]]
242
243
244
245
246
247
248
    _row_id_map_pass_1_kernel[grid](
        routing_map,
        row_id_map,
        workspace_tensor,
        num_tokens,
        routing_map.stride(0),
        routing_map.stride(1),
249
250
        row_id_map.stride(0),
        row_id_map.stride(1),
251
252
        block_size,
    )
253
254
255
256
257
258
259
260
261

    # pass 2: cumsum all and process the mask
    # process the block cumsum into the global cumsum and then into the dst row indices
    # the row_id_map will be like this after pass 2 (r means useless value):
    # [[ 0,  3, -1, r, r, r, r],
    #  [ 1, -1,  5, r, r, r, r],
    #  [-1, -1,  6, r, r, r, r],
    #  [ 2,  4, -1, r, r, r, r],
    #  [-1, -1, -1, r, r, r, r]]
262
263
264
265
    _row_id_map_pass_2_kernel[grid](
        row_id_map,
        workspace_tensor,
        num_tokens,
266
267
        row_id_map.stride(0),
        row_id_map.stride(1),
268
269
270
        triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
        block_size,
    )
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

    # pass 3: make the row_id_map from the sparse structure to the dense structure
    # the row_id_map will be like this after pass 3 (r means useless value):
    # [[3, 0, r, 1, 0, r, 2],
    #  [5, 1, r, 2, 0, r, 2],
    #  [6, r, r, 2, r, r, 1],
    #  [4, 2, r, 1, 0, r, 2],
    #  [r, r, r, r, r, r, 0]]
    grid = (num_tokens,)
    _row_id_map_pass_3_kernel[grid](
        row_id_map,
        num_experts,
        row_id_map.stride(0),
        row_id_map.stride(1),
        triton.next_power_of_2(num_experts),
    )
287
288
289
290
291
292
293
294
295
    return row_id_map


@triton.jit
def _permute_kernel(
    # pointers
    input_ptr,
    output_ptr,
    row_id_map_ptr,
296
    probs_ptr,
297
    scale_ptr,
298
    permuted_probs_ptr,
299
    permuted_scale_ptr,
300
    # sizes
301
302
    num_experts: tl.constexpr,
    hidden_size: tl.constexpr,
303
    scale_hidden_dim,
304
    # strides
305
306
    stride_row_id_map_token,
    stride_row_id_map_expert,
307
308
309
310
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
311
312
    stride_probs_token,
    stride_probs_expert,
313
314
    stride_scale_token,
    stride_scale_hidden,
315
    stride_permuted_probs_token,
316
317
    stride_permuted_scale_token,
    stride_permuted_scale_hidden,
318
    # metas
319
    PERMUTE_PROBS: tl.constexpr,
320
    PERMUTE_SCALE: tl.constexpr,
321
322
    BLOCK_SIZE: tl.constexpr,
):
323
324
325
326
    pid_t = tl.program_id(0)
    pid_h = tl.program_id(1)
    cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = cur_off < hidden_size
327
328
    src_row = pid_t.to(tl.int64)
    input_off = src_row * stride_input_token + cur_off * stride_input_hidden
329
330
331
332
333
334
335
336
337
338
339
340
341
    inp = tl.load(input_ptr + input_off, mask=mask)
    if PERMUTE_SCALE:
        mask_scale = cur_off < scale_hidden_dim
        scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden
        scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
    n_routed = tl.load(
        row_id_map_ptr
        + pid_t * stride_row_id_map_token
        + num_experts * 2 * stride_row_id_map_expert
    )
    for idx in tl.range(n_routed):
        dst_row = tl.load(
            row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
342
        ).to(tl.int64)
343
        output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
344
        if PERMUTE_SCALE:
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
            permuted_scale_off = (
                dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
            )
            tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
        if PERMUTE_PROBS:
            expert_idx = tl.load(
                row_id_map_ptr
                + pid_t * stride_row_id_map_token
                + (num_experts + idx) * stride_row_id_map_expert
            )
            prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
            prob = tl.load(probs_ptr + prob_off)
            if pid_h == 0:
                permuted_prob_off = dst_row * stride_permuted_probs_token
                tl.store(permuted_probs_ptr + permuted_prob_off, prob)
            if prob == 0.0:
                # for routing_map padding
                # dst_row != -1 and prob == 0.0 means that this slot is padded
363
                tl.store(output_ptr + output_off, 0.0, mask=mask)
364
            else:
365
                tl.store(output_ptr + output_off, inp, mask=mask)
366
367
        else:
            tl.store(output_ptr + output_off, inp, mask=mask)
368
369


370
371
372
373
374
375
376
377
try:
    _permute_kernel = triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 64}),
            triton.Config({"BLOCK_SIZE": 128}),
            triton.Config({"BLOCK_SIZE": 256}),
            triton.Config({"BLOCK_SIZE": 512}),
            triton.Config({"BLOCK_SIZE": 1024}),
378
379
            triton.Config({"BLOCK_SIZE": 2048}),
            triton.Config({"BLOCK_SIZE": 4096}),
380
381
382
383
384
385
386
        ],
        key=["hidden_size"],
    )(_permute_kernel)
except RuntimeError:
    pass


387
388
389
def permute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
390
    probs: torch.Tensor,
391
    scale: torch.Tensor,
392
393
394
395
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
396
    scale_hidden_dim: int,
397
):
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    """
    Permute the input tensor based on the row_id_map.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
    row_id_map: torch.Tensor
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    probs: torch.Tensor
        The probabilities of the input tensor. If it is not None, it will be permuted.
    scale: torch.Tensor
        The scale of the input tensor. If it is not None, it will be permuted.
    num_tokens: int
        Number of tokens in the input tensor.
    num_experts: int
        Number of experts in the input tensor.
    num_out_tokens: int
        Number of tokens in the permuted tensor.
    hidden_size: int
        Hidden size of the input tensor.
    scale_hidden_dim: int
        Hidden size of the scale tensor.
    """
422
    output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
423
424
425
426
    if probs is not None:
        permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
427
428
429
430
431
432
433

    if scale is not None:
        permuted_scale = torch.empty(
            (num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
        )
    else:
        permuted_scale = None
434
435
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
436
437
438
439
    _permute_kernel[grid](
        inp,
        output,
        row_id_map,
440
        probs,
441
        scale,
442
        permuted_probs,
443
        permuted_scale,
444
445
        num_experts,
        hidden_size,
446
        scale_hidden_dim,
447
448
        row_id_map.stride(0),
        row_id_map.stride(1),
449
450
451
452
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
453
454
        probs.stride(0) if probs is not None else None,
        probs.stride(1) if probs is not None else None,
455
456
        scale.stride(0) if scale is not None else None,
        scale.stride(1) if scale is not None else None,
457
        permuted_probs.stride(0) if permuted_probs is not None else None,
458
459
        permuted_scale.stride(0) if permuted_scale is not None else None,
        permuted_scale.stride(1) if permuted_scale is not None else None,
460
        PERMUTE_PROBS=probs is not None,
461
        PERMUTE_SCALE=scale is not None,
462
    )
463
    return output, permuted_scale, permuted_probs
464
465
466
467
468
469
470
471


@triton.jit
def _unpermute_kernel(
    # pointers
    input_ptr,
    output_ptr,
    row_id_map_ptr,
472
473
474
    merging_probs_ptr,
    permuted_probs_ptr,
    unpermuted_probs_ptr,
475
    # sizes
476
477
    num_experts: tl.constexpr,
    hidden_size: tl.constexpr,
478
    # strides
479
480
    stride_row_id_map_token,
    stride_row_id_map_expert,
481
482
483
484
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
485
486
487
488
489
    stride_merging_probs_token,
    stride_merging_probs_expert,
    stride_permuted_probs_token,
    stride_unpermuted_probs_token,
    stride_unpermuted_probs_expert,
490
    # metas
491
    PROBS_LOAD_WIDTH: tl.constexpr,
492
493
    WITH_MERGING_PROBS: tl.constexpr,
    PERMUTE_PROBS: tl.constexpr,
494
495
    BLOCK_SIZE: tl.constexpr,
):
496
    data_type = input_ptr.dtype.element_ty
497
    compute_type = tl.float32
498

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    pid_t = tl.program_id(0)
    pid_h = tl.program_id(1)
    current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = current_offset < hidden_size
    if PERMUTE_PROBS:
        # write 0.0 to probs_grad that are not routed
        if pid_h == 0:
            map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
            unpermuted_prob_off = (
                pid_t * stride_unpermuted_probs_token
                + stride_unpermuted_probs_expert * map_load_off
            )
            tl.store(
                unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts
            )
    accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
    n_routed = tl.load(
        row_id_map_ptr
        + pid_t * stride_row_id_map_token
        + num_experts * 2 * stride_row_id_map_expert
    )
    for idx in tl.range(n_routed):
        src_row = tl.load(
            row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
523
        ).to(tl.int64)
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        input_off = src_row * stride_input_token + current_offset * stride_input_hidden
        inp = tl.load(input_ptr + input_off, mask=mask)
        inp = inp.to(compute_type)
        if WITH_MERGING_PROBS:
            expert_idx = tl.load(
                row_id_map_ptr
                + pid_t * stride_row_id_map_token
                + (num_experts + idx) * stride_row_id_map_expert
            )
            merging_prob_off = (
                pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
            )
            merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
            inp *= merging_prob
        accumulator += inp
        if PERMUTE_PROBS:
            if pid_h == 0:
                expert_idx = tl.load(
                    row_id_map_ptr
                    + pid_t * stride_row_id_map_token
                    + (num_experts + idx) * stride_row_id_map_expert
                )
                unpermuted_prob_off = (
                    pid_t * stride_unpermuted_probs_token
                    + expert_idx * stride_unpermuted_probs_expert
                )
                permuted_prob_off = src_row * stride_permuted_probs_token
                prob = tl.load(permuted_probs_ptr + permuted_prob_off)
                tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
    accumulator = accumulator.to(data_type)
554
555
    dst_row = pid_t.to(tl.int64)
    output_off = dst_row * stride_output_token + current_offset * stride_output_hidden
556
    tl.store(output_ptr + output_off, accumulator, mask=mask)
557
558


559
560
561
562
563
564
565
566
try:
    _unpermute_kernel = triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 64}),
            triton.Config({"BLOCK_SIZE": 128}),
            triton.Config({"BLOCK_SIZE": 256}),
            triton.Config({"BLOCK_SIZE": 512}),
            triton.Config({"BLOCK_SIZE": 1024}),
567
568
            triton.Config({"BLOCK_SIZE": 2048}),
            triton.Config({"BLOCK_SIZE": 4096}),
569
570
571
572
573
574
575
        ],
        key=["hidden_size"],
    )(_unpermute_kernel)
except RuntimeError:
    pass


576
577
578
def unpermute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
579
580
    merging_probs: Union[torch.Tensor, None],
    permuted_probs: Union[torch.Tensor, None],
581
582
583
584
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
):
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    """
    Unpermute the input tensor based on the row_id_map.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_out_tokens, hidden_size]`.
    row_id_map: torch.Tensor
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    merging_probs: torch.Tensor
        The merging probabilities of the input tensor. If it is not None, it will be used as weights
        to reduce the unpermuted tokens.
    permuted_probs: torch.Tensor
        The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
    num_tokens: int
        Number of tokens in the permuted tensor.
    num_experts: int
        Number of experts in the permuted tensor.
    hidden_size: int
        Hidden size of the permuted tensor.
    """
606
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
607
608
609
610
611
612
    if permuted_probs is not None:
        unpermuted_probs = torch.empty(
            (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda"
        )
    else:
        unpermuted_probs = None
613
614
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
615
616
617
618
    _unpermute_kernel[grid](
        inp,
        output,
        row_id_map,
619
620
621
        merging_probs,
        permuted_probs,
        unpermuted_probs,
622
623
        num_experts,
        hidden_size,
624
625
        row_id_map.stride(0),
        row_id_map.stride(1),
626
627
628
629
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
630
631
632
633
634
        merging_probs.stride(0) if merging_probs is not None else None,
        merging_probs.stride(1) if merging_probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
        unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
635
        PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
636
637
        WITH_MERGING_PROBS=merging_probs is not None,
        PERMUTE_PROBS=permuted_probs is not None,
638
    )
639
    return output, unpermuted_probs
640
641
642


@triton.jit
643
def _unpermute_bwd_with_merging_probs_kernel(
644
645
646
647
    # pointers
    fwd_output_grad_ptr,
    fwd_input_grad_ptr,
    fwd_input_ptr,
648
649
    merging_probs_ptr,
    merging_probs_grad_ptr,
650
651
    row_id_map_ptr,
    # sizes
652
653
    num_experts: tl.constexpr,
    hidden_size: tl.constexpr,
654
    # strides
655
656
    stride_row_id_map_token,
    stride_row_id_map_expert,
657
658
659
660
661
662
    stride_fwd_output_grad_token,
    stride_fwd_output_grad_hidden,
    stride_fwd_input_grad_token,
    stride_fwd_input_grad_hidden,
    stride_fwd_input_token,
    stride_fwd_input_hidden,
663
664
665
666
    stride_merging_probs_token,
    stride_merging_probs_expert,
    stride_merging_probs_grad_token,
    stride_merging_probs_grad_expert,
667
    # metas
668
    PROBS_LOAD_WIDTH: tl.constexpr,
669
670
    BLOCK_SIZE: tl.constexpr,
):
671
    data_type = fwd_output_grad_ptr.dtype.element_ty
672
    compute_type = tl.float32
673
674

    pid = tl.program_id(0)
675
676
677
678
679
680
681
682
683
684
685
    map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
    token_probs_grad_off = (
        pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off
    )
    tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts)
    n_routed = tl.load(
        row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert
    )
    for idx in tl.range(n_routed):
        dst_row = tl.load(
            row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
686
        ).to(tl.int64)
687
688
689
690
691
692
693
694
695
696
        expert_idx = tl.load(
            row_id_map_ptr
            + pid * stride_row_id_map_token
            + (num_experts + idx) * stride_row_id_map_expert
        )
        prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
        current_start = 0
        while current_start < hidden_size:
            current_offset = current_start + tl.arange(0, BLOCK_SIZE)
            mask = current_offset < hidden_size
697
            src_row = pid.to(tl.int64)
698
            input_off = (
699
700
                src_row * stride_fwd_output_grad_token
                + current_offset * stride_fwd_output_grad_hidden
701
            )
702
703
704
705
706
707
708
709
710
711
712
            inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
            inp = inp.to(compute_type)
            merging_prob_off = (
                pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
            )
            merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
            output = inp * merging_prob
            output = output.to(data_type)
            output_off = (
                dst_row * stride_fwd_input_grad_token
                + current_offset * stride_fwd_input_grad_hidden
713
            )
714
715
716
717
718
719
720
721
722
723
724
725
726
            tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)

            fwd_input_off = (
                dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
            )
            fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
            prob_grad_accum += fwd_input.to(compute_type) * inp
            current_start += BLOCK_SIZE
        probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
        probs_grad_off = (
            pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert
        )
        tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
727
728


729
730
731
732
733
734
735
736
try:
    _unpermute_bwd_with_merging_probs_kernel = triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 64}),
            triton.Config({"BLOCK_SIZE": 128}),
            triton.Config({"BLOCK_SIZE": 256}),
            triton.Config({"BLOCK_SIZE": 512}),
            triton.Config({"BLOCK_SIZE": 1024}),
737
738
            triton.Config({"BLOCK_SIZE": 2048}),
            triton.Config({"BLOCK_SIZE": 4096}),
739
740
741
742
743
744
745
        ],
        key=["hidden_size"],
    )(_unpermute_bwd_with_merging_probs_kernel)
except RuntimeError:
    pass


746
def unpermute_with_mask_map_bwd_with_merging_probs(
747
748
749
    fwd_output_grad: torch.Tensor,
    row_id_map: torch.Tensor,
    fwd_input: torch.Tensor,
750
    merging_probs: torch.Tensor,
751
752
753
754
755
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
):
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
    """
    Unpermute backward pass kernel with merging probs.

    Parameters
    ----------
    fwd_output_grad: torch.Tensor
        The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
    row_id_map: torch.Tensor
        The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
    fwd_input: torch.Tensor
        The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
    merging_probs: torch.Tensor
        The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
    num_tokens: int
        Number of tokens in the permuted tensor.
    num_experts: int
        Number of experts in the permuted tensor.
    num_out_tokens: int
        Number of tokens in the output tensor.
    hidden_size: int
        Hidden size of the output tensor.
    """
778
779
780
    act_grad = torch.empty(
        (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
    )
781
782
783
    merging_probs_grad = torch.empty(
        (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
    )
784
    grid = (num_tokens,)
785
    _unpermute_bwd_with_merging_probs_kernel[grid](
786
787
788
        fwd_output_grad,
        act_grad,
        fwd_input,
789
790
        merging_probs,
        merging_probs_grad,
791
792
793
        row_id_map,
        num_experts,
        hidden_size,
794
795
        row_id_map.stride(0),
        row_id_map.stride(1),
796
797
798
799
800
801
        fwd_output_grad.stride(0),
        fwd_output_grad.stride(1),
        act_grad.stride(0),
        act_grad.stride(1),
        fwd_input.stride(0),
        fwd_input.stride(1),
802
803
804
805
        merging_probs.stride(0),
        merging_probs.stride(1),
        merging_probs_grad.stride(0),
        merging_probs_grad.stride(1),
806
        PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
807
    )
808
    return act_grad, merging_probs_grad
809
810
811


@triton.jit
812
def _make_chunk_sort_map_kernel(
813
814
815
816
817
    # pointers
    split_sizes_ptr,
    sorted_indices_ptr,
    dst_rows_ptr,
    # sizes
818
    num_splits: tl.constexpr,
819
820
821
822
823
824
825
826
827
828
829
    # metas
    IDX_LOAD_WIDTH: tl.constexpr,
):
    pid = tl.program_id(0)

    load_split_offset = tl.arange(0, IDX_LOAD_WIDTH)
    sorted_indices = tl.load(
        sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits
    )

    # get chunk idx of the current token in the input tensor
830
831
832
833
834
835
836
837
    input_split_sizes = tl.load(
        split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
    ).to(tl.int32)
    input_split_sizes_cumsum = tl.cumsum(input_split_sizes)
    input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
    input_chunk_idx = tl.sum(input_split_sizes_mask)
    input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
    in_chunk_offset = pid - input_split_sizes_presum
838
839

    # get chunk idx of the current token in the output tensor
840
841
    output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0)
    output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1)
842
843
844
845

    # make row_id_map
    output_split_sizes = tl.load(
        split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits
846
    ).to(tl.int32)
847
848
849
850
851
    output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
    dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
    tl.store(dst_rows_ptr + pid, dst_row)


852
def make_chunk_sort_map(
853
854
855
856
857
    split_sizes: torch.Tensor,
    sorted_indices: torch.Tensor,
    num_tokens: int,
    num_splits: int,
):
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
    """
    Make a row_id_map for chunk sort.

    Parameters
    ----------
    split_sizes: torch.Tensor
        The sizes of the chunks of shape `[num_splits,]`.
    sorted_indices: torch.Tensor
        The indices of the sorted chunks of shape `[num_splits,]`.
    num_tokens: int
        Number of tokens in the input tensor.
    num_splits: int
        Number of splits of split_sizes and sorted_indices.
    """
    row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda")
873
    grid = (num_tokens,)
874
    _make_chunk_sort_map_kernel[grid](
875
876
877
878
        split_sizes,
        sorted_indices,
        row_id_map,
        num_splits,
879
        IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits),
880
    )
881
    return row_id_map
882
883
884


@triton.jit
885
def _sort_chunks_by_map_kernel(
886
887
888
889
    # pointers
    input_ptr,
    output_ptr,
    row_id_map_ptr,
890
891
    probs_ptr,
    permuted_probs_ptr,
892
    # sizes
893
    hidden_size: tl.constexpr,
894
895
896
897
898
    # strides
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
899
900
    stride_probs_token,
    stride_permuted_probs_token,
901
    # metas
902
    PERMUTE_PROBS: tl.constexpr,
903
    BLOCK_SIZE: tl.constexpr,
904
    FORWARD: tl.constexpr,
905
):
906
907
908
    pid_t = tl.program_id(0)
    pid_h = tl.program_id(1)
    if FORWARD:
909
910
        src_row = pid_t.to(tl.int64)
        dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
911
    else:
912
913
        src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
        dst_row = pid_t.to(tl.int64)
914
915
916
917
918
919
    current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = current_offset < hidden_size
    input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden
    output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
    inp = tl.load(input_ptr + input_offsets, mask=mask)
    tl.store(output_ptr + output_offsets, inp, mask=mask)
920
    if PERMUTE_PROBS:
921
922
923
924
925
        if pid_h == 0:
            prob_off = src_row * stride_probs_token
            prob = tl.load(probs_ptr + prob_off)
            permuted_prob_off = dst_row * stride_permuted_probs_token
            tl.store(permuted_probs_ptr + permuted_prob_off, prob)
926
927


928
929
930
931
932
933
934
935
try:
    _sort_chunks_by_map_kernel = triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 64}),
            triton.Config({"BLOCK_SIZE": 128}),
            triton.Config({"BLOCK_SIZE": 256}),
            triton.Config({"BLOCK_SIZE": 512}),
            triton.Config({"BLOCK_SIZE": 1024}),
936
937
            triton.Config({"BLOCK_SIZE": 2048}),
            triton.Config({"BLOCK_SIZE": 4096}),
938
939
940
941
942
943
944
        ],
        key=["hidden_size"],
    )(_sort_chunks_by_map_kernel)
except RuntimeError:
    pass


945
946
947
def sort_chunks_by_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
948
    probs: torch.Tensor,
949
950
    num_tokens: int,
    hidden_size: int,
951
    is_forward: bool,
952
):
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
    """
    Sort chunks with row_id_map.

    Parameters
    ----------
    inp: torch.Tensor
        Input tensor of shape `[num_tokens, hidden_size]`.
    row_id_map: torch.Tensor
        The token to expert mapping tensor of shape `[num_tokens,]`.
    probs: torch.Tensor
        The probabilities of the input tensor. If it is not None, it will be permuted.
    num_tokens: int
        Number of tokens in the input tensor.
    hidden_size: int
        Hidden size of the input tensor.
    is_forward: bool
        Whether the sort is for forward or backward.
    """
971
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
972
973
974
975
    if probs is not None:
        permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
976
977
    # pylint: disable=unnecessary-lambda-assignment
    grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
978
    _sort_chunks_by_map_kernel[grid](
979
980
981
        inp,
        output,
        row_id_map,
982
983
        probs,
        permuted_probs,
984
985
986
987
988
        hidden_size,
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
989
990
991
        probs.stride(0) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        PERMUTE_PROBS=probs is not None,
992
        FORWARD=is_forward,
993
    )
994
    return output, permuted_probs