utils.py 10.6 KB
Newer Older
1
2
import itertools
from collections import UserDict
3
4
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
                    Union, overload)
5

6
import torch
7
import torch.nn as nn
8
from torch.func import functional_call
9
from transformers import PretrainedConfig
10

11
12
13
14
15
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
                         SchedulerConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry
16
from vllm.multimodal.base import NestedTensors
17
from vllm.sequence import IntermediateTensors
18
from vllm.utils import is_pin_memory_available
19
20


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class WeightsGroup(UserDict):
    """
    Wraps grouped weights dictionary for a more informative error message
    when attempting to access a weight component that does not exist.
    """

    def __getitem__(self, key: str) -> int:
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"There is no weights named with the prefix: {key}. "
                   f"Available prefix: {set(self.keys())}")
            raise KeyError(msg) from exc


def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
                   prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
38
39
40
41
42
43
44
45
46
47
48
49
50
    """
    Helper function to load weights for inner vLLM models.

    See also:
        :ref:`init_vllm_registered_model`
    """
    for name, loaded_weight in weights:
        name = name.split(".")
        if prefix == name.pop(0):
            name = ".".join(name)
            yield name, loaded_weight


51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def group_weights_with_prefix(
    weights: Iterable[Tuple[str, torch.Tensor]]
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
    """
    Helper function to group weights with prefix
    """
    init_weights, repeated_weights = itertools.tee(weights, 2)
    weights_prefix = {name.split(".")[0] for name, _ in init_weights}
    repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))

    return WeightsGroup({
        prefix: filter_weights(component, prefix)
        for component, prefix in zip(repeated_weights, weights_prefix)
    })


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def init_vllm_registered_model(
    hf_config: PretrainedConfig,
    cache_config: Optional[CacheConfig],
    quant_config: Optional[QuantizationConfig],
    *,
    lora_config: Optional[LoRAConfig] = None,
    multimodal_config: Optional[MultiModalConfig] = None,
    scheduler_config: Optional[SchedulerConfig] = None,
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
    model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)

    return build_model(
        model_class,
        hf_config,
        cache_config,
        quant_config,
        lora_config=lora_config,
        multimodal_config=multimodal_config,
        scheduler_config=scheduler_config,
    )


93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
    ...


@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
    ...


@overload
def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: Literal[True],
) -> torch.Tensor:
    ...


def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
    """
    Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.

    The input tensor should have shape ``(B, N, ...)```.
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


131
132
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
133
134
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
135
136
137
    """

    if isinstance(embeddings, torch.Tensor):
138
139
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

    return " + ".join(
        _embedding_count_expression(inner) for inner in embeddings)


157
158
def merge_multimodal_embeddings(input_ids: torch.Tensor,
                                inputs_embeds: torch.Tensor,
159
                                multimodal_embeddings: NestedTensors,
160
                                placeholder_token_id: int) -> torch.Tensor:
161
    """
162
163
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
164
    ``input_ids``.
165
166

    Note:
167
        This updates ``inputs_embeds`` in place.
168
    """
169
    mask = (input_ids == placeholder_token_id)
170
171
    num_expected_tokens = mask.sum().item()
    assert isinstance(num_expected_tokens, int)
172

173
    flattened = _flatten_embeddings(multimodal_embeddings)
174
    if flattened.shape[0] != num_expected_tokens:
175
176
        expr = _embedding_count_expression(multimodal_embeddings)
        raise ValueError(
177
            f"Attempted to assign {expr} = {flattened.shape[0]} "
178
            f"multimodal tokens to {num_expected_tokens} placeholders")
179

180
    inputs_embeds[mask] = flattened
181
    return inputs_embeds
182
183


184
185
186
187
188
189
190
191
192
class LayerFn(Protocol):

    def __call__(
        self,
        prefix="",
    ) -> torch.nn.Module:
        ...


193
194
195
196
197
198
199
200
201
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()


202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
    device = next(module.parameters()).device

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
226
    offloaded_parameters = False
227
228
229
230
231
232
233
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
234
235
236
237
238
239
        cpu_data = torch.empty_strided(size=p.data.size(),
                                       stride=p.data.stride(),
                                       dtype=p.data.dtype,
                                       layout=p.data.layout,
                                       device='cpu',
                                       pin_memory=pin_memory)
240
241
242
        cpu_data.copy_(p.data)
        p.data = cpu_data
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        offloaded_parameters = True

    if offloaded_parameters:
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module,
                                     device_state,
                                     args=args,
                                     kwargs=kwargs)
            module.forward = forward
            return output
262
263
264
265
266
267

        module.forward = forward

    return module


268
def make_layers(
269
270
271
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
272
273
274
275
276
277
278
279
280
281
) -> Tuple[int, int, torch.nn.ModuleList]:
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
    start_layer, end_layer = get_pp_indices(num_hidden_layers,
                                            get_pp_group().rank_in_group,
                                            get_pp_group().world_size)
    modules = torch.nn.ModuleList(
282
        [PPMissingLayer() for _ in range(start_layer)] + [
283
284
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
285
        ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}


def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
302
303
304
305
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
            missing_layer_names.append(name + '.')
306
307
308
309
310
311
312
313
314
315
316
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
    for missing_layer_name in get_pp_missing_layer_names(model):
        if name.startswith(missing_layer_name):
            return True
    return False
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331


def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):

    def make_empty_intermediate_tensors(
            batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            key: torch.zeros((batch_size, hidden_size),
                             dtype=dtype,
                             device=device)
            for key in keys
        })

    return make_empty_intermediate_tensors