utils.py 28.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Iterable, Mapping
6
from dataclasses import dataclass, field
7
from typing import Any, Literal, Protocol, overload
8

9
import torch
10
import torch.nn as nn
11
from torch.func import functional_call
12
from transformers import PretrainedConfig
13
from typing_extensions import deprecated
14

15
from vllm.config import VllmConfig
16
17
18
19
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
20
from vllm.logger import init_logger
21
22
23
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig,
)
24
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
from vllm.model_executor.models.interfaces import supports_any_eagle
26
from vllm.multimodal import NestedTensors
27
from vllm.sequence import IntermediateTensors
28
from vllm.utils.math_utils import cdiv
29
from vllm.utils.platform_utils import (
30
31
32
    is_pin_memory_available,
    is_uva_available,
)
33
34
35
36
from vllm.utils.torch_utils import (
    direct_register_custom_op,
    get_cuda_view_from_cpu_tensor,
)
37
38

logger = init_logger(__name__)
39

40
WeightsMapping = Mapping[str, str | None]
41
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
42

43

44
45
46
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
47

48
49
50
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
51

52
53
54
55
56
57
58
59
    def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
        """Combine two `WeightsMapper`s by merging their mappings."""
        return WeightsMapper(
            orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
            orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
            orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
        )

60
    def _map_name(self, key: str) -> str | None:
61
62
63
64
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
83

84
    def apply(
85
86
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
87
88
89
90
91
        return (
            (out_name, data)
            for name, data in weights
            if (out_name := self._map_name(name)) is not None
        )
92

93
94
    def apply_list(self, values: list[str]) -> list[str]:
        return [
95
96
            out_name
            for name in values
97
98
99
100
101
102
103
104
105
106
            if (out_name := self._map_name(name)) is not None
        ]

    def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
        return {
            out_name: value
            for name, value in values.items()
            if (out_name := self._map_name(name)) is not None
        }

107
108

class AutoWeightsLoader:
109
    """
110
    Helper class to load weights into a [`torch.nn.Module`][]. It is able
111
112
113
114
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
115
    by defining a `load_weights` method.
116
117

    Similarly, the weight loading logic for individual parameters can be
118
    overridden by defining a `weight_loader` method.
119
120

    Detailed weight loading information can be viewed by setting the
121
    environment variable `VLLM_LOGGING_LEVEL=DEBUG`.
122
    """
123

124
125
    # Models trained using early version ColossalAI or quantized by
    # GPTQModel may include these tensors in checkpoint. Skip them.
126
    ROTARY_EMBEDS_UNUSED_WEIGHTS = [
127
        "rotary_pos_emb.inv_freq",
128
129
130
131
132
        "rotary_emb.inv_freq",
        "rotary_emb.cos_cached",
        "rotary_emb.sin_cached",
    ]

133
134
135
136
    def __init__(
        self,
        module: nn.Module,
        *,
137
138
139
140
        skip_prefixes: list[str] | None = None,
        skip_substrs: list[str] | None = None,
        ignore_unexpected_prefixes: list[str] | None = None,
        ignore_unexpected_suffixes: list[str] | None = None,
141
142
143
144
145
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
146
        self.skip_substrs = skip_substrs or []
147
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
148
        self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
149
150
        # update default skip_substrs
        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
151
152
153

    def _groupby_prefix(
        self,
154
155
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
156
157
158
159
        weights_by_parts = (
            (weight_name.split(".", 1), weight_data)
            for weight_name, weight_data in weights
        )
160

161
        for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]):
162
163
164
165
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
166
167
168
169
                (
                    ("" if len(parts) == 1 else parts[1], weights_data)
                    for parts, weights_data in group
                ),
170
171
172
173
174
175
176
177
178
179
180
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
181
182
183
        return any(qualname.startswith(p) for p in self.skip_prefixes) or any(
            substr in qualname for substr in self.skip_substrs
        )
184
185

    def _can_ignore_unexpected(self, qualname: str) -> bool:
186
187
188
        iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
        ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
        return any(iup) or any(ius)
189
190
191
192
193

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
194
        weights: Iterable[tuple[str, torch.Tensor]],
195
    ) -> Iterable[str]:
196
197
198
199
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
200
201
                logger.debug("Skipping weight %s", weight_qualname)

202
203
204
                continue

            if weight_name != "":
205
206
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
207

208
209
210
211
                    continue

                raise ValueError(
                    f"Attempted to load nested weight '{weight_qualname}' "
212
213
                    f"into a single parameter '{base_prefix}'"
                )
214

215
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
216
217
            weight_loader(param, weight_data)

218
            logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape)
219

220
221
            yield weight_qualname

222
223
224
    def _add_loadable_non_param_tensors(
        self, module: nn.Module, child_params: dict[str, torch.Tensor]
    ):
225
226
227
228
        """
        Add tensor names that are not in the model params that may be in the
        safetensors, e.g., batch normalization stats.
        """
229
230
231
        if isinstance(
            module,
            (
232
233
234
235
236
237
238
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.LazyBatchNorm1d,
                nn.LazyBatchNorm2d,
                nn.LazyBatchNorm3d,
                nn.SyncBatchNorm,
239
240
            ),
        ):
241
            module_state_dict = module.state_dict()
242
            for stat_name in ("running_mean", "running_var", "num_batches_tracked"):
243
244
                child_params[stat_name] = module_state_dict[stat_name]

245
246
247
248
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
249
        weights: Iterable[tuple[str, torch.Tensor]],
250
    ) -> Iterable[str]:
251
252
253
254
255
256
257
258
        if isinstance(module, PPMissingLayer):
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
259
                loaded_params = module_load_weights(weights)
260
261
                if loaded_params is None:
                    logger.warning(
262
263
                        "Unable to collect loaded parameters for module %s", module
                    )
264
265
266
267
268
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
269
270
271
272

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

273
274
275
276
        # Add missing tensors the weight loader needs to be able to load
        # that aren't registered as params, e.g., batchnorm statistics.
        self._add_loadable_non_param_tensors(module, child_params)

277
278
279
280
        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
281
282
283
284
285
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

286
287
288
                yield from self._load_module(
                    prefix, child_modules[child_prefix], child_weights
                )
289
            elif child_prefix in child_params:
290
291
292
293
294
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

295
296
297
                yield from self._load_param(
                    prefix, child_params[child_prefix], child_weights
                )
298
            else:
299
300
301
302
303
304
305
306
307
308
309
310
311
312
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

313
314
315
316
                msg = (
                    f"There is no module or parameter named '{prefix}' "
                    f"in {type(self.module).__name__}"
                )
317
                raise ValueError(msg)
318
319
320

    def load_weights(
        self,
321
        weights: Iterable[tuple[str, torch.Tensor]],
322
        *,
323
        mapper: WeightsMapper | None = None,
324
    ) -> set[str]:
325
326
        if mapper is not None:
            weights = mapper.apply(weights)
327
        # filter out weights with first-prefix/substr to skip in name
328
329
330
        weights = (
            (name, weight) for name, weight in weights if not self._can_skip(name)
        )
331

332
        autoloaded_weights = set(self._load_module("", self.module, weights))
333
        return autoloaded_weights
334
335


336
def init_vllm_registered_model(
337
    vllm_config: VllmConfig,
338
    *,
339
    prefix: str = "",
340
341
    hf_config: PretrainedConfig | None = None,
    architectures: list[str] | None = None,
342
343
344
345
346
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
347
    from vllm.model_executor.model_loader.utils import initialize_model
348

349
350
351
352
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

353
    if hf_config is not None:
354
        vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures)
355

356
    return initialize_model(vllm_config=vllm_config, prefix=prefix)
357
358


359
@overload
360
def flatten_bn(x: torch.Tensor) -> torch.Tensor: ...
361
362
363


@overload
364
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ...
365
366
367
368


@overload
def flatten_bn(
369
    x: list[torch.Tensor] | torch.Tensor,
370
371
    *,
    concat: Literal[True],
372
) -> torch.Tensor: ...
373
374


375
376
@overload
def flatten_bn(
377
    x: list[torch.Tensor] | torch.Tensor,
378
379
    *,
    concat: bool = False,
380
) -> list[torch.Tensor] | torch.Tensor: ...
381
382


383
def flatten_bn(
384
    x: list[torch.Tensor] | torch.Tensor,
385
386
    *,
    concat: bool = False,
387
) -> list[torch.Tensor] | torch.Tensor:
388
    """
389
    Flatten the `B` and `N` dimensions of batched multimodal inputs.
390

391
    The input tensor should have shape `(B, N, ...)`.
392
393
394
395
396
397
398
399
400
401
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


402
403
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
404
405
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
406
407
408
    """

    if isinstance(embeddings, torch.Tensor):
409
410
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
411
412
413
414
415
416
417
418
419
420
421
422
423

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

424
    return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
425
426


427
428
429
430
431
432
433
434
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
    ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
    for num in lst:
        index = num // interval
        ranges[index].append(num)
    return ranges


Cyrus Leung's avatar
Cyrus Leung committed
435
436
437
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
438
    is_multimodal: torch.Tensor,
Cyrus Leung's avatar
Cyrus Leung committed
439
) -> torch.Tensor:
440
    """
441
442
443
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
444
445

    Note:
446
        This updates `inputs_embeds` in place.
447
    """
448
449
450
451
452
453
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

454
    try:
455
456
457
458
459
        # For debugging
        # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)

        # NOTE: This can avoid D2H sync (#22105), but fails to
        # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
460
461
462
        inputs_embeds.masked_scatter_(
            is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
        )
463
    except RuntimeError as e:
464
        num_actual_tokens = len(mm_embeds_flat)
465
466
        num_expected_tokens = is_multimodal.sum().item()

467
        if num_actual_tokens != num_expected_tokens:
468
            expr = _embedding_count_expression(multimodal_embeddings)
469

470
            raise ValueError(
471
                f"Attempted to assign {expr} = {num_actual_tokens} "
472
473
                f"multimodal tokens to {num_expected_tokens} placeholders"
            ) from e
Cyrus Leung's avatar
Cyrus Leung committed
474

475
        raise ValueError("Error during masked scatter operation") from e
Cyrus Leung's avatar
Cyrus Leung committed
476

477
    return inputs_embeds
Cyrus Leung's avatar
Cyrus Leung committed
478
479


480
481
@deprecated(
    "`merge_multimodal_embeddings` has been replaced with "
482
    "`SupportsMultiModal.embed_input_ids` and will be "
483
484
    "removed in v0.12."
)
Cyrus Leung's avatar
Cyrus Leung committed
485
486
487
488
def merge_multimodal_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
489
    placeholder_token_id: int | list[int],
Cyrus Leung's avatar
Cyrus Leung committed
490
491
) -> torch.Tensor:
    """
492
493
494
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
495

496
    `placeholder_token_id` can be a list of token ids (e.g, token ids
497
    of img_start, img_break, and img_end tokens) when needed: This means
498
499
    the order of these tokens in the `input_ids` MUST MATCH the order of
    their embeddings in `multimodal_embeddings` since we need to
500
501
502
503
504
505
506
507
    slice-merge instead of individually scattering.

    For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
    - T is text token
    - S is image start token
    - I is image embedding token
    - B is image break token
    - E is image end token.
508
509
510

    Then the image embeddings (that correspond to I's) from vision encoder
    must be padded with embeddings of S, B, and E in the same order of
511
    input_ids for a correct embedding merge.
Cyrus Leung's avatar
Cyrus Leung committed
512
513

    Note:
514
        This updates `inputs_embeds` in place.
Cyrus Leung's avatar
Cyrus Leung committed
515
    """
516
    if isinstance(placeholder_token_id, list):
517
518
        is_multimodal = isin_list(input_ids, placeholder_token_id)
    else:
519
        is_multimodal = input_ids == placeholder_token_id
520

Cyrus Leung's avatar
Cyrus Leung committed
521
522
    return _merge_multimodal_embeddings(
        inputs_embeds,
523
524
        multimodal_embeddings=multimodal_embeddings,
        is_multimodal=is_multimodal,
Cyrus Leung's avatar
Cyrus Leung committed
525
526
527
    )


528
529
530
531
532
533
534
535
536
537
538
539
def isin_list(
    elements: torch.Tensor,
    test_elements_list: list[int],
) -> torch.Tensor:
    test_elements = torch.tensor(
        test_elements_list,
        pin_memory=is_pin_memory_available(),
    ).to(device=elements.device, non_blocking=True)

    return torch.isin(elements, test_elements)


540
class LayerFn(Protocol):
541
    def __call__(self, prefix: str) -> torch.nn.Module: ...
542
543


544
545
546
547
548
549
550
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
551
552

    def forward(self, *args, **kwargs):
553
554
        """Return the first arg from args or the first value from kwargs."""
        return args[0] if args else next(iter(kwargs.values()))
555
556


557
558
559
560
561
562
563
564
565
566
567
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
568
569
570
571
    if (params := next(module.parameters(), None)) is None:
        return module

    device = params.device
572
573
574
575
576
577
578
579
580

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()
581
582
    uva_available = is_uva_available()

583
584
    assert uva_available, "V1 CPU offloading requires uva (pin memory) support"
    uva_offloading = True
585
586
587

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
588
    offloaded_parameters = False
589
590
591
592
593
594
595
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
596
597
598
599
600
601
602
603
        cpu_data = torch.empty_strided(
            size=p.data.size(),
            stride=p.data.stride(),
            dtype=p.data.dtype,
            layout=p.data.layout,
            device="cpu",
            pin_memory=pin_memory,
        )
604
        cpu_data.copy_(p.data)
605
606
607
608
609
610
        if not uva_offloading:
            p.data = cpu_data
        else:
            # keep the cpu data alive
            p._vllm_offloaded_cpu_data = cpu_data
            p.data = get_cuda_view_from_cpu_tensor(cpu_data)
611
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
612
613
        offloaded_parameters = True

614
    if offloaded_parameters and not uva_offloading:
615
616
617
618
619
620
621
622
623
624
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
625
            output = functional_call(module, device_state, args=args, kwargs=kwargs)
626
627
            module.forward = forward
            return output
628
629
630
631
632
633

        module.forward = forward

    return module


634
def make_layers(
635
636
637
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
638
) -> tuple[int, int, torch.nn.ModuleList]:
639
640
641
642
643
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
644
645
646
647

    start_layer, end_layer = get_pp_indices(
        num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size
    )
648
    modules = torch.nn.ModuleList(
649
650
        [PPMissingLayer() for _ in range(start_layer)]
        + [
651
652
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
653
654
655
        ]
        + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
    )
656
657
658
659
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
660
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
661
662


663
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
664
665
666
667
668
669
670
671
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
672
673
674
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
675
            missing_layer_names.append(name + ".")
676
677
678
679
680
681
682
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
683
684
685
686
687
    if isinstance(model, PPMissingLayer):
        return True

    return any(
        name.startswith(missing_layer_name)
688
689
        for missing_layer_name in get_pp_missing_layer_names(model)
    )
690
691


692
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
693
    def make_empty_intermediate_tensors(
694
695
696
697
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
698
699
700
701
702
703
        return IntermediateTensors(
            {
                key: torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
                for key in keys
            }
        )
704
705

    return make_empty_intermediate_tensors
706
707


708
709
710
711
712
713
714
715
716
717
718
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
719
720


721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
def get_draft_quant_config(
    vllm_config: VllmConfig,
) -> QuantizationConfig | None:
    """Get quantization config for Draft models.

    Draft models should use their own quantization config instead of the verifier/target
    model's config. This helper retrieves the draft model's quantization config.

    Args:
        vllm_config: The vLLM configuration object.

    Returns:
        The draft model's config if available, None otherwise.
    """
    draft_model_config = vllm_config.speculative_config.draft_model_config
    draft_load_config = vllm_config.load_config

    return (
        VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
        if draft_model_config
        else None
    )


XuruiYang's avatar
XuruiYang committed
745
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
746
747
748
749
750
751
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
XuruiYang's avatar
XuruiYang committed
752
    - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
753
754
    """
    subnames = layer_name.split(".")
755
    int_vals: list[int] = []
756
757
758
759
760
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
XuruiYang's avatar
XuruiYang committed
761
    if num_attn_module == 1 or "attn" not in layer_name:
762
763
764
        assert len(int_vals) == 1, (
            f"layer name {layer_name} should only contain one integer"
        )
XuruiYang's avatar
XuruiYang committed
765
766
767

        return int_vals[0]
    else:
768
769
770
771
772
773
774
775
        assert len(int_vals) <= 2, (
            f"layer name {layer_name} should contain most two integers"
        )
        layer_index = (
            int_vals[0] * num_attn_module + int_vals[1]
            if len(int_vals) == 2
            else int_vals[0]
        )
XuruiYang's avatar
XuruiYang committed
776
        return layer_index
777
778
779
780
781
782
783
784
785


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
786
    return tensors
787
788


789
790
791
def fast_topk(
    values: torch.Tensor, topk: int, dim: int
) -> tuple[torch.Tensor, torch.Tensor]:
792
793
    """
    Optimized topk implementation that uses torch.max for k=1 case.
794

795
796
    This function provides better performance for the common case of k=1
    by using torch.max instead of the more general torch.topk.
797

798
799
800
801
    Args:
        values: Input tensor to find top-k values from
        topk: Number of top values to return (k). Must be > 0.
        dim: Dimension along which to compute topk
802

803
804
805
806
    Returns:
        Tuple of (values, indices) where values are the top-k values
        and indices are their corresponding indices in the input tensor
    """
807
808
809
810
811
812
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
813
814


815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
    return torch.ops.vllm.sequence_parallel_chunk_impl(x)


def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    # all_gather needs the sequence length to be divisible by tp_size
    seq_len = x.size(0)
    remainder = seq_len % tp_size
    if remainder != 0:
        pad_len = tp_size - remainder
        y = nn.functional.pad(x, (0, 0, 0, pad_len))
    else:
        y = x

    chunk = y.shape[0] // tp_size
    start = tp_rank * chunk
    return torch.narrow(y, 0, start, chunk)


def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    seq_len = cdiv(x.size(0), tp_size)
    shape = list(x.shape)
    shape[0] = seq_len
    out = torch.empty(shape, dtype=x.dtype, device=x.device)
    return out


direct_register_custom_op(
    op_name="sequence_parallel_chunk_impl",
    op_func=sequence_parallel_chunk_impl,
    fake_impl=sequence_parallel_chunk_impl_fake,
854
    tags=(torch.Tag.needs_fixed_stride_order,),
855
)
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877


def process_eagle_weight(
    model: nn.Module,
    name: str,
) -> None:
    """
    Update EAGLE model flags based on loaded weight name.
    This should be called during weight loading to detect if a model
    has its own lm_head or embed_tokens weight.
    Args:
        model: The model instance (must support EAGLE)
        name: The name of the weight to process
    """
    if not supports_any_eagle(model):
        return

    # To prevent overriding with target model's layers
    if "lm_head" in name:
        model.has_own_lm_head = True
    if "embed_tokens" in name:
        model.has_own_embed_tokens = True