utils.py 10.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING, Optional
6

7
import huggingface_hub
8
from huggingface_hub.utils import HfHubHTTPError, HFValidationError
9
from torch import nn
10
from transformers import PretrainedConfig
11

12
from vllm import envs
13
from vllm.config.lora import LoRAConfig
14
from vllm.logger import init_logger
15

16
# being imported for _all_lora_classes below
17
18
19
20
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
21
    FusedMoE3DWithLoRA,
22
    FusedMoEWithLoRA,
23
    LogitsProcessorWithLoRA,
24
    MergedColumnParallelLinearVariableSliceWithLoRA,
25
26
27
28
29
30
31
32
33
34
35
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
36
from vllm.model_executor.layers.fused_moe import FusedMoE
37
from vllm.model_executor.layers.linear import LinearBase
38
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
39
40
41

if TYPE_CHECKING:
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
43
    from vllm.model_executor.models.utils import WeightsMapper
44
45

logger = init_logger(__name__)
46

luopl's avatar
luopl committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]:
    """
    Returns num_active_loras values for cudagraph capture.

    When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
    When specialize=False: just [max_loras + 1].

    This is the single source of truth for LoRA capture cases, used by both
    CudagraphDispatcher and PunicaWrapperGPU.
    """
    if not specialize:
        return [max_loras + 1]

    return [
        n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1
    ]

64
65
66
67
68
69
70
71
72
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


73
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
74
75
76
    VocabParallelEmbeddingWithLoRA,
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
77
78
    QKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithLoRA,
79
    RowParallelLinearWithLoRA,
80
    ReplicatedLinearWithLoRA,
81
82
    LogitsProcessorWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
83
    QKVParallelLinearWithShardedLoRA,
84
    MergedColumnParallelLinearWithShardedLoRA,
85
    MergedColumnParallelLinearVariableSliceWithLoRA,
86
    MergedQKVParallelLinearWithShardedLoRA,
87
    RowParallelLinearWithShardedLoRA,
88
    FusedMoEWithLoRA,
89
    FusedMoE3DWithLoRA,
90
91
92
}


93
94
95
96
97
98
99
100
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
        return True
    return False


101
102
103
104
105
def from_layer(
    layer: nn.Module,
    max_loras: int,
    lora_config: LoRAConfig,
    packed_modules_list: list,
106
    model_config: PretrainedConfig | None = None,
107
) -> nn.Module:
108
109
    for lora_cls in _all_lora_classes:
        # specifying kwargs so they can be easily accessed in decorator
110
111
112
113
114
115
        if lora_cls.can_replace_layer(
            source_layer=layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
        ):
116
            instance_layer = lora_cls(layer)
117
            instance_layer.create_lora_weights(max_loras, lora_config, model_config)
118
            return instance_layer
119
120
121
122
    return layer


def from_layer_logits_processor(
123
124
    layer: "LogitsProcessor",
    lm_head: "ParallelLMHead",
125
126
    max_loras: int,
    lora_config: LoRAConfig,
127
    model_config: PretrainedConfig | None = None,
128
) -> LogitsProcessorWithLoRA:
129
130
131
132
133
134
135
    ret = LogitsProcessorWithLoRA(
        layer,
        lm_head.embedding_dim,
        lm_head.weight.dtype,
        lm_head.weight.device,
        lm_head.get_sharded_to_full_mapping(),
    )
136
137
138
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret

139

140
141
142
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
143
144
145
146
147
148
149
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module


150
def parse_fine_tuned_lora_name(
151
    name: str, weights_mapper: Optional["WeightsMapper"] = None
152
) -> tuple[str, bool]:
153
154
155
156
157
    """Parse the name of lora weights.

    args:
        name: the name of the fine-tuned LoRA, e.g.
            base_model.model.dense1.weight
158
159
        weights_mapper: maps the name of weight, e.g.
            `model.` -> `language_model.model.`,
160
    return:
161
        tuple(module_name, is_lora_a):
162
163
164
            module_name: the name of the module, e.g. model.dense1,
            is_lora_a whether the tensor is lora_a or lora_b.
    """
165

166
    # LoRA weight qualified name usually starts with `base_model.model.`,
167
168
    # so we remove the prefix `base_model.model.` to make the following
    # mapping correctly.
169
    if name.startswith("base_model.model."):
170
171
172
173
        name = name.replace("base_model.model.", "")
        name = weights_mapper._map_name(name) if weights_mapper else name
        # recover the prefix `base_model.model.`
        name = "base_model.model." + name
174
175
    else:
        name = weights_mapper._map_name(name) if weights_mapper else name
176

177
178
179
    # In some situations, we may not start with `base_model.model.`.
    # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
    # we should keep the prefix intact.
180
    start_index = 2 if name.startswith("base_model.model.") else 0
181

182
    parts = name.split(".")
183
    if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
184
        new_name = ".".join(parts[start_index:-2])
185
        return new_name, parts[-2] == "lora_A"
186

187
    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
188
        new_name = ".".join(parts[start_index:-1])
189
        return new_name, parts[-1] == "lora_embedding_A"
190

191
    raise ValueError(f"{name} is unsupported LoRA weight")
192
193


194
195
def is_base_embeddding_weights(name: str) -> bool:
    # hardcoded subfixes for input & output embedding weights
196
197
198
    embedding_suffixes = (
        ".embed_tokens.base_layer.weight",
        ".lm_head.base_layer.weight",
199
    )
200
    return name.endswith(embedding_suffixes)
201
202


203
def get_supported_lora_modules(model: nn.Module) -> list[str]:
204
205
206
    """
    In vLLM, all linear layers support LoRA.
    """
207

208
    supported_lora_modules: set[str] = set()
209
    for name, module in model.named_modules():
210
211
212
213
214
215
216
217
        # get the embedding modules if the module's embedding_modules
        # is not empty.
        embedding_modules = getattr(module, "embedding_modules", None)
        if embedding_modules is not None:
            for name in embedding_modules:
                supported_lora_modules.add(name)

        # get all the linear subfixes.
218
        if isinstance(module, (LinearBase,)):
219
            supported_lora_modules.add(name.split(".")[-1])
220

221
222
223
        if isinstance(module, (FusedMoE,)):
            supported_lora_modules.add(name.split(".")[-1])

224
225
226
    return list(supported_lora_modules)


227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def get_adapter_absolute_path(lora_path: str) -> str:
    """
    Resolves the given lora_path to an absolute local path.

    If the lora_path is identified as a Hugging Face model identifier,
    it will download the model and return the local snapshot path.
    Otherwise, it treats the lora_path as a local file path and
    converts it to an absolute path.

    Parameters:
    lora_path (str): The path to the lora model, which can be an absolute path,
                     a relative path, or a Hugging Face model identifier.

    Returns:
    str: The resolved absolute local path to the lora model.
    """

    # Check if the path is an absolute path. Return it no matter exists or not.
    if os.path.isabs(lora_path):
        return lora_path

    # If the path starts with ~, expand the user home directory.
249
    if lora_path.startswith("~"):
250
251
252
253
254
255
        return os.path.expanduser(lora_path)

    # Check if the expanded relative path exists locally.
    if os.path.exists(lora_path):
        return os.path.abspath(lora_path)

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    # If the path does not exist locally.
    if envs.VLLM_USE_MODELSCOPE:
        # If using ModelScope, we assume the path is a ModelScope repo.
        from modelscope.hub.snapshot_download import InvalidParameter, snapshot_download
        from requests import HTTPError

        download_fn = lambda: snapshot_download(model_id=lora_path)
        download_exceptions = (HTTPError, InvalidParameter)
        error_log = "Error downloading the ModelScope model"
    else:
        # Otherwise, we assume the path is a Hugging Face Hub repo.
        download_fn = lambda: huggingface_hub.snapshot_download(repo_id=lora_path)
        download_exceptions = (HfHubHTTPError, HFValidationError)
        error_log = "Error downloading the HuggingFace model"

271
    try:
272
273
274
275
276
        local_snapshot_path = download_fn()
    except download_exceptions:
        # Handle errors that may occur during the download.
        # Return original path instead of throwing error here.
        logger.exception(error_log)
277
278
279
        return lora_path

    return local_snapshot_path
280
281
282
283
284
285
286
287
288
289
290


def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
    if is_moe_model(model):
        if moe_packed_mapping := get_moe_expert_mapping(model):
            # This method generates and returns a dictionary mapping packed module
            # names to lists of their corresponding submodule names. It includes
            # both static mappings and dynamic mappings for expert layers, where
            # the expert indices are expanded based on the configured number
            # of routed experts.
            packed_modules_mapping = get_packed_modules_mapping(model)
291
            if not model.is_3d_moe_weight:
292
                # 3D MoE LoRA does not need `packed_modules_mapping`
293
294
295
                # Filter out malformed entries: non-gated MoE has empty
                # ckpt_up_proj_name which results in weight_name containing ".."
                # (e.g., "experts.0.." instead of "experts.0.layer_name.")
296
297
298
                packed_modules_mapping["experts"] = [
                    weight_name.rstrip(".")
                    for _, weight_name, _, _ in moe_packed_mapping
299
                    if ".." not in weight_name
300
                ]
301
302
303
304
305
306
307
308
309

            return packed_modules_mapping
        else:
            raise AttributeError(
                "To support LoRA for MoE model, "
                "'get_expert_mapping' must be implemented"
            )
    else:
        return get_packed_modules_mapping(model)