utils.py 9.49 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING, Optional
6

7
import huggingface_hub
8
9
10
11
12
13
from huggingface_hub.utils import (
    EntryNotFoundError,
    HfHubHTTPError,
    HFValidationError,
    RepositoryNotFoundError,
)
14
from torch import nn
15
from transformers import PretrainedConfig
16

17
from vllm.config.lora import LoRAConfig
18
from vllm.logger import init_logger
19

20
# being imported for _all_lora_classes below
21
22
23
24
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
25
    FusedMoE3DWithLoRA,
26
    FusedMoEWithLoRA,
27
28
29
30
31
32
33
34
35
36
37
38
    LogitsProcessorWithLoRA,
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
39
from vllm.model_executor.layers.fused_moe import FusedMoE
40
from vllm.model_executor.layers.linear import LinearBase
41
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
42
43
44

if TYPE_CHECKING:
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
46
    from vllm.model_executor.models.utils import WeightsMapper
47
48

logger = init_logger(__name__)
49

50
51
52
53
54
55
56
57
58
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


59
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
60
61
62
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
63
64
    QKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithLoRA,
65
    RowParallelLinearWithLoRA,
66
    ReplicatedLinearWithLoRA,
67
68
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
69
    QKVParallelLinearWithShardedLoRA,
70
    MergedColumnParallelLinearWithShardedLoRA,
71
    MergedQKVParallelLinearWithShardedLoRA,
72
    RowParallelLinearWithShardedLoRA,
73
    FusedMoEWithLoRA,
74
    FusedMoE3DWithLoRA,
75
76
77
}


78
79
80
81
82
83
84
85
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
        return True
    return False


86
87
88
89
90
def from_layer(
    layer: nn.Module,
    max_loras: int,
    lora_config: LoRAConfig,
    packed_modules_list: list,
91
    model_config: PretrainedConfig | None = None,
92
) -> nn.Module:
93
94
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
95
96
97
98
99
100
        if lora_cls.can_replace_layer(
            source_layer=layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
        ):
101
            instance_layer = lora_cls(layer)
102
            instance_layer.create_lora_weights(max_loras, lora_config, model_config)
103
            return instance_layer
104
105
106
107
    return layer


def from_layer_logits_processor(
108
109
    layer: "LogitsProcessor",
    lm_head: "ParallelLMHead",
110
111
    max_loras: int,
    lora_config: LoRAConfig,
112
    model_config: PretrainedConfig | None = None,
113
) -> LogitsProcessorWithLoRA:
114
115
116
117
118
119
120
    ret = LogitsProcessorWithLoRA(
        layer,
        lm_head.embedding_dim,
        lm_head.weight.dtype,
        lm_head.weight.device,
        lm_head.get_sharded_to_full_mapping(),
    )
121
122
123
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

124

125
126
127
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
128
129
130
131
132
133
134
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


135
def parse_fine_tuned_lora_name(
136
    name: str, weights_mapper: Optional["WeightsMapper"] = None
137
) -> tuple[str, bool]:
138
139
140
141
142
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
143
144
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
145
    return:
146
        tuple(module_name, is_lora_a):
147
148
149
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
    """
150

151
    # LoRA weight qualified name usually starts with `base_model.model.`,
152
153
    # so we remove the prefix `base_model.model.` to make the following
    # mapping correctly.
154
    if name.startswith("base_model.model."):
155
156
157
158
        name = name.replace("base_model.model.", "")
        name = weights_mapper._map_name(name) if weights_mapper else name
        # recover the prefix `base_model.model.`
        name = "base_model.model." + name
159
160
    else:
        name = weights_mapper._map_name(name) if weights_mapper else name
161

162
163
164
    # In some situations, we may not start with `base_model.model.`.
    # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
    # we should keep the prefix intact.
165
    start_index = 2 if name.startswith("base_model.model.") else 0
166

167
    parts = name.split(".")
168
    if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
169
        new_name = ".".join(parts[start_index:-2])
170
        return new_name, parts[-2] == "lora_A"
171

172
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
173
        new_name = ".".join(parts[start_index:-1])
174
        return new_name, parts[-1] == "lora_embedding_A"
175

176
    raise ValueError(f"{name} is unsupported LoRA weight")
177
178


179
180
def is_base_embeddding_weights(name: str) -> bool:
    # hardcoded subfixes for input & output embedding weights
181
182
183
    embedding_suffixes = (
        ".embed_tokens.base_layer.weight",
        ".lm_head.base_layer.weight",
184
    )
185
    return name.endswith(embedding_suffixes)
186
187


188
def get_supported_lora_modules(model: nn.Module) -> list[str]:
189
190
191
    """
    In vLLM, all linear layers support LoRA.
    """
192

193
    supported_lora_modules: set[str] = set()
194
    for name, module in model.named_modules():
195
196
197
198
199
200
201
202
        # get the embedding modules if the module's embedding_modules
        # is not empty.
        embedding_modules = getattr(module, "embedding_modules", None)
        if embedding_modules is not None:
            for name in embedding_modules:
                supported_lora_modules.add(name)

        # get all the linear subfixes.
203
        if isinstance(module, (LinearBase,)):
204
            supported_lora_modules.add(name.split(".")[-1])
205

206
207
208
        if isinstance(module, (FusedMoE,)):
            supported_lora_modules.add(name.split(".")[-1])

209
210
211
    return list(supported_lora_modules)


212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
234
    if lora_path.startswith("~"):
235
236
237
238
239
240
241
242
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

    # If the path does not exist locally, assume it's a Hugging Face repo.
    try:
243
244
245
246
247
248
249
        local_snapshot_path = huggingface_hub.snapshot_download(repo_id=lora_path)
    except (
        HfHubHTTPError,
        RepositoryNotFoundError,
        EntryNotFoundError,
        HFValidationError,
    ):
250
        # Handle errors that may occur during the download
251
        # Return original path instead of throwing error here
252
253
254
255
        logger.exception("Error downloading the HuggingFace model")
        return lora_path

    return local_snapshot_path
256
257
258
259
260
261
262
263
264
265
266


def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
    if is_moe_model(model):
        if moe_packed_mapping := get_moe_expert_mapping(model):
            # This method generates and returns a dictionary mapping packed module
            # names to lists of their corresponding submodule names. It includes
            # both static mappings and dynamic mappings for expert layers, where
            # the expert indices are expanded based on the configured number
            # of routed experts.
            packed_modules_mapping = get_packed_modules_mapping(model)
267
            if not model.is_3d_moe_weight:
268
269
270
271
272
                # 3D MoE LoRA does not need `packed_modules_mapping`
                packed_modules_mapping["experts"] = [
                    weight_name.rstrip(".")
                    for _, weight_name, _, _ in moe_packed_mapping
                ]
273
274
275
276
277
278
279
280
281

            return packed_modules_mapping
        else:
            raise AttributeError(
                "To support LoRA for MoE model, "
                "'get_expert_mapping' must be implemented"
            )
    else:
        return get_packed_modules_mapping(model)