utils.py 24.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Iterable, Mapping
6
from dataclasses import dataclass, field
7
from typing import Any, Literal, Optional, Protocol, Union, overload
8

9
import torch
10
import torch.nn as nn
11
from torch.func import functional_call
12
from transformers import PretrainedConfig
13

14
import vllm.envs as envs
15
from vllm.config import VllmConfig
16
from vllm.logger import init_logger
17
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
18
from vllm.multimodal import NestedTensors
19
from vllm.sequence import IntermediateTensors
20
21
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
                        is_uva_available)
22
23

logger = init_logger(__name__)
24

25
26
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
27

28

29
30
31
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
32

33
34
35
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
36

37
38
39
40
41
    def _map_name(self, key: str) -> Optional[str]:
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
60

61
    def apply(
62
63
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
64
65
        return ((out_name, data) for name, data in weights
                if (out_name := self._map_name(name)) is not None)
66

67
68
69
70
71
72
73
74
75
76
77
78
79
    def apply_list(self, values: list[str]) -> list[str]:
        return [
            out_name for name in values
            if (out_name := self._map_name(name)) is not None
        ]

    def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
        return {
            out_name: value
            for name, value in values.items()
            if (out_name := self._map_name(name)) is not None
        }

80
81

class AutoWeightsLoader:
82
    """
83
    Helper class to load weights into a [`torch.nn.Module`][]. It is able
84
85
86
87
88
89
90
91
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
    by defining a ``load_weights`` method.

    Similarly, the weight loading logic for individual parameters can be
    overridden by defining a ``weight_loader`` method.
92
93
94

    Detailed weight loading information can be viewed by setting the
    environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
95
    """
96

97
98
99
100
101
102
103
104
    # Models trained using early version ColossalAI
    # may include these tensors in checkpoint. Skip them.
    ROTARY_EMBEDS_UNUSED_WEIGHTS = [
        "rotary_emb.inv_freq",
        "rotary_emb.cos_cached",
        "rotary_emb.sin_cached",
    ]

105
106
107
108
    def __init__(
        self,
        module: nn.Module,
        *,
109
        skip_prefixes: Optional[list[str]] = None,
110
        skip_substrs: Optional[list[str]] = None,
111
        ignore_unexpected_prefixes: Optional[list[str]] = None,
112
113
114
115
116
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
117
        self.skip_substrs = skip_substrs or []
118
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
119
120
        # update default skip_substrs
        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
121
122
123

    def _groupby_prefix(
        self,
124
125
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        weights_by_parts = ((weight_name.split(".", 1), weight_data)
                            for weight_name, weight_data in weights)

        for prefix, group in itertools.groupby(weights_by_parts,
                                               key=lambda x: x[0][0]):
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
                (("" if len(parts) == 1 else parts[1], weights_data)
                 for parts, weights_data in group),
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
148
149
        return (any(qualname.startswith(p) for p in self.skip_prefixes)
                or any(substr in qualname for substr in self.skip_substrs))
150
151
152
153
154
155
156
157
158

    def _can_ignore_unexpected(self, qualname: str) -> bool:
        return any(
            qualname.startswith(p) for p in self.ignore_unexpected_prefixes)

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
159
        weights: Iterable[tuple[str, torch.Tensor]],
160
    ) -> Iterable[str]:
161
162
163
164
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
165
166
                logger.debug("Skipping weight %s", weight_qualname)

167
168
169
                continue

            if weight_name != "":
170
171
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
172

173
174
175
176
177
                    continue

                raise ValueError(
                    f"Attempted to load nested weight '{weight_qualname}' "
                    f"into a single parameter '{base_prefix}'")
178
179
180
181
182

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, weight_data)

183
184
185
            logger.debug("Loaded weight %s with shape %s", weight_qualname,
                         param.shape)

186
187
            yield weight_qualname

188
    def _add_loadable_non_param_tensors(self, module: nn.Module,
189
                                        child_params: dict[str, torch.Tensor]):
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        """
        Add tensor names that are not in the model params that may be in the
        safetensors, e.g., batch normalization stats.
        """
        if isinstance(module, (
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.LazyBatchNorm1d,
                nn.LazyBatchNorm2d,
                nn.LazyBatchNorm3d,
                nn.SyncBatchNorm,
        )):
            module_state_dict = module.state_dict()
            for stat_name in ("running_mean", "running_var",
                              "num_batches_tracked"):
                child_params[stat_name] = module_state_dict[stat_name]

208
209
210
211
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
212
        weights: Iterable[tuple[str, torch.Tensor]],
213
    ) -> Iterable[str]:
214
215
216
217
218
219
220
221
        if isinstance(module, PPMissingLayer):
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
222
                loaded_params = module_load_weights(weights)
223
224
225
226
227
228
229
230
231
                if loaded_params is None:
                    logger.warning(
                        "Unable to collect loaded parameters "
                        "for module %s", module)
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
232
233
234
235

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

236
237
238
239
        # Add missing tensors the weight loader needs to be able to load
        # that aren't registered as params, e.g., batchnorm statistics.
        self._add_loadable_non_param_tensors(module, child_params)

240
241
242
243
        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
244
245
246
247
248
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

249
250
251
                yield from self._load_module(prefix,
                                             child_modules[child_prefix],
                                             child_weights)
252
            elif child_prefix in child_params:
253
254
255
256
257
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

258
259
                yield from self._load_param(prefix, child_params[child_prefix],
                                            child_weights)
260
            else:
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

                msg = (f"There is no module or parameter named '{prefix}' "
                       f"in {type(self.module).__name__}")
                raise ValueError(msg)
278
279
280

    def load_weights(
        self,
281
        weights: Iterable[tuple[str, torch.Tensor]],
282
283
        *,
        mapper: Optional[WeightsMapper] = None,
284
    ) -> set[str]:
285
286
        if mapper is not None:
            weights = mapper.apply(weights)
287
288
289
        # filter out weights with first-prefix/substr to skip in name
        weights = ((name, weight) for name, weight in weights
                   if not self._can_skip(name))
290

291
        autoloaded_weights = set(self._load_module("", self.module, weights))
292
        return autoloaded_weights
293
294


295
def init_vllm_registered_model(
296
    vllm_config: VllmConfig,
297
    *,
298
    prefix: str = "",
299
300
    hf_config: Optional[PretrainedConfig] = None,
    architectures: Optional[list[str]] = None,
301
302
303
304
305
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
306
    from vllm.model_executor.model_loader.utils import initialize_model
307

308
309
310
311
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

312
    if hf_config is not None:
313
314
        vllm_config = vllm_config.with_hf_config(hf_config,
                                                 architectures=architectures)
315

316
    return initialize_model(vllm_config=vllm_config, prefix=prefix)
317
318


319
320
321
322
323
324
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
    ...


@overload
325
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]:
326
327
328
329
330
    ...


@overload
def flatten_bn(
331
    x: Union[list[torch.Tensor], torch.Tensor],
332
333
334
335
336
337
    *,
    concat: Literal[True],
) -> torch.Tensor:
    ...


338
339
@overload
def flatten_bn(
340
    x: Union[list[torch.Tensor], torch.Tensor],
341
342
    *,
    concat: bool = False,
343
) -> Union[list[torch.Tensor], torch.Tensor]:
344
345
346
    ...


347
def flatten_bn(
348
    x: Union[list[torch.Tensor], torch.Tensor],
349
350
    *,
    concat: bool = False,
351
) -> Union[list[torch.Tensor], torch.Tensor]:
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    """
    Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.

    The input tensor should have shape ``(B, N, ...)```.
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


366
367
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
368
369
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
370
371
372
    """

    if isinstance(embeddings, torch.Tensor):
373
374
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

    return " + ".join(
        _embedding_count_expression(inner) for inner in embeddings)


Cyrus Leung's avatar
Cyrus Leung committed
392
393
394
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
395
    is_multimodal: torch.Tensor,
Cyrus Leung's avatar
Cyrus Leung committed
396
) -> torch.Tensor:
397
    """
398
399
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
400
    ``input_ids``.
401
402

    Note:
403
        This updates ``inputs_embeds`` in place.
404
    """
405
406
407
408
409
410
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

411
    try:
412
413
414
415
416
        # For debugging
        # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)

        # NOTE: This can avoid D2H sync (#22105), but fails to
        # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
417
        inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
418
                                      mm_embeds_flat.to(dtype=input_dtype))
419
    except RuntimeError as e:
420
        num_actual_tokens = len(mm_embeds_flat)
421
422
        num_expected_tokens = is_multimodal.sum().item()

423
        if num_actual_tokens != num_expected_tokens:
424
            expr = _embedding_count_expression(multimodal_embeddings)
425

426
            raise ValueError(
427
                f"Attempted to assign {expr} = {num_actual_tokens} "
428
429
                f"multimodal tokens to {num_expected_tokens} placeholders"
            ) from e
Cyrus Leung's avatar
Cyrus Leung committed
430

431
        raise ValueError("Error during masked scatter operation") from e
Cyrus Leung's avatar
Cyrus Leung committed
432

433
    return inputs_embeds
Cyrus Leung's avatar
Cyrus Leung committed
434
435
436
437
438
439


def merge_multimodal_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
440
    placeholder_token_id: Union[int, list[int]],
Cyrus Leung's avatar
Cyrus Leung committed
441
442
443
444
445
) -> torch.Tensor:
    """
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
    ``input_ids``.
446
447
448
449
450

    ``placeholder_token_id`` can be a list of token ids (e.g, token ids
    of img_start, img_break, and img_end tokens) when needed: This means
    the order of these tokens in the ``input_ids`` MUST MATCH the order of
    their embeddings in ``multimodal_embeddings`` since we need to
451
452
453
454
455
456
457
458
    slice-merge instead of individually scattering.

    For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
    - T is text token
    - S is image start token
    - I is image embedding token
    - B is image break token
    - E is image end token.
459
460
461

    Then the image embeddings (that correspond to I's) from vision encoder
    must be padded with embeddings of S, B, and E in the same order of
462
    input_ids for a correct embedding merge.
Cyrus Leung's avatar
Cyrus Leung committed
463
464
465
466

    Note:
        This updates ``inputs_embeds`` in place.
    """
467
    if isinstance(placeholder_token_id, list):
468
469
470
        is_multimodal = isin_list(input_ids, placeholder_token_id)
    else:
        is_multimodal = (input_ids == placeholder_token_id)
471

Cyrus Leung's avatar
Cyrus Leung committed
472
473
    return _merge_multimodal_embeddings(
        inputs_embeds,
474
475
        multimodal_embeddings=multimodal_embeddings,
        is_multimodal=is_multimodal,
Cyrus Leung's avatar
Cyrus Leung committed
476
477
478
    )


479
480
481
482
483
484
485
486
487
488
489
490
def isin_list(
    elements: torch.Tensor,
    test_elements_list: list[int],
) -> torch.Tensor:
    test_elements = torch.tensor(
        test_elements_list,
        pin_memory=is_pin_memory_available(),
    ).to(device=elements.device, non_blocking=True)

    return torch.isin(elements, test_elements)


491
492
class LayerFn(Protocol):

493
    def __call__(self, prefix: str) -> torch.nn.Module:
494
495
496
        ...


497
498
499
500
501
502
503
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
504
505

    def forward(self, *args, **kwargs):
506
507
        """Return the first arg from args or the first value from kwargs."""
        return args[0] if args else next(iter(kwargs.values()))
508
509


510
511
512
513
514
515
516
517
518
519
520
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
521
522
523
524
    if (params := next(module.parameters(), None)) is None:
        return module

    device = params.device
525
526
527
528
529
530
531
532
533

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()
534
535
536
537
538
539
540
541
    uva_available = is_uva_available()

    if envs.VLLM_USE_V1:
        assert uva_available, ("V1 CPU offloading requires"
                               " uva (pin memory) support")
        uva_offloading = True
    else:
        uva_offloading = False
542
543
544

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
545
    offloaded_parameters = False
546
547
548
549
550
551
552
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
553
554
555
556
557
558
        cpu_data = torch.empty_strided(size=p.data.size(),
                                       stride=p.data.stride(),
                                       dtype=p.data.dtype,
                                       layout=p.data.layout,
                                       device='cpu',
                                       pin_memory=pin_memory)
559
        cpu_data.copy_(p.data)
560
561
562
563
564
565
        if not uva_offloading:
            p.data = cpu_data
        else:
            # keep the cpu data alive
            p._vllm_offloaded_cpu_data = cpu_data
            p.data = get_cuda_view_from_cpu_tensor(cpu_data)
566
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
567
568
        offloaded_parameters = True

569
    if offloaded_parameters and not uva_offloading:
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module,
                                     device_state,
                                     args=args,
                                     kwargs=kwargs)
            module.forward = forward
            return output
586
587
588
589
590
591

        module.forward = forward

    return module


592
def make_layers(
593
594
595
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
596
) -> tuple[int, int, torch.nn.ModuleList]:
597
598
599
600
601
602
603
604
605
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
    start_layer, end_layer = get_pp_indices(num_hidden_layers,
                                            get_pp_group().rank_in_group,
                                            get_pp_group().world_size)
    modules = torch.nn.ModuleList(
606
        [PPMissingLayer() for _ in range(start_layer)] + [
607
608
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
609
        ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
610
611
612
613
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
614
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
615
616


617
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
618
619
620
621
622
623
624
625
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
626
627
628
629
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
            missing_layer_names.append(name + '.')
630
631
632
633
634
635
636
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
637
638
639
640
641
642
    if isinstance(model, PPMissingLayer):
        return True

    return any(
        name.startswith(missing_layer_name)
        for missing_layer_name in get_pp_missing_layer_names(model))
643
644


645
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
646
647

    def make_empty_intermediate_tensors(
648
649
650
651
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
652
        return IntermediateTensors({
653
654
            key:
            torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
655
656
657
658
            for key in keys
        })

    return make_empty_intermediate_tensors
659
660


661
662
663
664
665
666
667
668
669
670
671
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
672
673


XuruiYang's avatar
XuruiYang committed
674
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
675
676
677
678
679
680
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
XuruiYang's avatar
XuruiYang committed
681
    - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
682
683
    """
    subnames = layer_name.split(".")
684
    int_vals: list[int] = []
685
686
687
688
689
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
XuruiYang's avatar
XuruiYang committed
690
691
692
693
694
695
696
697
698
699
700
    if num_attn_module == 1 or "attn" not in layer_name:
        assert len(int_vals) == 1, (f"layer name {layer_name} should"
                                    " only contain one integer")

        return int_vals[0]
    else:
        assert len(int_vals) <= 2, (f"layer name {layer_name} should"
                                    " contain most two integers")
        layer_index = int_vals[0] * num_attn_module + int_vals[1] if len(
            int_vals) == 2 else int_vals[0]
        return layer_index
701
702
703
704
705
706
707
708
709


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
710
    return tensors
711
712


713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
def fast_topk(values: torch.Tensor, topk: int,
              dim: int) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Optimized topk implementation that uses torch.max for k=1 case.
    
    This function provides better performance for the common case of k=1
    by using torch.max instead of the more general torch.topk.
    
    Args:
        values: Input tensor to find top-k values from
        topk: Number of top values to return (k). Must be > 0.
        dim: Dimension along which to compute topk
        
    Returns:
        Tuple of (values, indices) where values are the top-k values
        and indices are their corresponding indices in the input tensor
    """
730
731
732
733
734
735
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
736
737
738
739
740
741
742


def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
    if hasattr(hf_config, "hidden_size"):
        return hf_config.hidden_size
    text_config = hf_config.get_text_config()
    return text_config.hidden_size