utils.py 28.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Iterable, Mapping
from contextlib import contextmanager
7
from dataclasses import dataclass, field
8
from typing import Any, Literal, Protocol, overload
9

10
import torch
11
import torch.nn as nn
12
from torch.nn.modules.module import register_module_module_registration_hook
13
from transformers import PretrainedConfig
14

15
from vllm.config import VllmConfig
16
17
18
19
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
20
from vllm.logger import init_logger
21
22
23
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig,
)
24
from vllm.model_executor.model_loader.reload import (
25
26
    support_quantized_model_reload_from_hp_weights,
)
27
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
from vllm.model_executor.models.interfaces import supports_any_eagle
29
from vllm.multimodal import NestedTensors
30
from vllm.sequence import IntermediateTensors
31
from vllm.utils.math_utils import cdiv
32
from vllm.utils.platform_utils import (
33
34
    is_pin_memory_available,
)
35
36
37
from vllm.utils.torch_utils import (
    direct_register_custom_op,
)
38
39

logger = init_logger(__name__)
40

41
WeightsMapping = Mapping[str, str | None]
42
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
43

44

45
46
47
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
48

49
50
51
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
52

53
54
55
56
57
58
59
60
    def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
        """Combine two `WeightsMapper`s by merging their mappings."""
        return WeightsMapper(
            orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
            orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
            orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
        )

61
    def _map_name(self, key: str) -> str | None:
62
63
64
65
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
84

85
    def apply(
86
87
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
88
89
90
91
92
        return (
            (out_name, data)
            for name, data in weights
            if (out_name := self._map_name(name)) is not None
        )
93

94
95
    def apply_list(self, values: list[str]) -> list[str]:
        return [
96
97
            out_name
            for name in values
98
99
100
101
102
103
104
105
106
107
            if (out_name := self._map_name(name)) is not None
        ]

    def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
        return {
            out_name: value
            for name, value in values.items()
            if (out_name := self._map_name(name)) is not None
        }

108
109

class AutoWeightsLoader:
110
    """
111
    Helper class to load weights into a [`torch.nn.Module`][]. It is able
112
113
114
115
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
116
    by defining a `load_weights` method.
117
118

    Similarly, the weight loading logic for individual parameters can be
119
    overridden by defining a `weight_loader` method.
120
121

    Detailed weight loading information can be viewed by setting the
122
    environment variable `VLLM_LOGGING_LEVEL=DEBUG`.
123
    """
124

125
126
    # Models trained using early version ColossalAI or quantized by
    # GPTQModel may include these tensors in checkpoint. Skip them.
127
    ROTARY_EMBEDS_UNUSED_WEIGHTS = [
128
        "rotary_pos_emb.inv_freq",
129
130
131
132
133
        "rotary_emb.inv_freq",
        "rotary_emb.cos_cached",
        "rotary_emb.sin_cached",
    ]

134
135
136
137
    def __init__(
        self,
        module: nn.Module,
        *,
138
139
140
141
        skip_prefixes: list[str] | None = None,
        skip_substrs: list[str] | None = None,
        ignore_unexpected_prefixes: list[str] | None = None,
        ignore_unexpected_suffixes: list[str] | None = None,
142
143
144
145
146
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
147
        self.skip_substrs = skip_substrs or []
148
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
149
        self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
150
151
        # update default skip_substrs
        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
152
153
154

    def _groupby_prefix(
        self,
155
156
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
157
158
159
160
        weights_by_parts = (
            (weight_name.split(".", 1), weight_data)
            for weight_name, weight_data in weights
        )
161

162
        for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]):
163
164
165
166
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
167
168
169
170
                (
                    ("" if len(parts) == 1 else parts[1], weights_data)
                    for parts, weights_data in group
                ),
171
172
173
174
175
176
177
178
179
180
181
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
182
183
184
        return any(qualname.startswith(p) for p in self.skip_prefixes) or any(
            substr in qualname for substr in self.skip_substrs
        )
185
186

    def _can_ignore_unexpected(self, qualname: str) -> bool:
187
188
189
        iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
        ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
        return any(iup) or any(ius)
190
191
192
193
194

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
195
        weights: Iterable[tuple[str, torch.Tensor]],
196
    ) -> Iterable[str]:
197
198
199
200
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
201
202
                logger.debug("Skipping weight %s", weight_qualname)

203
204
205
                continue

            if weight_name != "":
206
207
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
208

209
210
211
                    continue

                raise ValueError(
212
213
                    f"Attempted to load nested weight {weight_qualname!r} "
                    f"into a single parameter {base_prefix!r}"
214
                )
215

216
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
217
218
            weight_loader(param, weight_data)

219
            logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape)
220

221
222
            yield weight_qualname

223
224
225
    def _add_loadable_non_param_tensors(
        self, module: nn.Module, child_params: dict[str, torch.Tensor]
    ):
226
227
228
229
        """
        Add tensor names that are not in the model params that may be in the
        safetensors, e.g., batch normalization stats.
        """
230
231
232
        if isinstance(
            module,
            (
233
234
235
236
237
238
239
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.LazyBatchNorm1d,
                nn.LazyBatchNorm2d,
                nn.LazyBatchNorm3d,
                nn.SyncBatchNorm,
240
241
            ),
        ):
242
            module_state_dict = module.state_dict()
243
            for stat_name in ("running_mean", "running_var", "num_batches_tracked"):
244
245
                child_params[stat_name] = module_state_dict[stat_name]

246
247
248
249
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
250
        weights: Iterable[tuple[str, torch.Tensor]],
251
    ) -> Iterable[str]:
252
        if isinstance(module, (StageMissingLayer, PPMissingLayer)):
253
254
255
256
257
258
259
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
260
                loaded_params = module_load_weights(weights)
261
262
                if loaded_params is None:
                    logger.warning(
263
264
                        "Unable to collect loaded parameters for module %s", module
                    )
265
266
267
268
269
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
270
271
272
273

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

274
275
276
277
        # Add missing tensors the weight loader needs to be able to load
        # that aren't registered as params, e.g., batchnorm statistics.
        self._add_loadable_non_param_tensors(module, child_params)

278
279
280
281
        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
282
283
284
285
286
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

287
288
289
                yield from self._load_module(
                    prefix, child_modules[child_prefix], child_weights
                )
290
            elif child_prefix in child_params:
291
292
293
294
295
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

296
297
298
                yield from self._load_param(
                    prefix, child_params[child_prefix], child_weights
                )
299
            else:
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

314
315
316
                desc_param_keys = {
                    base_prefix + k for k, _ in module.named_parameters(recurse=True)
                }
317
                msg = (
318
319
320
321
                    f"There is no module or parameter named {prefix!r} "
                    f"in {self.module._get_name()}. "
                    f"The available parameters belonging to {base_prefix} "
                    f"({module._get_name()}) are: {desc_param_keys}"
322
                )
323
                raise ValueError(msg)
324

325
    @support_quantized_model_reload_from_hp_weights
326
327
    def load_weights(
        self,
328
        weights: Iterable[tuple[str, torch.Tensor]],
329
        *,
330
        mapper: WeightsMapper | None = None,
331
    ) -> set[str]:
332
333
        if mapper is not None:
            weights = mapper.apply(weights)
334
        # filter out weights with first-prefix/substr to skip in name
335
336
337
        weights = (
            (name, weight) for name, weight in weights if not self._can_skip(name)
        )
338

339
        autoloaded_weights = set(self._load_module("", self.module, weights))
340
        return autoloaded_weights
341
342


343
def init_vllm_registered_model(
344
    vllm_config: VllmConfig,
345
    *,
346
    prefix: str = "",
347
348
    hf_config: PretrainedConfig | None = None,
    architectures: list[str] | None = None,
349
350
351
352
353
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
354
    from vllm.model_executor.model_loader.utils import initialize_model
355

356
357
358
359
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

360
    if hf_config is not None:
361
        vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures)
362

363
    return initialize_model(vllm_config=vllm_config, prefix=prefix)
364
365


366
@overload
367
def flatten_bn(x: torch.Tensor) -> torch.Tensor: ...
368
369
370


@overload
371
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ...
372
373
374
375


@overload
def flatten_bn(
376
    x: list[torch.Tensor] | torch.Tensor,
377
378
    *,
    concat: Literal[True],
379
) -> torch.Tensor: ...
380
381


382
383
@overload
def flatten_bn(
384
    x: list[torch.Tensor] | torch.Tensor,
385
386
    *,
    concat: bool = False,
387
) -> list[torch.Tensor] | torch.Tensor: ...
388
389


390
def flatten_bn(
391
    x: list[torch.Tensor] | torch.Tensor,
392
393
    *,
    concat: bool = False,
394
) -> list[torch.Tensor] | torch.Tensor:
395
    """
396
    Flatten the `B` and `N` dimensions of batched multimodal inputs.
397

398
    The input tensor should have shape `(B, N, ...)`.
399
400
401
402
403
404
405
406
407
408
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


409
410
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
411
412
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
413
414
415
    """

    if isinstance(embeddings, torch.Tensor):
416
417
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
418
419
420
421
422
423
424
425
426
427
428
429
430

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

431
    return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
432
433


434
435
436
437
438
439
440
441
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
    ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
    for num in lst:
        index = num // interval
        ranges[index].append(num)
    return ranges


Cyrus Leung's avatar
Cyrus Leung committed
442
443
444
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
445
    is_multimodal: torch.Tensor,
Cyrus Leung's avatar
Cyrus Leung committed
446
) -> torch.Tensor:
447
    """
448
449
450
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
451
452

    Note:
453
        This updates `inputs_embeds` in place.
454
    """
455
456
457
458
459
460
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

461
    try:
462
463
464
465
466
        # For debugging
        # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)

        # NOTE: This can avoid D2H sync (#22105), but fails to
        # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
467
468
469
        inputs_embeds.masked_scatter_(
            is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
        )
470
    except RuntimeError as e:
471
        num_actual_tokens = len(mm_embeds_flat)
472
473
        num_expected_tokens = is_multimodal.sum().item()

474
        if num_actual_tokens != num_expected_tokens:
475
            expr = _embedding_count_expression(multimodal_embeddings)
476

477
            raise ValueError(
478
                f"Attempted to assign {expr} = {num_actual_tokens} "
479
480
                f"multimodal tokens to {num_expected_tokens} placeholders"
            ) from e
Cyrus Leung's avatar
Cyrus Leung committed
481

482
        raise ValueError("Error during masked scatter operation") from e
Cyrus Leung's avatar
Cyrus Leung committed
483

484
    return inputs_embeds
Cyrus Leung's avatar
Cyrus Leung committed
485
486


487
488
489
490
491
492
493
494
495
496
497
498
def isin_list(
    elements: torch.Tensor,
    test_elements_list: list[int],
) -> torch.Tensor:
    test_elements = torch.tensor(
        test_elements_list,
        pin_memory=is_pin_memory_available(),
    ).to(device=elements.device, non_blocking=True)

    return torch.isin(elements, test_elements)


499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
class StageMissingLayer(nn.Module):
    def __init__(self, stage_name: str, module: nn.Module | None = None) -> None:
        super().__init__()

        self.stage_name = stage_name

        # Don't register this as a child module in order to
        # avoid missing keys when loading weights
        self.__dict__["module"] = module

    def __getattr__(self, name: str):
        return getattr(self.__dict__["module"], name)

    def __call__(self, *args, **kwargs):
        raise RuntimeError(f"{self} should not be called")

    def extra_repr(self) -> str:
        return f"stage_name={self.stage_name!r}"


@contextmanager
def collect_children(
    module: nn.Module,
    *,
    targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
    """
    Within this context, collect all direct child assignments to `module`,
    returning a list of children names that is internally updated until the
    context is exited.

    If `targets` is set, instead collect descendents of `module`
    that are an instance of `targets`, even if they aren't direct children.
    """
    children_names = list[str]()

    if targets is None:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if module_ is module:
                children_names.append(name)

        with register_module_module_registration_hook(hook):
            yield children_names
    else:
        yield children_names

        for name, module_ in module.named_modules():
            if isinstance(module_, targets):
                children_names.append(name)


@contextmanager
def no_init_weights(
    module: nn.Module,
    placeholder: Callable[[nn.Module], nn.Module],
    *,
    targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
    """
    Within this context, prevent weight initialization from using device memory and
    replace direct child assignments to `module` with the result of `placeholder()`.

    If `targets` is set, instead prevent weight initialization and
    replace assignments where the child is an instance of `targets`,
    even if they aren't direct children of `module`.
    """
    if targets is None:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if module_ is module:
                return placeholder(submodule)

            return submodule

        with register_module_module_registration_hook(hook), torch.device("meta"):
            yield
    else:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if isinstance(module_, targets):
                submodule.to("meta")  # Free memory
            if isinstance(submodule, targets):
                submodule.to("meta")  # Free memory
                return placeholder(submodule)

            return submodule

        # Not all descendents are targeted, so we can't use a blanket
        # `torch.device("meta")` context
        with register_module_module_registration_hook(hook):
            yield


593
class LayerFn(Protocol):
594
    def __call__(self, prefix: str) -> torch.nn.Module: ...
595
596


597
598
599
600
601
602
603
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
604
605

    def forward(self, *args, **kwargs):
606
607
        """Return the first arg from args or the first value from kwargs."""
        return args[0] if args else next(iter(kwargs.values()))
608
609
610


def make_layers(
611
612
613
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
614
) -> tuple[int, int, torch.nn.ModuleList]:
615
616
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
617
618
619
620
621
622
623
624

    Args:
        num_hidden_layers: Total number of hidden layers in the model.
        layer_fn: Function to create a layer given its index.
        prefix: Prefix for layer names.

    Returns:
        Tuple of (start_layer, end_layer, modules).
625
626
627
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
628
    from vllm.model_executor.offloader import get_offloader
629
630
631
632

    start_layer, end_layer = get_pp_indices(
        num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size
    )
633

634
    modules = torch.nn.ModuleList(
635
        [PPMissingLayer() for _ in range(start_layer)]
636
637
638
        + get_offloader().wrap_modules(
            layer_fn(prefix=f"{prefix}.{idx}") for idx in range(start_layer, end_layer)
        )
639
640
        + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
    )
641

642
643
644
645
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
646
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
647
648


649
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
650
651
652
653
654
655
656
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
657
        if isinstance(module, (StageMissingLayer, PPMissingLayer)):
658
659
660
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
661
            missing_layer_names.append(name + ".")
662
663
664
665
666
667
668
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
669
    if isinstance(model, (StageMissingLayer, PPMissingLayer)):
670
671
672
673
        return True

    return any(
        name.startswith(missing_layer_name)
674
675
        for missing_layer_name in get_pp_missing_layer_names(model)
    )
676
677


678
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
679
    def make_empty_intermediate_tensors(
680
681
682
683
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
684
685
686
687
688
689
        return IntermediateTensors(
            {
                key: torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
                for key in keys
            }
        )
690
691

    return make_empty_intermediate_tensors
692
693


694
695
696
697
698
699
700
701
702
703
704
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
705
706


707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
def get_draft_quant_config(
    vllm_config: VllmConfig,
) -> QuantizationConfig | None:
    """Get quantization config for Draft models.

    Draft models should use their own quantization config instead of the verifier/target
    model's config. This helper retrieves the draft model's quantization config.

    Args:
        vllm_config: The vLLM configuration object.

    Returns:
        The draft model's config if available, None otherwise.
    """
    draft_model_config = vllm_config.speculative_config.draft_model_config
    draft_load_config = vllm_config.load_config

    return (
        VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
        if draft_model_config
        else None
    )


XuruiYang's avatar
XuruiYang committed
731
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
732
733
734
735
736
737
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
XuruiYang's avatar
XuruiYang committed
738
    - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
739
740
    """
    subnames = layer_name.split(".")
741
    int_vals: list[int] = []
742
743
744
745
746
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
XuruiYang's avatar
XuruiYang committed
747
    if num_attn_module == 1 or "attn" not in layer_name:
748
749
750
        assert len(int_vals) == 1, (
            f"layer name {layer_name} should only contain one integer"
        )
XuruiYang's avatar
XuruiYang committed
751
752
753

        return int_vals[0]
    else:
754
755
756
757
758
759
760
761
        assert len(int_vals) <= 2, (
            f"layer name {layer_name} should contain most two integers"
        )
        layer_index = (
            int_vals[0] * num_attn_module + int_vals[1]
            if len(int_vals) == 2
            else int_vals[0]
        )
XuruiYang's avatar
XuruiYang committed
762
        return layer_index
763
764
765
766
767
768
769
770
771


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
772
    return tensors
773
774


775
776
777
def fast_topk(
    values: torch.Tensor, topk: int, dim: int
) -> tuple[torch.Tensor, torch.Tensor]:
778
779
    """
    Optimized topk implementation that uses torch.max for k=1 case.
780

781
782
    This function provides better performance for the common case of k=1
    by using torch.max instead of the more general torch.topk.
783

784
785
786
787
    Args:
        values: Input tensor to find top-k values from
        topk: Number of top values to return (k). Must be > 0.
        dim: Dimension along which to compute topk
788

789
790
791
792
    Returns:
        Tuple of (values, indices) where values are the top-k values
        and indices are their corresponding indices in the input tensor
    """
793
794
795
796
797
798
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
799
800


801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
    return torch.ops.vllm.sequence_parallel_chunk_impl(x)


def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    # all_gather needs the sequence length to be divisible by tp_size
    seq_len = x.size(0)
    remainder = seq_len % tp_size
    if remainder != 0:
        pad_len = tp_size - remainder
        y = nn.functional.pad(x, (0, 0, 0, pad_len))
    else:
        y = x

    chunk = y.shape[0] // tp_size
    start = tp_rank * chunk
    return torch.narrow(y, 0, start, chunk)


def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    seq_len = cdiv(x.size(0), tp_size)
    shape = list(x.shape)
    shape[0] = seq_len
    out = torch.empty(shape, dtype=x.dtype, device=x.device)
    return out


direct_register_custom_op(
    op_name="sequence_parallel_chunk_impl",
    op_func=sequence_parallel_chunk_impl,
    fake_impl=sequence_parallel_chunk_impl_fake,
840
    tags=(torch.Tag.needs_fixed_stride_order,),
841
)
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863


def process_eagle_weight(
    model: nn.Module,
    name: str,
) -> None:
    """
    Update EAGLE model flags based on loaded weight name.
    This should be called during weight loading to detect if a model
    has its own lm_head or embed_tokens weight.
    Args:
        model: The model instance (must support EAGLE)
        name: The name of the weight to process
    """
    if not supports_any_eagle(model):
        return

    # To prevent overriding with target model's layers
    if "lm_head" in name:
        model.has_own_lm_head = True
    if "embed_tokens" in name:
        model.has_own_embed_tokens = True
864
865
866
867
868
869
870
871
872
873
874
875
876


def get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given a signed vision feature layer, get the number of hidden layers
       needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index