utils.py 9.86 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
bnellnm's avatar
bnellnm committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Optional, Union
4
5
6

import torch

bnellnm's avatar
bnellnm committed
7
import vllm._custom_ops as ops
8
from tests.kernels.quant_utils import per_block_cast_to_int8
9
10
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
                                                    FLOAT8_E4M3_MAX)
bnellnm's avatar
bnellnm committed
11
12
13
14
15
16
17
18
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
    BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
    FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.utils import (
    moe_kernel_quantize_input)
from vllm.utils import round_up
19
from vllm.utils.deep_gemm import per_block_cast_to_fp8
bnellnm's avatar
bnellnm committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67


def triton_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    w1_scale: Optional[torch.Tensor] = None,
    w2_scale: Optional[torch.Tensor] = None,
    a1_scale: Optional[torch.Tensor] = None,
    a2_scale: Optional[torch.Tensor] = None,
    quant_dtype: Optional[torch.dtype] = None,
    per_act_token_quant=False,
    block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
    return fused_experts(a,
                         w1,
                         w2,
                         topk_weight,
                         topk_ids,
                         w1_scale=w1_scale,
                         w2_scale=w2_scale,
                         a1_scale=a1_scale,
                         a2_scale=a2_scale,
                         per_channel_quant=per_act_token_quant,
                         use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
                         block_shape=block_shape)


def batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    w1_scale: Optional[torch.Tensor] = None,
    w2_scale: Optional[torch.Tensor] = None,
    a1_scale: Optional[torch.Tensor] = None,
    a2_scale: Optional[torch.Tensor] = None,
    quant_dtype: Optional[torch.dtype] = None,
    per_act_token_quant: bool = False,
    block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

    fused_experts = FusedMoEModularKernel(
        BatchedPrepareAndFinalize(max_num_tokens,
68
69
                                  num_dispatchers=1,
                                  num_local_experts=w1.shape[0],
bnellnm's avatar
bnellnm committed
70
71
72
                                  rank=0),
        BatchedTritonExperts(
            max_num_tokens=max_num_tokens,
73
            num_dispatchers=1,
bnellnm's avatar
bnellnm committed
74
75
76
77
            use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
        ),
78
79
    )

bnellnm's avatar
bnellnm committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    return fused_experts(a,
                         w1,
                         w2,
                         topk_weight,
                         topk_ids,
                         w1_scale=w1_scale,
                         w2_scale=w2_scale,
                         a1_scale=a1_scale,
                         a2_scale=a2_scale)


def naive_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    w1_scale: Optional[torch.Tensor] = None,
    w2_scale: Optional[torch.Tensor] = None,
    a1_scale: Optional[torch.Tensor] = None,
    a2_scale: Optional[torch.Tensor] = None,
    quant_dtype: Optional[torch.dtype] = None,
    per_act_token_quant: bool = False,
    block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

    fused_experts = FusedMoEModularKernel(
        BatchedPrepareAndFinalize(max_num_tokens,
109
110
                                  num_dispatchers=1,
                                  num_local_experts=w1.shape[0],
bnellnm's avatar
bnellnm committed
111
112
113
                                  rank=0),
        NaiveBatchedExperts(
            max_num_tokens=max_num_tokens,
114
            num_dispatchers=1,
bnellnm's avatar
bnellnm committed
115
116
117
118
119
            use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
        ),
    )
120

bnellnm's avatar
bnellnm committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    return fused_experts(a,
                         w1,
                         w2,
                         topk_weight,
                         topk_ids,
                         w1_scale=w1_scale,
                         w2_scale=w2_scale,
                         a1_scale=a1_scale,
                         a2_scale=a2_scale)


def chunk_scales(scales: Optional[torch.Tensor], start: int,
                 end: int) -> Optional[torch.Tensor]:
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None


def make_quantized_test_activations(
    E: int,
    m: int,
    k: int,
    in_dtype: torch.dtype,
    quant_dtype: Optional[torch.dtype] = None,
    block_shape: Optional[list[int]] = None,
    per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
    a_q = a
    a_scale = None

    if quant_dtype is not None:
        assert (quant_dtype == torch.float8_e4m3fn
                or quant_dtype == torch.int8), "only fp8/int8 supported"
        a_q = torch.zeros_like(a, dtype=quant_dtype)
        a_scale_l = [None] * E
        for e in range(E):
            a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
                a[e], None, quant_dtype, per_act_token_quant, block_shape)
        a_scale = torch.stack(a_scale_l)

        if not per_act_token_quant and block_shape is None:
            a_scale = a_scale.view(E, 1, 1)

    return a, a_q, a_scale


def moe_quantize_weights(
    w: torch.Tensor,
    w_s: Optional[torch.Tensor],
174
    quant_dtype: Union[torch.dtype, str, None],
bnellnm's avatar
bnellnm committed
175
176
    per_token_quant: bool,
    block_shape: Optional[list[int]],
177
178
179
180
181
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
    assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
            or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"

    w_gs = None
bnellnm's avatar
bnellnm committed
182
183
184
185
186

    if block_shape is not None:
        assert not per_token_quant
        if quant_dtype == torch.int8:
            w, w_s = per_block_cast_to_int8(w, block_shape)
187
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
188
            w, w_s = per_block_cast_to_fp8(w, block_shape)
189
190
191
192
        elif quant_dtype == "nvfp4":
            raise RuntimeError("blocked quantization not supported for nvfp4")
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
193
194
195
196
    else:
        if quant_dtype == torch.int8:
            w, w_s = ops.scaled_int8_quant(
                w, w_s, use_per_token_if_dynamic=per_token_quant)
197
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
198
199
            w, w_s = ops.scaled_fp8_quant(
                w, w_s, use_per_token_if_dynamic=per_token_quant)
200
201
202
203
204
205
206
        elif quant_dtype == "nvfp4":
            assert not per_token_quant
            w_amax = torch.abs(w).max().to(torch.float32)
            w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
            w, w_s = ops.scaled_fp4_quant(w, w_gs)
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
207

208
    return w, w_s, w_gs
bnellnm's avatar
bnellnm committed
209
210
211
212
213
214
215


def make_test_weight(
    e: int,
    rows: int,
    cols: int,
    in_dtype: torch.dtype = torch.bfloat16,
216
    quant_dtype: Union[torch.dtype, str, None] = None,
bnellnm's avatar
bnellnm committed
217
218
    block_shape: Optional[list[int]] = None,
    per_act_token_quant: bool = False,
219
220
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
           Optional[torch.Tensor]]:
bnellnm's avatar
bnellnm committed
221
    w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
222
    w_gs = None
bnellnm's avatar
bnellnm committed
223
224
225
226

    if quant_dtype is not None:
        w_l = [None] * e
        w_s_l = [None] * e
227
        w_gs_l = [None] * e
bnellnm's avatar
bnellnm committed
228
        for idx in range(e):
229
            w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
bnellnm's avatar
bnellnm committed
230
231
232
233
                w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)

        w = torch.stack(w_l)
        w_s = torch.stack(w_s_l)
234
235
        if e > 0 and w_gs_l[0] is not None:
            w_gs = torch.stack(w_gs_l)
bnellnm's avatar
bnellnm committed
236
237
238
239
240
241
242
243
244
245
246
247
        if w_s.ndim == 2:
            assert w_s.shape[-1] == 1
            w_s = w_s.view(-1, 1, 1)

        if block_shape is not None:
            block_n, block_k = block_shape
            n_tiles = (rows + block_n - 1) // block_n
            k_tiles = (cols + block_k - 1) // block_k
            assert w_s.shape == (e, n_tiles, k_tiles)
    else:
        w = w_16
        w_s = None
248
        w_gs = None
bnellnm's avatar
bnellnm committed
249

250
    return w_16, w, w_s, w_gs
bnellnm's avatar
bnellnm committed
251
252
253
254
255
256
257


def make_test_weights(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype = torch.bfloat16,
258
    quant_dtype: Union[torch.dtype, str, None] = None,
bnellnm's avatar
bnellnm committed
259
260
    block_shape: Optional[list[int]] = None,
    per_act_token_quant: bool = False,
261
262
263
264
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
                 Optional[torch.Tensor]],
           tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
                 Optional[torch.Tensor]]]:
bnellnm's avatar
bnellnm committed
265
    return (
266
267
268
269
        make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
                         per_act_token_quant),
        make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
                         per_act_token_quant),
bnellnm's avatar
bnellnm committed
270
    )
271
272
273
274
275
276
277
278
279
280
281
282
283
284


def per_token_cast_to_fp8(
        x: torch.Tensor,
        block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (block_size - (n % block_size)) % block_size
    x = torch.nn.functional.pad(x,
                                (0, pad_size), value=0) if pad_size > 0 else x
    x_view = x.view(m, -1, block_size)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)