test_utils.py 6.49 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections import OrderedDict
5
from typing import NamedTuple
6
from unittest.mock import MagicMock, patch
7

8
import pytest
9
from huggingface_hub.utils import HfHubHTTPError
10
11
from torch import nn

12
13
14
15
16
from vllm.lora.utils import (
    get_adapter_absolute_path,
    parse_fine_tuned_lora_name,
    replace_submodule,
)
17
18
19
20
21
22
23
from vllm.model_executor.models.utils import WeightsMapper


class LoRANameParserTestConfig(NamedTuple):
    name: str
    module_name: str
    is_lora_a: bool
24
    weights_mapper: WeightsMapper | None = None
25
26


27
def test_parse_fine_tuned_lora_name_valid():
28
    fixture = [
29
30
31
32
33
34
        LoRANameParserTestConfig(
            "base_model.model.lm_head.lora_A.weight", "lm_head", True, False
        ),
        LoRANameParserTestConfig(
            "base_model.model.lm_head.lora_B.weight", "lm_head", False, False
        ),
35
        LoRANameParserTestConfig(
36
37
38
39
            "base_model.model.model.embed_tokens.lora_embedding_A",
            "model.embed_tokens",
            True,
        ),
40
        LoRANameParserTestConfig(
41
42
43
44
            "base_model.model.model.embed_tokens.lora_embedding_B",
            "model.embed_tokens",
            False,
        ),
45
        LoRANameParserTestConfig(
46
47
48
49
            "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
            "model.layers.9.mlp.down_proj",
            True,
        ),
50
        LoRANameParserTestConfig(
51
52
53
54
            "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
            "model.layers.9.mlp.down_proj",
            False,
        ),
55
        LoRANameParserTestConfig(
56
57
58
59
            "language_model.layers.9.mlp.down_proj.lora_A.weight",
            "language_model.layers.9.mlp.down_proj",
            True,
        ),
60
        LoRANameParserTestConfig(
61
62
63
64
            "language_model.layers.9.mlp.down_proj.lora_B.weight",
            "language_model.layers.9.mlp.down_proj",
            False,
        ),
65
66
67
68
69
70
        # Test with WeightsMapper
        LoRANameParserTestConfig(
            "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
            "language_model.model.layers.9.mlp.down_proj",
            True,
            weights_mapper=WeightsMapper(
71
72
                orig_to_new_prefix={"model.": "language_model.model."}
            ),
73
74
75
76
77
78
        ),
        LoRANameParserTestConfig(
            "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
            "language_model.model.layers.9.mlp.down_proj",
            False,
            weights_mapper=WeightsMapper(
79
80
                orig_to_new_prefix={"model.": "language_model.model."}
            ),
81
82
83
84
85
86
        ),
        LoRANameParserTestConfig(
            "model.layers.9.mlp.down_proj.lora_A.weight",
            "language_model.model.layers.9.mlp.down_proj",
            True,
            weights_mapper=WeightsMapper(
87
88
                orig_to_new_prefix={"model.": "language_model.model."}
            ),
89
90
91
92
93
94
        ),
        LoRANameParserTestConfig(
            "model.layers.9.mlp.down_proj.lora_B.weight",
            "language_model.model.layers.9.mlp.down_proj",
            False,
            weights_mapper=WeightsMapper(
95
96
                orig_to_new_prefix={"model.": "language_model.model."}
            ),
97
98
        ),
    ]
99
100
    for name, module_name, is_lora_a, weights_mapper in fixture:
        assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(
101
102
            name, weights_mapper
        )
103
104


105
106
107
108
109
110
111
112
113
114
def test_parse_fine_tuned_lora_name_invalid():
    fixture = {
        "base_model.weight",
        "base_model.model.weight",
    }
    for name in fixture:
        with pytest.raises(ValueError, match="unsupported LoRA weight"):
            parse_fine_tuned_lora_name(name)


115
116
def test_replace_submodule():
    model = nn.Sequential(
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        OrderedDict(
            [
                ("dense1", nn.Linear(764, 100)),
                ("act1", nn.ReLU()),
                ("dense2", nn.Linear(100, 50)),
                (
                    "seq1",
                    nn.Sequential(
                        OrderedDict(
                            [
                                ("dense1", nn.Linear(100, 10)),
                                ("dense2", nn.Linear(10, 50)),
                            ]
                        )
                    ),
                ),
                ("act2", nn.ReLU()),
                ("output", nn.Linear(50, 10)),
                ("outact", nn.Sigmoid()),
            ]
        )
    )
139
140
141
142
143
144
145
146
147
148
149

    sigmoid = nn.Sigmoid()

    replace_submodule(model, "act1", sigmoid)
    assert dict(model.named_modules())["act1"] == sigmoid

    dense2 = nn.Linear(1, 5)
    replace_submodule(model, "seq1.dense2", dense2)
    assert dict(model.named_modules())["seq1.dense2"] == dense2


150
# Unit tests for get_adapter_absolute_path
151
@patch("os.path.isabs")
152
def test_get_adapter_absolute_path_absolute(mock_isabs):
153
    path = "/absolute/path/to/lora"
154
155
156
157
    mock_isabs.return_value = True
    assert get_adapter_absolute_path(path) == path


158
@patch("os.path.expanduser")
159
160
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
    # Path with ~ that needs to be expanded
161
162
    path = "~/relative/path/to/lora"
    absolute_path = "/home/user/relative/path/to/lora"
163
164
165
166
    mock_expanduser.return_value = absolute_path
    assert get_adapter_absolute_path(path) == absolute_path


167
168
@patch("os.path.exists")
@patch("os.path.abspath")
169
170
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
    # Relative path that exists locally
171
172
    path = "relative/path/to/lora"
    absolute_path = "/absolute/path/to/lora"
173
174
175
176
177
    mock_exist.return_value = True
    mock_abspath.return_value = absolute_path
    assert get_adapter_absolute_path(path) == absolute_path


178
179
180
@patch("huggingface_hub.snapshot_download")
@patch("os.path.exists")
def test_get_adapter_absolute_path_huggingface(mock_exist, mock_snapshot_download):
181
    # Hugging Face model identifier
182
183
    path = "org/repo"
    absolute_path = "/mock/snapshot/path"
184
185
186
187
188
    mock_exist.return_value = False
    mock_snapshot_download.return_value = absolute_path
    assert get_adapter_absolute_path(path) == absolute_path


189
190
191
192
193
@patch("huggingface_hub.snapshot_download")
@patch("os.path.exists")
def test_get_adapter_absolute_path_huggingface_error(
    mock_exist, mock_snapshot_download
):
194
    # Hugging Face model identifier with download error
195
    path = "org/repo"
196
    mock_exist.return_value = False
197
198
199
200
    mock_snapshot_download.side_effect = HfHubHTTPError(
        "failed to query model info",
        response=MagicMock(),
    )
201
    assert get_adapter_absolute_path(path) == path