test_modelopt.py 3.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test ModelOpt quantization method setup and weight loading.

Run `pytest tests/quantization/test_modelopt.py`.
"""

import os

import pytest
import torch

from tests.quantization.utils import is_quant_method_supported


@pytest.fixture(scope="function", autouse=True)
17
18
19
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
20
21


22
23
24
25
@pytest.mark.skipif(
    not is_quant_method_supported("modelopt"),
    reason="ModelOpt FP8 is not supported on this GPU type.",
)
26
27
def test_modelopt_fp8_checkpoint_setup(vllm_runner):
    """Test ModelOpt FP8 checkpoint loading and structure validation."""
co63oc's avatar
co63oc committed
28
    # TODO: provide a small publicly available test checkpoint
29
30
31
32
    model_path = (
        "/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/"
        "TinyLlama-1.1B-Chat-v1.0-fp8-0710"
    )
33
34
35

    # Skip test if checkpoint doesn't exist
    if not os.path.exists(model_path):
36
37
38
39
        pytest.skip(
            f"Test checkpoint not found at {model_path}. "
            "This test requires a local ModelOpt FP8 checkpoint."
        )
40

41
    with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
42
43
44
45
46
47
48
49
50
51
52

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            # Check that ModelOpt quantization method is properly applied
            from vllm.model_executor.layers.quantization.modelopt import (
53
54
55
                ModelOptFp8LinearMethod,
            )

56
57
            assert isinstance(qkv_proj.quant_method, ModelOptFp8LinearMethod)
            assert isinstance(o_proj.quant_method, ModelOptFp8LinearMethod)
58
            assert isinstance(gate_up_proj.quant_method, ModelOptFp8LinearMethod)
59
60
61
62
63
64
65
66
67
            assert isinstance(down_proj.quant_method, ModelOptFp8LinearMethod)

            # Check weight dtype is FP8
            assert qkv_proj.weight.dtype == torch.float8_e4m3fn
            assert o_proj.weight.dtype == torch.float8_e4m3fn
            assert gate_up_proj.weight.dtype == torch.float8_e4m3fn
            assert down_proj.weight.dtype == torch.float8_e4m3fn

            # Check scales are present and have correct dtype
68
69
            assert hasattr(qkv_proj, "weight_scale")
            assert hasattr(qkv_proj, "input_scale")
70
71
72
            assert qkv_proj.weight_scale.dtype == torch.float32
            assert qkv_proj.input_scale.dtype == torch.float32

73
74
            assert hasattr(o_proj, "weight_scale")
            assert hasattr(o_proj, "input_scale")
75
76
77
            assert o_proj.weight_scale.dtype == torch.float32
            assert o_proj.input_scale.dtype == torch.float32

78
79
            assert hasattr(gate_up_proj, "weight_scale")
            assert hasattr(gate_up_proj, "input_scale")
80
81
82
            assert gate_up_proj.weight_scale.dtype == torch.float32
            assert gate_up_proj.input_scale.dtype == torch.float32

83
84
            assert hasattr(down_proj, "weight_scale")
            assert hasattr(down_proj, "input_scale")
85
86
87
88
89
90
            assert down_proj.weight_scale.dtype == torch.float32
            assert down_proj.input_scale.dtype == torch.float32

        llm.apply_model(check_model)

        # Run a simple generation test to ensure the model works
91
        output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
92
93
        assert output
        print(f"ModelOpt FP8 output: {output}")