test_ocp_mx_moe.py 32.7 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import importlib.metadata
from dataclasses import dataclass
6
from importlib.util import find_spec
Lain's avatar
Lain committed
7
from typing import Optional
8
9
10
11
12

import pytest
import torch
from packaging import version

Lain's avatar
Lain committed
13
from vllm.platforms import current_platform
14
from vllm.utils.flashinfer import has_flashinfer
Lain's avatar
Lain committed
15

16
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
17
18
    importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
19

20
21
22
TRTLLM_GEN_MXFP4_AVAILABLE = (
    current_platform.is_cuda() and current_platform.is_device_capability(100)
)
Lain's avatar
Lain committed
23

24
25
26
27
28
HOPPER_MXFP4_BF16_AVAILABLE = (
    current_platform.is_cuda()
    and current_platform.is_device_capability(90)
    and has_flashinfer()
)
29

Lain's avatar
Lain committed
30
if TRTLLM_GEN_MXFP4_AVAILABLE:
31
32
33
34
35
36
37
38
39
    from flashinfer import (
        fp4_quantize,
        mxfp8_quantize,
        next_positive_power_of_2,
        reorder_rows_for_gated_act_gemm,
        shuffle_matrix_a,
        shuffle_matrix_sf_a,
        trtllm_fp4_block_scale_moe,
    )
40
41
    from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
    from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
Lain's avatar
Lain committed
42

43
44
45
46
47
48
49

@dataclass
class ModelCase:
    model_id: str
    tp: int


50
51
52
53
54
55
@pytest.fixture(scope="function", autouse=True)
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")


56
57
58
@pytest.mark.parametrize(
    "model_case",
    [
59
        ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=2),
60
61
        ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
        ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
62
63
        ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=1),
        ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=4),
64
65
66
    ],
)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
67
68
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
    if torch.cuda.device_count() < model_case.tp:
69
70
71
72
        pytest.skip(
            f"This test requires >={model_case.tp} gpus, got only "
            f"{torch.cuda.device_count()}"
        )
73

74
    # `cuda_graph_sizes=[16]` to reduce load time.
75
    with vllm_runner(
76
77
78
79
        model_case.model_id,
        tensor_parallel_size=model_case.tp,
        load_format="dummy",
        cuda_graph_sizes=[16],
80
    ) as llm:
81
82
83
84
85
86
87
        # Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
        # def check_model(model):
        #     from vllm.model_executor.layers.quantization.quark.quark import (  # noqa: E501
        #         QuarkLinearMethod)
        #     from vllm.model_executor.layers.quantization.quark.schemes.quark_ocp_mx import QuarkOCP_MX  # noqa: E501
        #     from vllm.model_executor.layers.quantization.quark.quark_moe import (  # noqa: E501
        #         QuarkOCP_MX_MoEMethod)
88

89
        #     layer = model.model.layers[0]
90

91
        #     qkv_proj = layer.self_attn.qkv_proj
92

93
94
        #     assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
        #     assert isinstance(qkv_proj.scheme, QuarkOCP_MX)
95

96
97
        #     assert isinstance(layer.mlp.experts.quant_method,
        #                       QuarkOCP_MX_MoEMethod)
98

99
100
        # if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
        #     llm.apply_model(check_model)
101

102
        output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
Lain's avatar
Lain committed
103
104
105
        assert output


106
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: Optional[float] = None):
Lain's avatar
Lain committed
107
108
109
110
111
112
113
114
115
    # Note we add an extra bias of 1 to the linear layer
    x_glu, x_linear = torch.chunk(x, 2, dim=-1)
    if limit is not None:
        x_glu = x_glu.clamp(max=limit)
        x_linear = x_linear.clamp(min=-limit, max=limit)
    out_glu = x_glu * torch.sigmoid(alpha * x_glu)
    return out_glu * (x_linear + beta)


116
fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6]
Lain's avatar
Lain committed
117
118
119
120
121


def mxfp4_dequantize(x, scale):
    assert x.dtype == torch.uint8
    x = x.view(torch.uint8).to(torch.int32)
122
123
124
    x_unpacked = torch.zeros(
        *x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device
    )
Lain's avatar
Lain committed
125
126
127
    x_unpacked[..., 0::2].copy_(x & 0xF)
    x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)

128
    x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device)
Lain's avatar
Lain committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    for i, val in enumerate(fp4_lookup_table):
        x_float[x_unpacked == i] = val

    scale = scale.view(torch.uint8).to(torch.int32)
    scale = (scale << 23).view(torch.float32)
    scale = scale.reshape(*x.shape[:-1], -1)
    scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)

    return x_float * scale


def mxfp8_dequantize(x, scale):
    assert x.dtype == torch.float8_e4m3fn
    x_float = x.to(torch.float32)

    scale = scale.view(torch.uint8).to(torch.int32)
    scale = (scale << 23).view(torch.float32)
    scale = scale.reshape(*x.shape[:-1], -1)
    scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)

    return x_float * scale


def reference_moe(
    roouting_logits,
    topk,
    num_experts,
    hidden_states,
    w13,
    bias13,
    w2,
    bias2,
    alpha,
    beta,
    limit,
    act_type,
):
    # renormalize routing
    experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
    expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
    expert_indices = experts.indices
    t = hidden_states.clone()
    # MLP #1
    mlp1_weight = w13[expert_indices, ...]
    mlp1_bias = bias13[expert_indices, ...]
    t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
    t = swiglu(t, alpha=alpha, beta=beta, limit=limit)

177
178
179
180
    if act_type == "mxfp8":
        t_quantized, t_scale = mxfp8_quantize(
            t.to(torch.bfloat16), is_sf_swizzled_layout=False
        )
Lain's avatar
Lain committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        t = mxfp8_dequantize(t_quantized, t_scale)
    # MLP #2
    mlp2_weight = w2[expert_indices, ...]
    mlp2_bias = bias2[expert_indices, ...]
    t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias
    # Weighted sum of experts
    t = torch.einsum("bec,be->bc", t, expert_weights)
    assert t.shape == hidden_states.shape
    return t.to(torch.bfloat16)


def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int):
    # Number of tokens in the input tensor.
    num_tokens = x.shape[0]
    # Factor to account for the imbalance of the experts.
    # factor equals to the
    # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
    # - 1.0 means perfect expert distribution.
    # - > 1.0 means some experts have more
    #     tokens than the perfect distribution.
    # - < 1.0 does not make sense.
    imbalance_factor = 1.3
    # Calculate the number of tokens per expert
    # assuming perfect distribution.
    num_tokens_per_expert = (num_tokens * top_k) // num_experts
    # Apply the imbalance factor.
    num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
    # And pad the number to the next power of 2.
    tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
    # Cap to 8-64 tokens per CTA tile
    # as it's the range supported by the kernel.
    tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
    return tile_tokens_dim


def tg_mxfp4_moe(
    router_logits,
    topk,
    num_experts,
    intermediate_size,
    hidden_size,
    hidden_states,
    hidden_states_scale,
    w13_weight,
    w13_weight_scale,
    w13_bias,
    w2_weight,
    w2_weight_scale,
    w2_bias,
    act_type,
    alpha,
    beta,
    limit,
234
    transpose_optimized: bool = False,
Lain's avatar
Lain committed
235
236
) -> torch.Tensor:
    sf_block_size = 32
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    assert (
        w13_weight.dim() == 3
        and w13_weight.shape[0] == num_experts
        and w13_weight.shape[1] == intermediate_size * 2
        and w13_weight.shape[2] == hidden_size // 2
    )
    assert (
        w13_weight_scale.dim() == 3
        and w13_weight_scale.shape[0] == num_experts
        and w13_weight_scale.shape[1] == intermediate_size * 2
        and w13_weight_scale.shape[2] == hidden_size // sf_block_size
    )
    assert (
        w2_weight.dim() == 3
        and w2_weight.shape[0] == num_experts
        and w2_weight.shape[1] == hidden_size
        and w2_weight.shape[2] == intermediate_size // 2
    )
    assert (
        w2_weight_scale.dim() == 3
        and w2_weight_scale.shape[1] == hidden_size
        and w2_weight_scale.shape[2] == intermediate_size // sf_block_size
    )
    assert (
        w13_bias.dim() == 2
        and w13_bias.shape[0] == num_experts
        and w13_bias.shape[1] == intermediate_size * 2
    )
    assert (
        w2_bias.dim() == 2
        and w2_bias.shape[0] == num_experts
        and w2_bias.shape[1] == hidden_size
    )
Lain's avatar
Lain committed
270

co63oc's avatar
co63oc committed
271
    # Swap w1 and w3 as the definition of
Lain's avatar
Lain committed
272
273
274
275
    # swiglu is different in the trtllm-gen
    w13_weight_scale_ = w13_weight_scale.clone()
    w13_weight_ = w13_weight.clone()
    w13_bias_ = w13_bias.clone()
276
277
    w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :])
    w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :])
Lain's avatar
Lain committed
278
    w13_weight_scale[:, :intermediate_size, :].copy_(
279
280
        w13_weight_scale_[:, intermediate_size:, :]
    )
Lain's avatar
Lain committed
281
    w13_weight_scale[:, intermediate_size:, :].copy_(
282
283
        w13_weight_scale_[:, :intermediate_size, :]
    )
Lain's avatar
Lain committed
284
285
286
287
288
289
290
291
292
    w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
    w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])

    # Interleave the weights and scaling factors for activation
    w13_weight_interleaved = []
    w13_weight_scale_interleaved = []
    w13_bias_interleaved = []
    for i in range(num_experts):
        w13_weight_interleaved.append(
293
294
            reorder_rows_for_gated_act_gemm(w13_weight[i].clone())
        )
Lain's avatar
Lain committed
295
        w13_weight_scale_interleaved.append(
296
297
            reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())
        )
Lain's avatar
Lain committed
298
        w13_bias_interleaved.append(
299
300
            reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1))
        )
Lain's avatar
Lain committed
301
    w13_weight = torch.stack(w13_weight_interleaved).reshape(
302
303
        num_experts, 2 * intermediate_size, hidden_size // 2
    )
Lain's avatar
Lain committed
304
    w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
305
306
        num_experts, 2 * intermediate_size, hidden_size // 32
    )
Lain's avatar
Lain committed
307
    w13_bias = torch.stack(w13_bias_interleaved).reshape(
308
309
        num_experts, 2 * intermediate_size
    )
Lain's avatar
Lain committed
310
311
312
313
314
315
316
317
318

    # Shuffle weights and scaling factors for transposed mma output
    gemm1_weights_shuffled = []
    gemm1_scales_shuffled = []
    gemm2_weights_shuffled = []
    gemm2_scales_shuffled = []
    gemm1_bias_shuffled = []
    gemm2_bias_shuffled = []
    epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
319
320
321
322
323
324
325
326
327
    _cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
    if transpose_optimized:
        for i in range(num_experts):
            # w13 weight shuffling
            permute_indices = _maybe_get_cached_w2_permute_indices(
                _cache_permute_indices,
                w13_weight[i].view(torch.uint8),
                epilogue_tile_m,
            )
328
329
330
331
332
            gemm1_weights_shuffled.append(
                w13_weight[i]
                .view(torch.uint8)[permute_indices.to(w13_weight.device)]
                .contiguous()
            )
333
334
335
336
337
338
339
340
            # w13 scale shuffling
            permute_sf_indices = _maybe_get_cached_w2_permute_indices(
                _cache_permute_indices,
                w13_weight_scale[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
            gemm1_scales_shuffled.append(
341
342
343
344
345
346
                nvfp4_block_scale_interleave(
                    w13_weight_scale[i]
                    .view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
                    .contiguous()
                )
            )
347
348
349
350
351
352
            # w13 bias shuffling
            permute_bias_indices = _maybe_get_cached_w2_permute_indices(
                _cache_permute_indices,
                w13_bias[i].clone().reshape(-1, 1),
                epilogue_tile_m,
            )
353
354
355
356
357
358
            gemm1_bias_shuffled.append(
                w13_bias[i]
                .clone()
                .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
                .contiguous()
            )
359
360
361
362
363
364
            # w2 weight shuffling
            permute_indices = _maybe_get_cached_w2_permute_indices(
                _cache_permute_indices,
                w2_weight[i].view(torch.uint8),
                epilogue_tile_m,
            )
365
366
367
368
369
            gemm2_weights_shuffled.append(
                w2_weight[i]
                .view(torch.uint8)[permute_indices.to(w2_weight.device)]
                .contiguous()
            )
370
371
372
373
374
375
376
377
            # w2 scale shuffling
            permute_sf_indices = _maybe_get_cached_w2_permute_indices(
                _cache_permute_indices,
                w2_weight_scale[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
            gemm2_scales_shuffled.append(
378
379
380
381
382
383
                nvfp4_block_scale_interleave(
                    w2_weight_scale[i]
                    .view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
                    .contiguous()
                )
            )
384
385
386
387
388
389
            # w2 bias shuffling
            permute_indices = _maybe_get_cached_w2_permute_indices(
                _cache_permute_indices,
                w2_bias[i].clone().reshape(-1, 1),
                epilogue_tile_m,
            )
390
391
392
393
394
395
            gemm2_bias_shuffled.append(
                w2_bias[i]
                .clone()
                .reshape(-1, 1)[permute_indices.to(w2_bias.device)]
                .contiguous()
            )
396
397
398
399

    else:
        for i in range(num_experts):
            gemm1_weights_shuffled.append(
400
401
                shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
            )
402
            gemm1_scales_shuffled.append(
403
404
405
406
                shuffle_matrix_sf_a(
                    w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
                )
            )
407
408

            gemm2_weights_shuffled.append(
409
410
                shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
            )
411
            gemm2_scales_shuffled.append(
412
413
414
415
                shuffle_matrix_sf_a(
                    w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
                )
            )
416
            gemm1_bias_shuffled.append(
417
418
                shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)
            )
419
            gemm2_bias_shuffled.append(
420
421
                shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)
            )
Lain's avatar
Lain committed
422
423

    w13_weight = torch.stack(gemm1_weights_shuffled)
424
425
426
427
428
    w13_weight_scale = (
        torch.stack(gemm1_scales_shuffled)
        .reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
        .view(torch.float8_e4m3fn)
    )
Lain's avatar
Lain committed
429
430
431
    w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)

    w2_weight = torch.stack(gemm2_weights_shuffled)
432
433
434
435
436
    w2_weight_scale = (
        torch.stack(gemm2_scales_shuffled)
        .reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
        .view(torch.float8_e4m3fn)
    )
Lain's avatar
Lain committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)

    tg_result = trtllm_fp4_block_scale_moe(
        routing_logits=router_logits.to(torch.bfloat16),
        routing_bias=None,
        hidden_states=hidden_states,
        hidden_states_scale=hidden_states_scale,
        gemm1_weights=w13_weight,
        gemm1_weights_scale=w13_weight_scale,
        gemm1_bias=w13_bias,
        gemm1_alpha=alpha,
        gemm1_beta=beta,
        gemm1_clamp_limit=limit,
        gemm2_weights=w2_weight,
        gemm2_weights_scale=w2_weight_scale,
        gemm2_bias=w2_bias,
        output1_scale_scalar=None,
        output1_scale_gate_scalar=None,
        output2_scale_scalar=None,
        num_experts=num_experts,
        top_k=topk,
        n_group=None,
        topk_group=None,
        intermediate_size=intermediate_size,
        local_expert_offset=0,
        local_num_experts=num_experts,
        routed_scaling_factor=None,
        tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
        routing_method_type=1,  # renormalize
466
467
        do_finalize=True,
    )[0]
Lain's avatar
Lain committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
    return tg_result


def check_accuracy(a, b, atol, rtol, percent):
    """Allow a mismatch percentage of 1 - percent."""
    if torch.any(torch.isnan(a)):
        raise Exception("NaN in reference output")
    if torch.any(torch.isnan(b)):
        raise Exception("NaN in actual output")
    if torch.any(torch.isinf(a)):
        raise Exception("Inf in reference output")
    if torch.any(torch.isinf(b)):
        raise Exception("Inf in actual output")
    assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"

    left = torch.abs(a - b)
    right = atol + rtol * torch.abs(b)
    count = torch.sum(left > right)
    mismatch_percent = count / a.numel()
    if mismatch_percent > 1 - percent:
        raise Exception(
            f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
490
491
            f"(threshold: {1 - percent:.4f})"
        )
Lain's avatar
Lain committed
492
493
494
495
496
497


@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32, 128])
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
498
499
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
@pytest.mark.parametrize("act_type", ["mxfp8", "bf16"])
500
@pytest.mark.parametrize("transpose_optimized", [False, True])
Lain's avatar
Lain committed
501
502
@pytest.mark.skipif(
    not TRTLLM_GEN_MXFP4_AVAILABLE,
503
504
    reason="nvidia gpu and compute capability sm100 is required for this test",
)
Lain's avatar
Lain committed
505
506
507
508
509
510
511
512
513
514
def test_trtllm_gen_mxfp4_fused_moe(
    topk: int,
    num_experts: int,
    num_tokens: int,
    intermediate_size: int,
    hidden_size: int,
    alpha: float,
    beta: float,
    limit: Optional[float],
    act_type: str,
515
    transpose_optimized: bool,
Lain's avatar
Lain committed
516
517
518
):
    seed = 42
    torch.manual_seed(seed)
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
    hidden_states = torch.randn(
        num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16
    )
    w13 = torch.randn(
        num_experts,
        intermediate_size * 2,
        hidden_size,
        device="cuda:0",
        dtype=torch.bfloat16,
    )
    w2 = torch.randn(
        num_experts,
        hidden_size,
        intermediate_size,
        device="cuda:0",
        dtype=torch.bfloat16,
    )
    bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10
Lain's avatar
Lain committed
537
    bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
538
539
540
541
542
543
544
545
546
    router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda()

    w13, w13_scale = fp4_quantize(
        w13,
        torch.tensor(1.0, device="cuda:0"),
        32,
        sf_use_ue8m0=True,
        is_sf_swizzled_layout=False,
    )
Lain's avatar
Lain committed
547
    w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
548
549
550
551
552
553
554
555
556
        num_experts, intermediate_size * 2, hidden_size // 32
    )
    w2, w2_scale = fp4_quantize(
        w2,
        torch.tensor(1.0, device="cuda:0"),
        32,
        sf_use_ue8m0=True,
        is_sf_swizzled_layout=False,
    )
Lain's avatar
Lain committed
557
    w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
558
559
560
        num_experts, hidden_size, intermediate_size // 32
    )
    if act_type == "mxfp8":
Lain's avatar
Lain committed
561
        hidden_states, hidden_states_scale = mxfp8_quantize(
562
563
564
            hidden_states, is_sf_swizzled_layout=False
        )
        hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1)
Lain's avatar
Lain committed
565
566
567
568
569
570
571
572
573
    else:
        hidden_states_scale = None

    # reference result
    ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16)
    w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone())
    w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
    bias13_ref = bias13
    bias2_ref = bias2
574
575
576
577
    if act_type == "mxfp8":
        hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to(
            torch.float32
        )
Lain's avatar
Lain committed
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    else:
        hidden_states_ref = hidden_states.to(torch.float32)
    # Process tokens in chunks of 32 to reduce memory usage
    chunk_size = 32
    num_chunks = (num_tokens + chunk_size - 1) // chunk_size
    for i in range(num_chunks):
        start_idx = i * chunk_size
        end_idx = min(start_idx + chunk_size, num_tokens)
        chunk_result = reference_moe(
            router_logits[start_idx:end_idx].to(torch.float32),
            topk,
            num_experts,
            hidden_states_ref[start_idx:end_idx],
            w13_ref,
            bias13_ref,
            w2_ref,
            bias2_ref,
            alpha,
            beta,
            limit,
            act_type,
        )
        ref_result[start_idx:end_idx].copy_(chunk_result)

    # trtllm-gen result
    if alpha is not None:
604
        alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
Lain's avatar
Lain committed
605
    if limit is not None:
606
        limit = torch.full((num_experts,), limit, device=hidden_states.device)
Lain's avatar
Lain committed
607
    if beta is not None:
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        beta = torch.full((num_experts,), beta, device=hidden_states.device)
    tg_result = tg_mxfp4_moe(
        router_logits,
        topk,
        num_experts,
        intermediate_size,
        hidden_size,
        hidden_states,
        hidden_states_scale,
        w13,
        w13_scale,
        bias13,
        w2,
        w2_scale,
        bias2,
        act_type,
        alpha=alpha,
        beta=beta,
        limit=limit,
        transpose_optimized=transpose_optimized,
    )
Lain's avatar
Lain committed
629
630
    # relatively loose check since the mxfp4 quantization is less accurate
    check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649


def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
    """Interleave scales on the last dimension by groups of 4, matching
    the transformation in mxfp4.py's BF16 (Hopper) path."""
    s = scales.to(torch.uint8)
    s_shape = s.shape
    assert s_shape[-1] % 4 == 0
    s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4)
    # Move the 4-group dimension before the row dimension
    permuted = s.permute(0, 2, 1, 3)
    # Merge the row dim with the 4-group dim
    return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4)


@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
650
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@pytest.mark.skipif(
    not HOPPER_MXFP4_BF16_AVAILABLE,
    reason="nvidia gpu sm90 and flashinfer are required for this test",
)
def test_flashinfer_cutlass_mxfp4_fused_moe(
    topk: int,
    num_experts: int,
    num_tokens: int,
    intermediate_size: int,
    hidden_size: int,
    alpha: float,
    beta: float,
    limit: Optional[float],
):
    torch.manual_seed(42)
    device = "cuda:0"

    # Inputs
669
670
671
    hidden_states = torch.randn(
        num_tokens, hidden_size, device=device, dtype=torch.bfloat16
    )
672
673
674
    # Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
    w13_q = torch.randint(
        0,
675
676
        256,
        (num_experts, 2 * intermediate_size, hidden_size // 2),
677
        device=device,
678
679
        dtype=torch.uint8,
    )
680
681
    w13_scale = torch.randint(
        118,
682
683
        123,
        (num_experts, 2 * intermediate_size, hidden_size // 32),
684
        device=device,
685
686
        dtype=torch.uint8,
    )
687

688
689
690
691
692
693
694
    w2_q = torch.randint(
        0,
        256,
        (num_experts, hidden_size, intermediate_size // 2),
        device=device,
        dtype=torch.uint8,
    )
695
696
    w2_scale = torch.randint(
        118,
697
698
        123,
        (num_experts, hidden_size, intermediate_size // 32),
699
        device=device,
700
701
        dtype=torch.uint8,
    )
702
    # Bias contiguous [b1; b3]
703
704
705
706
707
708
709
710
711
712
713
714
    bias13 = (
        torch.randn(
            num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
        )
        * 10
    )
    bias2 = (
        torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
    )
    router_logits = torch.rand(
        num_tokens, num_experts, dtype=torch.float32, device=device
    )
715
716

    w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
717
718
        num_experts, 2 * intermediate_size, hidden_size
    )
719
    w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
        num_experts, hidden_size, intermediate_size
    )
    ref = reference_moe(
        router_logits.to(torch.float32),
        topk,
        num_experts,
        hidden_states.to(torch.float32),
        w13_ref,
        bias13.to(torch.float32),
        w2_ref,
        bias2.to(torch.float32),
        alpha,
        beta,
        limit,
        "bf16",
    )
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750

    from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe

    # Swap halves to arrange as [w3; w1] (kernel expectation)
    w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
    w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)

    b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
    w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

    w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1)
    w13_s = torch.cat([w3_s, w1_s], dim=1)
    w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
    w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)

751
752
753
754
755
756
757
758
759
    routing_weights = torch.nn.functional.softmax(
        router_logits, dim=1, dtype=torch.float32
    )
    token_final_scales, token_selected_experts = torch.topk(
        routing_weights, topk, dim=-1
    )
    token_final_scales = token_final_scales / token_final_scales.sum(
        dim=-1, keepdim=True
    )
760
761
762
763
    token_selected_experts = token_selected_experts.to(torch.int).contiguous()

    out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
    if alpha is not None:
764
        alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
765
    if beta is not None:
766
        beta = torch.full((num_experts,), beta, device=hidden_states.device)
767
    if limit is not None:
768
        limit = torch.full((num_experts,), limit, device=hidden_states.device)
769
770
771
772
773
774
775
776
777

    _ = flashinfer_cutlass_fused_moe(
        input=hidden_states,
        token_selected_experts=token_selected_experts,
        token_final_scales=token_final_scales,
        fc1_expert_weights=w13_q_swapped,
        fc2_expert_weights=w2_q,
        output_dtype=torch.bfloat16,
        output=out,
778
        quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)],
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
        fc1_expert_biases=w13_b,
        fc2_expert_biases=bias2.to(torch.bfloat16),
        swiglu_alpha=alpha,
        swiglu_beta=beta,
        swiglu_limit=limit,
        tp_size=1,
        tp_rank=0,
        ep_size=1,
        ep_rank=0,
        use_w4_group_scaling=True,
    )

    # Allow some mismatch due to MXFP4 quantization
    check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)


@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
799
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
800
@pytest.mark.skipif(
801
802
803
804
805
    not (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and has_flashinfer()
    ),
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
    reason="NVIDIA GPU sm100 and flashinfer are required for this test",
)
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
    topk: int,
    num_experts: int,
    num_tokens: int,
    intermediate_size: int,
    hidden_size: int,
    alpha: Optional[float],
    beta: Optional[float],
    limit: Optional[float],
):
    torch.manual_seed(42)
    device = "cuda:0"

    # Inputs
822
823
824
    hidden_states = torch.randn(
        num_tokens, hidden_size, device=device, dtype=torch.bfloat16
    )
825
    # Float weights in w13 format [w1; w3]
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
    w13 = (
        torch.randn(
            num_experts,
            2 * intermediate_size,
            hidden_size,
            device=device,
            dtype=torch.bfloat16,
        )
        / 10
    )
    w2 = (
        torch.randn(
            num_experts,
            hidden_size,
            intermediate_size,
            device=device,
            dtype=torch.bfloat16,
        )
        / 10
    )
846
    # Bias contiguous [b1; b3]
847
848
849
850
851
852
853
854
855
856
857
858
    bias13 = (
        torch.randn(
            num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
        )
        * 10
    )
    bias2 = (
        torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
    )
    router_logits = torch.rand(
        num_tokens, num_experts, dtype=torch.float32, device=device
    )
859
860
861
862
863
864
865
866
867
868
869
870

    # Quantize weights to MXFP4 per expert (SM100 path)
    from flashinfer import mxfp4_quantize

    def quant_mxfp4_batches(a: torch.Tensor, e: int):
        qs, sfs = [], []
        for i in range(e):
            q, sf = mxfp4_quantize(a[i].cuda())
            qs.append(q)
            sfs.append(sf)
        return torch.stack(qs), torch.stack(sfs)

871
    def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
872
873
874
        num_batches = mat_fp4.size(0)
        scale_tensor = scale_tensor.view(num_batches, -1)
        from flashinfer import mxfp4_dequantize
875
876
877
878
879
880
881

        return torch.stack(
            [
                mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
                for b in range(num_batches)
            ]
        )
882
883
884
885
886

    w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
    w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)

    # Reference result using dequantized tensors and reference_moe
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
    w13_ref = (
        dequant_mxfp4_batches(
            w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1)
        )
        .to(torch.float32)
        .reshape(num_experts, 2 * intermediate_size, hidden_size)
        .to(device)
    )
    w2_ref = (
        dequant_mxfp4_batches(
            w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1)
        )
        .to(torch.float32)
        .reshape(num_experts, hidden_size, intermediate_size)
        .to(device)
    )
903
904
905
906

    # Quantize activations for SM100 path and dequantize for reference
    hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
    # Reference uses BF16 input but quantizes intermediate activation to MXFP8
907
908
909
910
911
912
913
914
915
916
917
918
919
920
    ref = reference_moe(
        router_logits.to(torch.float32),
        topk,
        num_experts,
        hidden_states.to(torch.float32),
        w13_ref,
        bias13.to(torch.float32),
        w2_ref,
        bias2.to(torch.float32),
        alpha,
        beta,
        limit,
        "mxfp8",
    )
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936

    # Prepare inputs for FlashInfer CUTLASS fused MoE
    from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe

    # Swap halves to arrange as [w3; w1] (kernel expectation)
    w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
    w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)

    # Swap scales halves to match swapped weights
    s1, s3 = torch.chunk(w13_scale, 2, dim=1)
    w13_scale_swapped = torch.cat([s3, s1], dim=1)

    b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
    w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

    # Build routing for kernel
937
938
939
940
941
942
943
944
945
    routing_weights = torch.nn.functional.softmax(
        router_logits, dim=1, dtype=torch.float32
    )
    token_final_scales, token_selected_experts = torch.topk(
        routing_weights, topk, dim=-1
    )
    token_final_scales = token_final_scales / token_final_scales.sum(
        dim=-1, keepdim=True
    )
946
947
948
949
    token_selected_experts = token_selected_experts.to(torch.int).contiguous()

    out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
    if alpha is not None:
950
        alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device)
951
952
953
    else:
        alpha_t = None
    if beta is not None:
954
        beta_t = torch.full((num_experts,), beta, device=hidden_states.device)
955
956
957
    else:
        beta_t = None
    if limit is not None:
958
        limit_t = torch.full((num_experts,), limit, device=hidden_states.device)
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
    else:
        limit_t = None

    # Quant scales for SM100 MXFP8+MXFP4 path
    fake_input_scale = torch.ones(num_experts, device=device)
    quant_scales = [
        w13_scale_swapped.view(torch.int32),
        fake_input_scale,
        w2_scale.view(torch.int32),
        fake_input_scale,
    ]

    _ = flashinfer_cutlass_fused_moe(
        input=hidden_states_q,
        token_selected_experts=token_selected_experts,
        token_final_scales=token_final_scales,
        fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long),
        fc2_expert_weights=w2_q.contiguous().view(torch.long),
        output_dtype=torch.bfloat16,
        output=out,
        quant_scales=quant_scales,
        fc1_expert_biases=w13_b,
        fc2_expert_biases=bias2.to(torch.bfloat16),
        swiglu_alpha=alpha_t,
        swiglu_beta=beta_t,
        swiglu_limit=limit_t,
        tp_size=1,
        tp_rank=0,
        ep_size=1,
        ep_rank=0,
        use_mxfp8_act_scaling=True,
        input_sf=hidden_states_sf,
    )

    # Allow some mismatch due to MXFP4 quantization
    check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)