Unverified Commit aeb79e62 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Support multiple models on single ingress node (#1127)

We can now do this:

- Node 1:

```
dynamo-run in=http out=dyn
```

- Node 2 and 3, two instances of component 'backend' in the nemotron_ultra pipeline:

```
dynamo-run in=dyn://nemotron_ultra.backend.generate out=vllm /data/models/NemotronUltra
```

- Node 4 and 5, two instances of the 'backend' component in nemotron_super pipeline:

```
dynamo-run in=dyn://nemotron_super.backend.generate out=vllm /data/models/NemotronSuper
```

The ingress node will discover all four instances and route correctly. We have been planning for this for a long time now.

As part of this auto-discovery is now always `out=dyn`, with no extra URL parts. Previously it could only route to a single pipeline.

Also:
- Refactor endpoint / instance naming now that I understand them
- Fix removing models when their instance stops.
parent 74221fd7
...@@ -53,13 +53,13 @@ pub struct Client { ...@@ -53,13 +53,13 @@ pub struct Client {
// This is me // This is me
pub endpoint: Endpoint, pub endpoint: Endpoint,
// These are the remotes I know about // These are the remotes I know about
pub endpoints: EndpointSource, pub instances: InstanceSource,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum EndpointSource { pub enum InstanceSource {
Static, Static,
Dynamic(tokio::sync::watch::Receiver<Vec<ComponentEndpointInfo>>), Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>),
} }
impl Client { impl Client {
...@@ -67,18 +67,18 @@ impl Client { ...@@ -67,18 +67,18 @@ impl Client {
pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> { pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
Ok(Client { Ok(Client {
endpoint, endpoint,
endpoints: EndpointSource::Static, instances: InstanceSource::Static,
}) })
} }
// Client with auto-discover endpoints using etcd // Client with auto-discover instances using etcd
pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> { pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
// create live endpoint watcher // create live endpoint watcher
let Some(etcd_client) = &endpoint.component.drt.etcd_client else { let Some(etcd_client) = &endpoint.component.drt.etcd_client else {
anyhow::bail!("Attempt to create a dynamic client on a static endpoint"); anyhow::bail!("Attempt to create a dynamic client on a static endpoint");
}; };
let prefix_watcher = etcd_client let prefix_watcher = etcd_client
.kv_get_and_watch_prefix(endpoint.etcd_path()) .kv_get_and_watch_prefix(endpoint.etcd_root())
.await?; .await?;
let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve(); let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
...@@ -114,7 +114,7 @@ impl Client { ...@@ -114,7 +114,7 @@ impl Client {
match kv_event { match kv_event {
WatchEvent::Put(kv) => { WatchEvent::Put(kv) => {
let key = String::from_utf8(kv.key().to_vec()); let key = String::from_utf8(kv.key().to_vec());
let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value()); let val = serde_json::from_slice::<Instance>(kv.value());
if let (Ok(key), Ok(val)) = (key, val) { if let (Ok(key), Ok(val)) = (key, val) {
map.insert(key.clone(), val); map.insert(key.clone(), val);
} else { } else {
...@@ -133,65 +133,64 @@ impl Client { ...@@ -133,65 +133,64 @@ impl Client {
} }
} }
let endpoints: Vec<ComponentEndpointInfo> = map.values().cloned().collect(); let instances: Vec<Instance> = map.values().cloned().collect();
if watch_tx.send(endpoints).is_err() { if watch_tx.send(instances).is_err() {
tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix); tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
break; break;
} }
} }
tracing::debug!("Completed endpoint watcher for prefix: {}", prefix); tracing::debug!("Completed endpoint watcher for prefix: {prefix}");
let _ = watch_tx.send(vec![]); let _ = watch_tx.send(vec![]);
}); });
Ok(Client { Ok(Client {
endpoint, endpoint,
endpoints: EndpointSource::Dynamic(watch_rx), instances: InstanceSource::Dynamic(watch_rx),
}) })
} }
/// String identifying `<namespace>/<component>/<endpoint>`
pub fn path(&self) -> String { pub fn path(&self) -> String {
self.endpoint.path() self.endpoint.path()
} }
/// String identifying `<namespace>/component/<component>/<endpoint>` /// The root etcd path we watch in etcd to discover new instances to route to.
pub fn etcd_path(&self) -> String { pub fn etcd_root(&self) -> String {
self.endpoint.etcd_path() self.endpoint.etcd_root()
} }
pub fn endpoints(&self) -> Vec<ComponentEndpointInfo> { pub fn instances(&self) -> Vec<Instance> {
match &self.endpoints { match &self.instances {
EndpointSource::Static => vec![], InstanceSource::Static => vec![],
EndpointSource::Dynamic(watch_rx) => watch_rx.borrow().clone(), InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
} }
} }
pub fn endpoint_ids(&self) -> Vec<i64> { pub fn instance_ids(&self) -> Vec<i64> {
self.endpoints().into_iter().map(|ep| ep.id()).collect() self.instances().into_iter().map(|ep| ep.id()).collect()
} }
/// Wait for at least one [`Endpoint`] to be available /// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_endpoints(&self) -> Result<Vec<ComponentEndpointInfo>> { pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
let mut endpoints: Vec<ComponentEndpointInfo> = vec![]; let mut instances: Vec<Instance> = vec![];
if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() { if let InstanceSource::Dynamic(mut rx) = self.instances.clone() {
// wait for there to be 1 or more endpoints // wait for there to be 1 or more endpoints
loop { loop {
endpoints = rx.borrow_and_update().to_vec(); instances = rx.borrow_and_update().to_vec();
if endpoints.is_empty() { if instances.is_empty() {
rx.changed().await?; rx.changed().await?;
} else { } else {
break; break;
} }
} }
} }
Ok(endpoints) Ok(instances)
} }
/// Is this component know at startup and not discovered via etcd? /// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool { pub fn is_static(&self) -> bool {
matches!(self.endpoints, EndpointSource::Static) matches!(self.instances, InstanceSource::Static)
} }
} }
...@@ -59,10 +59,7 @@ impl EndpointConfigBuilder { ...@@ -59,10 +59,7 @@ impl EndpointConfigBuilder {
let lease = lease.or(endpoint.drt().primary_lease()); let lease = lease.or(endpoint.drt().primary_lease());
let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0); let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0);
tracing::debug!( tracing::debug!("Starting endpoint: {}", endpoint.etcd_path(lease_id));
"Starting endpoint: {}",
endpoint.etcd_path_with_id(lease_id)
);
let service_name = endpoint.component.service_name(); let service_name = endpoint.component.service_name();
...@@ -115,11 +112,11 @@ impl EndpointConfigBuilder { ...@@ -115,11 +112,11 @@ impl EndpointConfigBuilder {
// make the components service endpoint discovery in etcd // make the components service endpoint discovery in etcd
// client.register_service() // client.register_service()
let info = ComponentEndpointInfo { let info = Instance {
component: endpoint.component.name.clone(), component: endpoint.component.name.clone(),
endpoint: endpoint.name.clone(), endpoint: endpoint.name.clone(),
namespace: endpoint.component.namespace.name.clone(), namespace: endpoint.component.namespace.name.clone(),
lease_id, instance_id: lease_id,
transport: TransportType::NatsTcp(endpoint.subject_to(lease_id)), transport: TransportType::NatsTcp(endpoint.subject_to(lease_id)),
}; };
...@@ -127,7 +124,7 @@ impl EndpointConfigBuilder { ...@@ -127,7 +124,7 @@ impl EndpointConfigBuilder {
if let Some(etcd_client) = &endpoint.component.drt.etcd_client { if let Some(etcd_client) = &endpoint.component.drt.etcd_client {
if let Err(e) = etcd_client if let Err(e) = etcd_client
.kv_create(endpoint.etcd_path_with_id(lease_id), info, Some(lease_id)) .kv_create(endpoint.etcd_path(lease_id), info, Some(lease_id))
.await .await
{ {
tracing::error!("Failed to register discoverable service: {:?}", e); tracing::error!("Failed to register discoverable service: {:?}", e);
......
...@@ -85,8 +85,6 @@ impl ServiceConfigBuilder { ...@@ -85,8 +85,6 @@ impl ServiceConfigBuilder {
.await .await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?; .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
tracing::debug!("Service started TEMP");
// new copy of service_name as the previous one is moved into the task above // new copy of service_name as the previous one is moved into the task above
let service_name = component.service_name(); let service_name = component.service_name();
...@@ -101,7 +99,6 @@ impl ServiceConfigBuilder { ...@@ -101,7 +99,6 @@ impl ServiceConfigBuilder {
// drop the guard to unlock the mutex // drop the guard to unlock the mutex
drop(guard); drop(guard);
tracing::debug!("create done");
Ok(component) Ok(component)
} }
} }
......
...@@ -25,7 +25,7 @@ use std::{ ...@@ -25,7 +25,7 @@ use std::{
}; };
use crate::{ use crate::{
component::{Client, Endpoint, EndpointSource}, component::{Client, Endpoint, InstanceSource},
engine::{AsyncEngine, Data}, engine::{AsyncEngine, Data},
pipeline::{AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn}, pipeline::{AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn},
traits::DistributedRuntimeProvider, traits::DistributedRuntimeProvider,
...@@ -41,7 +41,7 @@ where ...@@ -41,7 +41,7 @@ where
/// The Client is how we gather remote endpoint information from etcd. /// The Client is how we gather remote endpoint information from etcd.
pub client: Client, pub client: Client,
/// How we choose which endpoint to send traffic to. /// How we choose which instance to send traffic to.
/// ///
/// Setting this to KV means we never intend to call `generate` on this PushRouter. We are /// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
/// not using it as an AsyncEngine. /// not using it as an AsyncEngine.
...@@ -52,7 +52,7 @@ where ...@@ -52,7 +52,7 @@ where
/// Number of round robin requests handled. Used to decide which server is next. /// Number of round robin requests handled. Used to decide which server is next.
round_robin_counter: Arc<AtomicU64>, round_robin_counter: Arc<AtomicU64>,
/// The next step in the chain. PushRouter (this object) picks an endpoint, /// The next step in the chain. PushRouter (this object) picks an instances,
/// addresses it, then passes it to AddressedPushRouter which does the network traffic. /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>, addressed: Arc<AddressedPushRouter>,
...@@ -101,25 +101,25 @@ where ...@@ -101,25 +101,25 @@ where
}) })
} }
/// Issue a request to the next available endpoint in a round-robin fashion /// Issue a request to the next available instance in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> { pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed); let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
let endpoint_id = { let instance_id = {
let endpoints = self.client.endpoints(); let instances = self.client.instances();
let count = endpoints.len(); let count = instances.len();
if count == 0 { if count == 0 {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"no endpoints found for endpoint {:?}", "no instances found for endpoint {:?}",
self.client.endpoint.etcd_path() self.client.endpoint.etcd_root()
)); ));
} }
let offset = counter % count as u64; let offset = counter % count as u64;
endpoints[offset as usize].id() instances[offset as usize].id()
}; };
tracing::trace!("round robin router selected {endpoint_id}"); tracing::trace!("round robin router selected {instance_id}");
let subject = self.client.endpoint.subject_to(endpoint_id); let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject)); let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await self.addressed.generate(request).await
...@@ -127,22 +127,22 @@ where ...@@ -127,22 +127,22 @@ where
/// Issue a request to a random endpoint /// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> { pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let endpoint_id = { let instance_id = {
let endpoints = self.client.endpoints(); let instances = self.client.instances();
let count = endpoints.len(); let count = instances.len();
if count == 0 { if count == 0 {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"no endpoints found for endpoint {:?}", "no instances found for endpoint {:?}",
self.client.endpoint.etcd_path() self.client.endpoint.etcd_root()
)); ));
} }
let counter = rand::rng().random::<u64>(); let counter = rand::rng().random::<u64>();
let offset = counter % count as u64; let offset = counter % count as u64;
endpoints[offset as usize].id() instances[offset as usize].id()
}; };
tracing::trace!("random router selected {endpoint_id}"); tracing::trace!("random router selected {instance_id}");
let subject = self.client.endpoint.subject_to(endpoint_id); let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject)); let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await self.addressed.generate(request).await
...@@ -152,22 +152,21 @@ where ...@@ -152,22 +152,21 @@ where
pub async fn direct( pub async fn direct(
&self, &self,
request: SingleIn<T>, request: SingleIn<T>,
endpoint_id: i64, instance_id: i64,
) -> anyhow::Result<ManyOut<U>> { ) -> anyhow::Result<ManyOut<U>> {
let found = { let found = {
let endpoints = self.client.endpoints(); let instances = self.client.instances();
endpoints.iter().any(|ep| ep.id() == endpoint_id) instances.iter().any(|ep| ep.id() == instance_id)
}; };
if !found { if !found {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"endpoint_id={} not found for endpoint {:?}", "instance_id={instance_id} not found for endpoint {:?}",
endpoint_id, self.client.endpoint.etcd_root()
self.client.endpoint.etcd_path()
)); ));
} }
let subject = self.client.endpoint.subject_to(endpoint_id); let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject)); let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await self.addressed.generate(request).await
...@@ -189,12 +188,12 @@ where ...@@ -189,12 +188,12 @@ where
U: Data + for<'de> Deserialize<'de>, U: Data + for<'de> Deserialize<'de>,
{ {
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> { async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match &self.client.endpoints { match &self.client.instances {
EndpointSource::Static => self.r#static(request).await, InstanceSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => match self.router_mode { InstanceSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await, RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await, RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await, RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
RouterMode::KV => { RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter"); anyhow::bail!("KV routing should not call generate on PushRouter");
} }
......
...@@ -114,7 +114,7 @@ mod integration { ...@@ -114,7 +114,7 @@ mod integration {
.client::<String, Annotated<String>>() .client::<String, Annotated<String>>()
.await?; .await?;
client.wait_for_endpoints().await?; client.wait_for_instances().await?;
let client = Arc::new(client); let client = Arc::new(client);
let start = Instant::now(); let start = Instant::now();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment