kv_transfer_state.py 3.32 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import TYPE_CHECKING, Optional
xuxz's avatar
xuxz committed
4
import copy
5
from vllm import envs
yangshj1's avatar
yangshj1 committed
6
from vllm.config import get_current_vllm_config
7
8
9
10
11
12
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import (
    KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
                                                          KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
yangshj1's avatar
yangshj1 committed
13
14
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
15
16
17
18
19

if TYPE_CHECKING:
    from vllm.config import VllmConfig

_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
yangshj1's avatar
yangshj1 committed
20
_KV_LMCACHE_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
21
22
23
24
25
26

def get_kv_transfer_group() -> KVConnectorBaseType:
    assert _KV_CONNECTOR_AGENT is not None, (
        "disaggregated KV cache transfer parallel group is not initialized")
    return _KV_CONNECTOR_AGENT

xuxz's avatar
xuxz committed
27
def get_lmcache_connector() -> KVConnectorBaseType:
yangshj1's avatar
yangshj1 committed
28
29
30
31
    assert _KV_LMCACHE_CONNECTOR_AGENT is not None, (
        "LM cache transfer parallel group is not initialized")
    return _KV_LMCACHE_CONNECTOR_AGENT

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

def has_kv_transfer_group() -> bool:
    return _KV_CONNECTOR_AGENT is not None


def is_v1_kv_transfer_group(
        connector: Optional[KVConnectorBaseType] = None) -> bool:
    """Check if the KV connector is the v1 connector.
    If the argument is None, it will check the global KV connector

    Args:
        connector: The KV connector to check. If None, it will check the
            global KV connector.

    Note:
        This function will no-longer be needed after the v1 KV connector
        becomes the default.
    """
    if connector is None:
        connector = _KV_CONNECTOR_AGENT

    if connector is None:
        return False

    return isinstance(connector, KVConnectorBase_V1)


def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
    """
    Initialize KV cache transfer parallel group.
    """

    global _KV_CONNECTOR_AGENT
yangshj1's avatar
yangshj1 committed
65
    global _KV_LMCACHE_CONNECTOR_AGENT
66

xuxz's avatar
xuxz committed
67
68
69
70
71
72
73
74
75
    if _KV_LMCACHE_CONNECTOR_AGENT is None:
        lmcache_config = copy.deepcopy(vllm_config)
        from vllm.config import KVTransferConfig
        lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
        lmcache_config.kv_transfer_config.engine_id = "ed9e943a-e455-4ed6-b88c-09ae6263f0c9"
        lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(lmcache_config, role=KVConnectorRole.WORKER)

        _KV_LMCACHE_CONNECTOR_AGENT = lmcache_connector

76
77
78
79
80
81
82
83
84
85
86
87
88
89
    if vllm_config.kv_transfer_config is None:
        return

    if (vllm_config.kv_transfer_config.is_kv_transfer_instance
            and _KV_CONNECTOR_AGENT is None):
        if envs.VLLM_USE_V1:
            _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
                config=vllm_config, role=KVConnectorRole.WORKER)
        else:
            _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
                rank=get_world_group().rank,
                local_rank=get_world_group().local_rank,
                config=vllm_config,
            )