test_fp8.py 8.39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
"""Tests whether FP8 computation is enabled correctly.

Run `pytest tests/quantization/test_fp8.py --forked`.
"""
import pytest
import torch
8
import os
9

10
from tests.quantization.utils import is_quant_method_supported
11
from vllm import _custom_ops as ops
12
13
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
                                                         Fp8LinearMethod)
14
from vllm.platforms import current_platform
15
from ..utils import models_path_prefix
16

17
MODELS = [
18
19
20
    os.path.join(models_path_prefix, "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV"),
    os.path.join(models_path_prefix, "nm-testing/Phi-3-mini-128k-instruct-FP8"),
    os.path.join(models_path_prefix, "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV"),
21
22
23
]


zhuwenwen's avatar
zhuwenwen committed
24
@pytest.mark.skipif(not is_quant_method_supported("fp8") or current_platform.is_rocm(),
25
                    reason="FP8 is not supported on this GPU type.")
26
@pytest.mark.parametrize("model_id", MODELS)
27
@pytest.mark.parametrize("force_marlin", [False, True])
28
29
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
30
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
31
32
33
34
35
                            use_rocm_aiter: bool, monkeypatch) -> None:

    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

36
37
38
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

39
40
41
42
43
44
45
46
47
48
    with vllm_runner(model_id) as llm:
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
        outputs = llm.generate_greedy(prompts=["Hello my name is"],
                                      max_tokens=10)
        print(outputs[0][1])


KV_CACHE_MODELS = [
    # Deprecated AutoFP8 format using .kv_scale
49
    os.path.join(models_path_prefix, "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV"),
50
    # AutoFP8 format using separate .k_scale and .v_scale
51
    os.path.join(models_path_prefix, "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V"),
52
53
54
]


zhuwenwen's avatar
zhuwenwen committed
55
@pytest.mark.skipif(not is_quant_method_supported("fp8") or current_platform.is_rocm(),
56
57
                    reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
58
59
60
61
62
63
64
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
                                     use_rocm_aiter: bool, monkeypatch):
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

65
66
    # vllm_runner.apply_model() relies on V0 internals.
    monkeypatch.setenv("VLLM_USE_V1", "0")
67
68
    with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:

69
70
71
72
73
        def check_model(model):
            attn = model.model.layers[0].self_attn.attn

            assert isinstance(attn.quant_method, Fp8KVCacheMethod)

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
            if not current_platform.is_rocm():
                # NOTE: This code path requires validation on Non-CUDA platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                assert 0.0 < attn._k_scale < 1.0
                assert 0.0 < attn._v_scale < 1.0
            else:
                # NOTE: This code path is for ROCm platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                # However on ROCm platform, the _k_scale and _v_scale will be
                # scaled by a factor of 2 as described in
                # vllm/model_executor/layers/quantization/kv_cache.py
                assert 0.0 < attn._k_scale < (1.0 * 2.0)
                assert 0.0 < attn._v_scale < (1.0 * 2.0)
89
90

        llm.apply_model(check_model)
91

92
93
94
95
96
97
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
        outputs = llm.generate_greedy(prompts=["Hello my name is"],
                                      max_tokens=10)
        print(outputs[0][1])

98

zhuwenwen's avatar
zhuwenwen committed
99
@pytest.mark.skipif(not is_quant_method_supported("fp8") or current_platform.is_rocm(),
100
                    reason="FP8 is not supported on this GPU type.")
101
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
102
@pytest.mark.parametrize("force_marlin", [False, True])
103
104
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
105
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
106
107
108
109
                         use_rocm_aiter: bool, monkeypatch) -> None:
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

110
111
112
    # vllm_runner.apply_model() relies on V0 internals.
    monkeypatch.setenv("VLLM_USE_V1", "0")

113
114
115
    if force_marlin:
        monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

116
    with vllm_runner(os.path.join(models_path_prefix, "facebook/opt-125m"),
117
118
                     quantization="fp8",
                     kv_cache_dtype=kv_cache_dtype) as llm:
119

120
121
122
123
124
125
126
127
128
        def check_model(model):
            fc1 = model.model.decoder.layers[0].fc1
            assert isinstance(fc1.quant_method, Fp8LinearMethod)
            if kv_cache_dtype == "fp8":
                attn = model.model.decoder.layers[0].self_attn.attn
                assert isinstance(attn.quant_method, Fp8KVCacheMethod)
                assert attn._k_scale == 1.0
                assert attn._v_scale == 1.0

129
            if current_platform.is_cuda():
130
                if current_platform.supports_fp8() and not force_marlin:
131
132
133
134
135
136
137
                    # For GPUs with hardware support, we keep weights in fp8
                    assert fc1.weight.dtype == torch.float8_e4m3fn
                else:
                    # For GPUs without hardware support, we pack the fp8 weights
                    # for weight-only quantization using Marlin kernels
                    assert fc1.weight.dtype == torch.int32
            elif current_platform.is_rocm():
138
                if current_platform.supports_fp8() and not force_marlin:
139
                    # For GPUs with hardware support, we keep weights in fp8
140
                    assert fc1.weight.dtype == current_platform.fp8_dtype()
141
142
143
144
145
146
147
148
                else:  # unsupported ROCm platform
                    pytest.skip(
                        "Skip `test_load_fp16_model`. "
                        "It only runs on ROCm platform with FP8 compute."
                        " e.g. MI300X and above.")
            else:  # unsupported platform
                pytest.skip("Skip `test_load_fp16_model`. "
                            "It only runs on CUDA and ROCm platform.")
149
150

        llm.apply_model(check_model)
151
152


zhuwenwen's avatar
zhuwenwen committed
153
@pytest.mark.skipif(not is_quant_method_supported("fp8") or current_platform.is_rocm(),
154
                    reason="FP8 is not supported on this GPU type.")
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:

    def quantize_ref(tensor, inv_scale):
        # The reference implementation that fully aligns to
        # the kernel being tested.
        finfo = torch.finfo(torch.float8_e4m3fn)
        scale = inv_scale.reciprocal()
        qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
                                                           max=finfo.max)
        qweight = qweight.to(torch.float8_e4m3fn)
        return qweight

    def per_tensor_dequantize(tensor, inv_scale, dtype):
        fake_qweight = tensor.to(dtype)
        dq_weight = fake_qweight * inv_scale
        return dq_weight

    # Note that we use a shape % 4 != 0 to cover edge cases,
    # because scaled_fp8_quant is vectorized by 4.
    x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)

    # Dynamic quantization
178
    ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
179
180
181
182
    ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)

    # Reference dynamic quantizaton
    y = quantize_ref(x, inv_scale)
183
184
    torch.testing.assert_close(ref_y,
                               per_tensor_dequantize(y, inv_scale, dtype))
185
186

    # Static quantization
187
    y, _ = ops.scaled_fp8_quant(x, inv_scale)
188
189
    torch.testing.assert_close(ref_y,
                               per_tensor_dequantize(y, inv_scale, dtype))
190
191

    # Padding
192
    y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
193
    assert y.shape[0] == 17
194
    torch.testing.assert_close(
195
196
197
        ref_y,
        per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
                              dtype))