cpu_model_runner.py 5.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from contextlib import contextmanager
4
from typing import TYPE_CHECKING, Any
5
6

import torch
7
import torch.nn as nn
8
9
10
11

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
12
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
13
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
14
from vllm.v1.worker.utils import CpuGpuBuffer
15

16
17
18
if TYPE_CHECKING:
    from vllm.v1.core.sched.output import SchedulerOutput

19
20
21
22
23
24
logger = init_logger(__name__)


class CPUModelRunner(GPUModelRunner):

    def __init__(self, vllm_config: VllmConfig, device: torch.device):
25
26
        with _torch_cuda_wrapper():
            super().__init__(vllm_config, device)
27
28
29
30
31
32
33

        assert device == torch.device("cpu")
        assert self.speculative_config is None, "spec decode is not supported."

        self.use_cuda_graph = False
        self.cascade_attn_enabled = False

34
        self._postprocess_tensors()
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
        """
        Update the order of requests in the batch based on the attention
        backend's needs. For example, some attention backends (namely MLA) may
        want to separate requests based on if the attention computation will be
        compute-bound or memory-bound.

        Args:
            scheduler_output: The scheduler output.
        """
        # Attention free models have zero kv_cache_goups, however models
        # like Mamba are also attention free but use the kv_cache for
        # keeping its internal state. This is why we check the number
        # of kv_cache groups instead of solely checking
        # for self.model_config.is_attention_free.
        if len(self.kv_cache_config.kv_cache_groups) == 0:
            return

        if len(self.kv_cache_config.kv_cache_groups) > 1:
            raise ValueError("Multiple KVCacheGroups is not"
                             "currently supported with CPU model runner.")

58
59
        assert type(self.attn_groups[0]
                    [0].metadata_builder) is TorchSDPAMetadataBuilderV1
60

61
62
        self.attn_groups[0][0].metadata_builder.reorder_batch(
            self.input_batch, scheduler_output)
63

64
    def _postprocess_tensors(self) -> None:
65
66
67
68
69
70
71
72
73
74
75
        # Note: replace device tensors with cpu tensors
        def replace_tensor(obj: Any, cpu_attr_name: str,
                           device_attr_name) -> None:
            cpu_tensor = getattr(obj, cpu_attr_name, None)
            device_tensor = getattr(obj, device_attr_name, None)
            if cpu_tensor is not None and device_tensor is not None:
                assert isinstance(cpu_tensor, torch.Tensor)
                assert isinstance(device_tensor, torch.Tensor)
                setattr(obj, device_attr_name, cpu_tensor)

        for k, v in vars(self).items():
76
77
            if isinstance(v, CpuGpuBuffer):
                v.gpu = v.cpu
78
79
80
81
82

        for k, v in vars(self.input_batch).items():
            if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
                replace_tensor(self.input_batch, k, k[:-11])

83
84
85
86
        for block_table in self.input_batch.block_table.block_tables:
            for k, v in vars(block_table).items():
                if k.endswith("_cpu") and isinstance(v, torch.Tensor):
                    replace_tensor(block_table, k, k[:-4])
87

88
    def load_model(self, eep_scale_up: bool = False) -> None:
89
90
91
92
93
94
95
96
        logger.info("Starting to load model %s...", self.model_config.model)
        self.model = get_model(vllm_config=self.vllm_config)

        if self.lora_config:
            self.model = self.load_lora_model(self.model, self.model_config,
                                              self.scheduler_config,
                                              self.lora_config, self.device)

97
98
99
    def get_model(self) -> nn.Module:
        return self.model

100
101
102
    def warming_up_model(self) -> None:
        logger.info("Warming up model for the compilation...")
        # Only generate graph for the generic shape
103
104
        with _set_global_compilation_settings(self.vllm_config):
            self._dummy_run(max(16, self.max_num_reqs))
105
106
107
108
109
110
111
112
        logger.info("Warming up done.")

    def _init_device_properties(self) -> None:
        pass

    def _sync_device(self) -> None:
        pass

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
        return sampled_token_ids.tolist()


@contextmanager
def _torch_cuda_wrapper():

    class _EventPlaceholder:

        def __init__(self, *args, **kwargs) -> None:
            self.record = lambda: None
            self.synchronize = lambda: None

    try:
        cuda_event = torch.cuda.Event
        torch.cuda.Event = _EventPlaceholder
        yield
    finally:
        torch.cuda.Event = cuda_event

133
134

@contextmanager
135
def _set_global_compilation_settings(config: VllmConfig):
136
137
    import torch._inductor.config

138
139
140
141
142
143
144
145
146
    inductor_config = config.compilation_config.inductor_compile_config
    try:
        # Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
        freezing_value = torch._inductor.config.freezing
        if inductor_config.get("max_autotune", False):
            torch._inductor.config.freezing = True
        yield
    finally:
        torch._inductor.config.freezing = freezing_value