test_permutation.py 73.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import os
6
7
import random

8
9
10
11
import torch
import pytest
from typing import Dict, List

12
import transformer_engine.pytorch as te
13
from transformer_engine.common import recipe
14
15
from transformer_engine.pytorch import (
    moe_permute as te_permute,
16
    moe_permute_with_probs as te_permute_with_probs,
17
    moe_permute_and_pad_with_probs as te_permute_and_pad_with_probs,
18
19
    moe_unpermute as te_unpermute,
    moe_sort_chunks_by_index as te_sort_chunks_by_index,
20
    moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs,
21
)
22
from transformer_engine.pytorch import (
23
24
    Float8Quantizer,
    Float8CurrentScalingQuantizer,
25
26
    Float8BlockQuantizer,
    MXFP8Quantizer,
27
)
28
import transformer_engine_torch as tex
29
from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding
30
import copy
31
32
33
34
35
36

seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


37
def pytorch_permute_index_map(tokens, indices, num_out_tokens: int = None):
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    """
    Permute the tokens based on the indices. Token with the same index will be grouped together.
    The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.

    Args:
        tokens: torch.Tensor
            The input token tensor.
        indices: torch.Tensor
            The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk].
        num_out_tokens: int, optional
            The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped.
            By default, set to None, meaning no tokens are dropped.

    Returns:
        torch.Tensor:
            The permuted tensor.
        torch.Tensor:
            The sorted_indices corresponding permuted tensor.
    """
    if indices.dim() == 1:
        topk = 1
    else:
        topk = indices.size(1)
    flatten_indices = indices.view(-1)
    sorted_indices = torch.argsort(flatten_indices, stable=True)
    num_out_tokens = num_out_tokens if num_out_tokens is not None else flatten_indices.size(0)

    permuted_tokens = tokens.index_select(0, sorted_indices[:num_out_tokens] // topk)
    return permuted_tokens, sorted_indices


69
def pytorch_unpermute_index_map(
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    permuted_tokens: torch.Tensor,
    sorted_indices: torch.Tensor,
    probs: torch.Tensor = None,
):
    """
    Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their
    corresponding probabilities.

    Args:
        permuted_tokens: torch.Tensor
            The tensor of permuted tokens to be unpermuted.
        sorted_indices: torch.Tensor
            The tensor of sorted indices used to unpermute the tokens.
        probs: torch.Tensor, optional
            The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will
            be merged with their respective probabilities.

    Returns:
        torch.Tensor:
            The unpermuted tokens, optionally merged with probabilities.
    """

    if probs is not None:
        # Unpermute and merge the tokens with their probabilities
        num_unpermuted_tokens = probs.numel()
        topk = probs.size(1)
    else:
        # Unpermute the tokens without merge
        num_unpermuted_tokens = sorted_indices.size(0)
        topk = 1
    unpermuted_tokens = torch.zeros(
        [num_unpermuted_tokens, permuted_tokens.shape[-1]],
        dtype=permuted_tokens.dtype,
        device=permuted_tokens.device,
    )

    unpermuted_tokens.index_copy_(0, sorted_indices[: permuted_tokens.size(0)], permuted_tokens)
    unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
    if probs is not None:
        unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
    unpermuted_tokens = unpermuted_tokens.sum(dim=1)
    return unpermuted_tokens


114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def pytorch_permute_mask_map(tokens, routing_map):
    """Permute the tokens and probs based on the mask.
    Tokens with the same designated expert will be grouped together.
    The shape of mask is [tokens, num_experts], it indicates which experts were selected
    by each token.

    Args:
        tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
        routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
    """
    num_tokens, _ = tokens.shape
    num_experts = routing_map.shape[1]

    # mask [num_tokens, num_experts] -> [num_experts, num_tokens]
    routing_map = routing_map.bool().T.contiguous()

    # Create a dense expert-to-token mapping from the sparse token-to-expert mapping
    token_indices = (
        torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
    )
    sorted_indices = token_indices.masked_select(routing_map)

    # use the mapping to permute the tokens
    permuted_input = tokens.index_select(0, sorted_indices)

    return permuted_input, sorted_indices


def pytorch_unpermute_mask_map(
    permuted_tokens: torch.Tensor,
    sorted_indices: torch.Tensor,
    restore_shape: torch.Size,
    probs: torch.Tensor = None,
    routing_map: torch.Tensor = None,
):
    """
    Restore the original order of tokens after permutation. If probs are provided, it
    will also apply them to the tokens before restoring the order.

    Args:
        permuted_tokens (torch.Tensor): The permuted token tensor.
        sorted_indices (torch.Tensor): The indices used to sort the tokens.
        restore_shape (torch.Size): The shape of the unpermuted tensor.
        probs (torch.Tensor, optional): The unpermuted probs tensor,
        routing_map (torch.Tensor, optional): Token to expert mapping, shape
            [num_tokens, num_experts].

    Returns:
        torch.Tensor: The tokens restored to their original order.
    """
    _, hidden = restore_shape

    if probs is not None:
        assert routing_map is not None, "Mask must be provided to permute the probs."
        permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
        permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)

    # Create an output tensor filled with zeros
    output_tokens = torch.zeros(
        restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype
    )
    # Scatter add the permuted_input back to the original positions
    output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens)
    return output_tokens


def pytorch_sort_chunks_by_index(
    input: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_idxs: torch.Tensor,
):
    """
    Split and sort the input tensor based on the split_sizes and sorted indices.
    return a tuple of (output, row_id_map). row_id_map is only used when fused=True.
    """
    input = torch.split(input, split_sizes.tolist(), dim=0)
    output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)
    return output


194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]:
    """Estimated tolerances for a datatype

    Based on tolerances for torch.testing.assert_close.

    """
    if te_dtype == tex.DType.kFloat32:
        return dict(rtol=1.0e-6, atol=1.0e-6)
    if te_dtype == tex.DType.kFloat16:
        return dict(rtol=3.0e-3, atol=1.0e-5)
    if te_dtype == tex.DType.kBFloat16:
        return dict(rtol=2.0e-2, atol=1.0e-5)
    if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3:
        return dict(rtol=2.0e-1, atol=1.0e-1)
    raise ValueError(f"Unsuppored dtype ({te_dtype})")


211
212
213
214
215
216
217
218
219
220
def backward_wrapper(
    act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False
):
    # Set forward_input.grad to None to avoid grad accumulation.
    if accumulate_grad == False:
        for i in forward_input:
            i.grad = None
    return act.backward(backward_input, retain_graph=retain_graph)


221
def _test_permutation_index_map(
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
    with_probs,
    BENCHMARK=False,
):
    if not with_probs and topK > 1:
        pytest.skip("Only permutations with topK=1 and without probabilities are supported.")

    if topK > num_expert:
        pytest.skip("topK should be smaller than the number of experts.")

    if num_out_tokens == None:
        num_out_tokens = num_tokens * topK

    print(
241
242
        "index map:"
        f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
243
244
245
246
247
248
249
250
251
252
253
254
    )

    # Convert TE dtypes to PyTorch dtypes
    if te_dtype == tex.DType.kFloat32:
        dtype = torch.float32
    elif te_dtype == tex.DType.kFloat16:
        dtype = torch.float16
    elif te_dtype == tex.DType.kBFloat16:
        dtype = torch.bfloat16
    else:
        pytest.skip("Invalid dtype.")

255
256
257
    pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
    pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
    pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

    pytorch_permute_fwd_input.requires_grad_(True)

    if num_tokens > 0:
        indices = torch.stack([torch.randperm(num_expert)[:topK] for _ in range(num_tokens)])
    else:
        indices = torch.empty((num_tokens, topK))
    indices = indices.to(torch.int32).cuda()

    probs = None
    if with_probs:
        probs = torch.rand(num_tokens, topK).cuda()
        row_sums = probs.sum(dim=1, keepdim=True)
        probs = probs / row_sums
        probs.requires_grad_(True)

    ###################################################################################################################################
    #
    # PyTorch Permutation
    #
    ###################################################################################################################################
279
    pytorch_permute_output, sorted_indices = pytorch_permute_index_map(
280
281
282
283
284
285
286
        pytorch_permute_fwd_input, indices, num_out_tokens
    )
    pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True)

    pytorch_unpermute_fwd_input = pytorch_permute_output.detach()
    pytorch_unpermute_fwd_input.requires_grad_(True)

287
    pytorch_unpermute_output = pytorch_unpermute_index_map(
288
289
290
291
292
293
294
295
296
        pytorch_unpermute_fwd_input, sorted_indices, probs=probs
    )
    pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # TE Permutation
    #
    ###################################################################################################################################
297
    te_permute_fwd_input = pytorch_permute_fwd_input.detach()
298
    te_permute_fwd_input.requires_grad_(True)
299
    te_permute_bwd_input = pytorch_permute_bwd_input.detach()
300

301
302
303
    te_permute_output, row_id_map = te_permute(
        te_permute_fwd_input, indices, num_out_tokens, map_type="index"
    )
304
305
306
307
308
309
310
311
    te_permute_output.backward(te_permute_bwd_input, retain_graph=True)

    te_probs = None
    if with_probs:
        te_probs = probs.detach()
        te_probs.requires_grad_(True)
    te_unpermute_fwd_input = te_permute_output.detach()
    te_unpermute_fwd_input.requires_grad_(True)
312
    te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach()
313

314
315
316
    te_unpermute_output = te_unpermute(
        te_unpermute_fwd_input, row_id_map, te_probs, map_type="index"
    )
317
318
319
320
321
322
323
324
325
    te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # Results Check
    #
    ###################################################################################################################################
    tols = dtype_tols(te_dtype)

326
327
328
329
    te_permute_output_ = te_permute_output.float()
    te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
    te_unpermute_output_ = te_unpermute_output.float()
    te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
330

331
    if not BENCHMARK:
332
        torch.testing.assert_close(
333
334
335
            pytorch_permute_output.float(),
            te_permute_output_,
            msg=f"Mismatch in te_permute fwd",
336
        )
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        torch.testing.assert_close(
            pytorch_permute_fwd_input.grad.float(),
            te_permute_fwd_input_grad,
            msg=f"Mismatch in te_permute bwd",
            **tols,
        )
        torch.testing.assert_close(
            pytorch_unpermute_output.float(),
            te_unpermute_output_,
            msg=f"Mismatch in te_unpermute fwd",
            **tols,
        )
        torch.testing.assert_close(
            pytorch_unpermute_fwd_input.grad.float(),
            te_unpermute_fwd_input_grad,
            msg=f"Mismatch in te_unpermute bwd",
            **tols,
        )
        if with_probs:
            torch.testing.assert_close(
                probs.grad.float(),
                te_probs.grad.float(),
                msg=f"Mismatch in te_unpermute bwd",
                **tols,
            )
362
363
364
365
366
367
368
369
370
371
372
373

    if not pytorch_permute_fwd_input.numel():
        print("Empty pytorch_permute_fwd_input activation test passed.")
        return

    ###################################################################################################################################
    #
    # Benchmark
    #
    ###################################################################################################################################
    if BENCHMARK:
        t1 = perf_test_cuda_kernel(
374
            lambda: pytorch_permute_index_map(pytorch_permute_fwd_input, indices, num_out_tokens)
375
376
        )
        t2 = perf_test_cuda_kernel(
377
            lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens, map_type="index")
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        )
        print(f"permute\t\tfwd: pytorch: {t1:.3f} ms,  TE: {t2:.3f} ms")

        t1 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                pytorch_permute_output,
                pytorch_permute_bwd_input,
                forward_input=[pytorch_permute_fwd_input],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        t2 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                te_permute_output,
                te_permute_bwd_input,
                forward_input=[te_permute_fwd_input],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        print(f"permute\t\tbwd: pytorch: {t1:.3f} ms,  TE: {t2:.3f} ms")

        t1 = perf_test_cuda_kernel(
402
403
404
            lambda: pytorch_unpermute_index_map(
                pytorch_unpermute_fwd_input, sorted_indices, probs=probs
            )
405
406
        )
        t2 = perf_test_cuda_kernel(
407
            lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs, map_type="index")
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
        )
        print(f"unpermute\tfwd: pytorch: {t1:.3f} ms,  TE: {t2:.3f} ms")

        t1 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                pytorch_unpermute_output,
                pytorch_unpermute_bwd_input,
                forward_input=(
                    [pytorch_unpermute_fwd_input, probs]
                    if with_probs
                    else [pytorch_unpermute_fwd_input]
                ),
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        t2 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                te_unpermute_output,
                te_unpermute_bwd_input,
                forward_input=(
                    [te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input]
                ),
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        print(f"unpermute\tbwd: pytorch: {t1:.3f} ms,  TE: {t2:.3f} ms")


438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def _test_permutation_mask_map(
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
    with_probs,
    BENCHMARK=False,
):
    if topK > num_expert:
        pytest.skip("topK should be smaller than the number of experts.")

    if num_out_tokens == None:
        num_out_tokens = num_tokens * topK

    print(
        "mask map:"
        f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
    )

    # Convert TE dtypes to PyTorch dtypes
    if te_dtype == tex.DType.kFloat32:
        dtype = torch.float32
    elif te_dtype == tex.DType.kFloat16:
        dtype = torch.float16
    elif te_dtype == tex.DType.kBFloat16:
        dtype = torch.bfloat16
    else:
        pytest.skip("Invalid dtype.")

469
470
471
    pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
    pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
    pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486

    pytorch_permute_fwd_input.requires_grad_(True)

    restore_shape = pytorch_permute_fwd_input.shape

    _tmp_tensor = torch.zeros((num_tokens * num_expert,))
    _tmp_tensor[: int(num_out_tokens)] = 1.0
    _tmp_idx = torch.randperm(num_tokens * num_expert)
    routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()

    probs = None
    if with_probs:
        probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
        row_sums = probs.sum(dim=1, keepdim=True)
        probs = probs / row_sums
487
        probs = probs.to(dtype)
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        probs.requires_grad_(True)

    ###################################################################################################################################
    #
    # PyTorch Permutation
    #
    ###################################################################################################################################
    pytorch_permute_output, sorted_indices = pytorch_permute_mask_map(
        pytorch_permute_fwd_input, routing_map
    )
    pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True)

    pytorch_unpermute_fwd_input = pytorch_permute_output.detach()
    pytorch_unpermute_fwd_input.requires_grad_(True)

    pytorch_unpermute_output = pytorch_unpermute_mask_map(
        pytorch_unpermute_fwd_input, sorted_indices, restore_shape, probs, routing_map
    )
    pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # TE Permutation
    #
    ###################################################################################################################################
513
    te_permute_fwd_input = pytorch_permute_fwd_input.detach()
514
    te_permute_fwd_input.requires_grad_(True)
515
    te_permute_bwd_input = pytorch_permute_bwd_input.detach()
516
517

    te_permute_output, row_id_map = te_permute(
518
        te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask"
519
520
521
522
523
524
525
526
527
    )
    te_permute_output.backward(te_permute_bwd_input, retain_graph=True)

    te_probs = None
    if with_probs:
        te_probs = probs.detach()
        te_probs.requires_grad_(True)
    te_unpermute_fwd_input = te_permute_output.detach()
    te_unpermute_fwd_input.requires_grad_(True)
528
    te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach()
529
530
531
532
533
534
535
536
537
538
539
540
541

    te_unpermute_output = te_unpermute(
        te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask"
    )
    te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # Results Check
    #
    ###################################################################################################################################
    tols = dtype_tols(te_dtype)

542
543
544
545
    te_permute_output_ = te_permute_output.float()
    te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
    te_unpermute_output_ = te_unpermute_output.float()
    te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
546

547
548
549
550
551
552
553
    if not BENCHMARK:
        torch.testing.assert_close(
            pytorch_permute_output.float(),
            te_permute_output_,
            msg=f"Mismatch in te_permute fwd",
            **tols,
        )
554
        torch.testing.assert_close(
555
556
557
558
            pytorch_permute_fwd_input.grad.float(),
            te_permute_fwd_input_grad,
            msg=f"Mismatch in te_permute bwd",
            **tols,
559
        )
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        torch.testing.assert_close(
            pytorch_unpermute_output.float(),
            te_unpermute_output_,
            msg=f"Mismatch in te_unpermute fwd",
            **tols,
        )
        torch.testing.assert_close(
            pytorch_unpermute_fwd_input.grad.float(),
            te_unpermute_fwd_input_grad,
            msg=f"Mismatch in te_unpermute bwd",
            **tols,
        )
        if with_probs:
            torch.testing.assert_close(
                probs.grad.float(),
                te_probs.grad.float(),
                msg=f"Mismatch in te_unpermute bwd",
                **tols,
            )
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593

    if not pytorch_permute_fwd_input.numel():
        print("Empty pytorch_permute_fwd_input activation test passed.")
        return

    ###################################################################################################################################
    #
    # Benchmark
    #
    ###################################################################################################################################
    if BENCHMARK:
        t1 = perf_test_cuda_kernel(
            lambda: pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map)
        )
        t2 = perf_test_cuda_kernel(
594
595
596
            lambda: te_permute(
                te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask"
            )
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        )
        print(f"permute\t\tfwd: pytorch: {t1:.3f} ms,  TE: {t2:.3f} ms")

        t1 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                pytorch_permute_output,
                pytorch_permute_bwd_input,
                forward_input=[pytorch_permute_fwd_input],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        t2 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                te_permute_output,
                te_permute_bwd_input,
                forward_input=[te_permute_fwd_input],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        print(f"permute\t\tbwd: pytorch: {t1:.3f} ms,  TE: {t2:.3f} ms")

        t1 = perf_test_cuda_kernel(
            lambda: pytorch_unpermute_mask_map(
                pytorch_unpermute_fwd_input, sorted_indices, restore_shape, probs, routing_map
            )
        )
        t2 = perf_test_cuda_kernel(
            lambda: te_unpermute(
                te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask"
            )
        )
        print(f"unpermute\tfwd: pytorch: {t1:.3f} ms,  TE: {t2:.3f} ms")

        t1 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                pytorch_unpermute_output,
                pytorch_unpermute_bwd_input,
                forward_input=(
                    [pytorch_unpermute_fwd_input, probs]
                    if with_probs
                    else [pytorch_unpermute_fwd_input]
                ),
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        t2 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                te_unpermute_output,
                te_unpermute_bwd_input,
                forward_input=(
                    [te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input]
                ),
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        print(f"unpermute\tbwd: pytorch: {t1:.3f} ms,  TE: {t2:.3f} ms")


659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
def _test_permutation_and_padding_mask_map(
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
    with_merging_probs=False,
    align_size=16,
    BENCHMARK=False,
):
    if topK > num_expert:
        pytest.skip("topK should be smaller than the number of experts.")

    if num_out_tokens is None:
        num_out_tokens = num_tokens * topK

    print(
        "permutation and padding:"
        f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK}"
        f" with_merging_probs:{with_merging_probs} align_size:{align_size} {te_dtype}"
    )

    # Convert TE dtypes to PyTorch dtypes
    if te_dtype == tex.DType.kFloat32:
        dtype = torch.float32
    elif te_dtype == tex.DType.kFloat16:
        dtype = torch.float16
    elif te_dtype == tex.DType.kBFloat16:
        dtype = torch.bfloat16
    else:
        pytest.skip("Invalid dtype.")

    _tmp_tensor = torch.zeros((num_tokens * num_expert,))
    _tmp_tensor[: int(num_out_tokens)] = 1.0
    _tmp_idx = torch.randperm(num_tokens * num_expert)
    routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()

    probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
    row_sums = probs.sum(dim=1, keepdim=True)
    probs = probs / row_sums
    probs = probs.to(dtype)
    probs.requires_grad_(True)

    tokens_per_expert = routing_map.sum(dim=0).cpu()
    target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
    num_permute_pad_out_tokens = target_tokens_per_expert.sum().item()

    permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
    permute_pad_bwd_input = torch.rand(
        (num_permute_pad_out_tokens, hidden_size), dtype=dtype
    ).cuda()
    unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
    permute_pad_fwd_input.requires_grad_(True)

    restore_shape = permute_pad_fwd_input.shape
    ###################################################################################################################################
    #
    # moe_permute_with_probs and Fp8Padding, moe_unpermute and Fp8Unpadding
    #
    ###################################################################################################################################
    # permute + padding
    permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
        permute_pad_fwd_input,
        probs,
        routing_map,
        num_out_tokens=num_out_tokens,
    )
    tokens_per_expert_list = tokens_per_expert.tolist()
    fp8_padding = Fp8Padding(num_expert, align_size)
    permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list)
    permuted_paded_probs, _ = fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list)

    permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True)

    # unpadding + unpermute

    unpermute_unpad_fwd_input = permuted_paded_output.detach()
    unpermute_unpad_fwd_input.requires_grad_(True)

    fp8_unpadding = Fp8Unpadding(num_expert, align_size)
    unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list)

    probs_naive = probs
    unpermuted_unpaded_output = te_unpermute(
        unpaded_output,
        row_id_map,
        merging_probs=probs_naive if with_merging_probs else None,
        restore_shape=restore_shape,
    )

    unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # fusion moe_permute_with_probs and Fp8Padding, fusion fusion moe_unpermute and Fp8Unpadding
    #
    ###################################################################################################################################
    # fusion permute_and_pad
    fusion_permute_and_pad_fwd_input = permute_pad_fwd_input.detach()
    fusion_permute_and_pad_fwd_input.requires_grad_(True)
    probs_fusion = probs_naive.detach().clone()
    probs_fusion.requires_grad_(True)

    (
        fusion_permuted_padded_output,
        fusion_permuted_padded_probs,
        row_id_map,
        pad_offsets,
        target_tokens_per_expert,
    ) = te_permute_and_pad_with_probs(
        fusion_permute_and_pad_fwd_input,
        probs_fusion,
        routing_map,
        tokens_per_expert,
        align_size,
    )
    fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)

    fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach()
    fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True)

    # fusion unpad and unpermute
    fusion_unpermute_unpad_fwd_input = fusion_permuted_padded_output.detach()
    fusion_unpermute_unpad_fwd_input.requires_grad_(True)

    fusion_unpermuted_unpaded_output = te_unpermute(
        fusion_unpermute_unpad_fwd_input,
        row_id_map,
        merging_probs=probs_fusion if with_merging_probs else None,
        restore_shape=restore_shape,
        pad_offsets=pad_offsets,
    )

    fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach()
    fusion_unpermuted_unpaded_output.backward(fusion_unpermute_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # Results Check
    #
    ###################################################################################################################################
    tols = dtype_tols(te_dtype)

    permuted_paded_output_ = permuted_paded_output.float()
    fusion_permuted_padded_output_ = fusion_permuted_padded_output.float()
    permute_pad_fwd_input_grad = permute_pad_fwd_input.grad.float()
    fusion_permute_and_pad_fwd_input_grad = fusion_permute_and_pad_fwd_input.grad.float()

    unpermuted_unpaded_output_ = unpermuted_unpaded_output.float()
    fusion_unpermuted_unpaded_output_ = fusion_unpermuted_unpaded_output.float()
    unpermute_unpad_fwd_input_grad = unpermute_unpad_fwd_input.grad.float()
    fusion_unpermute_unpad_fwd_input_grad = fusion_unpermute_unpad_fwd_input.grad.float()

    if not BENCHMARK:
        torch.testing.assert_close(
            permuted_paded_output_,
            fusion_permuted_padded_output_,
            msg=f"Mismatch in te_permute_and_pad fwd",
            **tols,
        )
        torch.testing.assert_close(
            permute_pad_fwd_input_grad,
            fusion_permute_and_pad_fwd_input_grad,
            msg=f"Mismatch in te_permute_and_pad bwd",
            **tols,
        )
        torch.testing.assert_close(
            unpermuted_unpaded_output_,
            fusion_unpermuted_unpaded_output_,
            msg=f"Mismatch in te_unpermute fwd",
            **tols,
        )
        torch.testing.assert_close(
            unpermute_unpad_fwd_input_grad,
            fusion_unpermute_unpad_fwd_input_grad,
            msg=f"Mismatch in te_unpermute bwd",
            **tols,
        )
        torch.testing.assert_close(
            permuted_paded_probs.float(),
            fusion_permuted_padded_probs.float(),
            msg=f"Mismatch in te_permute_and_pad bwd",
            **tols,
        )
        if with_merging_probs:
            torch.testing.assert_close(
                probs_naive.grad.float(),
                probs_fusion.grad.float(),
                msg=f"Mismatch in te_unpermute bwd",
                **tols,
            )

    ###################################################################################################################################
    #
    # Benchmark
    #
    ###################################################################################################################################
    if BENCHMARK:

        def permute_and_pad():
            permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
                permute_pad_fwd_input,
                probs,
                routing_map,
                num_out_tokens=num_out_tokens,
            )
            fp8_padding(permuted_output, tokens_per_expert_list)
            fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list)

        def fusion_permute_and_pad():
            (
                fusion_permuted_padded_output,
                fusion_permuted_padded_probs,
                row_id_map,
                pad_offsets,
                target_tokens_per_expert,
            ) = te_permute_and_pad_with_probs(
                fusion_permute_and_pad_fwd_input,
                probs,
                routing_map,
                tokens_per_expert,
                align_size,
            )
            fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)

        t1 = perf_test_cuda_kernel(lambda: permute_and_pad())

        t2 = perf_test_cuda_kernel(lambda: fusion_permute_and_pad())

        print(f"permute_and_pad\t\tfwd: naive: {t1:.3f} ms,  fusion: {t2:.3f} ms")

        t1 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                permuted_paded_output,
                permute_pad_bwd_input,
                forward_input=[permute_pad_fwd_input],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        t2 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                fusion_permuted_padded_output,
                fusion_permute_pad_bwd_input,
                forward_input=[fusion_permute_and_pad_fwd_input],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        print(f"permute_and_pad\t\tbwd: naive: {t1:.3f} ms,  fusion: {t2:.3f} ms")

        def unpad_unpermute():
            unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list)
            unpermuted_unpaded_output = te_unpermute(
                unpaded_output, row_id_map, restore_shape=restore_shape
            )

            unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True)

        t1 = perf_test_cuda_kernel(lambda: unpad_unpermute())
        t2 = perf_test_cuda_kernel(
            lambda: te_unpermute(
                fusion_unpermute_unpad_fwd_input,
                row_id_map,
                restore_shape=restore_shape,
                pad_offsets=pad_offsets,
            )
        )
        print(f"unpermute_and_unpad\tfwd: naive: {t1:.3f} ms,  fusion: {t2:.3f} ms")

        t1 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                unpermuted_unpaded_output,
                unpermute_unpad_bwd_input,
                forward_input=([unpermute_unpad_fwd_input, probs]),
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        t2 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                fusion_unpermuted_unpaded_output,
                fusion_unpermute_bwd_input,
                forward_input=([fusion_unpermute_unpad_fwd_input, probs]),
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        print(f"unpermute_and_unpad\tbwd: naive: {t1:.3f} ms,  fusion: {t2:.3f} ms")


def _test_permutation_and_padding_with_merging_probs(
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
    align_size=16,
    BENCHMARK=False,
):
    """
    Test the combination of merging_probs AND pad_offsets together in moe_unpermute.
    This specifically tests the backward pass fix where pad_offsets must be used
    when computing gradients with merging_probs.
    """
    if topK > num_expert:
        pytest.skip("topK should be smaller than the number of experts.")

    if num_out_tokens == None:
        num_out_tokens = num_tokens * topK

    print(
        "permutation and padding with merging probs:"
        f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}"
    )

    # Convert TE dtypes to PyTorch dtypes
    if te_dtype == tex.DType.kFloat32:
        dtype = torch.float32
    elif te_dtype == tex.DType.kFloat16:
        dtype = torch.float16
    elif te_dtype == tex.DType.kBFloat16:
        dtype = torch.bfloat16
    else:
        pytest.skip("Invalid dtype.")

    _tmp_tensor = torch.zeros((num_tokens * num_expert,))
    _tmp_tensor[: int(num_out_tokens)] = 1.0
    _tmp_idx = torch.randperm(num_tokens * num_expert)
    routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()

    probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
    row_sums = probs.sum(dim=1, keepdim=True)
    probs = probs / row_sums
    probs = probs.to(dtype)
    probs.requires_grad_(True)

    tokens_per_expert = routing_map.sum(dim=0).cpu()
    target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
    num_permute_pad_out_tokens = target_tokens_per_expert.sum().item()

    permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
    permute_pad_bwd_input = torch.rand(
        (num_permute_pad_out_tokens, hidden_size), dtype=dtype
    ).cuda()
    unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
    permute_pad_fwd_input.requires_grad_(True)

    restore_shape = permute_pad_fwd_input.shape
    ###################################################################################################################################
    #
    # Reference: moe_permute_with_probs + Fp8Padding, then Fp8Unpadding + moe_unpermute with merging_probs
    #
    ###################################################################################################################################
    # permute + padding
    permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
        permute_pad_fwd_input,
        probs,
        routing_map,
        num_out_tokens=num_out_tokens,
    )
    tokens_per_expert_list = tokens_per_expert.tolist()
    fp8_padding = Fp8Padding(num_expert, align_size)
    permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list)

    permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True)

    # Reference: unpadding + unpermute WITH merging_probs
    ref_unpermute_fwd_input = permuted_paded_output.detach()
    ref_unpermute_fwd_input.requires_grad_(True)

    ref_probs = probs.detach()
    ref_probs.requires_grad_(True)

    fp8_unpadding = Fp8Unpadding(num_expert, align_size)
    unpaded_output = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list)
    ref_unpermuted_output = te_unpermute(
        unpaded_output, row_id_map, ref_probs, restore_shape=restore_shape
    )

    ref_unpermuted_output.backward(unpermute_unpad_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # Fused: moe_permute_and_pad_with_probs, then moe_unpermute with BOTH merging_probs AND pad_offsets
    #
    ###################################################################################################################################
    # fusion permute_and_pad
    fusion_permute_fwd_input = permute_pad_fwd_input.detach()
    fusion_permute_fwd_input.requires_grad_(True)
    fusion_probs = probs.detach()
    fusion_probs.requires_grad_(True)

    (
        fusion_permuted_padded_output,
        fusion_permuted_padded_probs,
        fused_row_id_map,
        pad_offsets,
        _,
    ) = te_permute_and_pad_with_probs(
        fusion_permute_fwd_input,
        fusion_probs,
        routing_map,
        tokens_per_expert,
        align_size,
    )

    fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach()
    fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True)

    # Fused: unpermute with BOTH merging_probs AND pad_offsets
    fusion_unpermute_fwd_input = fusion_permuted_padded_output.detach()
    fusion_unpermute_fwd_input.requires_grad_(True)

    fusion_merging_probs = probs.detach()
    fusion_merging_probs.requires_grad_(True)

    fusion_unpermuted_output = te_unpermute(
        fusion_unpermute_fwd_input,
        fused_row_id_map,
        fusion_merging_probs,
        restore_shape=restore_shape,
        pad_offsets=pad_offsets,
    )

    fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach()
    fusion_unpermuted_output.backward(fusion_unpermute_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # Results Check
    #
    ###################################################################################################################################
    tols = dtype_tols(te_dtype)

    # Check forward pass
    ref_unpermuted_output_ = ref_unpermuted_output.float()
    fusion_unpermuted_output_ = fusion_unpermuted_output.float()

    if not BENCHMARK:
        torch.testing.assert_close(
            ref_unpermuted_output_,
            fusion_unpermuted_output_,
            msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets fwd",
            **tols,
        )

        # Check backward pass - activation gradients
        ref_unpermute_fwd_input_grad = ref_unpermute_fwd_input.grad.float()
        fusion_unpermute_fwd_input_grad = fusion_unpermute_fwd_input.grad.float()

        torch.testing.assert_close(
            ref_unpermute_fwd_input_grad,
            fusion_unpermute_fwd_input_grad,
            msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (act_grad)",
            **tols,
        )

        # Check backward pass - probs gradients
        ref_probs_grad = ref_probs.grad.float()
        fusion_probs_grad = fusion_merging_probs.grad.float()

        torch.testing.assert_close(
            ref_probs_grad,
            fusion_probs_grad,
            msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (probs_grad)",
            **tols,
        )

    ###################################################################################################################################
    #
    # Benchmark
    #
    ###################################################################################################################################
    if BENCHMARK:

        def ref_unpad_unpermute():
            unpaded = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list)
            return te_unpermute(unpaded, row_id_map, ref_probs, restore_shape=restore_shape)

        def fused_unpermute():
            return te_unpermute(
                fusion_unpermute_fwd_input,
                fused_row_id_map,
                fusion_merging_probs,
                restore_shape=restore_shape,
                pad_offsets=pad_offsets,
            )

        t1 = perf_test_cuda_kernel(lambda: ref_unpad_unpermute())
        t2 = perf_test_cuda_kernel(lambda: fused_unpermute())
        print(f"unpermute_unpad_with_probs\tfwd: naive: {t1:.3f} ms,  fusion: {t2:.3f} ms")

        t1 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                ref_unpermuted_output,
                unpermute_unpad_bwd_input,
                forward_input=[ref_unpermute_fwd_input, ref_probs],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        t2 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                fusion_unpermuted_output,
                fusion_unpermute_bwd_input,
                forward_input=[fusion_unpermute_fwd_input, fusion_merging_probs],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        print(f"unpermute_unpad_with_probs\tbwd: naive: {t1:.3f} ms,  fusion: {t2:.3f} ms")


1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
def _test_permutation_mask_map_fp8(
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
    recipe,
):
    if topK > num_expert:
        pytest.skip("topK should be smaller than the number of experts.")

    if num_out_tokens == None:
        num_out_tokens = num_tokens * topK

    if recipe.delayed():
        quantizer = Float8Quantizer(
            scale=torch.full([1], 1.0).cuda().squeeze(),
            amax=torch.full([1], 1.0).cuda(),
            fp8_dtype=te_dtype,
        )
    elif recipe.float8_current_scaling():
        quantizer = Float8CurrentScalingQuantizer(
            fp8_dtype=te_dtype,
            device=torch.device("cuda"),
            columnwise=False,
        )
    elif recipe.float8_block_scaling():
        quantizer = Float8BlockQuantizer(
            fp8_dtype=te_dtype,
            rowwise=True,
            columnwise=False,
            amax_epsilon=0.0,
            force_pow_2_scales=True,  # Fp8 sub-channel a2a requires e8 scales
            block_scaling_dim=1,  # 1x128 scaling
        )
    elif recipe.mxfp8():
        quantizer = MXFP8Quantizer(
            fp8_dtype=te_dtype,
            rowwise=True,
            columnwise=False,
        )
    else:
        raise ValueError("Unsupported FP8 recipe")

    permute_fwd_input = torch.rand(
        size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
    )
    # Make an empty fp8 tensor
    permute_fwd_input_fp8 = quantizer.make_empty(
        permute_fwd_input.shape,
        dtype=permute_fwd_input.dtype,
        device=permute_fwd_input.device,
    )
    # quantize the tensor
    quantizer.update_quantized(permute_fwd_input, permute_fwd_input_fp8)
    if recipe.float8_block_scaling():
        pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._rowwise_data)
        pytorch_permute_fwd_scale_input = copy.deepcopy(
            permute_fwd_input_fp8._rowwise_scale_inv.T.contiguous()
        )
    elif recipe.mxfp8():
        pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._rowwise_data)
        pytorch_permute_fwd_scale_input = copy.deepcopy(
            permute_fwd_input_fp8._rowwise_scale_inv.contiguous()
        )
    else:
        pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._data)
        pytorch_permute_fwd_scale_input = None

    _tmp_tensor = torch.zeros((num_tokens * num_expert,))
    _tmp_tensor[: int(num_out_tokens)] = 1.0
    _tmp_idx = torch.randperm(num_tokens * num_expert)
    routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()

    # PyTorch Permutaion
    pytorch_permute_output, _ = pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map)
    if pytorch_permute_fwd_scale_input is not None:
        pytorch_permute_scale_output, _ = pytorch_permute_mask_map(
            pytorch_permute_fwd_scale_input, routing_map
        )

    # TE Permutation
    permute_output, _ = te_permute(
        permute_fwd_input_fp8, routing_map, num_out_tokens=num_out_tokens, map_type="mask"
    )
    if recipe.float8_block_scaling():
        te_permute_output = permute_output._rowwise_data
        te_permute_scale_output = permute_output._rowwise_scale_inv.T.contiguous()
    elif recipe.mxfp8():
        te_permute_output = permute_output._rowwise_data
        te_permute_scale_output = permute_output._rowwise_scale_inv.contiguous()
    else:
        te_permute_output = permute_output._data
        te_permute_scale_output = None

    # check the permute output
    torch.testing.assert_close(
        pytorch_permute_output,
        te_permute_output,
        atol=0,
        rtol=0,
    )
    if recipe.float8_block_scaling() or recipe.mxfp8():
        torch.testing.assert_close(
            pytorch_permute_scale_output,
            te_permute_scale_output,
            atol=0,
            rtol=0,
        )


1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
def _test_moe_chunk_sort(
    te_dtype,
    num_tokens,
    num_expert,
    tp_size,
    hidden_size,
    BENCHMARK=False,
):
    print(
        "chunk permute:"
        f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}"
    )

    # Convert TE dtypes to PyTorch dtypes
    if te_dtype == tex.DType.kFloat32:
        dtype = torch.float32
    elif te_dtype == tex.DType.kFloat16:
        dtype = torch.float16
    elif te_dtype == tex.DType.kBFloat16:
        dtype = torch.bfloat16
    else:
        pytest.skip("Invalid dtype.")

1310
1311
    pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
    pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338

    pytorch_fwd_input.requires_grad_(True)

    _split_sizes = [0] * (num_expert * tp_size)
    for _ in range(num_tokens):
        idx = random.randint(0, num_expert * tp_size - 1)
        _split_sizes[idx] += 1
    split_sizes = torch.tensor(_split_sizes, dtype=torch.int32).ravel()
    split_sizes_cuda = split_sizes.to(device="cuda")

    _sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32)
    sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel()
    sorted_idxs_cuda = sorted_idxs.to(device="cuda")

    ###################################################################################################################################
    #
    # PyTorch Permutation
    #
    ###################################################################################################################################
    pytorch_output = pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs)
    pytorch_output.backward(pytorch_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # TE Permutation
    #
    ###################################################################################################################################
1339
    te_fwd_input = pytorch_fwd_input.detach()
1340
    te_fwd_input.requires_grad_(True)
1341
    te_bwd_input = pytorch_bwd_input.detach()
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352

    te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda)
    te_output.backward(te_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # Results Check
    #
    ###################################################################################################################################
    tols = dtype_tols(te_dtype)

1353
1354
    te_output_ = te_output.float()
    te_fwd_input_grad = te_fwd_input.grad.float()
1355

1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
    if not BENCHMARK:
        torch.testing.assert_close(
            pytorch_output.float(),
            te_output_,
            msg=f"Mismatch in te_permute fwd",
            **tols,
        )
        torch.testing.assert_close(
            pytorch_fwd_input.grad.float(),
            te_fwd_input_grad,
            msg=f"Mismatch in te_permute bwd",
            **tols,
        )
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408

    if not pytorch_fwd_input.numel():
        print("Empty pytorch_fwd_input activation test passed.")
        return

    ###################################################################################################################################
    #
    # Benchmark
    #
    ###################################################################################################################################
    if BENCHMARK:
        t1 = perf_test_cuda_kernel(
            lambda: pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs)
        )
        t2 = perf_test_cuda_kernel(
            lambda: te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda)
        )
        print(f"chunk sort\t\tfwd: pytorch: {t1:.3f} ms,  TE: {t2:.3f} ms")

        t1 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                pytorch_output,
                pytorch_bwd_input,
                forward_input=[pytorch_fwd_input],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        t2 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                te_output,
                te_bwd_input,
                forward_input=[te_fwd_input],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        print(f"chunk sort\t\tbwd: pytorch: {t1:.3f} ms,  TE: {t2:.3f} ms")


1409
1410
1411
1412
1413
1414
1415
1416
def _test_permutation_mask_map_alongside_probs(
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
    tp_size,
1417
    BENCHMARK=False,
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
):
    if topK > num_expert:
        pytest.skip("topK should be smaller than the number of experts.")

    if num_out_tokens == None:
        num_out_tokens = num_tokens * topK

    print(
        "mask map alongside probs:"
        f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
    )

    # Convert TE dtypes to PyTorch dtypes
    if te_dtype == tex.DType.kFloat32:
        dtype = torch.float32
    elif te_dtype == tex.DType.kFloat16:
        dtype = torch.float16
    elif te_dtype == tex.DType.kBFloat16:
        dtype = torch.bfloat16
    else:
        pytest.skip("Invalid dtype.")

1440
1441
    pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
    pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454

    pytorch_permute_fwd_input.requires_grad_(True)

    restore_shape = pytorch_permute_fwd_input.shape

    _tmp_tensor = torch.zeros((num_tokens * num_expert,))
    _tmp_tensor[: int(num_out_tokens)] = 1.0
    _tmp_idx = torch.randperm(num_tokens * num_expert)
    routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()

    probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
    row_sums = probs.sum(dim=1, keepdim=True)
    probs = probs / row_sums
1455
    probs = probs.to(dtype)
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
    probs.requires_grad_(True)

    split_sizes = [0] * (num_expert * tp_size)
    for i in range(num_out_tokens):
        idx = random.randint(0, num_expert * tp_size - 1)
        split_sizes[idx] += 1
    split_sizes = torch.tensor(split_sizes, dtype=torch.int32)
    split_sizes_cuda = split_sizes.to(device="cuda")

    _sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32)
    sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel()
    sorted_idxs_cuda = sorted_idxs.to(device="cuda")

    split_sizes_2 = [split_sizes[i] for i in sorted_idxs.tolist()]
    split_sizes_2 = torch.tensor(split_sizes_2, dtype=torch.int32)
    split_sizes_2_cuda = split_sizes_2.to(device="cuda")

    sorted_idxs_2 = [0] * (num_expert * tp_size)
    for i in range(num_expert * tp_size):
        sorted_idxs_2[sorted_idxs[i]] = i
    sorted_idxs_2 = torch.tensor(sorted_idxs_2, dtype=torch.int32)
    sorted_idxs_2_cuda = sorted_idxs_2.to(device="cuda")

    ###################################################################################################################################
    #
    # PyTorch Permutation
    #
    ###################################################################################################################################
    pytorch_permute_output, sorted_indices = pytorch_permute_mask_map(
        pytorch_permute_fwd_input, routing_map
    )

    pytorch_permute_output = pytorch_sort_chunks_by_index(
        pytorch_permute_output, split_sizes, sorted_idxs
    )

    pytorch_permute_output = pytorch_sort_chunks_by_index(
        pytorch_permute_output, split_sizes_2, sorted_idxs_2
    )

    pytorch_unpermute_output = pytorch_unpermute_mask_map(
        pytorch_permute_output, sorted_indices, restore_shape, probs, routing_map
    )
    pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True)

    ###################################################################################################################################
    #
    # TE Permutation
    #
    ###################################################################################################################################
1506
    te_permute_fwd_input = pytorch_permute_fwd_input.detach()
1507
1508
    te_permute_fwd_input.requires_grad_(True)

1509
    te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach()
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
    te_probs = probs.detach()
    te_probs.requires_grad_(True)

    te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs(
        te_permute_fwd_input,
        te_probs,
        routing_map,
        num_out_tokens=num_out_tokens,
    )

    te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs(
        te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda
    )

1524
1525
1526
    te_permute_output_dtype = te_permute_output.dtype
    te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1)
    te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype)
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543

    te_permute_output = te_sort_chunks_by_index(
        te_permute_output, split_sizes_2_cuda, sorted_idxs_2_cuda
    )

    te_unpermute_output = te_unpermute(
        te_permute_output,
        row_id_map,
        restore_shape=restore_shape,
        map_type="mask",
    )
    te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)

    ###############################################################################################

    tols = dtype_tols(te_dtype)

1544
1545
    te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
    te_unpermute_output_ = te_unpermute_output.float()
1546

1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
    if not BENCHMARK:
        torch.testing.assert_close(
            pytorch_unpermute_output.float(),
            te_unpermute_output_,
            msg=f"Mismatch in fused_unpermute fwd",
            **tols,
        )
        torch.testing.assert_close(
            pytorch_permute_fwd_input.grad.float(),
            te_permute_fwd_input_grad,
            msg=f"Mismatch in fused_permute bwd",
            **tols,
        )
        torch.testing.assert_close(
            probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols
        )

    if BENCHMARK:
        t1 = perf_test_cuda_kernel(
            lambda: te_permute_with_probs(
                te_permute_fwd_input, te_probs, routing_map, num_out_tokens=num_out_tokens
            )
        )
        print(f"permute\t\tfwd: TE: {t1:.3f} ms")

        te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs(
            te_permute_fwd_input,
            te_probs,
            routing_map,
            num_out_tokens=num_out_tokens,
        )
        te_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
        t2 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                te_permute_output,
                te_permute_bwd_input,
                forward_input=[te_permute_fwd_input],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        print(f"permute\t\tbwd: TE: {t2:.3f} ms")

        chunk_sort_fwd_input = te_permute_output.detach()
        chunk_sort_fwd_input.requires_grad_(True)
        chunk_sort_fwd_probs = te_permuted_probs.detach()
        chunk_sort_fwd_probs.requires_grad_(True)
        t1 = perf_test_cuda_kernel(
            lambda: te_sort_chunks_by_index_with_probs(
                chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda
            )
        )
        print(f"chunk sort\t\tfwd: TE: {t1:.3f} ms")

        chunk_sort_output, _ = te_sort_chunks_by_index_with_probs(
            chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda
        )
        t2 = perf_test_cuda_kernel(
            lambda: backward_wrapper(
                chunk_sort_output,
                te_permute_bwd_input,
                forward_input=[chunk_sort_fwd_input],
                retain_graph=True,
                accumulate_grad=False,
            )
        )
        print(f"chunk sort\t\tbwd: TE: {t2:.3f} ms")
1614
1615


1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
def perf_test_cuda_kernel(cuda_kernel_fn):
    if torch.cuda.is_available():
        # create CUDA event
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        # warmup
        for _ in range(50):
            cuda_kernel_fn()

        start_event.record()
        for _ in range(100):
            cuda_kernel_fn()
        end_event.record()
        torch.cuda.synchronize()

        elapsed_time_ms = start_event.elapsed_time(end_event)
        return elapsed_time_ms / 100
    else:
        pytest.skip("CUDA is not available.")


# TE tensor dtypes
_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16]
1640
if te.is_bf16_available():
1641
1642
1643
1644
1645
    _te_dtypes.append(tex.DType.kBFloat16)


@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
1646
@pytest.mark.parametrize("num_expert", [7, 16])
1647
@pytest.mark.parametrize("hidden_size", [4096])
1648
@pytest.mark.parametrize("topK", [2, 5])
1649
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
1650
def test_permutation_index_map(
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
):
    with_probs = True
    BENCHMARK = False

1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
    _test_permutation_index_map(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        with_probs=with_probs,
        BENCHMARK=BENCHMARK,
    )


@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
1675
@pytest.mark.parametrize("num_expert", [7, 16])
1676
@pytest.mark.parametrize("hidden_size", [4096])
1677
@pytest.mark.parametrize("topK", [2, 5])
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_mask_map(
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
):
    with_probs = True
    BENCHMARK = False

    _test_permutation_mask_map(
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        with_probs=with_probs,
        BENCHMARK=BENCHMARK,
    )


1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_out_tokens", [None])
@pytest.mark.parametrize(
    "num_tokens, num_expert, hidden_size, topK",
    [
        (4096, 8, 1280, 2),
        (4096, 64, 4096, 6),
        (4096, 256, 7168, 6),
        (4096, 512, 9216, 8),
    ],
)
@pytest.mark.parametrize("with_merging_probs", [True, False])
def test_permutation_and_padding_mask_map(
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
    with_merging_probs,
):
    BENCHMARK = False

    _test_permutation_and_padding_mask_map(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        with_merging_probs=with_merging_probs,
        BENCHMARK=BENCHMARK,
    )


@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_out_tokens", [None])
@pytest.mark.parametrize(
    "num_tokens, num_expert, hidden_size, topK",
    [
        (4096, 8, 1280, 2),
        (4096, 64, 4096, 6),
        (4096, 256, 7168, 6),
        (4096, 512, 9216, 8),
    ],
)
def test_permutation_and_padding_with_merging_probs(
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
):
    """Test moe_unpermute backward pass with BOTH merging_probs AND pad_offsets."""
    BENCHMARK = False

    _test_permutation_and_padding_with_merging_probs(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        BENCHMARK=BENCHMARK,
    )


1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
@pytest.mark.parametrize("te_dtype", _te_dtypes)
def test_permutation_mask_map_empty_input(te_dtype):
    with_probs = True
    BENCHMARK = False

    _test_permutation_mask_map(
        te_dtype=te_dtype,
        num_tokens=0,
        num_expert=8,
        hidden_size=4096,
        topK=2,
        num_out_tokens=0,
        with_probs=with_probs,
        BENCHMARK=BENCHMARK,
    )


@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
1789
@pytest.mark.parametrize("num_expert", [7, 16])
1790
@pytest.mark.parametrize("hidden_size", [4096])
1791
@pytest.mark.parametrize("topK", [2, 5])
1792
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
1793
@pytest.mark.parametrize("tp_size", [1, 2])
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
def test_permutation_mask_map_alongside_probs(
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
    tp_size,
):
    _test_permutation_mask_map_alongside_probs(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        tp_size=tp_size,
    )


@pytest.mark.parametrize("te_dtype", _te_dtypes)
def test_permutation_mask_map_alongside_probs_empty_input(te_dtype):
    _test_permutation_mask_map_alongside_probs(
        te_dtype=te_dtype,
        num_tokens=0,
        num_expert=8,
        hidden_size=4096,
        topK=2,
        num_out_tokens=0,
        tp_size=2,
    )


1827
# Only run FP8 tests on H100.
1828
1829
1830
1831
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
    return_reason=True
1832
1833
1834
1835
1836
1837
1838
)
fp8_recipes = [
    recipe.MXFP8BlockScaling(),
    recipe.DelayedScaling(),
    recipe.Float8CurrentScaling(),
    recipe.Float8BlockScaling(),
]
1839
1840
1841
1842


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
1843
@pytest.mark.parametrize("num_tokens", [4096])
1844
@pytest.mark.parametrize("num_expert", [7, 16])
1845
@pytest.mark.parametrize("hidden_size", [4096])
1846
@pytest.mark.parametrize("topK", [2, 5])
1847
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
1848
@pytest.mark.parametrize("recipe", fp8_recipes)
1849
def test_permutation_mask_map_fp8(
1850
1851
1852
1853
1854
1855
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
    topK,
    num_out_tokens,
1856
    recipe,
1857
):
1858
1859
1860
1861
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
    if recipe.float8_block_scaling() and not fp8_block_scaling_available:
        pytest.skip(reason_for_no_fp8_block_scaling)
1862

1863
    _test_permutation_mask_map_fp8(
1864
1865
1866
1867
1868
1869
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
1870
        recipe=recipe,
1871
1872
1873
    )


1874
1875
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
1876
@pytest.mark.parametrize("num_expert", [7, 16])
1877
@pytest.mark.parametrize("hidden_size", [4096])
1878
def test_permutation_index_map_topk1_no_probs(
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
):
    topK = 1
    num_out_tokens = None
    with_probs = False
    BENCHMARK = False

1889
    _test_permutation_index_map(
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        with_probs=with_probs,
        BENCHMARK=BENCHMARK,
    )


1901
1902
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
1903
@pytest.mark.parametrize("num_expert", [7, 16])
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
@pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_mask_map_topk1_no_probs(
    te_dtype,
    num_tokens,
    num_expert,
    hidden_size,
):
    topK = 1
    num_out_tokens = None
    with_probs = False
    BENCHMARK = False

    _test_permutation_mask_map(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        with_probs=with_probs,
        BENCHMARK=BENCHMARK,
    )


@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
1930
@pytest.mark.parametrize("num_expert", [7, 16])
1931
@pytest.mark.parametrize("tp_size", [2, 8])
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
@pytest.mark.parametrize("hidden_size", [4096])
def test_chunk_permutation(
    te_dtype,
    num_tokens,
    num_expert,
    tp_size,
    hidden_size,
):
    BENCHMARK = False

    _test_moe_chunk_sort(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        tp_size=tp_size,
        hidden_size=hidden_size,
        BENCHMARK=BENCHMARK,
    )


1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
@pytest.mark.parametrize("te_dtype", _te_dtypes)
def test_chunk_permutation_empty_input(te_dtype):
    BENCHMARK = False

    _test_moe_chunk_sort(
        te_dtype=te_dtype,
        num_tokens=0,
        num_expert=8,
        tp_size=2,
        hidden_size=4096,
        BENCHMARK=BENCHMARK,
    )


1966
1967
1968
1969
@pytest.mark.skipif(
    os.getenv("RUN_BENCHMARK_TESTS", "0") != "1",
    reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k single_case",
)
1970
1971
1972
1973
1974
def test_permutation_single_case():
    print("GPU:", torch.cuda.get_device_name(0))

    # te_dtype = tex.DType.kFloat32
    # te_dtype = tex.DType.kFloat16
1975
    te_dtype = tex.DType.kBFloat16
1976

1977
    num_tokens = 12
1978
1979
1980
1981
1982
1983
1984
    num_expert = 4
    hidden_size = 16
    topK = 2
    num_out_tokens = num_tokens * topK - 1
    with_probs = True
    Benchmark = True

1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
    _test_permutation_index_map(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        with_probs=with_probs,
        BENCHMARK=Benchmark,
    )

    _test_permutation_mask_map(
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        with_probs=with_probs,
        BENCHMARK=Benchmark,
    )

2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
    _test_permutation_and_padding_mask_map(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        BENCHMARK=Benchmark,
    )

    _test_permutation_and_padding_with_merging_probs(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        BENCHMARK=Benchmark,
    )

2027
2028
2029
2030
2031
2032
2033
2034
2035
    _test_moe_chunk_sort(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        tp_size=4,
        hidden_size=hidden_size,
        BENCHMARK=Benchmark,
    )

2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
    _test_permutation_mask_map_alongside_probs(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        tp_size=4,
    )

2046

2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
def benchmark_single_case(
    te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
):
    torch.cuda.nvtx.range_push(
        f"{num_tokens}-{num_expert}-{hidden_size}-{topK}-{ep_size}-{tp_size}"
    )

    torch.cuda.nvtx.range_push("permutation_index_map_with_probs")
    _test_permutation_index_map(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        with_probs=True,
        BENCHMARK=True,
    )
    torch.cuda.nvtx.range_pop()

    torch.cuda.nvtx.range_push("permutation_mask_map_with_probs")
    _test_permutation_mask_map(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        with_probs=True,
        BENCHMARK=True,
    )
    torch.cuda.nvtx.range_pop()

    torch.cuda.nvtx.range_push("permutation_mask_map_without_probs")
    _test_permutation_mask_map(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        with_probs=False,
        BENCHMARK=True,
    )
    torch.cuda.nvtx.range_pop()

2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
    torch.cuda.nvtx.range_push("permutation_and_padding_mask_map")
    _test_permutation_and_padding_mask_map(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        BENCHMARK=True,
    )
    torch.cuda.nvtx.range_pop()

    torch.cuda.nvtx.range_push("permutation_and_padding_with_merging_probs")
    _test_permutation_and_padding_with_merging_probs(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        BENCHMARK=True,
    )
    torch.cuda.nvtx.range_pop()

2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
    torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs")
    _test_permutation_mask_map_alongside_probs(
        te_dtype=te_dtype,
        num_tokens=num_tokens,
        num_expert=num_expert,
        hidden_size=hidden_size,
        topK=topK,
        num_out_tokens=num_out_tokens,
        tp_size=tp_size,
        BENCHMARK=True,
    )
    torch.cuda.nvtx.range_pop()

    torch.cuda.nvtx.range_pop()


2133
2134
2135
2136
2137
2138
@pytest.mark.skipif(
    os.getenv("RUN_BENCHMARK_TESTS", "0") != "1",
    reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark",
)
def test_benchmark_multiple_cases():
    """Benchmark test - skipped by default. Run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark"""
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
    print("GPU:", torch.cuda.get_device_name(0))

    # te_dtype = tex.DType.kFloat32
    # te_dtype = tex.DType.kFloat16
    te_dtype = tex.DType.kBFloat16

    ep_size = 64
    tp_size = 2
    num_tokens = 4096
    num_expert = 256
    hidden_size = 7168
    topK = 8
    num_out_tokens = num_tokens * topK
    benchmark_single_case(
        te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
    )

    ep_size = 8
    tp_size = 1
    num_tokens = 8192 * 2
    num_expert = 128
    hidden_size = 4096
    topK = 6
    num_out_tokens = num_tokens * topK
    benchmark_single_case(
        te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
    )

    ep_size = 64
    tp_size = 2
    num_tokens = 16384
    num_expert = 4
    hidden_size = 7168
    topK = 1
    num_out_tokens = num_tokens * topK
    benchmark_single_case(
        te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
    )


2179
if __name__ == "__main__":
2180
    test_benchmark_multiple_cases()