utils.py 17.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
bnellnm's avatar
bnellnm committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

import torch

bnellnm's avatar
bnellnm committed
6
import vllm._custom_ops as ops
7
from tests.kernels.quant_utils import per_block_cast_to_int8
8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
9
from vllm.model_executor.layers.activation import SiluAndMul
10
11
12
13
14
from vllm.model_executor.layers.fused_moe import (
    TritonExperts,
    fused_experts,
    fused_topk,
)
15
16
17
18
19
20
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
    RoutingMethodType,
)
bnellnm's avatar
bnellnm committed
21
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
22
23
24
25
26
    BatchedPrepareAndFinalize,
    BatchedTritonExperts,
    NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
27
28
29
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
30
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
31
from vllm.utils.deep_gemm import per_block_cast_to_fp8
32
from vllm.utils.math_utils import round_up
bnellnm's avatar
bnellnm committed
33
34


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def make_dummy_moe_config(
    num_experts: int = 1,
    experts_per_token: int = 1,
    hidden_dim: int = 1,
    intermediate_size_per_partition: int = 1,
    in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
    """
    This is a dummy config for the mk constructor interface
    as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
    do not actually use this config.

    CUTLASSFp8 needs to set some params for workshapes.
    """
    return FusedMoEConfig(
        num_experts=num_experts,
        experts_per_token=experts_per_token,
        hidden_dim=hidden_dim,
        intermediate_size_per_partition=intermediate_size_per_partition,
        num_local_experts=num_experts,
55
        num_logical_experts=num_experts,
56
57
58
59
60
61
62
63
        moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
        activation="silu",
        in_dtype=in_dtype,
        device="cuda",
        routing_method=RoutingMethodType.TopK,
    )


bnellnm's avatar
bnellnm committed
64
65
66
67
68
69
def triton_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
70
71
72
73
74
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
75
    per_act_token_quant=False,
76
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
77
) -> torch.Tensor:
78
79
80
81
82
83
84
85
86
87
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

88
    return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
bnellnm's avatar
bnellnm committed
89
90
91
92
93
94
95
96


def batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
97
98
99
100
101
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
102
    per_act_token_quant: bool = False,
103
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
104
105
106
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

107
108
109
110
111
112
113
114
115
116
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

bnellnm's avatar
bnellnm committed
117
    fused_experts = FusedMoEModularKernel(
118
119
120
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
121
122
        BatchedTritonExperts(
            max_num_tokens=max_num_tokens,
123
            num_dispatchers=1,
124
            quant_config=quant_config,
125
            moe_config=make_dummy_moe_config(),
bnellnm's avatar
bnellnm committed
126
        ),
127
        inplace=False,
128
129
    )

130
    return fused_experts(a, w1, w2, topk_weight, topk_ids)
bnellnm's avatar
bnellnm committed
131
132
133
134
135
136
137
138


def naive_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
139
140
141
142
143
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
144
    per_act_token_quant: bool = False,
145
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
146
147
148
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

149
150
151
152
153
154
155
156
157
158
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

bnellnm's avatar
bnellnm committed
159
    fused_experts = FusedMoEModularKernel(
160
161
162
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
163
164
        NaiveBatchedExperts(
            max_num_tokens=max_num_tokens,
165
            num_dispatchers=1,
166
            quant_config=quant_config,
167
            moe_config=make_dummy_moe_config(),
bnellnm's avatar
bnellnm committed
168
        ),
169
        inplace=False,
bnellnm's avatar
bnellnm committed
170
    )
171

172
    return fused_experts(a, w1, w2, topk_weight, topk_ids)
bnellnm's avatar
bnellnm committed
173
174


175
def chunk_scales(
176
177
    scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
bnellnm's avatar
bnellnm committed
178
179
180
181
182
183
184
185
186
187
188
189
190
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None


def make_quantized_test_activations(
    E: int,
    m: int,
    k: int,
    in_dtype: torch.dtype,
191
192
    quant_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
193
    per_act_token_quant: bool = False,
194
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
195
196
197
198
199
    a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
    a_q = a
    a_scale = None

    if quant_dtype is not None:
200
201
202
        assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
            "only fp8/int8 supported"
        )
bnellnm's avatar
bnellnm committed
203
204
205
206
        a_q = torch.zeros_like(a, dtype=quant_dtype)
        a_scale_l = [None] * E
        for e in range(E):
            a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
207
208
                a[e], None, quant_dtype, per_act_token_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
209
210
211
212
213
214
215
216
217
218
        a_scale = torch.stack(a_scale_l)

        if not per_act_token_quant and block_shape is None:
            a_scale = a_scale.view(E, 1, 1)

    return a, a_q, a_scale


def moe_quantize_weights(
    w: torch.Tensor,
219
220
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
bnellnm's avatar
bnellnm committed
221
    per_token_quant: bool,
222
223
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
224
225
226
227
228
    assert (
        quant_dtype == torch.float8_e4m3fn
        or quant_dtype == torch.int8
        or quant_dtype == "nvfp4"
    ), "only fp8/int8/nvfp4 supported"
229
230

    w_gs = None
bnellnm's avatar
bnellnm committed
231
232
233
234
235

    if block_shape is not None:
        assert not per_token_quant
        if quant_dtype == torch.int8:
            w, w_s = per_block_cast_to_int8(w, block_shape)
236
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
237
            w, w_s = per_block_cast_to_fp8(w, block_shape)
238
239
240
241
        elif quant_dtype == "nvfp4":
            raise RuntimeError("blocked quantization not supported for nvfp4")
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
242
243
244
    else:
        if quant_dtype == torch.int8:
            w, w_s = ops.scaled_int8_quant(
245
246
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
247
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
248
            w, w_s = ops.scaled_fp8_quant(
249
250
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
251
252
253
254
255
256
257
        elif quant_dtype == "nvfp4":
            assert not per_token_quant
            w_amax = torch.abs(w).max().to(torch.float32)
            w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
            w, w_s = ops.scaled_fp4_quant(w, w_gs)
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
258

259
    return w, w_s, w_gs
bnellnm's avatar
bnellnm committed
260
261
262
263
264
265
266


def make_test_weight(
    e: int,
    rows: int,
    cols: int,
    in_dtype: torch.dtype = torch.bfloat16,
267
268
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
269
    per_out_ch_quant: bool = False,
270
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
271
    w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
272
    w_gs = None
bnellnm's avatar
bnellnm committed
273
274
275
276

    if quant_dtype is not None:
        w_l = [None] * e
        w_s_l = [None] * e
277
        w_gs_l = [None] * e
bnellnm's avatar
bnellnm committed
278
        for idx in range(e):
279
            w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
280
281
                w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
282
283
284

        w = torch.stack(w_l)
        w_s = torch.stack(w_s_l)
285
286
        if e > 0 and w_gs_l[0] is not None:
            w_gs = torch.stack(w_gs_l)
bnellnm's avatar
bnellnm committed
287
288
289
290
291
292
293
294
295
296
297
298
        if w_s.ndim == 2:
            assert w_s.shape[-1] == 1
            w_s = w_s.view(-1, 1, 1)

        if block_shape is not None:
            block_n, block_k = block_shape
            n_tiles = (rows + block_n - 1) // block_n
            k_tiles = (cols + block_k - 1) // block_k
            assert w_s.shape == (e, n_tiles, k_tiles)
    else:
        w = w_16
        w_s = None
299
        w_gs = None
bnellnm's avatar
bnellnm committed
300

301
    return w_16, w, w_s, w_gs
bnellnm's avatar
bnellnm committed
302
303
304
305
306
307
308


def make_test_weights(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype = torch.bfloat16,
309
310
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
311
    per_out_ch_quant: bool = False,
312
    make_gate: bool = True,
313
) -> tuple[
314
315
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
316
]:
bnellnm's avatar
bnellnm committed
317
    return (
318
        make_test_weight(
319
320
321
322
323
324
325
            e,
            (2 if make_gate else 1) * n,
            k,
            in_dtype,
            quant_dtype,
            block_shape,
            per_out_ch_quant,
326
327
        ),
        make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
bnellnm's avatar
bnellnm committed
328
    )
329
330
331


def per_token_cast_to_fp8(
332
333
    x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
334
335
336
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (block_size - (n % block_size)) % block_size
337
    x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
338
339
340
341
    x_view = x.view(m, -1, block_size)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
342
343


344
345
346
347
348
def make_test_quant_config(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype,
349
    quant_dtype: torch.dtype | str | None = None,
350
    per_act_token_quant: bool = False,
351
    block_shape: list[int] | None = None,
352
    make_gate: bool = True,
353
354
355
356
357
358
359
360
361
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
    (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
        e,
        n,
        k,
        in_dtype,
        quant_dtype,
        per_out_ch_quant=per_act_token_quant,
        block_shape=block_shape,
362
        make_gate=make_gate,
363
364
365
    )

    # Hacky/trivial scales for nvfp4.
366
367
    a1_gscale: torch.Tensor | None = None
    a2_gscale: torch.Tensor | None = None
368
    if quant_dtype == "nvfp4":
369
370
        a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
371
372
373
374
375
376
        a1_scale = a1_gscale
        a2_scale = a2_gscale
    else:
        a1_scale = None
        a2_scale = None

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    return (
        w1,
        w2,
        FusedMoEQuantConfig.make(
            quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
            w1_scale=w1_s,
            w2_scale=w2_s,
            a1_gscale=a1_gscale,
            a2_gscale=a2_gscale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            # TODO: make sure this is handled properly
            g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
            g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
        ),
394
395
396
397
398
399
400
401
402
403
    )


def fused_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    renormalize: bool = False,
404
    quant_config: FusedMoEQuantConfig | None = None,
405
    global_num_experts: int = -1,
406
    expert_map: torch.Tensor | None = None,
407
) -> torch.Tensor:
408
409
410
411
412
413
414
415
416
417
418
419
420
    topk_weights, topk_ids, _ = fused_topk(
        hidden_states, score.float(), topk, renormalize
    )
    return fused_experts(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        quant_config=quant_config,
    )
421
422


423
424
425
426
427
428
429
430
431
432
433
# CustomOp?
class BaselineMM(torch.nn.Module):
    def __init__(
        self,
        b: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.b = b.to(dtype=torch.float32)
        self.out_dtype = out_dtype

434
    def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
435
        return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477


class TestMLP(torch.nn.Module):
    def __init__(
        self,
        w1: torch.Tensor,
        w2: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.gate_up_proj = BaselineMM(w1, out_dtype)
        self.down_proj = BaselineMM(w2, out_dtype)
        self.act_fn = SiluAndMul()

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def make_naive_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
    w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
    w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
    return TestMLP(w1, w2, out_dtype=in_dtype)


class RealMLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        w1: torch.Tensor,
        w2: torch.Tensor,
        hidden_act: str = "silu",
        quant_config=None,
        reduce_results: bool = True,
        prefix: str = "",
478
479
        w1_s: torch.Tensor | None = None,
        w2_s: torch.Tensor | None = None,
480
481
    ) -> None:
        from vllm.model_executor.layers.linear import (
482
483
484
            MergedColumnParallelLinear,
            RowParallelLinear,
        )
485
486
487

        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
488
489
            hidden_size,
            [intermediate_size] * 2,
490
491
            bias=False,
            quant_config=quant_config,
492
493
            prefix=f"{prefix}.gate_up_proj",
        )
494
        self.gate_up_proj.register_parameter(
495
496
            "weight", torch.nn.Parameter(w1, requires_grad=False)
        )
497
        self.gate_up_proj.register_parameter(
498
499
            "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
        )
500
        self.gate_up_proj.register_parameter(
501
502
503
504
505
506
507
508
509
510
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
511
        self.down_proj.register_parameter(
512
513
            "weight", torch.nn.Parameter(w2, requires_grad=False)
        )
514
        self.down_proj.register_parameter(
515
516
            "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
        )
517
        self.down_proj.register_parameter(
518
519
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
520
        if hidden_act != "silu":
521
522
523
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
524
525
526
527
528
529
530
531
532
533
534
535
536
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


def make_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
537
    quant_dtype: torch.dtype | str | None = None,
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
) -> torch.nn.Module:
    from vllm.model_executor.layers.quantization.fp8 import Fp8Config

    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        1,
        N,
        K,
        in_dtype=in_dtype,
        quant_dtype=quant_dtype,
    )
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(in_dtype)
        if quant_dtype == torch.float8_e4m3fn:
            w1 = w1[0].transpose(0, 1)
            w2 = w2[0].transpose(0, 1)
            w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
            w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
            quant_config = Fp8Config(True)
        else:
            w1 = w1[0]
            w2 = w2[0]
            w1_s = None
            w2_s = None
            quant_config = None

564
        return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
565
566
    finally:
        torch.set_default_dtype(old_dtype)
567
568
569
570
571
572
573
574
575
576
577
578
579


def modular_triton_fused_moe(
    moe_config: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
    shared_experts: torch.nn.Module | None = None,
) -> FusedMoEModularKernel:
    return FusedMoEModularKernel(
        MoEPrepareAndFinalizeNoEP(),
        TritonExperts(moe_config, quant_config),
        shared_experts,
        inplace=False,
    )