utils.py 12.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING
6

7
import huggingface_hub
8
import regex as re
9
from huggingface_hub.utils import HfHubHTTPError, HFValidationError
10
from torch import nn
11
from transformers import PretrainedConfig
12

13
from vllm import envs
14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16

17
# being imported for _all_lora_classes below
18
19
20
21
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
22
    FusedMoE3DWithLoRA,
23
    FusedMoEWithLoRA,
24
    LogitsProcessorWithLoRA,
25
    MergedColumnParallelLinearVariableSliceWithLoRA,
26
27
28
29
30
31
32
33
34
35
36
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
37
from vllm.model_executor.layers.fused_moe import FusedMoE
38
from vllm.model_executor.layers.linear import LinearBase
39
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
40
41
42

if TYPE_CHECKING:
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
44
    from vllm.model_executor.models.utils import WeightsMapper
45
46

logger = init_logger(__name__)
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]:
    """
    Returns num_active_loras values for cudagraph capture.

    When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
    When specialize=False: just [max_loras + 1].

    This is the single source of truth for LoRA capture cases, used by both
    CudagraphDispatcher and PunicaWrapperGPU.
    """
    if not specialize:
        return [max_loras + 1]

    return [
        n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1
    ]


67
68
69
70
71
72
73
74
75
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


76
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
77
78
79
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
80
81
    QKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithLoRA,
82
    RowParallelLinearWithLoRA,
83
    ReplicatedLinearWithLoRA,
84
85
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
86
    QKVParallelLinearWithShardedLoRA,
87
    MergedColumnParallelLinearWithShardedLoRA,
88
    MergedColumnParallelLinearVariableSliceWithLoRA,
89
    MergedQKVParallelLinearWithShardedLoRA,
90
    RowParallelLinearWithShardedLoRA,
91
    FusedMoEWithLoRA,
92
    FusedMoE3DWithLoRA,
93
94
95
}


96
97
98
99
100
101
102
103
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
        return True
    return False


104
105
106
107
108
def from_layer(
    layer: nn.Module,
    max_loras: int,
    lora_config: LoRAConfig,
    packed_modules_list: list,
109
    model_config: PretrainedConfig | None = None,
110
) -> nn.Module:
111
112
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
113
114
115
116
117
118
        if lora_cls.can_replace_layer(
            source_layer=layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
        ):
119
            instance_layer = lora_cls(layer)
120
            instance_layer.create_lora_weights(max_loras, lora_config, model_config)
121
            return instance_layer
122
123
124
125
    return layer


def from_layer_logits_processor(
126
127
    layer: "LogitsProcessor",
    lm_head: "ParallelLMHead",
128
129
    max_loras: int,
    lora_config: LoRAConfig,
130
    model_config: PretrainedConfig | None = None,
131
) -> LogitsProcessorWithLoRA:
132
133
134
135
136
137
138
    ret = LogitsProcessorWithLoRA(
        layer,
        lm_head.embedding_dim,
        lm_head.weight.dtype,
        lm_head.weight.device,
        lm_head.get_sharded_to_full_mapping(),
    )
139
140
141
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

142

143
144
145
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
146
147
148
149
150
151
152
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


153
def parse_fine_tuned_lora_name(
154
    name: str, weights_mapper: "WeightsMapper | None" = None
155
) -> tuple[str, bool]:
156
157
158
159
160
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
161
162
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
163
    return:
164
        tuple(module_name, is_lora_a):
165
166
167
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
    """
168

169
    # LoRA weight qualified name usually starts with `base_model.model.`,
170
171
    # so we remove the prefix `base_model.model.` to make the following
    # mapping correctly.
172
    if name.startswith("base_model.model."):
173
174
175
176
        name = name.replace("base_model.model.", "")
        name = weights_mapper._map_name(name) if weights_mapper else name
        # recover the prefix `base_model.model.`
        name = "base_model.model." + name
177
178
    else:
        name = weights_mapper._map_name(name) if weights_mapper else name
179

180
181
182
    # In some situations, we may not start with `base_model.model.`.
    # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
    # we should keep the prefix intact.
183
    start_index = 2 if name.startswith("base_model.model.") else 0
184

185
    parts = name.split(".")
186
    if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
187
        new_name = ".".join(parts[start_index:-2])
188
        return new_name, parts[-2] == "lora_A"
189

190
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
191
        new_name = ".".join(parts[start_index:-1])
192
        return new_name, parts[-1] == "lora_embedding_A"
193

194
    raise ValueError(f"{name} is unsupported LoRA weight")
195
196


Jiayi Yan's avatar
Jiayi Yan committed
197
def is_base_embedding_weights(name: str) -> bool:
198
    # hardcoded subfixes for input & output embedding weights
199
200
201
    embedding_suffixes = (
        ".embed_tokens.base_layer.weight",
        ".lm_head.base_layer.weight",
202
    )
203
    return name.endswith(embedding_suffixes)
204
205


206
def get_supported_lora_modules(model: nn.Module) -> list[str]:
207
208
209
    """
    In vLLM, all linear layers support LoRA.
    """
210

211
    supported_lora_modules: set[str] = set()
212
    for name, module in model.named_modules():
213
214
215
216
217
218
219
220
        # get the embedding modules if the module's embedding_modules
        # is not empty.
        embedding_modules = getattr(module, "embedding_modules", None)
        if embedding_modules is not None:
            for name in embedding_modules:
                supported_lora_modules.add(name)

        # get all the linear subfixes.
221
        if isinstance(module, (LinearBase,)):
222
            supported_lora_modules.add(name.split(".")[-1])
223

224
225
226
        if isinstance(module, (FusedMoE,)):
            supported_lora_modules.add(name.split(".")[-1])

227
228
229
    return list(supported_lora_modules)


230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def is_supported_lora_module(
    module_name: str,
    supported_lora_modules: list[str],
) -> bool:
    """Check if a module is in the model's supported LoRA modules.

    Uses regex suffix matching against the model-defined supported modules
    list (e.g., matching "model.layers.0.self_attn.o_proj" against
    "o_proj").

    Args:
        module_name: Full dot-separated module name.
        supported_lora_modules: List of module suffixes supported by the
            model.

    Returns:
        True if the module is supported, False otherwise.
    """
    return any(
        re.match(
            r".*\.{target_module}$".format(target_module=target_module),
            module_name,
        )
        or target_module == module_name
        for target_module in supported_lora_modules
    )


def is_in_target_modules(
    module_name: str,
    target_modules: list[str] | None,
) -> bool:
    """Check if a module passes the deployment-time target_modules filter.

    When target_modules is None (no restriction), all modules pass.
    Otherwise, the module's suffix must be in the target_modules list.

    Args:
        module_name: Full dot-separated module name.
        target_modules: Optional deployment-time restriction list from
            LoRAConfig.target_modules.

    Returns:
        True if the module passes the filter, False otherwise.
    """
    if target_modules is None:
        return True
    module_suffix = module_name.split(".")[-1]
    return module_suffix in set(target_modules)


281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
303
    if lora_path.startswith("~"):
304
305
306
307
308
309
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    # If the path does not exist locally.
    if envs.VLLM_USE_MODELSCOPE:
        # If using ModelScope, we assume the path is a ModelScope repo.
        from modelscope.hub.snapshot_download import InvalidParameter, snapshot_download
        from requests import HTTPError

        download_fn = lambda: snapshot_download(model_id=lora_path)
        download_exceptions = (HTTPError, InvalidParameter)
        error_log = "Error downloading the ModelScope model"
    else:
        # Otherwise, we assume the path is a Hugging Face Hub repo.
        download_fn = lambda: huggingface_hub.snapshot_download(repo_id=lora_path)
        download_exceptions = (HfHubHTTPError, HFValidationError)
        error_log = "Error downloading the HuggingFace model"

325
    try:
326
327
328
329
330
        local_snapshot_path = download_fn()
    except download_exceptions:
        # Handle errors that may occur during the download.
        # Return original path instead of throwing error here.
        logger.exception(error_log)
331
332
333
        return lora_path

    return local_snapshot_path
334
335
336
337
338
339
340
341
342
343
344


def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
    if is_moe_model(model):
        if moe_packed_mapping := get_moe_expert_mapping(model):
            # This method generates and returns a dictionary mapping packed module
            # names to lists of their corresponding submodule names. It includes
            # both static mappings and dynamic mappings for expert layers, where
            # the expert indices are expanded based on the configured number
            # of routed experts.
            packed_modules_mapping = get_packed_modules_mapping(model)
345
            if not model.is_3d_moe_weight:
346
                # 3D MoE LoRA does not need `packed_modules_mapping`
347
348
349
                # Filter out malformed entries: non-gated MoE has empty
                # ckpt_up_proj_name which results in weight_name containing ".."
                # (e.g., "experts.0.." instead of "experts.0.layer_name.")
350
351
352
                packed_modules_mapping["experts"] = [
                    weight_name.rstrip(".")
                    for _, weight_name, _, _ in moe_packed_mapping
353
                    if ".." not in weight_name
354
                ]
355
356
357
358
359
360
361
362
363

            return packed_modules_mapping
        else:
            raise AttributeError(
                "To support LoRA for MoE model, "
                "'get_expert_mapping' must be implemented"
            )
    else:
        return get_packed_modules_mapping(model)