utils.py 26.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Iterable, Mapping
6
from dataclasses import dataclass, field
7
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload
8

9
import torch
10
import torch.nn as nn
11
from torch.func import functional_call
12
from transformers import PretrainedConfig
cx's avatar
cx committed
13
from typing_extensions import deprecated
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
19
from vllm.logger import init_logger
20
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
from vllm.multimodal import NestedTensors
22
from vllm.sequence import IntermediateTensors
23
24
from vllm.utils import (cdiv, direct_register_custom_op,
                        get_cuda_view_from_cpu_tensor, is_pin_memory_available,
25
                        is_uva_available)
26
27

logger = init_logger(__name__)
28

29
30
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
31

32

33
34
35
@dataclass
class WeightsMapper:
    """Maps the name of each weight if they match the following patterns."""
36

37
38
39
    orig_to_new_substr: WeightsMapping = field(default_factory=dict)
    orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
    orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
40

41
42
43
44
45
    def _map_name(self, key: str) -> Optional[str]:
        for substr, new_key in self.orig_to_new_substr.items():
            if substr in key:
                if new_key is None:
                    return None
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
                key = key.replace(substr, new_key, 1)

        for prefix, new_key in self.orig_to_new_prefix.items():
            if key.startswith(prefix):
                if new_key is None:
                    return None

                key = key.replace(prefix, new_key, 1)

        for suffix, new_key in self.orig_to_new_suffix.items():
            if key.endswith(suffix):
                if new_key is None:
                    return None

                key = new_key.join(key.rsplit(suffix, 1))

        return key
64

65
    def apply(
66
67
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
68
69
        return ((out_name, data) for name, data in weights
                if (out_name := self._map_name(name)) is not None)
70

71
72
73
74
75
76
77
78
79
80
81
82
83
    def apply_list(self, values: list[str]) -> list[str]:
        return [
            out_name for name in values
            if (out_name := self._map_name(name)) is not None
        ]

    def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
        return {
            out_name: value
            for name, value in values.items()
            if (out_name := self._map_name(name)) is not None
        }

84
85

class AutoWeightsLoader:
86
    """
87
    Helper class to load weights into a [`torch.nn.Module`][]. It is able
88
89
90
91
92
93
94
95
    to automatically detect child modules and parameters while iterating over
    the weights only once.

    The weight loading logic for individual modules can be overridden
    by defining a ``load_weights`` method.

    Similarly, the weight loading logic for individual parameters can be
    overridden by defining a ``weight_loader`` method.
96
97
98

    Detailed weight loading information can be viewed by setting the
    environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
99
    """
100

101
102
103
104
105
106
107
108
    # Models trained using early version ColossalAI
    # may include these tensors in checkpoint. Skip them.
    ROTARY_EMBEDS_UNUSED_WEIGHTS = [
        "rotary_emb.inv_freq",
        "rotary_emb.cos_cached",
        "rotary_emb.sin_cached",
    ]

109
110
111
112
    def __init__(
        self,
        module: nn.Module,
        *,
113
        skip_prefixes: Optional[list[str]] = None,
114
        skip_substrs: Optional[list[str]] = None,
115
        ignore_unexpected_prefixes: Optional[list[str]] = None,
116
117
118
119
120
    ) -> None:
        super().__init__()

        self.module = module
        self.skip_prefixes = skip_prefixes or []
121
        self.skip_substrs = skip_substrs or []
122
        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
123
124
        # update default skip_substrs
        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
125
126
127

    def _groupby_prefix(
        self,
128
129
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        weights_by_parts = ((weight_name.split(".", 1), weight_data)
                            for weight_name, weight_data in weights)

        for prefix, group in itertools.groupby(weights_by_parts,
                                               key=lambda x: x[0][0]):
            yield (
                prefix,
                # Because maxsplit=1 in weight_name.split(...),
                # the length of `parts` must either be 1 or 2
                (("" if len(parts) == 1 else parts[1], weights_data)
                 for parts, weights_data in group),
            )

    def _get_qualname(self, prefix: str, rest: str) -> str:
        if prefix == "":
            return rest
        if rest == "":
            return prefix

        return ".".join((prefix, rest))

    def _can_skip(self, qualname: str) -> bool:
152
153
        return (any(qualname.startswith(p) for p in self.skip_prefixes)
                or any(substr in qualname for substr in self.skip_substrs))
154
155
156
157
158
159
160
161
162

    def _can_ignore_unexpected(self, qualname: str) -> bool:
        return any(
            qualname.startswith(p) for p in self.ignore_unexpected_prefixes)

    def _load_param(
        self,
        base_prefix: str,
        param: nn.Parameter,
163
        weights: Iterable[tuple[str, torch.Tensor]],
164
    ) -> Iterable[str]:
165
166
167
168
        for weight_name, weight_data in weights:
            weight_qualname = self._get_qualname(base_prefix, weight_name)

            if self._can_skip(weight_qualname):
169
170
                logger.debug("Skipping weight %s", weight_qualname)

171
172
173
                continue

            if weight_name != "":
174
175
                if self._can_ignore_unexpected(weight_qualname):
                    logger.debug("Ignoring weight %s", weight_qualname)
176

177
178
179
180
181
                    continue

                raise ValueError(
                    f"Attempted to load nested weight '{weight_qualname}' "
                    f"into a single parameter '{base_prefix}'")
182
183
184
185
186

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, weight_data)

187
188
189
            logger.debug("Loaded weight %s with shape %s", weight_qualname,
                         param.shape)

190
191
            yield weight_qualname

192
    def _add_loadable_non_param_tensors(self, module: nn.Module,
193
                                        child_params: dict[str, torch.Tensor]):
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        """
        Add tensor names that are not in the model params that may be in the
        safetensors, e.g., batch normalization stats.
        """
        if isinstance(module, (
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.LazyBatchNorm1d,
                nn.LazyBatchNorm2d,
                nn.LazyBatchNorm3d,
                nn.SyncBatchNorm,
        )):
            module_state_dict = module.state_dict()
            for stat_name in ("running_mean", "running_var",
                              "num_batches_tracked"):
                child_params[stat_name] = module_state_dict[stat_name]

212
213
214
215
    def _load_module(
        self,
        base_prefix: str,
        module: nn.Module,
216
        weights: Iterable[tuple[str, torch.Tensor]],
217
    ) -> Iterable[str]:
218
219
220
221
222
223
224
225
        if isinstance(module, PPMissingLayer):
            return

        # Avoid infinite recursion since this function is typically
        # called inside load_weights of the module itself
        if module != self.module:
            module_load_weights = getattr(module, "load_weights", None)
            if callable(module_load_weights):
226
                loaded_params = module_load_weights(weights)
227
228
229
230
231
232
233
234
235
                if loaded_params is None:
                    logger.warning(
                        "Unable to collect loaded parameters "
                        "for module %s", module)
                else:
                    yield from map(
                        lambda x: self._get_qualname(base_prefix, x),
                        loaded_params,
                    )
236
237
238
239

        child_modules = dict(module.named_children())
        child_params = dict(module.named_parameters(recurse=False))

240
241
242
243
        # Add missing tensors the weight loader needs to be able to load
        # that aren't registered as params, e.g., batchnorm statistics.
        self._add_loadable_non_param_tensors(module, child_params)

244
245
246
247
        for child_prefix, child_weights in self._groupby_prefix(weights):
            prefix = self._get_qualname(base_prefix, child_prefix)

            if child_prefix in child_modules:
248
249
250
251
252
                if self._can_skip(prefix + "."):
                    logger.debug("Skipping module %s", prefix)

                    continue

253
254
255
                yield from self._load_module(prefix,
                                             child_modules[child_prefix],
                                             child_weights)
256
            elif child_prefix in child_params:
257
258
259
260
261
                if self._can_skip(prefix):
                    logger.debug("Skipping param %s", prefix)

                    continue

262
263
                yield from self._load_param(prefix, child_params[child_prefix],
                                            child_weights)
264
            else:
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
                can_skip_module = self._can_skip(prefix + ".")
                can_skip_param = self._can_skip(prefix)
                if can_skip_module or can_skip_param:
                    logger.debug("Skipping missing %s", prefix)

                    continue

                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
                can_ignore_param = self._can_ignore_unexpected(prefix)
                if can_ignore_module or can_ignore_param:
                    logger.debug("Ignoring missing %s", prefix)

                    continue

                msg = (f"There is no module or parameter named '{prefix}' "
                       f"in {type(self.module).__name__}")
                raise ValueError(msg)
282
283
284

    def load_weights(
        self,
285
        weights: Iterable[tuple[str, torch.Tensor]],
286
287
        *,
        mapper: Optional[WeightsMapper] = None,
288
    ) -> set[str]:
289
290
        if mapper is not None:
            weights = mapper.apply(weights)
291
292
293
        # filter out weights with first-prefix/substr to skip in name
        weights = ((name, weight) for name, weight in weights
                   if not self._can_skip(name))
294

295
        autoloaded_weights = set(self._load_module("", self.module, weights))
296
        return autoloaded_weights
297
298


299
def init_vllm_registered_model(
300
    vllm_config: VllmConfig,
301
    *,
302
    prefix: str = "",
303
304
    hf_config: Optional[PretrainedConfig] = None,
    architectures: Optional[list[str]] = None,
305
306
307
308
309
) -> nn.Module:
    """
    Helper function to initialize an inner model registered to vLLM,
    based on the arguments passed to the outer vLLM model.
    """
310
    from vllm.model_executor.model_loader.utils import initialize_model
311

312
313
314
315
    if hf_config is None and architectures is not None:
        # So that the architectures field is overridden
        hf_config = vllm_config.model_config.hf_config

316
    if hf_config is not None:
317
318
        vllm_config = vllm_config.with_hf_config(hf_config,
                                                 architectures=architectures)
319

320
    return initialize_model(vllm_config=vllm_config, prefix=prefix)
321
322


323
324
325
326
327
328
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
    ...


@overload
329
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]:
330
331
332
333
334
    ...


@overload
def flatten_bn(
335
    x: Union[list[torch.Tensor], torch.Tensor],
336
337
338
339
340
341
    *,
    concat: Literal[True],
) -> torch.Tensor:
    ...


342
343
@overload
def flatten_bn(
344
    x: Union[list[torch.Tensor], torch.Tensor],
345
346
    *,
    concat: bool = False,
347
) -> Union[list[torch.Tensor], torch.Tensor]:
348
349
350
    ...


351
def flatten_bn(
352
    x: Union[list[torch.Tensor], torch.Tensor],
353
354
    *,
    concat: bool = False,
355
) -> Union[list[torch.Tensor], torch.Tensor]:
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    """
    Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.

    The input tensor should have shape ``(B, N, ...)```.
    """
    if isinstance(x, torch.Tensor):
        return x.flatten(0, 1)

    if concat:
        return torch.cat(x)

    return [x_n for x_b in x for x_n in x_b]


370
371
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
    """
372
373
    Recursively flattens and concatenates NestedTensors on all but the last
    dimension.
374
375
376
    """

    if isinstance(embeddings, torch.Tensor):
377
378
        # Flatten all but the last dimension.
        return embeddings.flatten(0, -2)
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

    return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
    """
    Constructs a debugging representation of the number of embeddings in the
    NestedTensors.
    """

    if isinstance(embeddings, torch.Tensor):
        return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

    return " + ".join(
        _embedding_count_expression(inner) for inner in embeddings)

cx's avatar
cx committed
395
396
397
398
399
400
401
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
    ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
    for num in lst:
        index = num // interval
        ranges[index].append(num)
    return ranges

402

Cyrus Leung's avatar
Cyrus Leung committed
403
404
405
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
cx's avatar
cx committed
406
    is_multimodal: torch.Tensor,
Cyrus Leung's avatar
Cyrus Leung committed
407
) -> torch.Tensor:
408
    """
cx's avatar
cx committed
409
410
411
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
412
413

    Note:
cx's avatar
cx committed
414
        This updates `inputs_embeds` in place.
415
    """
cx's avatar
cx committed
416
417
418
419
420
421
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

422
    try:
cx's avatar
cx committed
423
424
425
426
427
428
429
430
        # For debugging
        # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)

        # NOTE: This can avoid D2H sync (#22105), but fails to
        # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
        inputs_embeds.masked_scatter_(
            is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
        )
431
    except RuntimeError as e:
cx's avatar
cx committed
432
        num_actual_tokens = len(mm_embeds_flat)
433
434
        num_expected_tokens = is_multimodal.sum().item()

cx's avatar
cx committed
435
        if num_actual_tokens != num_expected_tokens:
436
            expr = _embedding_count_expression(multimodal_embeddings)
cx's avatar
cx committed
437

438
            raise ValueError(
cx's avatar
cx committed
439
                f"Attempted to assign {expr} = {num_actual_tokens} "
440
441
                f"multimodal tokens to {num_expected_tokens} placeholders"
            ) from e
442

cx's avatar
cx committed
443
        raise ValueError("Error during masked scatter operation") from e
Cyrus Leung's avatar
Cyrus Leung committed
444

cx's avatar
cx committed
445
    return inputs_embeds
Cyrus Leung's avatar
Cyrus Leung committed
446
447


cx's avatar
cx committed
448
449
450
451
452
@deprecated(
    "`merge_multimodal_embeddings` has been replaced with "
    "`SupportsMultiModal.get_input_embeddings` and will be "
    "removed in v0.12."
)
Cyrus Leung's avatar
Cyrus Leung committed
453
454
455
456
def merge_multimodal_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
cx's avatar
cx committed
457
    placeholder_token_id: int | list[int],
Cyrus Leung's avatar
Cyrus Leung committed
458
459
) -> torch.Tensor:
    """
cx's avatar
cx committed
460
461
462
    Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
    positions in `inputs_embeds` corresponding to placeholder tokens in
    `input_ids`.
463

cx's avatar
cx committed
464
    `placeholder_token_id` can be a list of token ids (e.g, token ids
465
    of img_start, img_break, and img_end tokens) when needed: This means
cx's avatar
cx committed
466
467
    the order of these tokens in the `input_ids` MUST MATCH the order of
    their embeddings in `multimodal_embeddings` since we need to
468
469
470
471
472
473
474
475
    slice-merge instead of individually scattering.

    For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
    - T is text token
    - S is image start token
    - I is image embedding token
    - B is image break token
    - E is image end token.
476
477
478

    Then the image embeddings (that correspond to I's) from vision encoder
    must be padded with embeddings of S, B, and E in the same order of
479
    input_ids for a correct embedding merge.
Cyrus Leung's avatar
Cyrus Leung committed
480
481

    Note:
cx's avatar
cx committed
482
        This updates `inputs_embeds` in place.
Cyrus Leung's avatar
Cyrus Leung committed
483
    """
484
    if isinstance(placeholder_token_id, list):
cx's avatar
cx committed
485
486
487
        is_multimodal = isin_list(input_ids, placeholder_token_id)
    else:
        is_multimodal = input_ids == placeholder_token_id
488

Cyrus Leung's avatar
Cyrus Leung committed
489
490
    return _merge_multimodal_embeddings(
        inputs_embeds,
cx's avatar
cx committed
491
492
        multimodal_embeddings=multimodal_embeddings,
        is_multimodal=is_multimodal,
Cyrus Leung's avatar
Cyrus Leung committed
493
494
495
    )


cx's avatar
cx committed
496
497
498
499
500
501
502
503
504
505
506
507
def isin_list(
    elements: torch.Tensor,
    test_elements_list: list[int],
) -> torch.Tensor:
    test_elements = torch.tensor(
        test_elements_list,
        pin_memory=is_pin_memory_available(),
    ).to(device=elements.device, non_blocking=True)

    return torch.isin(elements, test_elements)


508
509
class LayerFn(Protocol):

510
    def __call__(self, prefix: str) -> torch.nn.Module:
511
512
513
        ...


514
515
516
517
518
519
520
class PPMissingLayer(torch.nn.Identity):
    """
    A placeholder layer for missing layers in a pipeline parallel model.
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
521
522

    def forward(self, *args, **kwargs):
523
524
        """Return the first arg from args or the first value from kwargs."""
        return args[0] if args else next(iter(kwargs.values()))
525
526


527
528
529
530
531
532
533
534
535
536
537
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
538
539
540
541
    if (params := next(module.parameters(), None)) is None:
        return module

    device = params.device
542
543
544
545
546
547
548
549
550

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()
551
552
553
554
555
556
557
558
    uva_available = is_uva_available()

    if envs.VLLM_USE_V1:
        assert uva_available, ("V1 CPU offloading requires"
                               " uva (pin memory) support")
        uva_offloading = True
    else:
        uva_offloading = False
559
560
561

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
562
    offloaded_parameters = False
563
564
565
566
567
568
569
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
570
571
572
573
574
575
        cpu_data = torch.empty_strided(size=p.data.size(),
                                       stride=p.data.stride(),
                                       dtype=p.data.dtype,
                                       layout=p.data.layout,
                                       device='cpu',
                                       pin_memory=pin_memory)
576
        cpu_data.copy_(p.data)
577
578
579
580
581
582
        if not uva_offloading:
            p.data = cpu_data
        else:
            # keep the cpu data alive
            p._vllm_offloaded_cpu_data = cpu_data
            p.data = get_cuda_view_from_cpu_tensor(cpu_data)
583
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
584
585
        offloaded_parameters = True

586
    if offloaded_parameters and not uva_offloading:
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module,
                                     device_state,
                                     args=args,
                                     kwargs=kwargs)
            module.forward = forward
            return output
603
604
605
606
607
608

        module.forward = forward

    return module


609
def make_layers(
610
611
612
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str,
613
) -> tuple[int, int, torch.nn.ModuleList]:
614
615
616
617
618
619
620
621
622
    """Make a list of layers with the given layer function, taking
    pipeline parallelism into account.
    """
    from vllm.distributed.parallel_state import get_pp_group
    from vllm.distributed.utils import get_pp_indices
    start_layer, end_layer = get_pp_indices(num_hidden_layers,
                                            get_pp_group().rank_in_group,
                                            get_pp_group().world_size)
    modules = torch.nn.ModuleList(
623
        [PPMissingLayer() for _ in range(start_layer)] + [
624
625
            maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
            for idx in range(start_layer, end_layer)
626
        ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
627
628
629
630
    return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
631
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
632
633


634
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
635
636
637
638
639
640
641
642
    """Get the names of the missing layers in a pipeline parallel model."""
    model_id = id(model)
    if model_id in _model_to_pp_missing_layer_names:
        return _model_to_pp_missing_layer_names[model_id]

    missing_layer_names = []
    for name, module in model.named_modules():
        if isinstance(module, PPMissingLayer):
643
644
645
646
            # NOTE: the trailing dot is used to match the prefix of the layer.
            # without the dot, we could match a layer that is not missing,
            # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
            missing_layer_names.append(name + '.')
647
648
649
650
651
652
653
    _model_to_pp_missing_layer_names[model_id] = missing_layer_names

    return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
    """Check if a parameter is missing in a pipeline parallel model."""
654
655
656
657
658
659
    if isinstance(model, PPMissingLayer):
        return True

    return any(
        name.startswith(missing_layer_name)
        for missing_layer_name in get_pp_missing_layer_names(model))
660
661


662
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
663
664

    def make_empty_intermediate_tensors(
665
666
667
668
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
669
        return IntermediateTensors({
670
671
            key:
            torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
672
673
674
675
            for key in keys
        })

    return make_empty_intermediate_tensors
676
677


678
679
680
681
682
683
684
685
686
687
688
def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"
689
690


XuruiYang's avatar
XuruiYang committed
691
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
692
693
694
695
696
697
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
XuruiYang's avatar
XuruiYang committed
698
    - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
699
700
    """
    subnames = layer_name.split(".")
701
    int_vals: list[int] = []
702
703
704
705
706
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
XuruiYang's avatar
XuruiYang committed
707
708
709
710
711
712
713
714
715
716
717
    if num_attn_module == 1 or "attn" not in layer_name:
        assert len(int_vals) == 1, (f"layer name {layer_name} should"
                                    " only contain one integer")

        return int_vals[0]
    else:
        assert len(int_vals) <= 2, (f"layer name {layer_name} should"
                                    " contain most two integers")
        layer_index = int_vals[0] * num_attn_module + int_vals[1] if len(
            int_vals) == 2 else int_vals[0]
        return layer_index
718
719
720
721
722
723
724
725
726


def cast_overflow_tensors(
    tensors: torch.Tensor,
    offset: float = 1000,
) -> torch.Tensor:
    if tensors.isinf().any() or tensors.isnan().any():
        clamp_value = torch.finfo(tensors.dtype).max - offset
        tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
727
    return tensors
728
729


730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
def fast_topk(values: torch.Tensor, topk: int,
              dim: int) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Optimized topk implementation that uses torch.max for k=1 case.
    
    This function provides better performance for the common case of k=1
    by using torch.max instead of the more general torch.topk.
    
    Args:
        values: Input tensor to find top-k values from
        topk: Number of top values to return (k). Must be > 0.
        dim: Dimension along which to compute topk
        
    Returns:
        Tuple of (values, indices) where values are the top-k values
        and indices are their corresponding indices in the input tensor
    """
747
748
749
750
751
752
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
753
754
755
756
757
758
759


def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
    if hasattr(hf_config, "hidden_size"):
        return hf_config.hidden_size
    text_config = hf_config.get_text_config()
    return text_config.hidden_size
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802


# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
    return torch.ops.vllm.sequence_parallel_chunk_impl(x)


def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    # all_gather needs the sequence length to be divisible by tp_size
    seq_len = x.size(0)
    remainder = seq_len % tp_size
    if remainder != 0:
        pad_len = tp_size - remainder
        y = nn.functional.pad(x, (0, 0, 0, pad_len))
    else:
        y = x

    chunk = y.shape[0] // tp_size
    start = tp_rank * chunk
    return torch.narrow(y, 0, start, chunk)


def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
    tp_size = get_tensor_model_parallel_world_size()
    seq_len = cdiv(x.size(0), tp_size)
    shape = list(x.shape)
    shape[0] = seq_len
    out = torch.empty(shape, dtype=x.dtype, device=x.device)
    return out


direct_register_custom_op(
    op_name="sequence_parallel_chunk_impl",
    op_func=sequence_parallel_chunk_impl,
    fake_impl=sequence_parallel_chunk_impl_fake,
    tags=(torch.Tag.needs_fixed_stride_order, ),
)