utils.py 9.19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING, Optional, Union
6

7
import huggingface_hub
8
import regex as re
9
10
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
                                   HFValidationError, RepositoryNotFoundError)
11
from torch import nn
12
from transformers import PretrainedConfig
13

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
17
18
19
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
Jee Jee Li's avatar
Jee Jee Li committed
20
                              ColumnParallelLinearWithShardedLoRA,
21
22
                              LogitsProcessorWithLoRA,
                              MergedColumnParallelLinearWithLoRA,
Jee Jee Li's avatar
Jee Jee Li committed
23
                              MergedColumnParallelLinearWithShardedLoRA,
24
                              MergedQKVParallelLinearWithLoRA,
Jee Jee Li's avatar
Jee Jee Li committed
25
                              MergedQKVParallelLinearWithShardedLoRA,
26
                              QKVParallelLinearWithLoRA,
Jee Jee Li's avatar
Jee Jee Li committed
27
                              QKVParallelLinearWithShardedLoRA,
28
                              ReplicatedLinearWithLoRA,
29
                              RowParallelLinearWithLoRA,
Jee Jee Li's avatar
Jee Jee Li committed
30
                              RowParallelLinearWithShardedLoRA,
31
                              VocabParallelEmbeddingWithLoRA)
32
from vllm.model_executor.layers.linear import LinearBase
33

34
# yapf: enable
35
36
37
38
39
40

if TYPE_CHECKING:
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
    from vllm.model_executor.layers.vocab_parallel_embedding import (
        ParallelLMHead)
    from vllm.model_executor.models.utils import WeightsMapper
41
42

logger = init_logger(__name__)
43

44
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
45
46
47
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
48
49
    QKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithLoRA,
50
    RowParallelLinearWithLoRA,
51
    ReplicatedLinearWithLoRA,
52
53
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
54
    QKVParallelLinearWithShardedLoRA,
55
    MergedColumnParallelLinearWithShardedLoRA,
56
    MergedQKVParallelLinearWithShardedLoRA,
57
    RowParallelLinearWithShardedLoRA,
58
59
60
61
62
63
}


def from_layer(layer: nn.Module,
               max_loras: int,
               lora_config: LoRAConfig,
64
               packed_modules_list: list,
65
66
67
68
69
70
71
               model_config: Optional[PretrainedConfig] = None) -> nn.Module:
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
        if lora_cls.can_replace_layer(source_layer=layer,
                                      lora_config=lora_config,
                                      packed_modules_list=packed_modules_list,
                                      model_config=model_config):
72
73
74
75
            instance_layer = lora_cls(layer)
            instance_layer.create_lora_weights(max_loras, lora_config,
                                               model_config)
            return instance_layer
76
77
78
79
    return layer


def from_layer_logits_processor(
80
81
    layer: "LogitsProcessor",
    lm_head: "ParallelLMHead",
82
83
84
85
86
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
    ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
87
88
                                  lm_head.weight.dtype, lm_head.weight.device,
                                  lm_head.get_sharded_to_full_mapping())
89
90
91
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

92
93
94
95
96
97
98
99
100
101

def replace_submodule(model: nn.Module, module_name: str,
                      new_module: nn.Module) -> nn.Module:
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


102
def parse_fine_tuned_lora_name(
103
104
    name: str,
    weights_mapper: Optional["WeightsMapper"] = None
105
) -> tuple[str, bool, bool]:
106
107
108
109
110
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
111
112
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
113
    return:
114
        tuple(module_name, is_lora_a):
115
116
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
117
            is_bias whether the tensor is lora bias.
118
    """
119

120
    # LoRA weight qualified name usually starts with `base_model.model.`,
121
122
    # so we remove the prefix `base_model.model.` to make the following
    # mapping correctly.
123
    if name.startswith("base_model.model."):
124
125
126
127
        name = name.replace("base_model.model.", "")
        name = weights_mapper._map_name(name) if weights_mapper else name
        # recover the prefix `base_model.model.`
        name = "base_model.model." + name
128
129
    else:
        name = weights_mapper._map_name(name) if weights_mapper else name
130

131
132
133
    # In some situations, we may not start with `base_model.model.`.
    # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
    # we should keep the prefix intact.
134
    start_index = 2 if name.startswith("base_model.model.") else 0
135

136
    parts = name.split(".")
137
138
    if parts[-1] == "weight" and (parts[-2] == "lora_A"
                                  or parts[-2] == "lora_B"):
139
        new_name = ".".join(parts[start_index:-2])
140
        return new_name, parts[-2] == "lora_A", False
141

142
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
143
        new_name = ".".join(parts[start_index:-1])
144
        return new_name, parts[-1] == "lora_embedding_A", False
145
146

    if parts[-1] == "bias":
147
        new_name = ".".join(parts[start_index:-2])
148
        return new_name, False, True
149

150
    raise ValueError(f"{name} is unsupported LoRA weight")
151
152


153
154
def is_regex_target_modules(load_modules: Union[str, list[str]],
                            expected_lora_modules: list[str]) -> bool:
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    """
    PEFT supports passing `target_modules` in the form of regular expressions, 
    such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to 
    determine whether the suffix in the regular expression is present in the 
    `expected_lora_modules`.
    """

    def is_valid_regex(pattern):
        try:
            re.compile(pattern)
            return True
        except re.error:
            return False

    def is_subset(sub_list, full_list):
        return set(sub_list).issubset(set(full_list))

    # Similar to PEFT's processing logic, regex-related operations are only
    #  executed when the load_modules is a `str`.
    if not isinstance(load_modules, str):
        return False

    if is_valid_regex(load_modules):
        match = re.search(r"\((.*?)\)\$?$", load_modules)
        if match:
            suffix = match.group(1).split("|")
            return is_subset(suffix, expected_lora_modules)
    return False


185
def get_supported_lora_modules(model: nn.Module) -> list[str]:
186
187
188
    """
    In vLLM, all linear layers support LoRA.
    """
189

190
    supported_lora_modules: set[str] = set()
191
    for name, module in model.named_modules():
192
193
194
195
196
197
198
199
        # get the embedding modules if the module's embedding_modules
        # is not empty.
        embedding_modules = getattr(module, "embedding_modules", None)
        if embedding_modules is not None:
            for name in embedding_modules:
                supported_lora_modules.add(name)

        # get all the linear subfixes.
200
201
        if isinstance(module, (LinearBase, )):
            supported_lora_modules.add(name.split(".")[-1])
202

203
204
205
    return list(supported_lora_modules)


206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
    if lora_path.startswith('~'):
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

    # If the path does not exist locally, assume it's a Hugging Face repo.
    try:
        local_snapshot_path = huggingface_hub.snapshot_download(
            repo_id=lora_path)
    except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
            HFValidationError):
        # Handle errors that may occur during the download
242
        # Return original path instead of throwing error here
243
244
245
246
        logger.exception("Error downloading the HuggingFace model")
        return lora_path

    return local_snapshot_path