"vllm/vscode:/vscode.git/clone" did not exist on "8f36444c4f9a55669bcb64e20b5588c0dd72bd93"
utils.py 16.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
bnellnm's avatar
bnellnm committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

import torch

bnellnm's avatar
bnellnm committed
6
import vllm._custom_ops as ops
7
from tests.kernels.quant_utils import per_block_cast_to_int8
8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
9
from vllm.model_executor.layers.activation import SiluAndMul
10
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
11
12
13
14
15
16
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
    RoutingMethodType,
)
bnellnm's avatar
bnellnm committed
17
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
18
19
20
21
22
23
    BatchedPrepareAndFinalize,
    BatchedTritonExperts,
    NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
24
from vllm.utils.deep_gemm import per_block_cast_to_fp8
25
from vllm.utils.math_utils import round_up
bnellnm's avatar
bnellnm committed
26
27


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def make_dummy_moe_config(
    num_experts: int = 1,
    experts_per_token: int = 1,
    hidden_dim: int = 1,
    intermediate_size_per_partition: int = 1,
    in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
    """
    This is a dummy config for the mk constructor interface
    as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
    do not actually use this config.

    CUTLASSFp8 needs to set some params for workshapes.
    """
    return FusedMoEConfig(
        num_experts=num_experts,
        experts_per_token=experts_per_token,
        hidden_dim=hidden_dim,
        intermediate_size_per_partition=intermediate_size_per_partition,
        num_local_experts=num_experts,
        moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
        activation="silu",
        in_dtype=in_dtype,
        device="cuda",
        routing_method=RoutingMethodType.TopK,
    )


bnellnm's avatar
bnellnm committed
56
57
58
59
60
61
def triton_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
62
63
64
65
66
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
67
    per_act_token_quant=False,
68
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
69
) -> torch.Tensor:
70
71
72
73
74
75
76
77
78
79
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

80
    return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
bnellnm's avatar
bnellnm committed
81
82
83
84
85
86
87
88


def batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
89
90
91
92
93
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
94
    per_act_token_quant: bool = False,
95
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
96
97
98
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

99
100
101
102
103
104
105
106
107
108
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

bnellnm's avatar
bnellnm committed
109
    fused_experts = FusedMoEModularKernel(
110
111
112
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
113
114
        BatchedTritonExperts(
            max_num_tokens=max_num_tokens,
115
            num_dispatchers=1,
116
            quant_config=quant_config,
117
            moe_config=make_dummy_moe_config(),
bnellnm's avatar
bnellnm committed
118
        ),
119
120
    )

121
    return fused_experts(a, w1, w2, topk_weight, topk_ids)
bnellnm's avatar
bnellnm committed
122
123
124
125
126
127
128
129


def naive_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
130
131
132
133
134
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
135
    per_act_token_quant: bool = False,
136
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
137
138
139
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

140
141
142
143
144
145
146
147
148
149
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

bnellnm's avatar
bnellnm committed
150
    fused_experts = FusedMoEModularKernel(
151
152
153
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
154
155
        NaiveBatchedExperts(
            max_num_tokens=max_num_tokens,
156
            num_dispatchers=1,
157
            quant_config=quant_config,
158
            moe_config=make_dummy_moe_config(),
bnellnm's avatar
bnellnm committed
159
160
        ),
    )
161

162
    return fused_experts(a, w1, w2, topk_weight, topk_ids)
bnellnm's avatar
bnellnm committed
163
164


165
def chunk_scales(
166
167
    scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
bnellnm's avatar
bnellnm committed
168
169
170
171
172
173
174
175
176
177
178
179
180
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None


def make_quantized_test_activations(
    E: int,
    m: int,
    k: int,
    in_dtype: torch.dtype,
181
182
    quant_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
183
    per_act_token_quant: bool = False,
184
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
185
186
187
188
189
    a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
    a_q = a
    a_scale = None

    if quant_dtype is not None:
190
191
192
        assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
            "only fp8/int8 supported"
        )
bnellnm's avatar
bnellnm committed
193
194
195
196
        a_q = torch.zeros_like(a, dtype=quant_dtype)
        a_scale_l = [None] * E
        for e in range(E):
            a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
197
198
                a[e], None, quant_dtype, per_act_token_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
199
200
201
202
203
204
205
206
207
208
        a_scale = torch.stack(a_scale_l)

        if not per_act_token_quant and block_shape is None:
            a_scale = a_scale.view(E, 1, 1)

    return a, a_q, a_scale


def moe_quantize_weights(
    w: torch.Tensor,
209
210
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
bnellnm's avatar
bnellnm committed
211
    per_token_quant: bool,
212
213
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
214
215
216
217
218
    assert (
        quant_dtype == torch.float8_e4m3fn
        or quant_dtype == torch.int8
        or quant_dtype == "nvfp4"
    ), "only fp8/int8/nvfp4 supported"
219
220

    w_gs = None
bnellnm's avatar
bnellnm committed
221
222
223
224
225

    if block_shape is not None:
        assert not per_token_quant
        if quant_dtype == torch.int8:
            w, w_s = per_block_cast_to_int8(w, block_shape)
226
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
227
            w, w_s = per_block_cast_to_fp8(w, block_shape)
228
229
230
231
        elif quant_dtype == "nvfp4":
            raise RuntimeError("blocked quantization not supported for nvfp4")
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
232
233
234
    else:
        if quant_dtype == torch.int8:
            w, w_s = ops.scaled_int8_quant(
235
236
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
237
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
238
            w, w_s = ops.scaled_fp8_quant(
239
240
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
241
242
243
244
245
246
247
        elif quant_dtype == "nvfp4":
            assert not per_token_quant
            w_amax = torch.abs(w).max().to(torch.float32)
            w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
            w, w_s = ops.scaled_fp4_quant(w, w_gs)
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
248

249
    return w, w_s, w_gs
bnellnm's avatar
bnellnm committed
250
251
252
253
254
255
256


def make_test_weight(
    e: int,
    rows: int,
    cols: int,
    in_dtype: torch.dtype = torch.bfloat16,
257
258
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
259
    per_out_ch_quant: bool = False,
260
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
261
    w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
262
    w_gs = None
bnellnm's avatar
bnellnm committed
263
264
265
266

    if quant_dtype is not None:
        w_l = [None] * e
        w_s_l = [None] * e
267
        w_gs_l = [None] * e
bnellnm's avatar
bnellnm committed
268
        for idx in range(e):
269
            w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
270
271
                w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
272
273
274

        w = torch.stack(w_l)
        w_s = torch.stack(w_s_l)
275
276
        if e > 0 and w_gs_l[0] is not None:
            w_gs = torch.stack(w_gs_l)
bnellnm's avatar
bnellnm committed
277
278
279
280
281
282
283
284
285
286
287
288
        if w_s.ndim == 2:
            assert w_s.shape[-1] == 1
            w_s = w_s.view(-1, 1, 1)

        if block_shape is not None:
            block_n, block_k = block_shape
            n_tiles = (rows + block_n - 1) // block_n
            k_tiles = (cols + block_k - 1) // block_k
            assert w_s.shape == (e, n_tiles, k_tiles)
    else:
        w = w_16
        w_s = None
289
        w_gs = None
bnellnm's avatar
bnellnm committed
290

291
    return w_16, w, w_s, w_gs
bnellnm's avatar
bnellnm committed
292
293
294
295
296
297
298


def make_test_weights(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype = torch.bfloat16,
299
300
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
301
    per_out_ch_quant: bool = False,
302
    make_gate: bool = True,
303
) -> tuple[
304
305
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
306
]:
bnellnm's avatar
bnellnm committed
307
    return (
308
        make_test_weight(
309
310
311
312
313
314
315
            e,
            (2 if make_gate else 1) * n,
            k,
            in_dtype,
            quant_dtype,
            block_shape,
            per_out_ch_quant,
316
317
        ),
        make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
bnellnm's avatar
bnellnm committed
318
    )
319
320
321


def per_token_cast_to_fp8(
322
323
    x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
324
325
326
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (block_size - (n % block_size)) % block_size
327
    x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
328
329
330
331
    x_view = x.view(m, -1, block_size)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
332
333


334
335
336
337
338
def make_test_quant_config(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype,
339
    quant_dtype: torch.dtype | str | None = None,
340
    per_act_token_quant: bool = False,
341
    block_shape: list[int] | None = None,
342
    make_gate: bool = True,
343
344
345
346
347
348
349
350
351
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
    (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
        e,
        n,
        k,
        in_dtype,
        quant_dtype,
        per_out_ch_quant=per_act_token_quant,
        block_shape=block_shape,
352
        make_gate=make_gate,
353
354
355
    )

    # Hacky/trivial scales for nvfp4.
356
357
    a1_gscale: torch.Tensor | None = None
    a2_gscale: torch.Tensor | None = None
358
    if quant_dtype == "nvfp4":
359
360
        a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
361
362
363
364
365
366
        a1_scale = a1_gscale
        a2_scale = a2_gscale
    else:
        a1_scale = None
        a2_scale = None

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    return (
        w1,
        w2,
        FusedMoEQuantConfig.make(
            quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
            w1_scale=w1_s,
            w2_scale=w2_s,
            a1_gscale=a1_gscale,
            a2_gscale=a2_gscale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            # TODO: make sure this is handled properly
            g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
            g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
        ),
384
385
386
387
388
389
390
391
392
393
    )


def fused_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    renormalize: bool = False,
394
    quant_config: FusedMoEQuantConfig | None = None,
395
    global_num_experts: int = -1,
396
    expert_map: torch.Tensor | None = None,
397
) -> torch.Tensor:
398
399
400
401
402
403
404
405
406
407
408
409
410
    topk_weights, topk_ids, _ = fused_topk(
        hidden_states, score.float(), topk, renormalize
    )
    return fused_experts(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        quant_config=quant_config,
    )
411
412


413
414
415
416
417
418
419
420
421
422
423
# CustomOp?
class BaselineMM(torch.nn.Module):
    def __init__(
        self,
        b: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.b = b.to(dtype=torch.float32)
        self.out_dtype = out_dtype

424
    def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
425
        return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467


class TestMLP(torch.nn.Module):
    def __init__(
        self,
        w1: torch.Tensor,
        w2: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.gate_up_proj = BaselineMM(w1, out_dtype)
        self.down_proj = BaselineMM(w2, out_dtype)
        self.act_fn = SiluAndMul()

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def make_naive_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
    w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
    w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
    return TestMLP(w1, w2, out_dtype=in_dtype)


class RealMLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        w1: torch.Tensor,
        w2: torch.Tensor,
        hidden_act: str = "silu",
        quant_config=None,
        reduce_results: bool = True,
        prefix: str = "",
468
469
        w1_s: torch.Tensor | None = None,
        w2_s: torch.Tensor | None = None,
470
471
    ) -> None:
        from vllm.model_executor.layers.linear import (
472
473
474
            MergedColumnParallelLinear,
            RowParallelLinear,
        )
475
476
477

        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
478
479
            hidden_size,
            [intermediate_size] * 2,
480
481
            bias=False,
            quant_config=quant_config,
482
483
            prefix=f"{prefix}.gate_up_proj",
        )
484
        self.gate_up_proj.register_parameter(
485
486
            "weight", torch.nn.Parameter(w1, requires_grad=False)
        )
487
        self.gate_up_proj.register_parameter(
488
489
            "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
        )
490
        self.gate_up_proj.register_parameter(
491
492
493
494
495
496
497
498
499
500
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
501
        self.down_proj.register_parameter(
502
503
            "weight", torch.nn.Parameter(w2, requires_grad=False)
        )
504
        self.down_proj.register_parameter(
505
506
            "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
        )
507
        self.down_proj.register_parameter(
508
509
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
510
        if hidden_act != "silu":
511
512
513
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
514
515
516
517
518
519
520
521
522
523
524
525
526
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


def make_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
527
    quant_dtype: torch.dtype | str | None = None,
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
) -> torch.nn.Module:
    from vllm.model_executor.layers.quantization.fp8 import Fp8Config

    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        1,
        N,
        K,
        in_dtype=in_dtype,
        quant_dtype=quant_dtype,
    )
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(in_dtype)
        if quant_dtype == torch.float8_e4m3fn:
            w1 = w1[0].transpose(0, 1)
            w2 = w2[0].transpose(0, 1)
            w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
            w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
            quant_config = Fp8Config(True)
        else:
            w1 = w1[0]
            w2 = w2[0]
            w1_s = None
            w2_s = None
            quant_config = None

554
        return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
555
556
    finally:
        torch.set_default_dtype(old_dtype)