utils.py 3.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Utils for model executor."""
4

5
import copy
6
from typing import Any
7
8
9
10
11

import torch


def set_random_seed(seed: int) -> None:
12
    from vllm.platforms import current_platform
13

14
    current_platform.seed_everything(seed)
15
16
17
18


def set_weight_attrs(
    weight: torch.Tensor,
19
    weight_attrs: dict[str, Any] | None,
20
21
22
23
24
25
26
27
28
29
30
31
32
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
33
        assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
34
35
36
37
38
39
40
41
42
43

        # NOTE(woosuk): During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        # TODO(woosuk): Remove this hack once we have a better solution.
44
        from vllm.platforms import current_platform
45

46
        if current_platform.use_sync_weight_loader() and key == "weight_loader":
47
            value = current_platform.make_synced_weight_loader(value)
48
        setattr(weight, key, value)
49
50


51
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
52
53
    parent_map = getattr(model, "packed_modules_mapping", None)
    parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
54
55
56
57
58
59
60

    # don't infer mapping if the model has defined it explicitly.
    if parent_map:
        return parent_map

    # We only check main components instead of whole model submodules
    for child in model.children():
61
62
63
        child_map = getattr(child, "packed_modules_mapping", None)
        child_map = copy.deepcopy(child_map) if child_map is not None else {}

64
        if any((k in parent_map and parent_map[k] != v) for k, v in child_map.items()):
65
66
            raise ValueError(
                f"Can't update {type(model).__name__}'s packed_modules_mapping "
67
68
                f"safely because of conflicts from {type(child).__name__}."
            )
69
70
        else:
            parent_map.update(child_map)
71
72
73
74
    return parent_map


def get_moe_expert_mapping(
75
76
    model: torch.nn.Module,
) -> list[tuple[str, str, int, str]]:
77
78
79
80
81
82
83
84
85
    if parent_map := getattr(model, "get_expert_mapping", None):
        return parent_map()
    else:
        # We only check main components instead of whole model submodules
        for child in model.children():
            child_map = getattr(child, "get_expert_mapping", None)
            if child_map is not None:
                return child_map()
        return []