utils.py 31.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Iterable, Mapping
from contextlib import contextmanager
7
from dataclasses import dataclass, field
8
from typing import Any, Literal, Protocol, overload
9

10
import torch
11
import torch.nn as nn
12
from torch.func import functional_call
13
from torch.nn.modules.module import register_module_module_registration_hook
14
from transformers import PretrainedConfig
15

16
import vllm.envs as envs
17
from vllm.config import VllmConfig
18
19
20
21
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
22
from vllm.logger import init_logger
23
24
25
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig,
)
26
from vllm.model_executor.model_loader.reload import (
27
28
    support_quantized_model_reload_from_hp_weights,
)
29
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
from vllm.model_executor.models.interfaces import supports_any_eagle
31
from vllm.multimodal import NestedTensors
32
from vllm.sequence import IntermediateTensors
33
from vllm.utils.math_utils import cdiv
34
from vllm.utils.mem_utils import format_gib
35
from vllm.utils.platform_utils import (
36
37
38
    is_pin_memory_available,
    is_uva_available,
)
39
40
from vllm.utils.torch_utils import (
    direct_register_custom_op,
41
    get_accelerator_view_from_cpu_tensor,
42
)
43
44

logger = init_logger(__name__)
45

46
WeightsMapping = Mapping[str, str | None]
47
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
48

49

50
51
52
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
53

54
55
56
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
57

58
59
60
61
62
63
64
65
    def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
        """Combine two `WeightsMapper`s by merging their mappings."""
        return WeightsMapper(
            orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
            orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
            orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
        )

66
    def _map_name(self, key: str) -> str | None:
67
68
69
70
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
89

90
    def apply(
91
92
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
93
94
95
96
97
        return (
            (out_name, data)
            for name, data in weights
            if (out_name := self._map_name(name)) is not None
        )
98

99
100
    def apply_list(self, values: list[str]) -> list[str]:
        return [
101
102
            out_name
            for name in values
103
104
105
106
107
108
109
110
111
112
            if (out_name := self._map_name(name)) is not None
        ]

    def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
        return {
            out_name: value
            for name, value in values.items()
            if (out_name := self._map_name(name)) is not None
        }

113
114

class AutoWeightsLoader:
115
    """
116
    Helper class to load weights into a [`torch.nn.Module`][]. It is able
117
118
119
120
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
121
    by defining a `load_weights` method.
122
123

    Similarly, the weight loading logic for individual parameters can be
124
    overridden by defining a `weight_loader` method.
125
126

    Detailed weight loading information can be viewed by setting the
127
    environment variable `VLLM_LOGGING_LEVEL=DEBUG`.
128
    """
129

130
131
    # Models trained using early version ColossalAI or quantized by
    # GPTQModel may include these tensors in checkpoint. Skip them.
132
    ROTARY_EMBEDS_UNUSED_WEIGHTS = [
133
        "rotary_pos_emb.inv_freq",
134
135
136
137
138
        "rotary_emb.inv_freq",
        "rotary_emb.cos_cached",
        "rotary_emb.sin_cached",
    ]

139
140
141
142
    def __init__(
        self,
        module: nn.Module,
        *,
143
144
145
146
        skip_prefixes: list[str] | None = None,
        skip_substrs: list[str] | None = None,
        ignore_unexpected_prefixes: list[str] | None = None,
        ignore_unexpected_suffixes: list[str] | None = None,
147
148
149
150
151
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
152
        self.skip_substrs = skip_substrs or []
153
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
154
        self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
155
156
        # update default skip_substrs
        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
157
158
159

    def _groupby_prefix(
        self,
160
161
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
162
163
164
165
        weights_by_parts = (
            (weight_name.split(".", 1), weight_data)
            for weight_name, weight_data in weights
        )
166

167
        for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]):
168
169
170
171
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
172
173
174
175
                (
                    ("" if len(parts) == 1 else parts[1], weights_data)
                    for parts, weights_data in group
                ),
176
177
178
179
180
181
182
183
184
185
186
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
187
188
189
        return any(qualname.startswith(p) for p in self.skip_prefixes) or any(
            substr in qualname for substr in self.skip_substrs
        )
190
191

    def _can_ignore_unexpected(self, qualname: str) -> bool:
192
193
194
        iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
        ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
        return any(iup) or any(ius)
195
196
197
198
199

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
200
        weights: Iterable[tuple[str, torch.Tensor]],
201
    ) -> Iterable[str]:
202
203
204
205
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
206
207
                logger.debug("Skipping weight %s", weight_qualname)

208
209
210
                continue

            if weight_name != "":
211
212
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
213

214
215
216
                    continue

                raise ValueError(
217
218
                    f"Attempted to load nested weight {weight_qualname!r} "
                    f"into a single parameter {base_prefix!r}"
219
                )
220

221
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
222
223
            weight_loader(param, weight_data)

224
            logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape)
225

226
227
            yield weight_qualname

228
229
230
    def _add_loadable_non_param_tensors(
        self, module: nn.Module, child_params: dict[str, torch.Tensor]
    ):
231
232
233
234
        """
        Add tensor names that are not in the model params that may be in the
        safetensors, e.g., batch normalization stats.
        """
235
236
237
        if isinstance(
            module,
            (
238
239
240
241
242
243
244
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.LazyBatchNorm1d,
                nn.LazyBatchNorm2d,
                nn.LazyBatchNorm3d,
                nn.SyncBatchNorm,
245
246
            ),
        ):
247
            module_state_dict = module.state_dict()
248
            for stat_name in ("running_mean", "running_var", "num_batches_tracked"):
249
250
                child_params[stat_name] = module_state_dict[stat_name]

251
252
253
254
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
255
        weights: Iterable[tuple[str, torch.Tensor]],
256
    ) -> Iterable[str]:
257
        if isinstance(module, (StageMissingLayer, PPMissingLayer)):
258
259
260
261
262
263
264
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
265
                loaded_params = module_load_weights(weights)
266
267
                if loaded_params is None:
                    logger.warning(
268
269
                        "Unable to collect loaded parameters for module %s", module
                    )
270
271
272
273
274
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
275
276
277
278

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

279
280
281
282
        # Add missing tensors the weight loader needs to be able to load
        # that aren't registered as params, e.g., batchnorm statistics.
        self._add_loadable_non_param_tensors(module, child_params)

283
284
285
286
        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
287
288
289
290
291
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

292
293
294
                yield from self._load_module(
                    prefix, child_modules[child_prefix], child_weights
                )
295
            elif child_prefix in child_params:
296
297
298
299
300
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

301
302
303
                yield from self._load_param(
                    prefix, child_params[child_prefix], child_weights
                )
304
            else:
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

319
320
321
                desc_param_keys = {
                    base_prefix + k for k, _ in module.named_parameters(recurse=True)
                }
322
                msg = (
323
324
325
326
                    f"There is no module or parameter named {prefix!r} "
                    f"in {self.module._get_name()}. "
                    f"The available parameters belonging to {base_prefix} "
                    f"({module._get_name()}) are: {desc_param_keys}"
327
                )
328
                raise ValueError(msg)
329

330
    @support_quantized_model_reload_from_hp_weights
331
332
    def load_weights(
        self,
333
        weights: Iterable[tuple[str, torch.Tensor]],
334
        *,
335
        mapper: WeightsMapper | None = None,
336
    ) -> set[str]:
337
338
        if mapper is not None:
            weights = mapper.apply(weights)
339
        # filter out weights with first-prefix/substr to skip in name
340
341
342
        weights = (
            (name, weight) for name, weight in weights if not self._can_skip(name)
        )
343

344
        autoloaded_weights = set(self._load_module("", self.module, weights))
345
        return autoloaded_weights
346
347


348
def init_vllm_registered_model(
349
    vllm_config: VllmConfig,
350
    *,
351
    prefix: str = "",
352
353
    hf_config: PretrainedConfig | None = None,
    architectures: list[str] | None = None,
354
355
356
357
358
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
359
    from vllm.model_executor.model_loader.utils import initialize_model
360

361
362
363
364
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

365
    if hf_config is not None:
366
        vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures)
367

368
    return initialize_model(vllm_config=vllm_config, prefix=prefix)
369
370


371
@overload
372
def flatten_bn(x: torch.Tensor) -> torch.Tensor: ...
373
374
375


@overload
376
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ...
377
378
379
380


@overload
def flatten_bn(
381
    x: list[torch.Tensor] | torch.Tensor,
382
383
    *,
    concat: Literal[True],
384
) -> torch.Tensor: ...
385
386


387
388
@overload
def flatten_bn(
389
    x: list[torch.Tensor] | torch.Tensor,
390
391
    *,
    concat: bool = False,
392
) -> list[torch.Tensor] | torch.Tensor: ...
393
394


395
def flatten_bn(
396
    x: list[torch.Tensor] | torch.Tensor,
397
398
    *,
    concat: bool = False,
399
) -> list[torch.Tensor] | torch.Tensor:
400
    """
401
    Flatten the `B` and `N` dimensions of batched multimodal inputs.
402

403
    The input tensor should have shape `(B, N, ...)`.
404
405
406
407
408
409
410
411
412
413
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


414
415
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
416
417
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
418
419
420
    """

    if isinstance(embeddings, torch.Tensor):
421
422
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
423
424
425
426
427
428
429
430
431
432
433
434
435

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

436
    return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
437
438


439
440
441
442
443
444
445
446
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
    ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
    for num in lst:
        index = num // interval
        ranges[index].append(num)
    return ranges


Cyrus Leung's avatar
Cyrus Leung committed
447
448
449
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
450
    is_multimodal: torch.Tensor,
Cyrus Leung's avatar
Cyrus Leung committed
451
) -> torch.Tensor:
452
    """
453
454
455
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
456
457

    Note:
458
        This updates `inputs_embeds` in place.
459
    """
460
461
462
463
464
465
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

466
    try:
467
468
469
470
471
        # For debugging
        # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)

        # NOTE: This can avoid D2H sync (#22105), but fails to
        # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
472
473
474
        inputs_embeds.masked_scatter_(
            is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
        )
475
    except RuntimeError as e:
476
        num_actual_tokens = len(mm_embeds_flat)
477
478
        num_expected_tokens = is_multimodal.sum().item()

479
        if num_actual_tokens != num_expected_tokens:
480
            expr = _embedding_count_expression(multimodal_embeddings)
481

482
            raise ValueError(
483
                f"Attempted to assign {expr} = {num_actual_tokens} "
484
485
                f"multimodal tokens to {num_expected_tokens} placeholders"
            ) from e
Cyrus Leung's avatar
Cyrus Leung committed
486

487
        raise ValueError("Error during masked scatter operation") from e
Cyrus Leung's avatar
Cyrus Leung committed
488

489
    return inputs_embeds
Cyrus Leung's avatar
Cyrus Leung committed
490
491


492
493
494
495
496
497
498
499
500
501
502
503
def isin_list(
    elements: torch.Tensor,
    test_elements_list: list[int],
) -> torch.Tensor:
    test_elements = torch.tensor(
        test_elements_list,
        pin_memory=is_pin_memory_available(),
    ).to(device=elements.device, non_blocking=True)

    return torch.isin(elements, test_elements)


504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
class StageMissingLayer(nn.Module):
    def __init__(self, stage_name: str, module: nn.Module | None = None) -> None:
        super().__init__()

        self.stage_name = stage_name

        # Don't register this as a child module in order to
        # avoid missing keys when loading weights
        self.__dict__["module"] = module

    def __getattr__(self, name: str):
        return getattr(self.__dict__["module"], name)

    def __call__(self, *args, **kwargs):
        raise RuntimeError(f"{self} should not be called")

    def extra_repr(self) -> str:
        return f"stage_name={self.stage_name!r}"


@contextmanager
def collect_children(
    module: nn.Module,
    *,
    targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
    """
    Within this context, collect all direct child assignments to `module`,
    returning a list of children names that is internally updated until the
    context is exited.

    If `targets` is set, instead collect descendents of `module`
    that are an instance of `targets`, even if they aren't direct children.
    """
    children_names = list[str]()

    if targets is None:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if module_ is module:
                children_names.append(name)

        with register_module_module_registration_hook(hook):
            yield children_names
    else:
        yield children_names

        for name, module_ in module.named_modules():
            if isinstance(module_, targets):
                children_names.append(name)


@contextmanager
def no_init_weights(
    module: nn.Module,
    placeholder: Callable[[nn.Module], nn.Module],
    *,
    targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
    """
    Within this context, prevent weight initialization from using device memory and
    replace direct child assignments to `module` with the result of `placeholder()`.

    If `targets` is set, instead prevent weight initialization and
    replace assignments where the child is an instance of `targets`,
    even if they aren't direct children of `module`.
    """
    if targets is None:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if module_ is module:
                return placeholder(submodule)

            return submodule

        with register_module_module_registration_hook(hook), torch.device("meta"):
            yield
    else:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if isinstance(module_, targets):
                submodule.to("meta")  # Free memory
            if isinstance(submodule, targets):
                submodule.to("meta")  # Free memory
                return placeholder(submodule)

            return submodule

        # Not all descendents are targeted, so we can't use a blanket
        # `torch.device("meta")` context
        with register_module_module_registration_hook(hook):
            yield


598
class LayerFn(Protocol):
599
    def __call__(self, prefix: str) -> torch.nn.Module: ...
600
601


602
603
604
605
606
607
608
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
609
610

    def forward(self, *args, **kwargs):
611
612
        """Return the first arg from args or the first value from kwargs."""
        return args[0] if args else next(iter(kwargs.values()))
613
614


615
616
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0
617
_CPU_OFFLOAD_PARAMS = set()
618
619
620
621
622
623
624
625


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


626
627
628
629
630
def set_cpu_offload_params(params: set[str]) -> None:
    global _CPU_OFFLOAD_PARAMS
    _CPU_OFFLOAD_PARAMS = params


631
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
632
633
634
635
    if (params := next(module.parameters(), None)) is None:
        return module

    device = params.device
636
637
638
639
640
641
642
643

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

644
645
646
647
    pin_memory = (
        is_pin_memory_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
    )
    uva_offloading = is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA
648
649
650

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
651
    offloaded_parameters = False
652
    for name, p in module.named_parameters():
653
654
655
656
657
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

658
659
660
661
662
663
664
665
666
667
668
        if _CPU_OFFLOAD_PARAMS:
            # Check if parameter belongs to the offloading set
            # Add dots here to ensure we match full segments only
            # e.g., "experts.w2_weight" matches "mlp.experts.w2_weight" but not
            # "mlp.experts.w2_weight_scale"
            should_offload = any(
                f".{param}." in f".{name}." for param in _CPU_OFFLOAD_PARAMS
            )
            if not should_offload:
                continue

669
670
671
672
        cpu_data = p.data.to(device="cpu")
        if pin_memory:
            cpu_data = cpu_data.pin_memory()

673
674
675
        if not uva_offloading:
            p.data = cpu_data
        else:
676
            p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
677
678
            p._vllm_is_uva_offloaded = True

679
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
680
681
        offloaded_parameters = True

682
    if offloaded_parameters and not uva_offloading:
683
684
685
686
687
688
689
690
691
692
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
693
694
695
696
697
698

            # set `tie_weights=False` as tied weights in original model
            # become untied when calling .to(device) individually
            output = functional_call(
                module, device_state, args=args, kwargs=kwargs, tie_weights=False
            )
699
700
            module.forward = forward
            return output
701
702
703
704
705
706

        module.forward = forward

    return module


707
def make_layers(
708
709
710
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
711
) -> tuple[int, int, torch.nn.ModuleList]:
712
713
714
715
716
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
717
718
719
720

    start_layer, end_layer = get_pp_indices(
        num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size
    )
721
    modules = torch.nn.ModuleList(
722
723
        [PPMissingLayer() for _ in range(start_layer)]
        + [
724
725
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
726
727
728
        ]
        + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
    )
729
730
731
732
    if _CPU_OFFLOAD_MAX_BYTES > 0:
        logger.info(
            "Total CPU offloaded parameters: %s GBs", format_gib(_CPU_OFFLOAD_BYTES)
        )
733
734
735
736
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
737
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
738
739


740
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
741
742
743
744
745
746
747
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
748
        if isinstance(module, (StageMissingLayer, PPMissingLayer)):
749
750
751
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
752
            missing_layer_names.append(name + ".")
753
754
755
756
757
758
759
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
760
    if isinstance(model, (StageMissingLayer, PPMissingLayer)):
761
762
763
764
        return True

    return any(
        name.startswith(missing_layer_name)
765
766
        for missing_layer_name in get_pp_missing_layer_names(model)
    )
767
768


769
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
770
    def make_empty_intermediate_tensors(
771
772
773
774
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
775
776
777
778
779
780
        return IntermediateTensors(
            {
                key: torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
                for key in keys
            }
        )
781
782

    return make_empty_intermediate_tensors
783
784


785
786
787
788
789
790
791
792
793
794
795
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
796
797


798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
def get_draft_quant_config(
    vllm_config: VllmConfig,
) -> QuantizationConfig | None:
    """Get quantization config for Draft models.

    Draft models should use their own quantization config instead of the verifier/target
    model's config. This helper retrieves the draft model's quantization config.

    Args:
        vllm_config: The vLLM configuration object.

    Returns:
        The draft model's config if available, None otherwise.
    """
    draft_model_config = vllm_config.speculative_config.draft_model_config
    draft_load_config = vllm_config.load_config

    return (
        VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
        if draft_model_config
        else None
    )


XuruiYang's avatar
XuruiYang committed
822
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
823
824
825
826
827
828
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
XuruiYang's avatar
XuruiYang committed
829
    - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
830
831
    """
    subnames = layer_name.split(".")
832
    int_vals: list[int] = []
833
834
835
836
837
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
XuruiYang's avatar
XuruiYang committed
838
    if num_attn_module == 1 or "attn" not in layer_name:
839
840
841
        assert len(int_vals) == 1, (
            f"layer name {layer_name} should only contain one integer"
        )
XuruiYang's avatar
XuruiYang committed
842
843
844

        return int_vals[0]
    else:
845
846
847
848
849
850
851
852
        assert len(int_vals) <= 2, (
            f"layer name {layer_name} should contain most two integers"
        )
        layer_index = (
            int_vals[0] * num_attn_module + int_vals[1]
            if len(int_vals) == 2
            else int_vals[0]
        )
XuruiYang's avatar
XuruiYang committed
853
        return layer_index
854
855
856
857
858
859
860
861
862


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
863
    return tensors
864
865


866
867
868
def fast_topk(
    values: torch.Tensor, topk: int, dim: int
) -> tuple[torch.Tensor, torch.Tensor]:
869
870
    """
    Optimized topk implementation that uses torch.max for k=1 case.
871

872
873
    This function provides better performance for the common case of k=1
    by using torch.max instead of the more general torch.topk.
874

875
876
877
878
    Args:
        values: Input tensor to find top-k values from
        topk: Number of top values to return (k). Must be > 0.
        dim: Dimension along which to compute topk
879

880
881
882
883
    Returns:
        Tuple of (values, indices) where values are the top-k values
        and indices are their corresponding indices in the input tensor
    """
884
885
886
887
888
889
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
890
891


892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
    return torch.ops.vllm.sequence_parallel_chunk_impl(x)


def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    # all_gather needs the sequence length to be divisible by tp_size
    seq_len = x.size(0)
    remainder = seq_len % tp_size
    if remainder != 0:
        pad_len = tp_size - remainder
        y = nn.functional.pad(x, (0, 0, 0, pad_len))
    else:
        y = x

    chunk = y.shape[0] // tp_size
    start = tp_rank * chunk
    return torch.narrow(y, 0, start, chunk)


def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    seq_len = cdiv(x.size(0), tp_size)
    shape = list(x.shape)
    shape[0] = seq_len
    out = torch.empty(shape, dtype=x.dtype, device=x.device)
    return out


direct_register_custom_op(
    op_name="sequence_parallel_chunk_impl",
    op_func=sequence_parallel_chunk_impl,
    fake_impl=sequence_parallel_chunk_impl_fake,
931
    tags=(torch.Tag.needs_fixed_stride_order,),
932
)
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954


def process_eagle_weight(
    model: nn.Module,
    name: str,
) -> None:
    """
    Update EAGLE model flags based on loaded weight name.
    This should be called during weight loading to detect if a model
    has its own lm_head or embed_tokens weight.
    Args:
        model: The model instance (must support EAGLE)
        name: The name of the weight to process
    """
    if not supports_any_eagle(model):
        return

    # To prevent overriding with target model's layers
    if "lm_head" in name:
        model.has_own_lm_head = True
    if "embed_tokens" in name:
        model.has_own_embed_tokens = True
955
956
957
958
959
960
961
962
963
964
965
966
967


def get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given a signed vision feature layer, get the number of hidden layers
       needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index