utils.py 17.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
bnellnm's avatar
bnellnm committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

import torch

bnellnm's avatar
bnellnm committed
6
import vllm._custom_ops as ops
7
from tests.kernels.quant_utils import per_block_cast_to_int8
8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
9
from vllm.model_executor.layers.activation import SiluAndMul
10
11
12
13
14
from vllm.model_executor.layers.fused_moe import (
    TritonExperts,
    fused_experts,
    fused_topk,
)
15
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
16
17
18
19
20
21
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
    RoutingMethodType,
)
bnellnm's avatar
bnellnm committed
22
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
23
24
25
26
27
    BatchedPrepareAndFinalize,
    BatchedTritonExperts,
    NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
28
29
30
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
31
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
32
from vllm.utils.deep_gemm import per_block_cast_to_fp8
33
from vllm.utils.math_utils import round_up
bnellnm's avatar
bnellnm committed
34
35


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def make_dummy_moe_config(
    num_experts: int = 1,
    experts_per_token: int = 1,
    hidden_dim: int = 1,
    intermediate_size_per_partition: int = 1,
    in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
    """
    This is a dummy config for the mk constructor interface
    as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
    do not actually use this config.

    CUTLASSFp8 needs to set some params for workshapes.
    """
    return FusedMoEConfig(
        num_experts=num_experts,
        experts_per_token=experts_per_token,
        hidden_dim=hidden_dim,
        intermediate_size_per_partition=intermediate_size_per_partition,
        num_local_experts=num_experts,
56
        num_logical_experts=num_experts,
57
        moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
58
        activation=MoEActivation.SILU,
59
60
61
62
63
64
        in_dtype=in_dtype,
        device="cuda",
        routing_method=RoutingMethodType.TopK,
    )


bnellnm's avatar
bnellnm committed
65
66
67
68
69
70
def triton_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
71
72
73
74
75
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
76
    per_act_token_quant=False,
77
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
78
) -> torch.Tensor:
79
80
81
82
83
84
85
86
87
88
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

89
    return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
bnellnm's avatar
bnellnm committed
90
91
92
93
94
95
96
97


def batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
98
99
100
101
102
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
103
    per_act_token_quant: bool = False,
104
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
105
106
107
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

108
109
110
111
112
113
114
115
116
117
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

bnellnm's avatar
bnellnm committed
118
    fused_experts = FusedMoEModularKernel(
119
120
121
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
122
123
        BatchedTritonExperts(
            max_num_tokens=max_num_tokens,
124
            num_dispatchers=1,
125
            quant_config=quant_config,
126
            moe_config=make_dummy_moe_config(),
bnellnm's avatar
bnellnm committed
127
        ),
128
        inplace=False,
129
130
    )

131
    return fused_experts(a, w1, w2, topk_weight, topk_ids)
bnellnm's avatar
bnellnm committed
132
133
134
135
136
137
138
139


def naive_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
140
141
142
143
144
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
145
    per_act_token_quant: bool = False,
146
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
147
148
149
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

150
151
152
153
154
155
156
157
158
159
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

bnellnm's avatar
bnellnm committed
160
    fused_experts = FusedMoEModularKernel(
161
162
163
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
164
165
        NaiveBatchedExperts(
            max_num_tokens=max_num_tokens,
166
            num_dispatchers=1,
167
            quant_config=quant_config,
168
            moe_config=make_dummy_moe_config(),
bnellnm's avatar
bnellnm committed
169
        ),
170
        inplace=False,
bnellnm's avatar
bnellnm committed
171
    )
172

173
    return fused_experts(a, w1, w2, topk_weight, topk_ids)
bnellnm's avatar
bnellnm committed
174
175


176
def chunk_scales(
177
178
    scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
bnellnm's avatar
bnellnm committed
179
180
181
182
183
184
185
186
187
188
189
190
191
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None


def make_quantized_test_activations(
    E: int,
    m: int,
    k: int,
    in_dtype: torch.dtype,
192
193
    quant_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
194
    per_act_token_quant: bool = False,
195
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
196
197
198
199
200
    a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
    a_q = a
    a_scale = None

    if quant_dtype is not None:
201
202
203
        assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
            "only fp8/int8 supported"
        )
bnellnm's avatar
bnellnm committed
204
205
206
207
        a_q = torch.zeros_like(a, dtype=quant_dtype)
        a_scale_l = [None] * E
        for e in range(E):
            a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
208
209
                a[e], None, quant_dtype, per_act_token_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
210
211
212
213
214
215
216
217
218
219
        a_scale = torch.stack(a_scale_l)

        if not per_act_token_quant and block_shape is None:
            a_scale = a_scale.view(E, 1, 1)

    return a, a_q, a_scale


def moe_quantize_weights(
    w: torch.Tensor,
220
221
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
bnellnm's avatar
bnellnm committed
222
    per_token_quant: bool,
223
224
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
225
226
227
228
229
    assert (
        quant_dtype == torch.float8_e4m3fn
        or quant_dtype == torch.int8
        or quant_dtype == "nvfp4"
    ), "only fp8/int8/nvfp4 supported"
230
231

    w_gs = None
bnellnm's avatar
bnellnm committed
232
233
234
235
236

    if block_shape is not None:
        assert not per_token_quant
        if quant_dtype == torch.int8:
            w, w_s = per_block_cast_to_int8(w, block_shape)
237
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
238
            w, w_s = per_block_cast_to_fp8(w, block_shape)
239
240
241
242
        elif quant_dtype == "nvfp4":
            raise RuntimeError("blocked quantization not supported for nvfp4")
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
243
244
245
    else:
        if quant_dtype == torch.int8:
            w, w_s = ops.scaled_int8_quant(
246
247
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
248
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
249
            w, w_s = ops.scaled_fp8_quant(
250
251
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
252
253
254
255
256
257
258
        elif quant_dtype == "nvfp4":
            assert not per_token_quant
            w_amax = torch.abs(w).max().to(torch.float32)
            w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
            w, w_s = ops.scaled_fp4_quant(w, w_gs)
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
259

260
    return w, w_s, w_gs
bnellnm's avatar
bnellnm committed
261
262
263
264
265
266
267


def make_test_weight(
    e: int,
    rows: int,
    cols: int,
    in_dtype: torch.dtype = torch.bfloat16,
268
269
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
270
    per_out_ch_quant: bool = False,
271
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
272
    w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
273
    w_gs = None
bnellnm's avatar
bnellnm committed
274
275
276
277

    if quant_dtype is not None:
        w_l = [None] * e
        w_s_l = [None] * e
278
        w_gs_l = [None] * e
bnellnm's avatar
bnellnm committed
279
        for idx in range(e):
280
            w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
281
282
                w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
283
284
285

        w = torch.stack(w_l)
        w_s = torch.stack(w_s_l)
286
287
        if e > 0 and w_gs_l[0] is not None:
            w_gs = torch.stack(w_gs_l)
bnellnm's avatar
bnellnm committed
288
289
290
291
292
293
294
295
296
297
298
299
        if w_s.ndim == 2:
            assert w_s.shape[-1] == 1
            w_s = w_s.view(-1, 1, 1)

        if block_shape is not None:
            block_n, block_k = block_shape
            n_tiles = (rows + block_n - 1) // block_n
            k_tiles = (cols + block_k - 1) // block_k
            assert w_s.shape == (e, n_tiles, k_tiles)
    else:
        w = w_16
        w_s = None
300
        w_gs = None
bnellnm's avatar
bnellnm committed
301

302
    return w_16, w, w_s, w_gs
bnellnm's avatar
bnellnm committed
303
304
305
306
307
308
309


def make_test_weights(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype = torch.bfloat16,
310
311
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
312
    per_out_ch_quant: bool = False,
313
    make_gate: bool = True,
314
) -> tuple[
315
316
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
317
]:
bnellnm's avatar
bnellnm committed
318
    return (
319
        make_test_weight(
320
321
322
323
324
325
326
            e,
            (2 if make_gate else 1) * n,
            k,
            in_dtype,
            quant_dtype,
            block_shape,
            per_out_ch_quant,
327
328
        ),
        make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
bnellnm's avatar
bnellnm committed
329
    )
330
331
332


def per_token_cast_to_fp8(
333
334
    x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
335
336
337
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (block_size - (n % block_size)) % block_size
338
    x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
339
340
341
342
    x_view = x.view(m, -1, block_size)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
343
344


345
346
347
348
349
def make_test_quant_config(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype,
350
    quant_dtype: torch.dtype | str | None = None,
351
    per_act_token_quant: bool = False,
352
    block_shape: list[int] | None = None,
353
    make_gate: bool = True,
354
355
356
357
358
359
360
361
362
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
    (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
        e,
        n,
        k,
        in_dtype,
        quant_dtype,
        per_out_ch_quant=per_act_token_quant,
        block_shape=block_shape,
363
        make_gate=make_gate,
364
365
366
    )

    # Hacky/trivial scales for nvfp4.
367
368
    a1_gscale: torch.Tensor | None = None
    a2_gscale: torch.Tensor | None = None
369
    if quant_dtype == "nvfp4":
370
371
        a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
372
373
374
375
376
377
        a1_scale = a1_gscale
        a2_scale = a2_gscale
    else:
        a1_scale = None
        a2_scale = None

378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    return (
        w1,
        w2,
        FusedMoEQuantConfig.make(
            quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
            w1_scale=w1_s,
            w2_scale=w2_s,
            a1_gscale=a1_gscale,
            a2_gscale=a2_gscale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            # TODO: make sure this is handled properly
            g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
            g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
        ),
395
396
397
398
399
400
401
402
403
404
    )


def fused_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    renormalize: bool = False,
405
    quant_config: FusedMoEQuantConfig | None = None,
406
    global_num_experts: int = -1,
407
    expert_map: torch.Tensor | None = None,
408
) -> torch.Tensor:
409
410
411
412
413
414
415
416
417
418
419
420
421
    topk_weights, topk_ids, _ = fused_topk(
        hidden_states, score.float(), topk, renormalize
    )
    return fused_experts(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        quant_config=quant_config,
    )
422
423


424
425
426
427
428
429
430
431
432
433
434
# CustomOp?
class BaselineMM(torch.nn.Module):
    def __init__(
        self,
        b: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.b = b.to(dtype=torch.float32)
        self.out_dtype = out_dtype

435
    def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
436
        return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478


class TestMLP(torch.nn.Module):
    def __init__(
        self,
        w1: torch.Tensor,
        w2: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.gate_up_proj = BaselineMM(w1, out_dtype)
        self.down_proj = BaselineMM(w2, out_dtype)
        self.act_fn = SiluAndMul()

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def make_naive_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
    w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
    w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
    return TestMLP(w1, w2, out_dtype=in_dtype)


class RealMLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        w1: torch.Tensor,
        w2: torch.Tensor,
        hidden_act: str = "silu",
        quant_config=None,
        reduce_results: bool = True,
        prefix: str = "",
479
480
        w1_s: torch.Tensor | None = None,
        w2_s: torch.Tensor | None = None,
481
482
    ) -> None:
        from vllm.model_executor.layers.linear import (
483
484
485
            MergedColumnParallelLinear,
            RowParallelLinear,
        )
486
487
488

        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
489
490
            hidden_size,
            [intermediate_size] * 2,
491
492
            bias=False,
            quant_config=quant_config,
493
494
            prefix=f"{prefix}.gate_up_proj",
        )
495
        self.gate_up_proj.register_parameter(
496
497
            "weight", torch.nn.Parameter(w1, requires_grad=False)
        )
498
        self.gate_up_proj.register_parameter(
499
500
            "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
        )
501
        self.gate_up_proj.register_parameter(
502
503
504
505
506
507
508
509
510
511
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
512
        self.down_proj.register_parameter(
513
514
            "weight", torch.nn.Parameter(w2, requires_grad=False)
        )
515
        self.down_proj.register_parameter(
516
517
            "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
        )
518
        self.down_proj.register_parameter(
519
520
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
521
        if hidden_act != "silu":
522
523
524
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
525
526
527
528
529
530
531
532
533
534
535
536
537
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


def make_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
538
    quant_dtype: torch.dtype | str | None = None,
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
) -> torch.nn.Module:
    from vllm.model_executor.layers.quantization.fp8 import Fp8Config

    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        1,
        N,
        K,
        in_dtype=in_dtype,
        quant_dtype=quant_dtype,
    )
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(in_dtype)
        if quant_dtype == torch.float8_e4m3fn:
            w1 = w1[0].transpose(0, 1)
            w2 = w2[0].transpose(0, 1)
            w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
            w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
            quant_config = Fp8Config(True)
        else:
            w1 = w1[0]
            w2 = w2[0]
            w1_s = None
            w2_s = None
            quant_config = None

565
        return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
566
567
    finally:
        torch.set_default_dtype(old_dtype)
568
569
570
571
572
573
574
575
576
577
578
579
580


def modular_triton_fused_moe(
    moe_config: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
    shared_experts: torch.nn.Module | None = None,
) -> FusedMoEModularKernel:
    return FusedMoEModularKernel(
        MoEPrepareAndFinalizeNoEP(),
        TritonExperts(moe_config, quant_config),
        shared_experts,
        inplace=False,
    )