cpu_model_runner.py 4.23 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from contextlib import contextmanager
4
from typing import Any
5
6

import torch
7
import torch.nn as nn
8
9
10
11

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
12
from vllm.v1.utils import CpuGpuBuffer
13
14
15
16
17
18
19
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

logger = init_logger(__name__)


class CPUModelRunner(GPUModelRunner):
    def __init__(self, vllm_config: VllmConfig, device: torch.device):
20
21
        with _torch_cuda_wrapper():
            super().__init__(vllm_config, device)
22
23
24
25
26
27
28

        assert device == torch.device("cpu")
        assert self.speculative_config is None, "spec decode is not supported."

        self.use_cuda_graph = False
        self.cascade_attn_enabled = False

29
        self._postprocess_tensors()
30

31
    def _postprocess_tensors(self) -> None:
32
        # Note: replace device tensors with cpu tensors
33
        def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
34
35
36
37
38
39
40
            cpu_tensor = getattr(obj, cpu_attr_name, None)
            device_tensor = getattr(obj, device_attr_name, None)
            if cpu_tensor is not None and device_tensor is not None:
                assert isinstance(cpu_tensor, torch.Tensor)
                assert isinstance(device_tensor, torch.Tensor)
                setattr(obj, device_attr_name, cpu_tensor)

41
        for v in vars(self).values():
42
43
            if isinstance(v, CpuGpuBuffer):
                v.gpu = v.cpu
44
45
46
47
48

        for k, v in vars(self.input_batch).items():
            if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
                replace_tensor(self.input_batch, k, k[:-11])

49
        for block_table in self.input_batch.block_table.block_tables:
50
51
52
            for v in vars(block_table).values():
                if isinstance(v, CpuGpuBuffer):
                    v.gpu = v.cpu
53

54
    def load_model(self, eep_scale_up: bool = False) -> None:
55
56
57
58
        logger.info("Starting to load model %s...", self.model_config.model)
        self.model = get_model(vllm_config=self.vllm_config)

        if self.lora_config:
59
            self.model = self.load_lora_model(self.model, self.vllm_config, self.device)
60

61
62
63
    def get_model(self) -> nn.Module:
        return self.model

64
65
66
    def warming_up_model(self) -> None:
        logger.info("Warming up model for the compilation...")
        # Only generate graph for the generic shape
67
        with _set_global_compilation_settings(self.vllm_config):
68
69
70
71
72
73
74
            self._dummy_run(
                min(
                    max(16, self.max_num_reqs),
                    self.scheduler_config.max_num_batched_tokens,
                )
            )

75
76
77
78
79
80
81
82
        logger.info("Warming up done.")

    def _init_device_properties(self) -> None:
        pass

    def _sync_device(self) -> None:
        pass

83
84
85
    def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
        return sampled_token_ids.tolist()

86
    def get_dp_padding(self, num_tokens: int) -> tuple[int, torch.Tensor | None]:
87
88
89
        # Note: For CPU backend, dp padding is not required for now.
        return 0, None

90
91
92
93
94
95
96
97

@contextmanager
def _torch_cuda_wrapper():
    class _EventPlaceholder:
        def __init__(self, *args, **kwargs) -> None:
            self.record = lambda: None
            self.synchronize = lambda: None

98
99
100
101
    class _StreamPlaceholder:
        def __init__(self, *args, **kwargs) -> None:
            pass

102
    cuda_event = torch.cuda.Event
103
    cuda_stream = torch.cuda.Stream
104
105
    try:
        torch.cuda.Event = _EventPlaceholder
106
        torch.cuda.Stream = _StreamPlaceholder
107
108
109
        yield
    finally:
        torch.cuda.Event = cuda_event
110
        torch.cuda.Stream = cuda_stream
111

112
113

@contextmanager
114
def _set_global_compilation_settings(config: VllmConfig):
115
    import torch._inductor.config as torch_inductor_config
116

117
    inductor_config = config.compilation_config.inductor_compile_config
118
    # Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
119
    freezing_value = torch_inductor_config.freezing
120
121
    try:
        if inductor_config.get("max_autotune", False):
122
            torch_inductor_config.freezing = True
123
124
        yield
    finally:
125
        torch_inductor_config.freezing = freezing_value