utils.py 9.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import Optional, Union
6

7
import huggingface_hub
8
import regex as re
9
10
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
                                   HFValidationError, RepositoryNotFoundError)
11
from torch import nn
12
from transformers import PretrainedConfig
13

14
from vllm.config import LoRAConfig
15
from vllm.logger import init_logger
16
17
18
from vllm.lora.fully_sharded_layers import (
    ColumnParallelLinearWithShardedLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
19
    MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
20
    RowParallelLinearWithShardedLoRA)
21
22
23
24
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
25
                              LinearScalingRotaryEmbeddingWithLoRA,
26
27
                              LogitsProcessorWithLoRA,
                              MergedColumnParallelLinearWithLoRA,
28
29
                              MergedQKVParallelLinearWithLoRA,
                              QKVParallelLinearWithLoRA,
30
                              ReplicatedLinearWithLoRA,
31
32
                              RowParallelLinearWithLoRA,
                              VocabParallelEmbeddingWithLoRA)
33
from vllm.model_executor.layers.linear import LinearBase
34
35
36
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
37
from vllm.model_executor.models.utils import WeightsMapper
38
39

logger = init_logger(__name__)
40

41
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
42
43
44
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
45
46
    QKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithLoRA,
47
    RowParallelLinearWithLoRA,
48
    ReplicatedLinearWithLoRA,
49
50
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
51
    QKVParallelLinearWithShardedLoRA,
52
    MergedColumnParallelLinearWithShardedLoRA,
53
    MergedQKVParallelLinearWithShardedLoRA,
54
    RowParallelLinearWithShardedLoRA,
55
    LinearScalingRotaryEmbeddingWithLoRA,
56
57
58
59
60
61
}


def from_layer(layer: nn.Module,
               max_loras: int,
               lora_config: LoRAConfig,
62
               packed_modules_list: list,
63
64
65
66
67
68
69
               model_config: Optional[PretrainedConfig] = None) -> nn.Module:
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
        if lora_cls.can_replace_layer(source_layer=layer,
                                      lora_config=lora_config,
                                      packed_modules_list=packed_modules_list,
                                      model_config=model_config):
70
71
72
73
            instance_layer = lora_cls(layer)
            instance_layer.create_lora_weights(max_loras, lora_config,
                                               model_config)
            return instance_layer
74
75
76
77
78
79
80
81
82
83
84
    return layer


def from_layer_logits_processor(
    layer: LogitsProcessor,
    lm_head: ParallelLMHead,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
    ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
85
86
                                  lm_head.weight.dtype, lm_head.weight.device,
                                  lm_head.get_sharded_to_full_mapping())
87
88
89
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

90
91
92
93
94
95
96
97
98
99

def replace_submodule(model: nn.Module, module_name: str,
                      new_module: nn.Module) -> nn.Module:
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


100
101
102
def parse_fine_tuned_lora_name(
        name: str,
        weights_mapper: Optional[WeightsMapper] = None
103
) -> tuple[str, bool, bool]:
104
105
106
107
108
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
109
110
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
111
    return:
112
        tuple(module_name, is_lora_a):
113
114
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
115
            is_bias whether the tensor is lora bias.
116
    """
117

118
    # LoRA weight qualified name usually starts with `base_model.model.`,
119
120
    # so we remove the prefix `base_model.model.` to make the following
    # mapping correctly.
121
    if name.startswith("base_model.model."):
122
123
124
125
        name = name.replace("base_model.model.", "")
        name = weights_mapper._map_name(name) if weights_mapper else name
        # recover the prefix `base_model.model.`
        name = "base_model.model." + name
126
127
    else:
        name = weights_mapper._map_name(name) if weights_mapper else name
128

129
130
131
    # In some situations, we may not start with `base_model.model.`.
    # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
    # we should keep the prefix intact.
132
    start_index = 2 if name.startswith("base_model.model.") else 0
133

134
    parts = name.split(".")
135
136
    if parts[-1] == "weight" and (parts[-2] == "lora_A"
                                  or parts[-2] == "lora_B"):
137
        new_name = ".".join(parts[start_index:-2])
138
        return new_name, parts[-2] == "lora_A", False
139

140
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
141
        new_name = ".".join(parts[start_index:-1])
142
        return new_name, parts[-1] == "lora_embedding_A", False
143
144

    if parts[-1] == "bias":
145
        new_name = ".".join(parts[start_index:-2])
146
        return new_name, False, True
147

148
    raise ValueError(f"{name} is unsupported LoRA weight")
149
150


151
152
def is_regex_target_modules(load_modules: Union[str, list[str]],
                            expected_lora_modules: list[str]) -> bool:
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    """
    PEFT supports passing `target_modules` in the form of regular expressions, 
    such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to 
    determine whether the suffix in the regular expression is present in the 
    `expected_lora_modules`.
    """

    def is_valid_regex(pattern):
        try:
            re.compile(pattern)
            return True
        except re.error:
            return False

    def is_subset(sub_list, full_list):
        return set(sub_list).issubset(set(full_list))

    # Similar to PEFT's processing logic, regex-related operations are only
    #  executed when the load_modules is a `str`.
    if not isinstance(load_modules, str):
        return False

    if is_valid_regex(load_modules):
        match = re.search(r"\((.*?)\)\$?$", load_modules)
        if match:
            suffix = match.group(1).split("|")
            return is_subset(suffix, expected_lora_modules)
    return False


183
def get_supported_lora_modules(model: nn.Module) -> list[str]:
184
185
186
    """
    In vLLM, all linear layers support LoRA.
    """
187
    supported_lora_modules: set[str] = set()
188
189
190
191
192
193
194
195
196
197
198
199
    # step1: traverse the model to get all the linear subfixes.
    for name, module in model.named_modules():
        if isinstance(module, (LinearBase, )):
            supported_lora_modules.add(name.split(".")[-1])
    # step 2: get the embedding modules if the model's mbedding_modules
    # is not empty.
    if model.embedding_modules:
        for name in model.embedding_modules:
            supported_lora_modules.add(name)
    return list(supported_lora_modules)


200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
    if lora_path.startswith('~'):
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

    # If the path does not exist locally, assume it's a Hugging Face repo.
    try:
        local_snapshot_path = huggingface_hub.snapshot_download(
            repo_id=lora_path)
    except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
            HFValidationError):
        # Handle errors that may occur during the download
        # Return original path instead instead of throwing error here
        logger.exception("Error downloading the HuggingFace model")
        return lora_path

    return local_snapshot_path