test_modelopt.py 3.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test ModelOpt quantization method setup and weight loading.

Run `pytest tests/quantization/test_modelopt.py`.
"""

import os

import pytest
import torch

from tests.quantization.utils import is_quant_method_supported


@pytest.fixture(scope="function", autouse=True)
17
18
19
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
20
21
22
23
24
25


@pytest.mark.skipif(not is_quant_method_supported("modelopt"),
                    reason="ModelOpt FP8 is not supported on this GPU type.")
def test_modelopt_fp8_checkpoint_setup(vllm_runner):
    """Test ModelOpt FP8 checkpoint loading and structure validation."""
co63oc's avatar
co63oc committed
26
    # TODO: provide a small publicly available test checkpoint
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    model_path = ("/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/"
                  "TinyLlama-1.1B-Chat-v1.0-fp8-0710")

    # Skip test if checkpoint doesn't exist
    if not os.path.exists(model_path):
        pytest.skip(f"Test checkpoint not found at {model_path}. "
                    "This test requires a local ModelOpt FP8 checkpoint.")

    with vllm_runner(model_path, quantization="modelopt",
                     enforce_eager=True) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            # Check that ModelOpt quantization method is properly applied
            from vllm.model_executor.layers.quantization.modelopt import (
                ModelOptFp8LinearMethod)
            assert isinstance(qkv_proj.quant_method, ModelOptFp8LinearMethod)
            assert isinstance(o_proj.quant_method, ModelOptFp8LinearMethod)
            assert isinstance(gate_up_proj.quant_method,
                              ModelOptFp8LinearMethod)
            assert isinstance(down_proj.quant_method, ModelOptFp8LinearMethod)

            # Check weight dtype is FP8
            assert qkv_proj.weight.dtype == torch.float8_e4m3fn
            assert o_proj.weight.dtype == torch.float8_e4m3fn
            assert gate_up_proj.weight.dtype == torch.float8_e4m3fn
            assert down_proj.weight.dtype == torch.float8_e4m3fn

            # Check scales are present and have correct dtype
            assert hasattr(qkv_proj, 'weight_scale')
            assert hasattr(qkv_proj, 'input_scale')
            assert qkv_proj.weight_scale.dtype == torch.float32
            assert qkv_proj.input_scale.dtype == torch.float32

            assert hasattr(o_proj, 'weight_scale')
            assert hasattr(o_proj, 'input_scale')
            assert o_proj.weight_scale.dtype == torch.float32
            assert o_proj.input_scale.dtype == torch.float32

            assert hasattr(gate_up_proj, 'weight_scale')
            assert hasattr(gate_up_proj, 'input_scale')
            assert gate_up_proj.weight_scale.dtype == torch.float32
            assert gate_up_proj.input_scale.dtype == torch.float32

            assert hasattr(down_proj, 'weight_scale')
            assert hasattr(down_proj, 'input_scale')
            assert down_proj.weight_scale.dtype == torch.float32
            assert down_proj.input_scale.dtype == torch.float32

        llm.apply_model(check_model)

        # Run a simple generation test to ensure the model works
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
        assert output
        print(f"ModelOpt FP8 output: {output}")