utils.py 24.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
from collections.abc import Iterable, Mapping
5
from dataclasses import dataclass, field
6
from typing import Callable, Literal, Optional, Protocol, Union, overload
7

8
import torch
9
import torch.nn as nn
10
from torch.func import functional_call
11
from transformers import PretrainedConfig
12

13
import vllm.envs as envs
14
from vllm.config import VllmConfig
15
from vllm.logger import init_logger
16
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
17
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
18
from vllm.sequence import IntermediateTensors
19
20
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
                        is_uva_available)
21
22

logger = init_logger(__name__)
23

24
25
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
26

27

28
29
30
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
31

32
33
34
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
35

36
37
38
39
40
    def _map_name(self, key: str) -> Optional[str]:
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
59

60
    def apply(
61
62
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
63
64
        return ((out_name, data) for name, data in weights
                if (out_name := self._map_name(name)) is not None)
65

66
67

class AutoWeightsLoader:
68
    """
69
    Helper class to load weights into a {class}`torch.nn.Module`. It is able
70
71
72
73
74
75
76
77
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
    by defining a ``load_weights`` method.

    Similarly, the weight loading logic for individual parameters can be
    overridden by defining a ``weight_loader`` method.
78
79
80

    Detailed weight loading information can be viewed by setting the
    environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
81
    """
82

83
84
85
86
87
88
89
90
    # Models trained using early version ColossalAI
    # may include these tensors in checkpoint. Skip them.
    ROTARY_EMBEDS_UNUSED_WEIGHTS = [
        "rotary_emb.inv_freq",
        "rotary_emb.cos_cached",
        "rotary_emb.sin_cached",
    ]

91
92
93
94
    def __init__(
        self,
        module: nn.Module,
        *,
95
        skip_prefixes: Optional[list[str]] = None,
96
        skip_substrs: Optional[list[str]] = None,
97
        ignore_unexpected_prefixes: Optional[list[str]] = None,
98
99
100
101
102
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
103
        self.skip_substrs = skip_substrs or []
104
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
105
106
        # update default skip_substrs
        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
107
108
109

    def _groupby_prefix(
        self,
110
111
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        weights_by_parts = ((weight_name.split(".", 1), weight_data)
                            for weight_name, weight_data in weights)

        for prefix, group in itertools.groupby(weights_by_parts,
                                               key=lambda x: x[0][0]):
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
                (("" if len(parts) == 1 else parts[1], weights_data)
                 for parts, weights_data in group),
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
134
135
        return (any(qualname.startswith(p) for p in self.skip_prefixes)
                or any(substr in qualname for substr in self.skip_substrs))
136
137
138
139
140
141
142
143
144

    def _can_ignore_unexpected(self, qualname: str) -> bool:
        return any(
            qualname.startswith(p) for p in self.ignore_unexpected_prefixes)

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
145
        weights: Iterable[tuple[str, torch.Tensor]],
146
    ) -> Iterable[str]:
147
148
149
150
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
151
152
                logger.debug("Skipping weight %s", weight_qualname)

153
154
155
                continue

            if weight_name != "":
156
157
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
158

159
160
161
162
163
                    continue

                raise ValueError(
                    f"Attempted to load nested weight '{weight_qualname}' "
                    f"into a single parameter '{base_prefix}'")
164
165
166
167
168

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, weight_data)

169
170
171
            logger.debug("Loaded weight %s with shape %s", weight_qualname,
                         param.shape)

172
173
            yield weight_qualname

174
    def _add_loadable_non_param_tensors(self, module: nn.Module,
175
                                        child_params: dict[str, torch.Tensor]):
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        """
        Add tensor names that are not in the model params that may be in the
        safetensors, e.g., batch normalization stats.
        """
        if isinstance(module, (
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.LazyBatchNorm1d,
                nn.LazyBatchNorm2d,
                nn.LazyBatchNorm3d,
                nn.SyncBatchNorm,
        )):
            module_state_dict = module.state_dict()
            for stat_name in ("running_mean", "running_var",
                              "num_batches_tracked"):
                child_params[stat_name] = module_state_dict[stat_name]

194
195
196
197
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
198
        weights: Iterable[tuple[str, torch.Tensor]],
199
    ) -> Iterable[str]:
200
201
202
203
204
205
206
207
        if isinstance(module, PPMissingLayer):
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
208
                loaded_params = module_load_weights(weights)
209
210
211
212
213
214
215
216
217
                if loaded_params is None:
                    logger.warning(
                        "Unable to collect loaded parameters "
                        "for module %s", module)
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
218
219
220
221

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

222
223
224
225
        # Add missing tensors the weight loader needs to be able to load
        # that aren't registered as params, e.g., batchnorm statistics.
        self._add_loadable_non_param_tensors(module, child_params)

226
227
228
229
        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
230
231
232
233
234
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

235
236
237
                yield from self._load_module(prefix,
                                             child_modules[child_prefix],
                                             child_weights)
238
            elif child_prefix in child_params:
239
240
241
242
243
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

244
245
                yield from self._load_param(prefix, child_params[child_prefix],
                                            child_weights)
246
            else:
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

                msg = (f"There is no module or parameter named '{prefix}' "
                       f"in {type(self.module).__name__}")
                raise ValueError(msg)
264
265
266

    def load_weights(
        self,
267
        weights: Iterable[tuple[str, torch.Tensor]],
268
269
        *,
        mapper: Optional[WeightsMapper] = None,
270
    ) -> set[str]:
271
272
        if mapper is not None:
            weights = mapper.apply(weights)
273
274
275
        # filter out weights with first-prefix/substr to skip in name
        weights = ((name, weight) for name, weight in weights
                   if not self._can_skip(name))
276

277
        autoloaded_weights = set(self._load_module("", self.module, weights))
278
        return autoloaded_weights
279
280


281
def init_vllm_registered_model(
282
    vllm_config: VllmConfig,
283
    *,
284
    prefix: str = "",
285
286
    hf_config: Optional[PretrainedConfig] = None,
    architectures: Optional[list[str]] = None,
287
288
289
290
291
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
292
    from vllm.model_executor.model_loader.utils import initialize_model
293

294
295
296
297
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

298
    if hf_config is not None:
299
300
        vllm_config = vllm_config.with_hf_config(hf_config,
                                                 architectures=architectures)
301

302
    return initialize_model(vllm_config=vllm_config, prefix=prefix)
303
304


305
306
307
308
309
310
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
    ...


@overload
311
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]:
312
313
314
315
316
    ...


@overload
def flatten_bn(
317
    x: Union[list[torch.Tensor], torch.Tensor],
318
319
320
321
322
323
    *,
    concat: Literal[True],
) -> torch.Tensor:
    ...


324
325
@overload
def flatten_bn(
326
    x: Union[list[torch.Tensor], torch.Tensor],
327
328
    *,
    concat: bool = False,
329
) -> Union[list[torch.Tensor], torch.Tensor]:
330
331
332
    ...


333
def flatten_bn(
334
    x: Union[list[torch.Tensor], torch.Tensor],
335
336
    *,
    concat: bool = False,
337
) -> Union[list[torch.Tensor], torch.Tensor]:
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    """
    Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.

    The input tensor should have shape ``(B, N, ...)```.
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


352
353
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
354
355
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
356
357
358
    """

    if isinstance(embeddings, torch.Tensor):
359
360
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

    return " + ".join(
        _embedding_count_expression(inner) for inner in embeddings)


378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def merge_multimodal_embeddings_from_map(
        inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
        placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
    """
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided 
    placeholder map .

    Note:
        This updates ``inputs_embeds`` in place.
    """
    flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
    inputs_embeds[placeholder_map.dest] = flattened_embeddings[
        placeholder_map.src]
    return inputs_embeds


Cyrus Leung's avatar
Cyrus Leung committed
394
395
396
397
398
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    is_multimodal: torch.Tensor,
    multimodal_embeddings: NestedTensors,
) -> torch.Tensor:
399
    """
400
401
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
402
    ``input_ids``.
403
404

    Note:
405
        This updates ``inputs_embeds`` in place.
406
    """
Cyrus Leung's avatar
Cyrus Leung committed
407
    num_expected_tokens = is_multimodal.sum().item()
408
    assert isinstance(num_expected_tokens, int)
409

410
    flattened = _flatten_embeddings(multimodal_embeddings)
411
    if flattened.shape[0] != num_expected_tokens:
412
413
        expr = _embedding_count_expression(multimodal_embeddings)
        raise ValueError(
414
            f"Attempted to assign {expr} = {flattened.shape[0]} "
415
            f"multimodal tokens to {num_expected_tokens} placeholders")
416

Cyrus Leung's avatar
Cyrus Leung committed
417
    inputs_embeds[is_multimodal] = flattened
418
    return inputs_embeds
419
420


Cyrus Leung's avatar
Cyrus Leung committed
421
422
423
424
def embed_multimodal(
    input_ids: torch.Tensor,
    multimodal_token_id: int,
    get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
425
    multimodal_embeds: NestedTensors,
Cyrus Leung's avatar
Cyrus Leung committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
) -> torch.Tensor:
    """
    Embed token IDs and multimodal inputs and combine their embeddings.

    ``multimodal_token_id`` is used to determine whether a token ID should
    be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.

    Compared to ``merge_multimodal_embeddings`, this avoids running
    ``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
    which causes issues when the placeholder token ID exceeds the
    vocabulary size of the language model.
    """
    is_multimodal = input_ids == multimodal_token_id
    is_text = ~is_multimodal

    text_embeds = get_text_embeds(input_ids[is_text])
    merged_embeds = torch.empty(
        (input_ids.shape[0], text_embeds.shape[1]),
        dtype=text_embeds.dtype,
        device=text_embeds.device,
    )

    merged_embeds[is_text] = text_embeds

    return _merge_multimodal_embeddings(
        merged_embeds,
        is_multimodal,
        multimodal_embeds,
    )


def merge_multimodal_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
461
    placeholder_token_id: Union[int, list[int]],
Cyrus Leung's avatar
Cyrus Leung committed
462
463
464
465
466
) -> torch.Tensor:
    """
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
    ``input_ids``.
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    
    ``placeholder_token_id`` can be a list of token ids (e.g, token ids 
    of img_start, img_break, and img_end tokens) when needed: This means 
    the order of these tokens in the ``input_ids`` MUST MATCH the order of 
    their embeddings in ``multimodal_embeddings`` since we need to 
    slice-merge instead of individually scattering.

    For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
    - T is text token
    - S is image start token
    - I is image embedding token
    - B is image break token
    - E is image end token.
    
    Then the image embeddings (that correspond to I's) from vision encoder 
    must be padded with embeddings of S, B, and E in the same order of 
    input_ids for a correct embedding merge.
Cyrus Leung's avatar
Cyrus Leung committed
484
485
486
487

    Note:
        This updates ``inputs_embeds`` in place.
    """
488
489
490
491
492
493
494
495
496
    if isinstance(placeholder_token_id, list):
        placeholder_token_id = torch.tensor(placeholder_token_id,
                                            device=input_ids.device)
        return _merge_multimodal_embeddings(
            inputs_embeds,
            torch.isin(input_ids, placeholder_token_id),
            multimodal_embeddings,
        )

Cyrus Leung's avatar
Cyrus Leung committed
497
498
499
500
501
502
503
    return _merge_multimodal_embeddings(
        inputs_embeds,
        (input_ids == placeholder_token_id),
        multimodal_embeddings,
    )


504
505
class LayerFn(Protocol):

506
    def __call__(self, prefix: str) -> torch.nn.Module:
507
508
509
        ...


510
511
512
513
514
515
516
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
517
518
519
520
521
522
523
524
525
526
        self.return_tuple = kwargs.get("return_tuple", False)

    def forward(self, *args, **kwargs):
        """
        Return the first arg from args or the first value from kwargs.

        Wraps the input in a tuple if `self.return_tuple` is True.
        """
        input = args[0] if args else next(iter(kwargs.values()))
        return (input, ) if self.return_tuple else input
527
528


529
530
531
532
533
534
535
536
537
538
539
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
540
541
542
543
    if (params := next(module.parameters(), None)) is None:
        return module

    device = params.device
544
545
546
547
548
549
550
551
552

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()
553
554
555
556
557
558
559
560
    uva_available = is_uva_available()

    if envs.VLLM_USE_V1:
        assert uva_available, ("V1 CPU offloading requires"
                               " uva (pin memory) support")
        uva_offloading = True
    else:
        uva_offloading = False
561
562
563

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
564
    offloaded_parameters = False
565
566
567
568
569
570
571
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
572
573
574
575
576
577
        cpu_data = torch.empty_strided(size=p.data.size(),
                                       stride=p.data.stride(),
                                       dtype=p.data.dtype,
                                       layout=p.data.layout,
                                       device='cpu',
                                       pin_memory=pin_memory)
578
        cpu_data.copy_(p.data)
579
580
581
582
583
584
        if not uva_offloading:
            p.data = cpu_data
        else:
            # keep the cpu data alive
            p._vllm_offloaded_cpu_data = cpu_data
            p.data = get_cuda_view_from_cpu_tensor(cpu_data)
585
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
586
587
        offloaded_parameters = True

588
    if offloaded_parameters and not uva_offloading:
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module,
                                     device_state,
                                     args=args,
                                     kwargs=kwargs)
            module.forward = forward
            return output
605
606
607
608
609
610

        module.forward = forward

    return module


611
def make_layers(
612
613
614
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
615
) -> tuple[int, int, torch.nn.ModuleList]:
616
617
618
619
620
621
622
623
624
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
    start_layer, end_layer = get_pp_indices(num_hidden_layers,
                                            get_pp_group().rank_in_group,
                                            get_pp_group().world_size)
    modules = torch.nn.ModuleList(
625
        [PPMissingLayer() for _ in range(start_layer)] + [
626
627
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
628
        ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
629
630
631
632
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
633
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
634
635


636
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
637
638
639
640
641
642
643
644
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
645
646
647
648
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
            missing_layer_names.append(name + '.')
649
650
651
652
653
654
655
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
656
657
658
659
660
661
    if isinstance(model, PPMissingLayer):
        return True

    return any(
        name.startswith(missing_layer_name)
        for missing_layer_name in get_pp_missing_layer_names(model))
662
663


664
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
665
666

    def make_empty_intermediate_tensors(
667
668
669
670
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
671
        return IntermediateTensors({
672
673
            key:
            torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
674
675
676
677
            for key in keys
        })

    return make_empty_intermediate_tensors
678
679


680
681
682
683
684
685
686
687
688
689
690
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
691
692
693
694
695
696
697
698
699
700
701
702


def extract_layer_index(layer_name: str) -> int:
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
    - "model.encoder.layers.0.sub.1" -> ValueError
    """
    subnames = layer_name.split(".")
703
    int_vals: list[int] = []
704
705
706
707
708
709
710
711
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
    assert len(int_vals) == 1, (f"layer name {layer_name} should"
                                " only contain one integer")
    return int_vals[0]
712
713
714
715
716
717
718
719
720


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
721
    return tensors
722
723
724
725
726
727
728
729
730


def fast_topk(values, topk, dim):
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)