"vscode:/vscode.git/clone" did not exist on "8477fe427d174df15204aae2819922cce384ecf2"
utils.py 30.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Iterable, Mapping
from contextlib import contextmanager
7
from dataclasses import dataclass, field
8
from typing import Any, Literal, Protocol, overload
9

10
import torch
11
import torch.nn as nn
12
from torch.func import functional_call
13
from torch.nn.modules.module import register_module_module_registration_hook
14
from transformers import PretrainedConfig
15

16
import vllm.envs as envs
17
from vllm.config import VllmConfig
18
19
20
21
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
22
from vllm.logger import init_logger
23
24
25
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig,
)
26
from vllm.model_executor.model_loader.reload import (
27
28
    support_quantized_model_reload_from_hp_weights,
)
29
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
from vllm.model_executor.models.interfaces import supports_any_eagle
31
from vllm.multimodal import NestedTensors
32
from vllm.sequence import IntermediateTensors
33
from vllm.utils.math_utils import cdiv
34
from vllm.utils.platform_utils import (
35
36
37
    is_pin_memory_available,
    is_uva_available,
)
38
39
from vllm.utils.torch_utils import (
    direct_register_custom_op,
40
    get_accelerator_view_from_cpu_tensor,
41
)
42
43

logger = init_logger(__name__)
44

45
WeightsMapping = Mapping[str, str | None]
46
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
47

48

49
50
51
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
52

53
54
55
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
56

57
58
59
60
61
62
63
64
    def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
        """Combine two `WeightsMapper`s by merging their mappings."""
        return WeightsMapper(
            orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
            orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
            orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
        )

65
    def _map_name(self, key: str) -> str | None:
66
67
68
69
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
88

89
    def apply(
90
91
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
92
93
94
95
96
        return (
            (out_name, data)
            for name, data in weights
            if (out_name := self._map_name(name)) is not None
        )
97

98
99
    def apply_list(self, values: list[str]) -> list[str]:
        return [
100
101
            out_name
            for name in values
102
103
104
105
106
107
108
109
110
111
            if (out_name := self._map_name(name)) is not None
        ]

    def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
        return {
            out_name: value
            for name, value in values.items()
            if (out_name := self._map_name(name)) is not None
        }

112
113

class AutoWeightsLoader:
114
    """
115
    Helper class to load weights into a [`torch.nn.Module`][]. It is able
116
117
118
119
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
120
    by defining a `load_weights` method.
121
122

    Similarly, the weight loading logic for individual parameters can be
123
    overridden by defining a `weight_loader` method.
124
125

    Detailed weight loading information can be viewed by setting the
126
    environment variable `VLLM_LOGGING_LEVEL=DEBUG`.
127
    """
128

129
130
    # Models trained using early version ColossalAI or quantized by
    # GPTQModel may include these tensors in checkpoint. Skip them.
131
    ROTARY_EMBEDS_UNUSED_WEIGHTS = [
132
        "rotary_pos_emb.inv_freq",
133
134
135
136
137
        "rotary_emb.inv_freq",
        "rotary_emb.cos_cached",
        "rotary_emb.sin_cached",
    ]

138
139
140
141
    def __init__(
        self,
        module: nn.Module,
        *,
142
143
144
145
        skip_prefixes: list[str] | None = None,
        skip_substrs: list[str] | None = None,
        ignore_unexpected_prefixes: list[str] | None = None,
        ignore_unexpected_suffixes: list[str] | None = None,
146
147
148
149
150
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
151
        self.skip_substrs = skip_substrs or []
152
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
153
        self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
154
155
        # update default skip_substrs
        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
156
157
158

    def _groupby_prefix(
        self,
159
160
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
161
162
163
164
        weights_by_parts = (
            (weight_name.split(".", 1), weight_data)
            for weight_name, weight_data in weights
        )
165

166
        for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]):
167
168
169
170
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
171
172
173
174
                (
                    ("" if len(parts) == 1 else parts[1], weights_data)
                    for parts, weights_data in group
                ),
175
176
177
178
179
180
181
182
183
184
185
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
186
187
188
        return any(qualname.startswith(p) for p in self.skip_prefixes) or any(
            substr in qualname for substr in self.skip_substrs
        )
189
190

    def _can_ignore_unexpected(self, qualname: str) -> bool:
191
192
193
        iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
        ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
        return any(iup) or any(ius)
194
195
196
197
198

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
199
        weights: Iterable[tuple[str, torch.Tensor]],
200
    ) -> Iterable[str]:
201
202
203
204
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
205
206
                logger.debug("Skipping weight %s", weight_qualname)

207
208
209
                continue

            if weight_name != "":
210
211
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
212

213
214
215
                    continue

                raise ValueError(
216
217
                    f"Attempted to load nested weight {weight_qualname!r} "
                    f"into a single parameter {base_prefix!r}"
218
                )
219

220
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
221
222
            weight_loader(param, weight_data)

223
            logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape)
224

225
226
            yield weight_qualname

227
228
229
    def _add_loadable_non_param_tensors(
        self, module: nn.Module, child_params: dict[str, torch.Tensor]
    ):
230
231
232
233
        """
        Add tensor names that are not in the model params that may be in the
        safetensors, e.g., batch normalization stats.
        """
234
235
236
        if isinstance(
            module,
            (
237
238
239
240
241
242
243
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.LazyBatchNorm1d,
                nn.LazyBatchNorm2d,
                nn.LazyBatchNorm3d,
                nn.SyncBatchNorm,
244
245
            ),
        ):
246
            module_state_dict = module.state_dict()
247
            for stat_name in ("running_mean", "running_var", "num_batches_tracked"):
248
249
                child_params[stat_name] = module_state_dict[stat_name]

250
251
252
253
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
254
        weights: Iterable[tuple[str, torch.Tensor]],
255
    ) -> Iterable[str]:
256
        if isinstance(module, (StageMissingLayer, PPMissingLayer)):
257
258
259
260
261
262
263
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
264
                loaded_params = module_load_weights(weights)
265
266
                if loaded_params is None:
                    logger.warning(
267
268
                        "Unable to collect loaded parameters for module %s", module
                    )
269
270
271
272
273
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
274
275
276
277

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

278
279
280
281
        # Add missing tensors the weight loader needs to be able to load
        # that aren't registered as params, e.g., batchnorm statistics.
        self._add_loadable_non_param_tensors(module, child_params)

282
283
284
285
        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
286
287
288
289
290
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

291
292
293
                yield from self._load_module(
                    prefix, child_modules[child_prefix], child_weights
                )
294
            elif child_prefix in child_params:
295
296
297
298
299
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

300
301
302
                yield from self._load_param(
                    prefix, child_params[child_prefix], child_weights
                )
303
            else:
304
305
306
307
308
309
310
311
312
313
314
315
316
317
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

318
319
320
                desc_param_keys = {
                    base_prefix + k for k, _ in module.named_parameters(recurse=True)
                }
321
                msg = (
322
323
324
325
                    f"There is no module or parameter named {prefix!r} "
                    f"in {self.module._get_name()}. "
                    f"The available parameters belonging to {base_prefix} "
                    f"({module._get_name()}) are: {desc_param_keys}"
326
                )
327
                raise ValueError(msg)
328

329
    @support_quantized_model_reload_from_hp_weights
330
331
    def load_weights(
        self,
332
        weights: Iterable[tuple[str, torch.Tensor]],
333
        *,
334
        mapper: WeightsMapper | None = None,
335
    ) -> set[str]:
336
337
        if mapper is not None:
            weights = mapper.apply(weights)
338
        # filter out weights with first-prefix/substr to skip in name
339
340
341
        weights = (
            (name, weight) for name, weight in weights if not self._can_skip(name)
        )
342

343
        autoloaded_weights = set(self._load_module("", self.module, weights))
344
        return autoloaded_weights
345
346


347
def init_vllm_registered_model(
348
    vllm_config: VllmConfig,
349
    *,
350
    prefix: str = "",
351
352
    hf_config: PretrainedConfig | None = None,
    architectures: list[str] | None = None,
353
354
355
356
357
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
358
    from vllm.model_executor.model_loader.utils import initialize_model
359

360
361
362
363
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

364
    if hf_config is not None:
365
        vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures)
366

367
    return initialize_model(vllm_config=vllm_config, prefix=prefix)
368
369


370
@overload
371
def flatten_bn(x: torch.Tensor) -> torch.Tensor: ...
372
373
374


@overload
375
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ...
376
377
378
379


@overload
def flatten_bn(
380
    x: list[torch.Tensor] | torch.Tensor,
381
382
    *,
    concat: Literal[True],
383
) -> torch.Tensor: ...
384
385


386
387
@overload
def flatten_bn(
388
    x: list[torch.Tensor] | torch.Tensor,
389
390
    *,
    concat: bool = False,
391
) -> list[torch.Tensor] | torch.Tensor: ...
392
393


394
def flatten_bn(
395
    x: list[torch.Tensor] | torch.Tensor,
396
397
    *,
    concat: bool = False,
398
) -> list[torch.Tensor] | torch.Tensor:
399
    """
400
    Flatten the `B` and `N` dimensions of batched multimodal inputs.
401

402
    The input tensor should have shape `(B, N, ...)`.
403
404
405
406
407
408
409
410
411
412
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


413
414
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
415
416
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
417
418
419
    """

    if isinstance(embeddings, torch.Tensor):
420
421
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
422
423
424
425
426
427
428
429
430
431
432
433
434

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

435
    return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
436
437


438
439
440
441
442
443
444
445
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
    ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
    for num in lst:
        index = num // interval
        ranges[index].append(num)
    return ranges


Cyrus Leung's avatar
Cyrus Leung committed
446
447
448
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
449
    is_multimodal: torch.Tensor,
Cyrus Leung's avatar
Cyrus Leung committed
450
) -> torch.Tensor:
451
    """
452
453
454
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
455
456

    Note:
457
        This updates `inputs_embeds` in place.
458
    """
459
460
461
462
463
464
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

465
    try:
466
467
468
469
470
        # For debugging
        # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)

        # NOTE: This can avoid D2H sync (#22105), but fails to
        # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
471
472
473
        inputs_embeds.masked_scatter_(
            is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
        )
474
    except RuntimeError as e:
475
        num_actual_tokens = len(mm_embeds_flat)
476
477
        num_expected_tokens = is_multimodal.sum().item()

478
        if num_actual_tokens != num_expected_tokens:
479
            expr = _embedding_count_expression(multimodal_embeddings)
480

481
            raise ValueError(
482
                f"Attempted to assign {expr} = {num_actual_tokens} "
483
484
                f"multimodal tokens to {num_expected_tokens} placeholders"
            ) from e
Cyrus Leung's avatar
Cyrus Leung committed
485

486
        raise ValueError("Error during masked scatter operation") from e
Cyrus Leung's avatar
Cyrus Leung committed
487

488
    return inputs_embeds
Cyrus Leung's avatar
Cyrus Leung committed
489
490


491
492
493
494
495
496
497
498
499
500
501
502
def isin_list(
    elements: torch.Tensor,
    test_elements_list: list[int],
) -> torch.Tensor:
    test_elements = torch.tensor(
        test_elements_list,
        pin_memory=is_pin_memory_available(),
    ).to(device=elements.device, non_blocking=True)

    return torch.isin(elements, test_elements)


503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
class StageMissingLayer(nn.Module):
    def __init__(self, stage_name: str, module: nn.Module | None = None) -> None:
        super().__init__()

        self.stage_name = stage_name

        # Don't register this as a child module in order to
        # avoid missing keys when loading weights
        self.__dict__["module"] = module

    def __getattr__(self, name: str):
        return getattr(self.__dict__["module"], name)

    def __call__(self, *args, **kwargs):
        raise RuntimeError(f"{self} should not be called")

    def extra_repr(self) -> str:
        return f"stage_name={self.stage_name!r}"


@contextmanager
def collect_children(
    module: nn.Module,
    *,
    targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
    """
    Within this context, collect all direct child assignments to `module`,
    returning a list of children names that is internally updated until the
    context is exited.

    If `targets` is set, instead collect descendents of `module`
    that are an instance of `targets`, even if they aren't direct children.
    """
    children_names = list[str]()

    if targets is None:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if module_ is module:
                children_names.append(name)

        with register_module_module_registration_hook(hook):
            yield children_names
    else:
        yield children_names

        for name, module_ in module.named_modules():
            if isinstance(module_, targets):
                children_names.append(name)


@contextmanager
def no_init_weights(
    module: nn.Module,
    placeholder: Callable[[nn.Module], nn.Module],
    *,
    targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
    """
    Within this context, prevent weight initialization from using device memory and
    replace direct child assignments to `module` with the result of `placeholder()`.

    If `targets` is set, instead prevent weight initialization and
    replace assignments where the child is an instance of `targets`,
    even if they aren't direct children of `module`.
    """
    if targets is None:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if module_ is module:
                return placeholder(submodule)

            return submodule

        with register_module_module_registration_hook(hook), torch.device("meta"):
            yield
    else:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if isinstance(module_, targets):
                submodule.to("meta")  # Free memory
            if isinstance(submodule, targets):
                submodule.to("meta")  # Free memory
                return placeholder(submodule)

            return submodule

        # Not all descendents are targeted, so we can't use a blanket
        # `torch.device("meta")` context
        with register_module_module_registration_hook(hook):
            yield


597
class LayerFn(Protocol):
598
    def __call__(self, prefix: str) -> torch.nn.Module: ...
599
600


601
602
603
604
605
606
607
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
608
609

    def forward(self, *args, **kwargs):
610
611
        """Return the first arg from args or the first value from kwargs."""
        return args[0] if args else next(iter(kwargs.values()))
612
613


614
615
616
617
618
619
620
621
622
623
624
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
625
626
627
628
    if (params := next(module.parameters(), None)) is None:
        return module

    device = params.device
629
630
631
632
633
634
635
636

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

637
638
639
640
    pin_memory = (
        is_pin_memory_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
    )
    uva_offloading = is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA
641
642
643

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
644
    offloaded_parameters = False
645
646
647
648
649
650
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

651
652
653
654
        cpu_data = p.data.to(device="cpu")
        if pin_memory:
            cpu_data = cpu_data.pin_memory()

655
656
657
        if not uva_offloading:
            p.data = cpu_data
        else:
658
            p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
659
660
            p._vllm_is_uva_offloaded = True

661
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
662
663
        offloaded_parameters = True

664
    if offloaded_parameters and not uva_offloading:
665
666
667
668
669
670
671
672
673
674
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
675
676
677
678
679
680

            # set `tie_weights=False` as tied weights in original model
            # become untied when calling .to(device) individually
            output = functional_call(
                module, device_state, args=args, kwargs=kwargs, tie_weights=False
            )
681
682
            module.forward = forward
            return output
683
684
685
686
687
688

        module.forward = forward

    return module


689
def make_layers(
690
691
692
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
693
) -> tuple[int, int, torch.nn.ModuleList]:
694
695
696
697
698
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
699
700
701
702

    start_layer, end_layer = get_pp_indices(
        num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size
    )
703
    modules = torch.nn.ModuleList(
704
705
        [PPMissingLayer() for _ in range(start_layer)]
        + [
706
707
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
708
709
710
        ]
        + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
    )
711
712
713
714
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
715
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
716
717


718
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
719
720
721
722
723
724
725
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
726
        if isinstance(module, (StageMissingLayer, PPMissingLayer)):
727
728
729
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
730
            missing_layer_names.append(name + ".")
731
732
733
734
735
736
737
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
738
    if isinstance(model, (StageMissingLayer, PPMissingLayer)):
739
740
741
742
        return True

    return any(
        name.startswith(missing_layer_name)
743
744
        for missing_layer_name in get_pp_missing_layer_names(model)
    )
745
746


747
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
748
    def make_empty_intermediate_tensors(
749
750
751
752
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
753
754
755
756
757
758
        return IntermediateTensors(
            {
                key: torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
                for key in keys
            }
        )
759
760

    return make_empty_intermediate_tensors
761
762


763
764
765
766
767
768
769
770
771
772
773
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
774
775


776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
def get_draft_quant_config(
    vllm_config: VllmConfig,
) -> QuantizationConfig | None:
    """Get quantization config for Draft models.

    Draft models should use their own quantization config instead of the verifier/target
    model's config. This helper retrieves the draft model's quantization config.

    Args:
        vllm_config: The vLLM configuration object.

    Returns:
        The draft model's config if available, None otherwise.
    """
    draft_model_config = vllm_config.speculative_config.draft_model_config
    draft_load_config = vllm_config.load_config

    return (
        VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
        if draft_model_config
        else None
    )


XuruiYang's avatar
XuruiYang committed
800
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
801
802
803
804
805
806
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
XuruiYang's avatar
XuruiYang committed
807
    - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
808
809
    """
    subnames = layer_name.split(".")
810
    int_vals: list[int] = []
811
812
813
814
815
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
XuruiYang's avatar
XuruiYang committed
816
    if num_attn_module == 1 or "attn" not in layer_name:
817
818
819
        assert len(int_vals) == 1, (
            f"layer name {layer_name} should only contain one integer"
        )
XuruiYang's avatar
XuruiYang committed
820
821
822

        return int_vals[0]
    else:
823
824
825
826
827
828
829
830
        assert len(int_vals) <= 2, (
            f"layer name {layer_name} should contain most two integers"
        )
        layer_index = (
            int_vals[0] * num_attn_module + int_vals[1]
            if len(int_vals) == 2
            else int_vals[0]
        )
XuruiYang's avatar
XuruiYang committed
831
        return layer_index
832
833
834
835
836
837
838
839
840


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
841
    return tensors
842
843


844
845
846
def fast_topk(
    values: torch.Tensor, topk: int, dim: int
) -> tuple[torch.Tensor, torch.Tensor]:
847
848
    """
    Optimized topk implementation that uses torch.max for k=1 case.
849

850
851
    This function provides better performance for the common case of k=1
    by using torch.max instead of the more general torch.topk.
852

853
854
855
856
    Args:
        values: Input tensor to find top-k values from
        topk: Number of top values to return (k). Must be > 0.
        dim: Dimension along which to compute topk
857

858
859
860
861
    Returns:
        Tuple of (values, indices) where values are the top-k values
        and indices are their corresponding indices in the input tensor
    """
862
863
864
865
866
867
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
868
869


870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
    return torch.ops.vllm.sequence_parallel_chunk_impl(x)


def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    # all_gather needs the sequence length to be divisible by tp_size
    seq_len = x.size(0)
    remainder = seq_len % tp_size
    if remainder != 0:
        pad_len = tp_size - remainder
        y = nn.functional.pad(x, (0, 0, 0, pad_len))
    else:
        y = x

    chunk = y.shape[0] // tp_size
    start = tp_rank * chunk
    return torch.narrow(y, 0, start, chunk)


def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    seq_len = cdiv(x.size(0), tp_size)
    shape = list(x.shape)
    shape[0] = seq_len
    out = torch.empty(shape, dtype=x.dtype, device=x.device)
    return out


direct_register_custom_op(
    op_name="sequence_parallel_chunk_impl",
    op_func=sequence_parallel_chunk_impl,
    fake_impl=sequence_parallel_chunk_impl_fake,
909
    tags=(torch.Tag.needs_fixed_stride_order,),
910
)
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932


def process_eagle_weight(
    model: nn.Module,
    name: str,
) -> None:
    """
    Update EAGLE model flags based on loaded weight name.
    This should be called during weight loading to detect if a model
    has its own lm_head or embed_tokens weight.
    Args:
        model: The model instance (must support EAGLE)
        name: The name of the weight to process
    """
    if not supports_any_eagle(model):
        return

    # To prevent overriding with target model's layers
    if "lm_head" in name:
        model.has_own_lm_head = True
    if "embed_tokens" in name:
        model.has_own_embed_tokens = True
933
934
935
936
937
938
939
940
941
942
943
944
945


def get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given a signed vision feature layer, get the number of hidden layers
       needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index