Commit ec2e7307 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(dynamo-run): Basic routing choice (#524)

As a first step towards KV routing:
- introduce a `--router-mode` in dynamo-run that only does random and round-robin right now. Not that interesting yet.
- Make the vllm engine publish the KV events received from our patched vllm.

Now we "just" need to connect the two. Easy right?
parent f91d5488
......@@ -17,6 +17,9 @@ use std::collections::HashMap;
use std::path::PathBuf;
use std::str::FromStr;
use clap::ValueEnum;
use dynamo_runtime::component::RouterMode as RuntimeRouterMode;
/// Required options depend on the in and out choices
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)]
......@@ -92,6 +95,13 @@ pub struct Flags {
#[arg(long)]
pub leader_addr: Option<String>,
/// If using `out=dyn://..` with multiple backends, this says how to route the requests.
///
/// Mostly interesting for KV-aware routing.
/// Defaults to RouterMode::Random
#[arg(long, default_value = "random")]
pub router_mode: RouterMode,
/// Internal use only.
// Start the python vllm engine sub-process.
#[arg(long, hide = true, default_value = "false")]
......@@ -198,3 +208,29 @@ fn parse_sglang_flags(s: &str) -> Result<SgLangFlags, String> {
gpu_id: nums[2],
})
}
#[derive(Default, PartialEq, Eq, ValueEnum, Clone, Debug)]
pub enum RouterMode {
#[default]
Random,
#[value(name = "round-robin")]
RoundRobin,
#[value(name = "kv")]
KV,
}
impl RouterMode {
pub fn is_kv_routing(&self) -> bool {
*self == RouterMode::KV
}
}
impl From<RouterMode> for RuntimeRouterMode {
fn from(r: RouterMode) -> RuntimeRouterMode {
match r {
RouterMode::RoundRobin => RuntimeRouterMode::RoundRobin,
RouterMode::KV => todo!("KV not implemented yet"),
_ => RuntimeRouterMode::Random,
}
}
}
......@@ -31,7 +31,7 @@ use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use crate::input::common;
use crate::EngineConfig;
use crate::{EngineConfig, Flags};
/// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model
......@@ -64,11 +64,12 @@ struct Entry {
pub async fn run(
runtime: Runtime,
cancel_token: CancellationToken,
flags: Flags,
maybe_card: Option<ModelDeploymentCard>,
input_jsonl: PathBuf,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
// Check if the path exists and is a directory
if !input_jsonl.exists() || !input_jsonl.is_file() {
anyhow::bail!(
......@@ -78,7 +79,7 @@ pub async fn run(
}
let (service_name, engine, _inspect_template) =
common::prepare_engine(runtime.clone(), engine_config).await?;
common::prepare_engine(runtime, flags, engine_config).await?;
let service_name_ref = Arc::new(service_name);
let pre_processor = if let Some(card) = maybe_card {
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::EngineConfig;
use crate::{flags::RouterMode, EngineConfig, Flags};
use dynamo_llm::{
backend::Backend,
preprocessor::OpenAIPreprocessor,
......@@ -34,6 +34,7 @@ use std::sync::Arc;
/// Turns an EngineConfig into an OpenAIChatCompletionsStreamingEngine.
pub async fn prepare_engine(
runtime: Runtime,
flags: Flags,
engine_config: EngineConfig,
) -> anyhow::Result<(String, OpenAIChatCompletionsStreamingEngine, bool)> {
match engine_config {
......@@ -41,14 +42,21 @@ pub async fn prepare_engine(
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint = distributed_runtime
.namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?
.endpoint(endpoint_id.name);
.namespace(endpoint_id.namespace.clone())?
.component(endpoint_id.component.clone())?
.endpoint(endpoint_id.name.clone());
let client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?;
tracing::info!("Model discovered");
let mut client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
match &flags.router_mode {
RouterMode::Random | RouterMode::RoundRobin => {
client.set_router_mode(flags.router_mode.into());
tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?;
tracing::info!("Model discovered");
}
RouterMode::KV => todo!(),
}
// The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration.
......
......@@ -28,21 +28,22 @@ use dynamo_llm::{
use dynamo_runtime::pipeline::{
network::Ingress, ManyOut, Operator, SegmentSource, ServiceBackend, SingleIn, Source,
};
use dynamo_runtime::{protocols::Endpoint, DistributedRuntime, Runtime};
use dynamo_runtime::{protocols::Endpoint, DistributedRuntime};
use crate::EngineConfig;
pub async fn run(
runtime: Runtime,
distributed_runtime: DistributedRuntime,
path: String,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
// This will attempt to connect to NATS and etcd
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let cancel_token = runtime.primary_token().clone();
let cancel_token = distributed_runtime.primary_token().clone();
let endpoint_id: Endpoint = path.parse()?;
let etcd_client = distributed_runtime.etcd_client();
let (ingress, service_name) = match engine_config {
EngineConfig::StaticFull {
service_name,
......@@ -85,7 +86,7 @@ pub async fn run(
model_type: ModelType::Chat,
};
let component = distributed
let component = distributed_runtime
.namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?;
let endpoint = component
......@@ -94,8 +95,8 @@ pub async fn run(
.await?
.endpoint(endpoint_id.name);
if let Some(etcd_client) = distributed.etcd_client() {
let network_name = endpoint.subject();
if let Some(etcd_client) = etcd_client {
let network_name = endpoint.subject_to(etcd_client.lease_id());
tracing::debug!("Registering with etcd as {network_name}");
etcd_client
.kv_create(
......
......@@ -32,16 +32,16 @@ use dynamo_runtime::{
DistributedRuntime, Runtime,
};
use crate::EngineConfig;
use crate::{EngineConfig, Flags};
/// Build and run an HTTP service
pub async fn run(
runtime: Runtime,
http_port: u16,
flags: Flags,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let http_service = service_v2::HttpService::builder()
.port(http_port)
.port(flags.http_port)
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.build()?;
......
......@@ -21,7 +21,7 @@ use futures::StreamExt;
use std::io::{ErrorKind, Write};
use crate::input::common;
use crate::EngineConfig;
use crate::{EngineConfig, Flags};
/// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this
......@@ -29,15 +29,16 @@ const MAX_TOKENS: u32 = 8192;
pub async fn run(
runtime: Runtime,
cancel_token: CancellationToken,
flags: Flags,
single_prompt: Option<String>,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
let (service_name, engine, inspect_template): (
String,
OpenAIChatCompletionsStreamingEngine,
bool,
) = common::prepare_engine(runtime.clone(), engine_config).await?;
) = common::prepare_engine(runtime, flags, engine_config).await?;
main_loop(
cancel_token,
&service_name,
......
......@@ -13,15 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::Read;
#[cfg(any(feature = "vllm", feature = "sglang"))]
use std::{future::Future, pin::Pin};
use std::{io::Read, sync::Arc};
use dynamo_llm::{
backend::ExecutionContext, model_card::model::ModelDeploymentCard,
backend::ExecutionContext, kv_router::publisher::KvMetricsPublisher,
model_card::model::ModelDeploymentCard,
types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine,
};
use dynamo_runtime::protocols::Endpoint;
use dynamo_runtime::{protocols::Endpoint, DistributedRuntime};
mod flags;
pub use flags::Flags;
......@@ -41,6 +42,9 @@ const ENDPOINT_SCHEME: &str = "dyn://";
/// the command line. Hence it's optional, and defaults to this.
const INVISIBLE_MODEL_NAME: &str = "dynamo-run";
/// The component name for the KV publisher, if used
const KV_PUBLISHER_COMPONENT: &str = "kvpublisher";
/// How we identify a python string endpoint
#[cfg(feature = "python")]
const PYTHON_STR_SCHEME: &str = "pystr:";
......@@ -70,6 +74,12 @@ pub enum EngineConfig {
None,
}
/// Distributed system values
struct DynInput {
endpoint_id: Endpoint,
distributed_runtime: DistributedRuntime,
}
#[allow(unused_mut)]
pub async fn run(
runtime: dynamo_runtime::Runtime,
......@@ -171,6 +181,19 @@ pub async fn run(
}
};
// If we are in a distributed system, we need to know our component upfront
let dyn_input = match &in_opt {
Input::Endpoint(endpoint_path) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint_id: Endpoint = endpoint_path.parse()?;
Some(DynInput {
endpoint_id,
distributed_runtime,
})
}
_ => None,
};
#[cfg(any(feature = "vllm", feature = "sglang"))]
let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process
......@@ -219,7 +242,6 @@ pub async fn run(
}
#[cfg(feature = "sglang")]
Output::SgLang => {
use dynamo_engine_sglang;
let Some(model_path) = model_path else {
anyhow::bail!("out=sglang requires flag --model-path=<full-path-to-model-dir>");
};
......@@ -234,7 +256,7 @@ pub async fn run(
let node_conf = dynamo_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
leader_addr: flags.leader_addr.unwrap_or_default(),
leader_addr: flags.leader_addr.clone().unwrap_or_default(),
};
if node_conf.num_nodes > 1 {
if let Ok(Some(if_name)) = net::get_primary_interface().await {
......@@ -256,7 +278,7 @@ pub async fn run(
node_conf,
flags.tensor_parallel_size,
flags.base_gpu_id,
flags.extra_engine_args,
flags.extra_engine_args.clone(),
)
.await?;
extra = Some(Box::pin(async move {
......@@ -289,7 +311,7 @@ pub async fn run(
let node_conf = dynamo_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
leader_addr: flags.leader_addr.unwrap_or_default(),
leader_addr: flags.leader_addr.clone().unwrap_or_default(),
};
if node_conf.num_nodes > 1 {
if let Ok(Some(if_name)) = net::get_primary_interface().await {
......@@ -302,6 +324,19 @@ pub async fn run(
}
}
if node_conf.node_rank == 0 {
let kv_metrics_publisher = if let Some(dyn_input) = &dyn_input {
let kvp_component = dyn_input
.distributed_runtime
.namespace(dyn_input.endpoint_id.namespace.clone())?
.component(KV_PUBLISHER_COMPONENT)?;
let kvp = Arc::new(KvMetricsPublisher::new()?);
let kvp_inner = kvp.clone();
tokio::spawn(async move { kvp_inner.create_endpoint(kvp_component).await });
Some(kvp)
} else {
None
};
// vllm multi-node only the leader runs vllm
let (engine, vllm_future) = dynamo_engine_vllm::make_leader_engine(
cancel_token.clone(),
......@@ -309,7 +344,8 @@ pub async fn run(
&sock_prefix,
node_conf,
flags.tensor_parallel_size,
flags.extra_engine_args,
flags.extra_engine_args.clone(),
kv_metrics_publisher,
)
.await?;
extra = Some(Box::pin(async move {
......@@ -330,7 +366,6 @@ pub async fn run(
}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
use dynamo_engine_llamacpp;
let Some(model_path) = model_path else {
anyhow::bail!("out=llamacpp requires flag --model-path=<full-path-to-model-gguf>");
};
......@@ -408,35 +443,25 @@ pub async fn run(
match in_opt {
Input::Http => {
crate::input::http::run(runtime.clone(), flags.http_port, engine_config).await?;
crate::input::http::run(runtime.clone(), flags, engine_config).await?;
}
Input::Text => {
crate::input::text::run(runtime.clone(), cancel_token.clone(), None, engine_config)
.await?;
crate::input::text::run(runtime.clone(), flags, None, engine_config).await?;
}
Input::Stdin => {
let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap();
crate::input::text::run(
runtime.clone(),
cancel_token.clone(),
Some(prompt),
engine_config,
)
.await?;
crate::input::text::run(runtime.clone(), flags, Some(prompt), engine_config).await?;
}
Input::Batch(path) => {
crate::input::batch::run(
runtime.clone(),
cancel_token.clone(),
maybe_card,
path,
engine_config,
)
.await?;
crate::input::batch::run(runtime.clone(), flags, maybe_card, path, engine_config)
.await?;
}
Input::Endpoint(path) => {
crate::input::endpoint::run(runtime.clone(), path, engine_config).await?;
let Some(dyn_input) = dyn_input else {
unreachable!("We set dyn_input earlier");
};
crate::input::endpoint::run(dyn_input.distributed_runtime, path, engine_config).await?;
}
Input::None => {
// Multi-node setup. The engine sub-process has been started and is talking
......
......@@ -32,7 +32,7 @@ Example:
const ZMQ_SOCKET_PREFIX: &str = "dyn";
const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>|none] out=[See available engines below] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json]";
const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>|none] out=[See available engines below] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin]";
fn main() -> anyhow::Result<()> {
logging::init();
......@@ -94,6 +94,7 @@ fn main() -> anyhow::Result<()> {
node_config,
flags.tensor_parallel_size,
flags.extra_engine_args,
flags.router_mode.is_kv_routing(),
);
}
} else {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod common;
pub mod echo_core;
pub mod echo_full;
......@@ -14,11 +14,13 @@
// limitations under the License.
use std::path::{Path, PathBuf};
use std::sync::Arc;
use async_stream::stream;
use async_trait::async_trait;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::kv_router::publisher::KvMetricsPublisher;
use dynamo_llm::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
......@@ -40,6 +42,7 @@ impl VllmEngine {
node_conf: MultiNodeConfig,
tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>,
kv_metrics_publisher: Option<Arc<KvMetricsPublisher>>,
) -> anyhow::Result<Self> {
let w = worker::start(
cancel_token.clone(),
......@@ -48,6 +51,7 @@ impl VllmEngine {
node_conf,
tensor_parallel_size,
extra_engine_args,
kv_metrics_publisher,
)
.await?;
let engine = VllmEngine {
......
......@@ -26,6 +26,7 @@ use dynamo_runtime::CancellationToken;
use dynamo_llm::backend::ExecutionContext;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::kv_router::publisher::KvMetricsPublisher;
mod engine;
use engine::VllmEngine;
......@@ -50,6 +51,8 @@ pub async fn make_leader_engine(
tensor_parallel_size: u32,
// Path to extra engine args file
extra_engine_args: Option<PathBuf>,
// When using our vllm fork, this is how we publish it's KV metrics for the KV router
kv_metrics_publisher: Option<Arc<KvMetricsPublisher>>,
) -> pipeline_error::Result<(ExecutionContext, impl Future<Output = ()>)> {
let ray_obj = if node_conf.num_nodes > 1 {
let r = ray::start_leader(node_conf.leader_addr.parse()?)?;
......@@ -69,6 +72,7 @@ pub async fn make_leader_engine(
node_conf,
tensor_parallel_size,
extra_engine_args,
kv_metrics_publisher,
)
.await?;
let vllm_process = engine.take_vllm_worker_handle();
......
......@@ -31,7 +31,11 @@ pub fn run_subprocess(
node_config: MultiNodeConfig,
tp_size: u32,
extra_engine_args: Option<PathBuf>,
with_kv_routing: bool,
) -> anyhow::Result<()> {
if with_kv_routing {
set_kv_routing_vars()?;
}
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
if let Ok(venv) = env::var("VIRTUAL_ENV") {
let _ = Python::with_gil(|py| crate::fix_venv(venv, py));
......@@ -47,6 +51,7 @@ pub fn run_subprocess(
("tp_size_str", &tp_size.to_string()),
("nnodes_str", &node_config.num_nodes.to_string()),
("extra_engine_args", extra_engine_args_str),
("enable_prefix_caching", &with_kv_routing.to_string()),
]
.into_py_dict(py)
.unwrap();
......@@ -57,3 +62,28 @@ pub fn run_subprocess(
Ok(())
})
}
// These environment variables trigger our vllm patch to emit KV routing events
fn set_kv_routing_vars() -> anyhow::Result<()> {
let exe = env::current_exe()?;
let exe_dir = exe
.parent()
.ok_or(anyhow::anyhow!("Current binary has no directory"))?;
let mut lib = PathBuf::from(exe_dir);
lib.set_file_name("libdynamo_llm_capi.so");
let vars = [
// Path to the C API Library
("VLLM_KV_CAPI_PATH", lib.display().to_string()),
// Identifiers to publish KV related information
("VLLM_KV_NAMESPACE", "dynamo".to_string()),
("VLLM_KV_COMPONENT", "vllm".to_string()),
// Worker ID used for identifying workers in distributed settings
("VLLM_WORKER_ID", "0".to_string()),
];
for (kvar, default_v) in vars {
if env::var(kvar).is_err() {
env::set_var(kvar, default_v);
}
}
Ok(())
}
......@@ -36,6 +36,7 @@ arg_map = {
"max_seq_len_to_capture": 8192,
"tensor_parallel_size": int(tp_size_str),
"pipeline_parallel_size": int(nnodes_str),
"enable_prefix_caching": enable_prefix_caching.lower() == "true",
}
json_map = {}
if extra_engine_args != "":
......
......@@ -33,11 +33,11 @@ use tokio::io::AsyncBufReadExt;
use tokio::sync::mpsc::{error::SendError, Sender};
use tokio::task::JoinHandle;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::kv_router::protocols::ForwardPassMetrics;
use dynamo_llm::protocols::common::llm_backend::LLMEngineOutput;
use dynamo_llm::protocols::common::preprocessor::PreprocessedRequest;
use dynamo_llm::protocols::common::FinishReason;
use dynamo_llm::{engines::MultiNodeConfig, kv_router::publisher::KvMetricsPublisher};
/// Wait this long for the vllm sub-process to stop after we send it a KILL
const VLLM_STOP_TIMEOUT: Duration = Duration::from_millis(1500);
......@@ -164,6 +164,8 @@ pub async fn start(
_node_conf: MultiNodeConfig,
tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>,
// When using our vllm fork, this is how we publish it's KV metrics for the KV router
kv_metrics_publisher: Option<Arc<KvMetricsPublisher>>,
) -> anyhow::Result<VllmWorker> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
if let Ok(venv) = env::var("VIRTUAL_ENV") {
......@@ -186,12 +188,17 @@ pub async fn start(
data,
tensor_parallel_size,
extra_engine_args,
kv_metrics_publisher.is_some(),
)
.await?;
let vllm_join_handle = watch_vllm(cancel_token.clone(), vllm_process);
tokio::spawn(heartbeat_loop(cancel_token.clone(), heartbeat));
tokio::spawn(metrics_loop(cancel_token.clone(), metrics));
tokio::spawn(metrics_loop(
cancel_token.clone(),
metrics,
kv_metrics_publisher.clone(),
));
let active_requests = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let (tx, rx) = tokio::sync::mpsc::channel(8);
......@@ -313,6 +320,7 @@ async fn start_vllm(
mut data_socket: async_zmq::Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>,
with_kv_routing: bool,
) -> anyhow::Result<tokio::process::Child> {
let mut vllm_args = vec![
"--internal-vllm-process".to_string(),
......@@ -322,6 +330,9 @@ async fn start_vllm(
if let Some(args_path) = extra_engine_args {
vllm_args.push(format!("--extra-engine-args={}", args_path.display()));
}
if with_kv_routing {
vllm_args.push("--router-mode=kv".to_string());
}
let self_path = std::env::current_exe()?;
let mut proc = tokio::process::Command::new(self_path)
......@@ -475,7 +486,11 @@ async fn heartbeat_loop(cancel_token: CancellationToken, mut socket: async_zmq::
}
// NOTE: Custom to our patch of vllm.
async fn metrics_loop(cancel_token: CancellationToken, mut socket: async_zmq::Pull) {
async fn metrics_loop(
cancel_token: CancellationToken,
mut socket: async_zmq::Pull,
publisher: Option<Arc<KvMetricsPublisher>>,
) {
loop {
let maybe_metrics = tokio::select! {
_ = cancel_token.cancelled() => {
......@@ -551,15 +566,14 @@ async fn metrics_loop(cancel_token: CancellationToken, mut socket: async_zmq::Pu
match metrics_result {
Ok(metrics) => {
// TODO: These metrics could be attached to StatsHandler or Events
// for aggregation and visualization.
tracing::debug!("Received vllm metrics: {:?}", metrics);
if let Some(metrics_publisher) = publisher.as_ref() {
if let Err(err) = metrics_publisher.publish(metrics.into()) {
tracing::error!(%err, "Failed publishing KV metrics");
}
}
}
Err(err) => {
tracing::error!(
"Error deserializing vllm metrics with Python pickle: {}",
err
);
tracing::error!("Error deserializing vllm metrics with Python pickle: {err}");
}
}
}
......
......@@ -22,7 +22,7 @@ use dynamo_runtime::{
protocols::{self, annotated::Annotated},
raise,
transports::etcd::{KeyValue, WatchEvent},
DistributedRuntime, Result,
DistributedRuntime,
};
use super::ModelManager;
......@@ -60,14 +60,41 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver<
while let Some(event) = events_rx.recv().await {
match event {
WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await {
Ok((model_name, model_type)) => {
tracing::info!("added {} model: {}", model_type, model_name);
WatchEvent::Put(kv) => {
let key = match kv.key_str() {
Ok(key) => key,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid UTF8 in model key");
continue;
}
};
tracing::debug!(key, "adding model");
// model_entry.name is the service name (e.g. "Llama-3.2-3B-Instruct")
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
Ok(model_entry) => model_entry,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid JSON in model entry");
continue;
}
};
if state.manager.has_model_any(&model_entry.name) {
tracing::trace!(
service_name = model_entry.name,
"New endpoint for existing model"
);
continue;
}
Err(e) => {
tracing::error!("error adding model: {}", e);
match handle_put(model_entry, state.clone()).await {
Ok((model_name, model_type)) => {
tracing::info!("added {} model: {}", model_type, model_name);
}
Err(e) => {
tracing::error!("error adding model: {}", e);
}
}
},
}
WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
Ok((model_name, model_type)) => {
tracing::info!("removed {} model: {}", model_type, model_name);
......@@ -80,7 +107,10 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver<
}
}
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> {
async fn handle_delete(
kv: &KeyValue,
state: Arc<ModelWatchState>,
) -> anyhow::Result<(&str, ModelType)> {
let key = kv.key_str()?;
tracing::debug!(key, "removing model");
......@@ -98,14 +128,10 @@ async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&s
// models.
//
// If this method errors, for the near term, we will delete the offending key.
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(String, ModelType)> {
let key = kv.key_str()?;
tracing::debug!(key, "adding model");
// model_entry.name is the service name (e.g. "Llama-3.2-3B-Instruct")
let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?;
let service_name = model_entry.name.clone();
async fn handle_put(
model_entry: ModelEntry,
state: Arc<ModelWatchState>,
) -> anyhow::Result<(String, ModelType)> {
if model_entry.model_type != state.model_type {
raise!(
"model type mismatch: {} != {}",
......@@ -125,7 +151,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(Strin
.await?;
state
.manager
.add_chat_completions_model(&service_name, Arc::new(client))?;
.add_chat_completions_model(&model_entry.name, Arc::new(client))?;
}
ModelType::Completion => {
let client = state
......@@ -137,9 +163,9 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(Strin
.await?;
state
.manager
.add_completions_model(&service_name, Arc::new(client))?;
.add_completions_model(&model_entry.name, Arc::new(client))?;
}
}
Ok((service_name, state.model_type))
Ok((model_entry.name, state.model_type))
}
......@@ -68,7 +68,7 @@ mod namespace;
mod registry;
pub mod service;
pub use client::Client;
pub use client::{Client, RouterMode};
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
......
......@@ -48,12 +48,25 @@ enum EndpointEvent {
Delete(String),
}
#[derive(Default, Debug, Clone, Copy)]
pub enum RouterMode {
#[default]
Random,
RoundRobin,
//KV,
//
// Always and only go to the given endpoint ID.
// TODO: Is this useful?
Direct(i64),
}
#[derive(Clone)]
pub struct Client<T: Data, U: Data> {
endpoint: Endpoint,
router: PushRouter<T, U>,
counter: Arc<AtomicU64>,
endpoints: EndpointSource,
router_mode: RouterMode,
}
#[derive(Clone, Debug)]
......@@ -74,6 +87,7 @@ where
endpoint,
counter: Arc::new(AtomicU64::new(0)),
endpoints: EndpointSource::Static,
router_mode: Default::default(),
})
}
......@@ -157,6 +171,7 @@ where
endpoint,
counter: Arc::new(AtomicU64::new(0)),
endpoints: EndpointSource::Dynamic(watch_rx),
router_mode: Default::default(),
})
}
......@@ -177,6 +192,10 @@ where
}
}
pub fn set_router_mode(&mut self, mode: RouterMode) {
self.router_mode = mode
}
/// Wait for at least one [`Endpoint`] to be available
pub async fn wait_for_endpoints(&self) -> Result<()> {
if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() {
......@@ -213,6 +232,7 @@ where
let offset = counter % count as u64;
endpoints[offset as usize]
};
tracing::trace!("round robin router selected {endpoint_id}");
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
......@@ -235,6 +255,7 @@ where
let offset = counter % count as u64;
endpoints[offset as usize]
};
tracing::trace!("random router selected {endpoint_id}");
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
......@@ -286,10 +307,13 @@ where
U: Data + for<'de> Deserialize<'de>,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
tracing::debug!("Client::generate: {:?}", self.endpoints);
match &self.endpoints {
EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => self.random(request).await,
EndpointSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
},
}
}
}
......@@ -88,6 +88,10 @@ impl DistributedRuntime {
&self.runtime
}
pub fn primary_token(&self) -> CancellationToken {
self.runtime.primary_token()
}
/// The etcd lease all our components will be attached to.
/// Not available for static workers.
pub fn primary_lease(&self) -> Option<etcd::Lease> {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment