kv_transfer_state.py 3.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import TYPE_CHECKING, Optional
xuxz's avatar
xuxz committed
4
import copy
5
from vllm import envs
yangshj1's avatar
yangshj1 committed
6
from vllm.config import get_current_vllm_config
7
8
9
10
11
12
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import (
    KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
                                                          KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
yangshj1's avatar
yangshj1 committed
13
14
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
15
16
17
18
19

if TYPE_CHECKING:
    from vllm.config import VllmConfig

_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
yangshj1's avatar
yangshj1 committed
20
_KV_LMCACHE_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
21
22
23
24
25
26

def get_kv_transfer_group() -> KVConnectorBaseType:
    assert _KV_CONNECTOR_AGENT is not None, (
        "disaggregated KV cache transfer parallel group is not initialized")
    return _KV_CONNECTOR_AGENT

xuxz's avatar
xuxz committed
27
def get_lmcache_connector() -> KVConnectorBaseType:
yangshj1's avatar
yangshj1 committed
28
29
    return _KV_LMCACHE_CONNECTOR_AGENT

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

def has_kv_transfer_group() -> bool:
    return _KV_CONNECTOR_AGENT is not None


def is_v1_kv_transfer_group(
        connector: Optional[KVConnectorBaseType] = None) -> bool:
    """Check if the KV connector is the v1 connector.
    If the argument is None, it will check the global KV connector

    Args:
        connector: The KV connector to check. If None, it will check the
            global KV connector.

    Note:
        This function will no-longer be needed after the v1 KV connector
        becomes the default.
    """
    if connector is None:
        connector = _KV_CONNECTOR_AGENT

    if connector is None:
        return False

    return isinstance(connector, KVConnectorBase_V1)


def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
    """
    Initialize KV cache transfer parallel group.
    """

    global _KV_CONNECTOR_AGENT
yangshj1's avatar
yangshj1 committed
63
    global _KV_LMCACHE_CONNECTOR_AGENT
64

yangshj1's avatar
add env  
yangshj1 committed
65
    if envs.VLLM_LMCACHE_ENABLE and _KV_LMCACHE_CONNECTOR_AGENT is None:
xuxz's avatar
xuxz committed
66
67
68
69
70
71
72
73
        lmcache_config = copy.deepcopy(vllm_config)
        from vllm.config import KVTransferConfig
        lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
        lmcache_config.kv_transfer_config.engine_id = "ed9e943a-e455-4ed6-b88c-09ae6263f0c9"
        lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(lmcache_config, role=KVConnectorRole.WORKER)

        _KV_LMCACHE_CONNECTOR_AGENT = lmcache_connector

74
75
76
77
78
79
80
81
82
83
84
85
86
87
    if vllm_config.kv_transfer_config is None:
        return

    if (vllm_config.kv_transfer_config.is_kv_transfer_instance
            and _KV_CONNECTOR_AGENT is None):
        if envs.VLLM_USE_V1:
            _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
                config=vllm_config, role=KVConnectorRole.WORKER)
        else:
            _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
                rank=get_world_group().rank,
                local_rank=get_world_group().local_rank,
                config=vllm_config,
            )