"docs/kubernetes/observability/metrics.md" did not exist on "0a2a820bcacda705d927c6fdfcf37ec076e4e3fd"
Unverified Commit d1911020 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[P/D] NIXL Integration (#17751)


Signed-off-by: default avatarApostaC <yihua98@uchicago.edu>
Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: default avatarrshaw@neuralmagic.com <robertgshaw2@gmail.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Signed-off-by: default avatarBrent Salisbury <bsalisbu@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarApostaC <yihua98@uchicago.edu>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tysmith@redhat.com>
Co-authored-by: default avatarBrent Salisbury <bsalisbu@redhat.com>
parent 05a4324f
......@@ -1086,6 +1086,7 @@ class OpenAIServingChat(OpenAIServing):
choices=choices,
usage=usage,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
kv_transfer_params=final_res.kv_transfer_params,
)
return response
......
......@@ -482,7 +482,7 @@ class OpenAIServingCompletion(OpenAIServing):
model=model_name,
choices=choices,
usage=usage,
)
kv_transfer_params=final_res_batch[0].kv_transfer_params)
def _create_completion_logprobs(
self,
......
......@@ -112,6 +112,8 @@ if TYPE_CHECKING:
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
def get_default_cache_root():
......@@ -747,6 +749,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# insecure method and it is needed for some reason.
"VLLM_ALLOW_INSECURE_SERIALIZATION":
lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))),
# IP address used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_HOST":
lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"),
# Port used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_PORT":
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
}
# end-env-vars-definition
......
......@@ -11,10 +11,6 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.logger import init_logger
if TYPE_CHECKING:
......@@ -106,16 +102,6 @@ def set_forward_context(attn_metadata: Any,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata)
# KVConnector: trigger (possibly async) load before forward.
# Each attn layer will block until the reading is complete.
trigger_kv_transfer = (attn_metadata is not None
and has_kv_transfer_group()
and is_v1_kv_transfer_group())
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.start_load_kv(_forward_context)
try:
yield
finally:
......@@ -152,11 +138,4 @@ def set_forward_context(attn_metadata: Any,
"(batchsize, count, median_time(ms)): %s"),
forward_stats)
# KVConnector: each attn layer triggers (possibly async) save.
# Ensure all those operations complete before forward() is done.
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.wait_for_save()
_forward_context = prev_context
......@@ -4,7 +4,7 @@ import time
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass
from typing import Generic, Optional, Union
from typing import Any, Generic, Optional, Union
import torch
from typing_extensions import TypeVar, deprecated
......@@ -103,6 +103,7 @@ class RequestOutput:
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
kv_transfer_params: The params for remote K/V transfer.
"""
def __init__(
......@@ -120,6 +121,7 @@ class RequestOutput:
num_cached_tokens: Optional[int] = None,
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
......@@ -133,11 +135,13 @@ class RequestOutput:
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one"""
self.finished |= next_output.finished
self.kv_transfer_params = next_output.kv_transfer_params
for next_completion in next_output.outputs:
for i, completion in enumerate(self.outputs):
......
......@@ -36,6 +36,12 @@ class KVCacheBlocks:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
return [block.block_id for block in self.blocks]
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
return [
block.block_id for block in self.blocks if block.block_hash is None
]
class KVCacheManager:
......@@ -116,6 +122,12 @@ class KVCacheManager:
- The number of computed tokens.
"""
# Request already has blocks from async load via KVConnector.
num_existing_blocks = len(
self.single_type_manager.req_to_blocks[request.request_id])
if num_existing_blocks > 0:
return KVCacheBlocks.create_empty(), request.num_computed_tokens
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
......@@ -173,6 +185,7 @@ class KVCacheManager:
num_new_tokens: int,
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append.
......@@ -186,6 +199,9 @@ class KVCacheManager:
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer
which will complete in a future step.
Blocks layout:
```
......@@ -255,7 +271,9 @@ class KVCacheManager:
new_blocks = self.single_type_manager.allocate_new_blocks(
request.request_id, num_tokens_need_slot)
if not self.enable_caching:
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
if not self.enable_caching or delay_cache_blocks:
return KVCacheBlocks(new_blocks)
# Speculated tokens might be rejected in the future, so we does
......@@ -350,3 +368,16 @@ class KVCacheManager:
A list of KV cache events.
"""
return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> list[int]:
"""Get the block ids of a request."""
assert request_id in self.single_type_manager.req_to_blocks
return [
block.block_id
for block in self.single_type_manager.req_to_blocks[request_id]
]
def get_num_blocks(self, request_id: str):
"""Get the number of blocks."""
assert request_id in self.single_type_manager.req_to_blocks
return len(self.single_type_manager.req_to_blocks[request_id])
......@@ -4,6 +4,7 @@ from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
......@@ -137,3 +138,6 @@ class SchedulerInterface(ABC):
def shutdown(self) -> None:
"""Shutdown the scheduler."""
raise NotImplementedError
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
return None
......@@ -5,13 +5,15 @@ from __future__ import annotations
import time
from collections import defaultdict, deque
from collections.abc import Iterable
from typing import Optional, Union
from typing import Any, Optional, Union
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole,
KVTransferParams)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
......@@ -96,6 +98,9 @@ class Scheduler(SchedulerInterface):
# This is flushed at the end of each scheduling step.
self.finished_req_ids: set[str] = set()
# P/D: requests in process of recving KV transfers
self.finished_recving_kv_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> deque of CachedRequestData
......@@ -307,6 +312,16 @@ class Scheduler(SchedulerInterface):
request = self.waiting[0]
# P/D: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
else:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
......@@ -330,49 +345,55 @@ class Scheduler(SchedulerInterface):
continue
# Get already-cached tokens.
computed_blocks, num_computed_tokens = \
new_computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
# Get externally-cached tokens if using a KVConnector.
num_external_tokens = (
0 if self.connector is None else
num_external_tokens, load_kv_async = (
(0, False) if self.connector is None else
self.connector.get_num_new_matched_tokens(
request, num_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens += num_external_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# P/D: loading remote KV, do not allocate for new work.
if load_kv_async:
num_new_tokens = 0
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_tokens,
computed_blocks,
new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async,
)
if new_blocks is None:
# The request cannot be scheduled.
......@@ -384,10 +405,18 @@ class Scheduler(SchedulerInterface):
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_tokens,
)
self.waiting.popleft()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.appendleft(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
if request.use_structured_output:
structured_output_request_ids[
request.request_id] = req_index
......@@ -407,7 +436,7 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = (
computed_blocks + new_blocks).get_block_ids()
self.kv_cache_manager.get_block_ids(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
......@@ -698,6 +727,7 @@ class Scheduler(SchedulerInterface):
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
......@@ -709,7 +739,7 @@ class Scheduler(SchedulerInterface):
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
kv_transfer_params = self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
......@@ -739,7 +769,8 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids:
if new_token_ids or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
......@@ -749,7 +780,10 @@ class Scheduler(SchedulerInterface):
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
......@@ -757,6 +791,9 @@ class Scheduler(SchedulerInterface):
if not stopped:
new_running.append(request)
# P/D: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output)
# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
......@@ -811,15 +848,27 @@ class Scheduler(SchedulerInterface):
request.status = finished_status
self._free_request(request)
def _free_request(self, request: Request) -> None:
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
self._cached_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)
if not delay_free_blocks:
self._free_blocks(request)
return kv_xfer_params
def _free_blocks(self, request: Request):
assert request.is_finished()
assert request.request_id not in self._cached_reqs_data
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id]
def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running)
......@@ -863,3 +912,70 @@ class Scheduler(SchedulerInterface):
def shutdown(self) -> None:
if self.kv_event_publisher:
self.kv_event_publisher.shutdown()
########################################################################
# P/D Related Methods
########################################################################
def get_kv_connector(self) -> Optional[KVConnectorBase_V1]:
return self.connector
def _connector_finished(
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]:
"""Invoke the KV connector request_finished() method if applicable."""
if self.connector is None:
return False, None
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
return self.connector.request_finished(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""
P/D: check if the request_id is finished_recving.
The finished_recving_kv_req_ids list is populated
on the previous steps()'s update_from_output based
on the worker side connector.
When the kv transfer is ready, we cache the blocks
and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV.
"""
if request.request_id not in self.finished_recving_kv_req_ids:
return False
# Now that the blocks are ready, actually cache them.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
self.kv_cache_manager.single_type_manager.cache_blocks(
request,
self.kv_cache_manager.req_to_block_hashes[request.request_id],
num_computed_tokens,
)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
return True
def _update_from_kv_xfer_finished(self,
model_runner_output: ModelRunnerOutput):
"""
P/D: update the scheduler state based on the output.
The Worker side connectors add finished_recving and
finished_sending reqs to the output.
* if finished_sending: free the blocks
# if finished_recving: add to state so we can
scheduler the request during the next step.
"""
# P/D: update recv and send status from last step.
for req_id in (model_runner_output.finished_recving or ()):
logger.debug("Finished recving KV transfer for request %s", req_id)
self.finished_recving_kv_req_ids.add(req_id)
for req_id in (model_runner_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id])
......@@ -105,6 +105,7 @@ class EngineCoreOutput(
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None
@property
def finished(self) -> bool:
......
......@@ -182,6 +182,15 @@ class EngineCore:
# Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req)
if req.raw_kv_transfer_params is not None:
if (kv_connector := self.scheduler.get_kv_connector()):
# Parse raw KV transfer params via connector.
kv_connector.set_kv_transfer_params(req)
else:
logger.warning(
"Got KVTransferParams, but no KVConnector found. "
"Disabling KVTransfer for this request.")
self.scheduler.add_request(req)
def abort_requests(self, request_ids: list[str]):
......
......@@ -3,7 +3,7 @@
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional, Union
from typing import Any, Optional, Union
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind
......@@ -146,6 +146,7 @@ class RequestState:
new_token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> Optional[RequestOutput]:
finished = finish_reason is not None
......@@ -167,13 +168,15 @@ class RequestState:
if not outputs:
return None
return self._new_request_output(request_id, outputs, finished)
return self._new_request_output(request_id, outputs, finished,
kv_transfer_params)
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> RequestOutput:
if self.output_kind == RequestOutputKind.DELTA:
......@@ -189,6 +192,7 @@ class RequestState:
prompt_logprobs=prompt_logprobs,
outputs=outputs,
finished=finished,
kv_transfer_params=kv_transfer_params,
)
def _new_completion_output(
......@@ -335,6 +339,7 @@ class OutputProcessor:
new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
req_state.is_prefilling = False
......@@ -350,7 +355,8 @@ class OutputProcessor:
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason):
new_token_ids, finish_reason, stop_reason,
kv_transfer_params):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
......
......@@ -100,12 +100,16 @@ class ModelRunnerOutput:
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
finished_sending=None,
finished_recving=None)
# SPDX-License-Identifier: Apache-2.0
import enum
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
......@@ -61,6 +62,15 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: KV transfer parameters (raw and parsed).
raw_params = (None if sampling_params.extra_args is None
else sampling_params.extra_args.get(
"kv_transfer_params", None))
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
# Each connector parses the raw dictionary and sets this
# attr the first time that the request is processed.
self.kv_transfer_params: Optional[KVTransferParams] = None
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
if self.mm_hashes:
......@@ -150,6 +160,7 @@ class RequestStatus(enum.IntEnum):
"""Status of a request."""
WAITING = enum.auto()
WAITING_FOR_FSM = enum.auto()
WAITING_FOR_REMOTE_KVS = enum.auto()
RUNNING = enum.auto()
PREEMPTED = enum.auto()
# Note: anything after PREEMPTED will be considered
......
# SPDX-License-Identifier: Apache-2.0
import copy
import gc
import time
import weakref
......@@ -17,8 +18,9 @@ from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
......@@ -1065,15 +1067,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
get_kv_transfer_group().bind_connector_metadata(
scheduler_output.kv_connector_metadata)
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
......@@ -1150,17 +1151,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
output = self.model(
self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = output
hidden_states, aux_hidden_states = model_output
else:
hidden_states = output
hidden_states = model_output
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
......@@ -1341,8 +1348,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
def generate_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
......@@ -1813,6 +1868,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self),
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment