Unverified Commit aeb79e62 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Support multiple models on single ingress node (#1127)

We can now do this:

- Node 1:

```
dynamo-run in=http out=dyn
```

- Node 2 and 3, two instances of component 'backend' in the nemotron_ultra pipeline:

```
dynamo-run in=dyn://nemotron_ultra.backend.generate out=vllm /data/models/NemotronUltra
```

- Node 4 and 5, two instances of the 'backend' component in nemotron_super pipeline:

```
dynamo-run in=dyn://nemotron_super.backend.generate out=vllm /data/models/NemotronSuper
```

The ingress node will discover all four instances and route correctly. We have been planning for this for a long time now.

As part of this auto-discovery is now always `out=dyn`, with no extra URL parts. Previously it could only route to a single pipeline.

Also:
- Refactor endpoint / instance naming now that I understand them
- Fix removing models when their instance stops.
parent 74221fd7
......@@ -53,13 +53,13 @@ pub struct Client {
// This is me
pub endpoint: Endpoint,
// These are the remotes I know about
pub endpoints: EndpointSource,
pub instances: InstanceSource,
}
#[derive(Clone, Debug)]
pub enum EndpointSource {
pub enum InstanceSource {
Static,
Dynamic(tokio::sync::watch::Receiver<Vec<ComponentEndpointInfo>>),
Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>),
}
impl Client {
......@@ -67,18 +67,18 @@ impl Client {
pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
Ok(Client {
endpoint,
endpoints: EndpointSource::Static,
instances: InstanceSource::Static,
})
}
// Client with auto-discover endpoints using etcd
// Client with auto-discover instances using etcd
pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
// create live endpoint watcher
let Some(etcd_client) = &endpoint.component.drt.etcd_client else {
anyhow::bail!("Attempt to create a dynamic client on a static endpoint");
};
let prefix_watcher = etcd_client
.kv_get_and_watch_prefix(endpoint.etcd_path())
.kv_get_and_watch_prefix(endpoint.etcd_root())
.await?;
let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
......@@ -114,7 +114,7 @@ impl Client {
match kv_event {
WatchEvent::Put(kv) => {
let key = String::from_utf8(kv.key().to_vec());
let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value());
let val = serde_json::from_slice::<Instance>(kv.value());
if let (Ok(key), Ok(val)) = (key, val) {
map.insert(key.clone(), val);
} else {
......@@ -133,65 +133,64 @@ impl Client {
}
}
let endpoints: Vec<ComponentEndpointInfo> = map.values().cloned().collect();
let instances: Vec<Instance> = map.values().cloned().collect();
if watch_tx.send(endpoints).is_err() {
if watch_tx.send(instances).is_err() {
tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
break;
}
}
tracing::debug!("Completed endpoint watcher for prefix: {}", prefix);
tracing::debug!("Completed endpoint watcher for prefix: {prefix}");
let _ = watch_tx.send(vec![]);
});
Ok(Client {
endpoint,
endpoints: EndpointSource::Dynamic(watch_rx),
instances: InstanceSource::Dynamic(watch_rx),
})
}
/// String identifying `<namespace>/<component>/<endpoint>`
pub fn path(&self) -> String {
self.endpoint.path()
}
/// String identifying `<namespace>/component/<component>/<endpoint>`
pub fn etcd_path(&self) -> String {
self.endpoint.etcd_path()
/// The root etcd path we watch in etcd to discover new instances to route to.
pub fn etcd_root(&self) -> String {
self.endpoint.etcd_root()
}
pub fn endpoints(&self) -> Vec<ComponentEndpointInfo> {
match &self.endpoints {
EndpointSource::Static => vec![],
EndpointSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
pub fn instances(&self) -> Vec<Instance> {
match &self.instances {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
}
pub fn endpoint_ids(&self) -> Vec<i64> {
self.endpoints().into_iter().map(|ep| ep.id()).collect()
pub fn instance_ids(&self) -> Vec<i64> {
self.instances().into_iter().map(|ep| ep.id()).collect()
}
/// Wait for at least one [`Endpoint`] to be available
pub async fn wait_for_endpoints(&self) -> Result<Vec<ComponentEndpointInfo>> {
let mut endpoints: Vec<ComponentEndpointInfo> = vec![];
if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() {
/// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
let mut instances: Vec<Instance> = vec![];
if let InstanceSource::Dynamic(mut rx) = self.instances.clone() {
// wait for there to be 1 or more endpoints
loop {
endpoints = rx.borrow_and_update().to_vec();
if endpoints.is_empty() {
instances = rx.borrow_and_update().to_vec();
if instances.is_empty() {
rx.changed().await?;
} else {
break;
}
}
}
Ok(endpoints)
Ok(instances)
}
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.endpoints, EndpointSource::Static)
matches!(self.instances, InstanceSource::Static)
}
}
......@@ -59,10 +59,7 @@ impl EndpointConfigBuilder {
let lease = lease.or(endpoint.drt().primary_lease());
let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0);
tracing::debug!(
"Starting endpoint: {}",
endpoint.etcd_path_with_id(lease_id)
);
tracing::debug!("Starting endpoint: {}", endpoint.etcd_path(lease_id));
let service_name = endpoint.component.service_name();
......@@ -115,11 +112,11 @@ impl EndpointConfigBuilder {
// make the components service endpoint discovery in etcd
// client.register_service()
let info = ComponentEndpointInfo {
let info = Instance {
component: endpoint.component.name.clone(),
endpoint: endpoint.name.clone(),
namespace: endpoint.component.namespace.name.clone(),
lease_id,
instance_id: lease_id,
transport: TransportType::NatsTcp(endpoint.subject_to(lease_id)),
};
......@@ -127,7 +124,7 @@ impl EndpointConfigBuilder {
if let Some(etcd_client) = &endpoint.component.drt.etcd_client {
if let Err(e) = etcd_client
.kv_create(endpoint.etcd_path_with_id(lease_id), info, Some(lease_id))
.kv_create(endpoint.etcd_path(lease_id), info, Some(lease_id))
.await
{
tracing::error!("Failed to register discoverable service: {:?}", e);
......
......@@ -85,8 +85,6 @@ impl ServiceConfigBuilder {
.await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
tracing::debug!("Service started TEMP");
// new copy of service_name as the previous one is moved into the task above
let service_name = component.service_name();
......@@ -101,7 +99,6 @@ impl ServiceConfigBuilder {
// drop the guard to unlock the mutex
drop(guard);
tracing::debug!("create done");
Ok(component)
}
}
......
......@@ -25,7 +25,7 @@ use std::{
};
use crate::{
component::{Client, Endpoint, EndpointSource},
component::{Client, Endpoint, InstanceSource},
engine::{AsyncEngine, Data},
pipeline::{AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn},
traits::DistributedRuntimeProvider,
......@@ -41,7 +41,7 @@ where
/// The Client is how we gather remote endpoint information from etcd.
pub client: Client,
/// How we choose which endpoint to send traffic to.
/// How we choose which instance to send traffic to.
///
/// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
/// not using it as an AsyncEngine.
......@@ -52,7 +52,7 @@ where
/// Number of round robin requests handled. Used to decide which server is next.
round_robin_counter: Arc<AtomicU64>,
/// The next step in the chain. PushRouter (this object) picks an endpoint,
/// The next step in the chain. PushRouter (this object) picks an instances,
/// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>,
......@@ -101,25 +101,25 @@ where
})
}
/// Issue a request to the next available endpoint in a round-robin fashion
/// Issue a request to the next available instance in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
let endpoint_id = {
let endpoints = self.client.endpoints();
let count = endpoints.len();
let instance_id = {
let instances = self.client.instances();
let count = instances.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no endpoints found for endpoint {:?}",
self.client.endpoint.etcd_path()
"no instances found for endpoint {:?}",
self.client.endpoint.etcd_root()
));
}
let offset = counter % count as u64;
endpoints[offset as usize].id()
instances[offset as usize].id()
};
tracing::trace!("round robin router selected {endpoint_id}");
tracing::trace!("round robin router selected {instance_id}");
let subject = self.client.endpoint.subject_to(endpoint_id);
let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await
......@@ -127,22 +127,22 @@ where
/// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let endpoint_id = {
let endpoints = self.client.endpoints();
let count = endpoints.len();
let instance_id = {
let instances = self.client.instances();
let count = instances.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no endpoints found for endpoint {:?}",
self.client.endpoint.etcd_path()
"no instances found for endpoint {:?}",
self.client.endpoint.etcd_root()
));
}
let counter = rand::rng().random::<u64>();
let offset = counter % count as u64;
endpoints[offset as usize].id()
instances[offset as usize].id()
};
tracing::trace!("random router selected {endpoint_id}");
tracing::trace!("random router selected {instance_id}");
let subject = self.client.endpoint.subject_to(endpoint_id);
let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await
......@@ -152,22 +152,21 @@ where
pub async fn direct(
&self,
request: SingleIn<T>,
endpoint_id: i64,
instance_id: i64,
) -> anyhow::Result<ManyOut<U>> {
let found = {
let endpoints = self.client.endpoints();
endpoints.iter().any(|ep| ep.id() == endpoint_id)
let instances = self.client.instances();
instances.iter().any(|ep| ep.id() == instance_id)
};
if !found {
return Err(anyhow::anyhow!(
"endpoint_id={} not found for endpoint {:?}",
endpoint_id,
self.client.endpoint.etcd_path()
"instance_id={instance_id} not found for endpoint {:?}",
self.client.endpoint.etcd_root()
));
}
let subject = self.client.endpoint.subject_to(endpoint_id);
let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await
......@@ -189,12 +188,12 @@ where
U: Data + for<'de> Deserialize<'de>,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match &self.client.endpoints {
EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => match self.router_mode {
match &self.client.instances {
InstanceSource::Static => self.r#static(request).await,
InstanceSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter");
}
......
......@@ -114,7 +114,7 @@ mod integration {
.client::<String, Annotated<String>>()
.await?;
client.wait_for_endpoints().await?;
client.wait_for_instances().await?;
let client = Arc::new(client);
let start = Instant::now();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment