Unverified Commit 59dd311c authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[KVConnector] Keep KVTransferParams as a dict (#18033)

parent d066e520
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional from typing import Any, Optional
import torch import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig) ModelConfig, SchedulerConfig, VllmConfig)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVTransferParams)
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec) KVCacheGroupSpec)
...@@ -124,20 +122,20 @@ def create_request( ...@@ -124,20 +122,20 @@ def create_request(
) -> Request: ) -> Request:
"""Make dummy request for testing.""" """Make dummy request for testing."""
kv_transfer_params: Optional[dict[str, Any]] = None
if do_remote_decode: if do_remote_decode:
assert not do_remote_prefill assert not do_remote_prefill
kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False, kv_transfer_params = dict(do_remote_prefill=False,
do_remote_decode=True) do_remote_decode=True)
elif do_remote_prefill: elif do_remote_prefill:
kv_transfer_params = NixlKVTransferParams( kv_transfer_params = dict(do_remote_prefill=True,
do_remote_prefill=True, do_remote_decode=False,
do_remote_decode=False, remote_engine_id="my-engine-id",
remote_engine_id="my-engine-id", remote_block_ids=list(
remote_block_ids=list(range(num_remote_blocks)), range(num_remote_blocks)),
remote_host="my-host", remote_host="my-host",
remote_port=1234) remote_port=1234)
else:
kv_transfer_params = None
max_tokens = 1 if do_remote_decode else max_tokens max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens) sampling_params = SamplingParams(max_tokens=max_tokens)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole, KVTransferParams) KVConnectorBase_V1, KVConnectorRole)
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"] __all__ = ["KVConnectorRole", "KVConnectorBase_V1"]
...@@ -48,23 +48,6 @@ class KVConnectorRole(enum.Enum): ...@@ -48,23 +48,6 @@ class KVConnectorRole(enum.Enum):
WORKER = 1 WORKER = 1
class KVTransferParams:
"""
Abstract KVTransferParams used to send KVTransfer
parameters between instances of vLLM.
Specific instances of KVConnector customize this
method for serializing / deserializing msgs sent
via the HTTP protocol.
"""
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["KVTransferParams"]:
return None
@dataclass @dataclass
class KVConnectorMetadata: class KVConnectorMetadata:
""" """
...@@ -75,7 +58,6 @@ class KVConnectorMetadata: ...@@ -75,7 +58,6 @@ class KVConnectorMetadata:
class KVConnectorBase_V1(ABC): class KVConnectorBase_V1(ABC):
_KVTransferParams = KVTransferParams
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
logger.warning( logger.warning(
...@@ -213,13 +195,6 @@ class KVConnectorBase_V1(ABC): ...@@ -213,13 +195,6 @@ class KVConnectorBase_V1(ABC):
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================
def set_kv_transfer_params(self, request: "Request"):
"""Parse raw KV Transfer params."""
assert request.kv_transfer_params is None
kv_transfer_params = self._KVTransferParams.from_raw_dict(
request.raw_kv_transfer_params)
request.kv_transfer_params = kv_transfer_params
@abstractmethod @abstractmethod
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, self,
......
...@@ -16,7 +16,7 @@ import zmq ...@@ -16,7 +16,7 @@ import zmq
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group) get_tp_group)
...@@ -44,56 +44,6 @@ except ImportError: ...@@ -44,56 +44,6 @@ except ImportError:
NixlWrapper = None NixlWrapper = None
@dataclass
class NixlKVTransferParams(KVTransferParams):
def __init__(
self,
do_remote_prefill: bool,
do_remote_decode: bool,
remote_block_ids: Optional[list[int]] = None,
remote_host: Optional[str] = None,
remote_port: Optional[int] = None,
remote_engine_id: Optional[str] = None,
):
self.do_remote_prefill = do_remote_prefill
self.do_remote_decode = do_remote_decode
self.remote_block_ids = remote_block_ids
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_engine_id = remote_engine_id
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["NixlKVTransferParams"]:
# If no raw transfer params passed, return None.
if raw_dict is None:
return None
# Validate the request is formatted properly.
if (("do_remote_prefill" not in raw_dict)
or ("do_remote_decode" not in raw_dict)
or ("remote_block_ids" not in raw_dict)
or ("remote_host" not in raw_dict)
or ("remote_port" not in raw_dict)
or ("remote_engine_id" not in raw_dict)):
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer", raw_dict)
return None
return NixlKVTransferParams(
do_remote_prefill=raw_dict["do_remote_prefill"],
do_remote_decode=raw_dict["do_remote_decode"],
remote_block_ids=raw_dict["remote_block_ids"],
remote_host=raw_dict["remote_host"],
remote_port=raw_dict["remote_port"],
remote_engine_id=raw_dict["remote_engine_id"],
)
class NixlAgentMetadata( class NixlAgentMetadata(
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
...@@ -123,25 +73,18 @@ class NixlConnectorMetadata(KVConnectorMetadata): ...@@ -123,25 +73,18 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self, self,
request_id: str, request_id: str,
local_block_ids: list[int], local_block_ids: list[int],
kv_transfer_params: NixlKVTransferParams, kv_transfer_params: dict[str, Any],
): ):
assert request_id not in self.requests
assert kv_transfer_params.remote_block_ids is not None
assert kv_transfer_params.remote_engine_id is not None
assert kv_transfer_params.remote_host is not None
assert kv_transfer_params.remote_port is not None
self.requests[request_id] = ReqMeta( self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids, local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params.remote_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params.remote_engine_id, remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params.remote_host, remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params.remote_port, remote_port=kv_transfer_params["remote_port"],
) )
class NixlConnector(KVConnectorBase_V1): class NixlConnector(KVConnectorBase_V1):
_KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
...@@ -253,52 +196,52 @@ class NixlConnectorScheduler: ...@@ -253,52 +196,52 @@ class NixlConnectorScheduler:
asynchronously (between scheduler steps). asynchronously (between scheduler steps).
""" """
params = request.kv_transfer_params
logger.debug( logger.debug(
"NIXLConnector get_num_new_matched_tokens: " "NIXLConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s", "num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens, request.kv_transfer_params) num_computed_tokens, params)
# No KVTransfer for this request.
if request.kv_transfer_params is None:
return 0, False
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
# Remote prefill: get all prompt blocks from remote. if params is not None and params.get("do_remote_prefill"):
if request.kv_transfer_params.do_remote_prefill: # Remote prefill: get all prompt blocks from remote.
assert num_computed_tokens % self.block_size == 0 assert num_computed_tokens % self.block_size == 0
rounded_num_prompt_tokens = round_down( rounded_num_prompt_tokens = round_down(
len(request.prompt_token_ids), self.block_size) len(request.prompt_token_ids), self.block_size)
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
return count, count > 0 return count, count > 0
# No remote prefill for this request.
return 0, False return 0, False
def update_state_after_alloc(self, request: "Request", def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks", blocks: "KVCacheBlocks",
num_external_tokens: int): num_external_tokens: int):
params = request.kv_transfer_params
logger.debug( logger.debug(
"NIXLConnector update_state_after_alloc: " "NIXLConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s", "num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens, request.kv_transfer_params) num_external_tokens, params)
if request.kv_transfer_params is None: if params is not None and params.get("do_remote_prefill"):
return
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if request.kv_transfer_params.do_remote_prefill:
# NOTE(rob): if prompt < block_size, no remote blocks # NOTE(rob): if prompt < block_size, no remote blocks
# since the remote only sends fully computed blocks, so # since the remote only sends fully computed blocks, so
# skip recving for this request. num_external_tokens # skip recving for this request. num_external_tokens
# should be 0 if there are no remote blocks. # should be 0 if there are no remote blocks.
if request.kv_transfer_params.remote_block_ids: if params.get("remote_block_ids"):
# Get unhashed blocks to pull from remote. if all(p in params for p in ("remote_engine_id", "remote_host",
self._reqs_need_recv[request.request_id] = ( "remote_port")):
request, blocks.get_unhashed_block_ids()) # Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer", params)
else: else:
assert num_external_tokens == 0 assert num_external_tokens == 0
# Only trigger 1 KV transfer per request. # Only trigger 1 KV transfer per request.
request.kv_transfer_params.do_remote_prefill = False params["do_remote_prefill"] = False
def build_connector_meta( def build_connector_meta(
self, self,
...@@ -308,7 +251,7 @@ class NixlConnectorScheduler: ...@@ -308,7 +251,7 @@ class NixlConnectorScheduler:
# Loop through scheduled reqs and convert to ReqMeta. # Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items(): for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert isinstance(req.kv_transfer_params, NixlKVTransferParams) assert req.kv_transfer_params is not None
meta.add_new_req( meta.add_new_req(
request_id=req_id, request_id=req_id,
local_block_ids=block_ids, local_block_ids=block_ids,
...@@ -330,34 +273,30 @@ class NixlConnectorScheduler: ...@@ -330,34 +273,30 @@ class NixlConnectorScheduler:
should be freed now or will be sent asynchronously and freed later. should be freed now or will be sent asynchronously and freed later.
""" """
params = request.kv_transfer_params
logger.debug( logger.debug(
"NIXLConnector request_finished, " "NIXLConnector request_finished, request_status=%s, "
"request_status=%s, kv_transfer_params=%s", request.status, "kv_transfer_params=%s", request.status, params)
request.kv_transfer_params)
if request.kv_transfer_params is None:
return False, None
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if ((not request.kv_transfer_params.do_remote_decode) if (params is None or not params.get("do_remote_decode")
or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)): or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None return False, None
# Get computed blocks. # Get computed blocks.
all_full = request.num_computed_tokens % self.block_size == 0 all_full = request.num_computed_tokens % self.block_size == 0
computed_block_ids = (block_ids if all_full else block_ids[:-1]) computed_block_ids = block_ids if all_full else block_ids[:-1]
# If prompt < block_size, no xfer so free blocks immediately. # If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks = len(computed_block_ids) > 0 delay_free_blocks = len(computed_block_ids) > 0
return delay_free_blocks, NixlKVTransferParams( return delay_free_blocks, dict(
do_remote_prefill=True, do_remote_prefill=True,
do_remote_decode=False, do_remote_decode=False,
remote_block_ids=computed_block_ids, remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id, remote_engine_id=self.engine_id,
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
).__dict__ )
class NixlConnectorWorker: class NixlConnectorWorker:
......
...@@ -12,8 +12,7 @@ from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch ...@@ -12,8 +12,7 @@ from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole, KVConnectorRole)
KVTransferParams)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
...@@ -931,8 +930,13 @@ class Scheduler(SchedulerInterface): ...@@ -931,8 +930,13 @@ class Scheduler(SchedulerInterface):
return self.connector return self.connector
def _connector_finished( def _connector_finished(
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]: self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]:
"""Invoke the KV connector request_finished() method if applicable.""" """
Invoke the KV connector request_finished() method if applicable.
Returns optional kv transfer parameters to be included with the
request outputs.
"""
if self.connector is None: if self.connector is None:
return False, None return False, None
block_ids = self.kv_cache_manager.get_block_ids(request.request_id) block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
......
...@@ -182,14 +182,10 @@ class EngineCore: ...@@ -182,14 +182,10 @@ class EngineCore:
# Start grammar compilation asynchronously # Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req) self.structured_output_manager.grammar_init(req)
if req.raw_kv_transfer_params is not None: if req.kv_transfer_params is not None and (
if (kv_connector := self.scheduler.get_kv_connector()): not self.scheduler.get_kv_connector()):
# Parse raw KV transfer params via connector. logger.warning("Got kv_transfer_params, but no KVConnector found. "
kv_connector.set_kv_transfer_params(req) "Disabling KVTransfer for this request.")
else:
logger.warning(
"Got KVTransferParams, but no KVConnector found. "
"Disabling KVTransfer for this request.")
self.scheduler.add_request(req) self.scheduler.add_request(req)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import enum import enum
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -62,14 +61,10 @@ class Request: ...@@ -62,14 +61,10 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs) self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0 self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: KV transfer parameters (raw and parsed). # P/D: Connector-specific KV transfer parameters.
raw_params = (None if sampling_params.extra_args is None kv_params = (None if sampling_params.extra_args is None else
else sampling_params.extra_args.get( sampling_params.extra_args.get("kv_transfer_params"))
"kv_transfer_params", None)) self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
# Each connector parses the raw dictionary and sets this
# attr the first time that the request is processed.
self.kv_transfer_params: Optional[KVTransferParams] = None
# Sanity check # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions) assert len(self.mm_inputs) == len(self.mm_positions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment