"tests/vscode:/vscode.git/clone" did not exist on "791b2fc30a25cc48e06a6bd7ce4fd62d765ac004"
test_utils.py 6.67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections import OrderedDict
5
from typing import NamedTuple, Optional
6
from unittest.mock import patch
7

8
import pytest
9
from huggingface_hub.utils import HfHubHTTPError
10
11
from torch import nn

12
13
14
15
16
from vllm.lora.utils import (
    get_adapter_absolute_path,
    parse_fine_tuned_lora_name,
    replace_submodule,
)
17
18
19
20
21
22
23
24
25
from vllm.model_executor.models.utils import WeightsMapper


class LoRANameParserTestConfig(NamedTuple):
    name: str
    module_name: str
    is_lora_a: bool
    is_bias: bool
    weights_mapper: Optional[WeightsMapper] = None
26
27


28
def test_parse_fine_tuned_lora_name_valid():
29
    fixture = [
30
31
32
33
34
35
        LoRANameParserTestConfig(
            "base_model.model.lm_head.lora_A.weight", "lm_head", True, False
        ),
        LoRANameParserTestConfig(
            "base_model.model.lm_head.lora_B.weight", "lm_head", False, False
        ),
36
        LoRANameParserTestConfig(
37
38
39
            "base_model.model.model.embed_tokens.lora_embedding_A",
            "model.embed_tokens",
            True,
40
            False,
41
        ),
42
        LoRANameParserTestConfig(
43
44
45
            "base_model.model.model.embed_tokens.lora_embedding_B",
            "model.embed_tokens",
            False,
46
            False,
47
        ),
48
        LoRANameParserTestConfig(
49
50
51
            "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
            "model.layers.9.mlp.down_proj",
            True,
52
            False,
53
        ),
54
        LoRANameParserTestConfig(
55
56
57
            "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
            "model.layers.9.mlp.down_proj",
            False,
58
            False,
59
        ),
60
        LoRANameParserTestConfig(
61
62
63
64
65
            "language_model.layers.9.mlp.down_proj.lora_A.weight",
            "language_model.layers.9.mlp.down_proj",
            True,
            False,
        ),
66
        LoRANameParserTestConfig(
67
68
69
70
71
            "language_model.layers.9.mlp.down_proj.lora_B.weight",
            "language_model.layers.9.mlp.down_proj",
            False,
            False,
        ),
72
73
74
75
76
77
78
        # Test with WeightsMapper
        LoRANameParserTestConfig(
            "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
            "language_model.model.layers.9.mlp.down_proj",
            True,
            False,
            weights_mapper=WeightsMapper(
79
80
                orig_to_new_prefix={"model.": "language_model.model."}
            ),
81
82
83
84
85
86
87
        ),
        LoRANameParserTestConfig(
            "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
            "language_model.model.layers.9.mlp.down_proj",
            False,
            False,
            weights_mapper=WeightsMapper(
88
89
                orig_to_new_prefix={"model.": "language_model.model."}
            ),
90
91
92
93
94
95
96
        ),
        LoRANameParserTestConfig(
            "model.layers.9.mlp.down_proj.lora_A.weight",
            "language_model.model.layers.9.mlp.down_proj",
            True,
            False,
            weights_mapper=WeightsMapper(
97
98
                orig_to_new_prefix={"model.": "language_model.model."}
            ),
99
100
101
102
103
104
105
        ),
        LoRANameParserTestConfig(
            "model.layers.9.mlp.down_proj.lora_B.weight",
            "language_model.model.layers.9.mlp.down_proj",
            False,
            False,
            weights_mapper=WeightsMapper(
106
107
                orig_to_new_prefix={"model.": "language_model.model."}
            ),
108
109
110
        ),
    ]
    for name, module_name, is_lora_a, is_bias, weights_mapper in fixture:
111
112
113
        assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name(
            name, weights_mapper
        )
114
115


116
117
118
119
120
121
122
123
124
125
def test_parse_fine_tuned_lora_name_invalid():
    fixture = {
        "base_model.weight",
        "base_model.model.weight",
    }
    for name in fixture:
        with pytest.raises(ValueError, match="unsupported LoRA weight"):
            parse_fine_tuned_lora_name(name)


126
127
def test_replace_submodule():
    model = nn.Sequential(
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        OrderedDict(
            [
                ("dense1", nn.Linear(764, 100)),
                ("act1", nn.ReLU()),
                ("dense2", nn.Linear(100, 50)),
                (
                    "seq1",
                    nn.Sequential(
                        OrderedDict(
                            [
                                ("dense1", nn.Linear(100, 10)),
                                ("dense2", nn.Linear(10, 50)),
                            ]
                        )
                    ),
                ),
                ("act2", nn.ReLU()),
                ("output", nn.Linear(50, 10)),
                ("outact", nn.Sigmoid()),
            ]
        )
    )
150
151
152
153
154
155
156
157
158
159
160

    sigmoid = nn.Sigmoid()

    replace_submodule(model, "act1", sigmoid)
    assert dict(model.named_modules())["act1"] == sigmoid

    dense2 = nn.Linear(1, 5)
    replace_submodule(model, "seq1.dense2", dense2)
    assert dict(model.named_modules())["seq1.dense2"] == dense2


161
# Unit tests for get_adapter_absolute_path
162
@patch("os.path.isabs")
163
def test_get_adapter_absolute_path_absolute(mock_isabs):
164
    path = "/absolute/path/to/lora"
165
166
167
168
    mock_isabs.return_value = True
    assert get_adapter_absolute_path(path) == path


169
@patch("os.path.expanduser")
170
171
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
    # Path with ~ that needs to be expanded
172
173
    path = "~/relative/path/to/lora"
    absolute_path = "/home/user/relative/path/to/lora"
174
175
176
177
    mock_expanduser.return_value = absolute_path
    assert get_adapter_absolute_path(path) == absolute_path


178
179
@patch("os.path.exists")
@patch("os.path.abspath")
180
181
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
    # Relative path that exists locally
182
183
    path = "relative/path/to/lora"
    absolute_path = "/absolute/path/to/lora"
184
185
186
187
188
    mock_exist.return_value = True
    mock_abspath.return_value = absolute_path
    assert get_adapter_absolute_path(path) == absolute_path


189
190
191
@patch("huggingface_hub.snapshot_download")
@patch("os.path.exists")
def test_get_adapter_absolute_path_huggingface(mock_exist, mock_snapshot_download):
192
    # Hugging Face model identifier
193
194
    path = "org/repo"
    absolute_path = "/mock/snapshot/path"
195
196
197
198
199
    mock_exist.return_value = False
    mock_snapshot_download.return_value = absolute_path
    assert get_adapter_absolute_path(path) == absolute_path


200
201
202
203
204
@patch("huggingface_hub.snapshot_download")
@patch("os.path.exists")
def test_get_adapter_absolute_path_huggingface_error(
    mock_exist, mock_snapshot_download
):
205
    # Hugging Face model identifier with download error
206
    path = "org/repo"
207
    mock_exist.return_value = False
208
    mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info")
209
    assert get_adapter_absolute_path(path) == path