Commit ec2e7307 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(dynamo-run): Basic routing choice (#524)

As a first step towards KV routing:
- introduce a `--router-mode` in dynamo-run that only does random and round-robin right now. Not that interesting yet.
- Make the vllm engine publish the KV events received from our patched vllm.

Now we "just" need to connect the two. Easy right?
parent f91d5488
...@@ -17,6 +17,9 @@ use std::collections::HashMap; ...@@ -17,6 +17,9 @@ use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
use clap::ValueEnum;
use dynamo_runtime::component::RouterMode as RuntimeRouterMode;
/// Required options depend on the in and out choices /// Required options depend on the in and out choices
#[derive(clap::Parser, Debug, Clone)] #[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
...@@ -92,6 +95,13 @@ pub struct Flags { ...@@ -92,6 +95,13 @@ pub struct Flags {
#[arg(long)] #[arg(long)]
pub leader_addr: Option<String>, pub leader_addr: Option<String>,
/// If using `out=dyn://..` with multiple backends, this says how to route the requests.
///
/// Mostly interesting for KV-aware routing.
/// Defaults to RouterMode::Random
#[arg(long, default_value = "random")]
pub router_mode: RouterMode,
/// Internal use only. /// Internal use only.
// Start the python vllm engine sub-process. // Start the python vllm engine sub-process.
#[arg(long, hide = true, default_value = "false")] #[arg(long, hide = true, default_value = "false")]
...@@ -198,3 +208,29 @@ fn parse_sglang_flags(s: &str) -> Result<SgLangFlags, String> { ...@@ -198,3 +208,29 @@ fn parse_sglang_flags(s: &str) -> Result<SgLangFlags, String> {
gpu_id: nums[2], gpu_id: nums[2],
}) })
} }
#[derive(Default, PartialEq, Eq, ValueEnum, Clone, Debug)]
pub enum RouterMode {
#[default]
Random,
#[value(name = "round-robin")]
RoundRobin,
#[value(name = "kv")]
KV,
}
impl RouterMode {
pub fn is_kv_routing(&self) -> bool {
*self == RouterMode::KV
}
}
impl From<RouterMode> for RuntimeRouterMode {
fn from(r: RouterMode) -> RuntimeRouterMode {
match r {
RouterMode::RoundRobin => RuntimeRouterMode::RoundRobin,
RouterMode::KV => todo!("KV not implemented yet"),
_ => RuntimeRouterMode::Random,
}
}
}
...@@ -31,7 +31,7 @@ use std::time::{Duration, Instant}; ...@@ -31,7 +31,7 @@ use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use crate::input::common; use crate::input::common;
use crate::EngineConfig; use crate::{EngineConfig, Flags};
/// Max tokens in each response. /// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model /// TODO: For batch mode this should be the full context size of the model
...@@ -64,11 +64,12 @@ struct Entry { ...@@ -64,11 +64,12 @@ struct Entry {
pub async fn run( pub async fn run(
runtime: Runtime, runtime: Runtime,
cancel_token: CancellationToken, flags: Flags,
maybe_card: Option<ModelDeploymentCard>, maybe_card: Option<ModelDeploymentCard>,
input_jsonl: PathBuf, input_jsonl: PathBuf,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
// Check if the path exists and is a directory // Check if the path exists and is a directory
if !input_jsonl.exists() || !input_jsonl.is_file() { if !input_jsonl.exists() || !input_jsonl.is_file() {
anyhow::bail!( anyhow::bail!(
...@@ -78,7 +79,7 @@ pub async fn run( ...@@ -78,7 +79,7 @@ pub async fn run(
} }
let (service_name, engine, _inspect_template) = let (service_name, engine, _inspect_template) =
common::prepare_engine(runtime.clone(), engine_config).await?; common::prepare_engine(runtime, flags, engine_config).await?;
let service_name_ref = Arc::new(service_name); let service_name_ref = Arc::new(service_name);
let pre_processor = if let Some(card) = maybe_card { let pre_processor = if let Some(card) = maybe_card {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::EngineConfig; use crate::{flags::RouterMode, EngineConfig, Flags};
use dynamo_llm::{ use dynamo_llm::{
backend::Backend, backend::Backend,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
...@@ -34,6 +34,7 @@ use std::sync::Arc; ...@@ -34,6 +34,7 @@ use std::sync::Arc;
/// Turns an EngineConfig into an OpenAIChatCompletionsStreamingEngine. /// Turns an EngineConfig into an OpenAIChatCompletionsStreamingEngine.
pub async fn prepare_engine( pub async fn prepare_engine(
runtime: Runtime, runtime: Runtime,
flags: Flags,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<(String, OpenAIChatCompletionsStreamingEngine, bool)> { ) -> anyhow::Result<(String, OpenAIChatCompletionsStreamingEngine, bool)> {
match engine_config { match engine_config {
...@@ -41,14 +42,21 @@ pub async fn prepare_engine( ...@@ -41,14 +42,21 @@ pub async fn prepare_engine(
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint = distributed_runtime let endpoint = distributed_runtime
.namespace(endpoint_id.namespace)? .namespace(endpoint_id.namespace.clone())?
.component(endpoint_id.component)? .component(endpoint_id.component.clone())?
.endpoint(endpoint_id.name); .endpoint(endpoint_id.name.clone());
let client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?; let mut client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
match &flags.router_mode {
RouterMode::Random | RouterMode::RoundRobin => {
client.set_router_mode(flags.router_mode.into());
tracing::info!("Waiting for remote model.."); tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?; client.wait_for_endpoints().await?;
tracing::info!("Model discovered"); tracing::info!("Model discovered");
}
RouterMode::KV => todo!(),
}
// The service_name isn't used for text chat outside of logs, // The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration. // so use the path. That avoids having to listen on etcd for model registration.
......
...@@ -28,21 +28,22 @@ use dynamo_llm::{ ...@@ -28,21 +28,22 @@ use dynamo_llm::{
use dynamo_runtime::pipeline::{ use dynamo_runtime::pipeline::{
network::Ingress, ManyOut, Operator, SegmentSource, ServiceBackend, SingleIn, Source, network::Ingress, ManyOut, Operator, SegmentSource, ServiceBackend, SingleIn, Source,
}; };
use dynamo_runtime::{protocols::Endpoint, DistributedRuntime, Runtime}; use dynamo_runtime::{protocols::Endpoint, DistributedRuntime};
use crate::EngineConfig; use crate::EngineConfig;
pub async fn run( pub async fn run(
runtime: Runtime, distributed_runtime: DistributedRuntime,
path: String, path: String,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// This will attempt to connect to NATS and etcd // This will attempt to connect to NATS and etcd
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let cancel_token = runtime.primary_token().clone(); let cancel_token = distributed_runtime.primary_token().clone();
let endpoint_id: Endpoint = path.parse()?; let endpoint_id: Endpoint = path.parse()?;
let etcd_client = distributed_runtime.etcd_client();
let (ingress, service_name) = match engine_config { let (ingress, service_name) = match engine_config {
EngineConfig::StaticFull { EngineConfig::StaticFull {
service_name, service_name,
...@@ -85,7 +86,7 @@ pub async fn run( ...@@ -85,7 +86,7 @@ pub async fn run(
model_type: ModelType::Chat, model_type: ModelType::Chat,
}; };
let component = distributed let component = distributed_runtime
.namespace(endpoint_id.namespace)? .namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?; .component(endpoint_id.component)?;
let endpoint = component let endpoint = component
...@@ -94,8 +95,8 @@ pub async fn run( ...@@ -94,8 +95,8 @@ pub async fn run(
.await? .await?
.endpoint(endpoint_id.name); .endpoint(endpoint_id.name);
if let Some(etcd_client) = distributed.etcd_client() { if let Some(etcd_client) = etcd_client {
let network_name = endpoint.subject(); let network_name = endpoint.subject_to(etcd_client.lease_id());
tracing::debug!("Registering with etcd as {network_name}"); tracing::debug!("Registering with etcd as {network_name}");
etcd_client etcd_client
.kv_create( .kv_create(
......
...@@ -32,16 +32,16 @@ use dynamo_runtime::{ ...@@ -32,16 +32,16 @@ use dynamo_runtime::{
DistributedRuntime, Runtime, DistributedRuntime, Runtime,
}; };
use crate::EngineConfig; use crate::{EngineConfig, Flags};
/// Build and run an HTTP service /// Build and run an HTTP service
pub async fn run( pub async fn run(
runtime: Runtime, runtime: Runtime,
http_port: u16, flags: Flags,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let http_service = service_v2::HttpService::builder() let http_service = service_v2::HttpService::builder()
.port(http_port) .port(flags.http_port)
.enable_chat_endpoints(true) .enable_chat_endpoints(true)
.enable_cmpl_endpoints(true) .enable_cmpl_endpoints(true)
.build()?; .build()?;
......
...@@ -21,7 +21,7 @@ use futures::StreamExt; ...@@ -21,7 +21,7 @@ use futures::StreamExt;
use std::io::{ErrorKind, Write}; use std::io::{ErrorKind, Write};
use crate::input::common; use crate::input::common;
use crate::EngineConfig; use crate::{EngineConfig, Flags};
/// Max response tokens for each single query. Must be less than model context size. /// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this /// TODO: Cmd line flag to overwrite this
...@@ -29,15 +29,16 @@ const MAX_TOKENS: u32 = 8192; ...@@ -29,15 +29,16 @@ const MAX_TOKENS: u32 = 8192;
pub async fn run( pub async fn run(
runtime: Runtime, runtime: Runtime,
cancel_token: CancellationToken, flags: Flags,
single_prompt: Option<String>, single_prompt: Option<String>,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
let (service_name, engine, inspect_template): ( let (service_name, engine, inspect_template): (
String, String,
OpenAIChatCompletionsStreamingEngine, OpenAIChatCompletionsStreamingEngine,
bool, bool,
) = common::prepare_engine(runtime.clone(), engine_config).await?; ) = common::prepare_engine(runtime, flags, engine_config).await?;
main_loop( main_loop(
cancel_token, cancel_token,
&service_name, &service_name,
......
...@@ -13,15 +13,16 @@ ...@@ -13,15 +13,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::io::Read;
#[cfg(any(feature = "vllm", feature = "sglang"))] #[cfg(any(feature = "vllm", feature = "sglang"))]
use std::{future::Future, pin::Pin}; use std::{future::Future, pin::Pin};
use std::{io::Read, sync::Arc};
use dynamo_llm::{ use dynamo_llm::{
backend::ExecutionContext, model_card::model::ModelDeploymentCard, backend::ExecutionContext, kv_router::publisher::KvMetricsPublisher,
model_card::model::ModelDeploymentCard,
types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine, types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine,
}; };
use dynamo_runtime::protocols::Endpoint; use dynamo_runtime::{protocols::Endpoint, DistributedRuntime};
mod flags; mod flags;
pub use flags::Flags; pub use flags::Flags;
...@@ -41,6 +42,9 @@ const ENDPOINT_SCHEME: &str = "dyn://"; ...@@ -41,6 +42,9 @@ const ENDPOINT_SCHEME: &str = "dyn://";
/// the command line. Hence it's optional, and defaults to this. /// the command line. Hence it's optional, and defaults to this.
const INVISIBLE_MODEL_NAME: &str = "dynamo-run"; const INVISIBLE_MODEL_NAME: &str = "dynamo-run";
/// The component name for the KV publisher, if used
const KV_PUBLISHER_COMPONENT: &str = "kvpublisher";
/// How we identify a python string endpoint /// How we identify a python string endpoint
#[cfg(feature = "python")] #[cfg(feature = "python")]
const PYTHON_STR_SCHEME: &str = "pystr:"; const PYTHON_STR_SCHEME: &str = "pystr:";
...@@ -70,6 +74,12 @@ pub enum EngineConfig { ...@@ -70,6 +74,12 @@ pub enum EngineConfig {
None, None,
} }
/// Distributed system values
struct DynInput {
endpoint_id: Endpoint,
distributed_runtime: DistributedRuntime,
}
#[allow(unused_mut)] #[allow(unused_mut)]
pub async fn run( pub async fn run(
runtime: dynamo_runtime::Runtime, runtime: dynamo_runtime::Runtime,
...@@ -171,6 +181,19 @@ pub async fn run( ...@@ -171,6 +181,19 @@ pub async fn run(
} }
}; };
// If we are in a distributed system, we need to know our component upfront
let dyn_input = match &in_opt {
Input::Endpoint(endpoint_path) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint_id: Endpoint = endpoint_path.parse()?;
Some(DynInput {
endpoint_id,
distributed_runtime,
})
}
_ => None,
};
#[cfg(any(feature = "vllm", feature = "sglang"))] #[cfg(any(feature = "vllm", feature = "sglang"))]
let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process
...@@ -219,7 +242,6 @@ pub async fn run( ...@@ -219,7 +242,6 @@ pub async fn run(
} }
#[cfg(feature = "sglang")] #[cfg(feature = "sglang")]
Output::SgLang => { Output::SgLang => {
use dynamo_engine_sglang;
let Some(model_path) = model_path else { let Some(model_path) = model_path else {
anyhow::bail!("out=sglang requires flag --model-path=<full-path-to-model-dir>"); anyhow::bail!("out=sglang requires flag --model-path=<full-path-to-model-dir>");
}; };
...@@ -234,7 +256,7 @@ pub async fn run( ...@@ -234,7 +256,7 @@ pub async fn run(
let node_conf = dynamo_llm::engines::MultiNodeConfig { let node_conf = dynamo_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes, num_nodes: flags.num_nodes,
node_rank: flags.node_rank, node_rank: flags.node_rank,
leader_addr: flags.leader_addr.unwrap_or_default(), leader_addr: flags.leader_addr.clone().unwrap_or_default(),
}; };
if node_conf.num_nodes > 1 { if node_conf.num_nodes > 1 {
if let Ok(Some(if_name)) = net::get_primary_interface().await { if let Ok(Some(if_name)) = net::get_primary_interface().await {
...@@ -256,7 +278,7 @@ pub async fn run( ...@@ -256,7 +278,7 @@ pub async fn run(
node_conf, node_conf,
flags.tensor_parallel_size, flags.tensor_parallel_size,
flags.base_gpu_id, flags.base_gpu_id,
flags.extra_engine_args, flags.extra_engine_args.clone(),
) )
.await?; .await?;
extra = Some(Box::pin(async move { extra = Some(Box::pin(async move {
...@@ -289,7 +311,7 @@ pub async fn run( ...@@ -289,7 +311,7 @@ pub async fn run(
let node_conf = dynamo_llm::engines::MultiNodeConfig { let node_conf = dynamo_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes, num_nodes: flags.num_nodes,
node_rank: flags.node_rank, node_rank: flags.node_rank,
leader_addr: flags.leader_addr.unwrap_or_default(), leader_addr: flags.leader_addr.clone().unwrap_or_default(),
}; };
if node_conf.num_nodes > 1 { if node_conf.num_nodes > 1 {
if let Ok(Some(if_name)) = net::get_primary_interface().await { if let Ok(Some(if_name)) = net::get_primary_interface().await {
...@@ -302,6 +324,19 @@ pub async fn run( ...@@ -302,6 +324,19 @@ pub async fn run(
} }
} }
if node_conf.node_rank == 0 { if node_conf.node_rank == 0 {
let kv_metrics_publisher = if let Some(dyn_input) = &dyn_input {
let kvp_component = dyn_input
.distributed_runtime
.namespace(dyn_input.endpoint_id.namespace.clone())?
.component(KV_PUBLISHER_COMPONENT)?;
let kvp = Arc::new(KvMetricsPublisher::new()?);
let kvp_inner = kvp.clone();
tokio::spawn(async move { kvp_inner.create_endpoint(kvp_component).await });
Some(kvp)
} else {
None
};
// vllm multi-node only the leader runs vllm // vllm multi-node only the leader runs vllm
let (engine, vllm_future) = dynamo_engine_vllm::make_leader_engine( let (engine, vllm_future) = dynamo_engine_vllm::make_leader_engine(
cancel_token.clone(), cancel_token.clone(),
...@@ -309,7 +344,8 @@ pub async fn run( ...@@ -309,7 +344,8 @@ pub async fn run(
&sock_prefix, &sock_prefix,
node_conf, node_conf,
flags.tensor_parallel_size, flags.tensor_parallel_size,
flags.extra_engine_args, flags.extra_engine_args.clone(),
kv_metrics_publisher,
) )
.await?; .await?;
extra = Some(Box::pin(async move { extra = Some(Box::pin(async move {
...@@ -330,7 +366,6 @@ pub async fn run( ...@@ -330,7 +366,6 @@ pub async fn run(
} }
#[cfg(feature = "llamacpp")] #[cfg(feature = "llamacpp")]
Output::LlamaCpp => { Output::LlamaCpp => {
use dynamo_engine_llamacpp;
let Some(model_path) = model_path else { let Some(model_path) = model_path else {
anyhow::bail!("out=llamacpp requires flag --model-path=<full-path-to-model-gguf>"); anyhow::bail!("out=llamacpp requires flag --model-path=<full-path-to-model-gguf>");
}; };
...@@ -408,35 +443,25 @@ pub async fn run( ...@@ -408,35 +443,25 @@ pub async fn run(
match in_opt { match in_opt {
Input::Http => { Input::Http => {
crate::input::http::run(runtime.clone(), flags.http_port, engine_config).await?; crate::input::http::run(runtime.clone(), flags, engine_config).await?;
} }
Input::Text => { Input::Text => {
crate::input::text::run(runtime.clone(), cancel_token.clone(), None, engine_config) crate::input::text::run(runtime.clone(), flags, None, engine_config).await?;
.await?;
} }
Input::Stdin => { Input::Stdin => {
let mut prompt = String::new(); let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap(); std::io::stdin().read_to_string(&mut prompt).unwrap();
crate::input::text::run( crate::input::text::run(runtime.clone(), flags, Some(prompt), engine_config).await?;
runtime.clone(),
cancel_token.clone(),
Some(prompt),
engine_config,
)
.await?;
} }
Input::Batch(path) => { Input::Batch(path) => {
crate::input::batch::run( crate::input::batch::run(runtime.clone(), flags, maybe_card, path, engine_config)
runtime.clone(),
cancel_token.clone(),
maybe_card,
path,
engine_config,
)
.await?; .await?;
} }
Input::Endpoint(path) => { Input::Endpoint(path) => {
crate::input::endpoint::run(runtime.clone(), path, engine_config).await?; let Some(dyn_input) = dyn_input else {
unreachable!("We set dyn_input earlier");
};
crate::input::endpoint::run(dyn_input.distributed_runtime, path, engine_config).await?;
} }
Input::None => { Input::None => {
// Multi-node setup. The engine sub-process has been started and is talking // Multi-node setup. The engine sub-process has been started and is talking
......
...@@ -32,7 +32,7 @@ Example: ...@@ -32,7 +32,7 @@ Example:
const ZMQ_SOCKET_PREFIX: &str = "dyn"; const ZMQ_SOCKET_PREFIX: &str = "dyn";
const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>|none] out=[See available engines below] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json]"; const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>|none] out=[See available engines below] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin]";
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
logging::init(); logging::init();
...@@ -94,6 +94,7 @@ fn main() -> anyhow::Result<()> { ...@@ -94,6 +94,7 @@ fn main() -> anyhow::Result<()> {
node_config, node_config,
flags.tensor_parallel_size, flags.tensor_parallel_size,
flags.extra_engine_args, flags.extra_engine_args,
flags.router_mode.is_kv_routing(),
); );
} }
} else { } else {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod common;
pub mod echo_core;
pub mod echo_full;
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
// limitations under the License. // limitations under the License.
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc;
use async_stream::stream; use async_stream::stream;
use async_trait::async_trait; use async_trait::async_trait;
use dynamo_llm::engines::MultiNodeConfig; use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::kv_router::publisher::KvMetricsPublisher;
use dynamo_llm::protocols::common::llm_backend::{BackendInput, LLMEngineOutput}; use dynamo_llm::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream}; use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn}; use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
...@@ -40,6 +42,7 @@ impl VllmEngine { ...@@ -40,6 +42,7 @@ impl VllmEngine {
node_conf: MultiNodeConfig, node_conf: MultiNodeConfig,
tensor_parallel_size: u32, tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
kv_metrics_publisher: Option<Arc<KvMetricsPublisher>>,
) -> anyhow::Result<Self> { ) -> anyhow::Result<Self> {
let w = worker::start( let w = worker::start(
cancel_token.clone(), cancel_token.clone(),
...@@ -48,6 +51,7 @@ impl VllmEngine { ...@@ -48,6 +51,7 @@ impl VllmEngine {
node_conf, node_conf,
tensor_parallel_size, tensor_parallel_size,
extra_engine_args, extra_engine_args,
kv_metrics_publisher,
) )
.await?; .await?;
let engine = VllmEngine { let engine = VllmEngine {
......
...@@ -26,6 +26,7 @@ use dynamo_runtime::CancellationToken; ...@@ -26,6 +26,7 @@ use dynamo_runtime::CancellationToken;
use dynamo_llm::backend::ExecutionContext; use dynamo_llm::backend::ExecutionContext;
use dynamo_llm::engines::MultiNodeConfig; use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::kv_router::publisher::KvMetricsPublisher;
mod engine; mod engine;
use engine::VllmEngine; use engine::VllmEngine;
...@@ -50,6 +51,8 @@ pub async fn make_leader_engine( ...@@ -50,6 +51,8 @@ pub async fn make_leader_engine(
tensor_parallel_size: u32, tensor_parallel_size: u32,
// Path to extra engine args file // Path to extra engine args file
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
// When using our vllm fork, this is how we publish it's KV metrics for the KV router
kv_metrics_publisher: Option<Arc<KvMetricsPublisher>>,
) -> pipeline_error::Result<(ExecutionContext, impl Future<Output = ()>)> { ) -> pipeline_error::Result<(ExecutionContext, impl Future<Output = ()>)> {
let ray_obj = if node_conf.num_nodes > 1 { let ray_obj = if node_conf.num_nodes > 1 {
let r = ray::start_leader(node_conf.leader_addr.parse()?)?; let r = ray::start_leader(node_conf.leader_addr.parse()?)?;
...@@ -69,6 +72,7 @@ pub async fn make_leader_engine( ...@@ -69,6 +72,7 @@ pub async fn make_leader_engine(
node_conf, node_conf,
tensor_parallel_size, tensor_parallel_size,
extra_engine_args, extra_engine_args,
kv_metrics_publisher,
) )
.await?; .await?;
let vllm_process = engine.take_vllm_worker_handle(); let vllm_process = engine.take_vllm_worker_handle();
......
...@@ -31,7 +31,11 @@ pub fn run_subprocess( ...@@ -31,7 +31,11 @@ pub fn run_subprocess(
node_config: MultiNodeConfig, node_config: MultiNodeConfig,
tp_size: u32, tp_size: u32,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
with_kv_routing: bool,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
if with_kv_routing {
set_kv_routing_vars()?;
}
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize" pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
if let Ok(venv) = env::var("VIRTUAL_ENV") { if let Ok(venv) = env::var("VIRTUAL_ENV") {
let _ = Python::with_gil(|py| crate::fix_venv(venv, py)); let _ = Python::with_gil(|py| crate::fix_venv(venv, py));
...@@ -47,6 +51,7 @@ pub fn run_subprocess( ...@@ -47,6 +51,7 @@ pub fn run_subprocess(
("tp_size_str", &tp_size.to_string()), ("tp_size_str", &tp_size.to_string()),
("nnodes_str", &node_config.num_nodes.to_string()), ("nnodes_str", &node_config.num_nodes.to_string()),
("extra_engine_args", extra_engine_args_str), ("extra_engine_args", extra_engine_args_str),
("enable_prefix_caching", &with_kv_routing.to_string()),
] ]
.into_py_dict(py) .into_py_dict(py)
.unwrap(); .unwrap();
...@@ -57,3 +62,28 @@ pub fn run_subprocess( ...@@ -57,3 +62,28 @@ pub fn run_subprocess(
Ok(()) Ok(())
}) })
} }
// These environment variables trigger our vllm patch to emit KV routing events
fn set_kv_routing_vars() -> anyhow::Result<()> {
let exe = env::current_exe()?;
let exe_dir = exe
.parent()
.ok_or(anyhow::anyhow!("Current binary has no directory"))?;
let mut lib = PathBuf::from(exe_dir);
lib.set_file_name("libdynamo_llm_capi.so");
let vars = [
// Path to the C API Library
("VLLM_KV_CAPI_PATH", lib.display().to_string()),
// Identifiers to publish KV related information
("VLLM_KV_NAMESPACE", "dynamo".to_string()),
("VLLM_KV_COMPONENT", "vllm".to_string()),
// Worker ID used for identifying workers in distributed settings
("VLLM_WORKER_ID", "0".to_string()),
];
for (kvar, default_v) in vars {
if env::var(kvar).is_err() {
env::set_var(kvar, default_v);
}
}
Ok(())
}
...@@ -36,6 +36,7 @@ arg_map = { ...@@ -36,6 +36,7 @@ arg_map = {
"max_seq_len_to_capture": 8192, "max_seq_len_to_capture": 8192,
"tensor_parallel_size": int(tp_size_str), "tensor_parallel_size": int(tp_size_str),
"pipeline_parallel_size": int(nnodes_str), "pipeline_parallel_size": int(nnodes_str),
"enable_prefix_caching": enable_prefix_caching.lower() == "true",
} }
json_map = {} json_map = {}
if extra_engine_args != "": if extra_engine_args != "":
......
...@@ -33,11 +33,11 @@ use tokio::io::AsyncBufReadExt; ...@@ -33,11 +33,11 @@ use tokio::io::AsyncBufReadExt;
use tokio::sync::mpsc::{error::SendError, Sender}; use tokio::sync::mpsc::{error::SendError, Sender};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::kv_router::protocols::ForwardPassMetrics; use dynamo_llm::kv_router::protocols::ForwardPassMetrics;
use dynamo_llm::protocols::common::llm_backend::LLMEngineOutput; use dynamo_llm::protocols::common::llm_backend::LLMEngineOutput;
use dynamo_llm::protocols::common::preprocessor::PreprocessedRequest; use dynamo_llm::protocols::common::preprocessor::PreprocessedRequest;
use dynamo_llm::protocols::common::FinishReason; use dynamo_llm::protocols::common::FinishReason;
use dynamo_llm::{engines::MultiNodeConfig, kv_router::publisher::KvMetricsPublisher};
/// Wait this long for the vllm sub-process to stop after we send it a KILL /// Wait this long for the vllm sub-process to stop after we send it a KILL
const VLLM_STOP_TIMEOUT: Duration = Duration::from_millis(1500); const VLLM_STOP_TIMEOUT: Duration = Duration::from_millis(1500);
...@@ -164,6 +164,8 @@ pub async fn start( ...@@ -164,6 +164,8 @@ pub async fn start(
_node_conf: MultiNodeConfig, _node_conf: MultiNodeConfig,
tensor_parallel_size: u32, tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
// When using our vllm fork, this is how we publish it's KV metrics for the KV router
kv_metrics_publisher: Option<Arc<KvMetricsPublisher>>,
) -> anyhow::Result<VllmWorker> { ) -> anyhow::Result<VllmWorker> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize" pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
if let Ok(venv) = env::var("VIRTUAL_ENV") { if let Ok(venv) = env::var("VIRTUAL_ENV") {
...@@ -186,12 +188,17 @@ pub async fn start( ...@@ -186,12 +188,17 @@ pub async fn start(
data, data,
tensor_parallel_size, tensor_parallel_size,
extra_engine_args, extra_engine_args,
kv_metrics_publisher.is_some(),
) )
.await?; .await?;
let vllm_join_handle = watch_vllm(cancel_token.clone(), vllm_process); let vllm_join_handle = watch_vllm(cancel_token.clone(), vllm_process);
tokio::spawn(heartbeat_loop(cancel_token.clone(), heartbeat)); tokio::spawn(heartbeat_loop(cancel_token.clone(), heartbeat));
tokio::spawn(metrics_loop(cancel_token.clone(), metrics)); tokio::spawn(metrics_loop(
cancel_token.clone(),
metrics,
kv_metrics_publisher.clone(),
));
let active_requests = Arc::new(tokio::sync::Mutex::new(HashMap::new())); let active_requests = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let (tx, rx) = tokio::sync::mpsc::channel(8); let (tx, rx) = tokio::sync::mpsc::channel(8);
...@@ -313,6 +320,7 @@ async fn start_vllm( ...@@ -313,6 +320,7 @@ async fn start_vllm(
mut data_socket: async_zmq::Dealer<IntoIter<Vec<u8>>, Vec<u8>>, mut data_socket: async_zmq::Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
tensor_parallel_size: u32, tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
with_kv_routing: bool,
) -> anyhow::Result<tokio::process::Child> { ) -> anyhow::Result<tokio::process::Child> {
let mut vllm_args = vec![ let mut vllm_args = vec![
"--internal-vllm-process".to_string(), "--internal-vllm-process".to_string(),
...@@ -322,6 +330,9 @@ async fn start_vllm( ...@@ -322,6 +330,9 @@ async fn start_vllm(
if let Some(args_path) = extra_engine_args { if let Some(args_path) = extra_engine_args {
vllm_args.push(format!("--extra-engine-args={}", args_path.display())); vllm_args.push(format!("--extra-engine-args={}", args_path.display()));
} }
if with_kv_routing {
vllm_args.push("--router-mode=kv".to_string());
}
let self_path = std::env::current_exe()?; let self_path = std::env::current_exe()?;
let mut proc = tokio::process::Command::new(self_path) let mut proc = tokio::process::Command::new(self_path)
...@@ -475,7 +486,11 @@ async fn heartbeat_loop(cancel_token: CancellationToken, mut socket: async_zmq:: ...@@ -475,7 +486,11 @@ async fn heartbeat_loop(cancel_token: CancellationToken, mut socket: async_zmq::
} }
// NOTE: Custom to our patch of vllm. // NOTE: Custom to our patch of vllm.
async fn metrics_loop(cancel_token: CancellationToken, mut socket: async_zmq::Pull) { async fn metrics_loop(
cancel_token: CancellationToken,
mut socket: async_zmq::Pull,
publisher: Option<Arc<KvMetricsPublisher>>,
) {
loop { loop {
let maybe_metrics = tokio::select! { let maybe_metrics = tokio::select! {
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => {
...@@ -551,15 +566,14 @@ async fn metrics_loop(cancel_token: CancellationToken, mut socket: async_zmq::Pu ...@@ -551,15 +566,14 @@ async fn metrics_loop(cancel_token: CancellationToken, mut socket: async_zmq::Pu
match metrics_result { match metrics_result {
Ok(metrics) => { Ok(metrics) => {
// TODO: These metrics could be attached to StatsHandler or Events if let Some(metrics_publisher) = publisher.as_ref() {
// for aggregation and visualization. if let Err(err) = metrics_publisher.publish(metrics.into()) {
tracing::debug!("Received vllm metrics: {:?}", metrics); tracing::error!(%err, "Failed publishing KV metrics");
}
}
} }
Err(err) => { Err(err) => {
tracing::error!( tracing::error!("Error deserializing vllm metrics with Python pickle: {err}");
"Error deserializing vllm metrics with Python pickle: {}",
err
);
} }
} }
} }
......
...@@ -22,7 +22,7 @@ use dynamo_runtime::{ ...@@ -22,7 +22,7 @@ use dynamo_runtime::{
protocols::{self, annotated::Annotated}, protocols::{self, annotated::Annotated},
raise, raise,
transports::etcd::{KeyValue, WatchEvent}, transports::etcd::{KeyValue, WatchEvent},
DistributedRuntime, Result, DistributedRuntime,
}; };
use super::ModelManager; use super::ModelManager;
...@@ -60,14 +60,41 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver< ...@@ -60,14 +60,41 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver<
while let Some(event) = events_rx.recv().await { while let Some(event) = events_rx.recv().await {
match event { match event {
WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await { WatchEvent::Put(kv) => {
let key = match kv.key_str() {
Ok(key) => key,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid UTF8 in model key");
continue;
}
};
tracing::debug!(key, "adding model");
// model_entry.name is the service name (e.g. "Llama-3.2-3B-Instruct")
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
Ok(model_entry) => model_entry,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid JSON in model entry");
continue;
}
};
if state.manager.has_model_any(&model_entry.name) {
tracing::trace!(
service_name = model_entry.name,
"New endpoint for existing model"
);
continue;
}
match handle_put(model_entry, state.clone()).await {
Ok((model_name, model_type)) => { Ok((model_name, model_type)) => {
tracing::info!("added {} model: {}", model_type, model_name); tracing::info!("added {} model: {}", model_type, model_name);
} }
Err(e) => { Err(e) => {
tracing::error!("error adding model: {}", e); tracing::error!("error adding model: {}", e);
} }
}, }
}
WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await { WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
Ok((model_name, model_type)) => { Ok((model_name, model_type)) => {
tracing::info!("removed {} model: {}", model_type, model_name); tracing::info!("removed {} model: {}", model_type, model_name);
...@@ -80,7 +107,10 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver< ...@@ -80,7 +107,10 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver<
} }
} }
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> { async fn handle_delete(
kv: &KeyValue,
state: Arc<ModelWatchState>,
) -> anyhow::Result<(&str, ModelType)> {
let key = kv.key_str()?; let key = kv.key_str()?;
tracing::debug!(key, "removing model"); tracing::debug!(key, "removing model");
...@@ -98,14 +128,10 @@ async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&s ...@@ -98,14 +128,10 @@ async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&s
// models. // models.
// //
// If this method errors, for the near term, we will delete the offending key. // If this method errors, for the near term, we will delete the offending key.
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(String, ModelType)> { async fn handle_put(
let key = kv.key_str()?; model_entry: ModelEntry,
tracing::debug!(key, "adding model"); state: Arc<ModelWatchState>,
) -> anyhow::Result<(String, ModelType)> {
// model_entry.name is the service name (e.g. "Llama-3.2-3B-Instruct")
let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?;
let service_name = model_entry.name.clone();
if model_entry.model_type != state.model_type { if model_entry.model_type != state.model_type {
raise!( raise!(
"model type mismatch: {} != {}", "model type mismatch: {} != {}",
...@@ -125,7 +151,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(Strin ...@@ -125,7 +151,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(Strin
.await?; .await?;
state state
.manager .manager
.add_chat_completions_model(&service_name, Arc::new(client))?; .add_chat_completions_model(&model_entry.name, Arc::new(client))?;
} }
ModelType::Completion => { ModelType::Completion => {
let client = state let client = state
...@@ -137,9 +163,9 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(Strin ...@@ -137,9 +163,9 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(Strin
.await?; .await?;
state state
.manager .manager
.add_completions_model(&service_name, Arc::new(client))?; .add_completions_model(&model_entry.name, Arc::new(client))?;
} }
} }
Ok((service_name, state.model_type)) Ok((model_entry.name, state.model_type))
} }
...@@ -68,7 +68,7 @@ mod namespace; ...@@ -68,7 +68,7 @@ mod namespace;
mod registry; mod registry;
pub mod service; pub mod service;
pub use client::Client; pub use client::{Client, RouterMode};
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
......
...@@ -48,12 +48,25 @@ enum EndpointEvent { ...@@ -48,12 +48,25 @@ enum EndpointEvent {
Delete(String), Delete(String),
} }
#[derive(Default, Debug, Clone, Copy)]
pub enum RouterMode {
#[default]
Random,
RoundRobin,
//KV,
//
// Always and only go to the given endpoint ID.
// TODO: Is this useful?
Direct(i64),
}
#[derive(Clone)] #[derive(Clone)]
pub struct Client<T: Data, U: Data> { pub struct Client<T: Data, U: Data> {
endpoint: Endpoint, endpoint: Endpoint,
router: PushRouter<T, U>, router: PushRouter<T, U>,
counter: Arc<AtomicU64>, counter: Arc<AtomicU64>,
endpoints: EndpointSource, endpoints: EndpointSource,
router_mode: RouterMode,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
...@@ -74,6 +87,7 @@ where ...@@ -74,6 +87,7 @@ where
endpoint, endpoint,
counter: Arc::new(AtomicU64::new(0)), counter: Arc::new(AtomicU64::new(0)),
endpoints: EndpointSource::Static, endpoints: EndpointSource::Static,
router_mode: Default::default(),
}) })
} }
...@@ -157,6 +171,7 @@ where ...@@ -157,6 +171,7 @@ where
endpoint, endpoint,
counter: Arc::new(AtomicU64::new(0)), counter: Arc::new(AtomicU64::new(0)),
endpoints: EndpointSource::Dynamic(watch_rx), endpoints: EndpointSource::Dynamic(watch_rx),
router_mode: Default::default(),
}) })
} }
...@@ -177,6 +192,10 @@ where ...@@ -177,6 +192,10 @@ where
} }
} }
pub fn set_router_mode(&mut self, mode: RouterMode) {
self.router_mode = mode
}
/// Wait for at least one [`Endpoint`] to be available /// Wait for at least one [`Endpoint`] to be available
pub async fn wait_for_endpoints(&self) -> Result<()> { pub async fn wait_for_endpoints(&self) -> Result<()> {
if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() { if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() {
...@@ -213,6 +232,7 @@ where ...@@ -213,6 +232,7 @@ where
let offset = counter % count as u64; let offset = counter % count as u64;
endpoints[offset as usize] endpoints[offset as usize]
}; };
tracing::trace!("round robin router selected {endpoint_id}");
let subject = self.endpoint.subject_to(endpoint_id); let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject)); let request = request.map(|req| AddressedRequest::new(req, subject));
...@@ -235,6 +255,7 @@ where ...@@ -235,6 +255,7 @@ where
let offset = counter % count as u64; let offset = counter % count as u64;
endpoints[offset as usize] endpoints[offset as usize]
}; };
tracing::trace!("random router selected {endpoint_id}");
let subject = self.endpoint.subject_to(endpoint_id); let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject)); let request = request.map(|req| AddressedRequest::new(req, subject));
...@@ -286,10 +307,13 @@ where ...@@ -286,10 +307,13 @@ where
U: Data + for<'de> Deserialize<'de>, U: Data + for<'de> Deserialize<'de>,
{ {
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> { async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
tracing::debug!("Client::generate: {:?}", self.endpoints);
match &self.endpoints { match &self.endpoints {
EndpointSource::Static => self.r#static(request).await, EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => self.random(request).await, EndpointSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
},
} }
} }
} }
...@@ -88,6 +88,10 @@ impl DistributedRuntime { ...@@ -88,6 +88,10 @@ impl DistributedRuntime {
&self.runtime &self.runtime
} }
pub fn primary_token(&self) -> CancellationToken {
self.runtime.primary_token()
}
/// The etcd lease all our components will be attached to. /// The etcd lease all our components will be attached to.
/// Not available for static workers. /// Not available for static workers.
pub fn primary_lease(&self) -> Option<etcd::Lease> { pub fn primary_lease(&self) -> Option<etcd::Lease> {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment