test_fp8.py 17.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests whether FP8 computation is enabled correctly.

Run `pytest tests/quantization/test_fp8.py --forked`.
"""
7

8
9
import logging

10
import pytest
11
import regex as re
12
13
import torch

14
from tests.quantization.utils import is_quant_method_supported
15
from vllm import _custom_ops as ops
16
from vllm.config.model import ModelConfig
17
from vllm.model_executor.layers.fused_moe import FusedMoE
18
from vllm.model_executor.layers.quantization.fp8 import (
19
    Fp8Config,
20
21
    Fp8KVCacheMethod,
    Fp8LinearMethod,
22
    Fp8MoEMethod,
23
)
24
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
from vllm.platforms import current_platform
26

27
MODELS = [
28
    "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
29
30
31
32
33
34
    # The checkpoint below was removed from the HF.
    # TODO: add a small replacement checkpoint.
    pytest.param(
        "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV",
        marks=pytest.mark.skip(reason="Checkpoint removed from HF."),
    ),
35
36
37
]


38
39
40
41
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
42
@pytest.mark.parametrize("model_id", MODELS)
43
44
45
@pytest.mark.parametrize(
    "force_marlin", [False] if current_platform.is_rocm() else [False, True]
)
46
@pytest.mark.parametrize(
47
48
49
50
51
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_model_load_and_run(
    vllm_runner, model_id: str, force_marlin: bool, use_rocm_aiter: bool, monkeypatch
) -> None:
52
53
54
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

55
56
57
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

58
    with vllm_runner(model_id, enforce_eager=True) as llm:
59
60
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
61
        outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
62
63
64
65
66
        print(outputs[0][1])


KV_CACHE_MODELS = [
    # AutoFP8 format using separate .k_scale and .v_scale
67
68
69
70
71
72
73
74
75
76
77
78
    # The original checkpoint below was removed from the Hub. To unblock CI and
    # until a small replacement with split K/V scales is found, skip this case.
    # See PR #27717 for context.
    pytest.param(
        "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
        marks=pytest.mark.skip(
            reason=(
                "Checkpoint removed from HF; temporarily disabling this "
                "AutoFP8 split K/V case (PR #27717)."
            )
        ),
    ),
79
80
81
]


82
83
84
85
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
86
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
87
@pytest.mark.parametrize(
88
89
90
91
92
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_kv_cache_model_load_and_run(
    vllm_runner, model_id: str, use_rocm_aiter: bool, monkeypatch
):
93
94
95
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

96
97
    # `LLM.apply_model` requires pickling a function.
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
98
    with vllm_runner(model_id, kv_cache_dtype="fp8", enforce_eager=True) as llm:
99

100
101
102
103
104
        def check_model(model):
            attn = model.model.layers[0].self_attn.attn

            assert isinstance(attn.quant_method, Fp8KVCacheMethod)

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            if not current_platform.is_rocm():
                # NOTE: This code path requires validation on Non-CUDA platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                assert 0.0 < attn._k_scale < 1.0
                assert 0.0 < attn._v_scale < 1.0
            else:
                # NOTE: This code path is for ROCm platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                # However on ROCm platform, the _k_scale and _v_scale will be
                # scaled by a factor of 2 as described in
                # vllm/model_executor/layers/quantization/kv_cache.py
                assert 0.0 < attn._k_scale < (1.0 * 2.0)
                assert 0.0 < attn._v_scale < (1.0 * 2.0)
120
121

        llm.apply_model(check_model)
122

123
124
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
125
        outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
126
127
        print(outputs[0][1])

128

129
130
131
132
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
133
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
134
135
136
@pytest.mark.parametrize(
    "force_marlin", [False] if current_platform.is_rocm() else [False, True]
)
137
@pytest.mark.parametrize(
138
139
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
140
def test_online_quantization(
141
142
143
144
145
146
    vllm_runner,
    kv_cache_dtype: str,
    force_marlin: bool,
    use_rocm_aiter: bool,
    monkeypatch,
) -> None:
147
148
149
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

150
151
    # `LLM.apply_model` requires pickling a function.
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
152

153
154
155
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

156
    with vllm_runner(
157
158
159
160
        "facebook/opt-125m",
        quantization="fp8",
        enforce_eager=True,
        kv_cache_dtype=kv_cache_dtype,
161
    ) as llm:
162

163
164
165
166
167
168
169
170
171
        def check_model(model):
            fc1 = model.model.decoder.layers[0].fc1
            assert isinstance(fc1.quant_method, Fp8LinearMethod)
            if kv_cache_dtype == "fp8":
                attn = model.model.decoder.layers[0].self_attn.attn
                assert isinstance(attn.quant_method, Fp8KVCacheMethod)
                assert attn._k_scale == 1.0
                assert attn._v_scale == 1.0

172
            if current_platform.is_cuda():
173
                if current_platform.supports_fp8() and not force_marlin:
174
175
176
177
178
179
180
                    # For GPUs with hardware support, we keep weights in fp8
                    assert fc1.weight.dtype == torch.float8_e4m3fn
                else:
                    # For GPUs without hardware support, we pack the fp8 weights
                    # for weight-only quantization using Marlin kernels
                    assert fc1.weight.dtype == torch.int32
            elif current_platform.is_rocm():
181
                if current_platform.supports_fp8() and not force_marlin:
182
                    # For GPUs with hardware support, we keep weights in fp8
183
                    assert fc1.weight.dtype == current_platform.fp8_dtype()
184
185
186
187
                else:  # unsupported ROCm platform
                    pytest.skip(
                        "Skip `test_load_fp16_model`. "
                        "It only runs on ROCm platform with FP8 compute."
188
189
                        " e.g. MI300X and above."
                    )
190
            else:  # unsupported platform
191
192
193
194
                pytest.skip(
                    "Skip `test_load_fp16_model`. "
                    "It only runs on CUDA and ROCm platform."
                )
195
196

        llm.apply_model(check_model)
197

198
199
200
        outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
        print(outputs[0][1])

201

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_peak_mem(
    vllm_runner,
    caplog_mp_spawn,
    monkeypatch,
) -> None:
    # Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because:
    # 1. it covers both Linear and MoE paths
    # 2. it is already used by other tests in CI, so adding it here
    #    does not increase disk space for CI runners
    # I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base`
    # which I think is the smallest MoE model in vLLM (2.5 GiB bf16,
    # 1.3 GiB fp8), but could not as adding one more model makes CI
    # run out of disk space.
    model_name = "allenai/OLMoE-1B-7B-0125-Instruct"

    # Force spawn to ensure caplog_mp_spawn works consistently
    # (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores)
    monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")

    with (
        caplog_mp_spawn(logging.DEBUG) as log_holder,
        vllm_runner(
            model_name,
            quantization="fp8",
            enforce_eager=True,
        ) as llm,
    ):
        outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
        print(outputs[0][1])

    log_text = log_holder.text

    # Parse memory usage from captured logs
    model_memory_gib = None
    peak_memory_gib = None
    for line in log_text.splitlines():
        if model_memory_gib is None:
            match = re.search(r"Model loading took ([\d.]+) GiB memory", line)
            if match:
                model_memory_gib = float(match.group(1))
        if peak_memory_gib is None:
            match = re.search(
                r"Peak GPU memory after loading weights: ([\d.]+) GiB", line
            )
            if match:
                peak_memory_gib = float(match.group(1))

    assert model_memory_gib is not None, "Could not find model loading memory log"
    assert peak_memory_gib is not None, "Could not find peak memory log"
    print(f"GPU memory used after loading weights: {model_memory_gib} GiB")
    print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB")

    # model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
    # uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
    expected_model_memory_gib = 6.7

    # for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
    # GiB, which is 1.36x above model_memory_gib. A slightly higher number is
    # expected as when we load and quantize weights in a streaming fashion we
    # need to have individual weights in bf16 + fp8 alive at the same time.
    expected_peak_memory_gib = expected_model_memory_gib * 1.4

    assert model_memory_gib < expected_model_memory_gib, (
        f"{model_memory_gib=} higher than {expected_model_memory_gib}"
    )
    assert peak_memory_gib < expected_peak_memory_gib, (
        f"{peak_memory_gib=} higher than {expected_peak_memory_gib}"
    )


@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_load_format_dummy(
    vllm_runner,
    monkeypatch,
    caplog,
) -> None:
    with vllm_runner(
        "ibm-granite/granite-3.0-1b-a400m-base",
        quantization="fp8",
        enforce_eager=True,
        load_format="dummy",
    ) as llm:
        outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
        print(outputs[0][1])


295
296
297
298
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
299
300
301
302
303
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:
    def quantize_ref(tensor, inv_scale):
        # The reference implementation that fully aligns to
        # the kernel being tested.
304
        finfo = torch.finfo(current_platform.fp8_dtype())
305
        scale = inv_scale.reciprocal()
306
        qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
307
        qweight = qweight.to(current_platform.fp8_dtype())
308
309
310
311
312
313
314
315
316
317
318
319
        return qweight

    def per_tensor_dequantize(tensor, inv_scale, dtype):
        fake_qweight = tensor.to(dtype)
        dq_weight = fake_qweight * inv_scale
        return dq_weight

    # Note that we use a shape % 4 != 0 to cover edge cases,
    # because scaled_fp8_quant is vectorized by 4.
    x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)

    # Dynamic quantization
320
    ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
321
322
    ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)

323
    # Reference dynamic quantization
324
    y = quantize_ref(x, inv_scale)
325
    torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
326
327

    # Static quantization
328
    y, _ = ops.scaled_fp8_quant(x, inv_scale)
329
    torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
330
331

    # Padding
332
    y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
333
    assert y.shape[0] == 17
334
    torch.testing.assert_close(
335
        ref_y,
336
337
        per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, dtype),
    )
338
339
340

    # non-contiguous input with padding
    m, n, padded_stride = 975, 512, 576
341
    padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * 13).to(dtype)
342
343
344
345
346
347
348
349
350
351
352
353
    x_nc = padded_tensor[:, :n]  # shape (m, n) with stride (padded_stride, 1)

    assert not x_nc.is_contiguous()
    assert x_nc.stride(0) == padded_stride

    # dynamic quantization
    ref_y_nc, inv_scale_nc = ops.scaled_fp8_quant(x_nc, None)
    ref_y_nc = per_tensor_dequantize(ref_y_nc, inv_scale_nc, dtype)

    # reference dynamic quantization
    y_nc = quantize_ref(x_nc, inv_scale_nc)
    torch.testing.assert_close(
354
355
        ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)
    )
356
357
358
359

    # static quantization
    y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc)
    torch.testing.assert_close(
360
361
        ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)
    )
362
363

    # padding after non-contiguous input quantization
364
    y_nc_pad, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc, num_token_padding=m + 10)
365
366
367
    assert y_nc_pad.shape[0] == m + 10
    torch.testing.assert_close(
        ref_y_nc,
368
369
370
371
        per_tensor_dequantize(
            torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
        ),
    )
372
373


374
375
376
377
@pytest.mark.skipif(
    current_platform.is_fp8_fnuz(),
    reason="FP8 e4m3fn weight reloading is not supported on e4m3fnuz platforms",
)
378
379
380
381
382
383
384
385
386
387
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
# FP8 weight reloading does not support online quantization
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True])  # skip False
@pytest.mark.parametrize("weight_block_size", [None, [1, 1]])
# any postprocessing that is applied to the weights such as padding and repacking
# (excluding device sharding) must also be applied to the reloaded weights
#
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False])  # skip True
def test_fp8_reloading(
388
    default_vllm_config,
389
390
391
392
393
394
    method_cls,
    is_checkpoint_fp8_serialized,
    weight_block_size,
    use_marlin,
    dist_init,
    monkeypatch,
395
):
396
397
398
399
400
    # NOTE(rob): this test fails when using DeepGEMM because the
    # shapes are invalid. Previously the test was passing because
    # we set fp8_backend to None, which sidestepped the issue.
    monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0")

401
402
403
404
405
406
407
408
409
    if is_checkpoint_fp8_serialized is False:
        pytest.skip("FP8 weight reloading does not support online quantization")

    if method_cls is Fp8MoEMethod and weight_block_size is None:
        pytest.skip(
            "FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
            "If this is your use case, consider using a restore function like #26327"
        )

410
411
    # Set model config as model_config.dtype is required in Fp8LinearMethod.
    default_vllm_config.model_config = ModelConfig()
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    with torch.device("cuda:0"):
        config = Fp8Config(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            weight_block_size=weight_block_size,
        )

        if method_cls is Fp8LinearMethod:
            layer = torch.nn.Linear(1, 1)
            method = method_cls(config)
            method.create_weights(
                layer=layer,
                input_size_per_partition=1,
                output_partition_sizes=[1],
                input_size=1,
                output_size=1,
                params_dtype=torch.bfloat16,
                weight_loader=default_weight_loader,
            )
430
            method.use_marlin = use_marlin
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471

        else:
            layer = FusedMoE(
                num_experts=1,
                top_k=1,
                hidden_size=1,
                intermediate_size=1,
            )
            method = method_cls(config, layer)
            method.create_weights(
                layer=layer,
                num_experts=1,
                hidden_size=1,
                intermediate_size_per_partition=1,
                params_dtype=torch.bfloat16,
                weight_loader=default_weight_loader,
            )

    # capture weights format during loading
    original_metadata = [
        (name, param.shape, getattr(param, "weight_loader", default_weight_loader))
        for name, param in layer.named_parameters()
    ]

    # test loading
    for name, shape, _ in original_metadata:
        param = getattr(layer, name)
        weight_loader = getattr(param, "weight_loader", default_weight_loader)
        weight_loader(param, torch.zeros(shape))  # cannot use empty

    method.process_weights_after_loading(layer)

    # test reloading works after loading
    # assuming that no reshaping occurred
    for name, shape, original_weight_loader in original_metadata:
        param = getattr(layer, name)
        weight_loader = getattr(param, "weight_loader", default_weight_loader)
        assert weight_loader is original_weight_loader
        weight_loader(param, torch.zeros(shape))  # cannot use empty

    method.process_weights_after_loading(layer)
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494


@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
def test_kv_cache_dtype_skip_layers(vllm_runner, monkeypatch):
    """Test that kv_cache_dtype_skip_layers skips quantization for specified layers."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

    with vllm_runner(
        "facebook/opt-125m",
        kv_cache_dtype="fp8",
        kv_cache_dtype_skip_layers=["0", "2"],
        enforce_eager=True,
    ) as llm:

        def check_layers(model):
            for i, layer in enumerate(model.model.decoder.layers):
                expected = "auto" if str(i) in ["0", "2"] else "fp8"
                assert layer.self_attn.attn.kv_cache_dtype == expected

        llm.apply_model(check_layers)