permutation.py 22.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Permutation kernels written with OpenAI Triton."""

from typing import Union

import torch
import triton
import triton.language as tl

from transformer_engine_torch import DType as TE_DType


@triton.jit
def _row_id_map_pass_1_kernel(
    # pointers
    routing_map_ptr,
    row_id_map_ptr,
    workspace_ptr,
    # sizes
    num_tokens,
    # strides
    stride_routing_map_token,
    stride_routing_map_expert,
    # metas
    BLOCK_SIZE: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    expert_token_mask = tl.load(
        routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
        mask=(offset < num_tokens),
        other=0,
    ).to(tl.int64)
    row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
    tl.store(
        row_id_map_ptr + pid_m * num_tokens + offset,
        row_id_within_token_block,
        mask=offset < num_tokens,
    )
    n_tokens_per_block = tl.sum(expert_token_mask)
    tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block)


@triton.jit
def _row_id_map_pass_2_kernel(
    # pointers
    row_id_map_ptr,
    workspace_ptr,
    # sizes
    num_tokens,
    # metas
    WORKSPACE_LOAD_WIDTH: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
    offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    row_id_within_token_block = tl.load(
        row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0
    )

    workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
    n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
    row_id = tl.where(
        row_id_within_token_block == 0,
        -1,
        row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
    )
    tl.store(
        row_id_map_ptr + pid_m * num_tokens + offset,
        row_id,
        mask=(offset < num_tokens),
    )


def make_row_id_map(
    routing_map: torch.Tensor,
    num_tokens: int,
    num_experts: int,
):
    # pylint: disable=missing-function-docstring
    row_id_map = torch.empty((num_experts, num_tokens), dtype=torch.int64, device="cuda")
    block_size = 256
    grid = (num_experts, triton.cdiv(num_tokens, block_size))
    workspace_tensor = torch.empty(grid, dtype=torch.int64, device="cuda")
    # block cumsum
    _row_id_map_pass_1_kernel[grid](
        routing_map,
        row_id_map,
        workspace_tensor,
        num_tokens,
        routing_map.stride(0),
        routing_map.stride(1),
        block_size,
    )
    # cumsum all and process the mask
    _row_id_map_pass_2_kernel[grid](
        row_id_map,
        workspace_tensor,
        num_tokens,
        triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
        block_size,
    )
    return row_id_map


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 64}),
        triton.Config({"BLOCK_SIZE": 128}),
        triton.Config({"BLOCK_SIZE": 256}),
        triton.Config({"BLOCK_SIZE": 512}),
        triton.Config({"BLOCK_SIZE": 1024}),
    ],
    key=["hidden_size"],
)
@triton.jit
def _permute_kernel(
    # pointers
    input_ptr,
    output_ptr,
    row_id_map_ptr,
128
129
    probs_ptr,
    permuted_probs_ptr,
130
131
132
133
134
135
136
137
138
    # sizes
    num_tokens,
    num_experts,
    hidden_size,
    # strides
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
139
140
141
    stride_probs_token,
    stride_probs_expert,
    stride_permuted_probs_token,
142
    # metas
143
    PERMUTE_PROBS: tl.constexpr,
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    cur_pos = 0
    while cur_pos < hidden_size:
        cur_off = cur_pos + tl.arange(0, BLOCK_SIZE)
        mask = cur_off < hidden_size
        input_off = pid * stride_input_token + cur_off * stride_input_hidden
        inp = tl.load(input_ptr + input_off, mask=mask)
        for expert_idx in range(num_experts):
            dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
            if dst_row != -1:
                output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
                tl.store(output_ptr + output_off, inp, mask=mask)
158
159
160
161
162
163
                if PERMUTE_PROBS:
                    if cur_pos == 0:
                        prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
                        prob = tl.load(probs_ptr + prob_off)
                        permuted_prob_off = dst_row * stride_permuted_probs_token
                        tl.store(permuted_probs_ptr + permuted_prob_off, prob)
164
165
166
167
168
169
        cur_pos += BLOCK_SIZE


def permute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
170
    probs: torch.Tensor,
171
172
173
174
175
176
177
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
):
    # pylint: disable=missing-function-docstring
    output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
178
179
180
181
    if probs is not None:
        permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
182
183
184
185
186
    grid = (num_tokens,)
    _permute_kernel[grid](
        inp,
        output,
        row_id_map,
187
188
        probs,
        permuted_probs,
189
190
191
192
193
194
195
        num_tokens,
        num_experts,
        hidden_size,
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
196
197
198
199
        probs.stride(0) if probs is not None else None,
        probs.stride(1) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        PERMUTE_PROBS=probs is not None,
200
    )
201
    return output, permuted_probs
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 64}),
        triton.Config({"BLOCK_SIZE": 128}),
        triton.Config({"BLOCK_SIZE": 256}),
        triton.Config({"BLOCK_SIZE": 512}),
        triton.Config({"BLOCK_SIZE": 1024}),
    ],
    key=["hidden_size"],
)
@triton.jit
def _unpermute_kernel(
    # pointers
    input_ptr,
    output_ptr,
    row_id_map_ptr,
220
221
222
    merging_probs_ptr,
    permuted_probs_ptr,
    unpermuted_probs_ptr,
223
224
225
226
227
228
229
230
231
    # sizes
    num_tokens,
    num_experts,
    hidden_size,
    # strides
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
232
233
234
235
236
    stride_merging_probs_token,
    stride_merging_probs_expert,
    stride_permuted_probs_token,
    stride_unpermuted_probs_token,
    stride_unpermuted_probs_expert,
237
    # metas
238
239
    WITH_MERGING_PROBS: tl.constexpr,
    PERMUTE_PROBS: tl.constexpr,
240
241
242
243
244
245
246
247
248
249
    FP8_DTYPE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    if FP8_DTYPE == "e5m2":
        data_type = tl.float8e5
        pytorch_tensor_dtype = tl.uint8
    elif FP8_DTYPE == "e4m3":
        data_type = tl.float8e4nv
        pytorch_tensor_dtype = tl.uint8
    else:
250
        data_type = input_ptr.dtype.element_ty
251
        assert FP8_DTYPE is None
252
    compute_type = tl.float32
253
254
255
256
257
258
259
260
261
262
263
264
265

    pid = tl.program_id(0)
    current_start = 0
    while current_start < hidden_size:
        current_offset = current_start + tl.arange(0, BLOCK_SIZE)
        mask = current_offset < hidden_size
        accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
        for expert_idx in range(num_experts):
            src_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
            if src_row != -1:
                input_off = src_row * stride_input_token + current_offset * stride_input_hidden
                inp = tl.load(input_ptr + input_off, mask=mask)
                if FP8_DTYPE is not None:
266
267
268
269
270
271
272
273
                    inp = inp.to(data_type, bitcast=True)
                inp = inp.to(compute_type)
                if WITH_MERGING_PROBS:
                    merging_prob_off = (
                        pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
                    )
                    merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
                    inp *= merging_prob
274
                accumulator += inp
275
276
277
278
279
280
281
282
283
284
285
286
            if PERMUTE_PROBS:
                if current_start == 0:
                    unpermuted_prob_off = (
                        pid * stride_unpermuted_probs_token
                        + expert_idx * stride_unpermuted_probs_expert
                    )
                    if src_row != -1:
                        permuted_prob_off = src_row * stride_permuted_probs_token
                        prob = tl.load(permuted_probs_ptr + permuted_prob_off)
                        tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
                    else:
                        tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0)
287
        if FP8_DTYPE is not None:
288
            if not WITH_MERGING_PROBS:
289
290
291
292
                # Directly adding these value may cause overflow for fp8, we scale it here.
                # The outside fp8_scale_inv is also scaled in the meantime.
                accumulator /= num_experts
            accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
293
294
        else:
            accumulator = accumulator.to(data_type)
295
296
297
298
299
300
301
302
        output_off = pid * stride_output_token + current_offset * stride_output_hidden
        tl.store(output_ptr + output_off, accumulator, mask=mask)
        current_start += BLOCK_SIZE


def unpermute_with_mask_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
303
304
    merging_probs: Union[torch.Tensor, None],
    permuted_probs: Union[torch.Tensor, None],
305
306
307
308
309
310
311
312
313
314
315
316
317
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
    fp8_dtype: TE_DType,
):
    # pylint: disable=missing-function-docstring
    if fp8_dtype == TE_DType.kFloat8E5M2:
        fp8_dtype = "e5m2"
    elif fp8_dtype == TE_DType.kFloat8E4M3:
        fp8_dtype = "e4m3"
    else:
        fp8_dtype = None
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
318
319
320
321
322
323
    if permuted_probs is not None:
        unpermuted_probs = torch.empty(
            (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda"
        )
    else:
        unpermuted_probs = None
324
325
326
327
328
    grid = (num_tokens,)
    _unpermute_kernel[grid](
        inp,
        output,
        row_id_map,
329
330
331
        merging_probs,
        permuted_probs,
        unpermuted_probs,
332
333
334
335
336
337
338
        num_tokens,
        num_experts,
        hidden_size,
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
339
340
341
342
343
344
345
        merging_probs.stride(0) if merging_probs is not None else None,
        merging_probs.stride(1) if merging_probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
        unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
        WITH_MERGING_PROBS=merging_probs is not None,
        PERMUTE_PROBS=permuted_probs is not None,
346
347
        FP8_DTYPE=fp8_dtype,
    )
348
    return output, unpermuted_probs
349
350
351
352
353
354
355
356
357
358
359
360
361


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 64}),
        triton.Config({"BLOCK_SIZE": 128}),
        triton.Config({"BLOCK_SIZE": 256}),
        triton.Config({"BLOCK_SIZE": 512}),
        triton.Config({"BLOCK_SIZE": 1024}),
    ],
    key=["hidden_size"],
)
@triton.jit
362
def _unpermute_bwd_with_merging_probs_kernel(
363
364
365
366
    # pointers
    fwd_output_grad_ptr,
    fwd_input_grad_ptr,
    fwd_input_ptr,
367
368
    merging_probs_ptr,
    merging_probs_grad_ptr,
369
370
371
372
373
374
375
376
377
378
379
380
    row_id_map_ptr,
    # sizes
    num_tokens,
    num_experts,
    hidden_size,
    # strides
    stride_fwd_output_grad_token,
    stride_fwd_output_grad_hidden,
    stride_fwd_input_grad_token,
    stride_fwd_input_grad_hidden,
    stride_fwd_input_token,
    stride_fwd_input_hidden,
381
382
383
384
    stride_merging_probs_token,
    stride_merging_probs_expert,
    stride_merging_probs_grad_token,
    stride_merging_probs_grad_expert,
385
386
387
388
389
390
391
392
393
394
395
    # metas
    FP8_DTYPE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    if FP8_DTYPE == "e5m2":
        data_type = tl.float8e5
        pytorch_tensor_dtype = tl.uint8
    elif FP8_DTYPE == "e4m3":
        data_type = tl.float8e4nv
        pytorch_tensor_dtype = tl.uint8
    else:
396
        data_type = fwd_output_grad_ptr.dtype.element_ty
397
        assert FP8_DTYPE is None
398
    compute_type = tl.float32
399
400
401
402
403

    pid = tl.program_id(0)
    for expert_idx in range(num_experts):
        dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
        if dst_row != -1:
404
            prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
405
406
407
408
409
410
411
412
413
414
            current_start = 0
            while current_start < hidden_size:
                current_offset = current_start + tl.arange(0, BLOCK_SIZE)
                mask = current_offset < hidden_size
                input_off = (
                    pid * stride_fwd_output_grad_token
                    + current_offset * stride_fwd_output_grad_hidden
                )
                inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
                if FP8_DTYPE is not None:
415
416
417
418
419
420
421
422
                    inp = inp.to(data_type, bitcast=True)
                inp = inp.to(compute_type)
                merging_prob_off = (
                    pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
                )
                merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
                output = inp * merging_prob
                output = output.to(data_type)
423
                if FP8_DTYPE is not None:
424
                    output = output.to(pytorch_tensor_dtype, bitcast=True)
425
426
427
428
429
430
431
432
433
434
435
436
                output_off = (
                    dst_row * stride_fwd_input_grad_token
                    + current_offset * stride_fwd_input_grad_hidden
                )
                tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)

                fwd_input_off = (
                    dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
                )
                fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
                if FP8_DTYPE is not None:
                    fwd_input = fwd_input.to(data_type, bitcast=True)
437
                prob_grad_accum += fwd_input.to(compute_type) * inp
438
                current_start += BLOCK_SIZE
439
440
441
442
443
444
            probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
            probs_grad_off = (
                pid * stride_merging_probs_grad_token
                + expert_idx * stride_merging_probs_grad_expert
            )
            tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
445
        else:
446
447
448
449
450
            probs_grad_off = (
                pid * stride_merging_probs_grad_token
                + expert_idx * stride_merging_probs_grad_expert
            )
            tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0)
451
452


453
def unpermute_with_mask_map_bwd_with_merging_probs(
454
455
456
    fwd_output_grad: torch.Tensor,
    row_id_map: torch.Tensor,
    fwd_input: torch.Tensor,
457
    merging_probs: torch.Tensor,
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    num_tokens: int,
    num_experts: int,
    num_out_tokens: int,
    hidden_size: int,
    fp8_dtype: TE_DType,
):
    # pylint: disable=missing-function-docstring
    if fp8_dtype == TE_DType.kFloat8E5M2:
        fp8_dtype = "e5m2"
    elif fp8_dtype == TE_DType.kFloat8E4M3:
        fp8_dtype = "e4m3"
    else:
        fp8_dtype = None
    act_grad = torch.empty(
        (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
    )
474
475
476
    merging_probs_grad = torch.empty(
        (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
    )
477
    grid = (num_tokens,)
478
    _unpermute_bwd_with_merging_probs_kernel[grid](
479
480
481
        fwd_output_grad,
        act_grad,
        fwd_input,
482
483
        merging_probs,
        merging_probs_grad,
484
485
486
487
488
489
490
491
492
493
        row_id_map,
        num_tokens,
        num_experts,
        hidden_size,
        fwd_output_grad.stride(0),
        fwd_output_grad.stride(1),
        act_grad.stride(0),
        act_grad.stride(1),
        fwd_input.stride(0),
        fwd_input.stride(1),
494
495
496
497
        merging_probs.stride(0),
        merging_probs.stride(1),
        merging_probs_grad.stride(0),
        merging_probs_grad.stride(1),
498
499
        fp8_dtype,
    )
500
    return act_grad, merging_probs_grad
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 64}),
        triton.Config({"BLOCK_SIZE": 128}),
        triton.Config({"BLOCK_SIZE": 256}),
        triton.Config({"BLOCK_SIZE": 512}),
        triton.Config({"BLOCK_SIZE": 1024}),
    ],
    key=["hidden_size"],
)
@triton.jit
def _sort_chunks_by_idxs_kernel(
    # pointers
    input_ptr,
    split_sizes_ptr,
    sorted_indices_ptr,
    output_ptr,
    dst_rows_ptr,
521
522
    probs_ptr,
    permuted_probs_ptr,
523
524
525
526
527
528
529
530
    # sizes
    num_splits,
    hidden_size,
    # strides
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
531
532
    stride_probs_token,
    stride_permuted_probs_token,
533
    # metas
534
    PERMUTE_PROBS: tl.constexpr,
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
    IDX_LOAD_WIDTH: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    load_split_offset = tl.arange(0, IDX_LOAD_WIDTH)
    sorted_indices = tl.load(
        sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits
    )

    # get chunk idx of the current token in the input tensor
    input_chunk_idx = -1
    in_chunk_offset = tl.zeros([], dtype=tl.int64)
    acc_chunk_sizes = tl.zeros([], dtype=tl.int64)
    cursor = 0
    while cursor < num_splits:
        cur_chunk_size = tl.load(split_sizes_ptr + cursor).to(tl.int64)
        acc_chunk_sizes += cur_chunk_size
        if input_chunk_idx == -1 and acc_chunk_sizes > pid:
            input_chunk_idx = cursor
            in_chunk_offset = pid - (acc_chunk_sizes - cur_chunk_size)
        cursor += 1

    # get chunk idx of the current token in the output tensor
    output_chunk_idx = 0
    cursor = 0
    while cursor < num_splits:
        cur_input_idx = tl.load(sorted_indices_ptr + cursor)
        if cur_input_idx == input_chunk_idx:
            output_chunk_idx = cursor
        cursor += 1

    # make row_id_map
    output_split_sizes = tl.load(
        split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits
    ).to(tl.int64)
    output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
    dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
    tl.store(dst_rows_ptr + pid, dst_row)

    current_start = 0
    while current_start < hidden_size:
        current_offset = current_start + tl.arange(0, BLOCK_SIZE)
        mask = current_offset < hidden_size
        input_offsets = pid * stride_input_token + current_offset * stride_input_hidden
        output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
        inp = tl.load(input_ptr + input_offsets, mask=mask)
        tl.store(output_ptr + output_offsets, inp, mask=mask)
        current_start += BLOCK_SIZE

585
586
587
588
589
590
    if PERMUTE_PROBS:
        prob_off = pid * stride_probs_token
        prob = tl.load(probs_ptr + prob_off)
        permuted_prob_off = dst_row * stride_permuted_probs_token
        tl.store(permuted_probs_ptr + permuted_prob_off, prob)

591
592
593
594
595

def sort_chunks_by_idx(
    inp: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_indices: torch.Tensor,
596
    probs: torch.Tensor,
597
598
599
600
601
602
603
    num_tokens: int,
    hidden_size: int,
    num_splits: int,
):
    # pylint: disable=missing-function-docstring
    row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda")
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
604
605
606
607
    if probs is not None:
        permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
608
609
610
611
612
613
614
    grid = (num_tokens,)
    _sort_chunks_by_idxs_kernel[grid](
        inp,
        split_sizes,
        sorted_indices,
        output,
        row_id_map,
615
616
        probs,
        permuted_probs,
617
618
619
620
621
622
        num_splits,
        hidden_size,
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
623
624
625
626
        probs.stride(0) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        PERMUTE_PROBS=probs is not None,
        IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits),
627
    )
628
    return output, row_id_map, permuted_probs
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 64}),
        triton.Config({"BLOCK_SIZE": 128}),
        triton.Config({"BLOCK_SIZE": 256}),
        triton.Config({"BLOCK_SIZE": 512}),
        triton.Config({"BLOCK_SIZE": 1024}),
    ],
    key=["hidden_size"],
)
@triton.jit
def _sort_chunks_by_map(
    # pointers
    input_ptr,
    output_ptr,
    row_id_map_ptr,
647
648
    probs_ptr,
    permuted_probs_ptr,
649
650
651
652
653
654
655
    # sizes
    hidden_size,
    # strides
    stride_input_token,
    stride_input_hidden,
    stride_output_token,
    stride_output_hidden,
656
657
    stride_probs_token,
    stride_permuted_probs_token,
658
    # metas
659
    PERMUTE_PROBS: tl.constexpr,
660
661
662
663
664
665
666
667
668
669
670
671
672
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    dst_row = tl.load(row_id_map_ptr + pid)
    current_start = 0
    while current_start < hidden_size:
        current_offset = current_start + tl.arange(0, BLOCK_SIZE)
        mask = current_offset < hidden_size
        input_offsets = dst_row * stride_input_token + current_offset * stride_input_hidden
        output_offsets = pid * stride_output_token + current_offset * stride_output_hidden
        inp = tl.load(input_ptr + input_offsets, mask=mask)
        tl.store(output_ptr + output_offsets, inp, mask=mask)
        current_start += BLOCK_SIZE
673
674
675
676
677
    if PERMUTE_PROBS:
        prob_off = dst_row * stride_probs_token
        prob = tl.load(probs_ptr + prob_off)
        permuted_prob_off = pid * stride_permuted_probs_token
        tl.store(permuted_probs_ptr + permuted_prob_off, prob)
678
679
680
681
682


def sort_chunks_by_map(
    inp: torch.Tensor,
    row_id_map: torch.Tensor,
683
    probs: torch.Tensor,
684
685
686
687
688
    num_tokens: int,
    hidden_size: int,
):
    # pylint: disable=missing-function-docstring
    output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
689
690
691
692
    if probs is not None:
        permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
    else:
        permuted_probs = None
693
694
695
696
697
    grid = (num_tokens,)
    _sort_chunks_by_map[grid](
        inp,
        output,
        row_id_map,
698
699
        probs,
        permuted_probs,
700
701
702
703
704
        hidden_size,
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
705
706
707
        probs.stride(0) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        PERMUTE_PROBS=probs is not None,
708
    )
709
    return output, permuted_probs