utils.py 10.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING
6

7
import huggingface_hub
8
from huggingface_hub.utils import HfHubHTTPError, HFValidationError
9
from torch import nn
10
from transformers import PretrainedConfig
11

12
from vllm import envs
13
from vllm.config.lora import LoRAConfig
14
from vllm.logger import init_logger
15

16
# being imported for _all_lora_classes below
17
18
19
20
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
21
    FusedMoE3DWithLoRA,
22
    FusedMoEWithLoRA,
23
    LogitsProcessorWithLoRA,
24
    MergedColumnParallelLinearVariableSliceWithLoRA,
25
26
27
28
29
30
31
32
33
34
35
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
36
from vllm.model_executor.layers.fused_moe import FusedMoE
37
from vllm.model_executor.layers.linear import LinearBase
38
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
39
40
41

if TYPE_CHECKING:
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
43
    from vllm.model_executor.models.utils import WeightsMapper
44
45

logger = init_logger(__name__)
46

47
48
49
50
51
52
53
54
55
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


56
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
57
58
59
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
60
61
    QKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithLoRA,
62
    RowParallelLinearWithLoRA,
63
    ReplicatedLinearWithLoRA,
64
65
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
66
    QKVParallelLinearWithShardedLoRA,
67
    MergedColumnParallelLinearWithShardedLoRA,
68
    MergedColumnParallelLinearVariableSliceWithLoRA,
69
    MergedQKVParallelLinearWithShardedLoRA,
70
    RowParallelLinearWithShardedLoRA,
71
    FusedMoEWithLoRA,
72
    FusedMoE3DWithLoRA,
73
74
75
}


76
77
78
79
80
81
82
83
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
        return True
    return False


84
85
86
87
88
def from_layer(
    layer: nn.Module,
    max_loras: int,
    lora_config: LoRAConfig,
    packed_modules_list: list,
89
    model_config: PretrainedConfig | None = None,
90
) -> nn.Module:
91
92
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
93
94
95
96
97
98
        if lora_cls.can_replace_layer(
            source_layer=layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
        ):
99
            instance_layer = lora_cls(layer)
100
            instance_layer.create_lora_weights(max_loras, lora_config, model_config)
101
            return instance_layer
102
103
104
105
    return layer


def from_layer_logits_processor(
106
107
    layer: "LogitsProcessor",
    lm_head: "ParallelLMHead",
108
109
    max_loras: int,
    lora_config: LoRAConfig,
110
    model_config: PretrainedConfig | None = None,
111
) -> LogitsProcessorWithLoRA:
112
113
114
115
116
117
118
    ret = LogitsProcessorWithLoRA(
        layer,
        lm_head.embedding_dim,
        lm_head.weight.dtype,
        lm_head.weight.device,
        lm_head.get_sharded_to_full_mapping(),
    )
119
120
121
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

122

123
124
125
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
126
127
128
129
130
131
132
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


133
def parse_fine_tuned_lora_name(
134
    name: str, weights_mapper: "WeightsMapper | None" = None
135
) -> tuple[str, bool]:
136
137
138
139
140
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
141
142
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
143
    return:
144
        tuple(module_name, is_lora_a):
145
146
147
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
    """
148

149
    # LoRA weight qualified name usually starts with `base_model.model.`,
150
151
    # so we remove the prefix `base_model.model.` to make the following
    # mapping correctly.
152
    if name.startswith("base_model.model."):
153
154
155
156
        name = name.replace("base_model.model.", "")
        name = weights_mapper._map_name(name) if weights_mapper else name
        # recover the prefix `base_model.model.`
        name = "base_model.model." + name
157
158
    else:
        name = weights_mapper._map_name(name) if weights_mapper else name
159

160
161
162
    # In some situations, we may not start with `base_model.model.`.
    # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
    # we should keep the prefix intact.
163
    start_index = 2 if name.startswith("base_model.model.") else 0
164

165
    parts = name.split(".")
166
    if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
167
        new_name = ".".join(parts[start_index:-2])
168
        return new_name, parts[-2] == "lora_A"
169

170
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
171
        new_name = ".".join(parts[start_index:-1])
172
        return new_name, parts[-1] == "lora_embedding_A"
173

174
    raise ValueError(f"{name} is unsupported LoRA weight")
175
176


177
178
def is_base_embeddding_weights(name: str) -> bool:
    # hardcoded subfixes for input & output embedding weights
179
180
181
    embedding_suffixes = (
        ".embed_tokens.base_layer.weight",
        ".lm_head.base_layer.weight",
182
    )
183
    return name.endswith(embedding_suffixes)
184
185


186
def get_supported_lora_modules(model: nn.Module) -> list[str]:
187
188
189
    """
    In vLLM, all linear layers support LoRA.
    """
190

191
    supported_lora_modules: set[str] = set()
192
    for name, module in model.named_modules():
193
194
195
196
197
198
199
200
        # get the embedding modules if the module's embedding_modules
        # is not empty.
        embedding_modules = getattr(module, "embedding_modules", None)
        if embedding_modules is not None:
            for name in embedding_modules:
                supported_lora_modules.add(name)

        # get all the linear subfixes.
201
        if isinstance(module, (LinearBase,)):
202
            supported_lora_modules.add(name.split(".")[-1])
203

204
205
206
        if isinstance(module, (FusedMoE,)):
            supported_lora_modules.add(name.split(".")[-1])

207
208
209
    return list(supported_lora_modules)


210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
232
    if lora_path.startswith("~"):
233
234
235
236
237
238
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    # If the path does not exist locally.
    if envs.VLLM_USE_MODELSCOPE:
        # If using ModelScope, we assume the path is a ModelScope repo.
        from modelscope.hub.snapshot_download import InvalidParameter, snapshot_download
        from requests import HTTPError

        download_fn = lambda: snapshot_download(model_id=lora_path)
        download_exceptions = (HTTPError, InvalidParameter)
        error_log = "Error downloading the ModelScope model"
    else:
        # Otherwise, we assume the path is a Hugging Face Hub repo.
        download_fn = lambda: huggingface_hub.snapshot_download(repo_id=lora_path)
        download_exceptions = (HfHubHTTPError, HFValidationError)
        error_log = "Error downloading the HuggingFace model"

254
    try:
255
256
257
258
259
        local_snapshot_path = download_fn()
    except download_exceptions:
        # Handle errors that may occur during the download.
        # Return original path instead of throwing error here.
        logger.exception(error_log)
260
261
262
        return lora_path

    return local_snapshot_path
263
264
265
266
267
268
269
270
271
272
273


def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
    if is_moe_model(model):
        if moe_packed_mapping := get_moe_expert_mapping(model):
            # This method generates and returns a dictionary mapping packed module
            # names to lists of their corresponding submodule names. It includes
            # both static mappings and dynamic mappings for expert layers, where
            # the expert indices are expanded based on the configured number
            # of routed experts.
            packed_modules_mapping = get_packed_modules_mapping(model)
274
            if not model.is_3d_moe_weight:
275
                # 3D MoE LoRA does not need `packed_modules_mapping`
276
277
278
                # Filter out malformed entries: non-gated MoE has empty
                # ckpt_up_proj_name which results in weight_name containing ".."
                # (e.g., "experts.0.." instead of "experts.0.layer_name.")
279
280
281
                packed_modules_mapping["experts"] = [
                    weight_name.rstrip(".")
                    for _, weight_name, _, _ in moe_packed_mapping
282
                    if ".." not in weight_name
283
                ]
284
285
286
287
288
289
290
291
292

            return packed_modules_mapping
        else:
            raise AttributeError(
                "To support LoRA for MoE model, "
                "'get_expert_mapping' must be implemented"
            )
    else:
        return get_packed_modules_mapping(model)