cpu_model_runner.py 4.67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from contextlib import contextmanager
4
from typing import Any
5
6

import torch
7
import torch.nn as nn
8

9
import vllm.utils.cpu_triton_utils as cpu_tl
10
11
12
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
13
from vllm.tracing import instrument
14
from vllm.v1.utils import CpuGpuBuffer
15
16
17
18
19
20
21
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

logger = init_logger(__name__)


class CPUModelRunner(GPUModelRunner):
    def __init__(self, vllm_config: VllmConfig, device: torch.device):
22
23
        with _torch_cuda_wrapper():
            super().__init__(vllm_config, device)
24
25
26
27
28
29
30

        assert device == torch.device("cpu")
        assert self.speculative_config is None, "spec decode is not supported."

        self.use_cuda_graph = False
        self.cascade_attn_enabled = False

31
        self._postprocess_tensors()
32
        self._postprocess_triton()
33

34
    def _postprocess_tensors(self) -> None:
35
        # Note: replace device tensors with cpu tensors
36
        def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
37
38
            cpu_tensor = getattr(obj, cpu_attr_name, None)
            device_tensor = getattr(obj, device_attr_name, None)
39
40
41
            if isinstance(cpu_tensor, torch.Tensor) and isinstance(
                device_tensor, torch.Tensor
            ):
42
43
                setattr(obj, device_attr_name, cpu_tensor)

44
        for v in vars(self).values():
45
46
            if isinstance(v, CpuGpuBuffer):
                v.gpu = v.cpu
47
48
49
50
51

        for k, v in vars(self.input_batch).items():
            if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
                replace_tensor(self.input_batch, k, k[:-11])

52
        for block_table in self.input_batch.block_table.block_tables:
53
54
55
            for v in vars(block_table).values():
                if isinstance(v, CpuGpuBuffer):
                    v.gpu = v.cpu
56

57
58
59
60
61
62
63
    def _postprocess_triton(self) -> None:
        import vllm.v1.worker.block_table

        vllm.v1.worker.block_table._compute_slot_mapping_kernel = (
            cpu_tl.compute_slot_mapping_kernel
        )

64
    @instrument(span_name="Loading (CPU)")
65
66
67
68
69
70
    def load_model(self, load_dummy_weights: bool = False) -> None:
        if load_dummy_weights:
            raise ValueError(
                "Loading dummy weights (needed for elastic EP scale-up) "
                "Is not supported by the CPU Model Runner."
            )
71
72
73
74
        logger.info("Starting to load model %s...", self.model_config.model)
        self.model = get_model(vllm_config=self.vllm_config)

        if self.lora_config:
75
            self.model = self.load_lora_model(self.model, self.vllm_config, self.device)
76

77
78
79
    def get_model(self) -> nn.Module:
        return self.model

80
    @instrument(span_name="Warmup (CPU)")
81
82
83
    def warming_up_model(self) -> None:
        logger.info("Warming up model for the compilation...")
        # Only generate graph for the generic shape
84
        with _set_global_compilation_settings(self.vllm_config):
85
86
87
88
89
90
91
            self._dummy_run(
                min(
                    max(16, self.max_num_reqs),
                    self.scheduler_config.max_num_batched_tokens,
                )
            )

92
93
94
95
96
97
98
99
        logger.info("Warming up done.")

    def _init_device_properties(self) -> None:
        pass

    def _sync_device(self) -> None:
        pass

100
101
102
103
104
    def _zero_block_ids(self, block_ids: list[int]) -> None:
        # CPU attention assigns -INF to logits at invalid positions,
        # so stale KV cache data never affects computation.
        pass

105
106
107
108
109
110
111
112

@contextmanager
def _torch_cuda_wrapper():
    class _EventPlaceholder:
        def __init__(self, *args, **kwargs) -> None:
            self.record = lambda: None
            self.synchronize = lambda: None

113
114
115
116
    class _StreamPlaceholder:
        def __init__(self, *args, **kwargs) -> None:
            pass

117
    cuda_event = torch.Event
118
    cuda_stream = torch.cuda.Stream
119
    try:
120
        torch.Event = _EventPlaceholder
121
        torch.cuda.Stream = _StreamPlaceholder
122
123
        yield
    finally:
124
        torch.Event = cuda_event
125
        torch.cuda.Stream = cuda_stream
126

127
128

@contextmanager
129
def _set_global_compilation_settings(config: VllmConfig):
130
    import torch._inductor.config as torch_inductor_config
131

132
    inductor_config = config.compilation_config.inductor_compile_config
133
    # Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
134
    freezing_value = torch_inductor_config.freezing
135
136
    try:
        if inductor_config.get("max_autotune", False):
137
            torch_inductor_config.freezing = True
138
139
        yield
    finally:
140
        torch_inductor_config.freezing = freezing_value