utils.py 27.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Iterable, Mapping
6
from dataclasses import dataclass, field
7
from typing import Any, Literal, Protocol, overload
8

9
import torch
10
import torch.nn as nn
11
from torch.func import functional_call
12
from transformers import PretrainedConfig
13
from typing_extensions import deprecated
14

15
from vllm.config import VllmConfig
16
17
18
19
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
20
from vllm.logger import init_logger
21
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
from vllm.model_executor.models.interfaces import supports_any_eagle
23
from vllm.multimodal import NestedTensors
24
from vllm.sequence import IntermediateTensors
25
from vllm.utils.math_utils import cdiv
26
from vllm.utils.platform_utils import (
27
28
29
    is_pin_memory_available,
    is_uva_available,
)
30
31
32
33
from vllm.utils.torch_utils import (
    direct_register_custom_op,
    get_cuda_view_from_cpu_tensor,
)
34
35

logger = init_logger(__name__)
36

37
WeightsMapping = Mapping[str, str | None]
38
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
39

40

41
42
43
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
44

45
46
47
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
48

49
50
51
52
53
54
55
56
    def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
        """Combine two `WeightsMapper`s by merging their mappings."""
        return WeightsMapper(
            orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
            orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
            orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
        )

57
    def _map_name(self, key: str) -> str | None:
58
59
60
61
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
80

81
    def apply(
82
83
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
84
85
86
87
88
        return (
            (out_name, data)
            for name, data in weights
            if (out_name := self._map_name(name)) is not None
        )
89

90
91
    def apply_list(self, values: list[str]) -> list[str]:
        return [
92
93
            out_name
            for name in values
94
95
96
97
98
99
100
101
102
103
            if (out_name := self._map_name(name)) is not None
        ]

    def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
        return {
            out_name: value
            for name, value in values.items()
            if (out_name := self._map_name(name)) is not None
        }

104
105

class AutoWeightsLoader:
106
    """
107
    Helper class to load weights into a [`torch.nn.Module`][]. It is able
108
109
110
111
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
112
    by defining a `load_weights` method.
113
114

    Similarly, the weight loading logic for individual parameters can be
115
    overridden by defining a `weight_loader` method.
116
117

    Detailed weight loading information can be viewed by setting the
118
    environment variable `VLLM_LOGGING_LEVEL=DEBUG`.
119
    """
120

121
122
    # Models trained using early version ColossalAI or quantized by
    # GPTQModel may include these tensors in checkpoint. Skip them.
123
    ROTARY_EMBEDS_UNUSED_WEIGHTS = [
124
        "rotary_pos_emb.inv_freq",
125
126
127
128
129
        "rotary_emb.inv_freq",
        "rotary_emb.cos_cached",
        "rotary_emb.sin_cached",
    ]

130
131
132
133
    def __init__(
        self,
        module: nn.Module,
        *,
134
135
136
137
        skip_prefixes: list[str] | None = None,
        skip_substrs: list[str] | None = None,
        ignore_unexpected_prefixes: list[str] | None = None,
        ignore_unexpected_suffixes: list[str] | None = None,
138
139
140
141
142
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
143
        self.skip_substrs = skip_substrs or []
144
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
145
        self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
146
147
        # update default skip_substrs
        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
148
149
150

    def _groupby_prefix(
        self,
151
152
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
153
154
155
156
        weights_by_parts = (
            (weight_name.split(".", 1), weight_data)
            for weight_name, weight_data in weights
        )
157

158
        for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]):
159
160
161
162
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
163
164
165
166
                (
                    ("" if len(parts) == 1 else parts[1], weights_data)
                    for parts, weights_data in group
                ),
167
168
169
170
171
172
173
174
175
176
177
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
178
179
180
        return any(qualname.startswith(p) for p in self.skip_prefixes) or any(
            substr in qualname for substr in self.skip_substrs
        )
181
182

    def _can_ignore_unexpected(self, qualname: str) -> bool:
183
184
185
        iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
        ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
        return any(iup) or any(ius)
186
187
188
189
190

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
191
        weights: Iterable[tuple[str, torch.Tensor]],
192
    ) -> Iterable[str]:
193
194
195
196
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
197
198
                logger.debug("Skipping weight %s", weight_qualname)

199
200
201
                continue

            if weight_name != "":
202
203
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
204

205
206
207
208
                    continue

                raise ValueError(
                    f"Attempted to load nested weight '{weight_qualname}' "
209
210
                    f"into a single parameter '{base_prefix}'"
                )
211

212
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
213
214
            weight_loader(param, weight_data)

215
            logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape)
216

217
218
            yield weight_qualname

219
220
221
    def _add_loadable_non_param_tensors(
        self, module: nn.Module, child_params: dict[str, torch.Tensor]
    ):
222
223
224
225
        """
        Add tensor names that are not in the model params that may be in the
        safetensors, e.g., batch normalization stats.
        """
226
227
228
        if isinstance(
            module,
            (
229
230
231
232
233
234
235
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.LazyBatchNorm1d,
                nn.LazyBatchNorm2d,
                nn.LazyBatchNorm3d,
                nn.SyncBatchNorm,
236
237
            ),
        ):
238
            module_state_dict = module.state_dict()
239
            for stat_name in ("running_mean", "running_var", "num_batches_tracked"):
240
241
                child_params[stat_name] = module_state_dict[stat_name]

242
243
244
245
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
246
        weights: Iterable[tuple[str, torch.Tensor]],
247
    ) -> Iterable[str]:
248
249
250
251
252
253
254
255
        if isinstance(module, PPMissingLayer):
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
256
                loaded_params = module_load_weights(weights)
257
258
                if loaded_params is None:
                    logger.warning(
259
260
                        "Unable to collect loaded parameters for module %s", module
                    )
261
262
263
264
265
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
266
267
268
269

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

270
271
272
273
        # Add missing tensors the weight loader needs to be able to load
        # that aren't registered as params, e.g., batchnorm statistics.
        self._add_loadable_non_param_tensors(module, child_params)

274
275
276
277
        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
278
279
280
281
282
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

283
284
285
                yield from self._load_module(
                    prefix, child_modules[child_prefix], child_weights
                )
286
            elif child_prefix in child_params:
287
288
289
290
291
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

292
293
294
                yield from self._load_param(
                    prefix, child_params[child_prefix], child_weights
                )
295
            else:
296
297
298
299
300
301
302
303
304
305
306
307
308
309
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

310
311
312
313
                msg = (
                    f"There is no module or parameter named '{prefix}' "
                    f"in {type(self.module).__name__}"
                )
314
                raise ValueError(msg)
315
316
317

    def load_weights(
        self,
318
        weights: Iterable[tuple[str, torch.Tensor]],
319
        *,
320
        mapper: WeightsMapper | None = None,
321
    ) -> set[str]:
322
323
        if mapper is not None:
            weights = mapper.apply(weights)
324
        # filter out weights with first-prefix/substr to skip in name
325
326
327
        weights = (
            (name, weight) for name, weight in weights if not self._can_skip(name)
        )
328

329
        autoloaded_weights = set(self._load_module("", self.module, weights))
330
        return autoloaded_weights
331
332


333
def init_vllm_registered_model(
334
    vllm_config: VllmConfig,
335
    *,
336
    prefix: str = "",
337
338
    hf_config: PretrainedConfig | None = None,
    architectures: list[str] | None = None,
339
340
341
342
343
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
344
    from vllm.model_executor.model_loader.utils import initialize_model
345

346
347
348
349
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

350
    if hf_config is not None:
351
        vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures)
352

353
    return initialize_model(vllm_config=vllm_config, prefix=prefix)
354
355


356
@overload
357
def flatten_bn(x: torch.Tensor) -> torch.Tensor: ...
358
359
360


@overload
361
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ...
362
363
364
365


@overload
def flatten_bn(
366
    x: list[torch.Tensor] | torch.Tensor,
367
368
    *,
    concat: Literal[True],
369
) -> torch.Tensor: ...
370
371


372
373
@overload
def flatten_bn(
374
    x: list[torch.Tensor] | torch.Tensor,
375
376
    *,
    concat: bool = False,
377
) -> list[torch.Tensor] | torch.Tensor: ...
378
379


380
def flatten_bn(
381
    x: list[torch.Tensor] | torch.Tensor,
382
383
    *,
    concat: bool = False,
384
) -> list[torch.Tensor] | torch.Tensor:
385
    """
386
    Flatten the `B` and `N` dimensions of batched multimodal inputs.
387

388
    The input tensor should have shape `(B, N, ...)`.
389
390
391
392
393
394
395
396
397
398
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


399
400
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
401
402
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
403
404
405
    """

    if isinstance(embeddings, torch.Tensor):
406
407
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
408
409
410
411
412
413
414
415
416
417
418
419
420

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

421
    return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
422
423


424
425
426
427
428
429
430
431
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
    ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
    for num in lst:
        index = num // interval
        ranges[index].append(num)
    return ranges


Cyrus Leung's avatar
Cyrus Leung committed
432
433
434
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
435
    is_multimodal: torch.Tensor,
Cyrus Leung's avatar
Cyrus Leung committed
436
) -> torch.Tensor:
437
    """
438
439
440
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
441
442

    Note:
443
        This updates `inputs_embeds` in place.
444
    """
445
446
447
448
449
450
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

451
    try:
452
453
454
455
456
        # For debugging
        # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)

        # NOTE: This can avoid D2H sync (#22105), but fails to
        # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
457
458
459
        inputs_embeds.masked_scatter_(
            is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
        )
460
    except RuntimeError as e:
461
        num_actual_tokens = len(mm_embeds_flat)
462
463
        num_expected_tokens = is_multimodal.sum().item()

464
        if num_actual_tokens != num_expected_tokens:
465
            expr = _embedding_count_expression(multimodal_embeddings)
466

467
            raise ValueError(
468
                f"Attempted to assign {expr} = {num_actual_tokens} "
469
470
                f"multimodal tokens to {num_expected_tokens} placeholders"
            ) from e
Cyrus Leung's avatar
Cyrus Leung committed
471

472
        raise ValueError("Error during masked scatter operation") from e
Cyrus Leung's avatar
Cyrus Leung committed
473

474
    return inputs_embeds
Cyrus Leung's avatar
Cyrus Leung committed
475
476


477
478
@deprecated(
    "`merge_multimodal_embeddings` has been replaced with "
479
    "`SupportsMultiModal.embed_input_ids` and will be "
480
481
    "removed in v0.12."
)
Cyrus Leung's avatar
Cyrus Leung committed
482
483
484
485
def merge_multimodal_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
486
    placeholder_token_id: int | list[int],
Cyrus Leung's avatar
Cyrus Leung committed
487
488
) -> torch.Tensor:
    """
489
490
491
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
492

493
    `placeholder_token_id` can be a list of token ids (e.g, token ids
494
    of img_start, img_break, and img_end tokens) when needed: This means
495
496
    the order of these tokens in the `input_ids` MUST MATCH the order of
    their embeddings in `multimodal_embeddings` since we need to
497
498
499
500
501
502
503
504
    slice-merge instead of individually scattering.

    For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
    - T is text token
    - S is image start token
    - I is image embedding token
    - B is image break token
    - E is image end token.
505
506
507

    Then the image embeddings (that correspond to I's) from vision encoder
    must be padded with embeddings of S, B, and E in the same order of
508
    input_ids for a correct embedding merge.
Cyrus Leung's avatar
Cyrus Leung committed
509
510

    Note:
511
        This updates `inputs_embeds` in place.
Cyrus Leung's avatar
Cyrus Leung committed
512
    """
513
    if isinstance(placeholder_token_id, list):
514
515
        is_multimodal = isin_list(input_ids, placeholder_token_id)
    else:
516
        is_multimodal = input_ids == placeholder_token_id
517

Cyrus Leung's avatar
Cyrus Leung committed
518
519
    return _merge_multimodal_embeddings(
        inputs_embeds,
520
521
        multimodal_embeddings=multimodal_embeddings,
        is_multimodal=is_multimodal,
Cyrus Leung's avatar
Cyrus Leung committed
522
523
524
    )


525
526
527
528
529
530
531
532
533
534
535
536
def isin_list(
    elements: torch.Tensor,
    test_elements_list: list[int],
) -> torch.Tensor:
    test_elements = torch.tensor(
        test_elements_list,
        pin_memory=is_pin_memory_available(),
    ).to(device=elements.device, non_blocking=True)

    return torch.isin(elements, test_elements)


537
class LayerFn(Protocol):
538
    def __call__(self, prefix: str) -> torch.nn.Module: ...
539
540


541
542
543
544
545
546
547
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
548
549

    def forward(self, *args, **kwargs):
550
551
        """Return the first arg from args or the first value from kwargs."""
        return args[0] if args else next(iter(kwargs.values()))
552
553


554
555
556
557
558
559
560
561
562
563
564
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
565
566
567
568
    if (params := next(module.parameters(), None)) is None:
        return module

    device = params.device
569
570
571
572
573
574
575
576
577

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()
578
579
    uva_available = is_uva_available()

580
581
    assert uva_available, "V1 CPU offloading requires uva (pin memory) support"
    uva_offloading = True
582
583
584

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
585
    offloaded_parameters = False
586
587
588
589
590
591
592
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
593
594
595
596
597
598
599
600
        cpu_data = torch.empty_strided(
            size=p.data.size(),
            stride=p.data.stride(),
            dtype=p.data.dtype,
            layout=p.data.layout,
            device="cpu",
            pin_memory=pin_memory,
        )
601
        cpu_data.copy_(p.data)
602
603
604
605
606
607
        if not uva_offloading:
            p.data = cpu_data
        else:
            # keep the cpu data alive
            p._vllm_offloaded_cpu_data = cpu_data
            p.data = get_cuda_view_from_cpu_tensor(cpu_data)
608
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
609
610
        offloaded_parameters = True

611
    if offloaded_parameters and not uva_offloading:
612
613
614
615
616
617
618
619
620
621
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
622
            output = functional_call(module, device_state, args=args, kwargs=kwargs)
623
624
            module.forward = forward
            return output
625
626
627
628
629
630

        module.forward = forward

    return module


631
def make_layers(
632
633
634
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
635
) -> tuple[int, int, torch.nn.ModuleList]:
636
637
638
639
640
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
641
642
643
644

    start_layer, end_layer = get_pp_indices(
        num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size
    )
645
    modules = torch.nn.ModuleList(
646
647
        [PPMissingLayer() for _ in range(start_layer)]
        + [
648
649
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
650
651
652
        ]
        + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
    )
653
654
655
656
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
657
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
658
659


660
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
661
662
663
664
665
666
667
668
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
669
670
671
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
672
            missing_layer_names.append(name + ".")
673
674
675
676
677
678
679
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
680
681
682
683
684
    if isinstance(model, PPMissingLayer):
        return True

    return any(
        name.startswith(missing_layer_name)
685
686
        for missing_layer_name in get_pp_missing_layer_names(model)
    )
687
688


689
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
690
    def make_empty_intermediate_tensors(
691
692
693
694
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
695
696
697
698
699
700
        return IntermediateTensors(
            {
                key: torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
                for key in keys
            }
        )
701
702

    return make_empty_intermediate_tensors
703
704


705
706
707
708
709
710
711
712
713
714
715
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
716
717


XuruiYang's avatar
XuruiYang committed
718
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
719
720
721
722
723
724
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
XuruiYang's avatar
XuruiYang committed
725
    - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
726
727
    """
    subnames = layer_name.split(".")
728
    int_vals: list[int] = []
729
730
731
732
733
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
XuruiYang's avatar
XuruiYang committed
734
    if num_attn_module == 1 or "attn" not in layer_name:
735
736
737
        assert len(int_vals) == 1, (
            f"layer name {layer_name} should only contain one integer"
        )
XuruiYang's avatar
XuruiYang committed
738
739
740

        return int_vals[0]
    else:
741
742
743
744
745
746
747
748
        assert len(int_vals) <= 2, (
            f"layer name {layer_name} should contain most two integers"
        )
        layer_index = (
            int_vals[0] * num_attn_module + int_vals[1]
            if len(int_vals) == 2
            else int_vals[0]
        )
XuruiYang's avatar
XuruiYang committed
749
        return layer_index
750
751
752
753
754
755
756
757
758


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
759
    return tensors
760
761


762
763
764
def fast_topk(
    values: torch.Tensor, topk: int, dim: int
) -> tuple[torch.Tensor, torch.Tensor]:
765
766
    """
    Optimized topk implementation that uses torch.max for k=1 case.
767

768
769
    This function provides better performance for the common case of k=1
    by using torch.max instead of the more general torch.topk.
770

771
772
773
774
    Args:
        values: Input tensor to find top-k values from
        topk: Number of top values to return (k). Must be > 0.
        dim: Dimension along which to compute topk
775

776
777
778
779
    Returns:
        Tuple of (values, indices) where values are the top-k values
        and indices are their corresponding indices in the input tensor
    """
780
781
782
783
784
785
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
786
787


788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
    return torch.ops.vllm.sequence_parallel_chunk_impl(x)


def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    # all_gather needs the sequence length to be divisible by tp_size
    seq_len = x.size(0)
    remainder = seq_len % tp_size
    if remainder != 0:
        pad_len = tp_size - remainder
        y = nn.functional.pad(x, (0, 0, 0, pad_len))
    else:
        y = x

    chunk = y.shape[0] // tp_size
    start = tp_rank * chunk
    return torch.narrow(y, 0, start, chunk)


def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    seq_len = cdiv(x.size(0), tp_size)
    shape = list(x.shape)
    shape[0] = seq_len
    out = torch.empty(shape, dtype=x.dtype, device=x.device)
    return out


direct_register_custom_op(
    op_name="sequence_parallel_chunk_impl",
    op_func=sequence_parallel_chunk_impl,
    fake_impl=sequence_parallel_chunk_impl_fake,
827
    tags=(torch.Tag.needs_fixed_stride_order,),
828
)
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850


def process_eagle_weight(
    model: nn.Module,
    name: str,
) -> None:
    """
    Update EAGLE model flags based on loaded weight name.
    This should be called during weight loading to detect if a model
    has its own lm_head or embed_tokens weight.
    Args:
        model: The model instance (must support EAGLE)
        name: The name of the weight to process
    """
    if not supports_any_eagle(model):
        return

    # To prevent overriding with target model's layers
    if "lm_head" in name:
        model.has_own_lm_head = True
    if "embed_tokens" in name:
        model.has_own_embed_tokens = True