"vscode:/vscode.git/clone" did not exist on "2c2096ed838154f0ca8f77ed5600aeb302c49f6a"
Unverified Commit 30610e73 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: use KvPushRouter for prefill router (#3401)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent c48f49a4
......@@ -19,7 +19,7 @@ from typing import Optional
import uvloop
from dynamo.llm import KvRouter, KvRouterConfig
from dynamo.llm import KvPushRouter, KvRouterConfig
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -41,7 +41,7 @@ class StandaloneRouterHandler:
self.worker_endpoint_path = worker_endpoint_path
self.block_size = block_size
self.kv_router_config = kv_router_config
self.kv_router: Optional[KvRouter] = None
self.kv_push_router: Optional[KvPushRouter] = None
self.worker_client: Optional[Client] = None
async def initialize(self):
......@@ -65,121 +65,76 @@ class StandaloneRouterHandler:
self.worker_client = await worker_endpoint.client()
# Create KvRouter with specified configuration
self.kv_router = KvRouter(
# Create KvPushRouter with specified configuration
self.kv_push_router = KvPushRouter(
endpoint=worker_endpoint,
block_size=self.block_size,
kv_router_config=self.kv_router_config,
)
except Exception as e:
logger.error(f"Failed to initialize KvRouter: {e}")
logger.error(f"Failed to initialize KvPushRouter: {e}")
raise
async def find_best_worker(self, request):
async def generate(self, request):
"""
Find the best worker based on KV cache state.
Generate tokens using the KV-aware router.
This endpoint is called by clients to determine which worker
should handle a request.
This endpoint routes the request to the best worker and streams back results.
Wraps the request into PreprocessedRequest format and wraps worker responses
into LLMEngineOutput format.
"""
if self.kv_router is None:
# Fallback to round-robin if router not initialized
logger.warning("KvRouter not initialized, falling back to round-robin")
yield {
"status": "fallback",
"message": "Router not initialized",
if self.kv_push_router is None:
logger.error("KvPushRouter not initialized - cannot process request")
raise RuntimeError("Router not initialized")
# Wrap incoming request into PreprocessedRequest format for KvPushRouter
# The request should already have most fields, but we ensure it has the structure
preprocessed_request = {
"model": request.get("model", "unknown"),
"token_ids": request["token_ids"],
"stop_conditions": request.get("stop_conditions", {}),
"sampling_options": request.get("sampling_options", {}),
"output_options": request.get("output_options", {}),
"eos_token_ids": request.get("eos_token_ids", []),
"annotations": request.get("annotations", []),
"extra_args": request.get("extra_args", {}),
}
# Route and process through KvPushRouter
async for worker_output in await self.kv_push_router.generate_from_request(
preprocessed_request
):
# Wrap worker output into LLMEngineOutput format
# Worker should return dict with at minimum kv_transfer_params in extra_args
llm_engine_output = {
"token_ids": worker_output.get("token_ids", []),
"tokens": worker_output.get("tokens"),
"text": worker_output.get("text"),
"cum_log_probs": worker_output.get("cum_log_probs"),
"log_probs": worker_output.get("log_probs"),
"top_logprobs": worker_output.get("top_logprobs"),
"finish_reason": worker_output.get("finish_reason"),
"index": worker_output.get("index"),
"extra_args": worker_output.get("extra_args"),
}
return
yield llm_engine_output
try:
# Get current workers
if self.worker_client is None:
yield {
"status": "error",
"message": "Worker client not initialized",
}
return
instance_ids = self.worker_client.instance_ids()
if not instance_ids:
yield {
"status": "error",
"message": "No workers available",
}
return
logger.debug(f"Routing request with {len(instance_ids)} available workers")
# Validate required fields
if "token_ids" not in request:
raise ValueError("Missing required field 'token_ids' in request")
if "request_id" not in request:
raise ValueError("Missing required field 'request_id' in request")
token_ids = request["token_ids"]
request_id = request["request_id"]
# Use KvRouter to find the best worker with state updates
best_worker_id, overlap_blocks = await self.kv_router.find_best_match(
request_id=request_id,
tokens=token_ids,
update_states=True, # Always update states for routing
)
logger.debug(
f"Selected worker {best_worker_id} with {overlap_blocks} overlap blocks for request {request_id}"
)
yield {
"worker_id": best_worker_id,
"overlap_blocks": overlap_blocks,
}
except Exception as e:
logger.error(f"Error finding best worker: {e}")
yield {
"status": "error",
"message": str(e),
}
async def free(self, request):
async def best_worker_id(self, token_ids, router_config_override=None):
"""
Free resources associated with a request.
Get the best worker ID for a given set of tokens without actually routing.
This endpoint is called when a request is completed to clean up
router state.
This method returns the worker ID that would be selected based on KV cache
overlap, but does NOT actually route the request or update router states.
It's useful for debugging, monitoring, or implementing custom routing logic.
"""
if self.kv_router is None:
logger.warning("KvRouter not initialized")
yield {
"status": "error",
"message": "Router not initialized",
}
return
if self.kv_push_router is None:
logger.error("KvPushRouter not initialized - cannot get best worker")
raise RuntimeError("Router not initialized")
try:
if "request_id" not in request:
raise ValueError("Missing required field 'request_id' in request")
request_id = request["request_id"]
# Free the request from the router
await self.kv_router.free(request_id=request_id)
logger.debug(f"Freed resources for request {request_id}")
yield {
"status": "success",
"message": f"Request {request_id} freed successfully",
}
except Exception as e:
logger.error(f"Error freeing request: {e}")
yield {
"status": "error",
"message": str(e),
}
return await self.kv_push_router.best_worker_id(
token_ids, router_config_override
)
def parse_args():
......@@ -308,20 +263,21 @@ async def worker(runtime: DistributedRuntime):
await handler.initialize()
# Expose endpoints
find_best_worker_endpoint = component.endpoint("find_best_worker")
free_endpoint = component.endpoint("free")
generate_endpoint = component.endpoint("generate")
best_worker_endpoint = component.endpoint("best_worker_id")
logger.debug("Starting to serve find_best_worker and free endpoints...")
logger.debug("Starting to serve endpoints...")
# Serve both endpoints concurrently
try:
await asyncio.gather(
find_best_worker_endpoint.serve_endpoint(
handler.find_best_worker,
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("service", "router")],
),
free_endpoint.serve_endpoint(
handler.free,
best_worker_endpoint.serve_endpoint(
handler.best_worker_id,
graceful_shutdown=True,
metrics_labels=[("service", "router")],
),
......
......@@ -8,7 +8,7 @@ import uuid
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import AsyncGenerator
from typing import Any, AsyncGenerator, Dict
import msgspec
from vllm.inputs import TokensPrompt
......@@ -18,7 +18,6 @@ from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor
from .protocol import MyRequestOutput
configure_dynamo_logging()
logger = logging.getLogger(__name__)
......@@ -126,27 +125,34 @@ class DecodeWorkerHandler(BaseWorkerHandler):
default_sampling_params,
prefill_worker_client=None,
prefill_router_client=None,
prefill_router_free_client=None,
):
super().__init__(runtime, component, engine, default_sampling_params)
self.prefill_worker_client = prefill_worker_client
self.prefill_router_client = prefill_router_client
self.prefill_router_free_client = prefill_router_free_client
self.can_prefill = 0
self._prefill_check_task = None
if self.prefill_worker_client is not None:
if self.prefill_worker_client or self.prefill_router_client:
self._prefill_check_task = asyncio.create_task(self._prefill_check_loop())
async def _prefill_check_loop(self):
"""Background task that checks prefill worker availability every 5 seconds."""
"""Background task that checks prefill router/worker availability every 5 seconds."""
while True:
try:
if self.prefill_worker_client is not None:
self.can_prefill = len(self.prefill_worker_client.instance_ids())
logger.debug(f"Current Prefill Workers: {self.can_prefill}")
else:
self.can_prefill = 0
router_count = (
len(self.prefill_router_client.instance_ids())
if self.prefill_router_client is not None
else 0
)
worker_count = (
len(self.prefill_worker_client.instance_ids())
if self.prefill_worker_client is not None
else 0
)
self.can_prefill = max(router_count, worker_count)
logger.debug(
f"Prefill availability - Routers: {router_count}, Workers: {worker_count}"
)
except asyncio.CancelledError:
logger.warning("Prefill check loop cancelled.")
raise
......@@ -178,15 +184,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
# TODO: Change to prefill queue
# TODO: (PeaBrane) eventually, do not use a router_client and a free_client directly.
# This is least intrusive for now, but quite error prone. Should consider (major) refactoring
# TODO: (PeaBrane) longer term, decode workers should not handle prefill routing at all.
# Prefill routing logic should be integrated directly into the frontend service potentially.
# Use prefill router or worker if available
if self.can_prefill:
# Create a copy for prefill with specific modifications
# Create prefill sampling params with modifications
prefill_sampling_params = deepcopy(sampling_params)
if prefill_sampling_params.extra_args is None:
prefill_sampling_params.extra_args = {}
prefill_sampling_params.extra_args["kv_transfer_params"] = {
......@@ -195,68 +196,55 @@ class DecodeWorkerHandler(BaseWorkerHandler):
prefill_sampling_params.max_tokens = 1
prefill_sampling_params.min_tokens = 1
prefill_request = {
"token_ids": request["token_ids"],
"sampling_params": msgspec.to_builtins(prefill_sampling_params),
"request_id": request_id,
}
used_prefill_router = False
try:
prefill_worker_id = None
# Send request with sampling_params and request_id in extra_args
prefill_request = request.copy()
# TODO (PeaBrane): this smells a bit bad as not we have two nestings
# of extra_args (an inner one again in sampling_params)
prefill_request["extra_args"] = {
"sampling_params": msgspec.to_builtins(prefill_sampling_params),
"request_id": request_id,
}
# Try router first if available, fallback to worker
if (
self.prefill_router_client is not None
and self.prefill_router_client.instance_ids()
):
used_prefill_router = True
best_worker_response = await anext(
await self.prefill_router_client.generate(
{
"token_ids": request["token_ids"],
"request_id": request_id,
}
)
)
prefill_worker_id = best_worker_response.data().get("worker_id")
if prefill_worker_id is not None:
# Call router's generate endpoint which returns LLMEngineOutput
prefill_response = await anext(
await self.prefill_worker_client.direct(
prefill_request, prefill_worker_id, context=context
await self.prefill_router_client.generate(
prefill_request, context=context
)
)
else:
elif self.prefill_worker_client is not None:
# Fallback to direct worker with same format
prefill_response = await anext(
await self.prefill_worker_client.round_robin(
prefill_request, context=context
)
)
else:
raise ValueError("No prefill router or worker available")
prefill_output = prefill_response.data()
# Extract kv_transfer_params from response
kv_transfer_params = prefill_output.get("extra_args", {}).get(
"kv_transfer_params"
)
if kv_transfer_params:
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args[
"kv_transfer_params"
] = kv_transfer_params
except Exception as e:
if context.is_stopped() or context.is_killed():
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
return
raise e
finally:
if used_prefill_router:
await anext(
await self.prefill_router_free_client.generate(
{"request_id": request_id}
)
)
logger.debug(f"Freed router state for request {request_id}")
prefill_response = MyRequestOutput.model_validate_json(
prefill_response.data()
)
# Modify original sampling_params for decode
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args[
"kv_transfer_params"
] = prefill_response.kv_transfer_params
logger.warning(f"Prefill error: {e}, falling back to local prefill")
async with self._abort_monitor(context, request_id):
try:
......@@ -276,11 +264,17 @@ class PrefillWorkerHandler(BaseWorkerHandler):
super().__init__(runtime, component, engine, default_sampling_params)
async def generate(self, request, context):
request_id = request["request_id"]
# Extract from PreprocessedRequest format - request_id and sampling_params from extra_args
extra_args = request.get("extra_args", {})
request_id = extra_args.get("request_id", str(uuid.uuid4().hex))
logger.debug(f"New Prefill Request ID: {request_id}")
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = msgspec.convert(request["sampling_params"], SamplingParams)
token_ids = request["token_ids"]
prompt = TokensPrompt(prompt_token_ids=token_ids)
# Get sampling_params from extra_args
sampling_params_dict = extra_args.get("sampling_params", {})
sampling_params = msgspec.convert(sampling_params_dict, SamplingParams)
async with self._abort_monitor(context, request_id, is_prefill=True):
try:
......@@ -291,20 +285,22 @@ class PrefillWorkerHandler(BaseWorkerHandler):
self.runtime.shutdown()
os._exit(1)
# Generate only 1 token in prefill
try:
async for res in gen:
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
yield MyRequestOutput(
request_id=res.request_id,
prompt=res.prompt,
prompt_token_ids=res.prompt_token_ids,
prompt_logprobs=res.prompt_logprobs,
outputs=res.outputs,
finished=res.finished,
metrics=res.metrics,
kv_transfer_params=res.kv_transfer_params,
).model_dump_json()
token_ids = res.outputs[0].token_ids if res.outputs else []
output: Dict[str, Any] = {
"token_ids": list(token_ids),
"extra_args": (
{"kv_transfer_params": res.kv_transfer_params}
if res.kv_transfer_params
else {}
),
}
yield output
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
......
......@@ -227,14 +227,7 @@ async def init(runtime: DistributedRuntime, config: Config):
prefill_router_client = (
await runtime.namespace(config.namespace)
.component("router") # Standalone router for prefill workers
.endpoint("find_best_worker")
.client()
)
prefill_router_free_client = (
await runtime.namespace(config.namespace)
.component("router") # Standalone router for prefill workers
.endpoint("free")
.endpoint("generate")
.client()
)
......@@ -268,7 +261,6 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params,
prefill_worker_client,
prefill_router_client,
prefill_router_free_client,
)
# Set up KV event publisher for prefix caching if enabled
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, List, Optional
from pydantic import BaseModel, ConfigDict
from vllm.outputs import CompletionOutput
from vllm.sequence import PromptLogprobs, RequestMetrics
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
kv_transfer_params: Optional[dict[str, Any]] = None
......@@ -174,7 +174,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::WorkerStats>()?;
m.add_class::<llm::kv::KvStats>()?;
m.add_class::<llm::kv::SpecDecodeStats>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvPushRouter>()?;
m.add_class::<llm::kv::KvPushRouterStream>()?;
m.add_class::<RouterMode>()?;
......
......@@ -866,118 +866,6 @@ async fn create_kv_router_from_endpoint(
Ok(kv_router)
}
#[pyclass]
pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>,
}
#[pymethods]
impl KvRouter {
#[new]
#[pyo3(signature = (endpoint, block_size, kv_router_config=None))]
fn new(
endpoint: &Endpoint,
block_size: usize,
kv_router_config: Option<&super::entrypoint::KvRouterConfig>,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async move {
let kv_router = create_kv_router_from_endpoint(
endpoint,
block_size,
kv_router_config.map(|c| c.inner()),
)
.await?;
Ok(Self { inner: kv_router })
})
}
#[pyo3(signature = (request_id, tokens, update_states=false, router_config_override=None))]
fn find_best_match<'p>(
&self,
py: Python<'p>,
request_id: String,
tokens: Vec<u32>,
update_states: bool,
router_config_override: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
depythonize(obj.bind(py)).map_err(to_pyerr)?;
Ok::<_, PyErr>(Some(override_config))
})?
} else {
None
};
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = inner
.find_best_match(
Some(&request_id),
&tokens,
router_config_override.as_ref(),
update_states,
)
.await
.map_err(to_pyerr)?;
Ok((worker_id, overlap_blocks))
})
}
fn add_request<'p>(
&self,
py: Python<'p>,
request_id: String,
tokens: Vec<u32>,
overlap_blocks: u32,
worker_id: i64,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner
.add_request(request_id, &tokens, overlap_blocks, worker_id)
.await;
Ok(())
})
}
fn mark_prefill_completed<'p>(
&self,
py: Python<'p>,
request_id: String,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner
.mark_prefill_completed(&request_id)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.free(&request_id).await.map_err(to_pyerr)?;
Ok(())
})
}
#[getter]
fn block_size(&self) -> PyResult<u32> {
Ok(self.inner.block_size())
}
}
#[pyclass]
pub(crate) struct KvPushRouter {
inner: Arc<llm_rs::kv_router::KvPushRouter>,
......@@ -1072,7 +960,7 @@ impl KvPushRouter {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None))]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, extra_args=None))]
fn generate<'p>(
&self,
py: Python<'p>,
......@@ -1083,9 +971,10 @@ impl KvPushRouter {
output_options: Option<PyObject>,
router_config_override: Option<PyObject>,
worker_id: Option<i64>,
extra_args: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the options with defaults
let (stop_conditions, sampling_options, output_options, router_config_override) =
let (stop_conditions, sampling_options, output_options, router_config_override, extra_args) =
Python::with_gil(|py| {
let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
depythonize(obj.bind(py)).map_err(to_pyerr)?
......@@ -1112,11 +1001,18 @@ impl KvPushRouter {
None
};
let extra_args: Option<serde_json::Value> = if let Some(obj) = extra_args {
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
Ok::<_, PyErr>((
stop_conditions,
sampling_options,
output_options,
router_config_override,
extra_args,
))
})?;
......@@ -1129,7 +1025,8 @@ impl KvPushRouter {
.stop_conditions(stop_conditions)
.sampling_options(sampling_options)
.output_options(output_options)
.router_config_override(router_config_override);
.router_config_override(router_config_override)
.extra_args(extra_args);
// Set backend_instance_id if worker_id is provided
if let Some(worker_id) = worker_id {
......
......@@ -1129,103 +1129,6 @@ class ZmqKvEventListener:
"""
...
class KvRouter:
"""
A KV Router that decides which worker to use based on KV cache overlap.
This router tracks request states and manages KV cache distribution across workers.
"""
def __init__(
self,
endpoint: Endpoint,
block_size: int,
kv_router_config: Optional[KvRouterConfig] = None,
consumer_uuid: Optional[str] = None,
) -> None:
"""
Create a new KvRouter instance.
Args:
endpoint: The endpoint to associate with this router
block_size: The KV cache block size
kv_router_config: Optional configuration for the KV router
consumer_uuid: Optional unique identifier for this router instance.
If not provided, a UUID will be generated.
"""
...
async def find_best_match(
self,
request_id: str,
tokens: List[int],
*,
update_states: bool = False,
router_config_override: Optional[JsonLike] = None,
) -> Tuple[int, int]:
"""
Find the best matching worker for the given tokens.
Args:
request_id: Unique identifier for the request used for tracking
tokens: List of token IDs to find matches for
update_states: Whether to update router states for this request (default: False)
router_config_override: Optional router configuration override with fields:
- overlap_score_weight: Optional weight for overlap score
- router_temperature: Optional temperature for worker selection
Returns:
A tuple of (worker_id, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- overlap_blocks: The number of overlapping blocks found
"""
...
async def add_request(
self,
request_id: str,
tokens: List[int],
overlap_blocks: int,
worker_id: int,
) -> None:
"""
Add a request to the router's tracking system.
Args:
request_id: Unique identifier for the request
tokens: List of token IDs for the request
overlap_blocks: Number of overlapping blocks found
worker_id: ID of the worker handling this request
"""
...
async def mark_prefill_completed(self, request_id: str) -> None:
"""
Mark that prefill has been completed for a request.
Args:
request_id: The request ID to mark as prefill completed
"""
...
async def free(self, request_id: str) -> None:
"""
Free resources associated with a request.
Args:
request_id: The request ID to free
"""
...
@property
def block_size(self) -> int:
"""
Get the KV cache block size.
Returns:
The block size in tokens
"""
...
class KvPushRouter:
"""
A KV-aware push router that performs intelligent routing based on KV cache overlap.
......
......@@ -26,7 +26,6 @@ from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvPushRouter as KvPushRouter
from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouter as KvRouter
from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import KvStats as KvStats
from dynamo._core import ModelInput as ModelInput
......
......@@ -271,6 +271,7 @@ fn run_request(
top_logprobs: None,
finish_reason: None,
index: None,
extra_args: None,
};
work_request
.response_channel
......
......@@ -210,6 +210,7 @@ mod tests {
top_logprobs: None,
finish_reason: None,
index: None,
extra_args: None,
})
}
......
......@@ -392,6 +392,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
top_logprobs: None,
finish_reason: None,
index: None,
extra_args: None,
};
if signal.completed && token_count < max_tokens {
......
......@@ -84,6 +84,10 @@ pub struct LLMEngineOutput {
// Index field for batch requests to match OpenAI format
pub index: Option<u32>,
/// Additional arguments for extensibility
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>,
}
impl LLMEngineOutput {
......@@ -97,6 +101,7 @@ impl LLMEngineOutput {
top_logprobs: None,
finish_reason: Some(FinishReason::Cancelled),
index: None,
extra_args: None,
}
}
......@@ -110,6 +115,7 @@ impl LLMEngineOutput {
finish_reason: Some(FinishReason::Stop),
top_logprobs: None,
index: None,
extra_args: None,
}
}
......@@ -123,6 +129,7 @@ impl LLMEngineOutput {
top_logprobs: None,
finish_reason: Some(FinishReason::Length),
index: None,
extra_args: None,
}
}
......@@ -136,6 +143,7 @@ impl LLMEngineOutput {
top_logprobs: None,
finish_reason: Some(FinishReason::Error(err_msg)),
index: None,
extra_args: None,
}
}
}
......
......@@ -59,6 +59,11 @@ pub struct PreprocessedRequest {
/// Router configuration overrides for this specific request
#[builder(default)]
pub router_config_override: Option<RouterConfigOverride>,
/// Additional arguments for extensibility
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>,
}
impl PreprocessedRequest {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment