Unverified Commit e1af3af6 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Remove static mode (#4235)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent d9b674b8
......@@ -13,7 +13,7 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
TEST_END_TO_END = os.environ.get("TEST_END_TO_END", 0)
@dynamo_worker(static=False)
@dynamo_worker()
async def test_register(runtime: DistributedRuntime):
component = runtime.namespace("test").component("tensor")
await component.create_service()
......
......@@ -44,21 +44,16 @@ pub enum EngineConfig {
/// Remote networked engines that we discover via etcd
Dynamic(Box<LocalModel>),
/// Remote networked engines that we know about at startup
StaticRemote(Box<LocalModel>),
/// A Full service engine does it's own tokenization and prompt formatting.
StaticFull {
engine: Arc<dyn StreamingEngine>,
model: Box<LocalModel>,
is_static: bool,
},
/// A core engine expects to be wrapped with pre/post processors that handle tokenization.
StaticCore {
engine: ExecutionContext,
model: Box<LocalModel>,
is_static: bool,
is_prefill: bool,
},
}
......@@ -68,7 +63,6 @@ impl EngineConfig {
use EngineConfig::*;
match self {
Dynamic(lm) => lm,
StaticRemote(lm) => lm,
StaticFull { model, .. } => model,
StaticCore { model, .. } => model,
}
......
......@@ -7,7 +7,7 @@ use crate::{
backend::{Backend, ExecutionContext},
discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig},
entrypoint::EngineConfig,
kv_router::{KvPushRouter, KvRouter, PrefillRouter},
migration::Migration,
model_card::ModelDeploymentCard,
......@@ -95,58 +95,6 @@ pub async fn prepare_engine(
request_template: local_model.request_template(),
})
}
EngineConfig::StaticRemote(local_model) => {
// The card should have been loaded at 'build' phase earlier
let card = local_model.card();
let router_mode = local_model.router_config().router_mode;
let endpoint_id = local_model.endpoint_id();
let component = distributed_runtime
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?;
let kv_chooser = if router_mode == RouterMode::KV {
let model_manager = Arc::new(ModelManager::new());
Some(
model_manager
.kv_chooser_for(
&component,
card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config),
)
.await?,
)
} else {
None
};
let hf_tokenizer = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
hf_tokenizer,
None, // No prefill chooser in static mode
)
.await?;
let service_name = local_model.service_name().to_string();
tracing::info!("Static connecting to {service_name}");
Ok(PreparedEngine {
service_name,
engine: chat_engine,
inspect_template: false,
request_template: local_model.request_template(),
card: Some(local_model.into_card()),
})
}
EngineConfig::StaticFull { engine, model, .. } => {
let service_name = model.service_name().to_string();
tracing::debug!("Model: {service_name} with engine pre-processing");
......
......@@ -43,22 +43,15 @@ pub async fn run(
let endpoint = component.endpoint(&endpoint_id.name);
let rt_fut: Pin<Box<dyn Future<Output = _> + Send + 'static>> = match engine_config {
EngineConfig::StaticFull {
engine,
mut model,
is_static,
} => {
EngineConfig::StaticFull { engine, mut model } => {
let engine = Arc::new(StreamingEngineAdapter::new(engine));
let ingress_chat = Ingress::<
Context<NvCreateChatCompletionRequest>,
Pin<Box<dyn AsyncEngineStream<Annotated<NvCreateChatCompletionStreamResponse>>>>,
>::for_engine(engine)?;
if !is_static {
model
.attach(&endpoint, ModelType::Chat, ModelInput::Text)
.await?;
}
model
.attach(&endpoint, ModelType::Chat, ModelInput::Text)
.await?;
let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start();
Box::pin(fut_chat)
......@@ -66,7 +59,6 @@ pub async fn run(
EngineConfig::StaticCore {
engine: inner_engine,
mut model,
is_static,
is_prefill,
} => {
// Pre-processing is done ingress-side, so it should be already done.
......@@ -83,24 +75,18 @@ pub async fn run(
.link(frontend)?;
let ingress = Ingress::for_pipeline(pipeline)?;
if !is_static {
let model_type = if is_prefill {
ModelType::Prefill
} else {
ModelType::Chat | ModelType::Completions
};
model
.attach(&endpoint, model_type, ModelInput::Tokens)
.await?;
}
let model_type = if is_prefill {
ModelType::Prefill
} else {
ModelType::Chat | ModelType::Completions
};
model
.attach(&endpoint, model_type, ModelInput::Tokens)
.await?;
let fut = endpoint.endpoint_builder().handler(ingress).start();
Box::pin(fut)
}
EngineConfig::StaticRemote(_) => {
panic!("StaticRemote definitions are only for the frontend end node.");
}
EngineConfig::Dynamic(_) => {
unreachable!("An endpoint input will never have a Dynamic engine");
}
......@@ -155,7 +141,6 @@ mod integration_tests {
.await
.map_err(|e| anyhow::anyhow!("Failed to build LocalModel: {}", e))?,
),
is_static: false,
};
Ok((distributed_runtime, engine_config))
......
......@@ -6,7 +6,7 @@ use std::sync::Arc;
use crate::{
discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common},
entrypoint::{EngineConfig, input::common},
grpc::service::kserve,
kv_router::KvRouterConfig,
namespace::is_global_namespace,
......@@ -49,75 +49,6 @@ pub async fn run(
.await?;
grpc_service
}
EngineConfig::StaticRemote(local_model) => {
let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode;
let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager();
let endpoint_id = local_model.endpoint_id();
let component = distributed_runtime
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?;
let kv_chooser = if router_mode == RouterMode::KV {
Some(
manager
.kv_chooser_for(
&component,
card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config),
)
.await?,
)
} else {
None
};
let tokenizer_hf = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
tokenizer_hf.clone(),
None, // No prefill chooser in grpc static mode
)
.await?;
manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser,
tokenizer_hf,
None, // No prefill chooser in grpc static mode
)
.await?;
manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
grpc_service
}
EngineConfig::StaticFull { engine, model, .. } => {
let grpc_service = grpc_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine));
......
......@@ -7,7 +7,7 @@ use crate::{
discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType,
engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common},
entrypoint::{EngineConfig, input::common},
http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig,
namespace::is_global_namespace,
......@@ -91,78 +91,6 @@ pub async fn run(
.await?;
http_service
}
EngineConfig::StaticRemote(local_model) => {
let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode;
let http_service = http_service_builder.build()?;
let manager = http_service.model_manager();
let endpoint_id = local_model.endpoint_id();
let component = distributed_runtime
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?;
let kv_chooser = if router_mode == RouterMode::KV {
Some(
manager
.kv_chooser_for(
&component,
card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config),
)
.await?,
)
} else {
None
};
let tokenizer_hf = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
tokenizer_hf.clone(),
None, // No prefill chooser in http static mode
)
.await?;
manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser,
tokenizer_hf,
None, // No prefill chooser in http static mode
)
.await?;
manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true);
}
http_service
}
EngineConfig::StaticFull { engine, model, .. } => {
let http_service = http_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine));
......
......@@ -213,11 +213,7 @@ async fn poll_backend_once(
) -> anyhow::Result<usize> {
use dynamo_runtime::pipeline::Context;
// Send request to backend (try static mode first, fall back to dynamic mode)
let response_stream = match router.r#static(Context::new("".to_string())).await {
Ok(stream) => stream,
Err(_) => router.random(Context::new("".to_string())).await?,
};
let response_stream = router.random(Context::new("".to_string())).await?;
// Collect responses from the stream
let mut responses = Vec::new();
......
......@@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::{
component::{Component, InstanceSource},
component::Component,
discovery::{DiscoveryQuery, watch_and_extract_field},
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
......@@ -226,12 +226,7 @@ impl KvRouter {
let generate_endpoint = component.endpoint("generate");
let client = generate_endpoint.client().await?;
let instances_rx = match client.instance_source.as_ref() {
InstanceSource::Dynamic(rx) => rx.clone(),
InstanceSource::Static => {
panic!("Expected dynamic instance source for KV routing");
}
};
let instances_rx = client.instance_source.as_ref().clone();
// Watch for runtime config updates via discovery interface
let discovery = component.drt().discovery();
......@@ -508,120 +503,111 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
match self.inner.client.instance_source.as_ref() {
InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => {
// Extract context ID for request tracking
let context_id = request.context().id().to_string();
// Check if this is a query_instance_id request first
let query_instance_id = request.has_annotation("query_instance_id");
let (instance_id, dp_rank, overlap_amount) = if let Some(id) =
request.backend_instance_id
{
// If instance_id is set, use it and compute actual overlap
let dp_rank = request.dp_rank.unwrap_or(0);
if query_instance_id {
tracing::debug!(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
);
// Extract context ID for request tracking
let context_id = request.context().id().to_string();
// Check if this is a query_instance_id request first
let query_instance_id = request.has_annotation("query_instance_id");
let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
// If instance_id is set, use it and compute actual overlap
let dp_rank = request.dp_rank.unwrap_or(0);
if query_instance_id {
tracing::debug!(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
);
}
// Compute actual overlap blocks by querying the indexer
let block_hashes =
compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);
self.chooser
.add_request(
context_id.clone(),
&request.token_ids,
overlap_blocks,
worker,
)
.await;
(id, dp_rank, overlap_blocks)
} else {
// Otherwise, find the best match
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(&context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id
)
.await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount)
};
// if request has the annotation "query_instance_id",
// then the request will not be routed to the worker,
// and instead the worker_instance_id will be returned.
let stream_context = request.context().clone();
if query_instance_id {
let instance_id_str = instance_id.to_string();
let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
// Return the tokens in nvext.token_data format
let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
tracing::trace!(
"Tokens requested in the response through the query_instance_id annotation: {:?}",
response_tokens
);
let stream = stream::iter(vec![response, response_tokens]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input);
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
let stream_context = response_stream.context();
let chooser = self.chooser.clone();
let context_for_monitoring = stream_context.clone();
let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false;
loop {
tokio::select! {
biased;
_ = context_for_monitoring.stopped() => {
tracing::debug!("Request {context_id} cancelled, ending stream");
break;
}
// Compute actual overlap blocks by querying the indexer
let block_hashes =
compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);
self.chooser
.add_request(
context_id.clone(),
&request.token_ids,
overlap_blocks,
worker,
)
.await;
(id, dp_rank, overlap_blocks)
} else {
// Otherwise, find the best match
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(&context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id
)
.await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount)
};
// if request has the annotation "query_instance_id",
// then the request will not be routed to the worker,
// and instead the worker_instance_id will be returned.
let stream_context = request.context().clone();
if query_instance_id {
let instance_id_str = instance_id.to_string();
let response =
Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
// Return the tokens in nvext.token_data format
let response_tokens =
Annotated::from_annotation("token_data", &request.token_ids)?;
tracing::trace!(
"Tokens requested in the response through the query_instance_id annotation: {:?}",
response_tokens
);
let stream = stream::iter(vec![response, response_tokens]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input);
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
let stream_context = response_stream.context();
let chooser = self.chooser.clone();
let context_for_monitoring = stream_context.clone();
let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false;
loop {
tokio::select! {
biased;
_ = context_for_monitoring.stopped() => {
tracing::debug!("Request {context_id} cancelled, ending stream");
break;
}
item = response_stream.next() => {
let Some(item) = item else {
break;
};
item = response_stream.next() => {
let Some(item) = item else {
break;
};
if !prefill_marked {
if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
}
prefill_marked = true;
}
yield item;
if !prefill_marked {
if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
}
prefill_marked = true;
}
yield item;
}
}
}
if let Err(e) = chooser.free(&context_id).await {
tracing::warn!("Failed to free request {context_id}: {e:?}");
}
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
if let Err(e) = chooser.free(&context_id).await {
tracing::warn!("Failed to free request {context_id}: {e:?}");
}
}
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
}
......
......@@ -295,12 +295,7 @@ pub async fn start_kv_router_background(
// Get instances_rx for tracking current workers
let client = generate_endpoint.client().await?;
let instances_rx = match client.instance_source.as_ref() {
dynamo_runtime::component::InstanceSource::Dynamic(rx) => rx.clone(),
dynamo_runtime::component::InstanceSource::Static => {
anyhow::bail!("Expected dynamic instance source for KV routing");
}
};
let instances_rx = client.instance_source.as_ref().clone();
// Only set up snapshot-related resources if snapshot_tx, get_workers_tx, and threshold are provided
let snapshot_resources = if let (Some(get_workers_tx), Some(snapshot_tx), Some(_)) = (
......
......@@ -51,7 +51,7 @@ mod tests {
/// Helper to create a test DistributedRuntime with NATS
async fn create_test_drt() -> dynamo_runtime::DistributedRuntime {
let rt = Runtime::from_current().unwrap();
let config = dynamo_runtime::distributed::DistributedConfig::from_settings(false);
let config = dynamo_runtime::distributed::DistributedConfig::from_settings();
dynamo_runtime::DistributedRuntime::new(rt, config)
.await
.expect("Failed to create DistributedRuntime")
......
......@@ -325,7 +325,6 @@ mod integration_tests {
let engine_config = EngineConfig::StaticFull {
engine: make_echo_engine(),
model: Box::new(local_model.clone()),
is_static: false, // This enables MDC registration!
};
let service = HttpService::builder()
......
......@@ -69,7 +69,7 @@ mod namespace;
mod registry;
pub mod service;
pub use client::{Client, InstanceSource};
pub use client::Client;
/// The root key-value path where each instance registers itself in.
/// An instance is namespace+component+endpoint+lease_id and must be unique.
......@@ -166,10 +166,6 @@ pub struct Component {
#[builder(setter(into))]
namespace: Namespace,
// A static component's endpoints cannot be discovered via etcd, they are
// fixed at startup time.
is_static: bool,
/// This hierarchy's own metrics registry
#[builder(default = "crate::MetricsRegistry::new()")]
metrics_registry: crate::MetricsRegistry,
......@@ -179,15 +175,12 @@ impl Hash for Component {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.namespace.name().hash(state);
self.name.hash(state);
self.is_static.hash(state);
}
}
impl PartialEq for Component {
fn eq(&self, other: &Self) -> bool {
self.namespace.name() == other.namespace.name()
&& self.name == other.name
&& self.is_static == other.is_static
self.namespace.name() == other.namespace.name() && self.name == other.name
}
}
......@@ -271,7 +264,6 @@ impl Component {
Endpoint {
component: self.clone(),
name: endpoint.into(),
is_static: self.is_static,
labels: Vec::new(),
metrics_registry: crate::MetricsRegistry::new(),
}
......@@ -436,8 +428,6 @@ pub struct Endpoint {
/// Endpoint name
name: String,
is_static: bool,
/// Additional labels for metrics
labels: Vec<(String, String)>,
......@@ -449,15 +439,12 @@ impl Hash for Endpoint {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.component.hash(state);
self.name.hash(state);
self.is_static.hash(state);
}
}
impl PartialEq for Endpoint {
fn eq(&self, other: &Self) -> bool {
self.component == other.component
&& self.name == other.name
&& self.is_static == other.is_static
self.component == other.component && self.name == other.name
}
}
......@@ -554,27 +541,19 @@ impl Endpoint {
format!("{ns}/{cp}/{ep}/{lease_id:x}")
}
/// The endpoint as an EtcdPath object with lease ID
pub fn etcd_path_object_with_lease_id(&self, lease_id: i64) -> EtcdPath {
if self.is_static {
self.etcd_path()
} else {
EtcdPath::new_endpoint_with_lease(
&self.component.namespace().name(),
self.component.name(),
&self.name,
lease_id,
)
.expect("Endpoint name and component name should be valid")
}
/// The endpoint as an EtcdPath object with instance ID
pub fn etcd_path_object_with_lease_id(&self, instance_id: i64) -> EtcdPath {
EtcdPath::new_endpoint_with_lease(
&self.component.namespace().name(),
self.component.name(),
&self.name,
instance_id,
)
.expect("Endpoint name and component name should be valid")
}
pub fn name_with_id(&self, lease_id: u64) -> String {
if self.is_static {
self.name.clone()
} else {
format!("{}-{:x}", self.name, lease_id)
}
pub fn name_with_id(&self, instance_id: u64) -> String {
format!("{}-{:x}", self.name, instance_id)
}
pub fn subject(&self) -> String {
......@@ -591,11 +570,7 @@ impl Endpoint {
}
pub async fn client(&self) -> anyhow::Result<client::Client> {
if self.is_static {
client::Client::new_static(self.clone()).await
} else {
client::Client::new_dynamic(self.clone()).await
}
client::Client::new(self.clone()).await
}
pub fn endpoint_builder(&self) -> endpoint::EndpointConfigBuilder {
......@@ -612,8 +587,6 @@ pub struct Namespace {
#[validate(custom(function = "validate_allowed_chars"))]
name: String,
is_static: bool,
#[builder(default = "None")]
parent: Option<Arc<Namespace>>,
......@@ -636,8 +609,8 @@ impl std::fmt::Debug for Namespace {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Namespace {{ name: {}; is_static: {}; parent: {:?} }}",
self.name, self.is_static, self.parent
"Namespace {{ name: {}; parent: {:?} }}",
self.name, self.parent
)
}
}
......@@ -655,15 +628,10 @@ impl std::fmt::Display for Namespace {
}
impl Namespace {
pub(crate) fn new(
runtime: DistributedRuntime,
name: String,
is_static: bool,
) -> anyhow::Result<Self> {
pub(crate) fn new(runtime: DistributedRuntime, name: String) -> anyhow::Result<Self> {
Ok(NamespaceBuilder::default()
.runtime(Arc::new(runtime))
.name(name)
.is_static(is_static)
.build()?)
}
......@@ -672,7 +640,6 @@ impl Namespace {
Ok(ComponentBuilder::from_runtime(self.runtime.clone())
.name(name)
.namespace(self.clone())
.is_static(self.is_static)
.build()?)
}
......@@ -681,7 +648,6 @@ impl Namespace {
Ok(NamespaceBuilder::default()
.runtime(self.runtime.clone())
.name(name.into())
.is_static(self.is_static)
.parent(Some(Arc::new(self.clone())))
.build()?)
}
......
......@@ -21,62 +21,25 @@ use crate::{
transports::etcd::Client as EtcdClient,
};
/// Each state will be have a nonce associated with it
/// The state will be emitted in a watch channel, so we can observe the
/// critical state transitions.
enum MapState {
/// The map is empty; value = nonce
Empty(u64),
/// The map is not-empty; values are (nonce, count)
NonEmpty(u64, u64),
/// The watcher has finished, no more events will be emitted
Finished,
}
enum EndpointEvent {
Put(String, u64),
Delete(String),
}
#[derive(Clone, Debug)]
pub struct Client {
// This is me
pub endpoint: Endpoint,
// These are the remotes I know about from watching etcd
pub instance_source: Arc<InstanceSource>,
// These are the remotes I know about from watching key-value store
pub instance_source: Arc<tokio::sync::watch::Receiver<Vec<Instance>>>,
// These are the instance source ids less those reported as down from sending rpc
instance_avail: Arc<ArcSwap<Vec<u64>>>,
// These are the instance source ids less those reported as busy (above threshold)
instance_free: Arc<ArcSwap<Vec<u64>>>,
}
#[derive(Clone, Debug)]
pub enum InstanceSource {
Static,
Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>),
}
impl Client {
// Client will only talk to a single static endpoint
pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
Ok(Client {
endpoint,
instance_source: Arc::new(InstanceSource::Static),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
})
}
// Client with auto-discover instances using etcd
pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
// Client with auto-discover instances using key-value store
pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
tracing::debug!(
"Client::new_dynamic: Creating dynamic client for endpoint: {}",
endpoint.path()
);
const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1);
let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
tracing::debug!(
"Client::new_dynamic: Got instance source for endpoint: {}",
......@@ -110,12 +73,9 @@ impl Client {
self.endpoint.etcd_root()
}
/// Instances available from watching etcd
/// Instances available from watching key-value store
pub fn instances(&self) -> Vec<Instance> {
match self.instance_source.as_ref() {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
self.instance_source.borrow().clone()
}
pub fn instance_ids(&self) -> Vec<u64> {
......@@ -136,52 +96,41 @@ impl Client {
"wait_for_instances: Starting wait for endpoint: {}",
self.endpoint.path()
);
let mut instances: Vec<Instance> = vec![];
if let InstanceSource::Dynamic(mut rx) = self.instance_source.as_ref().clone() {
// wait for there to be 1 or more endpoints
let mut iteration = 0;
loop {
instances = rx.borrow_and_update().to_vec();
let mut rx = self.instance_source.as_ref().clone();
// wait for there to be 1 or more endpoints
let mut iteration = 0;
let mut instances: Vec<Instance>;
loop {
instances = rx.borrow_and_update().to_vec();
tracing::debug!(
"wait_for_instances: iteration={}, current_instance_count={}, endpoint={}",
iteration,
instances.len(),
self.endpoint.path()
);
if instances.is_empty() {
tracing::debug!(
"wait_for_instances: iteration={}, current_instance_count={}, endpoint={}",
iteration,
"wait_for_instances: No instances yet, waiting for change notification for endpoint: {}",
self.endpoint.path()
);
rx.changed().await?;
tracing::debug!(
"wait_for_instances: Change notification received for endpoint: {}",
self.endpoint.path()
);
} else {
tracing::info!(
"wait_for_instances: Found {} instance(s) for endpoint: {}",
instances.len(),
self.endpoint.path()
);
if instances.is_empty() {
tracing::debug!(
"wait_for_instances: No instances yet, waiting for change notification for endpoint: {}",
self.endpoint.path()
);
rx.changed().await?;
tracing::debug!(
"wait_for_instances: Change notification received for endpoint: {}",
self.endpoint.path()
);
} else {
tracing::info!(
"wait_for_instances: Found {} instance(s) for endpoint: {}",
instances.len(),
self.endpoint.path()
);
break;
}
iteration += 1;
break;
}
} else {
tracing::debug!(
"wait_for_instances: Static instance source, no dynamic discovery for endpoint: {}",
self.endpoint.path()
);
iteration += 1;
}
Ok(instances)
}
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.instance_source.as_ref(), InstanceSource::Static)
}
/// Mark an instance as down/unavailable
pub fn report_instance_down(&self, instance_id: u64) {
let filtered = self
......@@ -204,7 +153,7 @@ impl Client {
self.instance_free.store(Arc::new(free_ids));
}
/// Monitor the ETCD instance source and update instance_avail.
/// Monitor the key-value instance source and update instance_avail.
fn monitor_instance_source(&self) {
let cancel_token = self.endpoint.drt().primary_token();
let client = self.clone();
......@@ -214,15 +163,7 @@ impl Client {
endpoint_path
);
tokio::task::spawn(async move {
let mut rx = match client.instance_source.as_ref() {
InstanceSource::Static => {
tracing::error!(
"monitor_instance_source: Static instance source is not watchable"
);
return;
}
InstanceSource::Dynamic(rx) => rx.clone(),
};
let mut rx = client.instance_source.as_ref().clone();
let mut iteration = 0;
while !cancel_token.is_cancelled() {
let instance_ids: Vec<u64> = rx
......@@ -267,7 +208,7 @@ impl Client {
async fn get_or_create_dynamic_instance_source(
endpoint: &Endpoint,
) -> Result<Arc<InstanceSource>> {
) -> Result<Arc<tokio::sync::watch::Receiver<Vec<Instance>>>> {
let drt = endpoint.drt();
let instance_sources = drt.instance_sources();
let mut instance_sources = instance_sources.lock().await;
......@@ -401,7 +342,7 @@ impl Client {
let _ = watch_tx.send(vec![]);
});
let instance_source = Arc::new(InstanceSource::Dynamic(watch_rx));
let instance_source = Arc::new(watch_rx);
instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
tracing::debug!(
"get_or_create_dynamic_instance_source: Successfully created and cached instance source for endpoint: {}",
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::component::Component;
use crate::component::{Component, Instance};
use crate::pipeline::PipelineError;
use crate::storage::key_value_store::{
EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, KeyValueStoreSelect,
......@@ -9,7 +9,7 @@ use crate::storage::key_value_store::{
};
use crate::transports::nats::DRTNatsClientPrometheusMetrics;
use crate::{
component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace},
component::{self, ComponentBuilder, Endpoint, Namespace},
discovery::Discovery,
metrics::PrometheusUpdateCallback,
metrics::{MetricsHierarchy, MetricsRegistry},
......@@ -24,6 +24,7 @@ use crate::runtime::Runtime;
use async_once_cell::OnceCell;
use std::sync::{Arc, OnceLock, Weak};
use tokio::sync::watch::Receiver;
use anyhow::Result;
use derive_getters::Dissolve;
......@@ -32,6 +33,8 @@ use std::collections::HashMap;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
type InstanceMap = HashMap<Endpoint, Weak<Receiver<Vec<Instance>>>>;
/// Distributed [Runtime] which provides access to shared resources across the cluster, this includes
/// communication protocols and transports.
#[derive(Clone)]
......@@ -60,11 +63,7 @@ pub struct DistributedRuntime {
// paths in etcd to a minimum.
component_registry: component::Registry,
// Will only have static components that are not discoverable via etcd, they must be know at
// startup. Will not start etcd.
is_static: bool,
instance_sources: Arc<tokio::sync::Mutex<HashMap<Endpoint, Weak<InstanceSource>>>>,
instance_sources: Arc<tokio::sync::Mutex<InstanceMap>>,
// Health Status
system_health: Arc<parking_lot::Mutex<SystemHealth>>,
......@@ -95,12 +94,12 @@ impl std::fmt::Debug for DistributedRuntime {
impl DistributedRuntime {
pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> {
let (selected_kv_store, nats_config, is_static) = config.dissolve();
let (selected_kv_store, nats_config) = config.dissolve();
let runtime_clone = runtime.clone();
let (etcd_client, store) = match (is_static, selected_kv_store) {
(false, KeyValueStoreSelect::Etcd(etcd_config)) => {
let (etcd_client, store) = match selected_kv_store {
KeyValueStoreSelect::Etcd(etcd_config) => {
let etcd_client = etcd::Client::new(*etcd_config, runtime_clone).await.inspect_err(|err|
// The returned error doesn't show because of a dropped runtime error, so
// log it first.
......@@ -108,10 +107,8 @@ impl DistributedRuntime {
let store = KeyValueStoreManager::etcd(etcd_client.clone());
(Some(etcd_client), store)
}
(false, KeyValueStoreSelect::File(root)) => (None, KeyValueStoreManager::file(root)),
(true, _) | (false, KeyValueStoreSelect::Memory) => {
(None, KeyValueStoreManager::memory())
}
KeyValueStoreSelect::File(root) => (None, KeyValueStoreManager::file(root)),
KeyValueStoreSelect::Memory => (None, KeyValueStoreManager::memory()),
};
let nats_client = Some(nats_config.clone().connect().await?);
......@@ -181,7 +178,6 @@ impl DistributedRuntime {
discovery_client,
discovery_metadata,
component_registry: component::Registry::new(),
is_static,
instance_sources: Arc::new(Mutex::new(HashMap::new())),
metrics_registry: crate::MetricsRegistry::new(),
system_health,
......@@ -283,13 +279,7 @@ impl DistributedRuntime {
}
pub async fn from_settings(runtime: Runtime) -> Result<Self> {
let config = DistributedConfig::from_settings(false);
Self::new(runtime, config).await
}
// Call this if you are using static workers that do not need etcd-based discovery.
pub async fn from_settings_without_discovery(runtime: Runtime) -> Result<Self> {
let config = DistributedConfig::from_settings(true);
let config = DistributedConfig::from_settings();
Self::new(runtime, config).await
}
......@@ -323,7 +313,7 @@ impl DistributedRuntime {
/// Create a [`Namespace`]
pub fn namespace(&self, name: impl Into<String>) -> Result<Namespace> {
Namespace::new(self.clone(), name.into(), self.is_static)
Namespace::new(self.clone(), name.into())
}
/// Returns the discovery interface for service registration and discovery
......@@ -380,7 +370,7 @@ impl DistributedRuntime {
self.runtime.graceful_shutdown_tracker()
}
pub fn instance_sources(&self) -> Arc<Mutex<HashMap<Endpoint, Weak<InstanceSource>>>> {
pub fn instance_sources(&self) -> Arc<Mutex<InstanceMap>> {
self.instance_sources.clone()
}
}
......@@ -389,15 +379,13 @@ impl DistributedRuntime {
pub struct DistributedConfig {
pub store_backend: KeyValueStoreSelect,
pub nats_config: nats::ClientOptions,
pub is_static: bool,
}
impl DistributedConfig {
pub fn from_settings(is_static: bool) -> DistributedConfig {
pub fn from_settings() -> DistributedConfig {
DistributedConfig {
store_backend: KeyValueStoreSelect::Etcd(Box::default()),
nats_config: nats::ClientOptions::default(),
is_static,
}
}
......@@ -409,24 +397,26 @@ impl DistributedConfig {
DistributedConfig {
store_backend: KeyValueStoreSelect::Etcd(Box::new(etcd_config)),
nats_config: nats::ClientOptions::default(),
is_static: false,
}
}
}
pub mod distributed_test_utils {
//! Common test helper functions for DistributedRuntime tests
// TODO: Use in-memory DistributedRuntime for tests instead of full runtime when available.
/// Helper function to create a DRT instance for integration-only tests.
/// Uses from_current to leverage existing tokio runtime
/// Note: Settings are read from environment variables inside DistributedRuntime::from_settings_without_discovery
/// Note: Settings are read from environment variables inside DistributedRuntime::from_settings
#[cfg(feature = "integration")]
pub async fn create_test_drt_async() -> crate::DistributedRuntime {
pub async fn create_test_drt_async() -> super::DistributedRuntime {
use crate::{storage::key_value_store::KeyValueStoreSelect, transports::nats};
let rt = crate::Runtime::from_current().unwrap();
crate::DistributedRuntime::from_settings_without_discovery(rt)
.await
.unwrap()
let config = super::DistributedConfig {
store_backend: KeyValueStoreSelect::Memory,
nats_config: nats::ClientOptions::default(),
};
super::DistributedRuntime::new(rt, config).await.unwrap()
}
}
......
......@@ -80,7 +80,7 @@ impl HealthCheckManager {
}
// Create a client that discovers instances dynamically for this endpoint
let client = Client::new_dynamic(endpoint).await?;
let client = Client::new(endpoint).await?;
// Create PushRouter - it will use direct routing when we call direct()
let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new(
......
......@@ -3,7 +3,7 @@
use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
use crate::{
component::{Client, Endpoint, InstanceSource},
component::{Client, Endpoint},
engine::{AsyncEngine, Data},
pipeline::{
AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
......@@ -118,9 +118,7 @@ where
let addressed = addressed_router(&client.endpoint).await?;
// Start worker monitor if provided and in dynamic mode
if let Some(monitor) = worker_monitor.as_ref()
&& matches!(client.instance_source.as_ref(), InstanceSource::Dynamic(_))
{
if let Some(monitor) = worker_monitor.as_ref() {
monitor.start_monitoring().await?;
}
......@@ -196,6 +194,7 @@ where
.await
}
/*
pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let subject = self.client.endpoint.subject();
tracing::debug!("static got subject: {subject}");
......@@ -203,6 +202,7 @@ where
tracing::debug!("router generate");
self.addressed.generate(request).await
}
*/
async fn generate_with_fault_detection(
&self,
......@@ -268,16 +268,14 @@ where
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match self.client.instance_source.as_ref() {
InstanceSource::Static => self.r#static(request).await,
InstanceSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter");
}
},
//InstanceSource::Static => self.r#static(request).await,
match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter");
}
}
}
}
......@@ -324,6 +324,8 @@ impl Runtime {
"Phase 3: All endpoints ended gracefully. Connections to NATS/ETCD will now be disconnected"
);
main_token.cancel();
// TODO: We should likely call shutdown_background on tokio rt to stop it cleanly.
});
}
}
......
......@@ -264,7 +264,7 @@ mod concurrent_create_tests {
fn test_concurrent_etcd_create_race_condition() {
let rt = Runtime::from_settings().unwrap();
let rt_clone = rt.clone();
let config = DistributedConfig::from_settings(false);
let config = DistributedConfig::from_settings();
rt_clone.primary().block_on(async move {
let drt = DistributedRuntime::new(rt, config).await.unwrap();
......
......@@ -751,7 +751,7 @@ mod tests {
fn test_ectd_client() {
let rt = Runtime::from_settings().unwrap();
let rt_clone = rt.clone();
let config = DistributedConfig::from_settings(false);
let config = DistributedConfig::from_settings();
rt_clone.primary().block_on(async move {
let drt = DistributedRuntime::new(rt, config).await.unwrap();
......@@ -794,7 +794,7 @@ mod tests {
fn test_kv_cache() {
let rt = Runtime::from_settings().unwrap();
let rt_clone = rt.clone();
let config = DistributedConfig::from_settings(false);
let config = DistributedConfig::from_settings();
rt_clone.primary().block_on(async move {
let drt = DistributedRuntime::new(rt, config).await.unwrap();
......
......@@ -35,10 +35,17 @@ fn test_namespace_etcd_path_format() {
#[cfg(feature = "integration")]
#[tokio::test]
async fn test_recursive_namespace_implementation() {
use dynamo_runtime::{
distributed::DistributedConfig, storage::key_value_store::KeyValueStoreSelect,
transports::nats,
};
let runtime = Runtime::from_current().unwrap();
let distributed_runtime = DistributedRuntime::from_settings_without_discovery(runtime)
.await
.unwrap();
let config = DistributedConfig {
store_backend: KeyValueStoreSelect::Memory,
nats_config: nats::ClientOptions::default(),
};
let distributed_runtime = DistributedRuntime::new(runtime, config).await.unwrap();
// Test single namespace
let ns1 = distributed_runtime.namespace("ns1").unwrap();
......@@ -74,10 +81,17 @@ async fn test_recursive_namespace_implementation() {
#[cfg(feature = "integration")]
#[tokio::test]
async fn test_multiple_branches_recursive_namespaces() {
use dynamo_runtime::{
distributed::DistributedConfig, storage::key_value_store::KeyValueStoreSelect,
transports::nats,
};
let runtime = Runtime::from_current().unwrap();
let distributed_runtime = DistributedRuntime::from_settings_without_discovery(runtime)
.await
.unwrap();
let config = DistributedConfig {
store_backend: KeyValueStoreSelect::Memory,
nats_config: nats::ClientOptions::default(),
};
let distributed_runtime = DistributedRuntime::new(runtime, config).await.unwrap();
// Create root namespace
let root = distributed_runtime.namespace("root").unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment