utils.py 19.3 KB
Newer Older
1
import itertools
2
from dataclasses import dataclass, field
Cyrus Leung's avatar
Cyrus Leung committed
3
4
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
                    Optional, Protocol, Tuple, Union, overload)
5

6
import torch
7
import torch.nn as nn
8
from torch.func import functional_call
9
from transformers import PretrainedConfig
10

11
12
13
import vllm.envs as envs
from vllm.attention.selector import (_Backend, backend_name_to_enum,
                                     get_global_forced_attn_backend)
14
15
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
                         SchedulerConfig)
16
from vllm.logger import init_logger
17
18
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
19
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20
from vllm.model_executor.models import ModelRegistry
21
from vllm.multimodal.base import NestedTensors
22
from vllm.platforms import current_platform
23
from vllm.sequence import IntermediateTensors
24
from vllm.utils import is_pin_memory_available
25
26

logger = init_logger(__name__)
27

28
29
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
30

31

32
33
34
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
35

36
37
38
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
39

40
41
42
43
44
    def _map_name(self, key: str) -> Optional[str]:
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
63

64
65
66
67
68
    def apply(
        self, weights: Iterable[Tuple[str, torch.Tensor]]
    ) -> Iterable[Tuple[str, torch.Tensor]]:
        return ((out_name, data) for name, data in weights
                if (out_name := self._map_name(name)) is not None)
69

70
71

class AutoWeightsLoader:
72
    """
73
74
75
76
77
78
79
80
81
    Helper class to load weights into a :class:`torch.nn.Module`. It is able
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
    by defining a ``load_weights`` method.

    Similarly, the weight loading logic for individual parameters can be
    overridden by defining a ``weight_loader`` method.
82
    """
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    def __init__(
        self,
        module: nn.Module,
        *,
        skip_prefixes: Optional[List[str]] = None,
        ignore_unexpected_prefixes: Optional[List[str]] = None,
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []

    def _groupby_prefix(
        self,
        weights: Iterable[Tuple[str, torch.Tensor]],
    ) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
        weights_by_parts = ((weight_name.split(".", 1), weight_data)
                            for weight_name, weight_data in weights)

        for prefix, group in itertools.groupby(weights_by_parts,
                                               key=lambda x: x[0][0]):
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
                (("" if len(parts) == 1 else parts[1], weights_data)
                 for parts, weights_data in group),
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
        return any(qualname.startswith(p) for p in self.skip_prefixes)

    def _can_ignore_unexpected(self, qualname: str) -> bool:
        return any(
            qualname.startswith(p) for p in self.ignore_unexpected_prefixes)

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
        weights: Iterable[Tuple[str, torch.Tensor]],
134
    ) -> Iterable[str]:
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
                continue

            if weight_name != "":
                if not self._can_ignore_unexpected(weight_qualname):
                    raise ValueError(
                        f"Attempted to load nested weight '{weight_qualname}' "
                        f"into a single parameter '{base_prefix}'")

                continue

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, weight_data)

153
154
            yield weight_qualname

155
156
157
158
159
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
        weights: Iterable[Tuple[str, torch.Tensor]],
160
    ) -> Iterable[str]:
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        if isinstance(module, PPMissingLayer):
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
                module_load_weights(weights)
                return

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if self._can_skip(prefix):
                continue

            if child_prefix in child_modules:
182
183
184
                yield from self._load_module(prefix,
                                             child_modules[child_prefix],
                                             child_weights)
185
            elif child_prefix in child_params:
186
187
                yield from self._load_param(prefix, child_params[child_prefix],
                                            child_weights)
188
189
            else:
                if not self._can_ignore_unexpected(prefix):
190
191
                    msg = (f"There is no module or parameter named '{prefix}' "
                           f"in {type(self.module).__name__}")
192
193
194
195
196
197
198
                    raise ValueError(msg)

    def load_weights(
        self,
        weights: Iterable[Tuple[str, torch.Tensor]],
        *,
        mapper: Optional[WeightsMapper] = None,
199
    ) -> List[str]:
200
201
202
        if mapper is not None:
            weights = mapper.apply(weights)

203
204
        autoloaded_weights = list(self._load_module("", self.module, weights))
        return autoloaded_weights
205
206


207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def init_vllm_registered_model(
    hf_config: PretrainedConfig,
    cache_config: Optional[CacheConfig],
    quant_config: Optional[QuantizationConfig],
    *,
    lora_config: Optional[LoRAConfig] = None,
    multimodal_config: Optional[MultiModalConfig] = None,
    scheduler_config: Optional[SchedulerConfig] = None,
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
    model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)

    return build_model(
        model_class,
        hf_config,
        cache_config,
        quant_config,
        lora_config=lora_config,
        multimodal_config=multimodal_config,
        scheduler_config=scheduler_config,
    )


233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
    ...


@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
    ...


@overload
def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: Literal[True],
) -> torch.Tensor:
    ...


def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
    """
    Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.

    The input tensor should have shape ``(B, N, ...)```.
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


271
272
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
273
274
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
275
276
277
    """

    if isinstance(embeddings, torch.Tensor):
278
279
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

    return " + ".join(
        _embedding_count_expression(inner) for inner in embeddings)


Cyrus Leung's avatar
Cyrus Leung committed
297
298
299
300
301
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    is_multimodal: torch.Tensor,
    multimodal_embeddings: NestedTensors,
) -> torch.Tensor:
302
    """
303
304
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
305
    ``input_ids``.
306
307

    Note:
308
        This updates ``inputs_embeds`` in place.
309
    """
Cyrus Leung's avatar
Cyrus Leung committed
310
    num_expected_tokens = is_multimodal.sum().item()
311
    assert isinstance(num_expected_tokens, int)
312

313
    flattened = _flatten_embeddings(multimodal_embeddings)
314
    if flattened.shape[0] != num_expected_tokens:
315
316
        expr = _embedding_count_expression(multimodal_embeddings)
        raise ValueError(
317
            f"Attempted to assign {expr} = {flattened.shape[0]} "
318
            f"multimodal tokens to {num_expected_tokens} placeholders")
319

Cyrus Leung's avatar
Cyrus Leung committed
320
    inputs_embeds[is_multimodal] = flattened
321
    return inputs_embeds
322
323


Cyrus Leung's avatar
Cyrus Leung committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def embed_multimodal(
    input_ids: torch.Tensor,
    multimodal_token_id: int,
    get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
    get_multimodal_embeds: Callable[[torch.Tensor], Union[torch.Tensor,
                                                          List[torch.Tensor]]],
) -> torch.Tensor:
    """
    Embed token IDs and multimodal inputs and combine their embeddings.

    ``multimodal_token_id`` is used to determine whether a token ID should
    be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.

    Compared to ``merge_multimodal_embeddings`, this avoids running
    ``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
    which causes issues when the placeholder token ID exceeds the
    vocabulary size of the language model.
    """
    is_multimodal = input_ids == multimodal_token_id
    is_text = ~is_multimodal

    text_embeds = get_text_embeds(input_ids[is_text])
    multimodal_embeds = get_multimodal_embeds(input_ids[is_multimodal])

    merged_embeds = torch.empty(
        (input_ids.shape[0], text_embeds.shape[1]),
        dtype=text_embeds.dtype,
        device=text_embeds.device,
    )

    merged_embeds[is_text] = text_embeds

    return _merge_multimodal_embeddings(
        merged_embeds,
        is_multimodal,
        multimodal_embeds,
    )


def merge_multimodal_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
    placeholder_token_id: int,
) -> torch.Tensor:
    """
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
    ``input_ids``.

    Note:
        This updates ``inputs_embeds`` in place.
    """
    return _merge_multimodal_embeddings(
        inputs_embeds,
        (input_ids == placeholder_token_id),
        multimodal_embeddings,
    )


384
385
class LayerFn(Protocol):

386
    def __call__(self, prefix: str) -> torch.nn.Module:
387
388
389
        ...


390
391
392
393
394
395
396
397
398
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()


399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
    device = next(module.parameters()).device

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
423
    offloaded_parameters = False
424
425
426
427
428
429
430
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
431
432
433
434
435
436
        cpu_data = torch.empty_strided(size=p.data.size(),
                                       stride=p.data.stride(),
                                       dtype=p.data.dtype,
                                       layout=p.data.layout,
                                       device='cpu',
                                       pin_memory=pin_memory)
437
438
439
        cpu_data.copy_(p.data)
        p.data = cpu_data
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        offloaded_parameters = True

    if offloaded_parameters:
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module,
                                     device_state,
                                     args=args,
                                     kwargs=kwargs)
            module.forward = forward
            return output
459
460
461
462
463
464

        module.forward = forward

    return module


465
def make_layers(
466
467
468
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
469
470
471
472
473
474
475
476
477
478
) -> Tuple[int, int, torch.nn.ModuleList]:
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
    start_layer, end_layer = get_pp_indices(num_hidden_layers,
                                            get_pp_group().rank_in_group,
                                            get_pp_group().world_size)
    modules = torch.nn.ModuleList(
479
        [PPMissingLayer() for _ in range(start_layer)] + [
480
481
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
482
        ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}


def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
499
500
501
502
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
            missing_layer_names.append(name + '.')
503
504
505
506
507
508
509
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
510
511
512
513
514
515
    if isinstance(model, PPMissingLayer):
        return True

    return any(
        name.startswith(missing_layer_name)
        for missing_layer_name in get_pp_missing_layer_names(model))
516
517
518
519
520


def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):

    def make_empty_intermediate_tensors(
521
522
523
524
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
525
526
527
528
529
530
531
532
        return IntermediateTensors({
            key: torch.zeros((batch_size, hidden_size),
                             dtype=dtype,
                             device=device)
            for key in keys
        })

    return make_empty_intermediate_tensors
533
534
535
536


class LLMWrapper(nn.Module):
    """
537
    To align with the key names of LoRA trained with PEFT, we need to add an
538
539
540
541
542
543
544
545
    additional layer to the llm's implementation.
    """

    def __init__(self, llm: nn.Module, name: str) -> None:
        super().__init__()
        self.model_name = name
        setattr(self, name, llm)

546
547
548
549
    def __getattr__(self, key: str):
        llm = super().__getattr__(self.model_name)
        if key == self.model_name:
            return llm
550

551
552
553
554
555
556
        return getattr(llm, key)

    # We need to explicitly override this
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        llm = super().__getattr__(self.model_name)
        return llm(*args, **kwargs)
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577


def get_vit_attn_backend() -> _Backend:
    selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
    if selected_backend is None:
        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
        if backend_by_env_var is not None:
            selected_backend = backend_name_to_enum(backend_by_env_var)
    if selected_backend is None:
        # For Volta and Turing GPUs, use xformers instead.
        device_available = current_platform.has_device_capability(80)
        if device_available:
            from transformers.utils import is_flash_attn_2_available
            if is_flash_attn_2_available():
                selected_backend = _Backend.FLASH_ATTN
            else:
                logger.warning(
                    "Current `vllm-flash-attn` has a bug inside vision module, "
                    "so we use xformers backend instead. You can run "
                    "`pip install flash-attn` to use flash-attention backend.")
                selected_backend = _Backend.XFORMERS
578
        elif current_platform.is_cpu():
579
580
581
582
            selected_backend = _Backend.TORCH_SDPA
        else:
            selected_backend = _Backend.XFORMERS
    return selected_backend