"vllm/compilation/fusion_attn.py" did not exist on "96846bb3607370798540c7d325f8d06dbd67dcf4"
test_moe.py 22 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_moe.py`.
"""
import pytest
import torch
8
9
from torch.nn import Parameter
from torch.nn import functional as F
10
11
12
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

13
import vllm.model_executor.layers.fused_moe  # noqa
14
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
15
from vllm.config import VllmConfig, set_current_vllm_config
16
from vllm.model_executor.layers.fused_moe import fused_moe
17
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
18
19
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
    fused_moe as iterative_moe)
20
21
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
    rand_marlin_weight_fp4_like)
22
23
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
    marlin_quant_fp8_torch)
24
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
25
    awq_marlin_quantize, marlin_quantize)
26
27
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    quantize_weights)
28
from vllm.model_executor.models.mixtral import MixtralMoE
29
from vllm.platforms import current_platform
30
from vllm.scalar_type import ScalarType, scalar_types
31

32
NUM_EXPERTS = [8, 64]
33
EP_SIZE = [1, 4]
34
TOP_KS = [2, 6]
35

36
37
38
39
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192

40
41
42

@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
43
@pytest.mark.parametrize("k", [128, 511, 1024])
44
45
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
46
@pytest.mark.parametrize("ep_size", EP_SIZE)
47
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
48
@pytest.mark.parametrize("padding", [True, False])
49
50
51
52
53
54
def test_fused_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
55
    ep_size: int,
56
    dtype: torch.dtype,
57
    padding: bool,
58
):
59
60
61
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
62

63
    score = torch.randn((m, e), device="cuda", dtype=dtype)
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randint(0,
                              e, (local_e, ),
                              device="cuda",
                              dtype=torch.int32)
        e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    with set_current_vllm_config(vllm_config):
        torch_output = torch_moe(a, w1, w2, score, topk, e_map)
        iterative_output = iterative_moe(a,
                                         w1,
                                         w2,
                                         score,
                                         topk,
                                         global_num_experts=e,
                                         expert_map=e_map,
                                         renormalize=False)

        # Pad the weight if moe padding is enabled
        if padding:
            w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
            torch.cuda.empty_cache()
            w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
            torch.cuda.empty_cache()

        triton_output = fused_moe(a,
                                  w1,
                                  w2,
                                  score,
                                  topk,
                                  global_num_experts=e,
                                  expert_map=e_map,
                                  renormalize=False)
104
105

    torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
106
107
108
109
    torch.testing.assert_close(iterative_output,
                               torch_output,
                               atol=2e-2,
                               rtol=0)
110
111


112
113
114
115
116
@pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
117
@pytest.mark.parametrize("ep_size", EP_SIZE)
118
119
120
121
122
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
123
124
                        ep_size: int, dtype: torch.dtype, group_size: int,
                        has_zp: bool, weight_bits: int):
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    if weight_bits == 4:
        pack_factor = 2
        quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
    elif weight_bits == 8:
        pack_factor = 1
        quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128

    w1_ref = w1.clone()
    w2_ref = w2.clone()
    w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
                             device="cuda",
                             dtype=torch.uint8)
    w2_qweight = torch.empty((e, k, n // pack_factor),
                             device="cuda",
                             dtype=torch.uint8)
    w1_scales = torch.empty((e, 2 * n, k // group_size),
                            device="cuda",
                            dtype=dtype)
    w2_scales = torch.empty((e, k, n // group_size),
                            device="cuda",
                            dtype=dtype)
    w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
                            device="cuda",
                            dtype=torch.uint8)
    w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
                            device="cuda",
                            dtype=torch.uint8)

    for i in range(e * 2):
        expert_id = i % e
        if i // e == 0:
            w, w_ref, w_qweight, w_scales, w_qzeros = \
                w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
        else:
            w, w_ref, w_qweight, w_scales, w_qzeros = \
                w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
        weight, qweight, scales, qzeros = quantize_weights(
            w[expert_id].T, quant_type, group_size, has_zp, False)
        weight = weight.T
        qweight = qweight.T.contiguous().to(torch.uint8)
        scales = scales.T
        if has_zp:
            qzeros = qzeros.T.contiguous().to(torch.uint8)
        if weight_bits == 4:
            qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
            if has_zp:
                qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]

        w_ref[expert_id] = weight
        w_qweight[expert_id] = qweight
        w_scales[expert_id] = scales
        if has_zp:
            w_qzeros[expert_id] = qzeros

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randint(0,
                              e, (local_e, ),
                              device="cuda",
                              dtype=torch.int32)
        e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1_ref = w1_ref[e_ids]
        w2_ref = w2_ref[e_ids]
        w1_qweight = w1_qweight[e_ids]
        w2_qweight = w2_qweight[e_ids]
        w1_scales = w1_scales[e_ids]
        w2_scales = w2_scales[e_ids]
        w1_qzeros = w1_qzeros[e_ids]
        w2_qzeros = w2_qzeros[e_ids]
    else:
        e_map = None

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    with set_current_vllm_config(vllm_config):
        triton_output = fused_moe(a,
                                  w1_qweight,
                                  w2_qweight,
                                  score,
                                  topk,
                                  renormalize=False,
                                  use_int4_w4a16=weight_bits == 4,
                                  use_int8_w8a16=weight_bits == 8,
                                  global_num_experts=e,
                                  expert_map=e_map,
                                  w1_scale=w1_scales,
                                  w2_scale=w2_scales,
                                  w1_zp=w1_qzeros if has_zp else None,
                                  w2_zp=w2_qzeros if has_zp else None,
                                  block_shape=[0, group_size])
        torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)

221
222
223
    torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)


224
225
@pytest.mark.parametrize("dtype",
                         [torch.float32, torch.float16, torch.bfloat16])
226
@pytest.mark.parametrize("padding", [True, False])
227
228
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
229
@torch.inference_mode()
230
231
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
                     monkeypatch):
232
233
    """Make sure our Mixtral MoE implementation agrees with the one from
    huggingface."""
234

235
236
237
238
    # clear the cache before every test
    from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
        is_rocm_aiter_moe_enabled)
    is_rocm_aiter_moe_enabled.cache_clear()
239
240
241
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

242
243
244
        if dtype == torch.float32:
            pytest.skip("AITER ROCm test skip for float32")

245
246
247
248
249
250
251
252
253
254
    # Instantiate our and huggingface's MoE blocks
    config = MixtralConfig()
    hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
    vllm_moe = MixtralMoE(
        num_experts=config.num_local_experts,
        top_k=config.num_experts_per_tok,
        hidden_size=config.hidden_size,
        intermediate_size=config.intermediate_size,
        params_dtype=dtype,
        tp_size=1,
255
        dp_size=1,
256
    ).cuda()
257
258

    # Load the weights
259
    vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
260
261
262
    for i in range(config.num_local_experts):
        weights = (hf_moe.experts[i].w1.weight.data,
                   hf_moe.experts[i].w3.weight.data)
263
264
        vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
        vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
265
266

    # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
267
268
269
    hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
    # vLLM uses 1D query [num_tokens, hidden_dim]
    vllm_inputs = hf_inputs.flatten(0, 1)
270

271
272
273
274
275
276
277
278
279
280
281
    # Pad the weight if moe padding is enabled
    if padding:
        vllm_moe.experts.w13_weight = Parameter(F.pad(
            vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
                                                requires_grad=False)
        torch.cuda.empty_cache()
        vllm_moe.experts.w2_weight = Parameter(F.pad(
            vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
                                               requires_grad=False)
        torch.cuda.empty_cache()

282
    # Run forward passes for both MoE blocks
283
284
    hf_states, _ = hf_moe.forward(hf_inputs)
    vllm_states = vllm_moe.forward(vllm_inputs)
285
286
287
288
289
290
291

    mixtral_moe_tol = {
        torch.float32: 1e-3,
        torch.float16: 1e-3,
        torch.bfloat16: 1e-2,
    }

292
293
294
295
296
297
298
299
300
301
302
303
    if use_rocm_aiter:
        # The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
        # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174  # noqa: E501
        torch.testing.assert_close(hf_states.flatten(0, 1),
                                   vllm_states,
                                   rtol=0.01,
                                   atol=100)
    else:
        torch.testing.assert_close(hf_states.flatten(0, 1),
                                   vllm_states,
                                   rtol=mixtral_moe_tol[dtype],
                                   atol=mixtral_moe_tol[dtype])
304
305


306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def marlin_moe_generate_valid_test_cases():
    import itertools
    m_list = [1, 123, 666]
    n_list = [128, 1024]
    k_list = [256, 2048]
    e_list = [4, 12]
    topk_list = [2, 3]
    ep_size_list = [1, 4]
    dtype_list = [torch.half, torch.bfloat16]
    group_size_list = [-1, 16, 32, 128]
    act_order_list = [True, False]
    quant_type_list = [
        scalar_types.float4_e2m1f,
        scalar_types.float8_e4m3fn,
        scalar_types.uint4,
        scalar_types.uint4b8,
        scalar_types.uint8b128,
    ]
    is_k_full_list = [True, False]

    all_combinations = itertools.product(m_list, n_list, k_list, e_list,
                                         topk_list, ep_size_list, dtype_list,
                                         group_size_list, act_order_list,
                                         quant_type_list, is_k_full_list)

    def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order,
                   quant_type, is_k_full):

        if quant_type == scalar_types.float8_e4m3fn and \
                group_size not in [-1, 128]:
            return False
        if quant_type == scalar_types.float4_e2m1f and group_size != 16:
            return False
        if quant_type != scalar_types.float4_e2m1f and group_size == 16:
            return False

        # Filter act_order
        if act_order:
            if group_size in (-1, k, n):
                return False
            if quant_type not in [scalar_types.uint4b8]:
                return False
        elif not is_k_full:
            return False

        return True

    cases = []
    for case in all_combinations:
        if is_invalid(*case):
            cases.append(case)
    return cases


360
@pytest.mark.flaky(reruns=2)
361
362
363
@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
                          "act_order, quant_type, is_k_full"),
                         marlin_moe_generate_valid_test_cases())
364
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
365
366
367
368
369
370
def test_fused_marlin_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
371
372
    ep_size: int,
    dtype: torch.dtype,
373
374
    group_size: int,
    act_order: bool,
375
    quant_type: ScalarType,
376
    is_k_full: bool,
377
):
378
379
380
381
382
383
384
385
    torch.cuda.manual_seed(0)
    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]

    if quant_type == scalar_types.float8_e4m3fn:
        if group_size not in [-1, 128]:
            return
        if act_order:
            return
386
387
388

    # Filter act_order
    if act_order:
389
390
        if quant_type == scalar_types.float8_e4m3fn:
            return
391
392
393
394
        if group_size == -1:
            return
        if group_size in (k, n):
            return
395
396
        if has_zp:
            return
397
398
399
    else:
        if not is_k_full:
            return
400

401
402
403
404
405
    if quant_type == scalar_types.float4_e2m1f and group_size != 16:
        return
    if quant_type != scalar_types.float4_e2m1f and group_size == 16:
        return

406
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
407
408
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
409

410
411
412
413
414
415
416
417
418
419
    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
        e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

420
421
422
    w_ref1_l = []
    qweight1_l = []
    scales1_l = []
423
    global_scale1_l = []
424
    zeros1_l = []
425
426
427
428
    g_idx1_l = []
    sort_indices1_l = []

    for i in range(w1.shape[0]):
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        if quant_type == scalar_types.float4_e2m1f:
            w_ref1, qweight1, scales1, global_scale1 = \
                rand_marlin_weight_fp4_like(w1[i], group_size)

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            global_scale1_l.append(global_scale1)
        elif quant_type == scalar_types.float8_e4m3fn:
            w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
                w1[i], group_size)
            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
        elif has_zp:
444
445
446
447
448
449
450
            w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
                w1[i].transpose(1, 0), quant_type, group_size)

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            zeros1_l.append(zeros1)
451
        else:
452
            test_perm = torch.randperm(k)
453
454
455
            w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
                marlin_quantize(w1[i].transpose(1, 0), quant_type,
                                group_size, act_order, test_perm)
456
457
458
459
460
461

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            g_idx1_l.append(g_idx1)
            sort_indices1_l.append(sort_indices1)
462
463
464
465

    w_ref1 = stack_and_dev(w_ref1_l)
    qweight1 = stack_and_dev(qweight1_l).contiguous()
    scales1 = stack_and_dev(scales1_l)
466
    global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
467
468
469
    g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
    zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
    sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
470
471
472
473

    w_ref2_l = []
    qweight2_l = []
    scales2_l = []
474
    global_scale2_l = []
475
    zeros2_l = []
476
477
478
479
    g_idx2_l = []
    sort_indices2_l = []

    for i in range(w2.shape[0]):
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
        if quant_type == scalar_types.float4_e2m1f:
            w_ref2, qweight2, scales2, global_scale2 = \
                rand_marlin_weight_fp4_like(w2[i], group_size)

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            global_scale2_l.append(global_scale2)
        elif quant_type == scalar_types.float8_e4m3fn:
            w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
                w2[i], group_size)
            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
        elif has_zp:
495
496
497
498
499
500
501
            w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
                w2[i].transpose(1, 0), quant_type, group_size)

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            zeros2_l.append(zeros2)
502
        else:
503
            test_perm = torch.randperm(n)
504
505
506
            w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
                marlin_quantize(w2[i].transpose(1, 0), quant_type,
                                group_size, act_order, test_perm)
507
508
509
510
511
512

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            g_idx2_l.append(g_idx2)
            sort_indices2_l.append(sort_indices2)
513
514
515
516

    w_ref2 = stack_and_dev(w_ref2_l)
    qweight2 = stack_and_dev(qweight2_l).contiguous()
    scales2 = stack_and_dev(scales2_l)
517
    global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
518
519
520
    g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
    zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
    sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
521
522
523

    score = torch.randn((m, e), device="cuda", dtype=dtype)

524
    topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
525

526
527
    with set_current_vllm_config(vllm_config):
        torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
528

529
    marlin_output = torch.ops.vllm.fused_marlin_moe(
530
531
532
        a,
        qweight1,
        qweight2,
533
534
        scales1,
        scales2,
535
536
537
        score,
        topk_weights,
        topk_ids,
538
539
        global_num_experts=e,
        expert_map=e_map,
540
541
        global_scale1=global_scale1,
        global_scale2=global_scale2,
542
543
544
545
        g_idx1=g_idx1,
        g_idx2=g_idx2,
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
546
547
        w1_zeros=zeros1,
        w2_zeros=zeros2,
548
        quant_type_id=quant_type.id,
549
        is_k_full=is_k_full)
550

551
    torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574


def test_moe_align_block_size_opcheck():
    num_experts = 4
    block_size = 4
    topk_ids = torch.randint(0,
                             num_experts, (3, 4),
                             dtype=torch.int32,
                             device='cuda')

    max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
    sorted_ids = torch.empty((max_num_tokens_padded, ),
                             dtype=torch.int32,
                             device=topk_ids.device)
    sorted_ids.fill_(topk_ids.numel())
    max_num_m_blocks = max_num_tokens_padded // block_size
    expert_ids = torch.empty((max_num_m_blocks, ),
                             dtype=torch.int32,
                             device=topk_ids.device)
    num_tokens_post_pad = torch.empty((1),
                                      dtype=torch.int32,
                                      device=topk_ids.device)

575
    opcheck(torch.ops._moe_C.moe_align_block_size,
576
577
            (topk_ids, num_experts, block_size, sorted_ids, expert_ids,
             num_tokens_post_pad))
bnellnm's avatar
bnellnm committed
578
579


580
@pytest.mark.parametrize("m", [1, 33, 64, 222])
bnellnm's avatar
bnellnm committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype",
                         [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
    input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
    actual = torch.empty((m, k), device="cuda", dtype=dtype)

    expected = input.sum(dim=1)
    torch.ops._moe_C.moe_sum(input, actual)

    torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)

    opcheck(torch.ops._moe_C.moe_sum, (input, actual))