utils.py 9.97 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING, Optional
6

7
import huggingface_hub
8
import regex as re
9
10
11
12
13
14
from huggingface_hub.utils import (
    EntryNotFoundError,
    HfHubHTTPError,
    HFValidationError,
    RepositoryNotFoundError,
)
15
from torch import nn
16
from transformers import PretrainedConfig
17

18
from vllm.config.lora import LoRAConfig
19
from vllm.logger import init_logger
20

21
# being imported for _all_lora_classes below
22
23
24
25
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
26
    FusedMoEWithLoRA,
27
28
29
30
31
32
33
34
35
36
37
38
    LogitsProcessorWithLoRA,
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
39
from vllm.model_executor.layers.fused_moe import FusedMoE
40
from vllm.model_executor.layers.linear import LinearBase
41
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
42
43
44

if TYPE_CHECKING:
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
46
    from vllm.model_executor.models.utils import WeightsMapper
47
48

logger = init_logger(__name__)
49

50
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
51
52
53
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
54
55
    QKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithLoRA,
56
    RowParallelLinearWithLoRA,
57
    ReplicatedLinearWithLoRA,
58
59
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
60
    QKVParallelLinearWithShardedLoRA,
61
    MergedColumnParallelLinearWithShardedLoRA,
62
    MergedQKVParallelLinearWithShardedLoRA,
63
    RowParallelLinearWithShardedLoRA,
64
    FusedMoEWithLoRA,
65
66
67
}


68
69
70
71
72
73
74
75
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
        return True
    return False


76
77
78
79
80
def from_layer(
    layer: nn.Module,
    max_loras: int,
    lora_config: LoRAConfig,
    packed_modules_list: list,
81
    model_config: PretrainedConfig | None = None,
82
) -> nn.Module:
83
84
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
85
86
87
88
89
90
        if lora_cls.can_replace_layer(
            source_layer=layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
        ):
91
            instance_layer = lora_cls(layer)
92
            instance_layer.create_lora_weights(max_loras, lora_config, model_config)
93
            return instance_layer
94
95
96
97
    return layer


def from_layer_logits_processor(
98
99
    layer: "LogitsProcessor",
    lm_head: "ParallelLMHead",
100
101
    max_loras: int,
    lora_config: LoRAConfig,
102
    model_config: PretrainedConfig | None = None,
103
) -> LogitsProcessorWithLoRA:
104
105
106
107
108
109
110
    ret = LogitsProcessorWithLoRA(
        layer,
        lm_head.embedding_dim,
        lm_head.weight.dtype,
        lm_head.weight.device,
        lm_head.get_sharded_to_full_mapping(),
    )
111
112
113
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

114

115
116
117
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
118
119
120
121
122
123
124
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


125
def parse_fine_tuned_lora_name(
126
    name: str, weights_mapper: Optional["WeightsMapper"] = None
127
) -> tuple[str, bool]:
128
129
130
131
132
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
133
134
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
135
    return:
136
        tuple(module_name, is_lora_a):
137
138
139
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
    """
140

141
    # LoRA weight qualified name usually starts with `base_model.model.`,
142
143
    # so we remove the prefix `base_model.model.` to make the following
    # mapping correctly.
144
    if name.startswith("base_model.model."):
145
146
147
148
        name = name.replace("base_model.model.", "")
        name = weights_mapper._map_name(name) if weights_mapper else name
        # recover the prefix `base_model.model.`
        name = "base_model.model." + name
149
150
    else:
        name = weights_mapper._map_name(name) if weights_mapper else name
151

152
153
154
    # In some situations, we may not start with `base_model.model.`.
    # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
    # we should keep the prefix intact.
155
    start_index = 2 if name.startswith("base_model.model.") else 0
156

157
    parts = name.split(".")
158
    if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
159
        new_name = ".".join(parts[start_index:-2])
160
        return new_name, parts[-2] == "lora_A"
161

162
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
163
        new_name = ".".join(parts[start_index:-1])
164
        return new_name, parts[-1] == "lora_embedding_A"
165

166
    raise ValueError(f"{name} is unsupported LoRA weight")
167
168


169
def is_regex_target_modules(
170
    load_modules: str | list[str], expected_lora_modules: list[str]
171
) -> bool:
172
    """
173
174
175
    PEFT supports passing `target_modules` in the form of regular expressions,
    such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
    determine whether the suffix in the regular expression is present in the
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    `expected_lora_modules`.
    """

    def is_valid_regex(pattern):
        try:
            re.compile(pattern)
            return True
        except re.error:
            return False

    def is_subset(sub_list, full_list):
        return set(sub_list).issubset(set(full_list))

    # Similar to PEFT's processing logic, regex-related operations are only
    #  executed when the load_modules is a `str`.
    if not isinstance(load_modules, str):
        return False

    if is_valid_regex(load_modules):
        match = re.search(r"\((.*?)\)\$?$", load_modules)
        if match:
            suffix = match.group(1).split("|")
            return is_subset(suffix, expected_lora_modules)
    return False


202
def get_supported_lora_modules(model: nn.Module) -> list[str]:
203
204
205
    """
    In vLLM, all linear layers support LoRA.
    """
206

207
    supported_lora_modules: set[str] = set()
208
    for name, module in model.named_modules():
209
210
211
212
213
214
215
216
        # get the embedding modules if the module's embedding_modules
        # is not empty.
        embedding_modules = getattr(module, "embedding_modules", None)
        if embedding_modules is not None:
            for name in embedding_modules:
                supported_lora_modules.add(name)

        # get all the linear subfixes.
217
        if isinstance(module, (LinearBase,)):
218
            supported_lora_modules.add(name.split(".")[-1])
219

220
221
222
        if isinstance(module, (FusedMoE,)):
            supported_lora_modules.add(name.split(".")[-1])

223
224
225
    return list(supported_lora_modules)


226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
248
    if lora_path.startswith("~"):
249
250
251
252
253
254
255
256
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

    # If the path does not exist locally, assume it's a Hugging Face repo.
    try:
257
258
259
260
261
262
263
        local_snapshot_path = huggingface_hub.snapshot_download(repo_id=lora_path)
    except (
        HfHubHTTPError,
        RepositoryNotFoundError,
        EntryNotFoundError,
        HFValidationError,
    ):
264
        # Handle errors that may occur during the download
265
        # Return original path instead of throwing error here
266
267
268
269
        logger.exception("Error downloading the HuggingFace model")
        return lora_path

    return local_snapshot_path
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293


def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
    if is_moe_model(model):
        if moe_packed_mapping := get_moe_expert_mapping(model):
            # This method generates and returns a dictionary mapping packed module
            # names to lists of their corresponding submodule names. It includes
            # both static mappings and dynamic mappings for expert layers, where
            # the expert indices are expanded based on the configured number
            # of routed experts.
            packed_modules_mapping = get_packed_modules_mapping(model)

            packed_modules_mapping["experts"] = [
                weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
            ]

            return packed_modules_mapping
        else:
            raise AttributeError(
                "To support LoRA for MoE model, "
                "'get_expert_mapping' must be implemented"
            )
    else:
        return get_packed_modules_mapping(model)