cpu.py 4.19 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator

import torch

7
from vllm.config import VllmConfig
8
from vllm.platforms import current_platform
9
from vllm.v1.attention.backend import AttentionBackend
10
from vllm.v1.kv_cache_interface import KVCacheConfig
11
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
12
from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager
13
14
15
16
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
17
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
18
19
20
21
from vllm.v1.kv_offload.worker.worker import OffloadingHandler


class CPUOffloadingSpec(OffloadingSpec):
22
23
    def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig):
        super().__init__(vllm_config, kv_cache_config)
24

25
26
        cpu_bytes_to_use = self.extra_config.get("cpu_bytes_to_use")
        if not cpu_bytes_to_use:
27
            raise Exception(
28
                "cpu_bytes_to_use must be specified in kv_connector_extra_config"
29
            )
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

        # calculate kv_bytes_per_offloaded_block
        assert kv_cache_config is not None
        page_sizes = {
            kv_cache_group.kv_cache_spec.page_size_bytes
            for kv_cache_group in kv_cache_config.kv_cache_groups
        }
        assert len(page_sizes) == 1
        page_size_bytes = page_sizes.pop()
        kv_bytes_per_block = (
            page_size_bytes
            * len(kv_cache_config.kv_cache_tensors)
            * vllm_config.parallel_config.world_size
        )
        kv_bytes_per_offloaded_block = kv_bytes_per_block * (
            self.offloaded_block_size // self.gpu_block_size
        )

        self.num_blocks = (
            int(cpu_bytes_to_use) // kv_bytes_per_offloaded_block
            if kv_bytes_per_offloaded_block > 0
            else 0
        )
53
54

        # scheduler-side
55
        self._manager: OffloadingManager | None = None
56
57

        # worker-side
58
        self._handlers: CpuGpuOffloadingHandlers | None = None
59

60
61
        self.eviction_policy: str = self.extra_config.get("eviction_policy", "lru")

62
63
64
    def get_manager(self) -> OffloadingManager:
        if not self._manager:
            kv_events_config = self.vllm_config.kv_events_config
65
66
67
            enable_events = (
                kv_events_config is not None and kv_events_config.enable_kv_cache_events
            )
68
69

            backend = CPUBackend(
70
                block_size=self.offloaded_block_size, num_blocks=self.num_blocks
71
            )
72
73
74
75
76
77
78
79
80
81
82
83
84
85

            if self.eviction_policy == "lru":
                self._manager = LRUOffloadingManager(
                    backend=backend, enable_events=enable_events
                )
            elif self.eviction_policy == "arc":
                self._manager = ARCOffloadingManager(
                    backend=backend, enable_events=enable_events
                )
            else:
                raise ValueError(
                    f"Unknown eviction policy: {self.eviction_policy}. "
                    f"Supported policies: lru, arc"
                )
86
87
88
        return self._manager

    def get_handlers(
89
90
91
        self,
        kv_caches: dict[str, torch.Tensor],
        attn_backends: dict[str, type[AttentionBackend]],
92
    ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
93
        if not self._handlers:
94
            if not current_platform.is_cuda_alike():
95
                raise Exception(
96
                    "CPU Offloading is currently only supported on CUDA-alike GPUs"
97
                )
98

99
            self._handlers = CpuGpuOffloadingHandlers(
100
101
102
                attn_backends=attn_backends,
                gpu_block_size=self.gpu_block_size,
                cpu_block_size=self.offloaded_block_size,
103
                num_cpu_blocks=self.num_blocks,
104
105
                gpu_caches=kv_caches,
            )
106

107
108
109
        assert self._handlers is not None
        yield GPULoadStoreSpec, CPULoadStoreSpec, self._handlers.gpu_to_cpu_handler
        yield CPULoadStoreSpec, GPULoadStoreSpec, self._handlers.cpu_to_gpu_handler