cpu_worker.py 12.9 KB
Newer Older
1
"""A CPU worker class."""
2
from typing import Any, Dict, List, Optional, Tuple
3
4
5
6
7

import torch
import torch.distributed

from vllm.attention import get_attn_backend
8
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
9
10
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
11
12
13
from vllm.distributed import (broadcast_tensor_dict,
                              ensure_model_parallel_initialized,
                              init_distributed_environment)
14
15
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
16
from vllm.sequence import ExecuteModelRequest, SamplerOutput
17
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
18
from vllm.worker.cpu_model_runner import CPUModelRunner
19
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

logger = init_logger(__name__)


class CPUCacheEngine:
    """Manages the KV cache for CPU backend.

    This class is responsible for initializing and managing CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as copying.
    """

    def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
                 parallel_config: ParallelConfig,
                 device_config: DeviceConfig) -> None:
        assert device_config.device_type == "cpu"
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config

        self.head_size = model_config.get_head_size()
        self.num_layers = model_config.get_num_layers(parallel_config)
        self.num_heads = model_config.get_num_kv_heads(parallel_config)

        self.block_size = cache_config.block_size
        # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
        # for CPU backend, because we want to reuse KV cache management
        # in the scheduler.
        self.num_cpu_blocks = cache_config.num_gpu_blocks

        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

        # Get attention backend.
56
57
58
59
60
61
62
63
64
        self.attn_backend = get_attn_backend(
            self.model_config.get_num_attention_heads(self.parallel_config),
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            cache_config.cache_dtype,
            self.block_size,
        )
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

        # Initialize the cache.
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)

    def _allocate_kv_cache(
        self,
        num_blocks: int,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on CPU."""
        kv_cache_shape = self.attn_backend.get_kv_cache_shape(
            num_blocks, self.block_size, self.num_heads, self.head_size)
        kv_cache: List[torch.Tensor] = []
        for _ in range(self.num_layers):
            kv_cache.append(
                torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
        return kv_cache

    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
        raise NotImplementedError("Swap is not supported in CPUCacheEngine.")

    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
        raise NotImplementedError("Swap is not supported in CPUCacheEngine.")

    def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
        self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)

    @staticmethod
    def get_cache_block_size(
        block_size: int,
        cache_dtype: str,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
        num_heads = model_config.get_num_kv_heads(parallel_config)
        num_layers = model_config.get_num_layers(parallel_config)

        key_cache_block = block_size * num_heads * head_size
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
        if cache_dtype == "auto":
            dtype = model_config.dtype
        else:
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
        dtype_size = torch.tensor([], dtype=dtype).element_size()
        return dtype_size * total


113
class CPUWorker(LoraNotSupportedWorkerBase):
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    """A worker class that executes (a partition of) the model on a CPU socket.

    Each worker is associated with a single CPU socket. The worker is 
    responsible for maintaining the KV cache and executing the model on the 
    CPU. In case of distributed inference, each worker is assigned a partition
    of the model.
    """

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
128
        cache_config: CacheConfig,
129
        load_config: LoadConfig,
130
131
132
133
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        lora_config: Optional[LoRAConfig] = None,
134
        vision_language_config: Optional[VisionLanguageConfig] = None,
135
136
137
138
139
140
141
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
    ) -> None:
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.device_config = device_config
142
        self.cache_config = cache_config
143
        self.load_config = load_config
144
145
146
147
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
        self.lora_config = lora_config
148
        self.vision_language_config = vision_language_config
149
150
151
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."
152

153
154
155
156
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
157
158
159
160
161
        self.model_runner = CPUModelRunner(
            model_config,
            parallel_config,
            scheduler_config,
            device_config,
162
            cache_config,
163
164
165
166
167
            load_config=self.load_config,
            lora_config=self.lora_config,
            vision_language_config=self.vision_language_config,
            kv_cache_dtype=kv_cache_dtype,
            is_driver_worker=is_driver_worker)
168
        # Uninitialized cache engine. Will be initialized by
169
        # initialize_cache.
170
171
        self.cache_engine: CPUCacheEngine
        self.cpu_cache: List[torch.Tensor]
172
173
174
175
176
177
178
179
180

    def init_device(self) -> None:
        self.init_distributed_environment()
        # Set random seed.
        set_random_seed(self.model_config.seed)

    def load_model(self):
        self.model_runner.load_model()

181
    def determine_num_available_blocks(self) -> Tuple[int, int]:
182
183
184
185
186
187
188
189
190
        """Determine the number of blocks available for the KV cache.

        This determines how many KV blocks can fit into the configured CPU
        KV cache space.

        Note that since vLLM assumes a block resides on GPU if it can be
        modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
        This allows us to reuse the scheduler of vLLM without generalizing it
        to different devices.
191
192
193
        """
        # For CPU device, the block number will be calculated based on the
        # cpu_kvcache_space.
194
195
196
        cache_block_size = self.get_cache_block_size_bytes()
        num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
                             cache_block_size)
197
198
        num_cpu_blocks = max(num_cpu_blocks, 0)

199
200
201
202
203
        # Note: To reuse the cache management procedure,
        # use cpu cache as 'gpu cache'.
        num_gpu_blocks = num_cpu_blocks
        num_cpu_blocks = 0
        return num_gpu_blocks, num_cpu_blocks
204

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache. Currently, swappable CPU memory is not
        supported.

        Since this worker does not support GPUs, we use the num_gpu_blocks to
        determine how many non-swappable CPU blocks to allocate.
        """
        assert (num_cpu_blocks == 0
                ), f"{type(self)} does not support swappable cache"

        # Note: To reuse the cache management procedure,
        # use cpu cache as 'gpu cache'.
        num_cpu_blocks = num_gpu_blocks

        self._validate_num_cpu_blocks(num_cpu_blocks)
        self.cache_config.num_gpu_blocks = num_cpu_blocks
        self.cache_config.num_cpu_blocks = 0

        # Initialize the cache.
        self._init_cache_engine()

    def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
        """Raise errors if the num_cpu_blocks is invalid.
        """
        if num_cpu_blocks <= 0:
            raise ValueError("No available memory for the cache blocks. "
                             "Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
                             "initializing the engine.")

        max_seq_len = self.cache_config.block_size * num_cpu_blocks
        if self.model_config.max_model_len > max_seq_len:
            raise ValueError(
                f"The model's max seq len ({self.model_config.max_model_len}) "
                "is larger than the maximum number of tokens that can be "
                f"stored in KV cache ({max_seq_len}). Try increasing "
                "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
                "initializing the engine.")

    def _init_cache_engine(self) -> None:
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        self.cache_engine = CPUCacheEngine(self.cache_config,
                                           self.model_config,
                                           self.parallel_config,
                                           self.device_config)
        self.cpu_cache = self.cache_engine.cpu_cache
        self.model_runner.block_size = self.cache_engine.block_size

        assert self.cpu_cache is not None

        # Populate the cache to warmup the memory
        for layer_cache in self.cpu_cache:
            layer_cache.fill_(0)

    def cache_copy(
        self,
260
        blocks_to_copy: torch.Tensor,
261
    ) -> None:
262
        if blocks_to_copy.numel() > 0:
263
264
265
266
267
            self.cache_engine.copy(blocks_to_copy)

    @torch.inference_mode()
    def execute_model(
        self,
268
        execute_model_req: Optional[ExecuteModelRequest] = None,
269
    ) -> List[SamplerOutput]:
270
271
272
273
274
275

        if execute_model_req is None:
            seq_group_metadata_list = None
        else:
            seq_group_metadata_list = execute_model_req.seq_group_metadata_list

276
277
        if self.is_driver_worker:
            assert seq_group_metadata_list is not None
278
            num_seq_groups: int = len(seq_group_metadata_list)
279
            assert execute_model_req is not None
280
281
282
            blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                          device="cpu",
                                          dtype=torch.int64).view(-1, 2)
283
284
            assert len(execute_model_req.blocks_to_swap_in) == 0
            assert len(execute_model_req.blocks_to_swap_out) == 0
285
            data: Dict[str, Any] = {
286
                "num_seq_groups": num_seq_groups,
287
                "blocks_to_copy": execute_model_req.blocks_to_copy,
288
289
290
291
292
293
294
295
296
297
298
            }
            broadcast_tensor_dict(data, src=0)
        else:
            data = broadcast_tensor_dict(src=0)
            num_seq_groups = data["num_seq_groups"]
            blocks_to_copy = data["blocks_to_copy"]

        self.cache_copy(blocks_to_copy)

        # If there is no input, we don't need to execute the model.
        if num_seq_groups == 0:
299
            return []
300
301
302

        output = self.model_runner.execute_model(seq_group_metadata_list,
                                                 self.cpu_cache)
303
304
305

        # CPU worker only supports single-step execution.
        return [output]
306
307
308
309
310
311
312

    def init_distributed_environment(self) -> None:
        """Initialize the distributed environment."""

        parallel_config = self.parallel_config
        rank = self.rank
        distributed_init_method = self.distributed_init_method
313
314
315
316
317
318
        init_distributed_environment(
            world_size=parallel_config.world_size,
            rank=rank,
            distributed_init_method=distributed_init_method,
            backend="gloo",
        )
319
320
321
322
323
324
325

        # A small all_reduce for warmup.
        torch.distributed.all_reduce(torch.zeros(1).cpu())

        ensure_model_parallel_initialized(
            parallel_config.tensor_parallel_size,
            parallel_config.pipeline_parallel_size)
326
327
328
329
330
331
332

    def get_cache_block_size_bytes(self) -> int:
        """Return the size in bytes of a single KV cache block.
        """
        return CPUCacheEngine.get_cache_block_size(
            self.cache_config.block_size, self.cache_config.cache_dtype,
            self.model_config, self.parallel_config)