Unverified Commit 889ab67e authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: Fix default RouterMode value (#1092)

The Python bindings use the default value for RouterMode. Previously that was Random (good), but now it became None (bad).

Remove the option and clean up the duplicate RouterMode. I was trying to avoid putting the `KV` enum in dynamo-runtime. Turns out adding those two characters gives us a healthy simplification, and restores the old default router value.

Also clean up two noisy log messages when waiting for KV routing metrics to start in worker.
parent 2a5eb7e7
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use clap::Parser; use clap::Parser;
use std::sync::Arc; use std::sync::Arc;
use dynamo_llm::{ use dynamo_llm::{
http::service::{ http::service::{discovery::ModelWatcher, service_v2::HttpService},
discovery::{LLMRouterMode, ModelWatcher},
service_v2::HttpService,
},
model_type::ModelType, model_type::ModelType,
}; };
use dynamo_runtime::{ use dynamo_runtime::{
logging, transports::etcd::PrefixWatcher, DistributedRuntime, Result, Runtime, Worker, logging, pipeline::RouterMode, transports::etcd::PrefixWatcher, DistributedRuntime, Result,
Runtime, Worker,
}; };
#[derive(Parser)] #[derive(Parser)]
...@@ -89,7 +75,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -89,7 +75,7 @@ async fn app(runtime: Runtime) -> Result<()> {
component.clone(), component.clone(),
manager.clone(), manager.clone(),
&etcd_path, &etcd_path,
LLMRouterMode::Random, RouterMode::Random,
) )
.await?, .await?,
); );
......
...@@ -17,7 +17,6 @@ use std::collections::HashMap; ...@@ -17,7 +17,6 @@ use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
use clap::ValueEnum; use clap::ValueEnum;
use dynamo_llm::http::service::discovery::LLMRouterMode;
use dynamo_runtime::pipeline::RouterMode as RuntimeRouterMode; use dynamo_runtime::pipeline::RouterMode as RuntimeRouterMode;
/// Required options depend on the in and out choices /// Required options depend on the in and out choices
...@@ -195,25 +194,12 @@ pub enum RouterMode { ...@@ -195,25 +194,12 @@ pub enum RouterMode {
KV, KV,
} }
impl RouterMode { impl From<RouterMode> for RuntimeRouterMode {
pub fn is_kv_routing(&self) -> bool { fn from(r: RouterMode) -> RuntimeRouterMode {
*self == RouterMode::KV match r {
} RouterMode::RoundRobin => RuntimeRouterMode::RoundRobin,
RouterMode::Random => RuntimeRouterMode::Random,
pub fn as_runtime(&self) -> Option<RuntimeRouterMode> { RouterMode::KV => RuntimeRouterMode::KV,
match self {
RouterMode::RoundRobin => Some(RuntimeRouterMode::RoundRobin),
RouterMode::Random => Some(RuntimeRouterMode::Random),
// Runtime router does not have KV, it's a dynamo-llm thing, not dynamo-runtime
RouterMode::KV => None,
}
}
pub fn as_llm(&self) -> LLMRouterMode {
match self {
RouterMode::RoundRobin => LLMRouterMode::RoundRobin,
RouterMode::Random => LLMRouterMode::Random,
RouterMode::KV => LLMRouterMode::KV,
} }
} }
} }
...@@ -102,7 +102,7 @@ pub async fn prepare_engine( ...@@ -102,7 +102,7 @@ pub async fn prepare_engine(
let router = let router =
PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client( PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client, client,
flags.router_mode.as_runtime(), flags.router_mode.into(),
) )
.await?; .await?;
let service_backend = match &flags.router_mode { let service_backend = match &flags.router_mode {
...@@ -133,7 +133,7 @@ pub async fn prepare_engine( ...@@ -133,7 +133,7 @@ pub async fn prepare_engine(
PushRouter::< PushRouter::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>, Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, flags.router_mode.as_runtime()) >::from_client(client, flags.router_mode.into())
.await?, .await?,
), ),
ModelType::Completion => { ModelType::Completion => {
......
...@@ -17,7 +17,6 @@ use std::sync::Arc; ...@@ -17,7 +17,6 @@ use std::sync::Arc;
use crate::input::common; use crate::input::common;
use crate::{EngineConfig, Flags}; use crate::{EngineConfig, Flags};
use dynamo_llm::http::service::discovery::LLMRouterMode;
use dynamo_llm::http::service::ModelManager; use dynamo_llm::http::service::ModelManager;
use dynamo_llm::{ use dynamo_llm::{
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
...@@ -31,6 +30,7 @@ use dynamo_llm::{ ...@@ -31,6 +30,7 @@ use dynamo_llm::{
}, },
}; };
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::transports::etcd; use dynamo_runtime::transports::etcd;
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
...@@ -65,7 +65,7 @@ pub async fn run( ...@@ -65,7 +65,7 @@ pub async fn run(
http_service.model_manager().clone(), http_service.model_manager().clone(),
etcd_client.clone(), etcd_client.clone(),
&network_prefix, &network_prefix,
flags.router_mode.as_llm(), flags.router_mode.into(),
) )
.await?; .await?;
} }
...@@ -121,7 +121,7 @@ async fn run_watcher( ...@@ -121,7 +121,7 @@ async fn run_watcher(
model_manager: ModelManager, model_manager: ModelManager,
etcd_client: etcd::Client, etcd_client: etcd::Client,
network_prefix: &str, network_prefix: &str,
router_mode: LLMRouterMode, router_mode: RouterMode,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let watch_obj = Arc::new( let watch_obj = Arc::new(
discovery::ModelWatcher::new(component, model_manager, network_prefix, router_mode).await?, discovery::ModelWatcher::new(component, model_manager, network_prefix, router_mode).await?,
......
...@@ -76,7 +76,10 @@ pub async fn run( ...@@ -76,7 +76,10 @@ pub async fn run(
} }
None => { None => {
// echo_full engine doesn't need a path // echo_full engine doesn't need a path
Default::default() match &flags.model_name {
Some(name) => LocalModel::with_name_only(name),
None => Default::default(),
}
} }
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc; use std::sync::Arc;
...@@ -22,8 +10,8 @@ use tokio::sync::mpsc::Receiver; ...@@ -22,8 +10,8 @@ use tokio::sync::mpsc::Receiver;
use dynamo_runtime::{ use dynamo_runtime::{
component::{self, Component, ComponentEndpointInfo}, component::{self, Component, ComponentEndpointInfo},
pipeline::{ pipeline::{
network::egress::push_router::PushRouter, ManyOut, Operator, network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
RouterMode as RuntimeRouterMode, SegmentSource, ServiceBackend, SingleIn, Source, ServiceBackend, SingleIn, Source,
}, },
protocols::{self, annotated::Annotated}, protocols::{self, annotated::Annotated},
slug::Slug, slug::Slug,
...@@ -165,33 +153,11 @@ impl std::fmt::Display for ModelNetworkName { ...@@ -165,33 +153,11 @@ impl std::fmt::Display for ModelNetworkName {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LLMRouterMode {
Random,
RoundRobin,
KV,
}
impl LLMRouterMode {
pub fn is_kv_routing(&self) -> bool {
*self == LLMRouterMode::KV
}
pub fn as_runtime(&self) -> Option<RuntimeRouterMode> {
match self {
LLMRouterMode::RoundRobin => Some(RuntimeRouterMode::RoundRobin),
LLMRouterMode::Random => Some(RuntimeRouterMode::Random),
// Runtime router does not have KV, it's a dynamo-llm thing, not dynamo-runtime
LLMRouterMode::KV => None,
}
}
}
pub struct ModelWatcher { pub struct ModelWatcher {
prefix: String, prefix: String,
manager: ModelManager, manager: ModelManager,
drt: DistributedRuntime, drt: DistributedRuntime,
router_mode: LLMRouterMode, router_mode: RouterMode,
kv_chooser: Option<Arc<KvRouter>>, kv_chooser: Option<Arc<KvRouter>>,
} }
...@@ -200,7 +166,7 @@ impl ModelWatcher { ...@@ -200,7 +166,7 @@ impl ModelWatcher {
component: Component, component: Component,
model_manager: ModelManager, model_manager: ModelManager,
network_prefix: &str, network_prefix: &str,
router_mode: LLMRouterMode, router_mode: RouterMode,
) -> anyhow::Result<ModelWatcher> { ) -> anyhow::Result<ModelWatcher> {
let kv_chooser = if router_mode.is_kv_routing() { let kv_chooser = if router_mode.is_kv_routing() {
let selector = Box::new(DefaultWorkerSelector {}); let selector = Box::new(DefaultWorkerSelector {});
...@@ -329,14 +295,14 @@ impl ModelWatcher { ...@@ -329,14 +295,14 @@ impl ModelWatcher {
let backend = Backend::from_mdc(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client( let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client.clone(), client.clone(),
self.router_mode.as_runtime(), self.router_mode,
) )
.await?; .await?;
let service_backend = match self.router_mode { let service_backend = match self.router_mode {
LLMRouterMode::Random | LLMRouterMode::RoundRobin => { RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
ServiceBackend::from_engine(Arc::new(router)) ServiceBackend::from_engine(Arc::new(router))
} }
LLMRouterMode::KV => { RouterMode::KV => {
let Some(kv_chooser) = self.kv_chooser.clone() else { let Some(kv_chooser) = self.kv_chooser.clone() else {
anyhow::bail!("KV routing mode with no chooser, should be unreachable"); anyhow::bail!("KV routing mode with no chooser, should be unreachable");
}; };
...@@ -363,14 +329,14 @@ impl ModelWatcher { ...@@ -363,14 +329,14 @@ impl ModelWatcher {
let backend = Backend::from_mdc(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client( let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client, client,
self.router_mode.as_runtime(), self.router_mode,
) )
.await?; .await?;
let service_backend = match self.router_mode { let service_backend = match self.router_mode {
LLMRouterMode::Random | LLMRouterMode::RoundRobin => { RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
ServiceBackend::from_engine(Arc::new(router)) ServiceBackend::from_engine(Arc::new(router))
} }
LLMRouterMode::KV => { RouterMode::KV => {
let Some(kv_chooser) = self.kv_chooser.clone() else { let Some(kv_chooser) = self.kv_chooser.clone() else {
anyhow::bail!("KV routing mode with no chooser, should be unreachable"); anyhow::bail!("KV routing mode with no chooser, should be unreachable");
}; };
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Once;
pub use crate::kv_router::protocols::ForwardPassMetrics; pub use crate::kv_router::protocols::ForwardPassMetrics;
use crate::kv_router::KV_METRICS_ENDPOINT; use crate::kv_router::KV_METRICS_ENDPOINT;
...@@ -23,6 +25,9 @@ use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; ...@@ -23,6 +25,9 @@ use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result};
use tokio::sync::watch; use tokio::sync::watch;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
static METRICS_WAITING_MESSAGE: Once = Once::new();
static METRICS_FOUND_MESSAGE: Once = Once::new();
pub struct KvMetricsAggregator { pub struct KvMetricsAggregator {
pub service_name: String, pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>, pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
...@@ -71,7 +76,14 @@ pub async fn collect_endpoints( ...@@ -71,7 +76,14 @@ pub async fn collect_endpoints(
.filter(|e| e.subject.starts_with(subject)) .filter(|e| e.subject.starts_with(subject))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if endpoints.is_empty() { if endpoints.is_empty() {
tracing::debug!("Metrics endpoint not visible yet"); // Only print it once, we poll while the worker starts
METRICS_WAITING_MESSAGE.call_once(|| {
tracing::debug!("Waiting for metrics endpoint..");
});
} else {
METRICS_FOUND_MESSAGE.call_once(|| {
tracing::debug!("Found metrics endpoint");
});
} }
Ok(endpoints) Ok(endpoints)
......
...@@ -36,6 +36,13 @@ impl Default for LocalModel { ...@@ -36,6 +36,13 @@ impl Default for LocalModel {
} }
impl LocalModel { impl LocalModel {
pub fn with_name_only(name: &str) -> Self {
LocalModel {
card: ModelDeploymentCard::with_name_only(name),
..Default::default()
}
}
pub fn card(&self) -> &ModelDeploymentCard { pub fn card(&self) -> &ModelDeploymentCard {
&self.card &self.card
} }
......
...@@ -43,11 +43,11 @@ where ...@@ -43,11 +43,11 @@ where
/// How we choose which endpoint to send traffic to. /// How we choose which endpoint to send traffic to.
/// ///
/// Setting this to None means we never intend to call `generate` on this PushRouter. We are /// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
/// not using it as an AsyncEngine. /// not using it as an AsyncEngine.
/// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly. /// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
/// dynamo-llm's KV Routing does this. /// dynamo-llm's KV Routing does this.
router_mode: Option<RouterMode>, router_mode: RouterMode,
/// Number of round robin requests handled. Used to decide which server is next. /// Number of round robin requests handled. Used to decide which server is next.
round_robin_counter: Arc<AtomicU64>, round_robin_counter: Arc<AtomicU64>,
...@@ -62,14 +62,20 @@ where ...@@ -62,14 +62,20 @@ where
_phantom: PhantomData<(T, U)>, _phantom: PhantomData<(T, U)>,
} }
// Note there's no KV router in here because we are in dynamo-runtime. The KvRouter is in #[derive(Default, Debug, Clone, Copy, PartialEq)]
// dynamo-llm.
#[derive(Default, Debug, Clone)]
pub enum RouterMode { pub enum RouterMode {
#[default] #[default]
Random, Random,
RoundRobin, RoundRobin,
Direct(i64), Direct(i64),
// Marker value, KV routing itself is in dynamo-llm
KV,
}
impl RouterMode {
pub fn is_kv_routing(&self) -> bool {
*self == RouterMode::KV
}
} }
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> { async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
...@@ -84,10 +90,7 @@ where ...@@ -84,10 +90,7 @@ where
T: Data + Serialize, T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>, U: Data + for<'de> Deserialize<'de>,
{ {
pub async fn from_client( pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
client: Client,
router_mode: Option<RouterMode>,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?; let addressed = addressed_router(&client.endpoint).await?;
Ok(PushRouter { Ok(PushRouter {
client, client,
...@@ -189,10 +192,10 @@ where ...@@ -189,10 +192,10 @@ where
match &self.client.endpoints { match &self.client.endpoints {
EndpointSource::Static => self.r#static(request).await, EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => match self.router_mode { EndpointSource::Dynamic(_) => match self.router_mode {
Some(RouterMode::Random) => self.random(request).await, RouterMode::Random => self.random(request).await,
Some(RouterMode::RoundRobin) => self.round_robin(request).await, RouterMode::RoundRobin => self.round_robin(request).await,
Some(RouterMode::Direct(endpoint_id)) => self.direct(request, endpoint_id).await, RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
None => { RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter"); anyhow::bail!("KV routing should not call generate on PushRouter");
} }
}, },
......
...@@ -126,19 +126,23 @@ impl ServiceClient { ...@@ -126,19 +126,23 @@ impl ServiceClient {
} }
let deadline = tokio::time::Instant::now() + timeout; let deadline = tokio::time::Instant::now() + timeout;
let services: Vec<ServiceInfo> = stream::until_deadline(sub, deadline) let mut services = vec![];
.map(|message| serde_json::from_slice::<ServiceInfo>(&message.payload)) let mut s = stream::until_deadline(sub, deadline);
.filter_map(|info| async move { while let Some(message) = s.next().await {
match info { if message.payload.is_empty() {
Ok(info) => Some(info), // Expected while we wait for KV metrics in worker to start
Err(e) => { tracing::trace!(service_name, "collect_services: empty payload from nats");
log::debug!("error decoding service info: {:?}", e); continue;
None }
} let info = serde_json::from_slice::<ServiceInfo>(&message.payload);
match info {
Ok(info) => services.push(info),
Err(err) => {
let payload = String::from_utf8_lossy(&message.payload);
tracing::debug!(%err, service_name, %payload, "error decoding service info");
} }
}) }
.collect() }
.await;
Ok(ServiceSet { services }) Ok(ServiceSet { services })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment