untest_fp8.py 14.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests whether FP8 computation is enabled correctly.

Run `pytest tests/quantization/test_fp8.py --forked`.
"""
7

8
9
import pytest
import torch
10
import os
11

12
from tests.quantization.utils import is_quant_method_supported
13
from vllm import _custom_ops as ops
14
from vllm.model_executor.layers.fused_moe import FusedMoE
15
from vllm.model_executor.layers.quantization.fp8 import (
16
    Fp8Config,
17
18
    Fp8KVCacheMethod,
    Fp8LinearMethod,
19
    Fp8MoEMethod,
20
)
21
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
from vllm.platforms import current_platform
23
from ..utils import models_path_prefix
24

25
MODELS = [
26
    os.path.join(models_path_prefix, "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV"),
27
28
29
    # The checkpoint below was removed from the HF.
    # TODO: add a small replacement checkpoint.
    pytest.param(
30
        os.path.join(models_path_prefix, "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV"),
31
32
        marks=pytest.mark.skip(reason="Checkpoint removed from HF."),
    ),
33
34
35
]


36
37
38
39
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
40
@pytest.mark.parametrize("model_id", MODELS)
41
42
43
@pytest.mark.parametrize(
    "force_marlin", [False] if current_platform.is_rocm() else [False, True]
)
44
@pytest.mark.parametrize(
45
46
47
48
49
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_model_load_and_run(
    vllm_runner, model_id: str, force_marlin: bool, use_rocm_aiter: bool, monkeypatch
) -> None:
50
51
52
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

53
54
55
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

56
    with vllm_runner(model_id, enforce_eager=True) as llm:
57
58
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
59
        outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
60
61
62
63
64
        print(outputs[0][1])


KV_CACHE_MODELS = [
    # AutoFP8 format using separate .k_scale and .v_scale
65
66
67
68
    # The original checkpoint below was removed from the Hub. To unblock CI and
    # until a small replacement with split K/V scales is found, skip this case.
    # See PR #27717 for context.
    pytest.param(
69
        os.path.join(models_path_prefix, "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V"),
70
71
72
73
74
75
76
        marks=pytest.mark.skip(
            reason=(
                "Checkpoint removed from HF; temporarily disabling this "
                "AutoFP8 split K/V case (PR #27717)."
            )
        ),
    ),
77
78
79
]


80
81
82
83
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
84
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
85
@pytest.mark.parametrize(
86
87
88
89
90
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_kv_cache_model_load_and_run(
    vllm_runner, model_id: str, use_rocm_aiter: bool, monkeypatch
):
91
92
93
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

94
95
    # `LLM.apply_model` requires pickling a function.
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
96
    with vllm_runner(model_id, kv_cache_dtype="fp8", enforce_eager=True) as llm:
97

98
99
100
101
102
        def check_model(model):
            attn = model.model.layers[0].self_attn.attn

            assert isinstance(attn.quant_method, Fp8KVCacheMethod)

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            if not current_platform.is_rocm():
                # NOTE: This code path requires validation on Non-CUDA platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                assert 0.0 < attn._k_scale < 1.0
                assert 0.0 < attn._v_scale < 1.0
            else:
                # NOTE: This code path is for ROCm platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                # However on ROCm platform, the _k_scale and _v_scale will be
                # scaled by a factor of 2 as described in
                # vllm/model_executor/layers/quantization/kv_cache.py
                assert 0.0 < attn._k_scale < (1.0 * 2.0)
                assert 0.0 < attn._v_scale < (1.0 * 2.0)
118
119

        llm.apply_model(check_model)
120

121
122
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
123
        outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
124
125
        print(outputs[0][1])

126

127
128
129
130
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
131
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
132
133
134
@pytest.mark.parametrize(
    "force_marlin", [False] if current_platform.is_rocm() else [False, True]
)
135
@pytest.mark.parametrize(
136
137
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
138
def test_online_quantization(
139
140
141
142
143
144
    vllm_runner,
    kv_cache_dtype: str,
    force_marlin: bool,
    use_rocm_aiter: bool,
    monkeypatch,
) -> None:
145
146
147
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

148
149
    # `LLM.apply_model` requires pickling a function.
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
150

151
152
153
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

154
    with vllm_runner(
155
        os.path.join(models_path_prefix, "facebook/opt-125m"),
156
157
158
        quantization="fp8",
        enforce_eager=True,
        kv_cache_dtype=kv_cache_dtype,
159
    ) as llm:
160

161
162
163
164
165
166
167
168
169
        def check_model(model):
            fc1 = model.model.decoder.layers[0].fc1
            assert isinstance(fc1.quant_method, Fp8LinearMethod)
            if kv_cache_dtype == "fp8":
                attn = model.model.decoder.layers[0].self_attn.attn
                assert isinstance(attn.quant_method, Fp8KVCacheMethod)
                assert attn._k_scale == 1.0
                assert attn._v_scale == 1.0

170
            if current_platform.is_cuda():
171
                if current_platform.supports_fp8() and not force_marlin:
172
173
174
175
176
177
178
                    # For GPUs with hardware support, we keep weights in fp8
                    assert fc1.weight.dtype == torch.float8_e4m3fn
                else:
                    # For GPUs without hardware support, we pack the fp8 weights
                    # for weight-only quantization using Marlin kernels
                    assert fc1.weight.dtype == torch.int32
            elif current_platform.is_rocm():
179
                if current_platform.supports_fp8() and not force_marlin:
180
                    # For GPUs with hardware support, we keep weights in fp8
181
                    assert fc1.weight.dtype == current_platform.fp8_dtype()
182
183
184
185
                else:  # unsupported ROCm platform
                    pytest.skip(
                        "Skip `test_load_fp16_model`. "
                        "It only runs on ROCm platform with FP8 compute."
186
187
                        " e.g. MI300X and above."
                    )
188
            else:  # unsupported platform
189
190
191
192
                pytest.skip(
                    "Skip `test_load_fp16_model`. "
                    "It only runs on CUDA and ROCm platform."
                )
193
194

        llm.apply_model(check_model)
195

196
197
198
        outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
        print(outputs[0][1])

199

200
201
202
203
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
204
205
206
207
208
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:
    def quantize_ref(tensor, inv_scale):
        # The reference implementation that fully aligns to
        # the kernel being tested.
209
        finfo = torch.finfo(current_platform.fp8_dtype())
210
        scale = inv_scale.reciprocal()
211
        qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
212
        qweight = qweight.to(current_platform.fp8_dtype())
213
214
215
216
217
218
219
220
221
222
223
224
        return qweight

    def per_tensor_dequantize(tensor, inv_scale, dtype):
        fake_qweight = tensor.to(dtype)
        dq_weight = fake_qweight * inv_scale
        return dq_weight

    # Note that we use a shape % 4 != 0 to cover edge cases,
    # because scaled_fp8_quant is vectorized by 4.
    x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)

    # Dynamic quantization
225
    ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
226
227
    ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)

228
    # Reference dynamic quantization
229
    y = quantize_ref(x, inv_scale)
230
    torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
231
232

    # Static quantization
233
    y, _ = ops.scaled_fp8_quant(x, inv_scale)
234
    torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
235
236

    # Padding
237
    y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
238
    assert y.shape[0] == 17
239
    torch.testing.assert_close(
240
        ref_y,
241
242
        per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, dtype),
    )
243
244
245

    # non-contiguous input with padding
    m, n, padded_stride = 975, 512, 576
246
    padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * 13).to(dtype)
247
248
249
250
251
252
253
254
255
256
257
258
    x_nc = padded_tensor[:, :n]  # shape (m, n) with stride (padded_stride, 1)

    assert not x_nc.is_contiguous()
    assert x_nc.stride(0) == padded_stride

    # dynamic quantization
    ref_y_nc, inv_scale_nc = ops.scaled_fp8_quant(x_nc, None)
    ref_y_nc = per_tensor_dequantize(ref_y_nc, inv_scale_nc, dtype)

    # reference dynamic quantization
    y_nc = quantize_ref(x_nc, inv_scale_nc)
    torch.testing.assert_close(
259
260
        ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)
    )
261
262
263
264

    # static quantization
    y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc)
    torch.testing.assert_close(
265
266
        ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)
    )
267
268

    # padding after non-contiguous input quantization
269
    y_nc_pad, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc, num_token_padding=m + 10)
270
271
272
    assert y_nc_pad.shape[0] == m + 10
    torch.testing.assert_close(
        ref_y_nc,
273
274
275
276
        per_tensor_dequantize(
            torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
        ),
    )
277
278


279
280
281
282
@pytest.mark.skipif(
    current_platform.is_fp8_fnuz(),
    reason="FP8 e4m3fn weight reloading is not supported on e4m3fnuz platforms",
)
283
284
285
286
287
288
289
290
291
292
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
# FP8 weight reloading does not support online quantization
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True])  # skip False
@pytest.mark.parametrize("weight_block_size", [None, [1, 1]])
# any postprocessing that is applied to the weights such as padding and repacking
# (excluding device sharding) must also be applied to the reloaded weights
#
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False])  # skip True
def test_fp8_reloading(
293
    default_vllm_config,
294
295
296
297
298
299
    method_cls,
    is_checkpoint_fp8_serialized,
    weight_block_size,
    use_marlin,
    dist_init,
    monkeypatch,
300
):
301
302
303
304
305
    # NOTE(rob): this test fails when using DeepGEMM because the
    # shapes are invalid. Previously the test was passing because
    # we set fp8_backend to None, which sidestepped the issue.
    monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0")

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    if is_checkpoint_fp8_serialized is False:
        pytest.skip("FP8 weight reloading does not support online quantization")

    if method_cls is Fp8MoEMethod and weight_block_size is None:
        pytest.skip(
            "FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
            "If this is your use case, consider using a restore function like #26327"
        )

    with torch.device("cuda:0"):
        config = Fp8Config(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            weight_block_size=weight_block_size,
        )

        if method_cls is Fp8LinearMethod:
            layer = torch.nn.Linear(1, 1)
            method = method_cls(config)
            method.create_weights(
                layer=layer,
                input_size_per_partition=1,
                output_partition_sizes=[1],
                input_size=1,
                output_size=1,
                params_dtype=torch.bfloat16,
                weight_loader=default_weight_loader,
            )
333
            method.use_marlin = use_marlin
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374

        else:
            layer = FusedMoE(
                num_experts=1,
                top_k=1,
                hidden_size=1,
                intermediate_size=1,
            )
            method = method_cls(config, layer)
            method.create_weights(
                layer=layer,
                num_experts=1,
                hidden_size=1,
                intermediate_size_per_partition=1,
                params_dtype=torch.bfloat16,
                weight_loader=default_weight_loader,
            )

    # capture weights format during loading
    original_metadata = [
        (name, param.shape, getattr(param, "weight_loader", default_weight_loader))
        for name, param in layer.named_parameters()
    ]

    # test loading
    for name, shape, _ in original_metadata:
        param = getattr(layer, name)
        weight_loader = getattr(param, "weight_loader", default_weight_loader)
        weight_loader(param, torch.zeros(shape))  # cannot use empty

    method.process_weights_after_loading(layer)

    # test reloading works after loading
    # assuming that no reshaping occurred
    for name, shape, original_weight_loader in original_metadata:
        param = getattr(layer, name)
        weight_loader = getattr(param, "weight_loader", default_weight_loader)
        assert weight_loader is original_weight_loader
        weight_loader(param, torch.zeros(shape))  # cannot use empty

    method.process_weights_after_loading(layer)
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397


@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
def test_kv_cache_dtype_skip_layers(vllm_runner, monkeypatch):
    """Test that kv_cache_dtype_skip_layers skips quantization for specified layers."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

    with vllm_runner(
        "facebook/opt-125m",
        kv_cache_dtype="fp8",
        kv_cache_dtype_skip_layers=["0", "2"],
        enforce_eager=True,
    ) as llm:

        def check_layers(model):
            for i, layer in enumerate(model.model.decoder.layers):
                expected = "auto" if str(i) in ["0", "2"] else "fp8"
                assert layer.self_attn.attn.kv_cache_dtype == expected

        llm.apply_model(check_layers)