test_moe.py 20.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_moe.py`.
"""
import pytest
import torch
8
9
from torch.nn import Parameter
from torch.nn import functional as F
10
11
12
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

13
import vllm.model_executor.layers.fused_moe  # noqa
14
15
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
                                 torch_moe_single)
16
from vllm.model_executor.layers.fused_moe import fused_moe
17
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
18
19
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
    fused_moe as iterative_moe)
20
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
21
    awq_marlin_quantize, marlin_quantize)
22
23
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    quantize_weights)
24
from vllm.model_executor.models.mixtral import MixtralMoE
25
from vllm.platforms import current_platform
26
from vllm.scalar_type import scalar_types
27

28
NUM_EXPERTS = [8, 64]
29
EP_SIZE = [1, 4]
30
TOP_KS = [2, 6]
31

32
33
34

@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
35
@pytest.mark.parametrize("k", [128, 511, 1024])
36
37
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
38
@pytest.mark.parametrize("ep_size", EP_SIZE)
39
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
40
@pytest.mark.parametrize("padding", [True, False])
41
42
43
44
45
46
def test_fused_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
47
    ep_size: int,
48
    dtype: torch.dtype,
49
    padding: bool,
50
):
51
52
53
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
54

55
    score = torch.randn((m, e), device="cuda", dtype=dtype)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randint(0,
                              e, (local_e, ),
                              device="cuda",
                              dtype=torch.int32)
        e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

    torch_output = torch_moe(a, w1, w2, score, topk, e_map)
    iterative_output = iterative_moe(a,
                                     w1,
                                     w2,
                                     score,
                                     topk,
                                     global_num_experts=e,
                                     expert_map=e_map,
                                     renormalize=False)
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    # Pad the weight if moe padding is enabled
    if padding:
        w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
        torch.cuda.empty_cache()
        w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
        torch.cuda.empty_cache()

    triton_output = fused_moe(a,
                              w1,
                              w2,
                              score,
                              topk,
                              global_num_experts=e,
                              expert_map=e_map,
                              renormalize=False)
    torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
96
97
98
99
    torch.testing.assert_close(iterative_output,
                               torch_output,
                               atol=2e-2,
                               rtol=0)
100
101


102
103
104
105
106
@pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
107
@pytest.mark.parametrize("ep_size", EP_SIZE)
108
109
110
111
112
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
113
114
                        ep_size: int, dtype: torch.dtype, group_size: int,
                        has_zp: bool, weight_bits: int):
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    if weight_bits == 4:
        pack_factor = 2
        quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
    elif weight_bits == 8:
        pack_factor = 1
        quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128

    w1_ref = w1.clone()
    w2_ref = w2.clone()
    w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
                             device="cuda",
                             dtype=torch.uint8)
    w2_qweight = torch.empty((e, k, n // pack_factor),
                             device="cuda",
                             dtype=torch.uint8)
    w1_scales = torch.empty((e, 2 * n, k // group_size),
                            device="cuda",
                            dtype=dtype)
    w2_scales = torch.empty((e, k, n // group_size),
                            device="cuda",
                            dtype=dtype)
    w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
                            device="cuda",
                            dtype=torch.uint8)
    w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
                            device="cuda",
                            dtype=torch.uint8)

    for i in range(e * 2):
        expert_id = i % e
        if i // e == 0:
            w, w_ref, w_qweight, w_scales, w_qzeros = \
                w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
        else:
            w, w_ref, w_qweight, w_scales, w_qzeros = \
                w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
        weight, qweight, scales, qzeros = quantize_weights(
            w[expert_id].T, quant_type, group_size, has_zp, False)
        weight = weight.T
        qweight = qweight.T.contiguous().to(torch.uint8)
        scales = scales.T
        if has_zp:
            qzeros = qzeros.T.contiguous().to(torch.uint8)
        if weight_bits == 4:
            qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
            if has_zp:
                qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]

        w_ref[expert_id] = weight
        w_qweight[expert_id] = qweight
        w_scales[expert_id] = scales
        if has_zp:
            w_qzeros[expert_id] = qzeros

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randint(0,
                              e, (local_e, ),
                              device="cuda",
                              dtype=torch.int32)
        e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1_ref = w1_ref[e_ids]
        w2_ref = w2_ref[e_ids]
        w1_qweight = w1_qweight[e_ids]
        w2_qweight = w2_qweight[e_ids]
        w1_scales = w1_scales[e_ids]
        w2_scales = w2_scales[e_ids]
        w1_qzeros = w1_qzeros[e_ids]
        w2_qzeros = w2_qzeros[e_ids]
    else:
        e_map = None

194
195
196
197
198
199
200
201
    triton_output = fused_moe(a,
                              w1_qweight,
                              w2_qweight,
                              score,
                              topk,
                              renormalize=False,
                              use_int4_w4a16=weight_bits == 4,
                              use_int8_w8a16=weight_bits == 8,
202
203
                              global_num_experts=e,
                              expert_map=e_map,
204
205
206
207
208
                              w1_scale=w1_scales,
                              w2_scale=w2_scales,
                              w1_zp=w1_qzeros if has_zp else None,
                              w2_zp=w2_qzeros if has_zp else None,
                              block_shape=[0, group_size])
209
    torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
210
211
212
    torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)


213
214
@pytest.mark.parametrize("dtype",
                         [torch.float32, torch.float16, torch.bfloat16])
215
@pytest.mark.parametrize("padding", [True, False])
216
217
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
218
@torch.inference_mode()
219
220
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
                     monkeypatch):
221
222
    """Make sure our Mixtral MoE implementation agrees with the one from
    huggingface."""
223

224
225
226
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

227
228
229
230
231
232
233
234
235
236
    # Instantiate our and huggingface's MoE blocks
    config = MixtralConfig()
    hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
    vllm_moe = MixtralMoE(
        num_experts=config.num_local_experts,
        top_k=config.num_experts_per_tok,
        hidden_size=config.hidden_size,
        intermediate_size=config.intermediate_size,
        params_dtype=dtype,
        tp_size=1,
237
        dp_size=1,
238
    ).cuda()
239
240

    # Load the weights
241
    vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
242
243
244
    for i in range(config.num_local_experts):
        weights = (hf_moe.experts[i].w1.weight.data,
                   hf_moe.experts[i].w3.weight.data)
245
246
        vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
        vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
247
248

    # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
249
250
251
    hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
    # vLLM uses 1D query [num_tokens, hidden_dim]
    vllm_inputs = hf_inputs.flatten(0, 1)
252

253
254
255
256
257
258
259
260
261
262
263
    # Pad the weight if moe padding is enabled
    if padding:
        vllm_moe.experts.w13_weight = Parameter(F.pad(
            vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
                                                requires_grad=False)
        torch.cuda.empty_cache()
        vllm_moe.experts.w2_weight = Parameter(F.pad(
            vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
                                               requires_grad=False)
        torch.cuda.empty_cache()

264
    # Run forward passes for both MoE blocks
265
266
    hf_states, _ = hf_moe.forward(hf_inputs)
    vllm_states = vllm_moe.forward(vllm_inputs)
267
268
269
270
271
272
273

    mixtral_moe_tol = {
        torch.float32: 1e-3,
        torch.float16: 1e-3,
        torch.bfloat16: 1e-2,
    }

274
275
276
277
278
279
280
281
282
283
284
285
    if use_rocm_aiter:
        # The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
        # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174  # noqa: E501
        torch.testing.assert_close(hf_states.flatten(0, 1),
                                   vllm_states,
                                   rtol=0.01,
                                   atol=100)
    else:
        torch.testing.assert_close(hf_states.flatten(0, 1),
                                   vllm_states,
                                   rtol=mixtral_moe_tol[dtype],
                                   atol=mixtral_moe_tol[dtype])
286
287


288
289
290
291
292
293
294
@pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("ep_size", [1, 4])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
295
@pytest.mark.parametrize("group_size", [-1, 32, 128])
296
@pytest.mark.parametrize("act_order", [True, False])
297
@pytest.mark.parametrize("num_bits", [4, 8])
298
@pytest.mark.parametrize("has_zp", [True, False])
299
@pytest.mark.parametrize("is_k_full", [True, False])
300
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
301
302
303
304
305
306
def test_fused_marlin_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
307
308
    ep_size: int,
    dtype: torch.dtype,
309
310
    group_size: int,
    act_order: bool,
311
    num_bits: int,
312
    has_zp: bool,
313
    is_k_full: bool,
314
):
315
    current_platform.seed_everything(7)
316
317
318
319
320
321
322

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size in (k, n):
            return
323
324
        if has_zp:
            return
325
326
327
    else:
        if not is_k_full:
            return
328

329
330
331
332
333
334
335
336
    if has_zp:
        # we don't build kernel for int8 with zero
        if num_bits == 8:
            return
        quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
    else:
        quant_type = scalar_types.uint4b8 \
                if num_bits == 4 else scalar_types.uint8b128
337
338
339
340
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

341
342
343
344
345
346
347
348
349
350
    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
        e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

351
352
353
    w_ref1_l = []
    qweight1_l = []
    scales1_l = []
354
    zeros1_l = []
355
356
357
358
    g_idx1_l = []
    sort_indices1_l = []

    for i in range(w1.shape[0]):
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
        if has_zp:
            w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
                w1[i].transpose(1, 0), quant_type, group_size)

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            zeros1_l.append(zeros1)
        else:
            test_perm = torch.randperm(k)
            quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type,
                                        group_size, act_order, test_perm)
            w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            g_idx1_l.append(g_idx1)
            sort_indices1_l.append(sort_indices1)
378
379
380
381

    w_ref1 = stack_and_dev(w_ref1_l)
    qweight1 = stack_and_dev(qweight1_l).contiguous()
    scales1 = stack_and_dev(scales1_l)
382
383
384
    g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
    zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
    sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
385
386
387
388

    w_ref2_l = []
    qweight2_l = []
    scales2_l = []
389
    zeros2_l = []
390
391
392
393
    g_idx2_l = []
    sort_indices2_l = []

    for i in range(w2.shape[0]):
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
        if has_zp:
            w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
                w2[i].transpose(1, 0), quant_type, group_size)

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            zeros2_l.append(zeros2)
        else:
            test_perm = torch.randperm(n)
            quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type,
                                        group_size, act_order, test_perm)
            w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            g_idx2_l.append(g_idx2)
            sort_indices2_l.append(sort_indices2)
413
414
415
416

    w_ref2 = stack_and_dev(w_ref2_l)
    qweight2 = stack_and_dev(qweight2_l).contiguous()
    scales2 = stack_and_dev(scales2_l)
417
418
419
    g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
    zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
    sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
420
421
422

    score = torch.randn((m, e), device="cuda", dtype=dtype)

423
424
    topk_weights, topk_ids, token_expert_indices = fused_topk(
        a, score, topk, False)
425

426
427
    torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)

428
    marlin_output = torch.ops.vllm.fused_marlin_moe(
429
430
431
        a,
        qweight1,
        qweight2,
432
433
        scales1,
        scales2,
434
435
436
        score,
        topk_weights,
        topk_ids,
437
438
        global_num_experts=e,
        expert_map=e_map,
439
440
441
442
        g_idx1=g_idx1,
        g_idx2=g_idx2,
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
443
444
        w1_zeros=zeros1,
        w2_zeros=zeros2,
445
        num_bits=num_bits,
446
        is_k_full=is_k_full)
447

448
    torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
449

450
451
452

@pytest.mark.skip("This test is here for the sake of debugging, "
                  "don't run it in automated tests.")
453
454
455
456
457
458
459
@pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128])
460
@pytest.mark.parametrize("act_order", [True, False])
461
@pytest.mark.parametrize("num_bits", [4, 8])
462
@pytest.mark.parametrize("has_zp", [True, False])
463
@pytest.mark.parametrize("is_k_full", [True, False])
464
465
466
467
def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int,
                                    dtype: torch.dtype, group_size: int,
                                    act_order: bool, num_bits: int,
                                    has_zp: bool, is_k_full: bool):
468
469
470
471
    # Filter act_order
    if act_order:
        if group_size == -1:
            return
472
473
474
        if group_size in (k, n):
            return
        if has_zp:
475
            return
476
477
478
    else:
        if not is_k_full:
            return
479

480
481
482
483
484
    if has_zp:
        quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
    else:
        quant_type = scalar_types.uint4b8 \
                if num_bits == 4 else scalar_types.uint8b128
485
486
487
488
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10

    w_ref_l = []
489
    qweight_l = []
490
    scales_l = []
491
    zeros_l = []
492
493
494
495
    g_idx_l = []
    sort_indices_l = []

    for i in range(w.shape[0]):
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
        if has_zp:
            w_ref, qweight, scales, zeros = awq_marlin_quantize(
                w[i].transpose(1, 0), quant_type, group_size)

            w_ref_l.append(w_ref.T)
            qweight_l.append(qweight)
            scales_l.append(scales)
            zeros_l.append(zeros)
        else:
            test_perm = torch.randperm(k)
            w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
                w[i].transpose(1, 0), quant_type, group_size, act_order,
                test_perm)

            w_ref_l.append(w_ref.T)
            qweight_l.append(qweight)
            scales_l.append(scales)
            g_idx_l.append(g_idx)
            sort_indices_l.append(sort_indices)
515
516

    w_ref = stack_and_dev(w_ref_l)
517
    qweight = stack_and_dev(qweight_l).contiguous()
518
    scales = stack_and_dev(scales_l)
519
520
521
    g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
    zeros = stack_and_dev(zeros_l) if zeros_l else None
    sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
522
523

    score = torch.randn((m, e), device="cuda", dtype=dtype)
524
    marlin_output = torch.ops.vllm.single_marlin_moe(
525
526
527
528
529
530
        a,
        qweight,
        scales,
        score,
        topk,
        renormalize=False,
531
532
        g_idx=g_idx,
        sort_indices=sort_indices,
533
        w_zeros=zeros,
534
535
536
        num_bits=num_bits,
        is_k_full=is_k_full,
    )
537

538
    torch_output = torch_moe_single(a, w_ref, score, topk)
539

540
    torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563


def test_moe_align_block_size_opcheck():
    num_experts = 4
    block_size = 4
    topk_ids = torch.randint(0,
                             num_experts, (3, 4),
                             dtype=torch.int32,
                             device='cuda')

    max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
    sorted_ids = torch.empty((max_num_tokens_padded, ),
                             dtype=torch.int32,
                             device=topk_ids.device)
    sorted_ids.fill_(topk_ids.numel())
    max_num_m_blocks = max_num_tokens_padded // block_size
    expert_ids = torch.empty((max_num_m_blocks, ),
                             dtype=torch.int32,
                             device=topk_ids.device)
    num_tokens_post_pad = torch.empty((1),
                                      dtype=torch.int32,
                                      device=topk_ids.device)

564
    opcheck(torch.ops._moe_C.moe_align_block_size,
565
566
            (topk_ids, num_experts, block_size, sorted_ids, expert_ids,
             num_tokens_post_pad))