utils.py 28.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Iterable, Mapping
from contextlib import contextmanager
7
from dataclasses import dataclass, field
8
from typing import Any, Literal, Protocol, overload
9

10
import regex as re
11
import torch
12
import torch.nn as nn
13
from torch.nn.modules.module import register_module_module_registration_hook
14
from transformers import PretrainedConfig
15

16
from vllm.config import VllmConfig
17
18
19
20
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
21
from vllm.logger import init_logger
22
23
24
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig,
)
25
from vllm.model_executor.model_loader.reload import (
26
27
    support_quantized_model_reload_from_hp_weights,
)
28
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
from vllm.model_executor.models.interfaces import supports_any_eagle
30
from vllm.multimodal import NestedTensors
31
from vllm.sequence import IntermediateTensors
32
from vllm.utils.math_utils import cdiv
33
from vllm.utils.platform_utils import (
34
35
    is_pin_memory_available,
)
36
37
38
from vllm.utils.torch_utils import (
    direct_register_custom_op,
)
39
40

logger = init_logger(__name__)
41

42

43
44
@dataclass
class WeightsMapper:
45
46
47
    """Maps the name of each weight if they match the following patterns.

    If a key maps to a value of `None`, the corresponding weight is ignored."""
48

49
50
51
52
    orig_to_new_regex: Mapping[re.Pattern, str | None] = field(default_factory=dict)
    orig_to_new_substr: Mapping[str, str | None] = field(default_factory=dict)
    orig_to_new_prefix: Mapping[str, str | None] = field(default_factory=dict)
    orig_to_new_suffix: Mapping[str, str | None] = field(default_factory=dict)
53

54
55
56
57
58
59
60
61
    def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
        """Combine two `WeightsMapper`s by merging their mappings."""
        return WeightsMapper(
            orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
            orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
            orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
        )

62
    def _map_name(self, key: str) -> str | None:
63
64
65
66
67
68
69
        for pattern, new_key in self.orig_to_new_regex.items():
            if pattern.search(key):
                if new_key is None:
                    return None

                key = pattern.sub(new_key, key)

70
71
72
73
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
74

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
92

93
    def apply(
94
95
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
96
97
98
99
100
        return (
            (out_name, data)
            for name, data in weights
            if (out_name := self._map_name(name)) is not None
        )
101

102
103
    def apply_list(self, values: list[str]) -> list[str]:
        return [
104
105
            out_name
            for name in values
106
107
108
109
110
111
112
113
114
115
            if (out_name := self._map_name(name)) is not None
        ]

    def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
        return {
            out_name: value
            for name, value in values.items()
            if (out_name := self._map_name(name)) is not None
        }

116
117

class AutoWeightsLoader:
118
    """
119
    Helper class to load weights into a [`torch.nn.Module`][]. It is able
120
121
122
123
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
124
    by defining a `load_weights` method.
125
126

    Similarly, the weight loading logic for individual parameters can be
127
    overridden by defining a `weight_loader` method.
128
129

    Detailed weight loading information can be viewed by setting the
130
    environment variable `VLLM_LOGGING_LEVEL=DEBUG`.
131
    """
132

133
134
    # Models trained using early version ColossalAI or quantized by
    # GPTQModel may include these tensors in checkpoint. Skip them.
135
    ROTARY_EMBEDS_UNUSED_WEIGHTS = [
136
        "rotary_pos_emb.inv_freq",
137
138
139
140
141
        "rotary_emb.inv_freq",
        "rotary_emb.cos_cached",
        "rotary_emb.sin_cached",
    ]

142
143
144
145
    def __init__(
        self,
        module: nn.Module,
        *,
146
147
148
149
        skip_prefixes: list[str] | None = None,
        skip_substrs: list[str] | None = None,
        ignore_unexpected_prefixes: list[str] | None = None,
        ignore_unexpected_suffixes: list[str] | None = None,
150
151
152
153
154
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
155
        self.skip_substrs = skip_substrs or []
156
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
157
        self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
158
159
        # update default skip_substrs
        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
160
161
162

    def _groupby_prefix(
        self,
163
164
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
165
166
167
168
        weights_by_parts = (
            (weight_name.split(".", 1), weight_data)
            for weight_name, weight_data in weights
        )
169

170
        for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]):
171
172
173
174
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
175
176
177
178
                (
                    ("" if len(parts) == 1 else parts[1], weights_data)
                    for parts, weights_data in group
                ),
179
180
181
182
183
184
185
186
187
188
189
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
190
191
192
        return any(qualname.startswith(p) for p in self.skip_prefixes) or any(
            substr in qualname for substr in self.skip_substrs
        )
193
194

    def _can_ignore_unexpected(self, qualname: str) -> bool:
195
196
197
        iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
        ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
        return any(iup) or any(ius)
198
199
200
201
202

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
203
        weights: Iterable[tuple[str, torch.Tensor]],
204
    ) -> Iterable[str]:
205
206
207
208
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
209
210
                logger.debug("Skipping weight %s", weight_qualname)

211
212
213
                continue

            if weight_name != "":
214
215
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
216

217
218
219
                    continue

                raise ValueError(
220
221
                    f"Attempted to load nested weight {weight_qualname!r} "
                    f"into a single parameter {base_prefix!r}"
222
                )
223

224
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
225
226
            weight_loader(param, weight_data)

227
            logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape)
228

229
230
            yield weight_qualname

231
232
233
    def _add_loadable_non_param_tensors(
        self, module: nn.Module, child_params: dict[str, torch.Tensor]
    ):
234
235
        """
        Add tensor names that are not in the model params that may be in the
236
        safetensors, e.g., batch normalization stats and registered buffers.
237
        """
238
239
240
241
242
243
244
        # Add persistent registered buffers.
        # Non-persistent buffers are excluded, matching PyTorch state_dict().
        non_persistent = getattr(module, "_non_persistent_buffers_set", set())
        for buf_name, buf in module.named_buffers(recurse=False):
            if buf_name not in child_params and buf_name not in non_persistent:
                child_params[buf_name] = buf

245
246
247
        if isinstance(
            module,
            (
248
249
250
251
252
253
254
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.LazyBatchNorm1d,
                nn.LazyBatchNorm2d,
                nn.LazyBatchNorm3d,
                nn.SyncBatchNorm,
255
256
            ),
        ):
257
            module_state_dict = module.state_dict()
258
            for stat_name in ("running_mean", "running_var", "num_batches_tracked"):
259
260
                child_params[stat_name] = module_state_dict[stat_name]

261
262
263
264
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
265
        weights: Iterable[tuple[str, torch.Tensor]],
266
    ) -> Iterable[str]:
267
        if isinstance(module, (StageMissingLayer, PPMissingLayer)):
268
269
270
271
272
273
274
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
275
                loaded_params = module_load_weights(weights)
276
277
                if loaded_params is None:
                    logger.warning(
278
279
                        "Unable to collect loaded parameters for module %s", module
                    )
280
281
282
283
284
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
285
286
287
288

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

289
290
291
292
        # Add missing tensors the weight loader needs to be able to load
        # that aren't registered as params, e.g., batchnorm statistics.
        self._add_loadable_non_param_tensors(module, child_params)

293
294
295
296
        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
297
298
299
300
301
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

302
303
304
                yield from self._load_module(
                    prefix, child_modules[child_prefix], child_weights
                )
305
            elif child_prefix in child_params:
306
307
308
309
310
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

311
312
313
                yield from self._load_param(
                    prefix, child_params[child_prefix], child_weights
                )
314
            else:
315
316
317
318
319
320
321
322
323
324
325
326
327
328
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

329
                named_parameters = module.named_parameters(recurse=True)
330
                desc_param_keys = {
331
                    maybe_prefix(base_prefix, k) for k, _ in named_parameters
332
                }
333
                msg = (
334
335
336
337
                    f"There is no module or parameter named {prefix!r} "
                    f"in {self.module._get_name()}. "
                    f"The available parameters belonging to {base_prefix} "
                    f"({module._get_name()}) are: {desc_param_keys}"
338
                )
339
                raise ValueError(msg)
340

341
    @support_quantized_model_reload_from_hp_weights
342
343
    def load_weights(
        self,
344
        weights: Iterable[tuple[str, torch.Tensor]],
345
        *,
346
        mapper: WeightsMapper | None = None,
347
    ) -> set[str]:
348
349
        if mapper is not None:
            weights = mapper.apply(weights)
350
        # filter out weights with first-prefix/substr to skip in name
351
352
353
        weights = (
            (name, weight) for name, weight in weights if not self._can_skip(name)
        )
354

355
        autoloaded_weights = set(self._load_module("", self.module, weights))
356
        return autoloaded_weights
357
358


359
def init_vllm_registered_model(
360
    vllm_config: VllmConfig,
361
    *,
362
    prefix: str = "",
363
364
    hf_config: PretrainedConfig | None = None,
    architectures: list[str] | None = None,
365
366
367
368
369
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
370
    from vllm.model_executor.model_loader.utils import initialize_model
371

372
373
374
375
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

376
    if hf_config is not None:
377
        vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures)
378

379
    return initialize_model(vllm_config=vllm_config, prefix=prefix)
380
381


382
@overload
383
def flatten_bn(x: torch.Tensor) -> torch.Tensor: ...
384
385
386


@overload
387
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ...
388
389
390
391


@overload
def flatten_bn(
392
    x: list[torch.Tensor] | torch.Tensor,
393
394
    *,
    concat: Literal[True],
395
) -> torch.Tensor: ...
396
397


398
399
@overload
def flatten_bn(
400
    x: list[torch.Tensor] | torch.Tensor,
401
402
    *,
    concat: bool = False,
403
) -> list[torch.Tensor] | torch.Tensor: ...
404
405


406
def flatten_bn(
407
    x: list[torch.Tensor] | torch.Tensor,
408
409
    *,
    concat: bool = False,
410
) -> list[torch.Tensor] | torch.Tensor:
411
    """
412
    Flatten the `B` and `N` dimensions of batched multimodal inputs.
413

414
    The input tensor should have shape `(B, N, ...)`.
415
416
417
418
419
420
421
422
423
424
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


425
426
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
427
428
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
429
430
431
    """

    if isinstance(embeddings, torch.Tensor):
432
433
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
434
435
436
437
438
439
440
441
442
443
444
445
446

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

447
    return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
448
449


450
451
452
453
454
455
456
457
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
    ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
    for num in lst:
        index = num // interval
        ranges[index].append(num)
    return ranges


Cyrus Leung's avatar
Cyrus Leung committed
458
459
460
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
461
    is_multimodal: torch.Tensor,
Cyrus Leung's avatar
Cyrus Leung committed
462
) -> torch.Tensor:
463
    """
464
465
466
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
467
468

    Note:
469
        This updates `inputs_embeds` in place.
470
    """
471
472
473
474
475
476
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

477
    try:
478
479
        # If is_multimodal is on CPU this avoids a D2H sync
        inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
480
    except RuntimeError as e:
481
        num_actual_tokens = len(mm_embeds_flat)
482
483
        num_expected_tokens = is_multimodal.sum().item()

484
        if num_actual_tokens != num_expected_tokens:
485
            expr = _embedding_count_expression(multimodal_embeddings)
486

487
            raise ValueError(
488
                f"Attempted to assign {expr} = {num_actual_tokens} "
489
490
                f"multimodal tokens to {num_expected_tokens} placeholders"
            ) from e
Cyrus Leung's avatar
Cyrus Leung committed
491

492
        raise ValueError("Error during index put operation") from e
Cyrus Leung's avatar
Cyrus Leung committed
493

494
    return inputs_embeds
Cyrus Leung's avatar
Cyrus Leung committed
495
496


497
498
499
500
501
502
503
504
505
506
507
508
def isin_list(
    elements: torch.Tensor,
    test_elements_list: list[int],
) -> torch.Tensor:
    test_elements = torch.tensor(
        test_elements_list,
        pin_memory=is_pin_memory_available(),
    ).to(device=elements.device, non_blocking=True)

    return torch.isin(elements, test_elements)


509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
class StageMissingLayer(nn.Module):
    def __init__(self, stage_name: str, module: nn.Module | None = None) -> None:
        super().__init__()

        self.stage_name = stage_name

        # Don't register this as a child module in order to
        # avoid missing keys when loading weights
        self.__dict__["module"] = module

    def __getattr__(self, name: str):
        return getattr(self.__dict__["module"], name)

    def __call__(self, *args, **kwargs):
        raise RuntimeError(f"{self} should not be called")

    def extra_repr(self) -> str:
        return f"stage_name={self.stage_name!r}"


@contextmanager
def collect_children(
    module: nn.Module,
    *,
    targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
    """
    Within this context, collect all direct child assignments to `module`,
    returning a list of children names that is internally updated until the
    context is exited.

    If `targets` is set, instead collect descendents of `module`
    that are an instance of `targets`, even if they aren't direct children.
    """
    children_names = list[str]()

    if targets is None:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if module_ is module:
                children_names.append(name)

        with register_module_module_registration_hook(hook):
            yield children_names
    else:
        yield children_names

        for name, module_ in module.named_modules():
            if isinstance(module_, targets):
                children_names.append(name)


@contextmanager
def no_init_weights(
    module: nn.Module,
    placeholder: Callable[[nn.Module], nn.Module],
    *,
    targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
    """
    Within this context, prevent weight initialization from using device memory and
    replace direct child assignments to `module` with the result of `placeholder()`.

    If `targets` is set, instead prevent weight initialization and
    replace assignments where the child is an instance of `targets`,
    even if they aren't direct children of `module`.
    """
    if targets is None:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if module_ is module:
                return placeholder(submodule)

            return submodule

        with register_module_module_registration_hook(hook), torch.device("meta"):
            yield
    else:

        def hook(module_: nn.Module, name: str, submodule: nn.Module):
            if isinstance(module_, targets):
                submodule.to("meta")  # Free memory
            if isinstance(submodule, targets):
                submodule.to("meta")  # Free memory
                return placeholder(submodule)

            return submodule

        # Not all descendents are targeted, so we can't use a blanket
        # `torch.device("meta")` context
        with register_module_module_registration_hook(hook):
            yield


603
class LayerFn(Protocol):
604
    def __call__(self, prefix: str) -> torch.nn.Module: ...
605
606


607
608
609
610
611
612
613
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
614
615

    def forward(self, *args, **kwargs):
616
617
        """Return the first arg from args or the first value from kwargs."""
        return args[0] if args else next(iter(kwargs.values()))
618
619
620


def make_layers(
621
622
623
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
624
) -> tuple[int, int, torch.nn.ModuleList]:
625
626
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
627
628
629
630
631
632
633
634

    Args:
        num_hidden_layers: Total number of hidden layers in the model.
        layer_fn: Function to create a layer given its index.
        prefix: Prefix for layer names.

    Returns:
        Tuple of (start_layer, end_layer, modules).
635
636
637
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
638
    from vllm.model_executor.offloader import get_offloader
639
640
641
642

    start_layer, end_layer = get_pp_indices(
        num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size
    )
643

644
    modules = torch.nn.ModuleList(
645
        [PPMissingLayer() for _ in range(start_layer)]
646
647
648
        + get_offloader().wrap_modules(
            layer_fn(prefix=f"{prefix}.{idx}") for idx in range(start_layer, end_layer)
        )
649
650
        + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
    )
651

652
653
654
655
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
656
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
657
658


659
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
660
661
662
663
664
665
666
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
667
        if isinstance(module, (StageMissingLayer, PPMissingLayer)):
668
669
670
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
671
            missing_layer_names.append(name + ".")
672
673
674
675
676
677
678
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
679
    if isinstance(model, (StageMissingLayer, PPMissingLayer)):
680
681
682
683
        return True

    return any(
        name.startswith(missing_layer_name)
684
685
        for missing_layer_name in get_pp_missing_layer_names(model)
    )
686
687


688
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
689
    def make_empty_intermediate_tensors(
690
691
692
693
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
694
695
696
697
698
699
        return IntermediateTensors(
            {
                key: torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
                for key in keys
            }
        )
700
701

    return make_empty_intermediate_tensors
702
703


704
705
706
707
708
709
710
711
712
713
714
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
715
716


717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
def get_draft_quant_config(
    vllm_config: VllmConfig,
) -> QuantizationConfig | None:
    """Get quantization config for Draft models.

    Draft models should use their own quantization config instead of the verifier/target
    model's config. This helper retrieves the draft model's quantization config.

    Args:
        vllm_config: The vLLM configuration object.

    Returns:
        The draft model's config if available, None otherwise.
    """
    draft_model_config = vllm_config.speculative_config.draft_model_config
    draft_load_config = vllm_config.load_config

    return (
        VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
        if draft_model_config
        else None
    )


XuruiYang's avatar
XuruiYang committed
741
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
742
743
744
745
746
747
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
XuruiYang's avatar
XuruiYang committed
748
    - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
749
750
    """
    subnames = layer_name.split(".")
751
    int_vals: list[int] = []
752
753
754
755
756
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
XuruiYang's avatar
XuruiYang committed
757
    if num_attn_module == 1 or "attn" not in layer_name:
758
759
760
        assert len(int_vals) == 1, (
            f"layer name {layer_name} should only contain one integer"
        )
XuruiYang's avatar
XuruiYang committed
761
762
763

        return int_vals[0]
    else:
764
765
766
767
768
769
770
771
        assert len(int_vals) <= 2, (
            f"layer name {layer_name} should contain most two integers"
        )
        layer_index = (
            int_vals[0] * num_attn_module + int_vals[1]
            if len(int_vals) == 2
            else int_vals[0]
        )
XuruiYang's avatar
XuruiYang committed
772
        return layer_index
773
774
775
776
777
778
779
780
781


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
782
    return tensors
783
784


785
786
787
def fast_topk(
    values: torch.Tensor, topk: int, dim: int
) -> tuple[torch.Tensor, torch.Tensor]:
788
789
    """
    Optimized topk implementation that uses torch.max for k=1 case.
790

791
792
    This function provides better performance for the common case of k=1
    by using torch.max instead of the more general torch.topk.
793

794
795
796
797
    Args:
        values: Input tensor to find top-k values from
        topk: Number of top values to return (k). Must be > 0.
        dim: Dimension along which to compute topk
798

799
800
801
802
    Returns:
        Tuple of (values, indices) where values are the top-k values
        and indices are their corresponding indices in the input tensor
    """
803
804
805
806
807
808
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
809
810


811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
    return torch.ops.vllm.sequence_parallel_chunk_impl(x)


def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    # all_gather needs the sequence length to be divisible by tp_size
    seq_len = x.size(0)
    remainder = seq_len % tp_size
    if remainder != 0:
        pad_len = tp_size - remainder
        y = nn.functional.pad(x, (0, 0, 0, pad_len))
    else:
        y = x

    chunk = y.shape[0] // tp_size
    start = tp_rank * chunk
    return torch.narrow(y, 0, start, chunk)


def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    seq_len = cdiv(x.size(0), tp_size)
    shape = list(x.shape)
    shape[0] = seq_len
    out = torch.empty(shape, dtype=x.dtype, device=x.device)
    return out


direct_register_custom_op(
    op_name="sequence_parallel_chunk_impl",
    op_func=sequence_parallel_chunk_impl,
    fake_impl=sequence_parallel_chunk_impl_fake,
850
    tags=(torch.Tag.needs_fixed_stride_order,),
851
)
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873


def process_eagle_weight(
    model: nn.Module,
    name: str,
) -> None:
    """
    Update EAGLE model flags based on loaded weight name.
    This should be called during weight loading to detect if a model
    has its own lm_head or embed_tokens weight.
    Args:
        model: The model instance (must support EAGLE)
        name: The name of the weight to process
    """
    if not supports_any_eagle(model):
        return

    # To prevent overriding with target model's layers
    if "lm_head" in name:
        model.has_own_lm_head = True
    if "embed_tokens" in name:
        model.has_own_embed_tokens = True
874
875
876
877
878
879
880
881
882
883
884
885
886


def get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given a signed vision feature layer, get the number of hidden layers
       needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index