test_fp8.py 8.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""Tests whether FP8 computation is enabled correctly.

Run `pytest tests/quantization/test_fp8.py --forked`.
"""
import pytest
import torch

10
from tests.quantization.utils import is_quant_method_supported
11
from vllm import _custom_ops as ops
12
13
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
                                                         Fp8LinearMethod)
14
from vllm.platforms import current_platform
15

16
MODELS = [
17
    "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
18
    "nm-testing/Phi-3-mini-128k-instruct-FP8",
19
    "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV",
20
21
22
23
24
]


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
25
@pytest.mark.parametrize("model_id", MODELS)
26
@pytest.mark.parametrize("force_marlin", [False, True])
27
28
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
29
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
30
31
32
33
34
                            use_rocm_aiter: bool, monkeypatch) -> None:

    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

35
36
37
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    with vllm_runner(model_id) as llm:
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
        outputs = llm.generate_greedy(prompts=["Hello my name is"],
                                      max_tokens=10)
        print(outputs[0][1])


KV_CACHE_MODELS = [
    # Deprecated AutoFP8 format using .kv_scale
    "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
    # AutoFP8 format using separate .k_scale and .v_scale
    "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
]


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
57
58
59
60
61
62
63
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
                                     use_rocm_aiter: bool, monkeypatch):
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

64
65
    # vllm_runner.apply_model() relies on V0 internals.
    monkeypatch.setenv("VLLM_USE_V1", "0")
66
67
    with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:

68
69
70
71
72
        def check_model(model):
            attn = model.model.layers[0].self_attn.attn

            assert isinstance(attn.quant_method, Fp8KVCacheMethod)

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
            if not current_platform.is_rocm():
                # NOTE: This code path requires validation on Non-CUDA platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                assert 0.0 < attn._k_scale < 1.0
                assert 0.0 < attn._v_scale < 1.0
            else:
                # NOTE: This code path is for ROCm platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                # However on ROCm platform, the _k_scale and _v_scale will be
                # scaled by a factor of 2 as described in
                # vllm/model_executor/layers/quantization/kv_cache.py
                assert 0.0 < attn._k_scale < (1.0 * 2.0)
                assert 0.0 < attn._v_scale < (1.0 * 2.0)
88
89

        llm.apply_model(check_model)
90

91
92
93
94
95
96
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
        outputs = llm.generate_greedy(prompts=["Hello my name is"],
                                      max_tokens=10)
        print(outputs[0][1])

97

98
99
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
100
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
101
@pytest.mark.parametrize("force_marlin", [False, True])
102
103
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
104
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
105
106
107
108
                         use_rocm_aiter: bool, monkeypatch) -> None:
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

109
110
111
    # vllm_runner.apply_model() relies on V0 internals.
    monkeypatch.setenv("VLLM_USE_V1", "0")

112
113
114
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

115
116
117
    with vllm_runner("facebook/opt-125m",
                     quantization="fp8",
                     kv_cache_dtype=kv_cache_dtype) as llm:
118

119
120
121
122
123
124
125
126
127
        def check_model(model):
            fc1 = model.model.decoder.layers[0].fc1
            assert isinstance(fc1.quant_method, Fp8LinearMethod)
            if kv_cache_dtype == "fp8":
                attn = model.model.decoder.layers[0].self_attn.attn
                assert isinstance(attn.quant_method, Fp8KVCacheMethod)
                assert attn._k_scale == 1.0
                assert attn._v_scale == 1.0

128
            if current_platform.is_cuda():
129
                if current_platform.supports_fp8() and not force_marlin:
130
131
132
133
134
135
136
                    # For GPUs with hardware support, we keep weights in fp8
                    assert fc1.weight.dtype == torch.float8_e4m3fn
                else:
                    # For GPUs without hardware support, we pack the fp8 weights
                    # for weight-only quantization using Marlin kernels
                    assert fc1.weight.dtype == torch.int32
            elif current_platform.is_rocm():
137
                if current_platform.supports_fp8() and not force_marlin:
138
                    # For GPUs with hardware support, we keep weights in fp8
139
                    assert fc1.weight.dtype == current_platform.fp8_dtype()
140
141
142
143
144
145
146
147
                else:  # unsupported ROCm platform
                    pytest.skip(
                        "Skip `test_load_fp16_model`. "
                        "It only runs on ROCm platform with FP8 compute."
                        " e.g. MI300X and above.")
            else:  # unsupported platform
                pytest.skip("Skip `test_load_fp16_model`. "
                            "It only runs on CUDA and ROCm platform.")
148
149

        llm.apply_model(check_model)
150
151


152
153
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:

    def quantize_ref(tensor, inv_scale):
        # The reference implementation that fully aligns to
        # the kernel being tested.
        finfo = torch.finfo(torch.float8_e4m3fn)
        scale = inv_scale.reciprocal()
        qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
                                                           max=finfo.max)
        qweight = qweight.to(torch.float8_e4m3fn)
        return qweight

    def per_tensor_dequantize(tensor, inv_scale, dtype):
        fake_qweight = tensor.to(dtype)
        dq_weight = fake_qweight * inv_scale
        return dq_weight

    # Note that we use a shape % 4 != 0 to cover edge cases,
    # because scaled_fp8_quant is vectorized by 4.
    x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)

    # Dynamic quantization
177
    ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
178
179
180
181
    ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)

    # Reference dynamic quantizaton
    y = quantize_ref(x, inv_scale)
182
183
    torch.testing.assert_close(ref_y,
                               per_tensor_dequantize(y, inv_scale, dtype))
184
185

    # Static quantization
186
    y, _ = ops.scaled_fp8_quant(x, inv_scale)
187
188
    torch.testing.assert_close(ref_y,
                               per_tensor_dequantize(y, inv_scale, dtype))
189
190

    # Padding
191
    y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
192
    assert y.shape[0] == 17
193
    torch.testing.assert_close(
194
195
196
        ref_y,
        per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
                              dtype))