Unverified Commit d1911020 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[P/D] NIXL Integration (#17751)


Signed-off-by: default avatarApostaC <yihua98@uchicago.edu>
Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: default avatarrshaw@neuralmagic.com <robertgshaw2@gmail.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Signed-off-by: default avatarBrent Salisbury <bsalisbu@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarApostaC <yihua98@uchicago.edu>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tysmith@redhat.com>
Co-authored-by: default avatarBrent Salisbury <bsalisbu@redhat.com>
parent 05a4324f
...@@ -1086,6 +1086,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1086,6 +1086,7 @@ class OpenAIServingChat(OpenAIServing):
choices=choices, choices=choices,
usage=usage, usage=usage,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
kv_transfer_params=final_res.kv_transfer_params,
) )
return response return response
......
...@@ -482,7 +482,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -482,7 +482,7 @@ class OpenAIServingCompletion(OpenAIServing):
model=model_name, model=model_name,
choices=choices, choices=choices,
usage=usage, usage=usage,
) kv_transfer_params=final_res_batch[0].kv_transfer_params)
def _create_completion_logprobs( def _create_completion_logprobs(
self, self,
......
...@@ -112,6 +112,8 @@ if TYPE_CHECKING: ...@@ -112,6 +112,8 @@ if TYPE_CHECKING:
VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
def get_default_cache_root(): def get_default_cache_root():
...@@ -747,6 +749,14 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -747,6 +749,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# insecure method and it is needed for some reason. # insecure method and it is needed for some reason.
"VLLM_ALLOW_INSECURE_SERIALIZATION": "VLLM_ALLOW_INSECURE_SERIALIZATION":
lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))), lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))),
# IP address used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_HOST":
lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"),
# Port used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_PORT":
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -11,10 +11,6 @@ import torch.distributed as dist ...@@ -11,10 +11,6 @@ import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -106,16 +102,6 @@ def set_forward_context(attn_metadata: Any, ...@@ -106,16 +102,6 @@ def set_forward_context(attn_metadata: Any,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
dp_metadata=dp_metadata) dp_metadata=dp_metadata)
# KVConnector: trigger (possibly async) load before forward.
# Each attn layer will block until the reading is complete.
trigger_kv_transfer = (attn_metadata is not None
and has_kv_transfer_group()
and is_v1_kv_transfer_group())
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.start_load_kv(_forward_context)
try: try:
yield yield
finally: finally:
...@@ -152,11 +138,4 @@ def set_forward_context(attn_metadata: Any, ...@@ -152,11 +138,4 @@ def set_forward_context(attn_metadata: Any,
"(batchsize, count, median_time(ms)): %s"), "(batchsize, count, median_time(ms)): %s"),
forward_stats) forward_stats)
# KVConnector: each attn layer triggers (possibly async) save.
# Ensure all those operations complete before forward() is done.
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.wait_for_save()
_forward_context = prev_context _forward_context = prev_context
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
from collections.abc import MutableSequence from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, Optional, Union from typing import Any, Generic, Optional, Union
import torch import torch
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar, deprecated
...@@ -103,6 +103,7 @@ class RequestOutput: ...@@ -103,6 +103,7 @@ class RequestOutput:
encoder_prompt_token_ids: The token IDs of the encoder prompt. encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only. None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit. num_cached_tokens: The number of tokens with prefix cache hit.
kv_transfer_params: The params for remote K/V transfer.
""" """
def __init__( def __init__(
...@@ -120,6 +121,7 @@ class RequestOutput: ...@@ -120,6 +121,7 @@ class RequestOutput:
num_cached_tokens: Optional[int] = None, num_cached_tokens: Optional[int] = None,
*, *,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
...@@ -133,11 +135,13 @@ class RequestOutput: ...@@ -133,11 +135,13 @@ class RequestOutput:
self.encoder_prompt = encoder_prompt self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params
def add(self, next_output: "RequestOutput", aggregate: bool) -> None: def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one""" """Merge subsequent RequestOutput into this one"""
self.finished |= next_output.finished self.finished |= next_output.finished
self.kv_transfer_params = next_output.kv_transfer_params
for next_completion in next_output.outputs: for next_completion in next_output.outputs:
for i, completion in enumerate(self.outputs): for i, completion in enumerate(self.outputs):
......
...@@ -36,6 +36,12 @@ class KVCacheBlocks: ...@@ -36,6 +36,12 @@ class KVCacheBlocks:
"""Converts the KVCacheBlocks instance to a list of block IDs.""" """Converts the KVCacheBlocks instance to a list of block IDs."""
return [block.block_id for block in self.blocks] return [block.block_id for block in self.blocks]
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
return [
block.block_id for block in self.blocks if block.block_hash is None
]
class KVCacheManager: class KVCacheManager:
...@@ -116,6 +122,12 @@ class KVCacheManager: ...@@ -116,6 +122,12 @@ class KVCacheManager:
- The number of computed tokens. - The number of computed tokens.
""" """
# Request already has blocks from async load via KVConnector.
num_existing_blocks = len(
self.single_type_manager.req_to_blocks[request.request_id])
if num_existing_blocks > 0:
return KVCacheBlocks.create_empty(), request.num_computed_tokens
# Prefix caching is disabled or # Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching. # When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching if (not self.enable_caching
...@@ -173,6 +185,7 @@ class KVCacheManager: ...@@ -173,6 +185,7 @@ class KVCacheManager:
num_new_tokens: int, num_new_tokens: int,
new_computed_blocks: Optional[KVCacheBlocks] = None, new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0, num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
) -> Optional[KVCacheBlocks]: ) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append. """Add slots for a request with new tokens to append.
...@@ -186,6 +199,9 @@ class KVCacheManager: ...@@ -186,6 +199,9 @@ class KVCacheManager:
num_lookahead_tokens: The number of speculative tokens to allocate. num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such This is used by spec decode proposers with kv-cache such
as eagle. as eagle.
delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer
which will complete in a future step.
Blocks layout: Blocks layout:
``` ```
...@@ -255,7 +271,9 @@ class KVCacheManager: ...@@ -255,7 +271,9 @@ class KVCacheManager:
new_blocks = self.single_type_manager.allocate_new_blocks( new_blocks = self.single_type_manager.allocate_new_blocks(
request.request_id, num_tokens_need_slot) request.request_id, num_tokens_need_slot)
if not self.enable_caching: # P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
if not self.enable_caching or delay_cache_blocks:
return KVCacheBlocks(new_blocks) return KVCacheBlocks(new_blocks)
# Speculated tokens might be rejected in the future, so we does # Speculated tokens might be rejected in the future, so we does
...@@ -350,3 +368,16 @@ class KVCacheManager: ...@@ -350,3 +368,16 @@ class KVCacheManager:
A list of KV cache events. A list of KV cache events.
""" """
return self.block_pool.take_events() return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> list[int]:
"""Get the block ids of a request."""
assert request_id in self.single_type_manager.req_to_blocks
return [
block.block_id
for block in self.single_type_manager.req_to_blocks[request_id]
]
def get_num_blocks(self, request_id: str):
"""Get the number of blocks."""
assert request_id in self.single_type_manager.req_to_blocks
return len(self.single_type_manager.req_to_blocks[request_id])
...@@ -4,6 +4,7 @@ from collections.abc import Iterable ...@@ -4,6 +4,7 @@ from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
...@@ -137,3 +138,6 @@ class SchedulerInterface(ABC): ...@@ -137,3 +138,6 @@ class SchedulerInterface(ABC):
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shutdown the scheduler.""" """Shutdown the scheduler."""
raise NotImplementedError raise NotImplementedError
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
return None
...@@ -5,13 +5,15 @@ from __future__ import annotations ...@@ -5,13 +5,15 @@ from __future__ import annotations
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Any, Optional, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole,
KVTransferParams)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
...@@ -96,6 +98,9 @@ class Scheduler(SchedulerInterface): ...@@ -96,6 +98,9 @@ class Scheduler(SchedulerInterface):
# This is flushed at the end of each scheduling step. # This is flushed at the end of each scheduling step.
self.finished_req_ids: set[str] = set() self.finished_req_ids: set[str] = set()
# P/D: requests in process of recving KV transfers
self.finished_recving_kv_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step. # them at each scheduling step.
# Request id -> deque of CachedRequestData # Request id -> deque of CachedRequestData
...@@ -307,6 +312,16 @@ class Scheduler(SchedulerInterface): ...@@ -307,6 +312,16 @@ class Scheduler(SchedulerInterface):
request = self.waiting[0] request = self.waiting[0]
# P/D: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
else:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
# Skip request if the structured output request is still waiting # Skip request if the structured output request is still waiting
# for FSM compilation. # for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM: if request.status == RequestStatus.WAITING_FOR_FSM:
...@@ -330,26 +345,33 @@ class Scheduler(SchedulerInterface): ...@@ -330,26 +345,33 @@ class Scheduler(SchedulerInterface):
continue continue
# Get already-cached tokens. # Get already-cached tokens.
computed_blocks, num_computed_tokens = \ new_computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks( self.kv_cache_manager.get_computed_blocks(
request) request)
# Get externally-cached tokens if using a KVConnector. # Get externally-cached tokens if using a KVConnector.
num_external_tokens = ( num_external_tokens, load_kv_async = (
0 if self.connector is None else (0, False) if self.connector is None else
self.connector.get_num_new_matched_tokens( self.connector.get_num_new_matched_tokens(
request, num_computed_tokens)) request, num_computed_tokens))
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens += num_external_tokens num_computed_tokens += num_external_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# P/D: loading remote KV, do not allocate for new work.
if load_kv_async:
num_new_tokens = 0
# Number of tokens to be scheduled. # Number of tokens to be scheduled.
else:
# We use `request.num_tokens` instead of # We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests, # `request.num_prompt_tokens` to consider the resumed
# which have output tokens. # requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold < if (0 < self.scheduler_config.long_prefill_token_threshold
num_new_tokens): < num_new_tokens):
num_new_tokens = ( num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold) self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
...@@ -358,21 +380,20 @@ class Scheduler(SchedulerInterface): ...@@ -358,21 +380,20 @@ class Scheduler(SchedulerInterface):
# Schedule encoder inputs. # Schedule encoder inputs.
if request.has_encoder_inputs: if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens, (encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs( new_encoder_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens, request, num_computed_tokens, num_new_tokens,
encoder_budget) encoder_budget)
if num_new_tokens == 0: if num_new_tokens == 0:
# The request cannot be scheduled. # The request cannot be scheduled.
break break
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
num_new_tokens + num_external_tokens, num_new_tokens + num_external_tokens,
computed_blocks, new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens, num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async,
) )
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
...@@ -384,10 +405,18 @@ class Scheduler(SchedulerInterface): ...@@ -384,10 +405,18 @@ class Scheduler(SchedulerInterface):
if self.connector is not None: if self.connector is not None:
self.connector.update_state_after_alloc( self.connector.update_state_after_alloc(
request, request,
new_computed_blocks + new_blocks,
num_external_tokens, num_external_tokens,
) )
self.waiting.popleft() self.waiting.popleft()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.appendleft(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
if request.use_structured_output: if request.use_structured_output:
structured_output_request_ids[ structured_output_request_ids[
request.request_id] = req_index request.request_id] = req_index
...@@ -407,7 +436,7 @@ class Scheduler(SchedulerInterface): ...@@ -407,7 +436,7 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = ( req_to_new_block_ids[request.request_id] = (
computed_blocks + new_blocks).get_block_ids() self.kv_cache_manager.get_block_ids(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
...@@ -698,6 +727,7 @@ class Scheduler(SchedulerInterface): ...@@ -698,6 +727,7 @@ class Scheduler(SchedulerInterface):
stopped = False stopped = False
new_logprobs = None new_logprobs = None
new_token_ids = generated_token_ids new_token_ids = generated_token_ids
kv_transfer_params = None
# Append generated tokens and check for stop. Note that if # Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner # a request is still being prefilled, we expect the model runner
...@@ -709,7 +739,7 @@ class Scheduler(SchedulerInterface): ...@@ -709,7 +739,7 @@ class Scheduler(SchedulerInterface):
# This must be called before we make the EngineCoreOutput. # This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len) stopped = check_stop(request, self.max_model_len)
if stopped: if stopped:
self._free_request(request) kv_transfer_params = self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed. del new_token_ids[num_new:] # Trim new tokens if needed.
break break
...@@ -739,7 +769,8 @@ class Scheduler(SchedulerInterface): ...@@ -739,7 +769,8 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request. # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids: if new_token_ids or kv_transfer_params:
# Add EngineCoreOutput for this Request. # Add EngineCoreOutput for this Request.
outputs.append( outputs.append(
EngineCoreOutput( EngineCoreOutput(
...@@ -749,7 +780,10 @@ class Scheduler(SchedulerInterface): ...@@ -749,7 +780,10 @@ class Scheduler(SchedulerInterface):
new_logprobs=new_logprobs, new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors, new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason, stop_reason=request.stop_reason,
events=request.take_events())) events=request.take_events(),
kv_transfer_params=kv_transfer_params,
))
else: else:
# Invariant: EngineCore returns no partial prefill outputs. # Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors assert not prompt_logprobs_tensors
...@@ -757,6 +791,9 @@ class Scheduler(SchedulerInterface): ...@@ -757,6 +791,9 @@ class Scheduler(SchedulerInterface):
if not stopped: if not stopped:
new_running.append(request) new_running.append(request)
# P/D: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output)
# Return the cached request data to the queue so they can be reused. # Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs: for req_data in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): since we free stopped reqs above, adding stopped reqs # NOTE(rob): since we free stopped reqs above, adding stopped reqs
...@@ -811,15 +848,27 @@ class Scheduler(SchedulerInterface): ...@@ -811,15 +848,27 @@ class Scheduler(SchedulerInterface):
request.status = finished_status request.status = finished_status
self._free_request(request) self._free_request(request)
def _free_request(self, request: Request) -> None: def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
assert request.is_finished() assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request) delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request) self.encoder_cache_manager.free(request)
self._cached_reqs_data.pop(request.request_id, None) self._cached_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id) self.finished_req_ids.add(request.request_id)
if not delay_free_blocks:
self._free_blocks(request)
return kv_xfer_params
def _free_blocks(self, request: Request):
assert request.is_finished()
assert request.request_id not in self._cached_reqs_data
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id]
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running) return len(self.waiting) + len(self.running)
...@@ -863,3 +912,70 @@ class Scheduler(SchedulerInterface): ...@@ -863,3 +912,70 @@ class Scheduler(SchedulerInterface):
def shutdown(self) -> None: def shutdown(self) -> None:
if self.kv_event_publisher: if self.kv_event_publisher:
self.kv_event_publisher.shutdown() self.kv_event_publisher.shutdown()
########################################################################
# P/D Related Methods
########################################################################
def get_kv_connector(self) -> Optional[KVConnectorBase_V1]:
return self.connector
def _connector_finished(
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]:
"""Invoke the KV connector request_finished() method if applicable."""
if self.connector is None:
return False, None
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
return self.connector.request_finished(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""
P/D: check if the request_id is finished_recving.
The finished_recving_kv_req_ids list is populated
on the previous steps()'s update_from_output based
on the worker side connector.
When the kv transfer is ready, we cache the blocks
and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV.
"""
if request.request_id not in self.finished_recving_kv_req_ids:
return False
# Now that the blocks are ready, actually cache them.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
self.kv_cache_manager.single_type_manager.cache_blocks(
request,
self.kv_cache_manager.req_to_block_hashes[request.request_id],
num_computed_tokens,
)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
return True
def _update_from_kv_xfer_finished(self,
model_runner_output: ModelRunnerOutput):
"""
P/D: update the scheduler state based on the output.
The Worker side connectors add finished_recving and
finished_sending reqs to the output.
* if finished_sending: free the blocks
# if finished_recving: add to state so we can
scheduler the request during the next step.
"""
# P/D: update recv and send status from last step.
for req_id in (model_runner_output.finished_recving or ()):
logger.debug("Finished recving KV transfer for request %s", req_id)
self.finished_recving_kv_req_ids.add(req_id)
for req_id in (model_runner_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id])
...@@ -105,6 +105,7 @@ class EngineCoreOutput( ...@@ -105,6 +105,7 @@ class EngineCoreOutput(
finish_reason: Optional[FinishReason] = None finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None
@property @property
def finished(self) -> bool: def finished(self) -> bool:
......
...@@ -182,6 +182,15 @@ class EngineCore: ...@@ -182,6 +182,15 @@ class EngineCore:
# Start grammar compilation asynchronously # Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req) self.structured_output_manager.grammar_init(req)
if req.raw_kv_transfer_params is not None:
if (kv_connector := self.scheduler.get_kv_connector()):
# Parse raw KV transfer params via connector.
kv_connector.set_kv_transfer_params(req)
else:
logger.warning(
"Got KVTransferParams, but no KVConnector found. "
"Disabling KVTransfer for this request.")
self.scheduler.add_request(req) self.scheduler.add_request(req)
def abort_requests(self, request_ids: list[str]): def abort_requests(self, request_ids: list[str]):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import asyncio import asyncio
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Any, Optional, Union
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
...@@ -146,6 +146,7 @@ class RequestState: ...@@ -146,6 +146,7 @@ class RequestState:
new_token_ids: list[int], new_token_ids: list[int],
finish_reason: Optional[FinishReason], finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None], stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> Optional[RequestOutput]: ) -> Optional[RequestOutput]:
finished = finish_reason is not None finished = finish_reason is not None
...@@ -167,13 +168,15 @@ class RequestState: ...@@ -167,13 +168,15 @@ class RequestState:
if not outputs: if not outputs:
return None return None
return self._new_request_output(request_id, outputs, finished) return self._new_request_output(request_id, outputs, finished,
kv_transfer_params)
def _new_request_output( def _new_request_output(
self, self,
request_id: str, request_id: str,
outputs: list[CompletionOutput], outputs: list[CompletionOutput],
finished: bool, finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> RequestOutput: ) -> RequestOutput:
if self.output_kind == RequestOutputKind.DELTA: if self.output_kind == RequestOutputKind.DELTA:
...@@ -189,6 +192,7 @@ class RequestState: ...@@ -189,6 +192,7 @@ class RequestState:
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
outputs=outputs, outputs=outputs,
finished=finished, finished=finished,
kv_transfer_params=kv_transfer_params,
) )
def _new_completion_output( def _new_completion_output(
...@@ -335,6 +339,7 @@ class OutputProcessor: ...@@ -335,6 +339,7 @@ class OutputProcessor:
new_token_ids = engine_core_output.new_token_ids new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
req_state.is_prefilling = False req_state.is_prefilling = False
...@@ -350,7 +355,8 @@ class OutputProcessor: ...@@ -350,7 +355,8 @@ class OutputProcessor:
# 4) Create and handle RequestOutput objects. # 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output( if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason): new_token_ids, finish_reason, stop_reason,
kv_transfer_params):
if req_state.queue is not None: if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate(). # AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output) req_state.queue.put(request_output)
......
...@@ -100,12 +100,16 @@ class ModelRunnerOutput: ...@@ -100,12 +100,16 @@ class ModelRunnerOutput:
# [prompt_len] # [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[], EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={}, req_id_to_index={},
sampled_token_ids=[], sampled_token_ids=[],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
) finished_sending=None,
finished_recving=None)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import enum import enum
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -61,6 +62,15 @@ class Request: ...@@ -61,6 +62,15 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs) self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0 self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: KV transfer parameters (raw and parsed).
raw_params = (None if sampling_params.extra_args is None
else sampling_params.extra_args.get(
"kv_transfer_params", None))
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
# Each connector parses the raw dictionary and sets this
# attr the first time that the request is processed.
self.kv_transfer_params: Optional[KVTransferParams] = None
# Sanity check # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions) assert len(self.mm_inputs) == len(self.mm_positions)
if self.mm_hashes: if self.mm_hashes:
...@@ -150,6 +160,7 @@ class RequestStatus(enum.IntEnum): ...@@ -150,6 +160,7 @@ class RequestStatus(enum.IntEnum):
"""Status of a request.""" """Status of a request."""
WAITING = enum.auto() WAITING = enum.auto()
WAITING_FOR_FSM = enum.auto() WAITING_FOR_FSM = enum.auto()
WAITING_FOR_REMOTE_KVS = enum.auto()
RUNNING = enum.auto() RUNNING = enum.auto()
PREEMPTED = enum.auto() PREEMPTED = enum.auto()
# Note: anything after PREEMPTED will be considered # Note: anything after PREEMPTED will be considered
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy
import gc import gc
import time import time
import weakref import weakref
...@@ -17,8 +18,9 @@ from vllm.config import (CompilationLevel, VllmConfig, ...@@ -17,8 +18,9 @@ from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -1065,16 +1067,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1065,16 +1067,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]: ) -> Union[ModelRunnerOutput, IntermediateTensors]:
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
get_kv_transfer_group().bind_connector_metadata(
scheduler_output.kv_connector_metadata)
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do. # Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs. # Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = ( attn_metadata, logits_indices, spec_decode_metadata = (
self._prepare_inputs(scheduler_output)) self._prepare_inputs(scheduler_output))
...@@ -1150,17 +1151,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1150,17 +1151,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens): num_tokens=num_input_tokens):
output = self.model( self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = output hidden_states, aux_hidden_states = model_output
else: else:
hidden_states = output hidden_states = model_output
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states. # For mid-pipeline stages, return the hidden states.
...@@ -1341,8 +1348,56 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1341,8 +1348,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=spec_token_ids, spec_token_ids=spec_token_ids,
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
finished_sending=finished_sending,
finished_recving=finished_recving,
) )
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
def generate_draft_token_ids( def generate_draft_token_ids(
self, self,
sampled_token_ids: list[list[int]], sampled_token_ids: list[list[int]],
...@@ -1813,6 +1868,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1813,6 +1868,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.static_forward_context, self.vllm_config.compilation_config.static_forward_context,
self.kv_caches) self.kv_caches)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()( self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self), weakref.proxy(self),
kv_cache_config.kv_cache_groups[0].kv_cache_spec, kv_cache_config.kv_cache_groups[0].kv_cache_spec,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment