cpu.py 3.36 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator

import torch

7
from vllm.attention.backends.abstract import AttentionBackend
8
from vllm.config import VllmConfig
9
10
from vllm.platforms import current_platform
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
11
from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager
12
13
14
15
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
16
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
17
18
19
20
21
22
23
24
25
from vllm.v1.kv_offload.worker.worker import OffloadingHandler


class CPUOffloadingSpec(OffloadingSpec):
    def __init__(self, vllm_config: VllmConfig):
        super().__init__(vllm_config)

        num_cpu_blocks = self.extra_config.get("num_cpu_blocks")
        if not num_cpu_blocks:
26
27
28
            raise Exception(
                "num_cpu_blocks must be specified in kv_connector_extra_config"
            )
29
30
31
        self.num_cpu_blocks: int = num_cpu_blocks

        # scheduler-side
32
        self._manager: OffloadingManager | None = None
33
34

        # worker-side
35
        self._handlers: CpuGpuOffloadingHandlers | None = None
36

37
38
        self.eviction_policy: str = self.extra_config.get("eviction_policy", "lru")

39
40
41
    def get_manager(self) -> OffloadingManager:
        if not self._manager:
            kv_events_config = self.vllm_config.kv_events_config
42
43
44
            enable_events = (
                kv_events_config is not None and kv_events_config.enable_kv_cache_events
            )
45
46
47

            backend = CPUBackend(
                block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks
48
            )
49
50
51
52
53
54
55
56
57
58
59
60
61
62

            if self.eviction_policy == "lru":
                self._manager = LRUOffloadingManager(
                    backend=backend, enable_events=enable_events
                )
            elif self.eviction_policy == "arc":
                self._manager = ARCOffloadingManager(
                    backend=backend, enable_events=enable_events
                )
            else:
                raise ValueError(
                    f"Unknown eviction policy: {self.eviction_policy}. "
                    f"Supported policies: lru, arc"
                )
63
64
65
        return self._manager

    def get_handlers(
66
67
68
        self,
        kv_caches: dict[str, torch.Tensor],
        attn_backends: dict[str, type[AttentionBackend]],
69
    ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
70
        if not self._handlers:
71
            if not current_platform.is_cuda_alike():
72
                raise Exception(
73
                    "CPU Offloading is currently only supported on CUDA-alike GPUs"
74
                )
75

76
            self._handlers = CpuGpuOffloadingHandlers(
77
78
79
80
                attn_backends=attn_backends,
                gpu_block_size=self.gpu_block_size,
                cpu_block_size=self.offloaded_block_size,
                num_cpu_blocks=self.num_cpu_blocks,
81
82
                gpu_caches=kv_caches,
            )
83

84
85
86
        assert self._handlers is not None
        yield GPULoadStoreSpec, CPULoadStoreSpec, self._handlers.gpu_to_cpu_handler
        yield CPULoadStoreSpec, GPULoadStoreSpec, self._handlers.cpu_to_gpu_handler