test_fp8.py 8.46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
"""Tests whether FP8 computation is enabled correctly.

Run `pytest tests/quantization/test_fp8.py --forked`.
"""
import pytest
import torch
9
import os
10

11
from tests.quantization.utils import is_quant_method_supported
12
from vllm import _custom_ops as ops
13
14
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
                                                         Fp8LinearMethod)
15
from vllm.platforms import current_platform
16
from ..utils import models_path_prefix
17

18
MODELS = [
19
20
21
    os.path.join(models_path_prefix, "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV"),
    os.path.join(models_path_prefix, "nm-testing/Phi-3-mini-128k-instruct-FP8"),
    os.path.join(models_path_prefix, "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV"),
22
23
24
]


zhuwenwen's avatar
zhuwenwen committed
25
@pytest.mark.skipif(not is_quant_method_supported("fp8") or current_platform.is_rocm(),
26
                    reason="FP8 is not supported on this GPU type.")
27
@pytest.mark.parametrize("model_id", MODELS)
28
@pytest.mark.parametrize("force_marlin", [False, True])
29
30
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
31
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
32
33
34
35
36
                            use_rocm_aiter: bool, monkeypatch) -> None:

    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

37
38
39
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

40
41
42
43
44
45
46
47
48
49
    with vllm_runner(model_id) as llm:
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
        outputs = llm.generate_greedy(prompts=["Hello my name is"],
                                      max_tokens=10)
        print(outputs[0][1])


KV_CACHE_MODELS = [
    # Deprecated AutoFP8 format using .kv_scale
50
    os.path.join(models_path_prefix, "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV"),
51
    # AutoFP8 format using separate .k_scale and .v_scale
52
    os.path.join(models_path_prefix, "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V"),
53
54
55
]


zhuwenwen's avatar
zhuwenwen committed
56
@pytest.mark.skipif(not is_quant_method_supported("fp8") or current_platform.is_rocm(),
57
58
                    reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
59
60
61
62
63
64
65
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
                                     use_rocm_aiter: bool, monkeypatch):
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

66
67
    # vllm_runner.apply_model() relies on V0 internals.
    monkeypatch.setenv("VLLM_USE_V1", "0")
68
69
    with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:

70
71
72
73
74
        def check_model(model):
            attn = model.model.layers[0].self_attn.attn

            assert isinstance(attn.quant_method, Fp8KVCacheMethod)

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
            if not current_platform.is_rocm():
                # NOTE: This code path requires validation on Non-CUDA platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                assert 0.0 < attn._k_scale < 1.0
                assert 0.0 < attn._v_scale < 1.0
            else:
                # NOTE: This code path is for ROCm platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                # However on ROCm platform, the _k_scale and _v_scale will be
                # scaled by a factor of 2 as described in
                # vllm/model_executor/layers/quantization/kv_cache.py
                assert 0.0 < attn._k_scale < (1.0 * 2.0)
                assert 0.0 < attn._v_scale < (1.0 * 2.0)
90
91

        llm.apply_model(check_model)
92

93
94
95
96
97
98
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
        outputs = llm.generate_greedy(prompts=["Hello my name is"],
                                      max_tokens=10)
        print(outputs[0][1])

99

zhuwenwen's avatar
zhuwenwen committed
100
@pytest.mark.skipif(not is_quant_method_supported("fp8") or current_platform.is_rocm(),
101
                    reason="FP8 is not supported on this GPU type.")
102
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
103
@pytest.mark.parametrize("force_marlin", [False, True])
104
105
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
106
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
107
108
109
110
                         use_rocm_aiter: bool, monkeypatch) -> None:
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

111
112
113
    # vllm_runner.apply_model() relies on V0 internals.
    monkeypatch.setenv("VLLM_USE_V1", "0")

114
115
116
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

117
    with vllm_runner(os.path.join(models_path_prefix, "facebook/opt-125m"),
118
119
                     quantization="fp8",
                     kv_cache_dtype=kv_cache_dtype) as llm:
120

121
122
123
124
125
126
127
128
129
        def check_model(model):
            fc1 = model.model.decoder.layers[0].fc1
            assert isinstance(fc1.quant_method, Fp8LinearMethod)
            if kv_cache_dtype == "fp8":
                attn = model.model.decoder.layers[0].self_attn.attn
                assert isinstance(attn.quant_method, Fp8KVCacheMethod)
                assert attn._k_scale == 1.0
                assert attn._v_scale == 1.0

130
            if current_platform.is_cuda():
131
                if current_platform.supports_fp8() and not force_marlin:
132
133
134
135
136
137
138
                    # For GPUs with hardware support, we keep weights in fp8
                    assert fc1.weight.dtype == torch.float8_e4m3fn
                else:
                    # For GPUs without hardware support, we pack the fp8 weights
                    # for weight-only quantization using Marlin kernels
                    assert fc1.weight.dtype == torch.int32
            elif current_platform.is_rocm():
139
                if current_platform.supports_fp8() and not force_marlin:
140
                    # For GPUs with hardware support, we keep weights in fp8
141
                    assert fc1.weight.dtype == current_platform.fp8_dtype()
142
143
144
145
146
147
148
149
                else:  # unsupported ROCm platform
                    pytest.skip(
                        "Skip `test_load_fp16_model`. "
                        "It only runs on ROCm platform with FP8 compute."
                        " e.g. MI300X and above.")
            else:  # unsupported platform
                pytest.skip("Skip `test_load_fp16_model`. "
                            "It only runs on CUDA and ROCm platform.")
150
151

        llm.apply_model(check_model)
152
153


zhuwenwen's avatar
zhuwenwen committed
154
@pytest.mark.skipif(not is_quant_method_supported("fp8") or current_platform.is_rocm(),
155
                    reason="FP8 is not supported on this GPU type.")
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:

    def quantize_ref(tensor, inv_scale):
        # The reference implementation that fully aligns to
        # the kernel being tested.
        finfo = torch.finfo(torch.float8_e4m3fn)
        scale = inv_scale.reciprocal()
        qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
                                                           max=finfo.max)
        qweight = qweight.to(torch.float8_e4m3fn)
        return qweight

    def per_tensor_dequantize(tensor, inv_scale, dtype):
        fake_qweight = tensor.to(dtype)
        dq_weight = fake_qweight * inv_scale
        return dq_weight

    # Note that we use a shape % 4 != 0 to cover edge cases,
    # because scaled_fp8_quant is vectorized by 4.
    x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)

    # Dynamic quantization
179
    ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
180
181
182
183
    ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)

    # Reference dynamic quantizaton
    y = quantize_ref(x, inv_scale)
184
185
    torch.testing.assert_close(ref_y,
                               per_tensor_dequantize(y, inv_scale, dtype))
186
187

    # Static quantization
188
    y, _ = ops.scaled_fp8_quant(x, inv_scale)
189
190
    torch.testing.assert_close(ref_y,
                               per_tensor_dequantize(y, inv_scale, dtype))
191
192

    # Padding
193
    y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
194
    assert y.shape[0] == 17
195
    torch.testing.assert_close(
196
197
198
        ref_y,
        per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
                              dtype))