utils.py 16.1 KB
Newer Older
1
import itertools
2
3
4
from dataclasses import dataclass, field
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
                    Protocol, Tuple, Union, overload)
5

6
import torch
7
import torch.nn as nn
8
from torch.func import functional_call
9
from transformers import PretrainedConfig
10

11
12
13
14
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
                         SchedulerConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
15
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
16
from vllm.model_executor.models import ModelRegistry
17
from vllm.multimodal.base import NestedTensors
18
from vllm.sequence import IntermediateTensors
19
from vllm.utils import is_pin_memory_available
20

21
22
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
23

24

25
26
27
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
28

29
30
31
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
32

33
34
35
36
37
    def _map_name(self, key: str) -> Optional[str]:
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
56

57
58
59
60
61
    def apply(
        self, weights: Iterable[Tuple[str, torch.Tensor]]
    ) -> Iterable[Tuple[str, torch.Tensor]]:
        return ((out_name, data) for name, data in weights
                if (out_name := self._map_name(name)) is not None)
62

63
64

class AutoWeightsLoader:
65
    """
66
67
68
69
70
71
72
73
74
    Helper class to load weights into a :class:`torch.nn.Module`. It is able
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
    by defining a ``load_weights`` method.

    Similarly, the weight loading logic for individual parameters can be
    overridden by defining a ``weight_loader`` method.
75
    """
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    def __init__(
        self,
        module: nn.Module,
        *,
        skip_prefixes: Optional[List[str]] = None,
        ignore_unexpected_prefixes: Optional[List[str]] = None,
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []

    def _groupby_prefix(
        self,
        weights: Iterable[Tuple[str, torch.Tensor]],
    ) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
        weights_by_parts = ((weight_name.split(".", 1), weight_data)
                            for weight_name, weight_data in weights)

        for prefix, group in itertools.groupby(weights_by_parts,
                                               key=lambda x: x[0][0]):
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
                (("" if len(parts) == 1 else parts[1], weights_data)
                 for parts, weights_data in group),
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
        return any(qualname.startswith(p) for p in self.skip_prefixes)

    def _can_ignore_unexpected(self, qualname: str) -> bool:
        return any(
            qualname.startswith(p) for p in self.ignore_unexpected_prefixes)

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
        weights: Iterable[Tuple[str, torch.Tensor]],
127
    ) -> Iterable[str]:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
                continue

            if weight_name != "":
                if not self._can_ignore_unexpected(weight_qualname):
                    raise ValueError(
                        f"Attempted to load nested weight '{weight_qualname}' "
                        f"into a single parameter '{base_prefix}'")

                continue

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, weight_data)

146
147
            yield weight_qualname

148
149
150
151
152
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
        weights: Iterable[Tuple[str, torch.Tensor]],
153
    ) -> Iterable[str]:
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        if isinstance(module, PPMissingLayer):
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
                module_load_weights(weights)
                return

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if self._can_skip(prefix):
                continue

            if child_prefix in child_modules:
175
176
177
                yield from self._load_module(prefix,
                                             child_modules[child_prefix],
                                             child_weights)
178
            elif child_prefix in child_params:
179
180
                yield from self._load_param(prefix, child_params[child_prefix],
                                            child_weights)
181
182
            else:
                if not self._can_ignore_unexpected(prefix):
183
184
                    msg = (f"There is no module or parameter named '{prefix}' "
                           f"in {type(self.module).__name__}")
185
186
187
188
189
190
191
                    raise ValueError(msg)

    def load_weights(
        self,
        weights: Iterable[Tuple[str, torch.Tensor]],
        *,
        mapper: Optional[WeightsMapper] = None,
192
    ) -> List[str]:
193
194
195
        if mapper is not None:
            weights = mapper.apply(weights)

196
197
        autoloaded_weights = list(self._load_module("", self.module, weights))
        return autoloaded_weights
198
199


200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def init_vllm_registered_model(
    hf_config: PretrainedConfig,
    cache_config: Optional[CacheConfig],
    quant_config: Optional[QuantizationConfig],
    *,
    lora_config: Optional[LoRAConfig] = None,
    multimodal_config: Optional[MultiModalConfig] = None,
    scheduler_config: Optional[SchedulerConfig] = None,
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
    model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)

    return build_model(
        model_class,
        hf_config,
        cache_config,
        quant_config,
        lora_config=lora_config,
        multimodal_config=multimodal_config,
        scheduler_config=scheduler_config,
    )


226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
    ...


@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
    ...


@overload
def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: Literal[True],
) -> torch.Tensor:
    ...


def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
    """
    Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.

    The input tensor should have shape ``(B, N, ...)```.
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


264
265
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
266
267
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
268
269
270
    """

    if isinstance(embeddings, torch.Tensor):
271
272
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

    return " + ".join(
        _embedding_count_expression(inner) for inner in embeddings)


290
291
def merge_multimodal_embeddings(input_ids: torch.Tensor,
                                inputs_embeds: torch.Tensor,
292
                                multimodal_embeddings: NestedTensors,
293
                                placeholder_token_id: int) -> torch.Tensor:
294
    """
295
296
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
297
    ``input_ids``.
298
299

    Note:
300
        This updates ``inputs_embeds`` in place.
301
    """
302
    mask = (input_ids == placeholder_token_id)
303
304
    num_expected_tokens = mask.sum().item()
    assert isinstance(num_expected_tokens, int)
305

306
    flattened = _flatten_embeddings(multimodal_embeddings)
307
    if flattened.shape[0] != num_expected_tokens:
308
309
        expr = _embedding_count_expression(multimodal_embeddings)
        raise ValueError(
310
            f"Attempted to assign {expr} = {flattened.shape[0]} "
311
            f"multimodal tokens to {num_expected_tokens} placeholders")
312

313
    inputs_embeds[mask] = flattened
314
    return inputs_embeds
315
316


317
318
class LayerFn(Protocol):

319
    def __call__(self, prefix: str) -> torch.nn.Module:
320
321
322
        ...


323
324
325
326
327
328
329
330
331
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()


332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
    device = next(module.parameters()).device

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
356
    offloaded_parameters = False
357
358
359
360
361
362
363
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
364
365
366
367
368
369
        cpu_data = torch.empty_strided(size=p.data.size(),
                                       stride=p.data.stride(),
                                       dtype=p.data.dtype,
                                       layout=p.data.layout,
                                       device='cpu',
                                       pin_memory=pin_memory)
370
371
372
        cpu_data.copy_(p.data)
        p.data = cpu_data
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        offloaded_parameters = True

    if offloaded_parameters:
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module,
                                     device_state,
                                     args=args,
                                     kwargs=kwargs)
            module.forward = forward
            return output
392
393
394
395
396
397

        module.forward = forward

    return module


398
def make_layers(
399
400
401
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
402
403
404
405
406
407
408
409
410
411
) -> Tuple[int, int, torch.nn.ModuleList]:
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
    start_layer, end_layer = get_pp_indices(num_hidden_layers,
                                            get_pp_group().rank_in_group,
                                            get_pp_group().world_size)
    modules = torch.nn.ModuleList(
412
        [PPMissingLayer() for _ in range(start_layer)] + [
413
414
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
415
        ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}


def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
432
433
434
435
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
            missing_layer_names.append(name + '.')
436
437
438
439
440
441
442
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
443
444
445
446
447
448
    if isinstance(model, PPMissingLayer):
        return True

    return any(
        name.startswith(missing_layer_name)
        for missing_layer_name in get_pp_missing_layer_names(model))
449
450
451
452
453


def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):

    def make_empty_intermediate_tensors(
454
455
456
457
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
458
459
460
461
462
463
464
465
        return IntermediateTensors({
            key: torch.zeros((batch_size, hidden_size),
                             dtype=dtype,
                             device=device)
            for key in keys
        })

    return make_empty_intermediate_tensors
466
467
468
469
470
471
472
473
474
475
476
477
478


class LLMWrapper(nn.Module):
    """
    To align with the key names of LoRA trained with PEFT, we need to add an 
    additional layer to the llm's implementation.
    """

    def __init__(self, llm: nn.Module, name: str) -> None:
        super().__init__()
        self.model_name = name
        setattr(self, name, llm)

479
480
481
482
    def __getattr__(self, key: str):
        llm = super().__getattr__(self.model_name)
        if key == self.model_name:
            return llm
483

484
485
486
487
488
489
        return getattr(llm, key)

    # We need to explicitly override this
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        llm = super().__getattr__(self.model_name)
        return llm(*args, **kwargs)