cpu_worker.py 18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A CPU worker class."""
4
import os
5
from importlib import util
6
from typing import List, Optional, Set, Tuple, Type
7
8
9
10

import torch
import torch.distributed

11
import vllm.envs as envs
12
from vllm.attention import get_attn_backend
13
14
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, VllmConfig)
15
from vllm.distributed import (ensure_model_parallel_initialized,
16
                              init_distributed_environment)
17
from vllm.logger import init_logger
18
from vllm.lora.request import LoRARequest
19
from vllm.model_executor import set_random_seed
20
from vllm.sequence import ExecuteModelRequest
21
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
22
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
23
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
24
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
25
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
26
                                     WorkerInput)
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

logger = init_logger(__name__)


class CPUCacheEngine:
    """Manages the KV cache for CPU backend.

    This class is responsible for initializing and managing CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as copying.
    """

    def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
                 parallel_config: ParallelConfig,
                 device_config: DeviceConfig) -> None:
        assert device_config.device_type == "cpu"
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config

        self.head_size = model_config.get_head_size()
        self.num_layers = model_config.get_num_layers(parallel_config)
        self.num_heads = model_config.get_num_kv_heads(parallel_config)

        self.block_size = cache_config.block_size
        # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
        # for CPU backend, because we want to reuse KV cache management
        # in the scheduler.
        self.num_cpu_blocks = cache_config.num_gpu_blocks

        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
59
60
        elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
            self.dtype = torch.float8_e5m2
61
        else:
62
63
            raise NotImplementedError(f"Unsupported KV cache type "
                                      f"{cache_config.cache_dtype}.")
64
65

        # Get attention backend.
66
67
68
69
70
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            cache_config.cache_dtype,
            self.block_size,
71
            self.model_config.is_attention_free,
Thien Tran's avatar
Thien Tran committed
72
            use_mla=self.model_config.use_mla,
73
        )
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

        # Initialize the cache.
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)

    def _allocate_kv_cache(
        self,
        num_blocks: int,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on CPU."""
        kv_cache_shape = self.attn_backend.get_kv_cache_shape(
            num_blocks, self.block_size, self.num_heads, self.head_size)
        kv_cache: List[torch.Tensor] = []
        for _ in range(self.num_layers):
            kv_cache.append(
                torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
        return kv_cache

91
    def swap_in(self, src_to_dst: torch.Tensor) -> None:
92
93
        raise NotImplementedError("Swap is not supported in CPUCacheEngine.")

94
    def swap_out(self, src_to_dst: torch.Tensor) -> None:
95
96
        raise NotImplementedError("Swap is not supported in CPUCacheEngine.")

97
    def copy(self, src_to_dsts: torch.Tensor) -> None:
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)

    @staticmethod
    def get_cache_block_size(
        block_size: int,
        cache_dtype: str,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
        num_heads = model_config.get_num_kv_heads(parallel_config)
        num_layers = model_config.get_num_layers(parallel_config)

        key_cache_block = block_size * num_heads * head_size
112
        value_cache_block = key_cache_block if not model_config.use_mla else 0
113
114
115
116
117
118
119
120
121
        total = num_layers * (key_cache_block + value_cache_block)
        if cache_dtype == "auto":
            dtype = model_config.dtype
        else:
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
        dtype_size = torch.tensor([], dtype=dtype).element_size()
        return dtype_size * total


122
class CPUWorker(LocalOrDistributedWorkerBase):
123
124
125
126
127
128
129
130
131
132
    """A worker class that executes (a partition of) the model on a CPU socket.

    Each worker is associated with a single CPU socket. The worker is 
    responsible for maintaining the KV cache and executing the model on the 
    CPU. In case of distributed inference, each worker is assigned a partition
    of the model.
    """

    def __init__(
        self,
133
        vllm_config: VllmConfig,
134
135
136
137
138
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
139
        model_runner_cls: Optional[Type[CPUModelRunner]] = None,
140
    ) -> None:
141
142
        WorkerBase.__init__(self, vllm_config=vllm_config)

143
144
        self.local_rank = local_rank
        self.rank = rank
145
146
        vllm_config.parallel_config.rank = rank

147
        self.distributed_init_method = distributed_init_method
148

149
150
151
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."
152

153
154
155
156
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
157
158
159

        # Setup OpenMP threads affinity.
        omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
160
161
162
163
        self.local_omp_cpuid = "all"
        if omp_cpuids == "auto":
            self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
            )
164
165
166
        else:
            self.local_omp_cpuid = omp_cpuids.split("|")[rank]

167
168
169
170
171
172
173
174
175
176
        # Return hidden states from target model if the draft model is an
        # mlp_speculator
        speculative_config = self.speculative_config
        model_config = self.model_config
        speculative_args = {} if speculative_config is None \
            or (speculative_config.draft_model_config.model ==
                model_config.model) \
            or (speculative_config.draft_model_config.hf_config.model_type
                not in ["medusa", "mlp_speculator", "eagle"]) \
                    else {"return_hidden_states": True}
177
        ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
178
        if self.model_config.runner_type == "pooling":
179
            ModelRunnerClass = CPUPoolingModelRunner
180
        elif self.model_config.is_encoder_decoder:
181
            ModelRunnerClass = CPUEncoderDecoderModelRunner
182
        self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
183
            vllm_config=vllm_config,
184
            kv_cache_dtype=kv_cache_dtype,
185
186
187
188
189
            is_driver_worker=is_driver_worker,
            **speculative_args,
        )
        if model_runner_cls is not None:
            self.model_runner = model_runner_cls(self.model_runner)
190
        # Uninitialized cache engine. Will be initialized by
191
        # initialize_cache.
192
        self.cache_engine: List[CPUCacheEngine]
193
        # Initialize cpu_cache as pooling models don't initialize kv_caches
194
        self.cpu_cache: Optional[List[List[torch.Tensor]]] = None
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
                with_stack=True,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None

    def start_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()

222
    def init_device(self) -> None:
223
        if self.local_omp_cpuid != "all":
224
            ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
225
226
            if ret:
                logger.info(ret)
227
228
229
230

        # Note: unique identifier for creating allreduce shared memory
        os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
            ":")[-1]
231
        self.device = torch.device("cpu")
232
233
234
235
236
237
238
        self.init_distributed_environment()
        # Set random seed.
        set_random_seed(self.model_config.seed)

    def load_model(self):
        self.model_runner.load_model()

239
    def determine_num_available_blocks(self) -> Tuple[int, int]:
240
241
242
243
244
245
246
247
248
        """Determine the number of blocks available for the KV cache.

        This determines how many KV blocks can fit into the configured CPU
        KV cache space.

        Note that since vLLM assumes a block resides on GPU if it can be
        modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
        This allows us to reuse the scheduler of vLLM without generalizing it
        to different devices.
249
250
251
        """
        # For CPU device, the block number will be calculated based on the
        # cpu_kvcache_space.
252
253
254
        cache_block_size = self.get_cache_block_size_bytes()
        num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
                             cache_block_size)
255
256
        num_cpu_blocks = max(num_cpu_blocks, 0)

257
258
259
260
261
        # Note: To reuse the cache management procedure,
        # use cpu cache as 'gpu cache'.
        num_gpu_blocks = num_cpu_blocks
        num_cpu_blocks = 0
        return num_gpu_blocks, num_cpu_blocks
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache. Currently, swappable CPU memory is not
        supported.

        Since this worker does not support GPUs, we use the num_gpu_blocks to
        determine how many non-swappable CPU blocks to allocate.
        """
        assert (num_cpu_blocks == 0
                ), f"{type(self)} does not support swappable cache"

        # Note: To reuse the cache management procedure,
        # use cpu cache as 'gpu cache'.
        num_cpu_blocks = num_gpu_blocks

        self._validate_num_cpu_blocks(num_cpu_blocks)
        self.cache_config.num_gpu_blocks = num_cpu_blocks
        self.cache_config.num_cpu_blocks = 0

        # Initialize the cache.
        self._init_cache_engine()

285
286
287
288
289
290
291
292
293
294
295
296
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
        """Raise errors if the num_cpu_blocks is invalid.
        """
        if num_cpu_blocks <= 0:
            raise ValueError("No available memory for the cache blocks. "
                             "Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
                             "initializing the engine.")

        max_seq_len = self.cache_config.block_size * num_cpu_blocks
        if self.model_config.max_model_len > max_seq_len:
            raise ValueError(
                f"The model's max seq len ({self.model_config.max_model_len}) "
                "is larger than the maximum number of tokens that can be "
                f"stored in KV cache ({max_seq_len}). Try increasing "
                "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
                "initializing the engine.")

    def _init_cache_engine(self) -> None:
315
316
317
318
319
320
321
322
323
        self.cache_engine = [
            CPUCacheEngine(self.cache_config, self.model_config,
                           self.parallel_config, self.device_config)
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
        self.cpu_cache = [
            self.cache_engine[ve].cpu_cache
            for ve in range(self.parallel_config.pipeline_parallel_size)
        ]
324
325
        bind_kv_cache(self.compilation_config.static_forward_context,
                      self.cpu_cache)
326
327
328
329
330
        self.model_runner.block_size = self.cache_engine[0].block_size

        assert all(
            self.cpu_cache[ve] is not None
            for ve in range(self.parallel_config.pipeline_parallel_size))
331
332

        # Populate the cache to warmup the memory
333
334
335
        for ve in range(self.parallel_config.pipeline_parallel_size):
            for layer_cache in self.cpu_cache[ve]:
                layer_cache.fill_(0)
336

337
338
339
340
341
    @property
    def do_metadata_broadcast(self) -> bool:
        return self.parallel_config.tensor_parallel_size > 1

    @property
342
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
343
344
        return self.cpu_cache

345
346
347
348
349
350
351
352
    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

353
    def execute_worker(
354
        self,
355
        worker_input: WorkerInput,
356
    ) -> None:
357
358
        if (worker_input.blocks_to_copy is not None
                and worker_input.blocks_to_copy.numel() > 0):
359
360
            self.cache_engine[worker_input.virtual_engine].copy(
                worker_input.blocks_to_copy)
361
362

    @torch.inference_mode()
363
364
365
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        assert execute_model_req is not None
366
        virtual_engine: int = execute_model_req.virtual_engine
367
368
369
370
371
372
373
374
375
        num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                      device="cpu",
                                      dtype=torch.int64).view(-1, 2)
        assert len(execute_model_req.blocks_to_swap_in) == 0
        assert len(execute_model_req.blocks_to_swap_out) == 0
        return WorkerInput(
            num_seq_groups=num_seq_groups,
            blocks_to_copy=blocks_to_copy,
376
            virtual_engine=virtual_engine,
377
        )
378
379
380
381
382
383
384

    def init_distributed_environment(self) -> None:
        """Initialize the distributed environment."""

        parallel_config = self.parallel_config
        rank = self.rank
        distributed_init_method = self.distributed_init_method
385
386
387
388
389
390
        init_distributed_environment(
            world_size=parallel_config.world_size,
            rank=rank,
            distributed_init_method=distributed_init_method,
            backend="gloo",
        )
391
392
393
394
395
396

        # A small all_reduce for warmup.
        torch.distributed.all_reduce(torch.zeros(1).cpu())

        ensure_model_parallel_initialized(
            parallel_config.tensor_parallel_size,
397
            parallel_config.pipeline_parallel_size)
398
399
400
401
402
403
404

    def get_cache_block_size_bytes(self) -> int:
        """Return the size in bytes of a single KV cache block.
        """
        return CPUCacheEngine.get_cache_block_size(
            self.cache_config.block_size, self.cache_config.cache_dtype,
            self.model_config, self.parallel_config)
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

    def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
        """Return CPUs id binding based on NUMA nodes.
        """
        rank_to_cpus = self.local_omp_cpuid
        # Setup OpenMP thread affinity based on NUMA nodes automatically
        world_size = self.vllm_config.parallel_config.world_size
        libnuma_found = util.find_spec("numa") is not None
        psutil_found = util.find_spec("psutil") is not None
        if libnuma_found and psutil_found:
            import psutil
            from numa import info
            cpu_count = psutil.cpu_count(logical=False)
            cpus_allow_list = psutil.Process().cpu_affinity()
            numa_size = info.get_num_configured_nodes()
            cpu_count_per_numa = cpu_count // numa_size
            num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
                                      cpu_count_per_numa // 2)

            # check allow node_to_cpus list
            node_to_cpus = []
            for i in range(numa_size):
                node_intersect = set(
                    info.node_to_cpus(i)).intersection(cpus_allow_list)
                if bool(node_intersect):
                    node_to_cpus.append(list(node_intersect))

            if world_size > len(node_to_cpus):
                logger.error(
                    "Auto thread-binding failed due to "
                    "world size: %d is larger than "
                    "allowed NUMA nodes number: %d."
                    "Please try to bind threads manually.", world_size,
                    len(node_to_cpus))
            else:
                end = cpu_count_per_numa - num_of_reserved_cpu
                rank_to_cpus_list = node_to_cpus[self.rank][:end]
                rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
                logger.info("auto thread-binding list: %s", rank_to_cpus)
        else:
            logger.warning(
                "Auto thread-binding is not supported due to "
                "the lack of package numa and psutil,"
                "fallback to no thread-binding. To get better performance,"
                "please try to manually bind threads.")
        return rank_to_cpus