utils.py 17.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
bnellnm's avatar
bnellnm committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

import torch

bnellnm's avatar
bnellnm committed
6
import vllm._custom_ops as ops
7
from tests.kernels.quant_utils import per_block_cast_to_int8
8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
9
from vllm.model_executor.layers.activation import SiluAndMul
10
11
12
13
14
from vllm.model_executor.layers.fused_moe import (
    TritonExperts,
    fused_experts,
    fused_topk,
)
15
16
17
18
19
20
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
    RoutingMethodType,
)
bnellnm's avatar
bnellnm committed
21
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
22
23
24
25
26
    BatchedPrepareAndFinalize,
    BatchedTritonExperts,
    NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
27
28
29
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
30
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
31
from vllm.utils.deep_gemm import per_block_cast_to_fp8
32
from vllm.utils.math_utils import round_up
bnellnm's avatar
bnellnm committed
33
34


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def make_dummy_moe_config(
    num_experts: int = 1,
    experts_per_token: int = 1,
    hidden_dim: int = 1,
    intermediate_size_per_partition: int = 1,
    in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
    """
    This is a dummy config for the mk constructor interface
    as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
    do not actually use this config.

    CUTLASSFp8 needs to set some params for workshapes.
    """
    return FusedMoEConfig(
        num_experts=num_experts,
        experts_per_token=experts_per_token,
        hidden_dim=hidden_dim,
        intermediate_size_per_partition=intermediate_size_per_partition,
        num_local_experts=num_experts,
        moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
        activation="silu",
        in_dtype=in_dtype,
        device="cuda",
        routing_method=RoutingMethodType.TopK,
    )


bnellnm's avatar
bnellnm committed
63
64
65
66
67
68
def triton_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
69
70
71
72
73
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
74
    per_act_token_quant=False,
75
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
76
) -> torch.Tensor:
77
78
79
80
81
82
83
84
85
86
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

87
    return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
bnellnm's avatar
bnellnm committed
88
89
90
91
92
93
94
95


def batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
96
97
98
99
100
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
101
    per_act_token_quant: bool = False,
102
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
103
104
105
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

106
107
108
109
110
111
112
113
114
115
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

bnellnm's avatar
bnellnm committed
116
    fused_experts = FusedMoEModularKernel(
117
118
119
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
120
121
        BatchedTritonExperts(
            max_num_tokens=max_num_tokens,
122
            num_dispatchers=1,
123
            quant_config=quant_config,
124
            moe_config=make_dummy_moe_config(),
bnellnm's avatar
bnellnm committed
125
        ),
126
        inplace=False,
127
128
    )

129
    return fused_experts(a, w1, w2, topk_weight, topk_ids)
bnellnm's avatar
bnellnm committed
130
131
132
133
134
135
136
137


def naive_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
138
139
140
141
142
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
143
    per_act_token_quant: bool = False,
144
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
145
146
147
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

148
149
150
151
152
153
154
155
156
157
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

bnellnm's avatar
bnellnm committed
158
    fused_experts = FusedMoEModularKernel(
159
160
161
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
162
163
        NaiveBatchedExperts(
            max_num_tokens=max_num_tokens,
164
            num_dispatchers=1,
165
            quant_config=quant_config,
166
            moe_config=make_dummy_moe_config(),
bnellnm's avatar
bnellnm committed
167
        ),
168
        inplace=False,
bnellnm's avatar
bnellnm committed
169
    )
170

171
    return fused_experts(a, w1, w2, topk_weight, topk_ids)
bnellnm's avatar
bnellnm committed
172
173


174
def chunk_scales(
175
176
    scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
bnellnm's avatar
bnellnm committed
177
178
179
180
181
182
183
184
185
186
187
188
189
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None


def make_quantized_test_activations(
    E: int,
    m: int,
    k: int,
    in_dtype: torch.dtype,
190
191
    quant_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
192
    per_act_token_quant: bool = False,
193
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
194
195
196
197
198
    a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
    a_q = a
    a_scale = None

    if quant_dtype is not None:
199
200
201
        assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
            "only fp8/int8 supported"
        )
bnellnm's avatar
bnellnm committed
202
203
204
205
        a_q = torch.zeros_like(a, dtype=quant_dtype)
        a_scale_l = [None] * E
        for e in range(E):
            a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
206
207
                a[e], None, quant_dtype, per_act_token_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
208
209
210
211
212
213
214
215
216
217
        a_scale = torch.stack(a_scale_l)

        if not per_act_token_quant and block_shape is None:
            a_scale = a_scale.view(E, 1, 1)

    return a, a_q, a_scale


def moe_quantize_weights(
    w: torch.Tensor,
218
219
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
bnellnm's avatar
bnellnm committed
220
    per_token_quant: bool,
221
222
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
223
224
225
226
227
    assert (
        quant_dtype == torch.float8_e4m3fn
        or quant_dtype == torch.int8
        or quant_dtype == "nvfp4"
    ), "only fp8/int8/nvfp4 supported"
228
229

    w_gs = None
bnellnm's avatar
bnellnm committed
230
231
232
233
234

    if block_shape is not None:
        assert not per_token_quant
        if quant_dtype == torch.int8:
            w, w_s = per_block_cast_to_int8(w, block_shape)
235
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
236
            w, w_s = per_block_cast_to_fp8(w, block_shape)
237
238
239
240
        elif quant_dtype == "nvfp4":
            raise RuntimeError("blocked quantization not supported for nvfp4")
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
241
242
243
    else:
        if quant_dtype == torch.int8:
            w, w_s = ops.scaled_int8_quant(
244
245
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
246
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
247
            w, w_s = ops.scaled_fp8_quant(
248
249
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
250
251
252
253
254
255
256
        elif quant_dtype == "nvfp4":
            assert not per_token_quant
            w_amax = torch.abs(w).max().to(torch.float32)
            w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
            w, w_s = ops.scaled_fp4_quant(w, w_gs)
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
257

258
    return w, w_s, w_gs
bnellnm's avatar
bnellnm committed
259
260
261
262
263
264
265


def make_test_weight(
    e: int,
    rows: int,
    cols: int,
    in_dtype: torch.dtype = torch.bfloat16,
266
267
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
268
    per_out_ch_quant: bool = False,
269
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
270
    w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
271
    w_gs = None
bnellnm's avatar
bnellnm committed
272
273
274
275

    if quant_dtype is not None:
        w_l = [None] * e
        w_s_l = [None] * e
276
        w_gs_l = [None] * e
bnellnm's avatar
bnellnm committed
277
        for idx in range(e):
278
            w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
279
280
                w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
281
282
283

        w = torch.stack(w_l)
        w_s = torch.stack(w_s_l)
284
285
        if e > 0 and w_gs_l[0] is not None:
            w_gs = torch.stack(w_gs_l)
bnellnm's avatar
bnellnm committed
286
287
288
289
290
291
292
293
294
295
296
297
        if w_s.ndim == 2:
            assert w_s.shape[-1] == 1
            w_s = w_s.view(-1, 1, 1)

        if block_shape is not None:
            block_n, block_k = block_shape
            n_tiles = (rows + block_n - 1) // block_n
            k_tiles = (cols + block_k - 1) // block_k
            assert w_s.shape == (e, n_tiles, k_tiles)
    else:
        w = w_16
        w_s = None
298
        w_gs = None
bnellnm's avatar
bnellnm committed
299

300
    return w_16, w, w_s, w_gs
bnellnm's avatar
bnellnm committed
301
302
303
304
305
306
307


def make_test_weights(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype = torch.bfloat16,
308
309
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
310
    per_out_ch_quant: bool = False,
311
    make_gate: bool = True,
312
) -> tuple[
313
314
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
315
]:
bnellnm's avatar
bnellnm committed
316
    return (
317
        make_test_weight(
318
319
320
321
322
323
324
            e,
            (2 if make_gate else 1) * n,
            k,
            in_dtype,
            quant_dtype,
            block_shape,
            per_out_ch_quant,
325
326
        ),
        make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
bnellnm's avatar
bnellnm committed
327
    )
328
329
330


def per_token_cast_to_fp8(
331
332
    x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
333
334
335
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (block_size - (n % block_size)) % block_size
336
    x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
337
338
339
340
    x_view = x.view(m, -1, block_size)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
341
342


343
344
345
346
347
def make_test_quant_config(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype,
348
    quant_dtype: torch.dtype | str | None = None,
349
    per_act_token_quant: bool = False,
350
    block_shape: list[int] | None = None,
351
    make_gate: bool = True,
352
353
354
355
356
357
358
359
360
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
    (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
        e,
        n,
        k,
        in_dtype,
        quant_dtype,
        per_out_ch_quant=per_act_token_quant,
        block_shape=block_shape,
361
        make_gate=make_gate,
362
363
364
    )

    # Hacky/trivial scales for nvfp4.
365
366
    a1_gscale: torch.Tensor | None = None
    a2_gscale: torch.Tensor | None = None
367
    if quant_dtype == "nvfp4":
368
369
        a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
370
371
372
373
374
375
        a1_scale = a1_gscale
        a2_scale = a2_gscale
    else:
        a1_scale = None
        a2_scale = None

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    return (
        w1,
        w2,
        FusedMoEQuantConfig.make(
            quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
            w1_scale=w1_s,
            w2_scale=w2_s,
            a1_gscale=a1_gscale,
            a2_gscale=a2_gscale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            # TODO: make sure this is handled properly
            g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
            g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
        ),
393
394
395
396
397
398
399
400
401
402
    )


def fused_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    renormalize: bool = False,
403
    quant_config: FusedMoEQuantConfig | None = None,
404
    global_num_experts: int = -1,
405
    expert_map: torch.Tensor | None = None,
406
) -> torch.Tensor:
407
408
409
410
411
412
413
414
415
416
417
418
419
    topk_weights, topk_ids, _ = fused_topk(
        hidden_states, score.float(), topk, renormalize
    )
    return fused_experts(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        quant_config=quant_config,
    )
420
421


422
423
424
425
426
427
428
429
430
431
432
# CustomOp?
class BaselineMM(torch.nn.Module):
    def __init__(
        self,
        b: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.b = b.to(dtype=torch.float32)
        self.out_dtype = out_dtype

433
    def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
434
        return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476


class TestMLP(torch.nn.Module):
    def __init__(
        self,
        w1: torch.Tensor,
        w2: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.gate_up_proj = BaselineMM(w1, out_dtype)
        self.down_proj = BaselineMM(w2, out_dtype)
        self.act_fn = SiluAndMul()

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def make_naive_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
    w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
    w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
    return TestMLP(w1, w2, out_dtype=in_dtype)


class RealMLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        w1: torch.Tensor,
        w2: torch.Tensor,
        hidden_act: str = "silu",
        quant_config=None,
        reduce_results: bool = True,
        prefix: str = "",
477
478
        w1_s: torch.Tensor | None = None,
        w2_s: torch.Tensor | None = None,
479
480
    ) -> None:
        from vllm.model_executor.layers.linear import (
481
482
483
            MergedColumnParallelLinear,
            RowParallelLinear,
        )
484
485
486

        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
487
488
            hidden_size,
            [intermediate_size] * 2,
489
490
            bias=False,
            quant_config=quant_config,
491
492
            prefix=f"{prefix}.gate_up_proj",
        )
493
        self.gate_up_proj.register_parameter(
494
495
            "weight", torch.nn.Parameter(w1, requires_grad=False)
        )
496
        self.gate_up_proj.register_parameter(
497
498
            "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
        )
499
        self.gate_up_proj.register_parameter(
500
501
502
503
504
505
506
507
508
509
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
510
        self.down_proj.register_parameter(
511
512
            "weight", torch.nn.Parameter(w2, requires_grad=False)
        )
513
        self.down_proj.register_parameter(
514
515
            "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
        )
516
        self.down_proj.register_parameter(
517
518
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
519
        if hidden_act != "silu":
520
521
522
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
523
524
525
526
527
528
529
530
531
532
533
534
535
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


def make_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
536
    quant_dtype: torch.dtype | str | None = None,
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
) -> torch.nn.Module:
    from vllm.model_executor.layers.quantization.fp8 import Fp8Config

    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        1,
        N,
        K,
        in_dtype=in_dtype,
        quant_dtype=quant_dtype,
    )
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(in_dtype)
        if quant_dtype == torch.float8_e4m3fn:
            w1 = w1[0].transpose(0, 1)
            w2 = w2[0].transpose(0, 1)
            w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
            w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
            quant_config = Fp8Config(True)
        else:
            w1 = w1[0]
            w2 = w2[0]
            w1_s = None
            w2_s = None
            quant_config = None

563
        return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
564
565
    finally:
        torch.set_default_dtype(old_dtype)
566
567
568
569
570
571
572
573
574
575
576
577
578


def modular_triton_fused_moe(
    moe_config: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
    shared_experts: torch.nn.Module | None = None,
) -> FusedMoEModularKernel:
    return FusedMoEModularKernel(
        MoEPrepareAndFinalizeNoEP(),
        TritonExperts(moe_config, quant_config),
        shared_experts,
        inplace=False,
    )