"vllm/vscode:/vscode.git/clone" did not exist on "0c637391359e53c7b41e99ebe7188bd6dd097b8f"
cpu_model_runner.py 4.73 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from contextlib import contextmanager
4
from typing import TYPE_CHECKING, Any, Optional
5
6

import torch
7
import torch.nn as nn
8
9
10
11

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
12
from vllm.v1.utils import CpuGpuBuffer
13
14
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

15
16
17
if TYPE_CHECKING:
    from vllm.v1.core.sched.output import SchedulerOutput

18
19
20
21
22
logger = init_logger(__name__)


class CPUModelRunner(GPUModelRunner):
    def __init__(self, vllm_config: VllmConfig, device: torch.device):
23
24
        with _torch_cuda_wrapper():
            super().__init__(vllm_config, device)
25
26
27
28
29
30
31

        assert device == torch.device("cpu")
        assert self.speculative_config is None, "spec decode is not supported."

        self.use_cuda_graph = False
        self.cascade_attn_enabled = False

32
        self._postprocess_tensors()
33

34
    # Note: Remove the override after new attention backend finished
35
36
    def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
        if len(self.kv_cache_config.kv_cache_groups) > 1:
37
38
39
40
            raise ValueError(
                "Multiple KVCacheGroups is not"
                "currently supported with CPU model runner."
            )
41
        super()._may_reorder_batch(scheduler_output)
42

43
    def _postprocess_tensors(self) -> None:
44
        # Note: replace device tensors with cpu tensors
45
        def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
46
47
48
49
50
51
52
            cpu_tensor = getattr(obj, cpu_attr_name, None)
            device_tensor = getattr(obj, device_attr_name, None)
            if cpu_tensor is not None and device_tensor is not None:
                assert isinstance(cpu_tensor, torch.Tensor)
                assert isinstance(device_tensor, torch.Tensor)
                setattr(obj, device_attr_name, cpu_tensor)

53
        for v in vars(self).values():
54
55
            if isinstance(v, CpuGpuBuffer):
                v.gpu = v.cpu
56
57
58
59
60

        for k, v in vars(self.input_batch).items():
            if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
                replace_tensor(self.input_batch, k, k[:-11])

61
        for block_table in self.input_batch.block_table.block_tables:
62
63
64
            for v in vars(block_table).values():
                if isinstance(v, CpuGpuBuffer):
                    v.gpu = v.cpu
65

66
    def load_model(self, eep_scale_up: bool = False) -> None:
67
68
69
70
        logger.info("Starting to load model %s...", self.model_config.model)
        self.model = get_model(vllm_config=self.vllm_config)

        if self.lora_config:
71
            self.model = self.load_lora_model(self.model, self.vllm_config, self.device)
72

73
74
75
    def get_model(self) -> nn.Module:
        return self.model

76
77
78
    def warming_up_model(self) -> None:
        logger.info("Warming up model for the compilation...")
        # Only generate graph for the generic shape
79
        with _set_global_compilation_settings(self.vllm_config):
80
81
82
83
84
85
86
            self._dummy_run(
                min(
                    max(16, self.max_num_reqs),
                    self.scheduler_config.max_num_batched_tokens,
                )
            )

87
88
89
90
91
92
93
94
        logger.info("Warming up done.")

    def _init_device_properties(self) -> None:
        pass

    def _sync_device(self) -> None:
        pass

95
96
97
    def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
        return sampled_token_ids.tolist()

98
    def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
99
100
101
        # Note: For CPU backend, dp padding is not required for now.
        return 0, None

102
103
104
105
106
107
108
109

@contextmanager
def _torch_cuda_wrapper():
    class _EventPlaceholder:
        def __init__(self, *args, **kwargs) -> None:
            self.record = lambda: None
            self.synchronize = lambda: None

110
111
112
113
    class _StreamPlaceholder:
        def __init__(self, *args, **kwargs) -> None:
            pass

114
    cuda_event = torch.cuda.Event
115
    cuda_stream = torch.cuda.Stream
116
117
    try:
        torch.cuda.Event = _EventPlaceholder
118
        torch.cuda.Stream = _StreamPlaceholder
119
120
121
        yield
    finally:
        torch.cuda.Event = cuda_event
122
        torch.cuda.Stream = cuda_stream
123

124
125

@contextmanager
126
def _set_global_compilation_settings(config: VllmConfig):
127
    import torch._inductor.config as torch_inductor_config
128

129
    inductor_config = config.compilation_config.inductor_compile_config
130
    # Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
131
    freezing_value = torch_inductor_config.freezing
132
133
    try:
        if inductor_config.get("max_autotune", False):
134
            torch_inductor_config.freezing = True
135
136
        yield
    finally:
137
        torch_inductor_config.freezing = freezing_value