Unverified Commit a1a10365 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Split PushRouter from Client (#817)

In a distributed system we don't know if the remote workers need pre-processing done ingress-side or not. Previously Client required us to decide this before discovering the remote endpoints, which was fine because pre-processing was worker-side.

As part of moving pre-processing back to ingress-side we need to split this into two steps:
- Client discovers the endpoints, and (later PR) will fetch their Model Deployment Card.
- PushRouter will use the Model Deployment Card to decide if they need pre-processing or not, which affects the types of the generic parameters.

Part of #743
parent 97bf8184
......@@ -68,7 +68,7 @@ mod namespace;
mod registry;
pub mod service;
pub use client::{Client, RouterMode};
pub use client::{Client, EndpointSource};
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
......@@ -96,6 +96,12 @@ pub struct ComponentEndpointInfo {
pub transport: TransportType,
}
impl ComponentEndpointInfo {
pub fn id(&self) -> i64 {
self.lease_id
}
}
/// A [Component] a discoverable entity in the distributed runtime.
/// You can host [Endpoint] on a [Component] by first creating
/// a [Service] then adding one or more [Endpoint] to the [Service].
......@@ -160,6 +166,10 @@ impl Component {
&self.namespace
}
pub fn name(&self) -> String {
self.name.clone()
}
pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint {
Endpoint {
component: self.clone(),
......@@ -272,11 +282,7 @@ impl Endpoint {
)
}
pub async fn client<Req, Resp>(&self) -> Result<client::Client<Req, Resp>>
where
Req: Serialize + Send + Sync + 'static,
Resp: for<'de> Deserialize<'de> + Send + Sync + 'static,
{
pub async fn client(&self) -> Result<client::Client> {
if self.is_static {
client::Client::new_static(self.clone()).await
} else {
......
......@@ -14,8 +14,8 @@
// limitations under the License.
use crate::pipeline::{
network::egress::push::{AddressedPushRouter, AddressedRequest, PushRouter},
AsyncEngine, Data, ManyOut, SingleIn,
AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
SingleIn,
};
use rand::Rng;
use std::collections::HashMap;
......@@ -25,7 +25,7 @@ use std::sync::{
};
use tokio::{net::unix::pipe::Receiver, sync::Mutex};
use crate::{pipeline::async_trait, transports::etcd::WatchEvent, Error};
use crate::{pipeline::async_trait, transports::etcd::WatchEvent};
use super::*;
......@@ -48,46 +48,26 @@ enum EndpointEvent {
Delete(String),
}
#[derive(Default, Debug, Clone, Copy)]
pub enum RouterMode {
#[default]
Random,
RoundRobin,
//KV,
//
// Always and only go to the given endpoint ID.
// TODO: Is this useful?
Direct(i64),
}
#[derive(Clone)]
pub struct Client<T: Data, U: Data> {
endpoint: Endpoint,
router: PushRouter<T, U>,
counter: Arc<AtomicU64>,
endpoints: EndpointSource,
router_mode: RouterMode,
#[derive(Clone, Debug)]
pub struct Client {
// This is me
pub endpoint: Endpoint,
// These are the remotes I know about
pub endpoints: EndpointSource,
}
#[derive(Clone, Debug)]
enum EndpointSource {
pub enum EndpointSource {
Static,
Dynamic(tokio::sync::watch::Receiver<Vec<i64>>),
Dynamic(tokio::sync::watch::Receiver<Vec<ComponentEndpointInfo>>),
}
impl<T, U> Client<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
impl Client {
// Client will only talk to a single static endpoint
pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
Ok(Client {
router: router(&endpoint).await?,
endpoint,
counter: Arc::new(AtomicU64::new(0)),
endpoints: EndpointSource::Static,
router_mode: Default::default(),
})
}
......@@ -136,7 +116,7 @@ where
let key = String::from_utf8(kv.key().to_vec());
let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value());
if let (Ok(key), Ok(val)) = (key, val) {
map.insert(key.clone(), val.lease_id);
map.insert(key.clone(), val);
} else {
tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
break;
......@@ -153,9 +133,9 @@ where
}
}
let endpoint_ids: Vec<i64> = map.values().cloned().collect();
let endpoints: Vec<ComponentEndpointInfo> = map.values().cloned().collect();
if watch_tx.send(endpoint_ids).is_err() {
if watch_tx.send(endpoints).is_err() {
tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
break;
}
......@@ -167,11 +147,8 @@ where
});
Ok(Client {
router: router(&endpoint).await?,
endpoint,
counter: Arc::new(AtomicU64::new(0)),
endpoints: EndpointSource::Dynamic(watch_rx),
router_mode: Default::default(),
})
}
......@@ -185,135 +162,36 @@ where
self.endpoint.etcd_path()
}
pub fn endpoint_ids(&self) -> Vec<i64> {
pub fn endpoints(&self) -> Vec<ComponentEndpointInfo> {
match &self.endpoints {
EndpointSource::Static => vec![0],
EndpointSource::Static => vec![],
EndpointSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
}
pub fn set_router_mode(&mut self, mode: RouterMode) {
self.router_mode = mode
pub fn endpoint_ids(&self) -> Vec<i64> {
self.endpoints().into_iter().map(|ep| ep.id()).collect()
}
/// Wait for at least one [`Endpoint`] to be available
pub async fn wait_for_endpoints(&self) -> Result<()> {
pub async fn wait_for_endpoints(&self) -> Result<Vec<ComponentEndpointInfo>> {
let mut endpoints: Vec<ComponentEndpointInfo> = vec![];
if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() {
// wait for there to be 1 or more endpoints
loop {
if rx.borrow_and_update().is_empty() {
endpoints = rx.borrow_and_update().to_vec();
if endpoints.is_empty() {
rx.changed().await?;
} else {
break;
}
}
}
Ok(())
Ok(endpoints)
}
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.endpoints, EndpointSource::Static)
}
/// Issue a request to the next available endpoint in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let counter = self.counter.fetch_add(1, Ordering::Relaxed);
let endpoint_id = {
let endpoints = self.endpoint_ids();
let count = endpoints.len();
if count == 0 {
return Err(error!(
"no endpoints found for endpoint {:?}",
self.endpoint.etcd_path()
));
}
let offset = counter % count as u64;
endpoints[offset as usize]
};
tracing::trace!("round robin router selected {endpoint_id}");
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
}
/// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let endpoint_id = {
let endpoints = self.endpoint_ids();
let count = endpoints.len();
if count == 0 {
return Err(error!(
"no endpoints found for endpoint {:?}",
self.endpoint.etcd_path()
));
}
let counter = rand::rng().random::<u64>();
let offset = counter % count as u64;
endpoints[offset as usize]
};
tracing::trace!("random router selected {endpoint_id}");
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
}
/// Issue a request to a specific endpoint
pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> {
let found = {
let endpoints = self.endpoint_ids();
endpoints.contains(&endpoint_id)
};
if !found {
return Err(error!(
"endpoint_id={} not found for endpoint {:?}",
endpoint_id,
self.endpoint.etcd_path()
));
}
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
}
pub async fn r#static(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let subject = self.endpoint.subject();
tracing::debug!("static got subject: {subject}");
let request = request.map(|req| AddressedRequest::new(req, subject));
tracing::debug!("router generate");
self.router.generate(request).await
}
}
async fn router(endpoint: &Endpoint) -> Result<Arc<AddressedPushRouter>> {
AddressedPushRouter::new(
endpoint.component.drt.nats_client.client().clone(),
endpoint.component.drt.tcp_server().await?,
)
}
#[async_trait]
impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for Client<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match &self.endpoints {
EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
},
}
}
}
......@@ -26,6 +26,8 @@ pub use nodes::{
pub mod context;
pub mod error;
pub mod network;
pub use network::egress::addressed_router::{AddressedPushRouter, AddressedRequest};
pub use network::egress::push_router::{PushRouter, RouterMode};
pub mod registry;
pub use crate::engine::{
......
......@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod push;
pub mod addressed_router;
pub mod push_router;
use super::*;
......@@ -41,9 +41,6 @@ struct RequestControlMessage {
connection_info: ConnectionInfo,
}
pub type PushRouter<In, Out> =
Arc<dyn AsyncEngine<SingleIn<AddressedRequest<In>>, ManyOut<Out>, Error>>;
pub struct AddressedRequest<T> {
request: T,
address: String,
......@@ -111,7 +108,7 @@ where
}
};
// separate out the the connection info and the stream provider from the registered stream
// separate out the connection info and the stream provider from the registered stream
let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
// package up the connection info as part of the "header" component of the two part message
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::{
marker::PhantomData,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use crate::{
component::{Client, Endpoint, EndpointSource},
engine::{AsyncEngine, Data},
pipeline::{AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn},
traits::DistributedRuntimeProvider,
};
#[derive(Clone)]
pub struct PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
// TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
/// The Client is how we gather remote endpoint information from etcd.
pub client: Client,
/// How we choose which endpoint to send traffic to.
router_mode: RouterMode,
/// Number of round robin requests handled. Used to decide which server is next.
round_robin_counter: Arc<AtomicU64>,
/// The next step in the chain. PushRouter (this object) picks an endpoint,
/// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>,
/// An internal Rust type. This says that PushRouter is generic over the T and U types,
/// which are the input and output types of it's `generate` function. It allows the
/// compiler to specialize us at compile time.
_phantom: PhantomData<(T, U)>,
}
#[derive(Default, Debug, Clone, Copy)]
pub enum RouterMode {
#[default]
Random,
RoundRobin,
//KV,
//
// Always and only go to the given endpoint ID. Used by Python bindings.
Direct(i64),
}
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
AddressedPushRouter::new(
endpoint.drt().nats_client.client().clone(),
endpoint.drt().tcp_server().await?,
)
}
impl<T, U> PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
Ok(PushRouter {
client,
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
_phantom: PhantomData,
})
}
/// Issue a request to the next available endpoint in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
let endpoint_id = {
let endpoints = self.client.endpoints();
let count = endpoints.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no endpoints found for endpoint {:?}",
self.client.endpoint.etcd_path()
));
}
let offset = counter % count as u64;
endpoints[offset as usize].id()
};
tracing::trace!("round robin router selected {endpoint_id}");
let subject = self.client.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await
}
/// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let endpoint_id = {
let endpoints = self.client.endpoints();
let count = endpoints.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no endpoints found for endpoint {:?}",
self.client.endpoint.etcd_path()
));
}
let counter = rand::rng().random::<u64>();
let offset = counter % count as u64;
endpoints[offset as usize].id()
};
tracing::trace!("random router selected {endpoint_id}");
let subject = self.client.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await
}
/// Issue a request to a specific endpoint
pub async fn direct(
&self,
request: SingleIn<T>,
endpoint_id: i64,
) -> anyhow::Result<ManyOut<U>> {
let found = {
let endpoints = self.client.endpoints();
endpoints.iter().any(|ep| ep.id() == endpoint_id)
};
if !found {
return Err(anyhow::anyhow!(
"endpoint_id={} not found for endpoint {:?}",
endpoint_id,
self.client.endpoint.etcd_path()
));
}
let subject = self.client.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await
}
pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let subject = self.client.endpoint.subject();
tracing::debug!("static got subject: {subject}");
let request = request.map(|req| AddressedRequest::new(req, subject));
tracing::debug!("router generate");
self.addressed.generate(request).await
}
}
#[async_trait]
impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match &self.client.endpoints {
EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
},
}
}
}
......@@ -147,7 +147,7 @@ impl TcpStreamServer {
PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
})?;
tracing::info!("tcp transport service on {}:{}", local_ip, local_port);
tracing::debug!("tcp transport service on {local_ip}:{local_port}");
Ok(Arc::new(Self {
local_ip,
......
......@@ -33,7 +33,7 @@ impl Slug {
/// Create [`Slug`] from a string.
pub fn from_string(s: impl AsRef<str>) -> Slug {
Slug::slugify_unique(s.as_ref())
Slug::slugify(s.as_ref())
}
/// Turn the string into a valid slug, replacing any not-web-or-nats-safe characters with '-'
......@@ -54,7 +54,7 @@ impl Slug {
}
/// Like slugify but also add a four byte hash on the end, in case two different strings slug
/// to the same thing.
/// to the same thing (e.g. because of case differences).
pub fn slugify_unique(s: &str) -> Slug {
let out = s
.to_lowercase()
......
......@@ -25,7 +25,8 @@ use tokio::sync::{mpsc, RwLock};
use validator::Validate;
use etcd_client::{
Compare, CompareOp, GetOptions, PutOptions, Txn, TxnOp, TxnOpResponse, WatchOptions, Watcher,
Compare, CompareOp, DeleteOptions, GetOptions, PutOptions, PutResponse, Txn, TxnOp,
TxnOpResponse, WatchOptions, Watcher,
};
pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient};
......@@ -172,13 +173,14 @@ impl Client {
value: Vec<u8>,
lease_id: Option<i64>,
) -> Result<()> {
let put_options = lease_id.map(|id| PutOptions::new().with_lease(id));
let id = lease_id.unwrap_or(self.lease_id());
let put_options = PutOptions::new().with_lease(id);
// Build the transaction
let txn = Txn::new()
.when(vec![Compare::version(key.as_str(), CompareOp::Equal, 0)]) // Ensure the lock does not exist
.and_then(vec![
TxnOp::put(key.as_str(), value, put_options), // Create the object
TxnOp::put(key.as_str(), value, Some(put_options)), // Create the object
]);
// Execute the transaction
......@@ -201,14 +203,15 @@ impl Client {
value: Vec<u8>,
lease_id: Option<i64>,
) -> Result<()> {
let put_options = lease_id.map(|id| PutOptions::new().with_lease(id));
let id = lease_id.unwrap_or(self.lease_id());
let put_options = PutOptions::new().with_lease(id);
// Build the transaction that either creates the key if it doesn't exist,
// or validates the existing value matches what we expect
let txn = Txn::new()
.when(vec![Compare::version(key.as_str(), CompareOp::Equal, 0)]) // Key doesn't exist
.and_then(vec![
TxnOp::put(key.as_str(), value.clone(), put_options), // Create it
TxnOp::put(key.as_str(), value.clone(), Some(put_options)), // Create it
])
.or_else(vec![
// If key exists but values don't match, this will fail the transaction
......@@ -245,19 +248,54 @@ impl Client {
value: impl AsRef<[u8]>,
lease_id: Option<i64>,
) -> Result<()> {
let id = lease_id.unwrap_or(self.lease_id());
let put_options = PutOptions::new().with_lease(id);
let _ = self
.client
.kv_client()
.put(
key.as_ref(),
value.as_ref(),
lease_id.map(|id| PutOptions::new().with_lease(id)),
)
.put(key.as_ref(), value.as_ref(), Some(put_options))
.await?;
Ok(())
}
pub async fn kv_put_with_options(
&self,
key: impl AsRef<str>,
value: impl AsRef<[u8]>,
options: Option<PutOptions>,
) -> Result<PutResponse> {
let options = options
.unwrap_or_default()
.with_lease(self.primary_lease().id());
self.client
.kv_client()
.put(key.as_ref(), value.as_ref(), Some(options))
.await
.map_err(|err| err.into())
}
pub async fn kv_get(
&self,
key: impl Into<Vec<u8>>,
options: Option<GetOptions>,
) -> Result<Vec<KeyValue>> {
let mut get_response = self.client.kv_client().get(key, options).await?;
Ok(get_response.take_kvs())
}
pub async fn kv_delete(
&self,
key: impl Into<Vec<u8>>,
options: Option<DeleteOptions>,
) -> Result<i64> {
self.client
.kv_client()
.delete(key, options)
.await
.map(|del_response| del_response.deleted())
.map_err(|err| err.into())
}
pub async fn kv_get_prefix(&self, prefix: impl AsRef<str>) -> Result<Vec<KeyValue>> {
let mut get_response = self
.client
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment