cpu_executor.py 5.53 KB
Newer Older
1
import os
2
from typing import Dict, List, Set, Tuple
3
4
5

import torch

6
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
7
8
9
10
11
12
13
14
15
16
17
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import get_distributed_init_method, get_ip, get_open_port

logger = init_logger(__name__)


class CPUExecutor(ExecutorBase):

18
19
20
21
22
23
24
    def _init_executor(self) -> None:
        assert self.device_config.device_type == "cpu"
        assert self.lora_config is None, "cpu backend doesn't support LoRA"
        self.model_config = _verify_and_get_model_config(self.model_config)
        self.cache_config = _verify_and_get_cache_config(self.cache_config)
        self.scheduler_config = _verify_and_get_scheduler_config(
            self.scheduler_config)
25
26
27
28
29
30
31
32
33
34
35
36
37

        # Instantiate the worker and load the model to CPU.
        self._init_worker()

    def _init_worker(self):
        from vllm.worker.cpu_worker import CPUWorker

        assert self.parallel_config.world_size == 1, (
            "CPUExecutor only supports single CPU socket currently.")

        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())
        self.driver_worker = CPUWorker(
38
39
40
41
42
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
43
            load_config=self.load_config,
44
45
46
47
            local_rank=0,
            rank=0,
            distributed_init_method=distributed_init_method,
            lora_config=self.lora_config,
48
            vision_language_config=self.vision_language_config,
49
50
51
52
53
54
            kv_cache_dtype=self.cache_config.cache_dtype,
            is_driver_worker=True,
        )
        self.driver_worker.init_device()
        self.driver_worker.load_model()

55
    def determine_num_available_blocks(self) -> Tuple[int, int]:
56
57
58
59
60
61
62
63
64
65
66
67
        """Determine the number of available KV blocks by invoking the
        underlying worker.
        """
        return self.driver_worker.determine_num_available_blocks()

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache by invoking the underlying worker.
        """
        # NOTE: We log here to avoid multiple logs when number of workers is
        # greater than one. We could log in the engine, but not all executors
        # have GPUs.
68
69
70
71
        # NOTE: `cpu block` for CPU backend is located on CPU memory but is
        # referred as `gpu block`. Because we want to reuse the existing block
        # management procedure.
        logger.info(f"# CPU blocks: {num_gpu_blocks}")
72
        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
73
74
75
76
77

    def execute_model(self,
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      blocks_to_swap_in: Dict[int, int],
                      blocks_to_swap_out: Dict[int, int],
78
79
                      blocks_to_copy: Dict[int, List[int]],
                      num_lookahead_slots: int) -> List[SamplerOutput]:
80
81
82
83
84
85
86
87
88
        output = self.driver_worker.execute_model(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
        )
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
89
        return self.driver_worker.add_lora(lora_request)
90
91

    def remove_lora(self, lora_id: int) -> bool:
92
        return self.driver_worker.remove_lora(lora_id)
93

94
    def list_loras(self) -> Set[int]:
95
        return self.driver_worker.list_loras()
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

    def check_health(self) -> None:
        # CPUExecutor will always be healthy as long as
        # it's running.
        return


def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
    if config.dtype == torch.float16:
        logger.warning("float16 is not supported on CPU, casting to bfloat16.")
        config.dtype = torch.bfloat16
    if not config.enforce_eager:
        logger.warning(
            "CUDA graph is not supported on CPU, fallback to the eager "
            "mode.")
        config.enforce_eager = True
    return config


115
116
117
118
119
120
121
122
123
def _verify_and_get_scheduler_config(
        config: SchedulerConfig) -> SchedulerConfig:
    if config.chunked_prefill_enabled:
        logger.warning("Chunked prefill is not supported on CPU, disable it.")
        config.chunked_prefill_enabled = False

    return config


124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
    _GB = 1 << 30
    if config.enable_prefix_caching:
        logger.warning("Prefix caching is not supported on CPU, disable it.")
        config.enable_prefix_caching = False

    kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")
    kv_cache_space = int(kv_cache_space_str)

    if kv_cache_space >= 0:
        if kv_cache_space == 0:
            config.cpu_kvcache_space_bytes = 4 * _GB  # type: ignore
            logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
                           "for CPU backend is not set, using 4 by default.")
        else:
            config.cpu_kvcache_space_bytes = kv_cache_space * _GB  # type: ignore
    else:
        raise RuntimeError(
            "Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
            f" {kv_cache_space}, expect a positive integer value.")

    return config