utils.py 8.58 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING, Optional, Union
6

7
import huggingface_hub
8
import regex as re
9
10
11
12
13
14
from huggingface_hub.utils import (
    EntryNotFoundError,
    HfHubHTTPError,
    HFValidationError,
    RepositoryNotFoundError,
)
15
from torch import nn
16
from transformers import PretrainedConfig
17

18
from vllm.config.lora import LoRAConfig
19
from vllm.logger import init_logger
20

21
# being imported for _all_lora_classes below
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
    LogitsProcessorWithLoRA,
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
38
from vllm.model_executor.layers.linear import LinearBase
39
40
41

if TYPE_CHECKING:
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
43
    from vllm.model_executor.models.utils import WeightsMapper
44
45

logger = init_logger(__name__)
46

47
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
48
49
50
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
51
52
    QKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithLoRA,
53
    RowParallelLinearWithLoRA,
54
    ReplicatedLinearWithLoRA,
55
56
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
57
    QKVParallelLinearWithShardedLoRA,
58
    MergedColumnParallelLinearWithShardedLoRA,
59
    MergedQKVParallelLinearWithShardedLoRA,
60
    RowParallelLinearWithShardedLoRA,
61
62
63
}


64
65
66
67
68
69
70
def from_layer(
    layer: nn.Module,
    max_loras: int,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: Optional[PretrainedConfig] = None,
) -> nn.Module:
71
72
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
73
74
75
76
77
78
        if lora_cls.can_replace_layer(
            source_layer=layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
        ):
79
            instance_layer = lora_cls(layer)
80
            instance_layer.create_lora_weights(max_loras, lora_config, model_config)
81
            return instance_layer
82
83
84
85
    return layer


def from_layer_logits_processor(
86
87
    layer: "LogitsProcessor",
    lm_head: "ParallelLMHead",
88
89
90
91
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
92
93
94
95
96
97
98
    ret = LogitsProcessorWithLoRA(
        layer,
        lm_head.embedding_dim,
        lm_head.weight.dtype,
        lm_head.weight.device,
        lm_head.get_sharded_to_full_mapping(),
    )
99
100
101
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

102

103
104
105
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
106
107
108
109
110
111
112
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


113
def parse_fine_tuned_lora_name(
114
    name: str, weights_mapper: Optional["WeightsMapper"] = None
115
) -> tuple[str, bool, bool]:
116
117
118
119
120
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
121
122
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
123
    return:
124
        tuple(module_name, is_lora_a):
125
126
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
127
            is_bias whether the tensor is lora bias.
128
    """
129

130
    # LoRA weight qualified name usually starts with `base_model.model.`,
131
132
    # so we remove the prefix `base_model.model.` to make the following
    # mapping correctly.
133
    if name.startswith("base_model.model."):
134
135
136
137
        name = name.replace("base_model.model.", "")
        name = weights_mapper._map_name(name) if weights_mapper else name
        # recover the prefix `base_model.model.`
        name = "base_model.model." + name
138
139
    else:
        name = weights_mapper._map_name(name) if weights_mapper else name
140

141
142
143
    # In some situations, we may not start with `base_model.model.`.
    # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
    # we should keep the prefix intact.
144
    start_index = 2 if name.startswith("base_model.model.") else 0
145

146
    parts = name.split(".")
147
    if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
148
        new_name = ".".join(parts[start_index:-2])
149
        return new_name, parts[-2] == "lora_A", False
150

151
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
152
        new_name = ".".join(parts[start_index:-1])
153
        return new_name, parts[-1] == "lora_embedding_A", False
154
155

    if parts[-1] == "bias":
156
        new_name = ".".join(parts[start_index:-2])
157
        return new_name, False, True
158

159
    raise ValueError(f"{name} is unsupported LoRA weight")
160
161


162
163
164
def is_regex_target_modules(
    load_modules: Union[str, list[str]], expected_lora_modules: list[str]
) -> bool:
165
    """
166
167
168
    PEFT supports passing `target_modules` in the form of regular expressions,
    such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
    determine whether the suffix in the regular expression is present in the
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    `expected_lora_modules`.
    """

    def is_valid_regex(pattern):
        try:
            re.compile(pattern)
            return True
        except re.error:
            return False

    def is_subset(sub_list, full_list):
        return set(sub_list).issubset(set(full_list))

    # Similar to PEFT's processing logic, regex-related operations are only
    #  executed when the load_modules is a `str`.
    if not isinstance(load_modules, str):
        return False

    if is_valid_regex(load_modules):
        match = re.search(r"\((.*?)\)\$?$", load_modules)
        if match:
            suffix = match.group(1).split("|")
            return is_subset(suffix, expected_lora_modules)
    return False


195
def get_supported_lora_modules(model: nn.Module) -> list[str]:
196
197
198
    """
    In vLLM, all linear layers support LoRA.
    """
199

200
    supported_lora_modules: set[str] = set()
201
    for name, module in model.named_modules():
202
203
204
205
206
207
208
209
        # get the embedding modules if the module's embedding_modules
        # is not empty.
        embedding_modules = getattr(module, "embedding_modules", None)
        if embedding_modules is not None:
            for name in embedding_modules:
                supported_lora_modules.add(name)

        # get all the linear subfixes.
210
        if isinstance(module, (LinearBase,)):
211
            supported_lora_modules.add(name.split(".")[-1])
212

213
214
215
    return list(supported_lora_modules)


216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
238
    if lora_path.startswith("~"):
239
240
241
242
243
244
245
246
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

    # If the path does not exist locally, assume it's a Hugging Face repo.
    try:
247
248
249
250
251
252
253
        local_snapshot_path = huggingface_hub.snapshot_download(repo_id=lora_path)
    except (
        HfHubHTTPError,
        RepositoryNotFoundError,
        EntryNotFoundError,
        HFValidationError,
    ):
254
        # Handle errors that may occur during the download
255
        # Return original path instead of throwing error here
256
257
258
259
        logger.exception("Error downloading the HuggingFace model")
        return lora_path

    return local_snapshot_path