utils.py 8.65 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING, Optional, Union
6

7
import huggingface_hub
8
import regex as re
9
10
11
12
13
14
from huggingface_hub.utils import (
    EntryNotFoundError,
    HfHubHTTPError,
    HFValidationError,
    RepositoryNotFoundError,
)
15
from torch import nn
16
from transformers import PretrainedConfig
17

18
from vllm.config.lora import LoRAConfig
19
from vllm.logger import init_logger
20

21
22
23
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
    LogitsProcessorWithLoRA,
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
40
from vllm.model_executor.layers.linear import LinearBase
41

42
# yapf: enable
43
44
45

if TYPE_CHECKING:
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
46
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
47
    from vllm.model_executor.models.utils import WeightsMapper
48
49

logger = init_logger(__name__)
50

51
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
52
53
54
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
55
56
    QKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithLoRA,
57
    RowParallelLinearWithLoRA,
58
    ReplicatedLinearWithLoRA,
59
60
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
61
    QKVParallelLinearWithShardedLoRA,
62
    MergedColumnParallelLinearWithShardedLoRA,
63
    MergedQKVParallelLinearWithShardedLoRA,
64
    RowParallelLinearWithShardedLoRA,
65
66
67
}


68
69
70
71
72
73
74
def from_layer(
    layer: nn.Module,
    max_loras: int,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: Optional[PretrainedConfig] = None,
) -> nn.Module:
75
76
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
77
78
79
80
81
82
        if lora_cls.can_replace_layer(
            source_layer=layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
        ):
83
            instance_layer = lora_cls(layer)
84
            instance_layer.create_lora_weights(max_loras, lora_config, model_config)
85
            return instance_layer
86
87
88
89
    return layer


def from_layer_logits_processor(
90
91
    layer: "LogitsProcessor",
    lm_head: "ParallelLMHead",
92
93
94
95
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
96
97
98
99
100
101
102
    ret = LogitsProcessorWithLoRA(
        layer,
        lm_head.embedding_dim,
        lm_head.weight.dtype,
        lm_head.weight.device,
        lm_head.get_sharded_to_full_mapping(),
    )
103
104
105
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

106

107
108
109
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
110
111
112
113
114
115
116
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


117
def parse_fine_tuned_lora_name(
118
    name: str, weights_mapper: Optional["WeightsMapper"] = None
119
) -> tuple[str, bool, bool]:
120
121
122
123
124
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
125
126
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
127
    return:
128
        tuple(module_name, is_lora_a):
129
130
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
131
            is_bias whether the tensor is lora bias.
132
    """
133

134
    # LoRA weight qualified name usually starts with `base_model.model.`,
135
136
    # so we remove the prefix `base_model.model.` to make the following
    # mapping correctly.
137
    if name.startswith("base_model.model."):
138
139
140
141
        name = name.replace("base_model.model.", "")
        name = weights_mapper._map_name(name) if weights_mapper else name
        # recover the prefix `base_model.model.`
        name = "base_model.model." + name
142
143
    else:
        name = weights_mapper._map_name(name) if weights_mapper else name
144

145
146
147
    # In some situations, we may not start with `base_model.model.`.
    # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
    # we should keep the prefix intact.
148
    start_index = 2 if name.startswith("base_model.model.") else 0
149

150
    parts = name.split(".")
151
    if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
152
        new_name = ".".join(parts[start_index:-2])
153
        return new_name, parts[-2] == "lora_A", False
154

155
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
156
        new_name = ".".join(parts[start_index:-1])
157
        return new_name, parts[-1] == "lora_embedding_A", False
158
159

    if parts[-1] == "bias":
160
        new_name = ".".join(parts[start_index:-2])
161
        return new_name, False, True
162

163
    raise ValueError(f"{name} is unsupported LoRA weight")
164
165


166
167
168
def is_regex_target_modules(
    load_modules: Union[str, list[str]], expected_lora_modules: list[str]
) -> bool:
169
    """
170
171
172
    PEFT supports passing `target_modules` in the form of regular expressions,
    such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
    determine whether the suffix in the regular expression is present in the
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    `expected_lora_modules`.
    """

    def is_valid_regex(pattern):
        try:
            re.compile(pattern)
            return True
        except re.error:
            return False

    def is_subset(sub_list, full_list):
        return set(sub_list).issubset(set(full_list))

    # Similar to PEFT's processing logic, regex-related operations are only
    #  executed when the load_modules is a `str`.
    if not isinstance(load_modules, str):
        return False

    if is_valid_regex(load_modules):
        match = re.search(r"\((.*?)\)\$?$", load_modules)
        if match:
            suffix = match.group(1).split("|")
            return is_subset(suffix, expected_lora_modules)
    return False


199
def get_supported_lora_modules(model: nn.Module) -> list[str]:
200
201
202
    """
    In vLLM, all linear layers support LoRA.
    """
203

204
    supported_lora_modules: set[str] = set()
205
    for name, module in model.named_modules():
206
207
208
209
210
211
212
213
        # get the embedding modules if the module's embedding_modules
        # is not empty.
        embedding_modules = getattr(module, "embedding_modules", None)
        if embedding_modules is not None:
            for name in embedding_modules:
                supported_lora_modules.add(name)

        # get all the linear subfixes.
214
        if isinstance(module, (LinearBase,)):
215
            supported_lora_modules.add(name.split(".")[-1])
216

217
218
219
    return list(supported_lora_modules)


220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
242
    if lora_path.startswith("~"):
243
244
245
246
247
248
249
250
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

    # If the path does not exist locally, assume it's a Hugging Face repo.
    try:
251
252
253
254
255
256
257
        local_snapshot_path = huggingface_hub.snapshot_download(repo_id=lora_path)
    except (
        HfHubHTTPError,
        RepositoryNotFoundError,
        EntryNotFoundError,
        HFValidationError,
    ):
258
        # Handle errors that may occur during the download
259
        # Return original path instead of throwing error here
260
261
262
263
        logger.exception("Error downloading the HuggingFace model")
        return lora_path

    return local_snapshot_path