test_fp8.py 3.53 KB
Newer Older
1
2
3
4
5
6
7
"""Tests whether FP8 computation is enabled correctly.

Run `pytest tests/quantization/test_fp8.py --forked`.
"""
import pytest
import torch

8
from tests.quantization.utils import is_quant_method_supported
9
from vllm import _custom_ops as ops
10
11
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
MODELS = [
    "neuralmagic/Meta-Llama-3-8B-Instruct-FP8",
    "nm-testing/Phi-3-mini-128k-instruct-FP8",
]


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
def test_model_load_and_run(vllm_runner, model: str):
    with vllm_runner(model) as llm:
        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
        outputs = llm.generate_greedy(prompts=["Hello my name is"],
                                      max_tokens=10)
        print(outputs[0][1])

29

30
31
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
32
def test_load_fp16_model(vllm_runner) -> None:
33
    with vllm_runner("facebook/opt-125m", quantization="fp8") as llm:
34

35
36
37
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        fc1 = model.model.decoder.layers[0].fc1
        assert isinstance(fc1.quant_method, Fp8LinearMethod)
38
39
40
41
42
43
44
45
46
47

        capability = torch.cuda.get_device_capability()
        capability = capability[0] * 10 + capability[1]
        if capability >= 89:
            # For GPUs with hardware support, we keep weights in fp8
            assert fc1.weight.dtype == torch.float8_e4m3fn
        else:
            # For GPUs without hardware support, we pack the fp8 weights
            # for weight-only quantization using Marlin kernels
            assert fc1.weight.dtype == torch.int32
48
49


50
51
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:

    def quantize_ref(tensor, inv_scale):
        # The reference implementation that fully aligns to
        # the kernel being tested.
        finfo = torch.finfo(torch.float8_e4m3fn)
        scale = inv_scale.reciprocal()
        qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
                                                           max=finfo.max)
        qweight = qweight.to(torch.float8_e4m3fn)
        return qweight

    def per_tensor_dequantize(tensor, inv_scale, dtype):
        fake_qweight = tensor.to(dtype)
        dq_weight = fake_qweight * inv_scale
        return dq_weight

    # Note that we use a shape % 4 != 0 to cover edge cases,
    # because scaled_fp8_quant is vectorized by 4.
    x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)

    # Dynamic quantization
75
    ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
76
77
78
79
80
81
82
    ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)

    # Reference dynamic quantizaton
    y = quantize_ref(x, inv_scale)
    assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))

    # Static quantization
83
    y, _ = ops.scaled_fp8_quant(x, inv_scale)
84
85
86
    assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))

    # Padding
87
    y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
88
89
90
91
92
    assert y.shape[0] == 17
    assert torch.allclose(
        ref_y,
        per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
                              dtype))