cpu_model_runner.py 6.43 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from contextlib import contextmanager
4
from typing import TYPE_CHECKING, Any, Optional
5
6

import torch
7
import torch.nn as nn
8
9
10
11

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
12
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
13
from vllm.v1.utils import CpuGpuBuffer
14
15
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

16
17
18
if TYPE_CHECKING:
    from vllm.v1.core.sched.output import SchedulerOutput

19
20
21
22
23
24
logger = init_logger(__name__)


class CPUModelRunner(GPUModelRunner):

    def __init__(self, vllm_config: VllmConfig, device: torch.device):
25
26
        with _torch_cuda_wrapper():
            super().__init__(vllm_config, device)
27
28
29
30
31
32
33

        assert device == torch.device("cpu")
        assert self.speculative_config is None, "spec decode is not supported."

        self.use_cuda_graph = False
        self.cascade_attn_enabled = False

34
        self._postprocess_tensors()
35

36
37
38
39
40
41
42
43
44
45
    def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
        """
        Update the order of requests in the batch based on the attention
        backend's needs. For example, some attention backends (namely MLA) may
        want to separate requests based on if the attention computation will be
        compute-bound or memory-bound.

        Args:
            scheduler_output: The scheduler output.
        """
46
        # Attention free models have zero kv_cache_groups, however models
47
48
49
50
51
52
53
54
55
56
57
        # like Mamba are also attention free but use the kv_cache for
        # keeping its internal state. This is why we check the number
        # of kv_cache groups instead of solely checking
        # for self.model_config.is_attention_free.
        if len(self.kv_cache_config.kv_cache_groups) == 0:
            return

        if len(self.kv_cache_config.kv_cache_groups) > 1:
            raise ValueError("Multiple KVCacheGroups is not"
                             "currently supported with CPU model runner.")

58
59
60
61
62
63
64
65
66
        # Guard against encoder-only / pooling models where `attn_groups`
        # may be empty or lack the expected metadata_builder.
        # Without this check, accessing `attn_groups[0][0]` would trigger
        # an AssertionError on CPU backend.
        if not hasattr(self, "attn_groups") or not self.attn_groups:
            return
        if not self.attn_groups[0]:
            return

67
68
69
70
71
72
73
        mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
        if isinstance(mb, list):
            if not isinstance(mb[0], TorchSDPAMetadataBuilderV1):
                return
            mb[0].reorder_batch(self.input_batch, scheduler_output)
            return
        elif not isinstance(mb, TorchSDPAMetadataBuilderV1):
74
75
76
            # Encoder-only / rerank models do not benefit from reordering,
            # so we safely skip here.
            return
77

78
79
        # Safe path for decoder/attention-heavy models
        mb.reorder_batch(self.input_batch, scheduler_output)
80

81
    def _postprocess_tensors(self) -> None:
82
83
84
85
86
87
88
89
90
91
        # Note: replace device tensors with cpu tensors
        def replace_tensor(obj: Any, cpu_attr_name: str,
                           device_attr_name) -> None:
            cpu_tensor = getattr(obj, cpu_attr_name, None)
            device_tensor = getattr(obj, device_attr_name, None)
            if cpu_tensor is not None and device_tensor is not None:
                assert isinstance(cpu_tensor, torch.Tensor)
                assert isinstance(device_tensor, torch.Tensor)
                setattr(obj, device_attr_name, cpu_tensor)

92
        for v in vars(self).values():
93
94
            if isinstance(v, CpuGpuBuffer):
                v.gpu = v.cpu
95
96
97
98
99

        for k, v in vars(self.input_batch).items():
            if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
                replace_tensor(self.input_batch, k, k[:-11])

100
        for block_table in self.input_batch.block_table.block_tables:
101
102
103
            for v in vars(block_table).values():
                if isinstance(v, CpuGpuBuffer):
                    v.gpu = v.cpu
104

105
    def load_model(self, eep_scale_up: bool = False) -> None:
106
107
108
109
110
111
112
113
        logger.info("Starting to load model %s...", self.model_config.model)
        self.model = get_model(vllm_config=self.vllm_config)

        if self.lora_config:
            self.model = self.load_lora_model(self.model, self.model_config,
                                              self.scheduler_config,
                                              self.lora_config, self.device)

114
115
116
    def get_model(self) -> nn.Module:
        return self.model

117
118
119
    def warming_up_model(self) -> None:
        logger.info("Warming up model for the compilation...")
        # Only generate graph for the generic shape
120
121
        with _set_global_compilation_settings(self.vllm_config):
            self._dummy_run(max(16, self.max_num_reqs))
122
123
124
125
126
127
128
129
        logger.info("Warming up done.")

    def _init_device_properties(self) -> None:
        pass

    def _sync_device(self) -> None:
        pass

130
131
132
    def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
        return sampled_token_ids.tolist()

133
134
135
136
137
    def get_dp_padding(self,
                       num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
        # Note: For CPU backend, dp padding is not required for now.
        return 0, None

138
139
140
141
142
143
144
145
146
147

@contextmanager
def _torch_cuda_wrapper():

    class _EventPlaceholder:

        def __init__(self, *args, **kwargs) -> None:
            self.record = lambda: None
            self.synchronize = lambda: None

148
149
150
151
152
    class _StreamPlaceholder:

        def __init__(self, *args, **kwargs) -> None:
            pass

153
    cuda_event = torch.cuda.Event
154
    cuda_stream = torch.cuda.Stream
155
156
    try:
        torch.cuda.Event = _EventPlaceholder
157
        torch.cuda.Stream = _StreamPlaceholder
158
159
160
        yield
    finally:
        torch.cuda.Event = cuda_event
161
        torch.cuda.Stream = cuda_stream
162

163
164

@contextmanager
165
def _set_global_compilation_settings(config: VllmConfig):
166
    import torch._inductor.config as torch_inductor_config
167

168
    inductor_config = config.compilation_config.inductor_compile_config
169
    # Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
170
    freezing_value = torch_inductor_config.freezing
171
172
    try:
        if inductor_config.get("max_autotune", False):
173
            torch_inductor_config.freezing = True
174
175
        yield
    finally:
176
        torch_inductor_config.freezing = freezing_value