Unverified Commit e1af3af6 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Remove static mode (#4235)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent d9b674b8
...@@ -13,7 +13,7 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker ...@@ -13,7 +13,7 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
TEST_END_TO_END = os.environ.get("TEST_END_TO_END", 0) TEST_END_TO_END = os.environ.get("TEST_END_TO_END", 0)
@dynamo_worker(static=False) @dynamo_worker()
async def test_register(runtime: DistributedRuntime): async def test_register(runtime: DistributedRuntime):
component = runtime.namespace("test").component("tensor") component = runtime.namespace("test").component("tensor")
await component.create_service() await component.create_service()
......
...@@ -44,21 +44,16 @@ pub enum EngineConfig { ...@@ -44,21 +44,16 @@ pub enum EngineConfig {
/// Remote networked engines that we discover via etcd /// Remote networked engines that we discover via etcd
Dynamic(Box<LocalModel>), Dynamic(Box<LocalModel>),
/// Remote networked engines that we know about at startup
StaticRemote(Box<LocalModel>),
/// A Full service engine does it's own tokenization and prompt formatting. /// A Full service engine does it's own tokenization and prompt formatting.
StaticFull { StaticFull {
engine: Arc<dyn StreamingEngine>, engine: Arc<dyn StreamingEngine>,
model: Box<LocalModel>, model: Box<LocalModel>,
is_static: bool,
}, },
/// A core engine expects to be wrapped with pre/post processors that handle tokenization. /// A core engine expects to be wrapped with pre/post processors that handle tokenization.
StaticCore { StaticCore {
engine: ExecutionContext, engine: ExecutionContext,
model: Box<LocalModel>, model: Box<LocalModel>,
is_static: bool,
is_prefill: bool, is_prefill: bool,
}, },
} }
...@@ -68,7 +63,6 @@ impl EngineConfig { ...@@ -68,7 +63,6 @@ impl EngineConfig {
use EngineConfig::*; use EngineConfig::*;
match self { match self {
Dynamic(lm) => lm, Dynamic(lm) => lm,
StaticRemote(lm) => lm,
StaticFull { model, .. } => model, StaticFull { model, .. } => model,
StaticCore { model, .. } => model, StaticCore { model, .. } => model,
} }
......
...@@ -7,7 +7,7 @@ use crate::{ ...@@ -7,7 +7,7 @@ use crate::{
backend::{Backend, ExecutionContext}, backend::{Backend, ExecutionContext},
discovery::{ModelManager, ModelWatcher}, discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig}, entrypoint::EngineConfig,
kv_router::{KvPushRouter, KvRouter, PrefillRouter}, kv_router::{KvPushRouter, KvRouter, PrefillRouter},
migration::Migration, migration::Migration,
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
...@@ -95,58 +95,6 @@ pub async fn prepare_engine( ...@@ -95,58 +95,6 @@ pub async fn prepare_engine(
request_template: local_model.request_template(), request_template: local_model.request_template(),
}) })
} }
EngineConfig::StaticRemote(local_model) => {
// The card should have been loaded at 'build' phase earlier
let card = local_model.card();
let router_mode = local_model.router_config().router_mode;
let endpoint_id = local_model.endpoint_id();
let component = distributed_runtime
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?;
let kv_chooser = if router_mode == RouterMode::KV {
let model_manager = Arc::new(ModelManager::new());
Some(
model_manager
.kv_chooser_for(
&component,
card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config),
)
.await?,
)
} else {
None
};
let hf_tokenizer = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
hf_tokenizer,
None, // No prefill chooser in static mode
)
.await?;
let service_name = local_model.service_name().to_string();
tracing::info!("Static connecting to {service_name}");
Ok(PreparedEngine {
service_name,
engine: chat_engine,
inspect_template: false,
request_template: local_model.request_template(),
card: Some(local_model.into_card()),
})
}
EngineConfig::StaticFull { engine, model, .. } => { EngineConfig::StaticFull { engine, model, .. } => {
let service_name = model.service_name().to_string(); let service_name = model.service_name().to_string();
tracing::debug!("Model: {service_name} with engine pre-processing"); tracing::debug!("Model: {service_name} with engine pre-processing");
......
...@@ -43,22 +43,15 @@ pub async fn run( ...@@ -43,22 +43,15 @@ pub async fn run(
let endpoint = component.endpoint(&endpoint_id.name); let endpoint = component.endpoint(&endpoint_id.name);
let rt_fut: Pin<Box<dyn Future<Output = _> + Send + 'static>> = match engine_config { let rt_fut: Pin<Box<dyn Future<Output = _> + Send + 'static>> = match engine_config {
EngineConfig::StaticFull { EngineConfig::StaticFull { engine, mut model } => {
engine,
mut model,
is_static,
} => {
let engine = Arc::new(StreamingEngineAdapter::new(engine)); let engine = Arc::new(StreamingEngineAdapter::new(engine));
let ingress_chat = Ingress::< let ingress_chat = Ingress::<
Context<NvCreateChatCompletionRequest>, Context<NvCreateChatCompletionRequest>,
Pin<Box<dyn AsyncEngineStream<Annotated<NvCreateChatCompletionStreamResponse>>>>, Pin<Box<dyn AsyncEngineStream<Annotated<NvCreateChatCompletionStreamResponse>>>>,
>::for_engine(engine)?; >::for_engine(engine)?;
if !is_static {
model model
.attach(&endpoint, ModelType::Chat, ModelInput::Text) .attach(&endpoint, ModelType::Chat, ModelInput::Text)
.await?; .await?;
}
let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start(); let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start();
Box::pin(fut_chat) Box::pin(fut_chat)
...@@ -66,7 +59,6 @@ pub async fn run( ...@@ -66,7 +59,6 @@ pub async fn run(
EngineConfig::StaticCore { EngineConfig::StaticCore {
engine: inner_engine, engine: inner_engine,
mut model, mut model,
is_static,
is_prefill, is_prefill,
} => { } => {
// Pre-processing is done ingress-side, so it should be already done. // Pre-processing is done ingress-side, so it should be already done.
...@@ -83,7 +75,6 @@ pub async fn run( ...@@ -83,7 +75,6 @@ pub async fn run(
.link(frontend)?; .link(frontend)?;
let ingress = Ingress::for_pipeline(pipeline)?; let ingress = Ingress::for_pipeline(pipeline)?;
if !is_static {
let model_type = if is_prefill { let model_type = if is_prefill {
ModelType::Prefill ModelType::Prefill
} else { } else {
...@@ -92,15 +83,10 @@ pub async fn run( ...@@ -92,15 +83,10 @@ pub async fn run(
model model
.attach(&endpoint, model_type, ModelInput::Tokens) .attach(&endpoint, model_type, ModelInput::Tokens)
.await?; .await?;
}
let fut = endpoint.endpoint_builder().handler(ingress).start(); let fut = endpoint.endpoint_builder().handler(ingress).start();
Box::pin(fut) Box::pin(fut)
} }
EngineConfig::StaticRemote(_) => {
panic!("StaticRemote definitions are only for the frontend end node.");
}
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic(_) => {
unreachable!("An endpoint input will never have a Dynamic engine"); unreachable!("An endpoint input will never have a Dynamic engine");
} }
...@@ -155,7 +141,6 @@ mod integration_tests { ...@@ -155,7 +141,6 @@ mod integration_tests {
.await .await
.map_err(|e| anyhow::anyhow!("Failed to build LocalModel: {}", e))?, .map_err(|e| anyhow::anyhow!("Failed to build LocalModel: {}", e))?,
), ),
is_static: false,
}; };
Ok((distributed_runtime, engine_config)) Ok((distributed_runtime, engine_config))
......
...@@ -6,7 +6,7 @@ use std::sync::Arc; ...@@ -6,7 +6,7 @@ use std::sync::Arc;
use crate::{ use crate::{
discovery::{ModelManager, ModelWatcher}, discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common}, entrypoint::{EngineConfig, input::common},
grpc::service::kserve, grpc::service::kserve,
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
namespace::is_global_namespace, namespace::is_global_namespace,
...@@ -49,75 +49,6 @@ pub async fn run( ...@@ -49,75 +49,6 @@ pub async fn run(
.await?; .await?;
grpc_service grpc_service
} }
EngineConfig::StaticRemote(local_model) => {
let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode;
let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager();
let endpoint_id = local_model.endpoint_id();
let component = distributed_runtime
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?;
let kv_chooser = if router_mode == RouterMode::KV {
Some(
manager
.kv_chooser_for(
&component,
card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config),
)
.await?,
)
} else {
None
};
let tokenizer_hf = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
tokenizer_hf.clone(),
None, // No prefill chooser in grpc static mode
)
.await?;
manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser,
tokenizer_hf,
None, // No prefill chooser in grpc static mode
)
.await?;
manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
grpc_service
}
EngineConfig::StaticFull { engine, model, .. } => { EngineConfig::StaticFull { engine, model, .. } => {
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine)); let engine = Arc::new(StreamingEngineAdapter::new(engine));
......
...@@ -7,7 +7,7 @@ use crate::{ ...@@ -7,7 +7,7 @@ use crate::{
discovery::{ModelManager, ModelUpdate, ModelWatcher}, discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType, endpoint_type::EndpointType,
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common}, entrypoint::{EngineConfig, input::common},
http::service::service_v2::{self, HttpService}, http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
namespace::is_global_namespace, namespace::is_global_namespace,
...@@ -91,78 +91,6 @@ pub async fn run( ...@@ -91,78 +91,6 @@ pub async fn run(
.await?; .await?;
http_service http_service
} }
EngineConfig::StaticRemote(local_model) => {
let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode;
let http_service = http_service_builder.build()?;
let manager = http_service.model_manager();
let endpoint_id = local_model.endpoint_id();
let component = distributed_runtime
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?;
let kv_chooser = if router_mode == RouterMode::KV {
Some(
manager
.kv_chooser_for(
&component,
card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config),
)
.await?,
)
} else {
None
};
let tokenizer_hf = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
tokenizer_hf.clone(),
None, // No prefill chooser in http static mode
)
.await?;
manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser,
tokenizer_hf,
None, // No prefill chooser in http static mode
)
.await?;
manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true);
}
http_service
}
EngineConfig::StaticFull { engine, model, .. } => { EngineConfig::StaticFull { engine, model, .. } => {
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine)); let engine = Arc::new(StreamingEngineAdapter::new(engine));
......
...@@ -213,11 +213,7 @@ async fn poll_backend_once( ...@@ -213,11 +213,7 @@ async fn poll_backend_once(
) -> anyhow::Result<usize> { ) -> anyhow::Result<usize> {
use dynamo_runtime::pipeline::Context; use dynamo_runtime::pipeline::Context;
// Send request to backend (try static mode first, fall back to dynamic mode) let response_stream = router.random(Context::new("".to_string())).await?;
let response_stream = match router.r#static(Context::new("".to_string())).await {
Ok(stream) => stream,
Err(_) => router.random(Context::new("".to_string())).await?,
};
// Collect responses from the stream // Collect responses from the stream
let mut responses = Vec::new(); let mut responses = Vec::new();
......
...@@ -8,7 +8,7 @@ use std::time::Duration; ...@@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::{ use dynamo_runtime::{
component::{Component, InstanceSource}, component::Component,
discovery::{DiscoveryQuery, watch_and_extract_field}, discovery::{DiscoveryQuery, watch_and_extract_field},
pipeline::{ pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
...@@ -226,12 +226,7 @@ impl KvRouter { ...@@ -226,12 +226,7 @@ impl KvRouter {
let generate_endpoint = component.endpoint("generate"); let generate_endpoint = component.endpoint("generate");
let client = generate_endpoint.client().await?; let client = generate_endpoint.client().await?;
let instances_rx = match client.instance_source.as_ref() { let instances_rx = client.instance_source.as_ref().clone();
InstanceSource::Dynamic(rx) => rx.clone(),
InstanceSource::Static => {
panic!("Expected dynamic instance source for KV routing");
}
};
// Watch for runtime config updates via discovery interface // Watch for runtime config updates via discovery interface
let discovery = component.drt().discovery(); let discovery = component.drt().discovery();
...@@ -508,18 +503,13 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -508,18 +503,13 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
&self, &self,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> { ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
match self.inner.client.instance_source.as_ref() {
InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => {
// Extract context ID for request tracking // Extract context ID for request tracking
let context_id = request.context().id().to_string(); let context_id = request.context().id().to_string();
// Check if this is a query_instance_id request first // Check if this is a query_instance_id request first
let query_instance_id = request.has_annotation("query_instance_id"); let query_instance_id = request.has_annotation("query_instance_id");
let (instance_id, dp_rank, overlap_amount) = if let Some(id) = let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
request.backend_instance_id
{
// If instance_id is set, use it and compute actual overlap // If instance_id is set, use it and compute actual overlap
let dp_rank = request.dp_rank.unwrap_or(0); let dp_rank = request.dp_rank.unwrap_or(0);
if query_instance_id { if query_instance_id {
...@@ -564,12 +554,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -564,12 +554,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = request.context().clone(); let stream_context = request.context().clone();
if query_instance_id { if query_instance_id {
let instance_id_str = instance_id.to_string(); let instance_id_str = instance_id.to_string();
let response = let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
// Return the tokens in nvext.token_data format // Return the tokens in nvext.token_data format
let response_tokens = let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
Annotated::from_annotation("token_data", &request.token_ids)?;
tracing::trace!( tracing::trace!(
"Tokens requested in the response through the query_instance_id annotation: {:?}", "Tokens requested in the response through the query_instance_id annotation: {:?}",
response_tokens response_tokens
...@@ -621,8 +609,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -621,8 +609,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}); });
Ok(ResponseStream::new(wrapped_stream, stream_context)) Ok(ResponseStream::new(wrapped_stream, stream_context))
} }
}
}
} }
impl Drop for KvRouter { impl Drop for KvRouter {
......
...@@ -295,12 +295,7 @@ pub async fn start_kv_router_background( ...@@ -295,12 +295,7 @@ pub async fn start_kv_router_background(
// Get instances_rx for tracking current workers // Get instances_rx for tracking current workers
let client = generate_endpoint.client().await?; let client = generate_endpoint.client().await?;
let instances_rx = match client.instance_source.as_ref() { let instances_rx = client.instance_source.as_ref().clone();
dynamo_runtime::component::InstanceSource::Dynamic(rx) => rx.clone(),
dynamo_runtime::component::InstanceSource::Static => {
anyhow::bail!("Expected dynamic instance source for KV routing");
}
};
// Only set up snapshot-related resources if snapshot_tx, get_workers_tx, and threshold are provided // Only set up snapshot-related resources if snapshot_tx, get_workers_tx, and threshold are provided
let snapshot_resources = if let (Some(get_workers_tx), Some(snapshot_tx), Some(_)) = ( let snapshot_resources = if let (Some(get_workers_tx), Some(snapshot_tx), Some(_)) = (
......
...@@ -51,7 +51,7 @@ mod tests { ...@@ -51,7 +51,7 @@ mod tests {
/// Helper to create a test DistributedRuntime with NATS /// Helper to create a test DistributedRuntime with NATS
async fn create_test_drt() -> dynamo_runtime::DistributedRuntime { async fn create_test_drt() -> dynamo_runtime::DistributedRuntime {
let rt = Runtime::from_current().unwrap(); let rt = Runtime::from_current().unwrap();
let config = dynamo_runtime::distributed::DistributedConfig::from_settings(false); let config = dynamo_runtime::distributed::DistributedConfig::from_settings();
dynamo_runtime::DistributedRuntime::new(rt, config) dynamo_runtime::DistributedRuntime::new(rt, config)
.await .await
.expect("Failed to create DistributedRuntime") .expect("Failed to create DistributedRuntime")
......
...@@ -325,7 +325,6 @@ mod integration_tests { ...@@ -325,7 +325,6 @@ mod integration_tests {
let engine_config = EngineConfig::StaticFull { let engine_config = EngineConfig::StaticFull {
engine: make_echo_engine(), engine: make_echo_engine(),
model: Box::new(local_model.clone()), model: Box::new(local_model.clone()),
is_static: false, // This enables MDC registration!
}; };
let service = HttpService::builder() let service = HttpService::builder()
......
...@@ -69,7 +69,7 @@ mod namespace; ...@@ -69,7 +69,7 @@ mod namespace;
mod registry; mod registry;
pub mod service; pub mod service;
pub use client::{Client, InstanceSource}; pub use client::Client;
/// The root key-value path where each instance registers itself in. /// The root key-value path where each instance registers itself in.
/// An instance is namespace+component+endpoint+lease_id and must be unique. /// An instance is namespace+component+endpoint+lease_id and must be unique.
...@@ -166,10 +166,6 @@ pub struct Component { ...@@ -166,10 +166,6 @@ pub struct Component {
#[builder(setter(into))] #[builder(setter(into))]
namespace: Namespace, namespace: Namespace,
// A static component's endpoints cannot be discovered via etcd, they are
// fixed at startup time.
is_static: bool,
/// This hierarchy's own metrics registry /// This hierarchy's own metrics registry
#[builder(default = "crate::MetricsRegistry::new()")] #[builder(default = "crate::MetricsRegistry::new()")]
metrics_registry: crate::MetricsRegistry, metrics_registry: crate::MetricsRegistry,
...@@ -179,15 +175,12 @@ impl Hash for Component { ...@@ -179,15 +175,12 @@ impl Hash for Component {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.namespace.name().hash(state); self.namespace.name().hash(state);
self.name.hash(state); self.name.hash(state);
self.is_static.hash(state);
} }
} }
impl PartialEq for Component { impl PartialEq for Component {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.namespace.name() == other.namespace.name() self.namespace.name() == other.namespace.name() && self.name == other.name
&& self.name == other.name
&& self.is_static == other.is_static
} }
} }
...@@ -271,7 +264,6 @@ impl Component { ...@@ -271,7 +264,6 @@ impl Component {
Endpoint { Endpoint {
component: self.clone(), component: self.clone(),
name: endpoint.into(), name: endpoint.into(),
is_static: self.is_static,
labels: Vec::new(), labels: Vec::new(),
metrics_registry: crate::MetricsRegistry::new(), metrics_registry: crate::MetricsRegistry::new(),
} }
...@@ -436,8 +428,6 @@ pub struct Endpoint { ...@@ -436,8 +428,6 @@ pub struct Endpoint {
/// Endpoint name /// Endpoint name
name: String, name: String,
is_static: bool,
/// Additional labels for metrics /// Additional labels for metrics
labels: Vec<(String, String)>, labels: Vec<(String, String)>,
...@@ -449,15 +439,12 @@ impl Hash for Endpoint { ...@@ -449,15 +439,12 @@ impl Hash for Endpoint {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.component.hash(state); self.component.hash(state);
self.name.hash(state); self.name.hash(state);
self.is_static.hash(state);
} }
} }
impl PartialEq for Endpoint { impl PartialEq for Endpoint {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.component == other.component self.component == other.component && self.name == other.name
&& self.name == other.name
&& self.is_static == other.is_static
} }
} }
...@@ -554,27 +541,19 @@ impl Endpoint { ...@@ -554,27 +541,19 @@ impl Endpoint {
format!("{ns}/{cp}/{ep}/{lease_id:x}") format!("{ns}/{cp}/{ep}/{lease_id:x}")
} }
/// The endpoint as an EtcdPath object with lease ID /// The endpoint as an EtcdPath object with instance ID
pub fn etcd_path_object_with_lease_id(&self, lease_id: i64) -> EtcdPath { pub fn etcd_path_object_with_lease_id(&self, instance_id: i64) -> EtcdPath {
if self.is_static {
self.etcd_path()
} else {
EtcdPath::new_endpoint_with_lease( EtcdPath::new_endpoint_with_lease(
&self.component.namespace().name(), &self.component.namespace().name(),
self.component.name(), self.component.name(),
&self.name, &self.name,
lease_id, instance_id,
) )
.expect("Endpoint name and component name should be valid") .expect("Endpoint name and component name should be valid")
} }
}
pub fn name_with_id(&self, lease_id: u64) -> String { pub fn name_with_id(&self, instance_id: u64) -> String {
if self.is_static { format!("{}-{:x}", self.name, instance_id)
self.name.clone()
} else {
format!("{}-{:x}", self.name, lease_id)
}
} }
pub fn subject(&self) -> String { pub fn subject(&self) -> String {
...@@ -591,11 +570,7 @@ impl Endpoint { ...@@ -591,11 +570,7 @@ impl Endpoint {
} }
pub async fn client(&self) -> anyhow::Result<client::Client> { pub async fn client(&self) -> anyhow::Result<client::Client> {
if self.is_static { client::Client::new(self.clone()).await
client::Client::new_static(self.clone()).await
} else {
client::Client::new_dynamic(self.clone()).await
}
} }
pub fn endpoint_builder(&self) -> endpoint::EndpointConfigBuilder { pub fn endpoint_builder(&self) -> endpoint::EndpointConfigBuilder {
...@@ -612,8 +587,6 @@ pub struct Namespace { ...@@ -612,8 +587,6 @@ pub struct Namespace {
#[validate(custom(function = "validate_allowed_chars"))] #[validate(custom(function = "validate_allowed_chars"))]
name: String, name: String,
is_static: bool,
#[builder(default = "None")] #[builder(default = "None")]
parent: Option<Arc<Namespace>>, parent: Option<Arc<Namespace>>,
...@@ -636,8 +609,8 @@ impl std::fmt::Debug for Namespace { ...@@ -636,8 +609,8 @@ impl std::fmt::Debug for Namespace {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!( write!(
f, f,
"Namespace {{ name: {}; is_static: {}; parent: {:?} }}", "Namespace {{ name: {}; parent: {:?} }}",
self.name, self.is_static, self.parent self.name, self.parent
) )
} }
} }
...@@ -655,15 +628,10 @@ impl std::fmt::Display for Namespace { ...@@ -655,15 +628,10 @@ impl std::fmt::Display for Namespace {
} }
impl Namespace { impl Namespace {
pub(crate) fn new( pub(crate) fn new(runtime: DistributedRuntime, name: String) -> anyhow::Result<Self> {
runtime: DistributedRuntime,
name: String,
is_static: bool,
) -> anyhow::Result<Self> {
Ok(NamespaceBuilder::default() Ok(NamespaceBuilder::default()
.runtime(Arc::new(runtime)) .runtime(Arc::new(runtime))
.name(name) .name(name)
.is_static(is_static)
.build()?) .build()?)
} }
...@@ -672,7 +640,6 @@ impl Namespace { ...@@ -672,7 +640,6 @@ impl Namespace {
Ok(ComponentBuilder::from_runtime(self.runtime.clone()) Ok(ComponentBuilder::from_runtime(self.runtime.clone())
.name(name) .name(name)
.namespace(self.clone()) .namespace(self.clone())
.is_static(self.is_static)
.build()?) .build()?)
} }
...@@ -681,7 +648,6 @@ impl Namespace { ...@@ -681,7 +648,6 @@ impl Namespace {
Ok(NamespaceBuilder::default() Ok(NamespaceBuilder::default()
.runtime(self.runtime.clone()) .runtime(self.runtime.clone())
.name(name.into()) .name(name.into())
.is_static(self.is_static)
.parent(Some(Arc::new(self.clone()))) .parent(Some(Arc::new(self.clone())))
.build()?) .build()?)
} }
......
...@@ -21,62 +21,25 @@ use crate::{ ...@@ -21,62 +21,25 @@ use crate::{
transports::etcd::Client as EtcdClient, transports::etcd::Client as EtcdClient,
}; };
/// Each state will be have a nonce associated with it
/// The state will be emitted in a watch channel, so we can observe the
/// critical state transitions.
enum MapState {
/// The map is empty; value = nonce
Empty(u64),
/// The map is not-empty; values are (nonce, count)
NonEmpty(u64, u64),
/// The watcher has finished, no more events will be emitted
Finished,
}
enum EndpointEvent {
Put(String, u64),
Delete(String),
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Client { pub struct Client {
// This is me // This is me
pub endpoint: Endpoint, pub endpoint: Endpoint,
// These are the remotes I know about from watching etcd // These are the remotes I know about from watching key-value store
pub instance_source: Arc<InstanceSource>, pub instance_source: Arc<tokio::sync::watch::Receiver<Vec<Instance>>>,
// These are the instance source ids less those reported as down from sending rpc // These are the instance source ids less those reported as down from sending rpc
instance_avail: Arc<ArcSwap<Vec<u64>>>, instance_avail: Arc<ArcSwap<Vec<u64>>>,
// These are the instance source ids less those reported as busy (above threshold) // These are the instance source ids less those reported as busy (above threshold)
instance_free: Arc<ArcSwap<Vec<u64>>>, instance_free: Arc<ArcSwap<Vec<u64>>>,
} }
#[derive(Clone, Debug)]
pub enum InstanceSource {
Static,
Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>),
}
impl Client { impl Client {
// Client will only talk to a single static endpoint // Client with auto-discover instances using key-value store
pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> { pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
Ok(Client {
endpoint,
instance_source: Arc::new(InstanceSource::Static),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
})
}
// Client with auto-discover instances using etcd
pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
tracing::debug!( tracing::debug!(
"Client::new_dynamic: Creating dynamic client for endpoint: {}", "Client::new_dynamic: Creating dynamic client for endpoint: {}",
endpoint.path() endpoint.path()
); );
const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1);
let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?; let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
tracing::debug!( tracing::debug!(
"Client::new_dynamic: Got instance source for endpoint: {}", "Client::new_dynamic: Got instance source for endpoint: {}",
...@@ -110,12 +73,9 @@ impl Client { ...@@ -110,12 +73,9 @@ impl Client {
self.endpoint.etcd_root() self.endpoint.etcd_root()
} }
/// Instances available from watching etcd /// Instances available from watching key-value store
pub fn instances(&self) -> Vec<Instance> { pub fn instances(&self) -> Vec<Instance> {
match self.instance_source.as_ref() { self.instance_source.borrow().clone()
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
} }
pub fn instance_ids(&self) -> Vec<u64> { pub fn instance_ids(&self) -> Vec<u64> {
...@@ -136,10 +96,10 @@ impl Client { ...@@ -136,10 +96,10 @@ impl Client {
"wait_for_instances: Starting wait for endpoint: {}", "wait_for_instances: Starting wait for endpoint: {}",
self.endpoint.path() self.endpoint.path()
); );
let mut instances: Vec<Instance> = vec![]; let mut rx = self.instance_source.as_ref().clone();
if let InstanceSource::Dynamic(mut rx) = self.instance_source.as_ref().clone() {
// wait for there to be 1 or more endpoints // wait for there to be 1 or more endpoints
let mut iteration = 0; let mut iteration = 0;
let mut instances: Vec<Instance>;
loop { loop {
instances = rx.borrow_and_update().to_vec(); instances = rx.borrow_and_update().to_vec();
tracing::debug!( tracing::debug!(
...@@ -168,20 +128,9 @@ impl Client { ...@@ -168,20 +128,9 @@ impl Client {
} }
iteration += 1; iteration += 1;
} }
} else {
tracing::debug!(
"wait_for_instances: Static instance source, no dynamic discovery for endpoint: {}",
self.endpoint.path()
);
}
Ok(instances) Ok(instances)
} }
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.instance_source.as_ref(), InstanceSource::Static)
}
/// Mark an instance as down/unavailable /// Mark an instance as down/unavailable
pub fn report_instance_down(&self, instance_id: u64) { pub fn report_instance_down(&self, instance_id: u64) {
let filtered = self let filtered = self
...@@ -204,7 +153,7 @@ impl Client { ...@@ -204,7 +153,7 @@ impl Client {
self.instance_free.store(Arc::new(free_ids)); self.instance_free.store(Arc::new(free_ids));
} }
/// Monitor the ETCD instance source and update instance_avail. /// Monitor the key-value instance source and update instance_avail.
fn monitor_instance_source(&self) { fn monitor_instance_source(&self) {
let cancel_token = self.endpoint.drt().primary_token(); let cancel_token = self.endpoint.drt().primary_token();
let client = self.clone(); let client = self.clone();
...@@ -214,15 +163,7 @@ impl Client { ...@@ -214,15 +163,7 @@ impl Client {
endpoint_path endpoint_path
); );
tokio::task::spawn(async move { tokio::task::spawn(async move {
let mut rx = match client.instance_source.as_ref() { let mut rx = client.instance_source.as_ref().clone();
InstanceSource::Static => {
tracing::error!(
"monitor_instance_source: Static instance source is not watchable"
);
return;
}
InstanceSource::Dynamic(rx) => rx.clone(),
};
let mut iteration = 0; let mut iteration = 0;
while !cancel_token.is_cancelled() { while !cancel_token.is_cancelled() {
let instance_ids: Vec<u64> = rx let instance_ids: Vec<u64> = rx
...@@ -267,7 +208,7 @@ impl Client { ...@@ -267,7 +208,7 @@ impl Client {
async fn get_or_create_dynamic_instance_source( async fn get_or_create_dynamic_instance_source(
endpoint: &Endpoint, endpoint: &Endpoint,
) -> Result<Arc<InstanceSource>> { ) -> Result<Arc<tokio::sync::watch::Receiver<Vec<Instance>>>> {
let drt = endpoint.drt(); let drt = endpoint.drt();
let instance_sources = drt.instance_sources(); let instance_sources = drt.instance_sources();
let mut instance_sources = instance_sources.lock().await; let mut instance_sources = instance_sources.lock().await;
...@@ -401,7 +342,7 @@ impl Client { ...@@ -401,7 +342,7 @@ impl Client {
let _ = watch_tx.send(vec![]); let _ = watch_tx.send(vec![]);
}); });
let instance_source = Arc::new(InstanceSource::Dynamic(watch_rx)); let instance_source = Arc::new(watch_rx);
instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source)); instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
tracing::debug!( tracing::debug!(
"get_or_create_dynamic_instance_source: Successfully created and cached instance source for endpoint: {}", "get_or_create_dynamic_instance_source: Successfully created and cached instance source for endpoint: {}",
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::component::Component; use crate::component::{Component, Instance};
use crate::pipeline::PipelineError; use crate::pipeline::PipelineError;
use crate::storage::key_value_store::{ use crate::storage::key_value_store::{
EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, KeyValueStoreSelect, EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, KeyValueStoreSelect,
...@@ -9,7 +9,7 @@ use crate::storage::key_value_store::{ ...@@ -9,7 +9,7 @@ use crate::storage::key_value_store::{
}; };
use crate::transports::nats::DRTNatsClientPrometheusMetrics; use crate::transports::nats::DRTNatsClientPrometheusMetrics;
use crate::{ use crate::{
component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace}, component::{self, ComponentBuilder, Endpoint, Namespace},
discovery::Discovery, discovery::Discovery,
metrics::PrometheusUpdateCallback, metrics::PrometheusUpdateCallback,
metrics::{MetricsHierarchy, MetricsRegistry}, metrics::{MetricsHierarchy, MetricsRegistry},
...@@ -24,6 +24,7 @@ use crate::runtime::Runtime; ...@@ -24,6 +24,7 @@ use crate::runtime::Runtime;
use async_once_cell::OnceCell; use async_once_cell::OnceCell;
use std::sync::{Arc, OnceLock, Weak}; use std::sync::{Arc, OnceLock, Weak};
use tokio::sync::watch::Receiver;
use anyhow::Result; use anyhow::Result;
use derive_getters::Dissolve; use derive_getters::Dissolve;
...@@ -32,6 +33,8 @@ use std::collections::HashMap; ...@@ -32,6 +33,8 @@ use std::collections::HashMap;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
type InstanceMap = HashMap<Endpoint, Weak<Receiver<Vec<Instance>>>>;
/// Distributed [Runtime] which provides access to shared resources across the cluster, this includes /// Distributed [Runtime] which provides access to shared resources across the cluster, this includes
/// communication protocols and transports. /// communication protocols and transports.
#[derive(Clone)] #[derive(Clone)]
...@@ -60,11 +63,7 @@ pub struct DistributedRuntime { ...@@ -60,11 +63,7 @@ pub struct DistributedRuntime {
// paths in etcd to a minimum. // paths in etcd to a minimum.
component_registry: component::Registry, component_registry: component::Registry,
// Will only have static components that are not discoverable via etcd, they must be know at instance_sources: Arc<tokio::sync::Mutex<InstanceMap>>,
// startup. Will not start etcd.
is_static: bool,
instance_sources: Arc<tokio::sync::Mutex<HashMap<Endpoint, Weak<InstanceSource>>>>,
// Health Status // Health Status
system_health: Arc<parking_lot::Mutex<SystemHealth>>, system_health: Arc<parking_lot::Mutex<SystemHealth>>,
...@@ -95,12 +94,12 @@ impl std::fmt::Debug for DistributedRuntime { ...@@ -95,12 +94,12 @@ impl std::fmt::Debug for DistributedRuntime {
impl DistributedRuntime { impl DistributedRuntime {
pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> { pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> {
let (selected_kv_store, nats_config, is_static) = config.dissolve(); let (selected_kv_store, nats_config) = config.dissolve();
let runtime_clone = runtime.clone(); let runtime_clone = runtime.clone();
let (etcd_client, store) = match (is_static, selected_kv_store) { let (etcd_client, store) = match selected_kv_store {
(false, KeyValueStoreSelect::Etcd(etcd_config)) => { KeyValueStoreSelect::Etcd(etcd_config) => {
let etcd_client = etcd::Client::new(*etcd_config, runtime_clone).await.inspect_err(|err| let etcd_client = etcd::Client::new(*etcd_config, runtime_clone).await.inspect_err(|err|
// The returned error doesn't show because of a dropped runtime error, so // The returned error doesn't show because of a dropped runtime error, so
// log it first. // log it first.
...@@ -108,10 +107,8 @@ impl DistributedRuntime { ...@@ -108,10 +107,8 @@ impl DistributedRuntime {
let store = KeyValueStoreManager::etcd(etcd_client.clone()); let store = KeyValueStoreManager::etcd(etcd_client.clone());
(Some(etcd_client), store) (Some(etcd_client), store)
} }
(false, KeyValueStoreSelect::File(root)) => (None, KeyValueStoreManager::file(root)), KeyValueStoreSelect::File(root) => (None, KeyValueStoreManager::file(root)),
(true, _) | (false, KeyValueStoreSelect::Memory) => { KeyValueStoreSelect::Memory => (None, KeyValueStoreManager::memory()),
(None, KeyValueStoreManager::memory())
}
}; };
let nats_client = Some(nats_config.clone().connect().await?); let nats_client = Some(nats_config.clone().connect().await?);
...@@ -181,7 +178,6 @@ impl DistributedRuntime { ...@@ -181,7 +178,6 @@ impl DistributedRuntime {
discovery_client, discovery_client,
discovery_metadata, discovery_metadata,
component_registry: component::Registry::new(), component_registry: component::Registry::new(),
is_static,
instance_sources: Arc::new(Mutex::new(HashMap::new())), instance_sources: Arc::new(Mutex::new(HashMap::new())),
metrics_registry: crate::MetricsRegistry::new(), metrics_registry: crate::MetricsRegistry::new(),
system_health, system_health,
...@@ -283,13 +279,7 @@ impl DistributedRuntime { ...@@ -283,13 +279,7 @@ impl DistributedRuntime {
} }
pub async fn from_settings(runtime: Runtime) -> Result<Self> { pub async fn from_settings(runtime: Runtime) -> Result<Self> {
let config = DistributedConfig::from_settings(false); let config = DistributedConfig::from_settings();
Self::new(runtime, config).await
}
// Call this if you are using static workers that do not need etcd-based discovery.
pub async fn from_settings_without_discovery(runtime: Runtime) -> Result<Self> {
let config = DistributedConfig::from_settings(true);
Self::new(runtime, config).await Self::new(runtime, config).await
} }
...@@ -323,7 +313,7 @@ impl DistributedRuntime { ...@@ -323,7 +313,7 @@ impl DistributedRuntime {
/// Create a [`Namespace`] /// Create a [`Namespace`]
pub fn namespace(&self, name: impl Into<String>) -> Result<Namespace> { pub fn namespace(&self, name: impl Into<String>) -> Result<Namespace> {
Namespace::new(self.clone(), name.into(), self.is_static) Namespace::new(self.clone(), name.into())
} }
/// Returns the discovery interface for service registration and discovery /// Returns the discovery interface for service registration and discovery
...@@ -380,7 +370,7 @@ impl DistributedRuntime { ...@@ -380,7 +370,7 @@ impl DistributedRuntime {
self.runtime.graceful_shutdown_tracker() self.runtime.graceful_shutdown_tracker()
} }
pub fn instance_sources(&self) -> Arc<Mutex<HashMap<Endpoint, Weak<InstanceSource>>>> { pub fn instance_sources(&self) -> Arc<Mutex<InstanceMap>> {
self.instance_sources.clone() self.instance_sources.clone()
} }
} }
...@@ -389,15 +379,13 @@ impl DistributedRuntime { ...@@ -389,15 +379,13 @@ impl DistributedRuntime {
pub struct DistributedConfig { pub struct DistributedConfig {
pub store_backend: KeyValueStoreSelect, pub store_backend: KeyValueStoreSelect,
pub nats_config: nats::ClientOptions, pub nats_config: nats::ClientOptions,
pub is_static: bool,
} }
impl DistributedConfig { impl DistributedConfig {
pub fn from_settings(is_static: bool) -> DistributedConfig { pub fn from_settings() -> DistributedConfig {
DistributedConfig { DistributedConfig {
store_backend: KeyValueStoreSelect::Etcd(Box::default()), store_backend: KeyValueStoreSelect::Etcd(Box::default()),
nats_config: nats::ClientOptions::default(), nats_config: nats::ClientOptions::default(),
is_static,
} }
} }
...@@ -409,24 +397,26 @@ impl DistributedConfig { ...@@ -409,24 +397,26 @@ impl DistributedConfig {
DistributedConfig { DistributedConfig {
store_backend: KeyValueStoreSelect::Etcd(Box::new(etcd_config)), store_backend: KeyValueStoreSelect::Etcd(Box::new(etcd_config)),
nats_config: nats::ClientOptions::default(), nats_config: nats::ClientOptions::default(),
is_static: false,
} }
} }
} }
pub mod distributed_test_utils { pub mod distributed_test_utils {
//! Common test helper functions for DistributedRuntime tests //! Common test helper functions for DistributedRuntime tests
// TODO: Use in-memory DistributedRuntime for tests instead of full runtime when available.
/// Helper function to create a DRT instance for integration-only tests. /// Helper function to create a DRT instance for integration-only tests.
/// Uses from_current to leverage existing tokio runtime /// Uses from_current to leverage existing tokio runtime
/// Note: Settings are read from environment variables inside DistributedRuntime::from_settings_without_discovery /// Note: Settings are read from environment variables inside DistributedRuntime::from_settings
#[cfg(feature = "integration")] #[cfg(feature = "integration")]
pub async fn create_test_drt_async() -> crate::DistributedRuntime { pub async fn create_test_drt_async() -> super::DistributedRuntime {
use crate::{storage::key_value_store::KeyValueStoreSelect, transports::nats};
let rt = crate::Runtime::from_current().unwrap(); let rt = crate::Runtime::from_current().unwrap();
crate::DistributedRuntime::from_settings_without_discovery(rt) let config = super::DistributedConfig {
.await store_backend: KeyValueStoreSelect::Memory,
.unwrap() nats_config: nats::ClientOptions::default(),
};
super::DistributedRuntime::new(rt, config).await.unwrap()
} }
} }
......
...@@ -80,7 +80,7 @@ impl HealthCheckManager { ...@@ -80,7 +80,7 @@ impl HealthCheckManager {
} }
// Create a client that discovers instances dynamically for this endpoint // Create a client that discovers instances dynamically for this endpoint
let client = Client::new_dynamic(endpoint).await?; let client = Client::new(endpoint).await?;
// Create PushRouter - it will use direct routing when we call direct() // Create PushRouter - it will use direct routing when we call direct()
let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new( let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new(
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG}; use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
use crate::{ use crate::{
component::{Client, Endpoint, InstanceSource}, component::{Client, Endpoint},
engine::{AsyncEngine, Data}, engine::{AsyncEngine, Data},
pipeline::{ pipeline::{
AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn, AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
...@@ -118,9 +118,7 @@ where ...@@ -118,9 +118,7 @@ where
let addressed = addressed_router(&client.endpoint).await?; let addressed = addressed_router(&client.endpoint).await?;
// Start worker monitor if provided and in dynamic mode // Start worker monitor if provided and in dynamic mode
if let Some(monitor) = worker_monitor.as_ref() if let Some(monitor) = worker_monitor.as_ref() {
&& matches!(client.instance_source.as_ref(), InstanceSource::Dynamic(_))
{
monitor.start_monitoring().await?; monitor.start_monitoring().await?;
} }
...@@ -196,6 +194,7 @@ where ...@@ -196,6 +194,7 @@ where
.await .await
} }
/*
pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> { pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let subject = self.client.endpoint.subject(); let subject = self.client.endpoint.subject();
tracing::debug!("static got subject: {subject}"); tracing::debug!("static got subject: {subject}");
...@@ -203,6 +202,7 @@ where ...@@ -203,6 +202,7 @@ where
tracing::debug!("router generate"); tracing::debug!("router generate");
self.addressed.generate(request).await self.addressed.generate(request).await
} }
*/
async fn generate_with_fault_detection( async fn generate_with_fault_detection(
&self, &self,
...@@ -268,16 +268,14 @@ where ...@@ -268,16 +268,14 @@ where
U: Data + for<'de> Deserialize<'de> + MaybeError, U: Data + for<'de> Deserialize<'de> + MaybeError,
{ {
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> { async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match self.client.instance_source.as_ref() { //InstanceSource::Static => self.r#static(request).await,
InstanceSource::Static => self.r#static(request).await, match self.router_mode {
InstanceSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await, RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await, RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(instance_id) => self.direct(request, instance_id).await, RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
RouterMode::KV => { RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter"); anyhow::bail!("KV routing should not call generate on PushRouter");
} }
},
} }
} }
} }
...@@ -324,6 +324,8 @@ impl Runtime { ...@@ -324,6 +324,8 @@ impl Runtime {
"Phase 3: All endpoints ended gracefully. Connections to NATS/ETCD will now be disconnected" "Phase 3: All endpoints ended gracefully. Connections to NATS/ETCD will now be disconnected"
); );
main_token.cancel(); main_token.cancel();
// TODO: We should likely call shutdown_background on tokio rt to stop it cleanly.
}); });
} }
} }
......
...@@ -264,7 +264,7 @@ mod concurrent_create_tests { ...@@ -264,7 +264,7 @@ mod concurrent_create_tests {
fn test_concurrent_etcd_create_race_condition() { fn test_concurrent_etcd_create_race_condition() {
let rt = Runtime::from_settings().unwrap(); let rt = Runtime::from_settings().unwrap();
let rt_clone = rt.clone(); let rt_clone = rt.clone();
let config = DistributedConfig::from_settings(false); let config = DistributedConfig::from_settings();
rt_clone.primary().block_on(async move { rt_clone.primary().block_on(async move {
let drt = DistributedRuntime::new(rt, config).await.unwrap(); let drt = DistributedRuntime::new(rt, config).await.unwrap();
......
...@@ -751,7 +751,7 @@ mod tests { ...@@ -751,7 +751,7 @@ mod tests {
fn test_ectd_client() { fn test_ectd_client() {
let rt = Runtime::from_settings().unwrap(); let rt = Runtime::from_settings().unwrap();
let rt_clone = rt.clone(); let rt_clone = rt.clone();
let config = DistributedConfig::from_settings(false); let config = DistributedConfig::from_settings();
rt_clone.primary().block_on(async move { rt_clone.primary().block_on(async move {
let drt = DistributedRuntime::new(rt, config).await.unwrap(); let drt = DistributedRuntime::new(rt, config).await.unwrap();
...@@ -794,7 +794,7 @@ mod tests { ...@@ -794,7 +794,7 @@ mod tests {
fn test_kv_cache() { fn test_kv_cache() {
let rt = Runtime::from_settings().unwrap(); let rt = Runtime::from_settings().unwrap();
let rt_clone = rt.clone(); let rt_clone = rt.clone();
let config = DistributedConfig::from_settings(false); let config = DistributedConfig::from_settings();
rt_clone.primary().block_on(async move { rt_clone.primary().block_on(async move {
let drt = DistributedRuntime::new(rt, config).await.unwrap(); let drt = DistributedRuntime::new(rt, config).await.unwrap();
......
...@@ -35,10 +35,17 @@ fn test_namespace_etcd_path_format() { ...@@ -35,10 +35,17 @@ fn test_namespace_etcd_path_format() {
#[cfg(feature = "integration")] #[cfg(feature = "integration")]
#[tokio::test] #[tokio::test]
async fn test_recursive_namespace_implementation() { async fn test_recursive_namespace_implementation() {
use dynamo_runtime::{
distributed::DistributedConfig, storage::key_value_store::KeyValueStoreSelect,
transports::nats,
};
let runtime = Runtime::from_current().unwrap(); let runtime = Runtime::from_current().unwrap();
let distributed_runtime = DistributedRuntime::from_settings_without_discovery(runtime) let config = DistributedConfig {
.await store_backend: KeyValueStoreSelect::Memory,
.unwrap(); nats_config: nats::ClientOptions::default(),
};
let distributed_runtime = DistributedRuntime::new(runtime, config).await.unwrap();
// Test single namespace // Test single namespace
let ns1 = distributed_runtime.namespace("ns1").unwrap(); let ns1 = distributed_runtime.namespace("ns1").unwrap();
...@@ -74,10 +81,17 @@ async fn test_recursive_namespace_implementation() { ...@@ -74,10 +81,17 @@ async fn test_recursive_namespace_implementation() {
#[cfg(feature = "integration")] #[cfg(feature = "integration")]
#[tokio::test] #[tokio::test]
async fn test_multiple_branches_recursive_namespaces() { async fn test_multiple_branches_recursive_namespaces() {
use dynamo_runtime::{
distributed::DistributedConfig, storage::key_value_store::KeyValueStoreSelect,
transports::nats,
};
let runtime = Runtime::from_current().unwrap(); let runtime = Runtime::from_current().unwrap();
let distributed_runtime = DistributedRuntime::from_settings_without_discovery(runtime) let config = DistributedConfig {
.await store_backend: KeyValueStoreSelect::Memory,
.unwrap(); nats_config: nats::ClientOptions::default(),
};
let distributed_runtime = DistributedRuntime::new(runtime, config).await.unwrap();
// Create root namespace // Create root namespace
let root = distributed_runtime.namespace("root").unwrap(); let root = distributed_runtime.namespace("root").unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment