utils.py 21.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
from dataclasses import dataclass, field
5
6
from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional,
                    Protocol, Set, Tuple, Union, overload)
7

8
import torch
9
import torch.nn as nn
10
from torch.func import functional_call
11
from transformers import PretrainedConfig
12

13
from vllm.config import VllmConfig
14
from vllm.logger import init_logger
15
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
16
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
17
from vllm.sequence import IntermediateTensors
18
from vllm.utils import is_pin_memory_available
19
20

logger = init_logger(__name__)
21

22
23
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
24

25

26
27
28
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
29

30
31
32
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
33

34
35
36
37
38
    def _map_name(self, key: str) -> Optional[str]:
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
57

58
59
60
61
62
    def apply(
        self, weights: Iterable[Tuple[str, torch.Tensor]]
    ) -> Iterable[Tuple[str, torch.Tensor]]:
        return ((out_name, data) for name, data in weights
                if (out_name := self._map_name(name)) is not None)
63

64
65

class AutoWeightsLoader:
66
    """
67
68
69
70
71
72
73
74
75
    Helper class to load weights into a :class:`torch.nn.Module`. It is able
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
    by defining a ``load_weights`` method.

    Similarly, the weight loading logic for individual parameters can be
    overridden by defining a ``weight_loader`` method.
76
77
78

    Detailed weight loading information can be viewed by setting the
    environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
79
    """
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

    def __init__(
        self,
        module: nn.Module,
        *,
        skip_prefixes: Optional[List[str]] = None,
        ignore_unexpected_prefixes: Optional[List[str]] = None,
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []

    def _groupby_prefix(
        self,
        weights: Iterable[Tuple[str, torch.Tensor]],
    ) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
        weights_by_parts = ((weight_name.split(".", 1), weight_data)
                            for weight_name, weight_data in weights)

        for prefix, group in itertools.groupby(weights_by_parts,
                                               key=lambda x: x[0][0]):
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
                (("" if len(parts) == 1 else parts[1], weights_data)
                 for parts, weights_data in group),
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
        return any(qualname.startswith(p) for p in self.skip_prefixes)

    def _can_ignore_unexpected(self, qualname: str) -> bool:
        return any(
            qualname.startswith(p) for p in self.ignore_unexpected_prefixes)

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
        weights: Iterable[Tuple[str, torch.Tensor]],
131
    ) -> Iterable[str]:
132
133
134
135
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
136
137
                logger.debug("Skipping weight %s", weight_qualname)

138
139
140
                continue

            if weight_name != "":
141
142
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
143

144
145
146
147
148
                    continue

                raise ValueError(
                    f"Attempted to load nested weight '{weight_qualname}' "
                    f"into a single parameter '{base_prefix}'")
149
150
151
152
153

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, weight_data)

154
155
156
            logger.debug("Loaded weight %s with shape %s", weight_qualname,
                         param.shape)

157
158
            yield weight_qualname

159
160
161
162
163
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
        weights: Iterable[Tuple[str, torch.Tensor]],
164
    ) -> Iterable[str]:
165
166
167
168
169
170
171
172
        if isinstance(module, PPMissingLayer):
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
173
                loaded_params = module_load_weights(weights)
174
175
176
177
178
179
180
181
182
                if loaded_params is None:
                    logger.warning(
                        "Unable to collect loaded parameters "
                        "for module %s", module)
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
183
184
185
186
187
188
189
190

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
191
192
193
194
195
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

196
197
198
                yield from self._load_module(prefix,
                                             child_modules[child_prefix],
                                             child_weights)
199
            elif child_prefix in child_params:
200
201
202
203
204
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

205
206
                yield from self._load_param(prefix, child_params[child_prefix],
                                            child_weights)
207
            else:
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

                msg = (f"There is no module or parameter named '{prefix}' "
                       f"in {type(self.module).__name__}")
                raise ValueError(msg)
225
226
227
228
229
230

    def load_weights(
        self,
        weights: Iterable[Tuple[str, torch.Tensor]],
        *,
        mapper: Optional[WeightsMapper] = None,
231
    ) -> Set[str]:
232
233
234
        if mapper is not None:
            weights = mapper.apply(weights)

235
        autoloaded_weights = set(self._load_module("", self.module, weights))
236
        return autoloaded_weights
237
238


239
def init_vllm_registered_model(
240
    vllm_config: VllmConfig,
241
    *,
242
    prefix: str = "",
243
244
    hf_config: Optional[PretrainedConfig] = None,
    architectures: Optional[list[str]] = None,
245
246
247
248
249
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
250
    from vllm.model_executor.model_loader.loader import _initialize_model
251

252
253
254
255
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

256
    if hf_config is not None:
257
258
        vllm_config = vllm_config.with_hf_config(hf_config,
                                                 architectures=architectures)
259

260
    return _initialize_model(vllm_config=vllm_config, prefix=prefix)
261
262


263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
    ...


@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
    ...


@overload
def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: Literal[True],
) -> torch.Tensor:
    ...


282
283
284
285
286
287
288
289
290
@overload
def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
    ...


291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
    """
    Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.

    The input tensor should have shape ``(B, N, ...)```.
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


310
311
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
312
313
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
314
315
316
    """

    if isinstance(embeddings, torch.Tensor):
317
318
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

    return " + ".join(
        _embedding_count_expression(inner) for inner in embeddings)


336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def merge_multimodal_embeddings_from_map(
        inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
        placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
    """
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided 
    placeholder map .

    Note:
        This updates ``inputs_embeds`` in place.
    """
    flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
    inputs_embeds[placeholder_map.dest] = flattened_embeddings[
        placeholder_map.src]
    return inputs_embeds


Cyrus Leung's avatar
Cyrus Leung committed
352
353
354
355
356
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    is_multimodal: torch.Tensor,
    multimodal_embeddings: NestedTensors,
) -> torch.Tensor:
357
    """
358
359
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
360
    ``input_ids``.
361
362

    Note:
363
        This updates ``inputs_embeds`` in place.
364
    """
Cyrus Leung's avatar
Cyrus Leung committed
365
    num_expected_tokens = is_multimodal.sum().item()
366
    assert isinstance(num_expected_tokens, int)
367

368
    flattened = _flatten_embeddings(multimodal_embeddings)
369
    if flattened.shape[0] != num_expected_tokens:
370
371
        expr = _embedding_count_expression(multimodal_embeddings)
        raise ValueError(
372
            f"Attempted to assign {expr} = {flattened.shape[0]} "
373
            f"multimodal tokens to {num_expected_tokens} placeholders")
374

Cyrus Leung's avatar
Cyrus Leung committed
375
    inputs_embeds[is_multimodal] = flattened
376
    return inputs_embeds
377
378


Cyrus Leung's avatar
Cyrus Leung committed
379
380
381
382
def embed_multimodal(
    input_ids: torch.Tensor,
    multimodal_token_id: int,
    get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
383
    multimodal_embeds: NestedTensors,
Cyrus Leung's avatar
Cyrus Leung committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
) -> torch.Tensor:
    """
    Embed token IDs and multimodal inputs and combine their embeddings.

    ``multimodal_token_id`` is used to determine whether a token ID should
    be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.

    Compared to ``merge_multimodal_embeddings`, this avoids running
    ``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
    which causes issues when the placeholder token ID exceeds the
    vocabulary size of the language model.
    """
    is_multimodal = input_ids == multimodal_token_id
    is_text = ~is_multimodal

    text_embeds = get_text_embeds(input_ids[is_text])
    merged_embeds = torch.empty(
        (input_ids.shape[0], text_embeds.shape[1]),
        dtype=text_embeds.dtype,
        device=text_embeds.device,
    )

    merged_embeds[is_text] = text_embeds

    return _merge_multimodal_embeddings(
        merged_embeds,
        is_multimodal,
        multimodal_embeds,
    )


def merge_multimodal_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
419
    placeholder_token_id: Union[int, List[int]],
Cyrus Leung's avatar
Cyrus Leung committed
420
421
422
423
424
) -> torch.Tensor:
    """
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
    ``input_ids``.
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    
    ``placeholder_token_id`` can be a list of token ids (e.g, token ids 
    of img_start, img_break, and img_end tokens) when needed: This means 
    the order of these tokens in the ``input_ids`` MUST MATCH the order of 
    their embeddings in ``multimodal_embeddings`` since we need to 
    slice-merge instead of individually scattering.

    For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
    - T is text token
    - S is image start token
    - I is image embedding token
    - B is image break token
    - E is image end token.
    
    Then the image embeddings (that correspond to I's) from vision encoder 
    must be padded with embeddings of S, B, and E in the same order of 
    input_ids for a correct embedding merge.
Cyrus Leung's avatar
Cyrus Leung committed
442
443
444
445

    Note:
        This updates ``inputs_embeds`` in place.
    """
446
447
448
449
450
451
452
453
454
    if isinstance(placeholder_token_id, list):
        placeholder_token_id = torch.tensor(placeholder_token_id,
                                            device=input_ids.device)
        return _merge_multimodal_embeddings(
            inputs_embeds,
            torch.isin(input_ids, placeholder_token_id),
            multimodal_embeddings,
        )

Cyrus Leung's avatar
Cyrus Leung committed
455
456
457
458
459
460
461
    return _merge_multimodal_embeddings(
        inputs_embeds,
        (input_ids == placeholder_token_id),
        multimodal_embeddings,
    )


462
463
class LayerFn(Protocol):

464
    def __call__(self, prefix: str) -> torch.nn.Module:
465
466
467
        ...


468
469
470
471
472
473
474
475
476
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()


477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
    device = next(module.parameters()).device

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
501
    offloaded_parameters = False
502
503
504
505
506
507
508
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
509
510
511
512
513
514
        cpu_data = torch.empty_strided(size=p.data.size(),
                                       stride=p.data.stride(),
                                       dtype=p.data.dtype,
                                       layout=p.data.layout,
                                       device='cpu',
                                       pin_memory=pin_memory)
515
516
517
        cpu_data.copy_(p.data)
        p.data = cpu_data
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        offloaded_parameters = True

    if offloaded_parameters:
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module,
                                     device_state,
                                     args=args,
                                     kwargs=kwargs)
            module.forward = forward
            return output
537
538
539
540
541
542

        module.forward = forward

    return module


543
def make_layers(
544
545
546
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
547
548
549
550
551
552
553
554
555
556
) -> Tuple[int, int, torch.nn.ModuleList]:
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
    start_layer, end_layer = get_pp_indices(num_hidden_layers,
                                            get_pp_group().rank_in_group,
                                            get_pp_group().world_size)
    modules = torch.nn.ModuleList(
557
        [PPMissingLayer() for _ in range(start_layer)] + [
558
559
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
560
        ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}


def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
577
578
579
580
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
            missing_layer_names.append(name + '.')
581
582
583
584
585
586
587
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
588
589
590
591
592
593
    if isinstance(model, PPMissingLayer):
        return True

    return any(
        name.startswith(missing_layer_name)
        for missing_layer_name in get_pp_missing_layer_names(model))
594
595
596
597
598


def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):

    def make_empty_intermediate_tensors(
599
600
601
602
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
603
        return IntermediateTensors({
604
605
            key:
            torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
606
607
608
609
            for key in keys
        })

    return make_empty_intermediate_tensors
610
611


612
613
614
615
616
617
618
619
620
621
622
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643


def extract_layer_index(layer_name: str) -> int:
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
    - "model.encoder.layers.0.sub.1" -> ValueError
    """
    subnames = layer_name.split(".")
    int_vals: List[int] = []
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
    assert len(int_vals) == 1, (f"layer name {layer_name} should"
                                " only contain one integer")
    return int_vals[0]
644
645
646
647
648
649
650
651
652


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
653
    return tensors