test_fp8.py 7.18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
"""Tests whether FP8 computation is enabled correctly.

Run `pytest tests/quantization/test_fp8.py --forked`.
"""
import pytest
import torch

9
from tests.quantization.utils import is_quant_method_supported
10
from vllm import _custom_ops as ops
11
12
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
                                                         Fp8LinearMethod)
13
from vllm.platforms import current_platform
14

15
MODELS = [
16
    "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
17
    "nm-testing/Phi-3-mini-128k-instruct-FP8",
18
    "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV",
19
20
21
22
23
]


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
24
@pytest.mark.parametrize("model_id", MODELS)
25
26
27
28
29
30
@pytest.mark.parametrize("force_marlin", [False, True])
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
                            monkeypatch) -> None:
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    with vllm_runner(model_id) as llm:
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
        outputs = llm.generate_greedy(prompts=["Hello my name is"],
                                      max_tokens=10)
        print(outputs[0][1])


KV_CACHE_MODELS = [
    # Deprecated AutoFP8 format using .kv_scale
    "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
    # AutoFP8 format using separate .k_scale and .v_scale
    "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
]


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
    with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:

53
54
55
56
57
        def check_model(model):
            attn = model.model.layers[0].self_attn.attn

            assert isinstance(attn.quant_method, Fp8KVCacheMethod)

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            if not current_platform.is_rocm():
                # NOTE: This code path requires validation on Non-CUDA platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                assert 0.0 < attn._k_scale < 1.0
                assert 0.0 < attn._v_scale < 1.0
            else:
                # NOTE: This code path is for ROCm platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                # However on ROCm platform, the _k_scale and _v_scale will be
                # scaled by a factor of 2 as described in
                # vllm/model_executor/layers/quantization/kv_cache.py
                assert 0.0 < attn._k_scale < (1.0 * 2.0)
                assert 0.0 < attn._v_scale < (1.0 * 2.0)
73
74

        llm.apply_model(check_model)
75

76
77
78
79
80
81
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
        outputs = llm.generate_greedy(prompts=["Hello my name is"],
                                      max_tokens=10)
        print(outputs[0][1])

82

83
84
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
85
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
86
87
88
89
90
91
@pytest.mark.parametrize("force_marlin", [False, True])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
                         monkeypatch) -> None:
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

92
93
94
    with vllm_runner("facebook/opt-125m",
                     quantization="fp8",
                     kv_cache_dtype=kv_cache_dtype) as llm:
95

96
97
98
99
100
101
102
103
104
        def check_model(model):
            fc1 = model.model.decoder.layers[0].fc1
            assert isinstance(fc1.quant_method, Fp8LinearMethod)
            if kv_cache_dtype == "fp8":
                attn = model.model.decoder.layers[0].self_attn.attn
                assert isinstance(attn.quant_method, Fp8KVCacheMethod)
                assert attn._k_scale == 1.0
                assert attn._v_scale == 1.0

105
            if current_platform.is_cuda():
106
                if current_platform.supports_fp8() and not force_marlin:
107
108
109
110
111
112
113
                    # For GPUs with hardware support, we keep weights in fp8
                    assert fc1.weight.dtype == torch.float8_e4m3fn
                else:
                    # For GPUs without hardware support, we pack the fp8 weights
                    # for weight-only quantization using Marlin kernels
                    assert fc1.weight.dtype == torch.int32
            elif current_platform.is_rocm():
114
                if current_platform.supports_fp8() and not force_marlin:
115
                    # For GPUs with hardware support, we keep weights in fp8
116
                    assert fc1.weight.dtype == current_platform.fp8_dtype()
117
118
119
120
121
122
123
124
                else:  # unsupported ROCm platform
                    pytest.skip(
                        "Skip `test_load_fp16_model`. "
                        "It only runs on ROCm platform with FP8 compute."
                        " e.g. MI300X and above.")
            else:  # unsupported platform
                pytest.skip("Skip `test_load_fp16_model`. "
                            "It only runs on CUDA and ROCm platform.")
125
126

        llm.apply_model(check_model)
127
128


129
130
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:

    def quantize_ref(tensor, inv_scale):
        # The reference implementation that fully aligns to
        # the kernel being tested.
        finfo = torch.finfo(torch.float8_e4m3fn)
        scale = inv_scale.reciprocal()
        qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
                                                           max=finfo.max)
        qweight = qweight.to(torch.float8_e4m3fn)
        return qweight

    def per_tensor_dequantize(tensor, inv_scale, dtype):
        fake_qweight = tensor.to(dtype)
        dq_weight = fake_qweight * inv_scale
        return dq_weight

    # Note that we use a shape % 4 != 0 to cover edge cases,
    # because scaled_fp8_quant is vectorized by 4.
    x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)

    # Dynamic quantization
154
    ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
155
156
157
158
    ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)

    # Reference dynamic quantizaton
    y = quantize_ref(x, inv_scale)
159
160
    torch.testing.assert_close(ref_y,
                               per_tensor_dequantize(y, inv_scale, dtype))
161
162

    # Static quantization
163
    y, _ = ops.scaled_fp8_quant(x, inv_scale)
164
165
    torch.testing.assert_close(ref_y,
                               per_tensor_dequantize(y, inv_scale, dtype))
166
167

    # Padding
168
    y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
169
    assert y.shape[0] == 17
170
    torch.testing.assert_close(
171
172
173
        ref_y,
        per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
                              dtype))