Unverified Commit 29813508 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(dynamo-run): KV-aware routing (#1064)

Router:
```
dynamo-run in=http out=dyn://dynamo.endpoint.generate --router-mode kv
```

Worker (* N):
```
dynamo-run in=dyn://dynamo.endpoint.generate out=vllm /data/llms/Qwen/Qwen3-4B
```

You need patched vllm and the C bindings `.so`. Full docs in the updated guide: `docs/guides/dynamo_run.md`.

This gives us a pure-Rust ingress node: OpenAI compliant HTTP server + Pre-processor + KV-aware router.
parent b82e7327
......@@ -18,7 +18,7 @@ use std::sync::Arc;
use dynamo_llm::{
http::service::{
discovery::{model_watcher, ModelWatchState},
discovery::{LLMRouterMode, ModelWatcher},
service_v2::HttpService,
},
model_type::ModelType,
......@@ -83,18 +83,22 @@ async fn app(runtime: Runtime) -> Result<()> {
for model_type in [ModelType::Chat, ModelType::Completion] {
let etcd_path = format!("{}/models/{}/", etcd_root, model_type.as_str());
let state = Arc::new(ModelWatchState {
prefix: etcd_path.clone(),
manager: manager.clone(),
drt: distributed.clone(),
});
let watch_obj = Arc::new(
ModelWatcher::new(
component.clone(),
manager.clone(),
&etcd_path,
LLMRouterMode::Random,
)
.await?,
);
if let Some(etcd_client) = distributed.etcd_client() {
let models_watcher: PrefixWatcher =
etcd_client.kv_get_and_watch_prefix(etcd_path).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
tokio::spawn(model_watcher(state, receiver));
tokio::spawn(watch_obj.watch(receiver));
}
}
......
......@@ -20,6 +20,8 @@
// 2. Update the backend component to produce a config in a standard location.
// 3. Update the KvRouter to read the config from the backend component.
use std::sync::Arc;
use clap::Parser;
use dynamo_llm::kv_router::{
......@@ -65,7 +67,7 @@ async fn app(runtime: Runtime) -> Result<()> {
let selector = Box::new(CustomWorkerSelector::default());
let router = KvRouter::new(component.clone(), args.block_size, Some(selector)).await?;
let router = Ingress::for_engine(router)?;
let router = Ingress::for_engine(Arc::new(router))?;
component
.service_builder()
......
......@@ -4,6 +4,7 @@
* [Automatically download a model from Hugging Face](#use-model-from-hugging-face)
* [Run a model from local file](#run-a-model-from-local-file)
* [Distributed system](#distributed-system)
* [KV-aware routing](#kv-aware-routing)
* [Full usage details](#full-usage-details)
* [Setup](#setup)
* [mistral.rs](#mistralrs)
......@@ -23,7 +24,7 @@ It supports the following engines: mistralrs, llamacpp, sglang, vllm and tensorr
Usage:
```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin]
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv]
```
Example: `dynamo run Qwen/Qwen3-0.6B`
......@@ -111,6 +112,43 @@ The `llama3B_pool` name is purely symbolic, pick anything as long as it matches
Run `dynamo-run --help` for more options.
### KV-aware routing
**Setup**
Only patched vllm currently supports KV-aware routing. Key setup steps:
1. `etcd` and `nats` (see earlier) must be running and accessible from all nodes.
1. Create a virtualenv: `uv venv kvtest`, source it's `activate`.
1. EITHER install Dynamo's vllm branch: `uv pip install ai-dynamo-vllm`,
1. OR install upstream vllm 0.8.4 (`uv pip install vllm==0.8.4`) and patch it: `cd kvtest/lib/python3.12/site-packages`, `patch -p1 < $REPO_ROOT/container/deps/vllm/vllm_v0.8.4-dynamo-kv-disagg-patch.patch`.
1. Build the C bindings. `cd $REPO_ROOT/lib/bindings/c`. `cargo build`.
1. Put the library you just built on library path: `export LD_LIBRARY_PATH=$REPO_ROOT/target/debug/`.
If you patched locally (instead of installing `ai-dynamo-vllm`) you will need to edit vllm's `platforms/__init__.py` to undo a patch change:
```
#vllm_version = version("ai_dynamo_vllm")
vllm_version = version("vllm")
```
**Start the workers**
The workers are started normally.
```
dynamo-run in=dyn://dynamo.endpoint.generate out=vllm /data/llms/Qwen/Qwen3-4B
```
**Start the ingress node**
```
dynamo-run in=http out=dyn://dynamo.endpoint.generate --router-mode kv
```
The only difference from the distributed system above is `--router-mode kv`. The patched vllm will announce when a KV block is created or removed. The Dynamo router run will find the worker with the best match for those KV blocks and direct the traffic to that node.
For performance testing compare a typical workload with `--router-mode random|round-robin` to see if it will benefit from KV-aware routing.
## Full usage details
`dynamo-run` is what `dynamo run` executes. It is also an example of what you can build in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide demonstrates how you can build from source with all the features.
......@@ -427,7 +465,7 @@ async def worker(runtime: DistributedRuntime):
#
component = runtime.namespace("namespace").component("component")
await component.create_service()
model_path = "Qwen/Qwen2.5-0.5B-Instruct" # or "/data/models/Qwen2.5-0.5B-Instruct"
model_path = "Qwen/Qwen3-0.6B" # or "/data/models/Qwen3-0.6B"
model_type = ModelType.Backend
endpoint = component.endpoint("endpoint")
# Optional last param to register_llm is model_name. If not present derives it from model_path
......@@ -457,7 +495,7 @@ if __name__ == "__main__":
The `model_path` can be:
- A HuggingFace repo ID. It will be downloaded and cached locally.
- A HuggingFace repo ID, optionally prefixed with `hf://`. It will be downloaded and cached locally.
- The path to a checkout of a HuggingFace repo - any folder containing safetensor files as well as `config.json`, `tokenizer.json` and `tokenizer_config.json`.
- The path to a GGUF file, if your engine supports that.
......
......@@ -17,6 +17,7 @@ use std::collections::HashMap;
use std::path::PathBuf;
use clap::ValueEnum;
use dynamo_llm::http::service::discovery::LLMRouterMode;
use dynamo_runtime::pipeline::RouterMode as RuntimeRouterMode;
/// Required options depend on the in and out choices
......@@ -184,7 +185,7 @@ impl Flags {
}
}
#[derive(Default, PartialEq, Eq, ValueEnum, Clone, Debug)]
#[derive(Default, PartialEq, Eq, ValueEnum, Clone, Debug, Copy)]
pub enum RouterMode {
#[default]
Random,
......@@ -198,14 +199,21 @@ impl RouterMode {
pub fn is_kv_routing(&self) -> bool {
*self == RouterMode::KV
}
}
impl From<RouterMode> for RuntimeRouterMode {
fn from(r: RouterMode) -> RuntimeRouterMode {
match r {
RouterMode::RoundRobin => RuntimeRouterMode::RoundRobin,
RouterMode::KV => todo!("KV not implemented yet"),
_ => RuntimeRouterMode::Random,
pub fn as_runtime(&self) -> Option<RuntimeRouterMode> {
match self {
RouterMode::RoundRobin => Some(RuntimeRouterMode::RoundRobin),
RouterMode::Random => Some(RuntimeRouterMode::Random),
// Runtime router does not have KV, it's a dynamo-llm thing, not dynamo-runtime
RouterMode::KV => None,
}
}
pub fn as_llm(&self) -> LLMRouterMode {
match self {
RouterMode::RoundRobin => LLMRouterMode::RoundRobin,
RouterMode::Random => LLMRouterMode::Random,
RouterMode::KV => LLMRouterMode::KV,
}
}
}
......@@ -19,6 +19,7 @@ use dynamo_llm::{
backend::{Backend, ExecutionContext},
engines::StreamingEngineAdapter,
http::service::discovery::ModelNetworkName,
kv_router::{scheduler::DefaultWorkerSelector, KvPushRouter, KvRouter},
model_card::ModelDeploymentCard,
model_type::ModelType,
preprocessor::OpenAIPreprocessor,
......@@ -67,82 +68,91 @@ pub async fn prepare_engine(
let client = endpoint.client().await?;
let mut cache_dir = None;
let engine: OpenAIChatCompletionsStreamingEngine = match &flags.router_mode {
RouterMode::Random | RouterMode::RoundRobin => {
tracing::info!("Waiting for remote model..");
let remote_endpoints = client.wait_for_endpoints().await?;
debug_assert!(!remote_endpoints.is_empty());
tracing::info!(count = remote_endpoints.len(), "Model(s) discovered");
tracing::info!("Waiting for remote model..");
let network_name: ModelNetworkName = (&remote_endpoints[0]).into();
let Some(etcd_client) = distributed_runtime.etcd_client() else {
anyhow::bail!("Cannot run distributed components without etcd");
};
let network_entry = network_name.load_entry(etcd_client.clone()).await?;
let mut card = network_entry.load_mdc(endpoint_id, etcd_client).await?;
let remote_endpoints = client.wait_for_endpoints().await?;
debug_assert!(!remote_endpoints.is_empty());
tracing::info!(count = remote_endpoints.len(), "Model(s) discovered");
match network_entry.model_type {
ModelType::Backend => {
// Download tokenizer.json etc to local disk
cache_dir = Some(
card.move_from_nats(distributed_runtime.nats_client())
.await?,
);
let network_name: ModelNetworkName = (&remote_endpoints[0]).into();
let Some(etcd_client) = distributed_runtime.etcd_client() else {
anyhow::bail!("Cannot run distributed components without etcd");
};
let network_entry = network_name.load_entry(etcd_client.clone()).await?;
let mut card = network_entry.load_mdc(endpoint_id, etcd_client).await?;
// The backend doesn't mind what we expose to the user (chat or
// completions), and this function is only used by text and batch input so
// the user doesn't see the HTTP request. So use Chat.
let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor =
OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router =
PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client,
flags.router_mode.into(),
)
.await?;
let engine: OpenAIChatCompletionsStreamingEngine = match network_entry.model_type {
ModelType::Backend => {
// Download tokenizer.json etc to local disk
cache_dir = Some(
card.move_from_nats(distributed_runtime.nats_client())
.await?,
);
frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(ServiceBackend::from_engine(Arc::new(router)))?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?
// The backend doesn't mind what we expose to the user (chat or
// completions), and this function is only used by text and batch input so
// the user doesn't see the HTTP request. So use Chat.
let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router =
PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client,
flags.router_mode.as_runtime(),
)
.await?;
let service_backend = match &flags.router_mode {
RouterMode::Random | RouterMode::RoundRobin => {
ServiceBackend::from_engine(Arc::new(router))
}
ModelType::Chat => Arc::new(
PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(
client, flags.router_mode.into()
)
.await?,
),
ModelType::Completion => {
anyhow::bail!("text and batch input only accept remote Chat models, not Completion");
/*
Arc::new(
PushRouter::<
CompletionRequest,
Annotated<CompletionResponse>,
>::from_client(
client, flags.router_mode.into()
)
.await?,
RouterMode::KV => {
let selector = Box::new(DefaultWorkerSelector {});
let chooser = KvRouter::new(
endpoint.component().clone(),
dynamo_llm::DEFAULT_KV_BLOCK_SIZE,
Some(selector),
)
*/
.await?;
let kv_push_router = KvPushRouter::new(router, Arc::new(chooser));
ServiceBackend::from_engine(Arc::new(kv_push_router))
}
}
};
frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(service_backend)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?
}
ModelType::Chat => Arc::new(
PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, flags.router_mode.as_runtime())
.await?,
),
ModelType::Completion => {
anyhow::bail!(
"text and batch input only accept remote Chat models, not Completion"
);
/*
Arc::new(
PushRouter::<
CompletionRequest,
Annotated<CompletionResponse>,
>::from_client(
client, flags.router_mode.into()
)
.await?,
)
*/
}
RouterMode::KV => todo!(),
};
// The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration.
let service_name = endpoint.subject();
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{pin::Pin, sync::Arc};
use std::{future::Future, pin::Pin, sync::Arc};
use dynamo_llm::{
backend::Backend,
......@@ -52,7 +52,8 @@ pub async fn run(
.await?
.endpoint(&endpoint_id.name);
let (rt_fut, mut card) = match engine_config {
let (rt_fut, card): (Pin<Box<dyn Future<Output = _> + Send + 'static>>, _) = match engine_config
{
EngineConfig::StaticFull { engine, mut model } => {
let engine = Arc::new(StreamingEngineAdapter::new(engine));
let ingress_chat = Ingress::<
......@@ -63,7 +64,7 @@ pub async fn run(
model.attach(&endpoint, ModelType::Chat).await?;
let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start();
(fut_chat, model.card().clone())
(Box::pin(fut_chat), Some(model.card().clone()))
}
EngineConfig::StaticCore {
engine: inner_engine,
......@@ -86,10 +87,13 @@ pub async fn run(
model.attach(&endpoint, ModelType::Backend).await?;
let fut = endpoint.endpoint_builder().handler(ingress).start();
(fut, model.card().clone())
(Box::pin(fut), Some(model.card().clone()))
}
EngineConfig::Dynamic(_) => {
anyhow::bail!("Cannot use endpoint for both in and out");
// We can only get here for in=dyn out=vllm|sglang`, because vllm and sglang are a
// subprocess that we talk to like a remote endpoint.
// That means the vllm/sglang subprocess is doing all the work, we are idle.
(never_ready(), None)
}
};
......@@ -102,12 +106,18 @@ pub async fn run(
}
// Cleanup on shutdown
if let Err(err) = card
.delete_from_nats(distributed_runtime.nats_client())
.await
{
tracing::error!(%err, "delete_from_nats error on shutdown");
if let Some(mut card) = card {
if let Err(err) = card
.delete_from_nats(distributed_runtime.nats_client())
.await
{
tracing::error!(%err, "delete_from_nats error on shutdown");
}
}
Ok(())
}
fn never_ready() -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>> {
Box::pin(std::future::pending::<anyhow::Result<()>>())
}
......@@ -17,6 +17,7 @@ use std::sync::Arc;
use crate::input::common;
use crate::{EngineConfig, Flags};
use dynamo_llm::http::service::discovery::LLMRouterMode;
use dynamo_llm::http::service::ModelManager;
use dynamo_llm::{
engines::StreamingEngineAdapter,
......@@ -29,6 +30,7 @@ use dynamo_llm::{
openai::completions::{CompletionRequest, CompletionResponse},
},
};
use dynamo_runtime::component::Component;
use dynamo_runtime::transports::etcd;
use dynamo_runtime::{DistributedRuntime, Runtime};
......@@ -59,10 +61,11 @@ pub async fn run(
// Listen for models registering themselves in etcd, add them to HTTP service
run_watcher(
distributed_runtime.clone(),
component.clone(),
http_service.model_manager().clone(),
etcd_client.clone(),
&network_prefix,
flags.router_mode.as_llm(),
)
.await?;
}
......@@ -114,19 +117,18 @@ pub async fn run(
/// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them.
async fn run_watcher(
distributed_runtime: DistributedRuntime,
component: Component,
model_manager: ModelManager,
etcd_client: etcd::Client,
network_prefix: &str,
router_mode: LLMRouterMode,
) -> anyhow::Result<()> {
let state = Arc::new(discovery::ModelWatchState {
prefix: network_prefix.to_string(),
manager: model_manager,
drt: distributed_runtime.clone(),
});
let watch_obj = Arc::new(
discovery::ModelWatcher::new(component, model_manager, network_prefix, router_mode).await?,
);
tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let _watcher_task = tokio::spawn(discovery::model_watcher(state, receiver));
let _watcher_task = tokio::spawn(watch_obj.watch(receiver));
Ok(())
}
......@@ -22,6 +22,9 @@ const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2);
#[cfg(feature = "python")]
const PYTHON_STR_SCHEME: &str = "pystr:";
/// Where we will attach the vllm/sglang subprocess. Invisible to users.
pub const INTERNAL_ENDPOINT: &str = "dyn://dynamo.internal.worker";
pub enum EngineConfig {
/// An remote networked engine we don't know about yet
Dynamic(Endpoint),
......@@ -45,6 +48,10 @@ pub async fn run(
out_opt: Output,
flags: Flags,
) -> anyhow::Result<()> {
if matches!(&in_opt, Input::Endpoint(_)) && matches!(&out_opt, Output::Endpoint(_)) {
anyhow::bail!("Cannot use endpoint for both in and out");
}
let cancel_token = runtime.primary_token();
let maybe_path = flags
.model_path_pos
......@@ -120,6 +127,14 @@ pub async fn run(
// TODO Does sglang support GGUF? Can we make it work?
anyhow::bail!("`--model-path should point at a HuggingFace repo checkout");
}
// If `in=dyn` we want the sglang subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we invent an internal one.
let endpoint = match &in_opt {
Input::Endpoint(path) => path.parse()?,
_ => INTERNAL_ENDPOINT.parse()?,
};
let multi_node_conf = dynamo_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
......@@ -128,6 +143,7 @@ pub async fn run(
let (py_script, child) = match subprocess::start(
subprocess::sglang::PY,
&local_model,
&endpoint,
flags.tensor_parallel_size,
if flags.base_gpu_id == 0 {
None
......@@ -154,16 +170,24 @@ pub async fn run(
extra = Some(Box::pin(async move {
stopper(cancel_token, child, py_script).await;
}));
let endpoint: Endpoint = subprocess::ENDPOINT.parse()?;
EngineConfig::Dynamic(endpoint)
}
Output::Vllm => {
if flags.base_gpu_id != 0 {
anyhow::bail!("vllm does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
// If `in=dyn` we want the vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we invent an internal one.
let endpoint = match &in_opt {
Input::Endpoint(path) => path.parse()?,
_ => INTERNAL_ENDPOINT.parse()?,
};
let (py_script, child) = match subprocess::start(
subprocess::vllm::PY,
&local_model,
&endpoint,
flags.tensor_parallel_size,
None, // base_gpu_id. vllm uses CUDA_VISIBLE_DEVICES instead
None, // multi-node config. vllm uses `ray`, see guide
......@@ -182,7 +206,6 @@ pub async fn run(
extra = Some(Box::pin(async move {
stopper(cancel_token, child, py_script).await;
}));
let endpoint: Endpoint = subprocess::ENDPOINT.parse()?;
EngineConfig::Dynamic(endpoint)
}
......
......@@ -30,7 +30,7 @@ Example:
- OR: ./dynamo-run /data/models/Llama-3.2-1B-Instruct-Q4_K_M.gguf
"#;
const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=ENGINE_LIST|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin]";
const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=ENGINE_LIST|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv]";
fn main() -> anyhow::Result<()> {
// Set log level based on verbosity flag
......
......@@ -13,18 +13,18 @@ use tokio::io::AsyncBufReadExt;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::LocalModel;
use dynamo_runtime::protocols::Endpoint as EndpointId;
pub mod sglang;
pub mod vllm;
/// Internal endpoint to connect the subprocess over etcd/nats
pub const ENDPOINT: &str = "dyn://dynamo.internal.worker";
pub async fn start(
// The Python code to run
py_script: &'static str,
// Model info
local_model: &LocalModel,
// Endpoint to connect the subprocess over etcd/nats
endpoint: &EndpointId,
// How many GPUs to use
tensor_parallel_size: u32,
// sglang which GPU to start from, on a multi-GPU system
......@@ -43,13 +43,15 @@ pub async fn start(
let mut args = vec![
script_path.to_string_lossy().to_string(),
"--endpoint".to_string(),
ENDPOINT.to_string(),
endpoint.as_url(),
"--model-path".to_string(),
local_model.path().to_string_lossy().to_string(),
"--model-name".to_string(),
local_model.display_name().to_string(),
"--tensor-parallel-size".to_string(),
tensor_parallel_size.to_string(),
"--kv-block-size".to_string(),
dynamo_llm::DEFAULT_KV_BLOCK_SIZE.to_string(),
];
// sglang only
if let Some(base_gpu_id) = base_gpu_id {
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# A very basic example of sglang worker handling pre-processed requests.
#
# Dynamo does the HTTP handling, prompt templating and tokenization, then forwards the
# request via NATS to this python script, which runs sglang.
#
# Setup a virtualenv with dynamo.llm, dynamo.runtime and sglang[all] installed
# in lib/bindings/python `maturin develop` and `pip install -e .` should do it
# Start nats and etcd:
# - nats-server -js
#
# Window 1: `python server_sglang.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn://dynamo.backend.generate`
# `dynamo-run out=sglang` runs this script
# Can also be used standalone: `python3 sglang_inc.py` - lots of optional cmd line params
import argparse
import asyncio
......@@ -28,8 +17,9 @@ from sglang.srt.server_args import ServerArgs
from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
logging.basicConfig(level=logging.DEBUG)
......@@ -44,6 +34,7 @@ class Config:
model_name: Optional[str]
base_gpu_id: int
tensor_parallel_size: int
kv_block_size: int
nnodes: int
node_rank: int
dist_init_addr: str
......@@ -106,6 +97,7 @@ async def init(runtime: DistributedRuntime, config: Config):
"skip_tokenizer_init": True,
"tp_size": config.tensor_parallel_size,
"base_gpu_id": config.base_gpu_id,
"page_size": config.kv_block_size,
}
if config.dist_init_addr != "":
arg_map["trust_remote_code"] = True
......@@ -168,6 +160,9 @@ def cmd_line_args():
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--nnodes", type=int, default=1, help="The number of machines SGLang will use"
)
......@@ -214,6 +209,7 @@ def cmd_line_args():
config.endpoint = parsed_endpoint_name
config.base_gpu_id = args.base_gpu_id
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.nnodes = args.nnodes
config.node_rank = args.node_rank
config.dist_init_addr = args.dist_init_addr
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# A very basic example of vllm worker handling pre-processed requests.
#
# Dynamo does the HTTP handling, prompt templating and tokenization, then forwards the
# request via NATS to this python script, which runs vllm.
#
# Setup a virtualenv with dynamo.llm, dynamo.runtime and vllm installed
# in lib/bindings/python `maturin develop` and `pip install -e .` should do it
# Start nats and etcd:
# - nats-server -js
#
# Window 1: `python vllm_inc.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn://dynamo.backend.generate`
# `dynamo-run out=vllm` runs this script
# Can also be used standalone: `python3 vllm_inc.py` - lots of optional cmd line params
# Setup checklist:
# - We are in a virtualenv with vllm installed - and patched if using kv routing.
# - `libdynamo_llm_capi.so` is in system lib path or it's containing folder is in LD_LIBRARY_PATH
# It builds in target/debug/ by default.
import argparse
import asyncio
......@@ -30,11 +25,12 @@ from vllm.entrypoints.openai.api_server import (
)
from vllm.inputs import TokensPrompt
from dynamo.llm import ModelType, register_llm
from dynamo.llm import KvMetricsPublisher, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
logging.basicConfig(level=logging.DEBUG)
......@@ -48,6 +44,7 @@ class Config:
model_path: str
model_name: Optional[str]
tensor_parallel_size: int
kv_block_size: int
extra_engine_args: str
......@@ -56,9 +53,37 @@ class RequestHandler:
Request handler for the generate endpoint
"""
def __init__(self, engine, default_sampling_params):
def __init__(self, component, engine, default_sampling_params):
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
self.metrics_publisher = KvMetricsPublisher()
def setup_kv_metrics(self):
if not hasattr(self.engine_client, "set_metrics_publisher"):
logging.debug("VLLM version does not support KV metrics")
return
self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
self.metrics_publisher.publish(
0, # request_active_slots
1024, # request_total_slots
0, # kv_active_blocks
1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
)
async def create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint")
await self.metrics_publisher.create_endpoint(self.component)
async def generate(self, request):
# logging.debug(f"Received request: {request}")
......@@ -126,6 +151,8 @@ async def init(runtime: DistributedRuntime, config: Config):
"tensor_parallel_size": config.tensor_parallel_size,
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
"block_size": config.kv_block_size,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
}
......@@ -142,6 +169,14 @@ async def init(runtime: DistributedRuntime, config: Config):
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
# Patch won't start KVCacheEventManager unless these four are set
os.environ["VLLM_WORKER_ID"] = str(endpoint.lease_id())
os.environ[
"VLLM_KV_CAPI_PATH"
] = "libdynamo_llm_capi.so" # Must be on LD_LIBRARY_PATH
os.environ["VLLM_KV_NAMESPACE"] = config.namespace
os.environ["VLLM_KV_COMPONENT"] = config.component
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
engine_args = AsyncEngineArgs(**arg_map)
model_config = engine_args.create_model_config()
......@@ -151,11 +186,12 @@ async def init(runtime: DistributedRuntime, config: Config):
engine_context = build_async_engine_client_from_engine_args(engine_args)
engine_client = await engine_context.__aenter__()
handler = RequestHandler(component, engine_client, default_sampling_params)
handler.setup_kv_metrics()
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(
RequestHandler(engine_client, default_sampling_params).generate
)
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
......@@ -183,6 +219,9 @@ def cmd_line_args():
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--extra-engine-args",
type=str,
......@@ -213,6 +252,7 @@ def cmd_line_args():
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.extra_engine_args = args.extra_engine_args
return config
......
......@@ -31,7 +31,7 @@ from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
class Config:
......
......@@ -43,7 +43,7 @@ from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
class Config:
......
......@@ -51,7 +51,6 @@ module-name = "dynamo._core"
manifest-path = "Cargo.toml"
python-packages = ["dynamo"]
python-source = "src"
features = ["dynamo-llm/block-manager"]
[build-system]
requires = ["maturin>=1.0,<2.0", "patchelf"]
......
......@@ -37,7 +37,9 @@ impl KvRouter {
llm_rs::kv_router::KvRouter::new(component.inner.clone(), kv_block_size, None)
.await
.map_err(to_pyerr)?;
Ok(Self { inner })
Ok(Self {
inner: Arc::new(inner),
})
})
}
......
......@@ -20,21 +20,19 @@ use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver;
use dynamo_runtime::{
component::{self, ComponentEndpointInfo},
component::{self, Component, ComponentEndpointInfo},
pipeline::{
network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
ServiceBackend, SingleIn, Source,
network::egress::push_router::PushRouter, ManyOut, Operator,
RouterMode as RuntimeRouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
},
protocols::{self, annotated::Annotated},
slug::Slug,
traits::DistributedRuntimeProvider as _,
transports::etcd::{self, KeyValue, WatchEvent},
DistributedRuntime,
};
use super::ModelManager;
use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use crate::{
backend::Backend,
......@@ -46,6 +44,12 @@ use crate::{
key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
model_card::{self, ModelDeploymentCard},
};
use crate::{
kv_router::{scheduler::DefaultWorkerSelector, KvPushRouter, KvRouter},
protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
},
};
use tracing;
/// [ModelEntry] is a struct that contains the information for the HTTP service to discover models
......@@ -161,180 +165,253 @@ impl std::fmt::Display for ModelNetworkName {
}
}
pub struct ModelWatchState {
pub prefix: String,
pub manager: ModelManager,
pub drt: DistributedRuntime,
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LLMRouterMode {
Random,
RoundRobin,
KV,
}
impl LLMRouterMode {
pub fn is_kv_routing(&self) -> bool {
*self == LLMRouterMode::KV
}
pub fn as_runtime(&self) -> Option<RuntimeRouterMode> {
match self {
LLMRouterMode::RoundRobin => Some(RuntimeRouterMode::RoundRobin),
LLMRouterMode::Random => Some(RuntimeRouterMode::Random),
// Runtime router does not have KV, it's a dynamo-llm thing, not dynamo-runtime
LLMRouterMode::KV => None,
}
}
}
pub struct ModelWatcher {
prefix: String,
manager: ModelManager,
drt: DistributedRuntime,
router_mode: LLMRouterMode,
kv_chooser: Option<Arc<KvRouter>>,
}
pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver<WatchEvent>) {
tracing::debug!("model watcher started");
impl ModelWatcher {
pub async fn new(
component: Component,
model_manager: ModelManager,
network_prefix: &str,
router_mode: LLMRouterMode,
) -> anyhow::Result<ModelWatcher> {
let kv_chooser = if router_mode.is_kv_routing() {
let selector = Box::new(DefaultWorkerSelector {});
let chooser = KvRouter::new(
component.clone(),
crate::DEFAULT_KV_BLOCK_SIZE,
Some(selector),
)
.await?;
Some(Arc::new(chooser))
} else {
None
};
Ok(Self {
prefix: network_prefix.to_string(),
manager: model_manager,
drt: component.drt().clone(),
router_mode,
kv_chooser,
})
}
while let Some(event) = events_rx.recv().await {
match event {
WatchEvent::Put(kv) => {
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
Ok(model_entry) => model_entry,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid JSON in model entry");
pub async fn watch(self: Arc<Self>, mut events_rx: Receiver<WatchEvent>) {
tracing::debug!("model watcher started");
while let Some(event) = events_rx.recv().await {
match event {
WatchEvent::Put(kv) => {
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
Ok(model_entry) => model_entry,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid JSON in model entry");
continue;
}
};
if self.manager.has_model_any(&model_entry.name) {
tracing::trace!(
service_name = model_entry.name,
"New endpoint for existing model"
);
continue;
}
};
if state.manager.has_model_any(&model_entry.name) {
tracing::trace!(
service_name = model_entry.name,
"New endpoint for existing model"
);
continue;
}
match handle_put(&model_entry, state.clone()).await {
Ok(()) => {
tracing::info!(model_name = model_entry.name, "added model");
match self.clone().handle_put(&model_entry).await {
Ok(()) => {
tracing::info!(model_name = model_entry.name, "added model");
}
Err(e) => {
tracing::error!(%e, "error adding model {}", model_entry.name);
}
}
}
WatchEvent::Delete(kv) => match self.clone().handle_delete(&kv).await {
Ok(model_name) => {
tracing::info!("removed model {}", model_name);
}
Err(e) => {
tracing::error!(%e, "error adding model {}", model_entry.name);
tracing::error!("error removing model: {}", e);
}
}
},
}
WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
Ok(model_name) => {
tracing::info!("removed model {}", model_name);
}
Err(e) => {
tracing::error!("error removing model: {}", e);
}
},
}
}
}
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> anyhow::Result<&str> {
let key = kv.key_str()?;
tracing::debug!(key, "removing model");
async fn handle_delete(self: Arc<Self>, kv: &KeyValue) -> anyhow::Result<&str> {
let key = kv.key_str()?;
tracing::debug!(key, "removing model");
let model_name = key.trim_start_matches(&state.prefix);
let model_name = key.trim_start_matches(&self.prefix);
// Ignore the errors because model could be either type
let _ = state.manager.remove_chat_completions_model(model_name);
let _ = state.manager.remove_completions_model(model_name);
// Ignore the errors because model could be either type
let _ = self.manager.remove_chat_completions_model(model_name);
let _ = self.manager.remove_completions_model(model_name);
Ok(model_name)
}
Ok(model_name)
}
// Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models.
//
// If this method errors, for the near term, we will delete the offending key.
async fn handle_put(model_entry: &ModelEntry, state: Arc<ModelWatchState>) -> anyhow::Result<()> {
let endpoint_id = model_entry.endpoint.clone();
let client = state
.drt
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?
.endpoint(&endpoint_id.name)
.client()
.await?;
let Some(etcd_client) = state.drt.etcd_client() else {
// Should be impossible because we only get here on an etcd event
anyhow::bail!("Missing etcd_client");
};
let card = match model_entry.load_mdc(endpoint_id, etcd_client).await {
Ok(card) => {
tracing::debug!(card.display_name, "adding model");
Some(card)
}
Err(err) => {
// `dynamo serve` isn't using MDC yet so can't be an error
tracing::info!(%err, "load_mdc did not complete");
None
}
};
match model_entry.model_type {
ModelType::Backend => {
// A Backend model expects pre-processed requests meaning it's up to us whether we
// handle Chat or Completions requests, so handle both.
let Some(mut card) = card else {
anyhow::bail!("Missing model deployment card");
};
// Download tokenizer.json etc to local disk
// This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_
// OpenAIPreprocessor::new loads the files, so we can delete them after this
// function. Needs checking carefully, possibly we need to store it in state.
let _cache_dir = Some(card.move_from_nats(state.drt.nats_client()).await?);
let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client.clone(),
RouterMode::Random, // TODO how do we configure this?
)
// Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models.
//
// If this method errors, for the near term, we will delete the offending key.
async fn handle_put(self: Arc<ModelWatcher>, model_entry: &ModelEntry) -> anyhow::Result<()> {
let endpoint_id = model_entry.endpoint.clone();
let client = self
.drt
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?
.endpoint(&endpoint_id.name)
.client()
.await?;
let chat_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(ServiceBackend::from_engine(Arc::new(router)))?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
state
.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?;
let frontend = SegmentSource::<
SingleIn<CompletionRequest>,
ManyOut<Annotated<CompletionResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client,
RouterMode::Random, // TODO how do we configure this?
)
.await?;
let Some(etcd_client) = self.drt.etcd_client() else {
// Should be impossible because we only get here on an etcd event
anyhow::bail!("Missing etcd_client");
};
let card = match model_entry.load_mdc(endpoint_id, etcd_client).await {
Ok(card) => {
tracing::debug!(card.display_name, "adding model");
Some(card)
}
Err(err) => {
// `dynamo serve` isn't using MDC yet so can't be an error
tracing::info!(%err, "load_mdc did not complete");
None
}
};
match model_entry.model_type {
ModelType::Backend => {
// A Backend model expects pre-processed requests meaning it's up to us whether we
// handle Chat or Completions requests, so handle both.
let Some(mut card) = card else {
anyhow::bail!("Missing model deployment card");
};
// Download tokenizer.json etc to local disk
// This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_
// OpenAIPreprocessor::new loads the files, so we can delete them after this
// function. Needs checking carefully, possibly we need to store it in state.
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client.clone(),
self.router_mode.as_runtime(),
)
.await?;
let service_backend = match self.router_mode {
LLMRouterMode::Random | LLMRouterMode::RoundRobin => {
ServiceBackend::from_engine(Arc::new(router))
}
LLMRouterMode::KV => {
let Some(kv_chooser) = self.kv_chooser.clone() else {
anyhow::bail!("KV routing mode with no chooser, should be unreachable");
};
let kv_push_router = KvPushRouter::new(router, kv_chooser);
ServiceBackend::from_engine(Arc::new(kv_push_router))
}
};
let completions_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(ServiceBackend::from_engine(Arc::new(router)))?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
state
.manager
.add_completions_model(&model_entry.name, completions_engine)?;
}
ModelType::Chat => {
let push_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, Default::default())
.await?;
let engine = Arc::new(push_router);
state
.manager
.add_chat_completions_model(&model_entry.name, engine)?;
}
ModelType::Completion => {
let push_router =
PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
let chat_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(service_backend)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?;
let frontend = SegmentSource::<
SingleIn<CompletionRequest>,
ManyOut<Annotated<CompletionResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client,
Default::default(),
self.router_mode.as_runtime(),
)
.await?;
let engine = Arc::new(push_router);
state
.manager
.add_completions_model(&model_entry.name, engine)?;
let service_backend = match self.router_mode {
LLMRouterMode::Random | LLMRouterMode::RoundRobin => {
ServiceBackend::from_engine(Arc::new(router))
}
LLMRouterMode::KV => {
let Some(kv_chooser) = self.kv_chooser.clone() else {
anyhow::bail!("KV routing mode with no chooser, should be unreachable");
};
let kv_push_router = KvPushRouter::new(router, kv_chooser);
ServiceBackend::from_engine(Arc::new(kv_push_router))
}
};
let completions_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(service_backend)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
self.manager
.add_completions_model(&model_entry.name, completions_engine)?;
}
ModelType::Chat => {
let push_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, Default::default())
.await?;
let engine = Arc::new(push_router);
self.manager
.add_chat_completions_model(&model_entry.name, engine)?;
}
ModelType::Completion => {
let push_router =
PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
client,
Default::default(),
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_completions_model(&model_entry.name, engine)?;
}
}
}
Ok(())
Ok(())
}
}
......@@ -13,18 +13,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use anyhow::Result;
use dynamo_runtime::{
component::Component,
component::{Component, EndpointSource},
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream,
SingleIn,
async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
ResponseStream, SingleIn,
},
prelude::*,
protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
use std::sync::Arc;
pub mod indexer;
pub mod metrics_aggregator;
......@@ -42,11 +43,16 @@ use crate::{
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
},
preprocessor::BackendInput,
protocols::common::llm_backend::LLMEngineOutput,
tokens::TokenBlockSequence,
};
use dynamo_runtime::traits::events::EventSubscriber;
// TODO: Allow user to change
pub const DEFAULT_KV_BLOCK_SIZE: usize = 16;
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
pub const KV_EVENT_SUBJECT: &str = "kv_events";
......@@ -63,6 +69,8 @@ pub trait WorkerSelector {
) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
pub struct KvRouter {
indexer: KvIndexer,
scheduler: KvScheduler,
......@@ -74,7 +82,7 @@ impl KvRouter {
component: Component,
block_size: usize,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Arc<Self>> {
) -> Result<Self> {
let cancellation_token = component
.drt()
.primary_lease()
......@@ -100,10 +108,7 @@ impl KvRouter {
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: RouterEvent = match serde_json::from_slice(&event.payload) {
Ok(event) => {
tracing::debug!("received kv event: {:?}", event);
event
}
Ok(event) => event,
Err(e) => {
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
// Choosing warn and continue to process other events from other workers
......@@ -112,16 +117,16 @@ impl KvRouter {
}
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e);
}
}
});
Ok(Arc::new(Self {
Ok(Self {
scheduler,
indexer,
block_size,
}))
})
}
// [TODO] indexer needs to take 'lora_id' as parameter
......@@ -137,28 +142,33 @@ impl KvRouter {
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
}
}
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
async fn generate(
&self,
request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts();
let isl_tokens = request.tokens.len();
/// Give these tokens, find the worker with the best match in it's KV cache.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<i64> {
let isl_tokens = tokens.len();
let block_size = self.block_size;
let (complete_blocks, _partial_block) =
TokenBlockSequence::split_tokens(&request.tokens, block_size, 1337_u64);
TokenBlockSequence::split_tokens(tokens, block_size, 1337_u64);
let local_block_hashes = complete_blocks
.into_iter()
.map(|block| LocalBlockHash(block.block_hash()))
.collect();
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
}
}
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
async fn generate(
&self,
request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts();
let worker_id = self.find_best_match(&request.tokens).await?;
let response = RouterResponse { worker_id };
let response = Annotated::from_data(response);
......@@ -166,3 +176,35 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}
pub struct KvPushRouter {
inner: PushRouter<BackendInput, Annotated<LLMEngineOutput>>,
chooser: Arc<KvRouter>,
}
impl KvPushRouter {
pub fn new(
inner: PushRouter<BackendInput, Annotated<LLMEngineOutput>>,
chooser: Arc<KvRouter>,
) -> Self {
KvPushRouter { inner, chooser }
}
}
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for KvPushRouter
{
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
match &self.inner.client.endpoints {
EndpointSource::Static => self.inner.r#static(request).await,
EndpointSource::Dynamic(_) => {
let worker_id = self.chooser.find_best_match(&request.token_ids).await?;
self.inner.direct(request, worker_id).await
}
}
}
}
......@@ -58,7 +58,6 @@ use std::{
};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tracing as log;
use xxhash_rust::xxh3;
pub const XXH3_SEED: u64 = 1337;
......@@ -160,6 +159,7 @@ impl RouterEvent {
}
/// A block in the Radix Tree.
#[derive(Debug)]
struct RadixBlock {
/// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedRadixBlock>,
......@@ -245,7 +245,6 @@ impl RadixTree {
let current_borrow = current.borrow();
current_borrow.children.get(&block_hash).cloned()
};
if let Some(block) = next_block {
scores.update_scores(&block.borrow().workers);
......@@ -284,7 +283,7 @@ impl RadixTree {
pub fn apply_event(&mut self, event: RouterEvent) {
let (worker_id, event) = (event.worker_id, event.event);
let (id, op) = (event.event_id, event.data);
log::debug!(id, "Store operation: {:?}", op);
tracing::trace!(id, "Store operation: {:?}", op);
let worker_lookup = self.lookup.entry(worker_id).or_default();
......@@ -301,7 +300,7 @@ impl RadixTree {
let mut current = match current {
Some(current) => current.clone(),
None => {
log::warn!(
tracing::warn!(
worker_id = worker_id.to_string(),
id,
parent_hash = ?op.parent_hash,
......@@ -344,7 +343,7 @@ impl RadixTree {
}
}
KvCacheEventData::Removed(remove) => {
// log::trace!(id, "KV Remove Operation: {:?}", op);
// tracing::trace!(id, "KV Remove Operation: {:?}", op);
// let mut worker_lookup = self.lookup.get(&worker_id).expect("Worker not found");
for block in remove.block_hashes {
......@@ -355,7 +354,7 @@ impl RadixTree {
let entry = match worker_lookup.get(&block) {
Some(entry) => entry.clone(),
None => {
log::warn!(
tracing::warn!(
worker_id = worker_id.to_string(),
id,
"Failed to find block to remove; skipping remove operation"
......@@ -562,7 +561,7 @@ impl KvIndexer {
}
_ = cancel.cancelled() => {
log::debug!("KvCacheIndexer progress loop shutting down");
tracing::debug!("KvCacheIndexer progress loop shutting down");
return;
}
......@@ -576,7 +575,7 @@ impl KvIndexer {
.unwrap()
}));
log::debug!("KvCacheIndexer task completed");
tracing::debug!("KvCacheIndexer task completed");
});
let once = OnceLock::new();
......@@ -624,7 +623,7 @@ impl KvIndexerInterface for KvIndexer {
};
if let Err(e) = self.match_tx.send(req).await {
log::error!(
tracing::error!(
"Failed to send match request: {:?}; the indexer maybe offline",
e
);
......@@ -640,13 +639,13 @@ impl KvIndexerInterface for KvIndexer {
&self,
tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
log::debug!(
tracing::debug!(
"Finding matches for request tokens: {:?} / len: {}",
tokens,
tokens.len()
);
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
log::debug!("Computed sequence: {:?}", sequence);
tracing::debug!("Computed sequence: {:?}", sequence);
self.find_matches(sequence).await
}
......@@ -748,12 +747,12 @@ impl KvIndexerSharded {
Ok(req) = shard_broadcast_rx.recv() => {
let matches = trie.find_matches(req.sequence, req.early_exit);
if let Err(e) = req.resp.send(matches).await {
log::trace!("Failed to send match response: {:?}", e);
tracing::trace!("Failed to send match response: {:?}", e);
}
}
_ = cancel.cancelled() => {
log::debug!("KvCacheIndexer progress loop shutting down");
tracing::debug!("KvCacheIndexer progress loop shutting down");
return;
}
......@@ -767,7 +766,7 @@ impl KvIndexerSharded {
.unwrap()
}));
log::debug!("KvCacheIndexer task completed");
tracing::debug!("KvCacheIndexer task completed");
}));
}
......
......@@ -70,10 +70,8 @@ pub async fn collect_endpoints(
.into_endpoints()
.filter(|e| e.subject.starts_with(subject))
.collect::<Vec<_>>();
tracing::debug!("Endpoints: {endpoints:?}");
if endpoints.is_empty() {
tracing::warn!("No endpoints found matching subject {subject}");
tracing::debug!("Metrics endpoint not visible yet");
}
Ok(endpoints)
......@@ -92,7 +90,6 @@ pub async fn collect_endpoints_task(
loop {
tokio::select! {
_ = cancel.cancelled() => {
tracing::debug!("cancellation token triggered");
break;
}
_ = tokio::time::sleep(backoff_delay) => {
......@@ -106,8 +103,6 @@ pub async fn collect_endpoints_task(
continue;
}
};
tracing::debug!("unfiltered endpoints: {:?}", unfiltered_endpoints);
let endpoints: Vec<Endpoint> = unfiltered_endpoints
.into_iter()
.filter(|s| s.data.is_some())
......@@ -125,13 +120,7 @@ pub async fn collect_endpoints_task(
}
)
.collect();
tracing::debug!("endpoints: {:?}", endpoints);
tracing::trace!(
"found {} endpoints for service: {}",
endpoints.len(),
service_subject
);
tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len());
let processed = ProcessedEndpoints::new(endpoints);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment