Unverified Commit 3e0d4a34 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Move `KVTransferConfig` from `config/__init__.py` to `config/kv_transfer.py` (#24434)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 562663a0
...@@ -9,7 +9,6 @@ import hashlib ...@@ -9,7 +9,6 @@ import hashlib
import inspect import inspect
import json import json
import textwrap import textwrap
import uuid
import warnings import warnings
from collections.abc import Mapping from collections.abc import Mapping
from contextlib import contextmanager from contextlib import contextmanager
...@@ -34,6 +33,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, ...@@ -34,6 +33,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
from vllm.config.compilation import (CompilationConfig, CompilationLevel, from vllm.config.compilation import (CompilationConfig, CompilationLevel,
CUDAGraphMode, PassConfig) CUDAGraphMode, PassConfig)
from vllm.config.kv_events import KVEventsConfig from vllm.config.kv_events import KVEventsConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
ParallelConfig) ParallelConfig)
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
...@@ -3210,107 +3210,6 @@ class ObservabilityConfig: ...@@ -3210,107 +3210,6 @@ class ObservabilityConfig:
self.collect_detailed_traces[0].split(",")) self.collect_detailed_traces[0].split(","))
KVProducer = Literal["kv_producer", "kv_both"]
KVConsumer = Literal["kv_consumer", "kv_both"]
KVRole = Literal[KVProducer, KVConsumer]
@config
@dataclass
class KVTransferConfig:
"""Configuration for distributed KV cache transfer."""
kv_connector: Optional[str] = None
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
"""
engine_id: Optional[str] = None
"""The engine id for KV transfers."""
kv_buffer_device: Optional[str] = "cuda"
"""The device used by kv connector to buffer the KV cache.
Currently only support 'cuda'."""
kv_buffer_size: float = 1e9
"""The buffer size for TorchDistributedConnector. Measured in number of
bytes. Recommended value: 1e9 (about 1GB)."""
kv_role: Optional[KVRole] = None
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
kv_rank: Optional[int] = None
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
0 for prefill instance, 1 for decode instance.
Currently only 1P1D is supported."""
kv_parallel_size: int = 1
"""The number of parallel instances for KV cache transfer. For
P2pNcclConnector, this should be 2."""
kv_ip: str = "127.0.0.1"
"""The KV connector ip, used to build distributed connection."""
kv_port: int = 14579
"""The KV connector port, used to build distributed connection."""
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
"""any extra config that the connector may need."""
kv_connector_module_path: Optional[str] = None
"""The Python module path to dynamically load the KV connector from.
Only supported in V1."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self) -> None:
if self.engine_id is None:
self.engine_id = str(uuid.uuid4())
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
f"Supported roles are {get_args(KVRole)}")
if self.kv_connector is not None and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
f"is set, supported roles are {get_args(KVRole)}")
@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in get_args(KVRole)
@property
def is_kv_producer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in get_args(KVProducer)
@property
def is_kv_consumer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in get_args(KVConsumer)
def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)
@config @config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) @dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig: class VllmConfig:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import uuid
from dataclasses import field
from typing import Any, Literal, Optional, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
KVProducer = Literal["kv_producer", "kv_both"]
KVConsumer = Literal["kv_consumer", "kv_both"]
KVRole = Literal[KVProducer, KVConsumer]
@config
@dataclass
class KVTransferConfig:
"""Configuration for distributed KV cache transfer."""
kv_connector: Optional[str] = None
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
"""
engine_id: Optional[str] = None
"""The engine id for KV transfers."""
kv_buffer_device: Optional[str] = "cuda"
"""The device used by kv connector to buffer the KV cache.
Currently only support 'cuda'."""
kv_buffer_size: float = 1e9
"""The buffer size for TorchDistributedConnector. Measured in number of
bytes. Recommended value: 1e9 (about 1GB)."""
kv_role: Optional[KVRole] = None
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
kv_rank: Optional[int] = None
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
0 for prefill instance, 1 for decode instance.
Currently only 1P1D is supported."""
kv_parallel_size: int = 1
"""The number of parallel instances for KV cache transfer. For
P2pNcclConnector, this should be 2."""
kv_ip: str = "127.0.0.1"
"""The KV connector ip, used to build distributed connection."""
kv_port: int = 14579
"""The KV connector port, used to build distributed connection."""
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
"""any extra config that the connector may need."""
kv_connector_module_path: Optional[str] = None
"""The Python module path to dynamically load the KV connector from.
Only supported in V1."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self) -> None:
if self.engine_id is None:
self.engine_id = str(uuid.uuid4())
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
f"Supported roles are {get_args(KVRole)}")
if self.kv_connector is not None and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
f"is set, supported roles are {get_args(KVRole)}")
@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in get_args(KVRole)
@property
def is_kv_producer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in get_args(KVProducer)
@property
def is_kv_consumer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in get_args(KVConsumer)
def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)
...@@ -14,7 +14,8 @@ from vllm.logger import init_logger ...@@ -14,7 +14,8 @@ from vllm.logger import init_logger
# yapf: enable # yapf: enable
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import KVTransferConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from vllm.config import KVTransferConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
......
...@@ -15,7 +15,7 @@ import msgpack ...@@ -15,7 +15,7 @@ import msgpack
import torch import torch
import zmq import zmq
from vllm.config import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
......
...@@ -13,7 +13,7 @@ import zmq ...@@ -13,7 +13,7 @@ import zmq
from safetensors.torch import load as safetensors_load from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save from safetensors.torch import save as safetensors_save
from vllm.config import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import join_host_port, make_zmq_path, split_host_port from vllm.utils import join_host_port, make_zmq_path, split_host_port
......
...@@ -20,7 +20,7 @@ from typing import Callable, Optional ...@@ -20,7 +20,7 @@ from typing import Callable, Optional
import torch import torch
from vllm.config import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
......
...@@ -204,7 +204,7 @@ class LLM: ...@@ -204,7 +204,7 @@ class LLM:
if "kv_transfer_config" in kwargs and isinstance( if "kv_transfer_config" in kwargs and isinstance(
kwargs["kv_transfer_config"], dict): kwargs["kv_transfer_config"], dict):
from vllm.config import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
raw_config_dict = kwargs["kv_transfer_config"] raw_config_dict = kwargs["kv_transfer_config"]
try: try:
kwargs["kv_transfer_config"] = KVTransferConfig( kwargs["kv_transfer_config"] = KVTransferConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment