kv_transfer_state.py 2.38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import TYPE_CHECKING, Optional
4
5
6

from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
7
8
9
10
11
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
    KVConnectorBase_V1,
    KVConnectorRole,
)
12
13
14

if TYPE_CHECKING:
    from vllm.config import VllmConfig
15
    from vllm.v1.kv_cache_interface import KVCacheConfig
16

17
_KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None
18
19
20
21


def get_kv_transfer_group() -> KVConnectorBaseType:
    assert _KV_CONNECTOR_AGENT is not None, (
22
23
        "disaggregated KV cache transfer parallel group is not initialized"
    )
24
25
26
27
28
29
30
    return _KV_CONNECTOR_AGENT


def has_kv_transfer_group() -> bool:
    return _KV_CONNECTOR_AGENT is not None


31
def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> bool:
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    """Check if the KV connector is the v1 connector.
    If the argument is None, it will check the global KV connector

    Args:
        connector: The KV connector to check. If None, it will check the
            global KV connector.

    Note:
        This function will no-longer be needed after the v1 KV connector
        becomes the default.
    """
    if connector is None:
        connector = _KV_CONNECTOR_AGENT

    if connector is None:
        return False

    return isinstance(connector, KVConnectorBase_V1)


52
53
54
def ensure_kv_transfer_initialized(
    vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None
) -> None:
55
56
57
58
59
60
61
62
63
    """
    Initialize KV cache transfer parallel group.
    """

    global _KV_CONNECTOR_AGENT

    if vllm_config.kv_transfer_config is None:
        return

64
65
66
67
    if (
        vllm_config.kv_transfer_config.is_kv_transfer_instance
        and _KV_CONNECTOR_AGENT is None
    ):
68
        if envs.VLLM_USE_V1:
69
            _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
70
71
72
                config=vllm_config,
                role=KVConnectorRole.WORKER,
                kv_cache_config=kv_cache_config,
73
            )
74
        else:
75
            raise ValueError("V0 is no longer supported")
76
77
78
79
80
81
82


def ensure_kv_transfer_shutdown() -> None:
    global _KV_CONNECTOR_AGENT
    if _KV_CONNECTOR_AGENT is not None:
        _KV_CONNECTOR_AGENT.shutdown()
        _KV_CONNECTOR_AGENT = None