utils.py 11.2 KB
Newer Older
1
2
import itertools
from collections import UserDict
3
4
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol,
                    Tuple, Union, overload)
5

6
import torch
7
import torch.nn as nn
8
from torch.func import functional_call
9
from transformers import PretrainedConfig
10

11
12
13
14
15
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
                         SchedulerConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry
16
from vllm.multimodal.base import NestedTensors
17
from vllm.sequence import IntermediateTensors
18
from vllm.utils import is_pin_memory_available
19
20


21
22
23
24
25
26
class WeightsGroup(UserDict):
    """
    Wraps grouped weights dictionary for a more informative error message
    when attempting to access a weight component that does not exist.
    """

27
    def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]:
28
29
30
31
32
33
34
35
36
37
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"There is no weights named with the prefix: {key}. "
                   f"Available prefix: {set(self.keys())}")
            raise KeyError(msg) from exc


def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
                   prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
38
39
40
41
42
43
44
45
46
47
48
49
50
    """
    Helper function to load weights for inner vLLM models.

    See also:
        :ref:`init_vllm_registered_model`
    """
    for name, loaded_weight in weights:
        name = name.split(".")
        if prefix == name.pop(0):
            name = ".".join(name)
            yield name, loaded_weight


51
def group_weights_with_prefix(
52
    weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup:
53
54
55
56
57
58
59
60
61
62
63
64
65
    """
    Helper function to group weights with prefix
    """
    init_weights, repeated_weights = itertools.tee(weights, 2)
    weights_prefix = {name.split(".")[0] for name, _ in init_weights}
    repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))

    return WeightsGroup({
        prefix: filter_weights(component, prefix)
        for component, prefix in zip(repeated_weights, weights_prefix)
    })


66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def init_vllm_registered_model(
    hf_config: PretrainedConfig,
    cache_config: Optional[CacheConfig],
    quant_config: Optional[QuantizationConfig],
    *,
    lora_config: Optional[LoRAConfig] = None,
    multimodal_config: Optional[MultiModalConfig] = None,
    scheduler_config: Optional[SchedulerConfig] = None,
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
    model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)

    return build_model(
        model_class,
        hf_config,
        cache_config,
        quant_config,
        lora_config=lora_config,
        multimodal_config=multimodal_config,
        scheduler_config=scheduler_config,
    )


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
    ...


@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
    ...


@overload
def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: Literal[True],
) -> torch.Tensor:
    ...


def flatten_bn(
    x: Union[List[torch.Tensor], torch.Tensor],
    *,
    concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
    """
    Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.

    The input tensor should have shape ``(B, N, ...)```.
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


130
131
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
132
133
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
134
135
136
    """

    if isinstance(embeddings, torch.Tensor):
137
138
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

    return " + ".join(
        _embedding_count_expression(inner) for inner in embeddings)


156
157
def merge_multimodal_embeddings(input_ids: torch.Tensor,
                                inputs_embeds: torch.Tensor,
158
                                multimodal_embeddings: NestedTensors,
159
                                placeholder_token_id: int) -> torch.Tensor:
160
    """
161
162
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
163
    ``input_ids``.
164
165

    Note:
166
        This updates ``inputs_embeds`` in place.
167
    """
168
    mask = (input_ids == placeholder_token_id)
169
170
    num_expected_tokens = mask.sum().item()
    assert isinstance(num_expected_tokens, int)
171

172
    flattened = _flatten_embeddings(multimodal_embeddings)
173
    if flattened.shape[0] != num_expected_tokens:
174
175
        expr = _embedding_count_expression(multimodal_embeddings)
        raise ValueError(
176
            f"Attempted to assign {expr} = {flattened.shape[0]} "
177
            f"multimodal tokens to {num_expected_tokens} placeholders")
178

179
    inputs_embeds[mask] = flattened
180
    return inputs_embeds
181
182


183
184
class LayerFn(Protocol):

185
    def __call__(self, prefix: str) -> torch.nn.Module:
186
187
188
        ...


189
190
191
192
193
194
195
196
197
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()


198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
    device = next(module.parameters()).device

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
222
    offloaded_parameters = False
223
224
225
226
227
228
229
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
230
231
232
233
234
235
        cpu_data = torch.empty_strided(size=p.data.size(),
                                       stride=p.data.stride(),
                                       dtype=p.data.dtype,
                                       layout=p.data.layout,
                                       device='cpu',
                                       pin_memory=pin_memory)
236
237
238
        cpu_data.copy_(p.data)
        p.data = cpu_data
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        offloaded_parameters = True

    if offloaded_parameters:
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module,
                                     device_state,
                                     args=args,
                                     kwargs=kwargs)
            module.forward = forward
            return output
258
259
260
261
262
263

        module.forward = forward

    return module


264
def make_layers(
265
266
267
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
268
269
270
271
272
273
274
275
276
277
) -> Tuple[int, int, torch.nn.ModuleList]:
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
    start_layer, end_layer = get_pp_indices(num_hidden_layers,
                                            get_pp_group().rank_in_group,
                                            get_pp_group().world_size)
    modules = torch.nn.ModuleList(
278
        [PPMissingLayer() for _ in range(start_layer)] + [
279
280
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
281
        ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}


def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
298
299
300
301
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
            missing_layer_names.append(name + '.')
302
303
304
305
306
307
308
309
310
311
312
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
    for missing_layer_name in get_pp_missing_layer_names(model):
        if name.startswith(missing_layer_name):
            return True
    return False
313
314
315
316
317


def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):

    def make_empty_intermediate_tensors(
318
319
320
321
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
322
323
324
325
326
327
328
329
        return IntermediateTensors({
            key: torch.zeros((batch_size, hidden_size),
                             dtype=dtype,
                             device=device)
            for key in keys
        })

    return make_empty_intermediate_tensors
330
331
332
333
334
335
336
337
338
339
340
341
342


class LLMWrapper(nn.Module):
    """
    To align with the key names of LoRA trained with PEFT, we need to add an 
    additional layer to the llm's implementation.
    """

    def __init__(self, llm: nn.Module, name: str) -> None:
        super().__init__()
        self.model_name = name
        setattr(self, name, llm)

343
344
345
346
    def __getattr__(self, key: str):
        llm = super().__getattr__(self.model_name)
        if key == self.model_name:
            return llm
347

348
349
350
351
352
353
        return getattr(llm, key)

    # We need to explicitly override this
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        llm = super().__getattr__(self.model_name)
        return llm(*args, **kwargs)