spec.py 2.15 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import TYPE_CHECKING

import torch

9
from vllm.attention.backends.abstract import AttentionBackend
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from vllm.logger import init_logger
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.worker.worker import OffloadingHandler

if TYPE_CHECKING:
    from vllm.config import VllmConfig

logger = init_logger(__name__)


class OffloadingSpec(ABC):
    """Spec for an offloading connector"""

    def __init__(self, vllm_config: "VllmConfig"):
        logger.warning(
            "Initializing OffloadingSpec. This API is experimental and "
26
27
            "subject to change in the future as we iterate the design."
        )
28
29
30
31
32
33
34
35
        self.vllm_config = vllm_config

        kv_transfer_config = vllm_config.kv_transfer_config
        assert kv_transfer_config is not None
        self.extra_config = kv_transfer_config.kv_connector_extra_config

        self.gpu_block_size = vllm_config.cache_config.block_size
        self.offloaded_block_size = int(
36
37
            self.extra_config.get("block_size", self.gpu_block_size)
        )
38
39
40
41
42
43
44
45
46
47
48
49
50
51

        assert self.offloaded_block_size % self.gpu_block_size == 0

    @abstractmethod
    def get_manager(self) -> OffloadingManager:
        """
        Get an OffloadingManager that will be used
        by the scheduler-side offloading connector to track
        offloaded blocks and manage evictions.
        """
        pass

    @abstractmethod
    def get_handlers(
52
53
        self,
        kv_caches: dict[str, torch.Tensor],
54
        attn_backends: dict[str, type[AttentionBackend]],
55
    ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
56
57
58
59
60
        """
        Get offloading handlers along with their respective src and dst types.

        Args:
            kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
61
            attn_backends: A dictionary of layer_name -> AttentionBackend.
62
63
64
65
66

        Yields:
            Tuples of (src_type, dst_type, offloading_handler).
        """
        pass