test_lora_checkpoints.py 5.11 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from typing import List

5
6
7
import pytest

from vllm.lora.models import LoRAModel
8
from vllm.lora.peft_helper import PEFTHelper
9
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
10
from vllm.model_executor.models.utils import WeightsMapper
11

12
13
14
lora_lst = [
    "baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
]
15
16
17
18
19
20
BAICHUAN_LORA_MODULES = [
    "W_pack",
    "o_proj",
    "gate_up_proj",
    "down_proj",
]
21

22
23
24
25
26
27

@pytest.mark.parametrize("lora_name", lora_lst)
def test_load_checkpoints(
    lora_name,
    baichuan_lora_files,
    baichuan_zero_lora_files,
28
    baichuan_regex_lora_files,
29
30
    chatglm3_lora_files,
):
31
32
33
    packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
    embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
    embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
34
    expected_lora_modules: List[str] = []
35
    for module in BAICHUAN_LORA_MODULES:
36
37
38
39
40
        if module in packed_modules_mapping:
            expected_lora_modules.extend(packed_modules_mapping[module])
        else:
            expected_lora_modules.append(module)
    if lora_name == "baichuan7B":
41
42
        peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
                                                max_position_embeddings=4096)
43
44
45
46
47
        # For the baichuan7B model, load it's LoRA,
        # and the test should pass.
        LoRAModel.from_local_checkpoint(
            baichuan_lora_files,
            expected_lora_modules,
48
            peft_helper=peft_helper,
49
50
51
52
            lora_model_id=1,
            device="cpu",
            embedding_modules=embedding_modules,
            embedding_padding_modules=embed_padding_modules)
53
    elif lora_name == "baichuan7B-zero":
54
        # Test that the target_modules contain prefix
55
56
        # such as "model.layers.0.self_atten.W_pack", and
        # the test should pass.
57
58
        peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files,
                                                max_position_embeddings=4096)
59
60
61
        LoRAModel.from_local_checkpoint(
            baichuan_zero_lora_files,
            expected_lora_modules,
62
            peft_helper=peft_helper,
63
64
65
66
            lora_model_id=1,
            device="cpu",
            embedding_modules=embedding_modules,
            embedding_padding_modules=embed_padding_modules)
67
68
69
    elif lora_name == "baichuan7B-zero-regex":
        # Test that the `target_modules` in the form of regular expressions,
        # such as `model\\..*(W_pack|o_proj)`, and the test should pass.
70
71
        peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files,
                                                max_position_embeddings=4096)
72
73
74
        LoRAModel.from_local_checkpoint(
            baichuan_regex_lora_files,
            expected_lora_modules,
75
            peft_helper=peft_helper,
76
77
78
79
            lora_model_id=1,
            device="cpu",
            embedding_modules=embedding_modules,
            embedding_padding_modules=embed_padding_modules)
80
81
82
83
    else:
        # For the baichuan7B model, load chatglm3-6b's LoRA,
        # and the test should raise the following error.
        expected_error = "Please verify that the loaded LoRA module is correct"  # noqa: E501
84
85
        peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files,
                                                max_position_embeddings=4096)
86
87
88
89
        with pytest.raises(ValueError, match=expected_error):
            LoRAModel.from_local_checkpoint(
                chatglm3_lora_files,
                expected_lora_modules,
90
                peft_helper=peft_helper,
91
92
93
94
                lora_model_id=1,
                device="cpu",
                embedding_modules=embedding_modules,
                embedding_padding_modules=embed_padding_modules)
95
96


97
def test_lora_weights_mapping(baichuan_lora_files):
98

99
100
101
102
    packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
    embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
    embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
    expected_lora_modules: List[str] = []
103
    for module in BAICHUAN_LORA_MODULES:
104
105
106
107
108
        if module in packed_modules_mapping:
            expected_lora_modules.extend(packed_modules_mapping[module])
        else:
            expected_lora_modules.append(module)

109
110
111
112
113
114
115
116
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.": "language_model.model.",
        },
        orig_to_new_substr={
            ".layers.": ".baichuan_layers.",
        },
    )
117
118
    peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
                                            max_position_embeddings=4096)
119
120
121
    lora_model = LoRAModel.from_local_checkpoint(
        baichuan_lora_files,
        expected_lora_modules,
122
        peft_helper=peft_helper,
123
124
125
126
127
128
129
130
        lora_model_id=1,
        device="cpu",
        embedding_modules=embedding_modules,
        embedding_padding_modules=embed_padding_modules,
        weights_mapper=hf_to_vllm_mapper,
    )
    for name in lora_model.loras:
        assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
131
        assert ".baichuan_layers." in name