test_lora_checkpoints.py 5.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pytest

6
from vllm.lora.lora_model import LoRAModel
7
from vllm.lora.peft_helper import PEFTHelper
8
from vllm.lora.utils import parse_fine_tuned_lora_name
9
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
10
from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM
11
from vllm.model_executor.models.utils import WeightsMapper
12

13
lora_lst = ["baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"]
14
15
16
17
18
19
BAICHUAN_LORA_MODULES = [
    "W_pack",
    "o_proj",
    "gate_up_proj",
    "down_proj",
]
20

21
22
23
24
25
26

@pytest.mark.parametrize("lora_name", lora_lst)
def test_load_checkpoints(
    lora_name,
    baichuan_lora_files,
    baichuan_zero_lora_files,
27
    baichuan_regex_lora_files,
28
29
    chatglm3_lora_files,
):
30
    packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
31

32
    expected_lora_lst: list[str] = []
33
    for module in BAICHUAN_LORA_MODULES:
34
        if module in packed_modules_mapping:
35
            expected_lora_lst.extend(packed_modules_mapping[module])
36
        else:
37
38
            expected_lora_lst.append(module)
    expected_lora_modules = set(expected_lora_lst)
39
    if lora_name == "baichuan7B":
40
41
42
        peft_helper = PEFTHelper.from_local_dir(
            baichuan_lora_files, max_position_embeddings=4096
        )
43
44
45
46
47
        # For the baichuan7B model, load it's LoRA,
        # and the test should pass.
        LoRAModel.from_local_checkpoint(
            baichuan_lora_files,
            expected_lora_modules,
48
            peft_helper=peft_helper,
49
50
            lora_model_id=1,
            device="cpu",
51
            model_vocab_size=64000,
52
        )
53
    elif lora_name == "baichuan7B-zero":
54
        # Test that the target_modules contain prefix
55
56
        # such as "model.layers.0.self_atten.W_pack", and
        # the test should pass.
57
58
59
        peft_helper = PEFTHelper.from_local_dir(
            baichuan_zero_lora_files, max_position_embeddings=4096
        )
60
61
62
        LoRAModel.from_local_checkpoint(
            baichuan_zero_lora_files,
            expected_lora_modules,
63
            peft_helper=peft_helper,
64
65
            lora_model_id=1,
            device="cpu",
66
            model_vocab_size=64000,
67
        )
68
69
70
    elif lora_name == "baichuan7B-zero-regex":
        # Test that the `target_modules` in the form of regular expressions,
        # such as `model\\..*(W_pack|o_proj)`, and the test should pass.
71
72
73
        peft_helper = PEFTHelper.from_local_dir(
            baichuan_regex_lora_files, max_position_embeddings=4096
        )
74
75
76
        LoRAModel.from_local_checkpoint(
            baichuan_regex_lora_files,
            expected_lora_modules,
77
            peft_helper=peft_helper,
78
79
            lora_model_id=1,
            device="cpu",
80
            model_vocab_size=64000,
81
        )
82
83
84
85
    else:
        # For the baichuan7B model, load chatglm3-6b's LoRA,
        # and the test should raise the following error.
        expected_error = "Please verify that the loaded LoRA module is correct"  # noqa: E501
86
87
88
        peft_helper = PEFTHelper.from_local_dir(
            chatglm3_lora_files, max_position_embeddings=4096
        )
89
90
91
92
        with pytest.raises(ValueError, match=expected_error):
            LoRAModel.from_local_checkpoint(
                chatglm3_lora_files,
                expected_lora_modules,
93
                peft_helper=peft_helper,
94
95
                lora_model_id=1,
                device="cpu",
96
                model_vocab_size=64000,
97
            )
98
99


100
def test_lora_weights_mapping(baichuan_lora_files):
101
    packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
102

103
    expected_lora_lst: list[str] = []
104
    for module in BAICHUAN_LORA_MODULES:
105
        if module in packed_modules_mapping:
106
            expected_lora_lst.extend(packed_modules_mapping[module])
107
        else:
108
109
            expected_lora_lst.append(module)
    expected_lora_modules = set(expected_lora_lst)
110
111
112
113
114
115
116
117
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.": "language_model.model.",
        },
        orig_to_new_substr={
            ".layers.": ".baichuan_layers.",
        },
    )
118
119
120
    peft_helper = PEFTHelper.from_local_dir(
        baichuan_lora_files, max_position_embeddings=4096
    )
121
122
123
    lora_model = LoRAModel.from_local_checkpoint(
        baichuan_lora_files,
        expected_lora_modules,
124
        peft_helper=peft_helper,
125
126
        lora_model_id=1,
        device="cpu",
127
        model_vocab_size=64000,
128
129
130
131
        weights_mapper=hf_to_vllm_mapper,
    )
    for name in lora_model.loras:
        assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
132
        assert ".baichuan_layers." in name
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153


def test_gemma4_lora_weights_mapping():
    mapper = Gemma4ForCausalLM.hf_to_vllm_mapper
    name = "base_model.model.model.language_model.layers.9.mlp.down_proj.lora_A.weight"
    assert parse_fine_tuned_lora_name(name, mapper) == (
        "model.layers.9.mlp.down_proj",
        True,
    )


def test_gemma4_moe_lora_weights_mapping():
    mapper = Gemma4ForCausalLM.hf_to_vllm_mapper
    name = (
        "base_model.model.model.language_model.layers.9.moe.experts."
        "gate_up_proj.lora_B.weight"
    )
    assert parse_fine_tuned_lora_name(name, mapper) == (
        "model.layers.9.moe.gate_up_proj",
        False,
    )