cpu_model_runner.py 9.79 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from contextlib import contextmanager
4
from typing import Any
5
6

import torch
7
import torch.nn as nn
8

9
import vllm.utils.cpu_triton_utils as cpu_tl
10
11
12
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
13
from vllm.tracing import instrument
14
15
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
16
from vllm.v1.utils import CpuGpuBuffer
17
18
19
20
21
22
23
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

logger = init_logger(__name__)


class CPUModelRunner(GPUModelRunner):
    def __init__(self, vllm_config: VllmConfig, device: torch.device):
24
25
        with _torch_cuda_wrapper():
            super().__init__(vllm_config, device)
26
27

        assert device == torch.device("cpu")
28
        # Note: speculative decoding is now supported on CPU with C++ native impls
29
30
31
32

        self.use_cuda_graph = False
        self.cascade_attn_enabled = False

33
        self._postprocess_tensors()
34
        self._postprocess_triton()
35

36
    def _postprocess_tensors(self) -> None:
37
        # Note: replace device tensors with cpu tensors
38
        def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
39
40
            cpu_tensor = getattr(obj, cpu_attr_name, None)
            device_tensor = getattr(obj, device_attr_name, None)
41
42
43
            if isinstance(cpu_tensor, torch.Tensor) and isinstance(
                device_tensor, torch.Tensor
            ):
44
45
                setattr(obj, device_attr_name, cpu_tensor)

46
        for v in vars(self).values():
47
48
            if isinstance(v, CpuGpuBuffer):
                v.gpu = v.cpu
49
50
51
52
53

        for k, v in vars(self.input_batch).items():
            if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
                replace_tensor(self.input_batch, k, k[:-11])

54
        for block_table in self.input_batch.block_table.block_tables:
55
56
57
            for v in vars(block_table).values():
                if isinstance(v, CpuGpuBuffer):
                    v.gpu = v.cpu
58

59
60
61
62
63
64
65
    def _postprocess_triton(self) -> None:
        import vllm.v1.worker.block_table

        vllm.v1.worker.block_table._compute_slot_mapping_kernel = (
            cpu_tl.compute_slot_mapping_kernel
        )

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        # Speculative decoding fallbacks
        import vllm.v1.sample.rejection_sampler
        import vllm.v1.spec_decode.eagle
        import vllm.v1.spec_decode.utils

        vllm.v1.spec_decode.eagle.eagle_prepare_inputs_padded_kernel = (
            cpu_tl.eagle_prepare_inputs_padded_kernel
        )
        vllm.v1.spec_decode.eagle.eagle_prepare_next_token_padded_kernel = (
            cpu_tl.eagle_prepare_next_token_padded_kernel
        )
        vllm.v1.spec_decode.eagle.copy_and_expand_eagle_inputs_kernel = (
            cpu_tl.copy_and_expand_eagle_inputs_kernel
        )
        vllm.v1.spec_decode.utils.eagle_step_slot_mapping_metadata_kernel = (
            cpu_tl.eagle_step_slot_mapping_metadata_kernel
        )
        vllm.v1.sample.rejection_sampler.rejection_greedy_sample_kernel = (
            cpu_tl.rejection_greedy_sample_kernel
        )
        vllm.v1.sample.rejection_sampler.rejection_random_sample_kernel = (
            cpu_tl.rejection_random_sample_kernel
        )
        vllm.v1.sample.rejection_sampler.expand_kernel = cpu_tl.expand_kernel
        vllm.v1.sample.rejection_sampler.sample_recovered_tokens_kernel = (
            cpu_tl.sample_recovered_tokens_kernel
        )

94
    @instrument(span_name="Loading (CPU)")
95
96
97
98
99
100
    def load_model(self, load_dummy_weights: bool = False) -> None:
        if load_dummy_weights:
            raise ValueError(
                "Loading dummy weights (needed for elastic EP scale-up) "
                "Is not supported by the CPU Model Runner."
            )
101
102
103
104
        logger.info("Starting to load model %s...", self.model_config.model)
        self.model = get_model(vllm_config=self.vllm_config)

        if self.lora_config:
105
            self.model = self.load_lora_model(self.model, self.vllm_config, self.device)
106

107
108
109
110
        if hasattr(self, "drafter"):
            logger.info_once("Loading drafter model...")
            self.drafter.load_model(self.model)

111
112
113
    def get_model(self) -> nn.Module:
        return self.model

114
    @instrument(span_name="Warmup (CPU)")
115
116
117
    def warming_up_model(self) -> None:
        logger.info("Warming up model for the compilation...")
        # Only generate graph for the generic shape
118
        with _set_global_compilation_settings(self.vllm_config):
119
120
121
122
123
124
125
            self._dummy_run(
                min(
                    max(16, self.max_num_reqs),
                    self.scheduler_config.max_num_batched_tokens,
                )
            )

126
127
128
129
130
131
132
133
        # Warm up drafter for speculative decoding
        if self.speculative_config and (self.speculative_config.uses_draft_model()):
            from vllm.v1.spec_decode.draft_model import DraftModelProposer

            if isinstance(self.drafter, (DraftModelProposer)):
                logger.info("Warming up drafter model...")
                self.drafter.dummy_run(max(16, self.max_num_reqs))

134
135
        logger.info("Warming up done.")

136
137
138
139
140
141
142
143
144
145
146
147
148
    def initialize_kv_cache(
        self,
        kv_cache_config: KVCacheConfig,
        is_profiling: bool = False,
    ) -> None:
        super().initialize_kv_cache(kv_cache_config, is_profiling)

        if self.speculative_config:
            if self.speculative_config.use_eagle():
                logger.info("EAGLE drafter KV cache initialized for CPU backend")
            elif self.speculative_config.uses_draft_model():
                logger.info("Draft model KV cache initialized for CPU backend")

149
150
151
152
153
154
    def _init_device_properties(self) -> None:
        pass

    def _sync_device(self) -> None:
        pass

155
156
157
158
159
    def _zero_block_ids(self, block_ids: list[int]) -> None:
        # CPU attention assigns -INF to logits at invalid positions,
        # so stale KV cache data never affects computation.
        pass

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    # =========================================================================
    # CPU-safe overrides for speculative decoding methods
    # These methods override GPU-specific implementations that use CUDA streams
    # =========================================================================

    def _copy_draft_token_ids_to_cpu(
        self, scheduler_output: "SchedulerOutput", zeros_only: bool = False
    ) -> None:
        """CPU-safe version: no async copy needed, tensors already on CPU."""
        if self.use_async_scheduling and not (
            scheduler_output.has_structured_output_requests
            or self.input_batch.sampling_metadata.output_token_ids
        ):
            return
        self._draft_token_req_ids = self.input_batch.req_ids.copy()

        draft_token_ids: torch.Tensor = self._draft_token_ids
        if not torch.is_tensor(draft_token_ids):
            return

        num_reqs = draft_token_ids.shape[0]
        if self.draft_token_ids_cpu is not None:
            if not zeros_only:
                self.draft_token_ids_cpu[:num_reqs].copy_(draft_token_ids)
            else:
                self.draft_token_ids_cpu[:num_reqs] = 0

    def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]:
        """CPU-safe version: no event synchronization needed."""
        if isinstance(self._draft_token_ids, list):
            return self._draft_token_ids, self.input_batch.req_ids
        req_ids = self._draft_token_req_ids
        if req_ids is None:
            return [], []
        if self.draft_token_ids_cpu is not None:
            return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids
        return [], []

    def _copy_valid_sampled_token_count(
        self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
    ) -> None:
        """CPU-safe version: direct copy without CUDA streams."""
        if self.valid_sampled_token_count_cpu is None:
            return

        counts = valid_sampled_tokens_count
        counts_cpu = self.valid_sampled_token_count_cpu
        counts_cpu[: counts.shape[0]].copy_(counts)
        self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)

    def _get_valid_sampled_token_count(self) -> list[int]:
        """CPU-safe version: no event synchronization needed."""
        prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
        if prev_sampled_token_ids is None:
            return []

        counts_cpu = self.valid_sampled_token_count_cpu
        if counts_cpu is None:
            return []
        return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()

    def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
        """CPU-safe version: direct tolist() without CUDA events."""
        return sampled_token_ids.tolist()

225
226
227
228
229
230
231
232

@contextmanager
def _torch_cuda_wrapper():
    class _EventPlaceholder:
        def __init__(self, *args, **kwargs) -> None:
            self.record = lambda: None
            self.synchronize = lambda: None

233
234
235
236
    class _StreamPlaceholder:
        def __init__(self, *args, **kwargs) -> None:
            pass

237
    cuda_event = torch.Event
238
    cuda_stream = torch.cuda.Stream
239
    try:
240
        torch.Event = _EventPlaceholder
241
        torch.cuda.Stream = _StreamPlaceholder
242
243
        yield
    finally:
244
        torch.Event = cuda_event
245
        torch.cuda.Stream = cuda_stream
246

247
248

@contextmanager
249
def _set_global_compilation_settings(config: VllmConfig):
250
    import torch._inductor.config as torch_inductor_config
251

252
    inductor_config = config.compilation_config.inductor_compile_config
253
    # Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
254
    freezing_value = torch_inductor_config.freezing
255
256
    try:
        if inductor_config.get("max_autotune", False):
257
            torch_inductor_config.freezing = True
258
259
        yield
    finally:
260
        torch_inductor_config.freezing = freezing_value