utils.py 10.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING, Optional
6

7
import huggingface_hub
8
import regex as re
9
10
11
12
13
14
from huggingface_hub.utils import (
    EntryNotFoundError,
    HfHubHTTPError,
    HFValidationError,
    RepositoryNotFoundError,
)
15
from torch import nn
16
from transformers import PretrainedConfig
17

18
from vllm.config.lora import LoRAConfig
19
from vllm.logger import init_logger
20

21
# being imported for _all_lora_classes below
22
23
24
25
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
26
    FusedMoE3DWithLoRA,
27
    FusedMoEWithLoRA,
28
29
30
31
32
33
34
35
36
37
38
39
    LogitsProcessorWithLoRA,
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
40
from vllm.model_executor.layers.fused_moe import FusedMoE
41
from vllm.model_executor.layers.linear import LinearBase
42
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
43
44
45

if TYPE_CHECKING:
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
46
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
47
    from vllm.model_executor.models.utils import WeightsMapper
48
49

logger = init_logger(__name__)
50

51
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
52
53
54
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
55
56
    QKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithLoRA,
57
    RowParallelLinearWithLoRA,
58
    ReplicatedLinearWithLoRA,
59
60
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
61
    QKVParallelLinearWithShardedLoRA,
62
    MergedColumnParallelLinearWithShardedLoRA,
63
    MergedQKVParallelLinearWithShardedLoRA,
64
    RowParallelLinearWithShardedLoRA,
65
    FusedMoEWithLoRA,
66
    FusedMoE3DWithLoRA,
67
68
69
}


70
71
72
73
74
75
76
77
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
        return True
    return False


78
79
80
81
82
def from_layer(
    layer: nn.Module,
    max_loras: int,
    lora_config: LoRAConfig,
    packed_modules_list: list,
83
    model_config: PretrainedConfig | None = None,
84
) -> nn.Module:
85
86
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
87
88
89
90
91
92
        if lora_cls.can_replace_layer(
            source_layer=layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
        ):
93
            instance_layer = lora_cls(layer)
94
            instance_layer.create_lora_weights(max_loras, lora_config, model_config)
95
            return instance_layer
96
97
98
99
    return layer


def from_layer_logits_processor(
100
101
    layer: "LogitsProcessor",
    lm_head: "ParallelLMHead",
102
103
    max_loras: int,
    lora_config: LoRAConfig,
104
    model_config: PretrainedConfig | None = None,
105
) -> LogitsProcessorWithLoRA:
106
107
108
109
110
111
112
    ret = LogitsProcessorWithLoRA(
        layer,
        lm_head.embedding_dim,
        lm_head.weight.dtype,
        lm_head.weight.device,
        lm_head.get_sharded_to_full_mapping(),
    )
113
114
115
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

116

117
118
119
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
120
121
122
123
124
125
126
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


127
def parse_fine_tuned_lora_name(
128
    name: str, weights_mapper: Optional["WeightsMapper"] = None
129
) -> tuple[str, bool]:
130
131
132
133
134
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
135
136
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
137
    return:
138
        tuple(module_name, is_lora_a):
139
140
141
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
    """
142

143
    # LoRA weight qualified name usually starts with `base_model.model.`,
144
145
    # so we remove the prefix `base_model.model.` to make the following
    # mapping correctly.
146
    if name.startswith("base_model.model."):
147
148
149
150
        name = name.replace("base_model.model.", "")
        name = weights_mapper._map_name(name) if weights_mapper else name
        # recover the prefix `base_model.model.`
        name = "base_model.model." + name
151
152
    else:
        name = weights_mapper._map_name(name) if weights_mapper else name
153

154
155
156
    # In some situations, we may not start with `base_model.model.`.
    # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
    # we should keep the prefix intact.
157
    start_index = 2 if name.startswith("base_model.model.") else 0
158

159
    parts = name.split(".")
160
    if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
161
        new_name = ".".join(parts[start_index:-2])
162
        return new_name, parts[-2] == "lora_A"
163

164
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
165
        new_name = ".".join(parts[start_index:-1])
166
        return new_name, parts[-1] == "lora_embedding_A"
167

168
    raise ValueError(f"{name} is unsupported LoRA weight")
169
170


171
172
173
174
175
176
177
178
179
180
def is_base_embeddding_weights(name: str) -> bool:
    # hardcoded subfixes for input & output embedding weights
    input_embedding_subfix = ".embed_tokens.base_layer.weight"
    output_embedding_subfix = ".lm_head.base_layer.weight"

    return name.endswith(input_embedding_subfix) or name.endswith(
        output_embedding_subfix
    )


181
def is_regex_target_modules(
182
    load_modules: str | list[str], expected_lora_modules: list[str]
183
) -> bool:
184
    """
185
186
187
    PEFT supports passing `target_modules` in the form of regular expressions,
    such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
    determine whether the suffix in the regular expression is present in the
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    `expected_lora_modules`.
    """

    def is_valid_regex(pattern):
        try:
            re.compile(pattern)
            return True
        except re.error:
            return False

    def is_subset(sub_list, full_list):
        return set(sub_list).issubset(set(full_list))

    # Similar to PEFT's processing logic, regex-related operations are only
    #  executed when the load_modules is a `str`.
    if not isinstance(load_modules, str):
        return False

    if is_valid_regex(load_modules):
        match = re.search(r"\((.*?)\)\$?$", load_modules)
        if match:
            suffix = match.group(1).split("|")
            return is_subset(suffix, expected_lora_modules)
    return False


214
def get_supported_lora_modules(model: nn.Module) -> list[str]:
215
216
217
    """
    In vLLM, all linear layers support LoRA.
    """
218

219
    supported_lora_modules: set[str] = set()
220
    for name, module in model.named_modules():
221
222
223
224
225
226
227
228
        # get the embedding modules if the module's embedding_modules
        # is not empty.
        embedding_modules = getattr(module, "embedding_modules", None)
        if embedding_modules is not None:
            for name in embedding_modules:
                supported_lora_modules.add(name)

        # get all the linear subfixes.
229
        if isinstance(module, (LinearBase,)):
230
            supported_lora_modules.add(name.split(".")[-1])
231

232
233
234
        if isinstance(module, (FusedMoE,)):
            supported_lora_modules.add(name.split(".")[-1])

235
236
237
    return list(supported_lora_modules)


238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
260
    if lora_path.startswith("~"):
261
262
263
264
265
266
267
268
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

    # If the path does not exist locally, assume it's a Hugging Face repo.
    try:
269
270
271
272
273
274
275
        local_snapshot_path = huggingface_hub.snapshot_download(repo_id=lora_path)
    except (
        HfHubHTTPError,
        RepositoryNotFoundError,
        EntryNotFoundError,
        HFValidationError,
    ):
276
        # Handle errors that may occur during the download
277
        # Return original path instead of throwing error here
278
279
280
281
        logger.exception("Error downloading the HuggingFace model")
        return lora_path

    return local_snapshot_path
282
283
284
285
286
287
288
289
290
291
292


def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
    if is_moe_model(model):
        if moe_packed_mapping := get_moe_expert_mapping(model):
            # This method generates and returns a dictionary mapping packed module
            # names to lists of their corresponding submodule names. It includes
            # both static mappings and dynamic mappings for expert layers, where
            # the expert indices are expanded based on the configured number
            # of routed experts.
            packed_modules_mapping = get_packed_modules_mapping(model)
293
294
295
296
297
298
            if not hasattr(model, "is_3d_moe_weight"):
                # 3D MoE LoRA does not need `packed_modules_mapping`
                packed_modules_mapping["experts"] = [
                    weight_name.rstrip(".")
                    for _, weight_name, _, _ in moe_packed_mapping
                ]
299
300
301
302
303
304
305
306
307

            return packed_modules_mapping
        else:
            raise AttributeError(
                "To support LoRA for MoE model, "
                "'get_expert_mapping' must be implemented"
            )
    else:
        return get_packed_modules_mapping(model)