Unverified Commit 30610e73 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: use KvPushRouter for prefill router (#3401)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent c48f49a4
...@@ -19,7 +19,7 @@ from typing import Optional ...@@ -19,7 +19,7 @@ from typing import Optional
import uvloop import uvloop
from dynamo.llm import KvRouter, KvRouterConfig from dynamo.llm import KvPushRouter, KvRouterConfig
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -41,7 +41,7 @@ class StandaloneRouterHandler: ...@@ -41,7 +41,7 @@ class StandaloneRouterHandler:
self.worker_endpoint_path = worker_endpoint_path self.worker_endpoint_path = worker_endpoint_path
self.block_size = block_size self.block_size = block_size
self.kv_router_config = kv_router_config self.kv_router_config = kv_router_config
self.kv_router: Optional[KvRouter] = None self.kv_push_router: Optional[KvPushRouter] = None
self.worker_client: Optional[Client] = None self.worker_client: Optional[Client] = None
async def initialize(self): async def initialize(self):
...@@ -65,121 +65,76 @@ class StandaloneRouterHandler: ...@@ -65,121 +65,76 @@ class StandaloneRouterHandler:
self.worker_client = await worker_endpoint.client() self.worker_client = await worker_endpoint.client()
# Create KvRouter with specified configuration # Create KvPushRouter with specified configuration
self.kv_router = KvRouter( self.kv_push_router = KvPushRouter(
endpoint=worker_endpoint, endpoint=worker_endpoint,
block_size=self.block_size, block_size=self.block_size,
kv_router_config=self.kv_router_config, kv_router_config=self.kv_router_config,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize KvRouter: {e}") logger.error(f"Failed to initialize KvPushRouter: {e}")
raise raise
async def find_best_worker(self, request): async def generate(self, request):
""" """
Find the best worker based on KV cache state. Generate tokens using the KV-aware router.
This endpoint is called by clients to determine which worker This endpoint routes the request to the best worker and streams back results.
should handle a request. Wraps the request into PreprocessedRequest format and wraps worker responses
into LLMEngineOutput format.
""" """
if self.kv_router is None: if self.kv_push_router is None:
# Fallback to round-robin if router not initialized logger.error("KvPushRouter not initialized - cannot process request")
logger.warning("KvRouter not initialized, falling back to round-robin") raise RuntimeError("Router not initialized")
yield {
"status": "fallback", # Wrap incoming request into PreprocessedRequest format for KvPushRouter
"message": "Router not initialized", # The request should already have most fields, but we ensure it has the structure
preprocessed_request = {
"model": request.get("model", "unknown"),
"token_ids": request["token_ids"],
"stop_conditions": request.get("stop_conditions", {}),
"sampling_options": request.get("sampling_options", {}),
"output_options": request.get("output_options", {}),
"eos_token_ids": request.get("eos_token_ids", []),
"annotations": request.get("annotations", []),
"extra_args": request.get("extra_args", {}),
}
# Route and process through KvPushRouter
async for worker_output in await self.kv_push_router.generate_from_request(
preprocessed_request
):
# Wrap worker output into LLMEngineOutput format
# Worker should return dict with at minimum kv_transfer_params in extra_args
llm_engine_output = {
"token_ids": worker_output.get("token_ids", []),
"tokens": worker_output.get("tokens"),
"text": worker_output.get("text"),
"cum_log_probs": worker_output.get("cum_log_probs"),
"log_probs": worker_output.get("log_probs"),
"top_logprobs": worker_output.get("top_logprobs"),
"finish_reason": worker_output.get("finish_reason"),
"index": worker_output.get("index"),
"extra_args": worker_output.get("extra_args"),
} }
return yield llm_engine_output
try: async def best_worker_id(self, token_ids, router_config_override=None):
# Get current workers
if self.worker_client is None:
yield {
"status": "error",
"message": "Worker client not initialized",
}
return
instance_ids = self.worker_client.instance_ids()
if not instance_ids:
yield {
"status": "error",
"message": "No workers available",
}
return
logger.debug(f"Routing request with {len(instance_ids)} available workers")
# Validate required fields
if "token_ids" not in request:
raise ValueError("Missing required field 'token_ids' in request")
if "request_id" not in request:
raise ValueError("Missing required field 'request_id' in request")
token_ids = request["token_ids"]
request_id = request["request_id"]
# Use KvRouter to find the best worker with state updates
best_worker_id, overlap_blocks = await self.kv_router.find_best_match(
request_id=request_id,
tokens=token_ids,
update_states=True, # Always update states for routing
)
logger.debug(
f"Selected worker {best_worker_id} with {overlap_blocks} overlap blocks for request {request_id}"
)
yield {
"worker_id": best_worker_id,
"overlap_blocks": overlap_blocks,
}
except Exception as e:
logger.error(f"Error finding best worker: {e}")
yield {
"status": "error",
"message": str(e),
}
async def free(self, request):
""" """
Free resources associated with a request. Get the best worker ID for a given set of tokens without actually routing.
This endpoint is called when a request is completed to clean up This method returns the worker ID that would be selected based on KV cache
router state. overlap, but does NOT actually route the request or update router states.
It's useful for debugging, monitoring, or implementing custom routing logic.
""" """
if self.kv_router is None: if self.kv_push_router is None:
logger.warning("KvRouter not initialized") logger.error("KvPushRouter not initialized - cannot get best worker")
yield { raise RuntimeError("Router not initialized")
"status": "error",
"message": "Router not initialized",
}
return
try: return await self.kv_push_router.best_worker_id(
if "request_id" not in request: token_ids, router_config_override
raise ValueError("Missing required field 'request_id' in request") )
request_id = request["request_id"]
# Free the request from the router
await self.kv_router.free(request_id=request_id)
logger.debug(f"Freed resources for request {request_id}")
yield {
"status": "success",
"message": f"Request {request_id} freed successfully",
}
except Exception as e:
logger.error(f"Error freeing request: {e}")
yield {
"status": "error",
"message": str(e),
}
def parse_args(): def parse_args():
...@@ -308,20 +263,21 @@ async def worker(runtime: DistributedRuntime): ...@@ -308,20 +263,21 @@ async def worker(runtime: DistributedRuntime):
await handler.initialize() await handler.initialize()
# Expose endpoints # Expose endpoints
find_best_worker_endpoint = component.endpoint("find_best_worker") generate_endpoint = component.endpoint("generate")
free_endpoint = component.endpoint("free") best_worker_endpoint = component.endpoint("best_worker_id")
logger.debug("Starting to serve find_best_worker and free endpoints...") logger.debug("Starting to serve endpoints...")
# Serve both endpoints concurrently
try: try:
await asyncio.gather( await asyncio.gather(
find_best_worker_endpoint.serve_endpoint( generate_endpoint.serve_endpoint(
handler.find_best_worker, handler.generate,
graceful_shutdown=True, graceful_shutdown=True,
metrics_labels=[("service", "router")], metrics_labels=[("service", "router")],
), ),
free_endpoint.serve_endpoint( best_worker_endpoint.serve_endpoint(
handler.free, handler.best_worker_id,
graceful_shutdown=True, graceful_shutdown=True,
metrics_labels=[("service", "router")], metrics_labels=[("service", "router")],
), ),
......
...@@ -8,7 +8,7 @@ import uuid ...@@ -8,7 +8,7 @@ import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from copy import deepcopy from copy import deepcopy
from typing import AsyncGenerator from typing import Any, AsyncGenerator, Dict
import msgspec import msgspec
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
...@@ -18,7 +18,6 @@ from vllm.v1.engine.exceptions import EngineDeadError ...@@ -18,7 +18,6 @@ from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor from .engine_monitor import VllmEngineMonitor
from .protocol import MyRequestOutput
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -126,27 +125,34 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -126,27 +125,34 @@ class DecodeWorkerHandler(BaseWorkerHandler):
default_sampling_params, default_sampling_params,
prefill_worker_client=None, prefill_worker_client=None,
prefill_router_client=None, prefill_router_client=None,
prefill_router_free_client=None,
): ):
super().__init__(runtime, component, engine, default_sampling_params) super().__init__(runtime, component, engine, default_sampling_params)
self.prefill_worker_client = prefill_worker_client self.prefill_worker_client = prefill_worker_client
self.prefill_router_client = prefill_router_client self.prefill_router_client = prefill_router_client
self.prefill_router_free_client = prefill_router_free_client
self.can_prefill = 0 self.can_prefill = 0
self._prefill_check_task = None self._prefill_check_task = None
if self.prefill_worker_client is not None: if self.prefill_worker_client or self.prefill_router_client:
self._prefill_check_task = asyncio.create_task(self._prefill_check_loop()) self._prefill_check_task = asyncio.create_task(self._prefill_check_loop())
async def _prefill_check_loop(self): async def _prefill_check_loop(self):
"""Background task that checks prefill worker availability every 5 seconds.""" """Background task that checks prefill router/worker availability every 5 seconds."""
while True: while True:
try: try:
if self.prefill_worker_client is not None: router_count = (
self.can_prefill = len(self.prefill_worker_client.instance_ids()) len(self.prefill_router_client.instance_ids())
logger.debug(f"Current Prefill Workers: {self.can_prefill}") if self.prefill_router_client is not None
else: else 0
self.can_prefill = 0 )
worker_count = (
len(self.prefill_worker_client.instance_ids())
if self.prefill_worker_client is not None
else 0
)
self.can_prefill = max(router_count, worker_count)
logger.debug(
f"Prefill availability - Routers: {router_count}, Workers: {worker_count}"
)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.warning("Prefill check loop cancelled.") logger.warning("Prefill check loop cancelled.")
raise raise
...@@ -178,15 +184,10 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -178,15 +184,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if value is not None and hasattr(sampling_params, key): if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value) setattr(sampling_params, key, value)
# TODO: Change to prefill queue # Use prefill router or worker if available
# TODO: (PeaBrane) eventually, do not use a router_client and a free_client directly.
# This is least intrusive for now, but quite error prone. Should consider (major) refactoring
# TODO: (PeaBrane) longer term, decode workers should not handle prefill routing at all.
# Prefill routing logic should be integrated directly into the frontend service potentially.
if self.can_prefill: if self.can_prefill:
# Create a copy for prefill with specific modifications # Create prefill sampling params with modifications
prefill_sampling_params = deepcopy(sampling_params) prefill_sampling_params = deepcopy(sampling_params)
if prefill_sampling_params.extra_args is None: if prefill_sampling_params.extra_args is None:
prefill_sampling_params.extra_args = {} prefill_sampling_params.extra_args = {}
prefill_sampling_params.extra_args["kv_transfer_params"] = { prefill_sampling_params.extra_args["kv_transfer_params"] = {
...@@ -195,68 +196,55 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -195,68 +196,55 @@ class DecodeWorkerHandler(BaseWorkerHandler):
prefill_sampling_params.max_tokens = 1 prefill_sampling_params.max_tokens = 1
prefill_sampling_params.min_tokens = 1 prefill_sampling_params.min_tokens = 1
prefill_request = {
"token_ids": request["token_ids"],
"sampling_params": msgspec.to_builtins(prefill_sampling_params),
"request_id": request_id,
}
used_prefill_router = False
try: try:
prefill_worker_id = None # Send request with sampling_params and request_id in extra_args
prefill_request = request.copy()
# TODO (PeaBrane): this smells a bit bad as not we have two nestings
# of extra_args (an inner one again in sampling_params)
prefill_request["extra_args"] = {
"sampling_params": msgspec.to_builtins(prefill_sampling_params),
"request_id": request_id,
}
# Try router first if available, fallback to worker
if ( if (
self.prefill_router_client is not None self.prefill_router_client is not None
and self.prefill_router_client.instance_ids() and self.prefill_router_client.instance_ids()
): ):
used_prefill_router = True # Call router's generate endpoint which returns LLMEngineOutput
best_worker_response = await anext(
await self.prefill_router_client.generate(
{
"token_ids": request["token_ids"],
"request_id": request_id,
}
)
)
prefill_worker_id = best_worker_response.data().get("worker_id")
if prefill_worker_id is not None:
prefill_response = await anext( prefill_response = await anext(
await self.prefill_worker_client.direct( await self.prefill_router_client.generate(
prefill_request, prefill_worker_id, context=context prefill_request, context=context
) )
) )
else: elif self.prefill_worker_client is not None:
# Fallback to direct worker with same format
prefill_response = await anext( prefill_response = await anext(
await self.prefill_worker_client.round_robin( await self.prefill_worker_client.round_robin(
prefill_request, context=context prefill_request, context=context
) )
) )
else:
raise ValueError("No prefill router or worker available")
prefill_output = prefill_response.data()
# Extract kv_transfer_params from response
kv_transfer_params = prefill_output.get("extra_args", {}).get(
"kv_transfer_params"
)
if kv_transfer_params:
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args[
"kv_transfer_params"
] = kv_transfer_params
except Exception as e: except Exception as e:
if context.is_stopped() or context.is_killed(): if context.is_stopped() or context.is_killed():
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}") logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
return return
raise e logger.warning(f"Prefill error: {e}, falling back to local prefill")
finally:
if used_prefill_router:
await anext(
await self.prefill_router_free_client.generate(
{"request_id": request_id}
)
)
logger.debug(f"Freed router state for request {request_id}")
prefill_response = MyRequestOutput.model_validate_json(
prefill_response.data()
)
# Modify original sampling_params for decode
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args[
"kv_transfer_params"
] = prefill_response.kv_transfer_params
async with self._abort_monitor(context, request_id): async with self._abort_monitor(context, request_id):
try: try:
...@@ -276,11 +264,17 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -276,11 +264,17 @@ class PrefillWorkerHandler(BaseWorkerHandler):
super().__init__(runtime, component, engine, default_sampling_params) super().__init__(runtime, component, engine, default_sampling_params)
async def generate(self, request, context): async def generate(self, request, context):
request_id = request["request_id"] # Extract from PreprocessedRequest format - request_id and sampling_params from extra_args
extra_args = request.get("extra_args", {})
request_id = extra_args.get("request_id", str(uuid.uuid4().hex))
logger.debug(f"New Prefill Request ID: {request_id}") logger.debug(f"New Prefill Request ID: {request_id}")
prompt = TokensPrompt(prompt_token_ids=request["token_ids"]) token_ids = request["token_ids"]
sampling_params = msgspec.convert(request["sampling_params"], SamplingParams) prompt = TokensPrompt(prompt_token_ids=token_ids)
# Get sampling_params from extra_args
sampling_params_dict = extra_args.get("sampling_params", {})
sampling_params = msgspec.convert(sampling_params_dict, SamplingParams)
async with self._abort_monitor(context, request_id, is_prefill=True): async with self._abort_monitor(context, request_id, is_prefill=True):
try: try:
...@@ -291,20 +285,22 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -291,20 +285,22 @@ class PrefillWorkerHandler(BaseWorkerHandler):
self.runtime.shutdown() self.runtime.shutdown()
os._exit(1) os._exit(1)
# Generate only 1 token in prefill
try: try:
async for res in gen: async for res in gen:
logger.debug(f"kv transfer params: {res.kv_transfer_params}") logger.debug(f"kv transfer params: {res.kv_transfer_params}")
yield MyRequestOutput(
request_id=res.request_id, token_ids = res.outputs[0].token_ids if res.outputs else []
prompt=res.prompt,
prompt_token_ids=res.prompt_token_ids, output: Dict[str, Any] = {
prompt_logprobs=res.prompt_logprobs, "token_ids": list(token_ids),
outputs=res.outputs, "extra_args": (
finished=res.finished, {"kv_transfer_params": res.kv_transfer_params}
metrics=res.metrics, if res.kv_transfer_params
kv_transfer_params=res.kv_transfer_params, else {}
).model_dump_json() ),
}
yield output
except asyncio.CancelledError: except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests # raise the error because we cannot migrate prefill requests
raise GeneratorExit( raise GeneratorExit(
......
...@@ -227,14 +227,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -227,14 +227,7 @@ async def init(runtime: DistributedRuntime, config: Config):
prefill_router_client = ( prefill_router_client = (
await runtime.namespace(config.namespace) await runtime.namespace(config.namespace)
.component("router") # Standalone router for prefill workers .component("router") # Standalone router for prefill workers
.endpoint("find_best_worker") .endpoint("generate")
.client()
)
prefill_router_free_client = (
await runtime.namespace(config.namespace)
.component("router") # Standalone router for prefill workers
.endpoint("free")
.client() .client()
) )
...@@ -268,7 +261,6 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -268,7 +261,6 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params, default_sampling_params,
prefill_worker_client, prefill_worker_client,
prefill_router_client, prefill_router_client,
prefill_router_free_client,
) )
# Set up KV event publisher for prefix caching if enabled # Set up KV event publisher for prefix caching if enabled
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, List, Optional
from pydantic import BaseModel, ConfigDict
from vllm.outputs import CompletionOutput
from vllm.sequence import PromptLogprobs, RequestMetrics
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
kv_transfer_params: Optional[dict[str, Any]] = None
...@@ -174,7 +174,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -174,7 +174,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::WorkerStats>()?; m.add_class::<llm::kv::WorkerStats>()?;
m.add_class::<llm::kv::KvStats>()?; m.add_class::<llm::kv::KvStats>()?;
m.add_class::<llm::kv::SpecDecodeStats>()?; m.add_class::<llm::kv::SpecDecodeStats>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvPushRouter>()?; m.add_class::<llm::kv::KvPushRouter>()?;
m.add_class::<llm::kv::KvPushRouterStream>()?; m.add_class::<llm::kv::KvPushRouterStream>()?;
m.add_class::<RouterMode>()?; m.add_class::<RouterMode>()?;
......
...@@ -866,118 +866,6 @@ async fn create_kv_router_from_endpoint( ...@@ -866,118 +866,6 @@ async fn create_kv_router_from_endpoint(
Ok(kv_router) Ok(kv_router)
} }
#[pyclass]
pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>,
}
#[pymethods]
impl KvRouter {
#[new]
#[pyo3(signature = (endpoint, block_size, kv_router_config=None))]
fn new(
endpoint: &Endpoint,
block_size: usize,
kv_router_config: Option<&super::entrypoint::KvRouterConfig>,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async move {
let kv_router = create_kv_router_from_endpoint(
endpoint,
block_size,
kv_router_config.map(|c| c.inner()),
)
.await?;
Ok(Self { inner: kv_router })
})
}
#[pyo3(signature = (request_id, tokens, update_states=false, router_config_override=None))]
fn find_best_match<'p>(
&self,
py: Python<'p>,
request_id: String,
tokens: Vec<u32>,
update_states: bool,
router_config_override: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
depythonize(obj.bind(py)).map_err(to_pyerr)?;
Ok::<_, PyErr>(Some(override_config))
})?
} else {
None
};
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = inner
.find_best_match(
Some(&request_id),
&tokens,
router_config_override.as_ref(),
update_states,
)
.await
.map_err(to_pyerr)?;
Ok((worker_id, overlap_blocks))
})
}
fn add_request<'p>(
&self,
py: Python<'p>,
request_id: String,
tokens: Vec<u32>,
overlap_blocks: u32,
worker_id: i64,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner
.add_request(request_id, &tokens, overlap_blocks, worker_id)
.await;
Ok(())
})
}
fn mark_prefill_completed<'p>(
&self,
py: Python<'p>,
request_id: String,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner
.mark_prefill_completed(&request_id)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.free(&request_id).await.map_err(to_pyerr)?;
Ok(())
})
}
#[getter]
fn block_size(&self) -> PyResult<u32> {
Ok(self.inner.block_size())
}
}
#[pyclass] #[pyclass]
pub(crate) struct KvPushRouter { pub(crate) struct KvPushRouter {
inner: Arc<llm_rs::kv_router::KvPushRouter>, inner: Arc<llm_rs::kv_router::KvPushRouter>,
...@@ -1072,7 +960,7 @@ impl KvPushRouter { ...@@ -1072,7 +960,7 @@ impl KvPushRouter {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None))] #[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, extra_args=None))]
fn generate<'p>( fn generate<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
...@@ -1083,9 +971,10 @@ impl KvPushRouter { ...@@ -1083,9 +971,10 @@ impl KvPushRouter {
output_options: Option<PyObject>, output_options: Option<PyObject>,
router_config_override: Option<PyObject>, router_config_override: Option<PyObject>,
worker_id: Option<i64>, worker_id: Option<i64>,
extra_args: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the options with defaults // Depythonize the options with defaults
let (stop_conditions, sampling_options, output_options, router_config_override) = let (stop_conditions, sampling_options, output_options, router_config_override, extra_args) =
Python::with_gil(|py| { Python::with_gil(|py| {
let stop_conditions: StopConditions = if let Some(obj) = stop_conditions { let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
depythonize(obj.bind(py)).map_err(to_pyerr)? depythonize(obj.bind(py)).map_err(to_pyerr)?
...@@ -1112,11 +1001,18 @@ impl KvPushRouter { ...@@ -1112,11 +1001,18 @@ impl KvPushRouter {
None None
}; };
let extra_args: Option<serde_json::Value> = if let Some(obj) = extra_args {
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
Ok::<_, PyErr>(( Ok::<_, PyErr>((
stop_conditions, stop_conditions,
sampling_options, sampling_options,
output_options, output_options,
router_config_override, router_config_override,
extra_args,
)) ))
})?; })?;
...@@ -1129,7 +1025,8 @@ impl KvPushRouter { ...@@ -1129,7 +1025,8 @@ impl KvPushRouter {
.stop_conditions(stop_conditions) .stop_conditions(stop_conditions)
.sampling_options(sampling_options) .sampling_options(sampling_options)
.output_options(output_options) .output_options(output_options)
.router_config_override(router_config_override); .router_config_override(router_config_override)
.extra_args(extra_args);
// Set backend_instance_id if worker_id is provided // Set backend_instance_id if worker_id is provided
if let Some(worker_id) = worker_id { if let Some(worker_id) = worker_id {
......
...@@ -1129,103 +1129,6 @@ class ZmqKvEventListener: ...@@ -1129,103 +1129,6 @@ class ZmqKvEventListener:
""" """
... ...
class KvRouter:
"""
A KV Router that decides which worker to use based on KV cache overlap.
This router tracks request states and manages KV cache distribution across workers.
"""
def __init__(
self,
endpoint: Endpoint,
block_size: int,
kv_router_config: Optional[KvRouterConfig] = None,
consumer_uuid: Optional[str] = None,
) -> None:
"""
Create a new KvRouter instance.
Args:
endpoint: The endpoint to associate with this router
block_size: The KV cache block size
kv_router_config: Optional configuration for the KV router
consumer_uuid: Optional unique identifier for this router instance.
If not provided, a UUID will be generated.
"""
...
async def find_best_match(
self,
request_id: str,
tokens: List[int],
*,
update_states: bool = False,
router_config_override: Optional[JsonLike] = None,
) -> Tuple[int, int]:
"""
Find the best matching worker for the given tokens.
Args:
request_id: Unique identifier for the request used for tracking
tokens: List of token IDs to find matches for
update_states: Whether to update router states for this request (default: False)
router_config_override: Optional router configuration override with fields:
- overlap_score_weight: Optional weight for overlap score
- router_temperature: Optional temperature for worker selection
Returns:
A tuple of (worker_id, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- overlap_blocks: The number of overlapping blocks found
"""
...
async def add_request(
self,
request_id: str,
tokens: List[int],
overlap_blocks: int,
worker_id: int,
) -> None:
"""
Add a request to the router's tracking system.
Args:
request_id: Unique identifier for the request
tokens: List of token IDs for the request
overlap_blocks: Number of overlapping blocks found
worker_id: ID of the worker handling this request
"""
...
async def mark_prefill_completed(self, request_id: str) -> None:
"""
Mark that prefill has been completed for a request.
Args:
request_id: The request ID to mark as prefill completed
"""
...
async def free(self, request_id: str) -> None:
"""
Free resources associated with a request.
Args:
request_id: The request ID to free
"""
...
@property
def block_size(self) -> int:
"""
Get the KV cache block size.
Returns:
The block size in tokens
"""
...
class KvPushRouter: class KvPushRouter:
""" """
A KV-aware push router that performs intelligent routing based on KV cache overlap. A KV-aware push router that performs intelligent routing based on KV cache overlap.
......
...@@ -26,7 +26,6 @@ from dynamo._core import KvIndexer as KvIndexer ...@@ -26,7 +26,6 @@ from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvPushRouter as KvPushRouter from dynamo._core import KvPushRouter as KvPushRouter
from dynamo._core import KvRecorder as KvRecorder from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouter as KvRouter
from dynamo._core import KvRouterConfig as KvRouterConfig from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import KvStats as KvStats from dynamo._core import KvStats as KvStats
from dynamo._core import ModelInput as ModelInput from dynamo._core import ModelInput as ModelInput
......
...@@ -271,6 +271,7 @@ fn run_request( ...@@ -271,6 +271,7 @@ fn run_request(
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
extra_args: None,
}; };
work_request work_request
.response_channel .response_channel
......
...@@ -210,6 +210,7 @@ mod tests { ...@@ -210,6 +210,7 @@ mod tests {
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
extra_args: None,
}) })
} }
......
...@@ -392,6 +392,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -392,6 +392,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
extra_args: None,
}; };
if signal.completed && token_count < max_tokens { if signal.completed && token_count < max_tokens {
......
...@@ -84,6 +84,10 @@ pub struct LLMEngineOutput { ...@@ -84,6 +84,10 @@ pub struct LLMEngineOutput {
// Index field for batch requests to match OpenAI format // Index field for batch requests to match OpenAI format
pub index: Option<u32>, pub index: Option<u32>,
/// Additional arguments for extensibility
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>,
} }
impl LLMEngineOutput { impl LLMEngineOutput {
...@@ -97,6 +101,7 @@ impl LLMEngineOutput { ...@@ -97,6 +101,7 @@ impl LLMEngineOutput {
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Cancelled), finish_reason: Some(FinishReason::Cancelled),
index: None, index: None,
extra_args: None,
} }
} }
...@@ -110,6 +115,7 @@ impl LLMEngineOutput { ...@@ -110,6 +115,7 @@ impl LLMEngineOutput {
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
top_logprobs: None, top_logprobs: None,
index: None, index: None,
extra_args: None,
} }
} }
...@@ -123,6 +129,7 @@ impl LLMEngineOutput { ...@@ -123,6 +129,7 @@ impl LLMEngineOutput {
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Length), finish_reason: Some(FinishReason::Length),
index: None, index: None,
extra_args: None,
} }
} }
...@@ -136,6 +143,7 @@ impl LLMEngineOutput { ...@@ -136,6 +143,7 @@ impl LLMEngineOutput {
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Error(err_msg)), finish_reason: Some(FinishReason::Error(err_msg)),
index: None, index: None,
extra_args: None,
} }
} }
} }
......
...@@ -59,6 +59,11 @@ pub struct PreprocessedRequest { ...@@ -59,6 +59,11 @@ pub struct PreprocessedRequest {
/// Router configuration overrides for this specific request /// Router configuration overrides for this specific request
#[builder(default)] #[builder(default)]
pub router_config_override: Option<RouterConfigOverride>, pub router_config_override: Option<RouterConfigOverride>,
/// Additional arguments for extensibility
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>,
} }
impl PreprocessedRequest { impl PreprocessedRequest {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment