permutation.py 23.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Permutation kernels written with OpenAI Triton."""

from typing import Union

import torch
import triton
import triton.language as tl

from transformer_engine_torch import DType as TE_DType


@triton.jit
def _row_id_map_pass_1_kernel(
    # pointers
    routing_map_ptr,
    row_id_map_ptr,
    workspace_ptr,
    # sizes
    num_tokens,
    # strides
    stride_routing_map_token,
    stride_routing_map_expert,
    # metas
    BLOCK_SIZE: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    expert_token_mask = tl.load(
        routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
        mask=(offset < num_tokens),
        other=0,
    ).to(tl.int64)
    row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
    tl.store(
        row_id_map_ptr + pid_m * num_tokens + offset,
        row_id_within_token_block,
        mask=offset < num_tokens,
    )
    n_tokens_per_block = tl.sum(expert_token_mask)
    tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block)


@triton.jit
def _row_id_map_pass_2_kernel(
    # pointers
    row_id_map_ptr,
    workspace_ptr,
    # sizes
    num_tokens,
    # metas
    WORKSPACE_LOAD_WIDTH: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
    offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    row_id_within_token_block = tl.load(
        row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0
    )

    workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
    n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
    row_id = tl.where(
        row_id_within_token_block == 0,
        -1,
        row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
    )
    tl.store(
        row_id_map_ptr + pid_m * num_tokens + offset,
        row_id,
        mask=(offset < num_tokens),
    )


def make_row_id_map(
    routing_map: torch.Tensor,
    num_tokens: int,
    num_experts: int,
):
    # pylint: disable=missing-function-docstring
    row_id_map = torch.empty((num_experts, num_tokens), dtype=torch.int64, device="cuda")
    block_size = 256
    grid = (num_experts, triton.cdiv(num_tokens, block_size))
    workspace_tensor = torch.empty(grid, dtype=torch.int64, device="cuda")
    # block cumsum
    _row_id_map_pass_1_kernel[grid](
        routing_map,
        row_id_map,
        workspace_tensor,
        num_tokens,
        routing_map.stride(0),
        routing_map.stride(1),
        block_size,
    )
    # cumsum all and process the mask
    _row_id_map_pass_2_kernel[grid](
        row_id_map,
        workspace_tensor,
        num_tokens,
        triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
        block_size,
    )
    return row_id_map


@triton.jit
def _permute_kernel(
    # pointers
    input_ptr,
    output_ptr,
    row_id_map_ptr,
118
119
    probs_ptr,
    permuted_probs_ptr,
120
121
122
123
124
125
126
127
128
    # sizes
    num_tokens,
    num_experts,
    hidden_size,
    # strides
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
129
130
131
    stride_probs_token,
    stride_probs_expert,
    stride_permuted_probs_token,
132
    # metas
133
    PERMUTE_PROBS: tl.constexpr,
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    cur_pos = 0
    while cur_pos < hidden_size:
        cur_off = cur_pos + tl.arange(0, BLOCK_SIZE)
        mask = cur_off < hidden_size
        input_off = pid * stride_input_token + cur_off * stride_input_hidden
        inp = tl.load(input_ptr + input_off, mask=mask)
        for expert_idx in range(num_experts):
            dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
            if dst_row != -1:
                output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
                tl.store(output_ptr + output_off, inp, mask=mask)
148
149
150
151
152
153
                if PERMUTE_PROBS:
                    if cur_pos == 0:
                        prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
                        prob = tl.load(probs_ptr + prob_off)
                        permuted_prob_off = dst_row * stride_permuted_probs_token
                        tl.store(permuted_probs_ptr + permuted_prob_off, prob)
154
155
156
        cur_pos += BLOCK_SIZE


157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
try:
    _permute_kernel = triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 64}),
            triton.Config({"BLOCK_SIZE": 128}),
            triton.Config({"BLOCK_SIZE": 256}),
            triton.Config({"BLOCK_SIZE": 512}),
            triton.Config({"BLOCK_SIZE": 1024}),
        ],
        key=["hidden_size"],
    )(_permute_kernel)
except RuntimeError:
    pass


172
173
174
def permute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
175
    probs: torch.Tensor,
176
177
178
179
180
181
182
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
):
    # pylint: disable=missing-function-docstring
    output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
183
184
185
186
    if probs is not None:
        permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
187
188
189
190
191
    grid = (num_tokens,)
    _permute_kernel[grid](
        inp,
        output,
        row_id_map,
192
193
        probs,
        permuted_probs,
194
195
196
197
198
199
200
        num_tokens,
        num_experts,
        hidden_size,
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
201
202
203
204
        probs.stride(0) if probs is not None else None,
        probs.stride(1) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        PERMUTE_PROBS=probs is not None,
205
    )
206
    return output, permuted_probs
207
208
209
210
211
212
213
214


@triton.jit
def _unpermute_kernel(
    # pointers
    input_ptr,
    output_ptr,
    row_id_map_ptr,
215
216
217
    merging_probs_ptr,
    permuted_probs_ptr,
    unpermuted_probs_ptr,
218
219
220
221
222
223
224
225
226
    # sizes
    num_tokens,
    num_experts,
    hidden_size,
    # strides
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
227
228
229
230
231
    stride_merging_probs_token,
    stride_merging_probs_expert,
    stride_permuted_probs_token,
    stride_unpermuted_probs_token,
    stride_unpermuted_probs_expert,
232
    # metas
233
234
    WITH_MERGING_PROBS: tl.constexpr,
    PERMUTE_PROBS: tl.constexpr,
235
236
237
238
239
240
241
242
243
244
    FP8_DTYPE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    if FP8_DTYPE == "e5m2":
        data_type = tl.float8e5
        pytorch_tensor_dtype = tl.uint8
    elif FP8_DTYPE == "e4m3":
        data_type = tl.float8e4nv
        pytorch_tensor_dtype = tl.uint8
    else:
245
        data_type = input_ptr.dtype.element_ty
246
        assert FP8_DTYPE is None
247
    compute_type = tl.float32
248
249
250
251
252
253
254
255
256
257
258
259
260

    pid = tl.program_id(0)
    current_start = 0
    while current_start < hidden_size:
        current_offset = current_start + tl.arange(0, BLOCK_SIZE)
        mask = current_offset < hidden_size
        accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
        for expert_idx in range(num_experts):
            src_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
            if src_row != -1:
                input_off = src_row * stride_input_token + current_offset * stride_input_hidden
                inp = tl.load(input_ptr + input_off, mask=mask)
                if FP8_DTYPE is not None:
261
262
263
264
265
266
267
268
                    inp = inp.to(data_type, bitcast=True)
                inp = inp.to(compute_type)
                if WITH_MERGING_PROBS:
                    merging_prob_off = (
                        pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
                    )
                    merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
                    inp *= merging_prob
269
                accumulator += inp
270
271
272
273
274
275
276
277
278
279
280
281
            if PERMUTE_PROBS:
                if current_start == 0:
                    unpermuted_prob_off = (
                        pid * stride_unpermuted_probs_token
                        + expert_idx * stride_unpermuted_probs_expert
                    )
                    if src_row != -1:
                        permuted_prob_off = src_row * stride_permuted_probs_token
                        prob = tl.load(permuted_probs_ptr + permuted_prob_off)
                        tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
                    else:
                        tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0)
282
        if FP8_DTYPE is not None:
283
            if not WITH_MERGING_PROBS:
284
285
286
287
                # Directly adding these value may cause overflow for fp8, we scale it here.
                # The outside fp8_scale_inv is also scaled in the meantime.
                accumulator /= num_experts
            accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
288
289
        else:
            accumulator = accumulator.to(data_type)
290
291
292
293
294
        output_off = pid * stride_output_token + current_offset * stride_output_hidden
        tl.store(output_ptr + output_off, accumulator, mask=mask)
        current_start += BLOCK_SIZE


295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
try:
    _unpermute_kernel = triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 64}),
            triton.Config({"BLOCK_SIZE": 128}),
            triton.Config({"BLOCK_SIZE": 256}),
            triton.Config({"BLOCK_SIZE": 512}),
            triton.Config({"BLOCK_SIZE": 1024}),
        ],
        key=["hidden_size"],
    )(_unpermute_kernel)
except RuntimeError:
    pass


310
311
312
def unpermute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
313
314
    merging_probs: Union[torch.Tensor, None],
    permuted_probs: Union[torch.Tensor, None],
315
316
317
318
319
320
321
322
323
324
325
326
327
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
    fp8_dtype: TE_DType,
):
    # pylint: disable=missing-function-docstring
    if fp8_dtype == TE_DType.kFloat8E5M2:
        fp8_dtype = "e5m2"
    elif fp8_dtype == TE_DType.kFloat8E4M3:
        fp8_dtype = "e4m3"
    else:
        fp8_dtype = None
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
328
329
330
331
332
333
    if permuted_probs is not None:
        unpermuted_probs = torch.empty(
            (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda"
        )
    else:
        unpermuted_probs = None
334
335
336
337
338
    grid = (num_tokens,)
    _unpermute_kernel[grid](
        inp,
        output,
        row_id_map,
339
340
341
        merging_probs,
        permuted_probs,
        unpermuted_probs,
342
343
344
345
346
347
348
        num_tokens,
        num_experts,
        hidden_size,
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
349
350
351
352
353
354
355
        merging_probs.stride(0) if merging_probs is not None else None,
        merging_probs.stride(1) if merging_probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
        unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
        WITH_MERGING_PROBS=merging_probs is not None,
        PERMUTE_PROBS=permuted_probs is not None,
356
357
        FP8_DTYPE=fp8_dtype,
    )
358
    return output, unpermuted_probs
359
360
361


@triton.jit
362
def _unpermute_bwd_with_merging_probs_kernel(
363
364
365
366
    # pointers
    fwd_output_grad_ptr,
    fwd_input_grad_ptr,
    fwd_input_ptr,
367
368
    merging_probs_ptr,
    merging_probs_grad_ptr,
369
370
371
372
373
374
375
376
377
378
379
380
    row_id_map_ptr,
    # sizes
    num_tokens,
    num_experts,
    hidden_size,
    # strides
    stride_fwd_output_grad_token,
    stride_fwd_output_grad_hidden,
    stride_fwd_input_grad_token,
    stride_fwd_input_grad_hidden,
    stride_fwd_input_token,
    stride_fwd_input_hidden,
381
382
383
384
    stride_merging_probs_token,
    stride_merging_probs_expert,
    stride_merging_probs_grad_token,
    stride_merging_probs_grad_expert,
385
386
387
388
389
390
391
392
393
394
395
    # metas
    FP8_DTYPE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    if FP8_DTYPE == "e5m2":
        data_type = tl.float8e5
        pytorch_tensor_dtype = tl.uint8
    elif FP8_DTYPE == "e4m3":
        data_type = tl.float8e4nv
        pytorch_tensor_dtype = tl.uint8
    else:
396
        data_type = fwd_output_grad_ptr.dtype.element_ty
397
        assert FP8_DTYPE is None
398
    compute_type = tl.float32
399
400
401
402
403

    pid = tl.program_id(0)
    for expert_idx in range(num_experts):
        dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
        if dst_row != -1:
404
            prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
405
406
407
408
409
410
411
412
413
414
            current_start = 0
            while current_start < hidden_size:
                current_offset = current_start + tl.arange(0, BLOCK_SIZE)
                mask = current_offset < hidden_size
                input_off = (
                    pid * stride_fwd_output_grad_token
                    + current_offset * stride_fwd_output_grad_hidden
                )
                inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
                if FP8_DTYPE is not None:
415
416
417
418
419
420
421
422
                    inp = inp.to(data_type, bitcast=True)
                inp = inp.to(compute_type)
                merging_prob_off = (
                    pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
                )
                merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
                output = inp * merging_prob
                output = output.to(data_type)
423
                if FP8_DTYPE is not None:
424
                    output = output.to(pytorch_tensor_dtype, bitcast=True)
425
426
427
428
429
430
431
432
433
434
435
436
                output_off = (
                    dst_row * stride_fwd_input_grad_token
                    + current_offset * stride_fwd_input_grad_hidden
                )
                tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)

                fwd_input_off = (
                    dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
                )
                fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
                if FP8_DTYPE is not None:
                    fwd_input = fwd_input.to(data_type, bitcast=True)
437
                prob_grad_accum += fwd_input.to(compute_type) * inp
438
                current_start += BLOCK_SIZE
439
440
441
442
443
444
            probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
            probs_grad_off = (
                pid * stride_merging_probs_grad_token
                + expert_idx * stride_merging_probs_grad_expert
            )
            tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
445
        else:
446
447
448
449
450
            probs_grad_off = (
                pid * stride_merging_probs_grad_token
                + expert_idx * stride_merging_probs_grad_expert
            )
            tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0)
451
452


453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
try:
    _unpermute_bwd_with_merging_probs_kernel = triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 64}),
            triton.Config({"BLOCK_SIZE": 128}),
            triton.Config({"BLOCK_SIZE": 256}),
            triton.Config({"BLOCK_SIZE": 512}),
            triton.Config({"BLOCK_SIZE": 1024}),
        ],
        key=["hidden_size"],
    )(_unpermute_bwd_with_merging_probs_kernel)
except RuntimeError:
    pass


468
def unpermute_with_mask_map_bwd_with_merging_probs(
469
470
471
    fwd_output_grad: torch.Tensor,
    row_id_map: torch.Tensor,
    fwd_input: torch.Tensor,
472
    merging_probs: torch.Tensor,
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
    fp8_dtype: TE_DType,
):
    # pylint: disable=missing-function-docstring
    if fp8_dtype == TE_DType.kFloat8E5M2:
        fp8_dtype = "e5m2"
    elif fp8_dtype == TE_DType.kFloat8E4M3:
        fp8_dtype = "e4m3"
    else:
        fp8_dtype = None
    act_grad = torch.empty(
        (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
    )
489
490
491
    merging_probs_grad = torch.empty(
        (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
    )
492
    grid = (num_tokens,)
493
    _unpermute_bwd_with_merging_probs_kernel[grid](
494
495
496
        fwd_output_grad,
        act_grad,
        fwd_input,
497
498
        merging_probs,
        merging_probs_grad,
499
500
501
502
503
504
505
506
507
508
        row_id_map,
        num_tokens,
        num_experts,
        hidden_size,
        fwd_output_grad.stride(0),
        fwd_output_grad.stride(1),
        act_grad.stride(0),
        act_grad.stride(1),
        fwd_input.stride(0),
        fwd_input.stride(1),
509
510
511
512
        merging_probs.stride(0),
        merging_probs.stride(1),
        merging_probs_grad.stride(0),
        merging_probs_grad.stride(1),
513
514
        fp8_dtype,
    )
515
    return act_grad, merging_probs_grad
516
517
518
519
520
521
522
523
524
525


@triton.jit
def _sort_chunks_by_idxs_kernel(
    # pointers
    input_ptr,
    split_sizes_ptr,
    sorted_indices_ptr,
    output_ptr,
    dst_rows_ptr,
526
527
    probs_ptr,
    permuted_probs_ptr,
528
529
530
531
532
533
534
535
    # sizes
    num_splits,
    hidden_size,
    # strides
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
536
537
    stride_probs_token,
    stride_permuted_probs_token,
538
    # metas
539
    PERMUTE_PROBS: tl.constexpr,
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    IDX_LOAD_WIDTH: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    load_split_offset = tl.arange(0, IDX_LOAD_WIDTH)
    sorted_indices = tl.load(
        sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits
    )

    # get chunk idx of the current token in the input tensor
    input_chunk_idx = -1
    in_chunk_offset = tl.zeros([], dtype=tl.int64)
    acc_chunk_sizes = tl.zeros([], dtype=tl.int64)
    cursor = 0
    while cursor < num_splits:
        cur_chunk_size = tl.load(split_sizes_ptr + cursor).to(tl.int64)
        acc_chunk_sizes += cur_chunk_size
        if input_chunk_idx == -1 and acc_chunk_sizes > pid:
            input_chunk_idx = cursor
            in_chunk_offset = pid - (acc_chunk_sizes - cur_chunk_size)
        cursor += 1

    # get chunk idx of the current token in the output tensor
    output_chunk_idx = 0
    cursor = 0
    while cursor < num_splits:
        cur_input_idx = tl.load(sorted_indices_ptr + cursor)
        if cur_input_idx == input_chunk_idx:
            output_chunk_idx = cursor
        cursor += 1

    # make row_id_map
    output_split_sizes = tl.load(
        split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits
    ).to(tl.int64)
    output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
    dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
    tl.store(dst_rows_ptr + pid, dst_row)

    current_start = 0
    while current_start < hidden_size:
        current_offset = current_start + tl.arange(0, BLOCK_SIZE)
        mask = current_offset < hidden_size
        input_offsets = pid * stride_input_token + current_offset * stride_input_hidden
        output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
        inp = tl.load(input_ptr + input_offsets, mask=mask)
        tl.store(output_ptr + output_offsets, inp, mask=mask)
        current_start += BLOCK_SIZE

590
591
592
593
594
595
    if PERMUTE_PROBS:
        prob_off = pid * stride_probs_token
        prob = tl.load(probs_ptr + prob_off)
        permuted_prob_off = dst_row * stride_permuted_probs_token
        tl.store(permuted_probs_ptr + permuted_prob_off, prob)

596

597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
try:
    _sort_chunks_by_idxs_kernel = triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 64}),
            triton.Config({"BLOCK_SIZE": 128}),
            triton.Config({"BLOCK_SIZE": 256}),
            triton.Config({"BLOCK_SIZE": 512}),
            triton.Config({"BLOCK_SIZE": 1024}),
        ],
        key=["hidden_size"],
    )(_sort_chunks_by_idxs_kernel)
except RuntimeError:
    pass


612
613
614
615
def sort_chunks_by_idx(
    inp: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_indices: torch.Tensor,
616
    probs: torch.Tensor,
617
618
619
620
621
622
623
    num_tokens: int,
    hidden_size: int,
    num_splits: int,
):
    # pylint: disable=missing-function-docstring
    row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda")
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
624
625
626
627
    if probs is not None:
        permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
628
629
630
631
632
633
634
    grid = (num_tokens,)
    _sort_chunks_by_idxs_kernel[grid](
        inp,
        split_sizes,
        sorted_indices,
        output,
        row_id_map,
635
636
        probs,
        permuted_probs,
637
638
639
640
641
642
        num_splits,
        hidden_size,
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
643
644
645
646
        probs.stride(0) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        PERMUTE_PROBS=probs is not None,
        IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits),
647
    )
648
    return output, row_id_map, permuted_probs
649
650
651


@triton.jit
652
def _sort_chunks_by_map_kernel(
653
654
655
656
    # pointers
    input_ptr,
    output_ptr,
    row_id_map_ptr,
657
658
    probs_ptr,
    permuted_probs_ptr,
659
660
661
662
663
664
665
    # sizes
    hidden_size,
    # strides
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
666
667
    stride_probs_token,
    stride_permuted_probs_token,
668
    # metas
669
    PERMUTE_PROBS: tl.constexpr,
670
671
672
673
674
675
676
677
678
679
680
681
682
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    dst_row = tl.load(row_id_map_ptr + pid)
    current_start = 0
    while current_start < hidden_size:
        current_offset = current_start + tl.arange(0, BLOCK_SIZE)
        mask = current_offset < hidden_size
        input_offsets = dst_row * stride_input_token + current_offset * stride_input_hidden
        output_offsets = pid * stride_output_token + current_offset * stride_output_hidden
        inp = tl.load(input_ptr + input_offsets, mask=mask)
        tl.store(output_ptr + output_offsets, inp, mask=mask)
        current_start += BLOCK_SIZE
683
684
685
686
687
    if PERMUTE_PROBS:
        prob_off = dst_row * stride_probs_token
        prob = tl.load(probs_ptr + prob_off)
        permuted_prob_off = pid * stride_permuted_probs_token
        tl.store(permuted_probs_ptr + permuted_prob_off, prob)
688
689


690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
try:
    _sort_chunks_by_map_kernel = triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 64}),
            triton.Config({"BLOCK_SIZE": 128}),
            triton.Config({"BLOCK_SIZE": 256}),
            triton.Config({"BLOCK_SIZE": 512}),
            triton.Config({"BLOCK_SIZE": 1024}),
        ],
        key=["hidden_size"],
    )(_sort_chunks_by_map_kernel)
except RuntimeError:
    pass


705
706
707
def sort_chunks_by_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
708
    probs: torch.Tensor,
709
710
711
712
713
    num_tokens: int,
    hidden_size: int,
):
    # pylint: disable=missing-function-docstring
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
714
715
716
717
    if probs is not None:
        permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
718
    grid = (num_tokens,)
719
    _sort_chunks_by_map_kernel[grid](
720
721
722
        inp,
        output,
        row_id_map,
723
724
        probs,
        permuted_probs,
725
726
727
728
729
        hidden_size,
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
730
731
732
        probs.stride(0) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        PERMUTE_PROBS=probs is not None,
733
    )
734
    return output, permuted_probs