test_fp8.py 9.53 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests whether FP8 computation is enabled correctly.

Run `pytest tests/quantization/test_fp8.py --forked`.
"""
7

8
9
10
import pytest
import torch

11
from tests.quantization.utils import is_quant_method_supported
12
from vllm import _custom_ops as ops
13
14
15
16
from vllm.model_executor.layers.quantization.fp8 import (
    Fp8KVCacheMethod,
    Fp8LinearMethod,
)
17
from vllm.platforms import current_platform
18

19
MODELS = [
20
    "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
21
22
23
24
25
26
    # The checkpoint below was removed from the HF.
    # TODO: add a small replacement checkpoint.
    pytest.param(
        "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV",
        marks=pytest.mark.skip(reason="Checkpoint removed from HF."),
    ),
27
28
29
]


30
31
32
33
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
34
@pytest.mark.parametrize("model_id", MODELS)
35
@pytest.mark.parametrize("force_marlin", [False, True])
36
@pytest.mark.parametrize(
37
38
39
40
41
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_model_load_and_run(
    vllm_runner, model_id: str, force_marlin: bool, use_rocm_aiter: bool, monkeypatch
) -> None:
42
43
44
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

45
46
47
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

48
    with vllm_runner(model_id, enforce_eager=True) as llm:
49
50
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
51
        outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
52
53
54
55
56
        print(outputs[0][1])


KV_CACHE_MODELS = [
    # AutoFP8 format using separate .k_scale and .v_scale
57
58
59
60
61
62
63
64
65
66
67
68
    # The original checkpoint below was removed from the Hub. To unblock CI and
    # until a small replacement with split K/V scales is found, skip this case.
    # See PR #27717 for context.
    pytest.param(
        "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
        marks=pytest.mark.skip(
            reason=(
                "Checkpoint removed from HF; temporarily disabling this "
                "AutoFP8 split K/V case (PR #27717)."
            )
        ),
    ),
69
70
71
]


72
73
74
75
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
76
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
77
@pytest.mark.parametrize(
78
79
80
81
82
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_kv_cache_model_load_and_run(
    vllm_runner, model_id: str, use_rocm_aiter: bool, monkeypatch
):
83
84
85
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

86
87
    # `LLM.apply_model` requires pickling a function.
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
88
    with vllm_runner(model_id, kv_cache_dtype="fp8", enforce_eager=True) as llm:
89

90
91
92
93
94
        def check_model(model):
            attn = model.model.layers[0].self_attn.attn

            assert isinstance(attn.quant_method, Fp8KVCacheMethod)

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            if not current_platform.is_rocm():
                # NOTE: This code path requires validation on Non-CUDA platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                assert 0.0 < attn._k_scale < 1.0
                assert 0.0 < attn._v_scale < 1.0
            else:
                # NOTE: This code path is for ROCm platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                # However on ROCm platform, the _k_scale and _v_scale will be
                # scaled by a factor of 2 as described in
                # vllm/model_executor/layers/quantization/kv_cache.py
                assert 0.0 < attn._k_scale < (1.0 * 2.0)
                assert 0.0 < attn._v_scale < (1.0 * 2.0)
110
111

        llm.apply_model(check_model)
112

113
114
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
115
        outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
116
117
        print(outputs[0][1])

118

119
120
121
122
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
123
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
124
@pytest.mark.parametrize("force_marlin", [False, True])
125
@pytest.mark.parametrize(
126
127
128
129
130
131
132
133
134
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_load_fp16_model(
    vllm_runner,
    kv_cache_dtype: str,
    force_marlin: bool,
    use_rocm_aiter: bool,
    monkeypatch,
) -> None:
135
136
137
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

138
139
    # `LLM.apply_model` requires pickling a function.
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
140

141
142
143
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

144
    with vllm_runner(
145
146
147
148
        "facebook/opt-125m",
        quantization="fp8",
        enforce_eager=True,
        kv_cache_dtype=kv_cache_dtype,
149
    ) as llm:
150

151
152
153
154
155
156
157
158
159
        def check_model(model):
            fc1 = model.model.decoder.layers[0].fc1
            assert isinstance(fc1.quant_method, Fp8LinearMethod)
            if kv_cache_dtype == "fp8":
                attn = model.model.decoder.layers[0].self_attn.attn
                assert isinstance(attn.quant_method, Fp8KVCacheMethod)
                assert attn._k_scale == 1.0
                assert attn._v_scale == 1.0

160
            if current_platform.is_cuda():
161
                if current_platform.supports_fp8() and not force_marlin:
162
163
164
165
166
167
168
                    # For GPUs with hardware support, we keep weights in fp8
                    assert fc1.weight.dtype == torch.float8_e4m3fn
                else:
                    # For GPUs without hardware support, we pack the fp8 weights
                    # for weight-only quantization using Marlin kernels
                    assert fc1.weight.dtype == torch.int32
            elif current_platform.is_rocm():
169
                if current_platform.supports_fp8() and not force_marlin:
170
                    # For GPUs with hardware support, we keep weights in fp8
171
                    assert fc1.weight.dtype == current_platform.fp8_dtype()
172
173
174
175
                else:  # unsupported ROCm platform
                    pytest.skip(
                        "Skip `test_load_fp16_model`. "
                        "It only runs on ROCm platform with FP8 compute."
176
177
                        " e.g. MI300X and above."
                    )
178
            else:  # unsupported platform
179
180
181
182
                pytest.skip(
                    "Skip `test_load_fp16_model`. "
                    "It only runs on CUDA and ROCm platform."
                )
183
184

        llm.apply_model(check_model)
185
186


187
188
189
190
@pytest.mark.skipif(
    not is_quant_method_supported("fp8"),
    reason="FP8 is not supported on this GPU type.",
)
191
192
193
194
195
196
197
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:
    def quantize_ref(tensor, inv_scale):
        # The reference implementation that fully aligns to
        # the kernel being tested.
        finfo = torch.finfo(torch.float8_e4m3fn)
        scale = inv_scale.reciprocal()
198
        qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
199
200
201
202
203
204
205
206
207
208
209
210
211
        qweight = qweight.to(torch.float8_e4m3fn)
        return qweight

    def per_tensor_dequantize(tensor, inv_scale, dtype):
        fake_qweight = tensor.to(dtype)
        dq_weight = fake_qweight * inv_scale
        return dq_weight

    # Note that we use a shape % 4 != 0 to cover edge cases,
    # because scaled_fp8_quant is vectorized by 4.
    x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)

    # Dynamic quantization
212
    ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
213
214
215
216
    ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)

    # Reference dynamic quantizaton
    y = quantize_ref(x, inv_scale)
217
    torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
218
219

    # Static quantization
220
    y, _ = ops.scaled_fp8_quant(x, inv_scale)
221
    torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
222
223

    # Padding
224
    y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
225
    assert y.shape[0] == 17
226
    torch.testing.assert_close(
227
        ref_y,
228
229
        per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, dtype),
    )
230
231
232

    # non-contiguous input with padding
    m, n, padded_stride = 975, 512, 576
233
    padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * 13).to(dtype)
234
235
236
237
238
239
240
241
242
243
244
245
    x_nc = padded_tensor[:, :n]  # shape (m, n) with stride (padded_stride, 1)

    assert not x_nc.is_contiguous()
    assert x_nc.stride(0) == padded_stride

    # dynamic quantization
    ref_y_nc, inv_scale_nc = ops.scaled_fp8_quant(x_nc, None)
    ref_y_nc = per_tensor_dequantize(ref_y_nc, inv_scale_nc, dtype)

    # reference dynamic quantization
    y_nc = quantize_ref(x_nc, inv_scale_nc)
    torch.testing.assert_close(
246
247
        ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)
    )
248
249
250
251

    # static quantization
    y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc)
    torch.testing.assert_close(
252
253
        ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)
    )
254
255

    # padding after non-contiguous input quantization
256
    y_nc_pad, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc, num_token_padding=m + 10)
257
258
259
    assert y_nc_pad.shape[0] == m + 10
    torch.testing.assert_close(
        ref_y_nc,
260
261
262
263
        per_tensor_dequantize(
            torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
        ),
    )