utils.py 26.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Iterable, Mapping
6
from dataclasses import dataclass, field
7
from typing import Any, Literal, Protocol, overload
8

9
import torch
10
import torch.nn as nn
11
from torch.func import functional_call
12
from transformers import PretrainedConfig
13
from typing_extensions import deprecated
14

15
from vllm.config import VllmConfig
16
17
18
19
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
20
from vllm.logger import init_logger
21
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
from vllm.multimodal import NestedTensors
23
from vllm.sequence import IntermediateTensors
24
from vllm.utils.math_utils import cdiv
25
from vllm.utils.platform_utils import (
26
27
28
    is_pin_memory_available,
    is_uva_available,
)
29
30
31
32
from vllm.utils.torch_utils import (
    direct_register_custom_op,
    get_cuda_view_from_cpu_tensor,
)
33
34

logger = init_logger(__name__)
35

36
WeightsMapping = Mapping[str, str | None]
37
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
38

39

40
41
42
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
43

44
45
46
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
47

48
49
50
51
52
53
54
55
    def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
        """Combine two `WeightsMapper`s by merging their mappings."""
        return WeightsMapper(
            orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
            orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
            orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
        )

56
    def _map_name(self, key: str) -> str | None:
57
58
59
60
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
79

80
    def apply(
81
82
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
83
84
85
86
87
        return (
            (out_name, data)
            for name, data in weights
            if (out_name := self._map_name(name)) is not None
        )
88

89
90
    def apply_list(self, values: list[str]) -> list[str]:
        return [
91
92
            out_name
            for name in values
93
94
95
96
97
98
99
100
101
102
            if (out_name := self._map_name(name)) is not None
        ]

    def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
        return {
            out_name: value
            for name, value in values.items()
            if (out_name := self._map_name(name)) is not None
        }

103
104

class AutoWeightsLoader:
105
    """
106
    Helper class to load weights into a [`torch.nn.Module`][]. It is able
107
108
109
110
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
111
    by defining a `load_weights` method.
112
113

    Similarly, the weight loading logic for individual parameters can be
114
    overridden by defining a `weight_loader` method.
115
116

    Detailed weight loading information can be viewed by setting the
117
    environment variable `VLLM_LOGGING_LEVEL=DEBUG`.
118
    """
119

120
121
122
123
124
125
126
127
    # Models trained using early version ColossalAI
    # may include these tensors in checkpoint. Skip them.
    ROTARY_EMBEDS_UNUSED_WEIGHTS = [
        "rotary_emb.inv_freq",
        "rotary_emb.cos_cached",
        "rotary_emb.sin_cached",
    ]

128
129
130
131
    def __init__(
        self,
        module: nn.Module,
        *,
132
133
134
135
        skip_prefixes: list[str] | None = None,
        skip_substrs: list[str] | None = None,
        ignore_unexpected_prefixes: list[str] | None = None,
        ignore_unexpected_suffixes: list[str] | None = None,
136
137
138
139
140
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
141
        self.skip_substrs = skip_substrs or []
142
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
143
        self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
144
145
        # update default skip_substrs
        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
146
147
148

    def _groupby_prefix(
        self,
149
150
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
151
152
153
154
        weights_by_parts = (
            (weight_name.split(".", 1), weight_data)
            for weight_name, weight_data in weights
        )
155

156
        for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]):
157
158
159
160
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
161
162
163
164
                (
                    ("" if len(parts) == 1 else parts[1], weights_data)
                    for parts, weights_data in group
                ),
165
166
167
168
169
170
171
172
173
174
175
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
176
177
178
        return any(qualname.startswith(p) for p in self.skip_prefixes) or any(
            substr in qualname for substr in self.skip_substrs
        )
179
180

    def _can_ignore_unexpected(self, qualname: str) -> bool:
181
182
183
        iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
        ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
        return any(iup) or any(ius)
184
185
186
187
188

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
189
        weights: Iterable[tuple[str, torch.Tensor]],
190
    ) -> Iterable[str]:
191
192
193
194
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
195
196
                logger.debug("Skipping weight %s", weight_qualname)

197
198
199
                continue

            if weight_name != "":
200
201
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
202

203
204
205
206
                    continue

                raise ValueError(
                    f"Attempted to load nested weight '{weight_qualname}' "
207
208
                    f"into a single parameter '{base_prefix}'"
                )
209

210
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
211
212
            weight_loader(param, weight_data)

213
            logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape)
214

215
216
            yield weight_qualname

217
218
219
    def _add_loadable_non_param_tensors(
        self, module: nn.Module, child_params: dict[str, torch.Tensor]
    ):
220
221
222
223
        """
        Add tensor names that are not in the model params that may be in the
        safetensors, e.g., batch normalization stats.
        """
224
225
226
        if isinstance(
            module,
            (
227
228
229
230
231
232
233
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.LazyBatchNorm1d,
                nn.LazyBatchNorm2d,
                nn.LazyBatchNorm3d,
                nn.SyncBatchNorm,
234
235
            ),
        ):
236
            module_state_dict = module.state_dict()
237
            for stat_name in ("running_mean", "running_var", "num_batches_tracked"):
238
239
                child_params[stat_name] = module_state_dict[stat_name]

240
241
242
243
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
244
        weights: Iterable[tuple[str, torch.Tensor]],
245
    ) -> Iterable[str]:
246
247
248
249
250
251
252
253
        if isinstance(module, PPMissingLayer):
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
254
                loaded_params = module_load_weights(weights)
255
256
                if loaded_params is None:
                    logger.warning(
257
258
                        "Unable to collect loaded parameters for module %s", module
                    )
259
260
261
262
263
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
264
265
266
267

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

268
269
270
271
        # Add missing tensors the weight loader needs to be able to load
        # that aren't registered as params, e.g., batchnorm statistics.
        self._add_loadable_non_param_tensors(module, child_params)

272
273
274
275
        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
276
277
278
279
280
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

281
282
283
                yield from self._load_module(
                    prefix, child_modules[child_prefix], child_weights
                )
284
            elif child_prefix in child_params:
285
286
287
288
289
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

290
291
292
                yield from self._load_param(
                    prefix, child_params[child_prefix], child_weights
                )
293
            else:
294
295
296
297
298
299
300
301
302
303
304
305
306
307
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

308
309
310
311
                msg = (
                    f"There is no module or parameter named '{prefix}' "
                    f"in {type(self.module).__name__}"
                )
312
                raise ValueError(msg)
313
314
315

    def load_weights(
        self,
316
        weights: Iterable[tuple[str, torch.Tensor]],
317
        *,
318
        mapper: WeightsMapper | None = None,
319
    ) -> set[str]:
320
321
        if mapper is not None:
            weights = mapper.apply(weights)
322
        # filter out weights with first-prefix/substr to skip in name
323
324
325
        weights = (
            (name, weight) for name, weight in weights if not self._can_skip(name)
        )
326

327
        autoloaded_weights = set(self._load_module("", self.module, weights))
328
        return autoloaded_weights
329
330


331
def init_vllm_registered_model(
332
    vllm_config: VllmConfig,
333
    *,
334
    prefix: str = "",
335
336
    hf_config: PretrainedConfig | None = None,
    architectures: list[str] | None = None,
337
338
339
340
341
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
342
    from vllm.model_executor.model_loader.utils import initialize_model
343

344
345
346
347
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

348
    if hf_config is not None:
349
        vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures)
350

351
    return initialize_model(vllm_config=vllm_config, prefix=prefix)
352
353


354
@overload
355
def flatten_bn(x: torch.Tensor) -> torch.Tensor: ...
356
357
358


@overload
359
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ...
360
361
362
363


@overload
def flatten_bn(
364
    x: list[torch.Tensor] | torch.Tensor,
365
366
    *,
    concat: Literal[True],
367
) -> torch.Tensor: ...
368
369


370
371
@overload
def flatten_bn(
372
    x: list[torch.Tensor] | torch.Tensor,
373
374
    *,
    concat: bool = False,
375
) -> list[torch.Tensor] | torch.Tensor: ...
376
377


378
def flatten_bn(
379
    x: list[torch.Tensor] | torch.Tensor,
380
381
    *,
    concat: bool = False,
382
) -> list[torch.Tensor] | torch.Tensor:
383
    """
384
    Flatten the `B` and `N` dimensions of batched multimodal inputs.
385

386
    The input tensor should have shape `(B, N, ...)`.
387
388
389
390
391
392
393
394
395
396
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


397
398
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
399
400
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
401
402
403
    """

    if isinstance(embeddings, torch.Tensor):
404
405
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
406
407
408
409
410
411
412
413
414
415
416
417
418

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

419
    return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
420
421


422
423
424
425
426
427
428
429
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
    ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
    for num in lst:
        index = num // interval
        ranges[index].append(num)
    return ranges


Cyrus Leung's avatar
Cyrus Leung committed
430
431
432
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
433
    is_multimodal: torch.Tensor,
Cyrus Leung's avatar
Cyrus Leung committed
434
) -> torch.Tensor:
435
    """
436
437
438
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
439
440

    Note:
441
        This updates `inputs_embeds` in place.
442
    """
443
444
445
446
447
448
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

449
    try:
450
451
452
453
454
        # For debugging
        # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)

        # NOTE: This can avoid D2H sync (#22105), but fails to
        # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
455
456
457
        inputs_embeds.masked_scatter_(
            is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
        )
458
    except RuntimeError as e:
459
        num_actual_tokens = len(mm_embeds_flat)
460
461
        num_expected_tokens = is_multimodal.sum().item()

462
        if num_actual_tokens != num_expected_tokens:
463
            expr = _embedding_count_expression(multimodal_embeddings)
464

465
            raise ValueError(
466
                f"Attempted to assign {expr} = {num_actual_tokens} "
467
468
                f"multimodal tokens to {num_expected_tokens} placeholders"
            ) from e
Cyrus Leung's avatar
Cyrus Leung committed
469

470
        raise ValueError("Error during masked scatter operation") from e
Cyrus Leung's avatar
Cyrus Leung committed
471

472
    return inputs_embeds
Cyrus Leung's avatar
Cyrus Leung committed
473
474


475
476
@deprecated(
    "`merge_multimodal_embeddings` has been replaced with "
477
    "`SupportsMultiModal.embed_input_ids` and will be "
478
479
    "removed in v0.12."
)
Cyrus Leung's avatar
Cyrus Leung committed
480
481
482
483
def merge_multimodal_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
484
    placeholder_token_id: int | list[int],
Cyrus Leung's avatar
Cyrus Leung committed
485
486
) -> torch.Tensor:
    """
487
488
489
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
490

491
    `placeholder_token_id` can be a list of token ids (e.g, token ids
492
    of img_start, img_break, and img_end tokens) when needed: This means
493
494
    the order of these tokens in the `input_ids` MUST MATCH the order of
    their embeddings in `multimodal_embeddings` since we need to
495
496
497
498
499
500
501
502
    slice-merge instead of individually scattering.

    For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
    - T is text token
    - S is image start token
    - I is image embedding token
    - B is image break token
    - E is image end token.
503
504
505

    Then the image embeddings (that correspond to I's) from vision encoder
    must be padded with embeddings of S, B, and E in the same order of
506
    input_ids for a correct embedding merge.
Cyrus Leung's avatar
Cyrus Leung committed
507
508

    Note:
509
        This updates `inputs_embeds` in place.
Cyrus Leung's avatar
Cyrus Leung committed
510
    """
511
    if isinstance(placeholder_token_id, list):
512
513
        is_multimodal = isin_list(input_ids, placeholder_token_id)
    else:
514
        is_multimodal = input_ids == placeholder_token_id
515

Cyrus Leung's avatar
Cyrus Leung committed
516
517
    return _merge_multimodal_embeddings(
        inputs_embeds,
518
519
        multimodal_embeddings=multimodal_embeddings,
        is_multimodal=is_multimodal,
Cyrus Leung's avatar
Cyrus Leung committed
520
521
522
    )


523
524
525
526
527
528
529
530
531
532
533
534
def isin_list(
    elements: torch.Tensor,
    test_elements_list: list[int],
) -> torch.Tensor:
    test_elements = torch.tensor(
        test_elements_list,
        pin_memory=is_pin_memory_available(),
    ).to(device=elements.device, non_blocking=True)

    return torch.isin(elements, test_elements)


535
class LayerFn(Protocol):
536
    def __call__(self, prefix: str) -> torch.nn.Module: ...
537
538


539
540
541
542
543
544
545
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
546
547

    def forward(self, *args, **kwargs):
548
549
        """Return the first arg from args or the first value from kwargs."""
        return args[0] if args else next(iter(kwargs.values()))
550
551


552
553
554
555
556
557
558
559
560
561
562
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
563
564
565
566
    if (params := next(module.parameters(), None)) is None:
        return module

    device = params.device
567
568
569
570
571
572
573
574
575

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()
576
577
    uva_available = is_uva_available()

578
579
    assert uva_available, "V1 CPU offloading requires uva (pin memory) support"
    uva_offloading = True
580
581
582

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
583
    offloaded_parameters = False
584
585
586
587
588
589
590
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
591
592
593
594
595
596
597
598
        cpu_data = torch.empty_strided(
            size=p.data.size(),
            stride=p.data.stride(),
            dtype=p.data.dtype,
            layout=p.data.layout,
            device="cpu",
            pin_memory=pin_memory,
        )
599
        cpu_data.copy_(p.data)
600
601
602
603
604
605
        if not uva_offloading:
            p.data = cpu_data
        else:
            # keep the cpu data alive
            p._vllm_offloaded_cpu_data = cpu_data
            p.data = get_cuda_view_from_cpu_tensor(cpu_data)
606
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
607
608
        offloaded_parameters = True

609
    if offloaded_parameters and not uva_offloading:
610
611
612
613
614
615
616
617
618
619
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
620
            output = functional_call(module, device_state, args=args, kwargs=kwargs)
621
622
            module.forward = forward
            return output
623
624
625
626
627
628

        module.forward = forward

    return module


629
def make_layers(
630
631
632
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
633
) -> tuple[int, int, torch.nn.ModuleList]:
634
635
636
637
638
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
639
640
641
642

    start_layer, end_layer = get_pp_indices(
        num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size
    )
643
    modules = torch.nn.ModuleList(
644
645
        [PPMissingLayer() for _ in range(start_layer)]
        + [
646
647
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
648
649
650
        ]
        + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
    )
651
652
653
654
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
655
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
656
657


658
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
659
660
661
662
663
664
665
666
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
667
668
669
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
670
            missing_layer_names.append(name + ".")
671
672
673
674
675
676
677
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
678
679
680
681
682
    if isinstance(model, PPMissingLayer):
        return True

    return any(
        name.startswith(missing_layer_name)
683
684
        for missing_layer_name in get_pp_missing_layer_names(model)
    )
685
686


687
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
688
    def make_empty_intermediate_tensors(
689
690
691
692
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
693
694
695
696
697
698
        return IntermediateTensors(
            {
                key: torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
                for key in keys
            }
        )
699
700

    return make_empty_intermediate_tensors
701
702


703
704
705
706
707
708
709
710
711
712
713
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
714
715


XuruiYang's avatar
XuruiYang committed
716
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
717
718
719
720
721
722
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
XuruiYang's avatar
XuruiYang committed
723
    - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
724
725
    """
    subnames = layer_name.split(".")
726
    int_vals: list[int] = []
727
728
729
730
731
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
XuruiYang's avatar
XuruiYang committed
732
    if num_attn_module == 1 or "attn" not in layer_name:
733
734
735
        assert len(int_vals) == 1, (
            f"layer name {layer_name} should only contain one integer"
        )
XuruiYang's avatar
XuruiYang committed
736
737
738

        return int_vals[0]
    else:
739
740
741
742
743
744
745
746
        assert len(int_vals) <= 2, (
            f"layer name {layer_name} should contain most two integers"
        )
        layer_index = (
            int_vals[0] * num_attn_module + int_vals[1]
            if len(int_vals) == 2
            else int_vals[0]
        )
XuruiYang's avatar
XuruiYang committed
747
        return layer_index
748
749
750
751
752
753
754
755
756


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
757
    return tensors
758
759


760
761
762
def fast_topk(
    values: torch.Tensor, topk: int, dim: int
) -> tuple[torch.Tensor, torch.Tensor]:
763
764
    """
    Optimized topk implementation that uses torch.max for k=1 case.
765

766
767
    This function provides better performance for the common case of k=1
    by using torch.max instead of the more general torch.topk.
768

769
770
771
772
    Args:
        values: Input tensor to find top-k values from
        topk: Number of top values to return (k). Must be > 0.
        dim: Dimension along which to compute topk
773

774
775
776
777
    Returns:
        Tuple of (values, indices) where values are the top-k values
        and indices are their corresponding indices in the input tensor
    """
778
779
780
781
782
783
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
784
785


786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
    return torch.ops.vllm.sequence_parallel_chunk_impl(x)


def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    # all_gather needs the sequence length to be divisible by tp_size
    seq_len = x.size(0)
    remainder = seq_len % tp_size
    if remainder != 0:
        pad_len = tp_size - remainder
        y = nn.functional.pad(x, (0, 0, 0, pad_len))
    else:
        y = x

    chunk = y.shape[0] // tp_size
    start = tp_rank * chunk
    return torch.narrow(y, 0, start, chunk)


def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    seq_len = cdiv(x.size(0), tp_size)
    shape = list(x.shape)
    shape[0] = seq_len
    out = torch.empty(shape, dtype=x.dtype, device=x.device)
    return out


direct_register_custom_op(
    op_name="sequence_parallel_chunk_impl",
    op_func=sequence_parallel_chunk_impl,
    fake_impl=sequence_parallel_chunk_impl_fake,
825
    tags=(torch.Tag.needs_fixed_stride_order,),
826
)