Commit 5ed8c1c0 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: rust - initial commit

the journey begins
parent 4017bd18
This diff is collapsed.
[package]
name = "triton-distributed"
version = "0.1.1"
edition = "2021"
authors = ["NVIDIA"]
homepage = "https://github.com/triton-inference-server/triton_distributed"
[dependencies]
# workspace - when we expand to multiple crates; put these in the workspace
anyhow = { version = "1" }
async-nats = { version = "0.38", features = ["service"] }
async-stream = { version = "0.3" }
async-trait = { version = "0.1" }
blake3 = "1"
bytes = "1"
derive_builder = "0.20"
derive-getters = "0.5"
either = { version = "1.13", features = ["serde"] }
figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] }
futures = { version = "0.3" }
once_cell = "1"
prometheus = { version = "0.13" }
regex = { version = "1" }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = { version = "1" }
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1" }
tokio-util = { version = "0.7", features = ["codec", "net"] }
tracing = { version = "0.1" }
uuid = { version = "1", features = ["v4", "serde"] }
validator = { version = "0.20", features = ["derive"] }
xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] }
# non-workspace
async-once-cell = "0.5.4"
educe = "0.6.0"
etcd-client = "0.14"
local-ip-address = { version = "0.6.3" }
nid = { version = "3.0.0", features = ["serde"] }
nix = { version = "0.29", features = ["signal"] }
nuid = { version = "0.5" }
rand = { version = "0.8"}
[dev-dependencies]
assert_matches = "1.5.0"
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! The [Component] module defines the top-level API for building distributed applications.
//!
//! A distributed application consists of a set of [Component][Component] that can host one
//! or more [Endpoint][Endpoint]. Each [Endpoint][Endpoint] is a network-accessible service
//! that can be accessed by other [Component][Component] in the distributed application.
//!
//! A [Component] is made discoverable by registering it with the distributed runtime under
//! a [`Namespace`].
//!
//! A [`Namespace`] is a logical grouping of [Component][Component] that are grouped together.
//!
//! We might extend namespace to include grouping behavior, which would define groups of
//! components that are tightly coupled.
//!
//! A [Component] is the core building block of a distributed application. It is a logical
//! unit of work such as a `Preprocessor` or `SmartRouter` that has a well-defined role in the
//! distributed application.
//!
//! A [Component] can present to the distributed application one or more configuration files
//! which define how that component was constructed/configured and what capabilities it can
//! provide.
//!
//! Other [Component][Component] can write to watching locations within a [Component] etcd
//! path. This allows the [Component] to take dynamic actions depending on the watch
//! triggers.
//!
//! TODO: Top-level Overview of Endpoints/Functions
use crate::discovery::Lease;
use super::{error, log, transports::nats::Slug, DistributedRuntime, Result};
use crate::pipeline::network::{ingress::push_endpoint::PushEndpoint, PushWorkHandler};
use async_nats::{
rustls::quic,
service::{Service, ServiceExt},
};
use derive_builder::Builder;
use derive_getters::Getters;
use educe::Educe;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};
use validator::{Validate, ValidationError};
mod client;
mod endpoint;
mod registry;
mod service;
pub use client::Client;
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TransportType {
NatsTcp(String),
}
#[derive(Clone)]
pub struct Registry {
services: Arc<tokio::sync::Mutex<HashMap<String, Service>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComponentEndpointInfo {
pub component: String,
pub endpoint: String,
pub namespace: String,
pub lease_id: i64,
pub transport: TransportType,
}
/// A [Component] a discoverable entity in the distributed runtime.
/// You can host [Endpoint][Endpoint] on a [Component] by first creating
/// a [Service] then adding one or more [Endpoint][Endpoint] to the [Service].
///
/// You can also issue a request to a [Component]'s [Endpoint] by creating a [Client].
#[derive(Educe, Builder, Clone)]
#[educe(Debug)]
#[builder(pattern = "owned")]
pub struct Component {
#[builder(private)]
#[educe(Debug(ignore))]
drt: DistributedRuntime,
// todo - restrict the namespace to a-z0-9-_A-Z
/// Name of the component
#[builder(setter(into))]
name: String,
// todo - restrict the namespace to a-z0-9-_A-Z
/// Namespace
#[builder(setter(into))]
namespace: String,
}
impl Component {
pub fn etcd_path(&self) -> String {
format!("{}/components/{}", self.namespace, self.name)
}
fn slug(&self) -> Slug {
Slug::from_string(self.etcd_path())
}
pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint {
Endpoint {
component: self.clone(),
name: endpoint.into(),
}
}
/// Get keys from etcd on the slug, splitting the endpoints and only returning the
/// set of unique endpoints.
pub async fn list_endpoints(&self) -> Vec<Endpoint> {
unimplemented!("endpoints")
}
/// This method will scrape the stats for all available services
/// Returns a stream of [`ServiceInfo`] objects.
/// This should be consumed by a `[tokio::time::timeout_at`] because each services
/// will only respond once, but there is no way to know when all services have responded.
pub async fn stats_stream(&self) -> Result<()> {
unimplemented!("collect_stats")
}
pub fn service_builder(&self) -> service::ServiceConfigBuilder {
service::ServiceConfigBuilder::from_component(self.clone())
}
}
impl ComponentBuilder {
pub fn from_runtime(drt: DistributedRuntime) -> Self {
Self::default().drt(drt)
}
}
#[derive(Debug, Clone)]
pub struct Endpoint {
component: Component,
// todo - restrict alphabet
/// Endpoint name
name: String,
}
impl Endpoint {
pub fn name(&self) -> &str {
&self.name
}
pub fn etcd_path(&self) -> String {
format!("{}/{}", self.component.etcd_path(), self.name)
}
pub fn etcd_path_with_id(&self, lease_id: i64) -> String {
format!("{}:{:x}", self.etcd_path(), lease_id)
}
pub fn name_with_id(&self, lease_id: i64) -> String {
format!("{}-{:x}", self.name, lease_id)
}
pub fn subject(&self, lease_id: i64) -> String {
format!("{}.{}", self.component.slug(), self.name_with_id(lease_id))
}
pub async fn client<Req, Resp>(&self) -> Result<client::Client<Req, Resp>>
where
Req: Serialize + Send + Sync + 'static,
Resp: for<'de> Deserialize<'de> + Send + Sync + 'static,
{
client::Client::new(self.clone()).await
}
pub fn endpoint_builder(&self) -> endpoint::EndpointConfigBuilder {
endpoint::EndpointConfigBuilder::from_endpoint(self.clone())
}
}
#[derive(Educe, Builder, Clone, Validate)]
#[educe(Debug)]
#[builder(pattern = "owned")]
pub struct Namespace {
#[builder(private)]
#[educe(Debug(ignore))]
runtime: DistributedRuntime,
#[validate()]
name: String,
}
impl Namespace {
pub(crate) fn new(runtime: DistributedRuntime, name: String) -> Result<Self> {
Ok(NamespaceBuilder::default()
.runtime(runtime)
.name(name)
.build()?)
}
/// Create a [`Component`] in the namespace
pub fn component(&self, name: impl Into<String>) -> Result<Component> {
Ok(ComponentBuilder::from_runtime(self.runtime.clone())
.name(name)
.namespace(self.name.clone())
.build()?)
}
}
// Custom validator function
fn validate_allowed_chars(input: &str) -> Result<(), ValidationError> {
// Define the allowed character set using a regex
let regex = regex::Regex::new(r"^[a-z0-9-_]+$").unwrap();
if regex.is_match(input) {
Ok(())
} else {
Err(ValidationError::new("invalid_characters"))
}
}
// TODO - enable restrictions to the character sets allowed for namespaces,
// components, and endpoints.
//
// Put Validate traits on the struct and use the `validate_allowed_chars` method
// to validate the fields.
// #[cfg(test)]
// mod tests {
// use super::*;
// use validator::Validate;
// #[test]
// fn test_valid_names() {
// // Valid strings
// let valid_inputs = vec![
// "abc", // Lowercase letters
// "abc123", // Letters and numbers
// "a-b-c", // Letters with hyphens
// "a_b_c", // Letters with underscores
// "a-b_c-123", // Mixed valid characters
// "a", // Single character
// "a_b", // Short valid pattern
// "123456", // Only numbers
// "a---b_c123", // Repeated hyphens/underscores
// ];
// for input in valid_inputs {
// let result = validate_allowed_chars(input);
// assert!(result.is_ok(), "Expected '{}' to be valid", input);
// }
// }
// #[test]
// fn test_invalid_names() {
// // Invalid strings
// let invalid_inputs = vec![
// "abc!", // Invalid character `!`
// "abc@", // Invalid character `@`
// "123$", // Invalid character `$`
// "foo.bar", // Invalid character `.`
// "foo/bar", // Invalid character `/`
// "foo\\bar", // Invalid character `\`
// "abc#", // Invalid character `#`
// "abc def", // Spaces are not allowed
// "foo,", // Invalid character `,`
// "", // Empty string
// ];
// for input in invalid_inputs {
// let result = validate_allowed_chars(input);
// assert!(result.is_err(), "Expected '{}' to be invalid", input);
// }
// }
// // #[test]
// // fn test_struct_validation_valid() {
// // // Struct with valid data
// // let valid_data = InputData {
// // name: "valid-name_123".to_string(),
// // };
// // assert!(valid_data.validate().is_ok());
// // }
// // #[test]
// // fn test_struct_validation_invalid() {
// // // Struct with invalid data
// // let invalid_data = InputData {
// // name: "invalid!name".to_string(),
// // };
// // let result = invalid_data.validate();
// // assert!(result.is_err());
// // if let Err(errors) = result {
// // let error_map = errors.field_errors();
// // assert!(error_map.contains_key("name"));
// // let name_errors = &error_map["name"];
// // assert_eq!(name_errors[0].code, "invalid_characters");
// // }
// // }
// #[test]
// fn test_edge_cases() {
// // Edge cases
// let edge_inputs = vec![
// ("-", true), // Single hyphen
// ("_", true), // Single underscore
// ("a-", true), // Letter with hyphen
// ("-", false), // Repeated hyphens
// ("-a", false), // Hyphen at the beginning
// ("a-", false), // Hyphen at the end
// ];
// for (input, expected_validity) in edge_inputs {
// let result = validate_allowed_chars(input);
// if expected_validity {
// assert!(result.is_ok(), "Expected '{}' to be valid", input);
// } else {
// assert!(result.is_err(), "Expected '{}' to be invalid", input);
// }
// }
// }
// }
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use crate::pipeline::{
network::egress::push::{AddressedPushRouter, AddressedRequest, PushRouter},
AsyncEngine, Data, ManyOut, SingleIn,
};
use rand::Rng;
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use tokio::{net::unix::pipe::Receiver, sync::Mutex};
use crate::{pipeline::async_trait, transports::etcd::WatchEvent, Error};
use super::*;
/// Each state will be have a nonce associated with it
/// The state will be emitted in a watch channel, so we can observe the
/// critical state transitions.
enum MapState {
/// The map is empty; value = nonce
Empty(u64),
/// The map is not-empty; values are (nonce, count)
NonEmpty(u64, u64),
/// The watcher has finished, no more events will be emitted
Finished,
}
enum EndpointEvent {
Put(String, i64),
Delete(String),
}
#[derive(Clone)]
pub struct Client<T: Data, U: Data> {
endpoint: Endpoint,
router: PushRouter<T, U>,
watch_rx: tokio::sync::watch::Receiver<Vec<i64>>,
counter: Arc<AtomicU64>,
}
impl<T, U> Client<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
let router = AddressedPushRouter::new(
endpoint.component.drt.nats_client.client().clone(),
endpoint.component.drt.tcp_server().await?,
)?;
// create live endpoint watcher
let prefix_watcher = endpoint
.component
.drt
.etcd_client
.kv_get_and_watch_prefix(endpoint.etcd_path())
.await?;
let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
let secondary = endpoint.component.drt.runtime.secondary().clone();
// this task should be included in the registry
// currently this is created once per client, but this object/task should only be instantiated
// once per worker/instance
secondary.spawn(async move {
log::debug!("Starting endpoint watcher for prefix: {}", prefix);
let mut map = HashMap::new();
loop {
let kv_event = tokio::select! {
_ = watch_tx.closed() => {
log::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {}", prefix);
break;
}
kv_event = kv_event_rx.recv() => {
match kv_event {
Some(kv_event) => kv_event,
None => {
log::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {}", prefix);
break;
}
}
}
};
match kv_event {
WatchEvent::Put(kv) => {
let key = String::from_utf8(kv.key().to_vec());
let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value());
if let (Ok(key), Ok(val)) = (key, val) {
map.insert(key.clone(), val.lease_id);
} else {
log::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
break;
}
}
WatchEvent::Delete(kv) => {
match String::from_utf8(kv.key().to_vec()) {
Ok(key) => { map.remove(&key); }
Err(_) => {
log::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
break;
}
}
}
}
let endpoint_ids: Vec<i64> = map.values().cloned().collect();
if watch_tx.send(endpoint_ids).is_err() {
log::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
break;
}
}
log::debug!("Completed endpoint watcher for prefix: {}", prefix);
let _ = watch_tx.send(vec![]);
});
Ok(Client {
endpoint,
router,
watch_rx,
counter: Arc::new(AtomicU64::new(0)),
})
}
pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> {
&self.watch_rx
}
/// Wait for at least one [`Endpoint`] to be available
pub async fn wait_for_endpoints(&self) -> Result<()> {
let mut rx = self.watch_rx.clone();
// wait for there to be 1 or more endpoints
loop {
if rx.borrow_and_update().is_empty() {
rx.changed().await?;
} else {
break;
}
}
Ok(())
}
/// Issue a request to the next available endpoint in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let counter = self.counter.fetch_add(1, Ordering::Relaxed);
let endpoint_id = {
let endpoints = self.watch_rx.borrow();
let count = endpoints.len();
if count == 0 {
return Err(error!(
"no endpoints found for endpoint {:?}",
self.endpoint.etcd_path()
));
}
let offset = counter % count as u64;
endpoints[offset as usize]
};
let subject = self.endpoint.subject(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
}
/// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let endpoint_id = {
let endpoints = self.watch_rx.borrow();
let count = endpoints.len();
if count == 0 {
return Err(error!(
"no endpoints found for endpoint {:?}",
self.endpoint.etcd_path()
));
}
let counter = rand::thread_rng().gen::<u64>();
let offset = counter % count as u64;
endpoints[offset as usize]
};
let subject = self.endpoint.subject(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
}
/// Issue a request to a specific endpoint
pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> {
let found = {
let endpoints = self.watch_rx.borrow();
endpoints.contains(&endpoint_id)
};
if !found {
return Err(error!(
"endpoint_id={} not found for endpoint {:?}",
endpoint_id,
self.endpoint.etcd_path()
));
}
let subject = self.endpoint.subject(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
}
}
#[async_trait]
impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for Client<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
self.random(request).await
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use derive_getters::Dissolve;
use super::*;
#[derive(Educe, Builder, Dissolve)]
#[educe(Debug)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
pub struct EndpointConfig {
#[builder(private)]
endpoint: Endpoint,
/// Lease
#[educe(Debug(ignore))]
#[builder(default)]
lease: Option<Lease>,
/// Endpoint handler
#[educe(Debug(ignore))]
handler: Arc<dyn PushWorkHandler>,
}
impl EndpointConfigBuilder {
pub(crate) fn from_endpoint(endpoint: Endpoint) -> Self {
Self::default().endpoint(endpoint)
}
pub async fn start(self) -> Result<()> {
let (endpoint, lease, handler) = self.build_internal()?.dissolve();
let lease = lease.unwrap_or(endpoint.component.drt.primary_lease());
log::debug!(
"Starting endpoint: {}",
endpoint.etcd_path_with_id(lease.id())
);
let group = endpoint
.component
.drt
.component_registry
.services
.lock()
.await
.get(&endpoint.component.etcd_path())
.map(|service| service.group(endpoint.component.slug()))
.ok_or(error!("Service not found"))?;
// let group = service.group(service_name.as_str());
// creates an endpoint for the service
let service_endpoint = group
.endpoint(&endpoint.name_with_id(lease.id()))
.await
.map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
let cancel_token = lease.child_token();
let push_endpoint = PushEndpoint::builder()
.service_handler(handler)
.cancellation_token(cancel_token.clone())
.build()
.map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?;
// launch in primary runtime
let task = tokio::spawn(push_endpoint.start(service_endpoint));
// log::debug!(worker_id, "endpoint subject: {}", subject);
// make the components service endpoint discovery in etcd
// client.register_service()
let info = ComponentEndpointInfo {
component: endpoint.component.name.clone(),
endpoint: endpoint.name.clone(),
namespace: endpoint.component.namespace.clone(),
lease_id: lease.id(),
transport: TransportType::NatsTcp(endpoint.subject(lease.id())),
};
let info = serde_json::to_vec_pretty(&info)?;
if let Err(e) = endpoint
.component
.drt
.etcd_client
.kv_create(
endpoint.etcd_path_with_id(lease.id()),
info,
Some(lease.id()),
)
.await
{
log::error!("Failed to register discoverable service: {:?}", e);
cancel_token.cancel();
return Err(error!("Failed to register discoverable service"));
}
task.await??;
Ok(())
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::{Component, Registry, Result};
use async_once_cell::OnceCell;
use std::{
collections::HashMap,
sync::{Arc, Weak},
};
use tokio::sync::Mutex;
impl Default for Registry {
fn default() -> Self {
Self::new()
}
}
impl Registry {
pub fn new() -> Self {
Self {
services: Arc::new(Mutex::new(HashMap::new())),
}
}
}
// impl ComponentRegistry {
// pub fn new() -> Self {
// Self {
// clients: Arc::new(Mutex::new(HashMap::new())),
// }
// }
// pub async fn get_or_create(&mut self, component: Component) -> Result<Arc<Client>> {
// // Lock the clients HashMap for thread-safe access
// let mut guard = self.clients.lock().await;
// // Check if the component already exists in the registry
// if let Some(weak) = guard.get(&component.slug()) {
// // Attempt to upgrade the Weak pointer
// if let Some(client) = weak.upgrade() {
// return Ok(client);
// }
// }
// // Fallback: Create a new Client
// let client = component.client().await?;
// // Insert a Weak reference to the new client into the map
// guard.insert(component.slug(), Arc::downgrade(&client));
// Ok(client)
// }
// }
// #[derive(Clone)]
// pub struct ServiceRegistry {
// clients: Arc<Mutex<HashMap<String, Arc<Service>>>>,
// }
// impl ServiceRegistry {
// pub fn new() -> Self {
// Self {
// clients: Arc::new(Mutex::new(HashMap::new())),
// }
// }
// pub async fn get_or_create(&mut self, component: Component) -> Result<Arc<Client>> {
// // Lock the clients HashMap for thread-safe access
// let mut guard = self.clients.lock().await;
// // Check if the component already exists in the registry
// if let Some(weak) = guard.get(&component.slug()) {
// // Attempt to upgrade the Weak pointer
// if let Some(client) = weak.upgrade() {
// return Ok(client);
// }
// }
// // Fallback: Create a new Client
// let client = component.client().await?;
// // Insert a Weak reference to the new client into the map
// guard.insert(component.slug(), Arc::downgrade(&client));
// Ok(client)
// }
// }
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use derive_getters::Dissolve;
use super::*;
use async_nats::service::{endpoint, Service};
pub type StatsHandler =
Box<dyn FnMut(String, endpoint::Stats) -> serde_json::Value + Send + Sync + 'static>;
#[derive(Educe, Builder, Dissolve)]
#[educe(Debug)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
pub struct ServiceConfig {
#[builder(private)]
component: Component,
/// Description
#[builder(default)]
description: Option<String>,
// todo - make optional - if None, then skip making the endpoint
// and skip making the service-endpoint discoverable.
/// Endpoint handler
#[educe(Debug(ignore))]
#[builder(default)]
stats_handler: Option<StatsHandler>,
}
impl ServiceConfigBuilder {
/// Create the [`Component`]'s service and store it in the registry.
pub async fn create(self) -> Result<Component> {
let version = "0.0.1".to_string();
let (component, description, stat_handler) = self.build_internal()?.dissolve();
let service_name = component.slug();
let description = description.unwrap_or(format!(
"Triton Component {} in {}",
component.name, component.namespace
));
let mut guard = component.drt.component_registry.services.lock().await;
if guard.contains_key(&component.etcd_path()) {
return Err(anyhow::anyhow!("Service already exists"));
}
// create service on the secondary runtime
let secondary = component.drt.runtime.secondary.clone();
let builder = component.drt.nats_client.client().service_builder();
let service = secondary
.spawn(async move {
// unwrap the stats handler
let builder = match stat_handler {
Some(handler) => builder.stats_handler(handler),
None => builder,
};
log::debug!("Starting service: {}", service_name);
builder
.description(description)
.start(service_name.to_string(), version)
.await
})
.await?
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
guard.insert(component.etcd_path(), service);
drop(guard);
Ok(component)
}
}
impl ServiceConfigBuilder {
pub(crate) fn from_component(component: Component) -> Self {
Self::default().component(component)
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::Result;
use derive_builder::Builder;
use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
};
use serde::{Deserialize, Serialize};
use validator::Validate;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerConfig {
/// Grace shutdown period for http-service.
pub graceful_shutdown_timeout: u64,
}
impl WorkerConfig {
pub fn from_settings() -> Self {
// Instantiates and reads server configurations from appropriate sources.
// All calls should be global and thread safe.
Figment::new()
.merge(Serialized::defaults(Self::default()))
.merge(Env::prefixed("TRITON_WORKER_"))
.extract()
.unwrap()
}
}
impl Default for WorkerConfig {
fn default() -> Self {
WorkerConfig {
graceful_shutdown_timeout: if cfg!(debug_assertions) {
1 // Debug build: 1 second
} else {
30 // Release build: 30 seconds
},
}
}
}
/// Runtime configuration
/// Defines the configuration for Tokio runtimes
#[derive(Serialize, Deserialize, Validate, Debug, Builder, Clone)]
#[builder(build_fn(private, name = "build_internal"), derive(Debug, Serialize))]
pub struct RuntimeConfig {
/// Maximum number of async worker threads
/// If set to 1, the runtime will run in single-threaded mode
#[validate(range(min = 1))]
#[builder(default = "16")]
#[builder_field_attr(serde(skip_serializing_if = "Option::is_none"))]
pub max_worker_threads: usize,
/// Maximum number of blocking threads
/// Blocking threads are used for blocking operations, this value must be greater than 0.
#[validate(range(min = 1))]
#[builder(default = "16")]
#[builder_field_attr(serde(skip_serializing_if = "Option::is_none"))]
pub max_blocking_threads: usize,
}
impl RuntimeConfig {
pub fn builder() -> RuntimeConfigBuilder {
RuntimeConfigBuilder::default()
}
pub(crate) fn figment() -> Figment {
Figment::new()
.merge(Serialized::defaults(RuntimeConfig::default()))
.merge(Toml::file("/opt/triton/defaults/runtime.toml"))
.merge(Toml::file("/opt/triton/etc/runtime.toml"))
.merge(Env::prefixed("TRITON_RUNTIME_"))
}
/// Load the runtime configuration from the environment and configuration files
/// Configuration is priorities in the following order, where the last has the lowest priority:
/// 1. Environment variables (top priority)
/// 2. /opt/triton/etc/runtime.toml
/// 3. /opt/triton/defaults/runtime.toml (lowest priority)
///
/// Environment variables are prefixed with `TRITON_RUNTIME_`
pub fn from_settings() -> Result<RuntimeConfig> {
let config: RuntimeConfig = Self::figment().extract()?;
config.validate()?;
Ok(config)
}
pub fn single_threaded() -> Self {
RuntimeConfig {
max_worker_threads: 1,
max_blocking_threads: 1,
}
}
/// Create a new default runtime configuration
pub(crate) fn create_runtime(&self) -> Result<tokio::runtime::Runtime> {
Ok(tokio::runtime::Builder::new_multi_thread()
.worker_threads(self.max_worker_threads)
.max_blocking_threads(self.max_blocking_threads)
.enable_all()
.build()?)
}
}
impl Default for RuntimeConfig {
fn default() -> Self {
Self::single_threaded()
}
}
impl RuntimeConfigBuilder {
/// Build and validate the runtime configuration
pub fn build(&self) -> Result<RuntimeConfig> {
let config = self.build_internal()?;
config.validate()?;
Ok(config)
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use crate::{transports::etcd, Result};
pub use etcd::Lease;
pub struct DiscoveryClient {
namespace: String,
etcd_client: etcd::Client,
}
impl DiscoveryClient {
/// Create a new [`DiscoveryClient`]
///
/// This will establish a connection to the etcd server, create a primary lease,
/// and spawn a task to keep the lease alive and tie the lifetime of the [`Runtime`]
/// to the lease.
///
/// If the lease expires, the [`Runtime`] will be shutdown.
/// If the [`Runtime`] is shutdown, the lease will be revoked.
pub(crate) fn new(namespace: String, etcd_client: etcd::Client) -> Self {
DiscoveryClient {
namespace,
etcd_client,
}
}
/// Get the primary lease ID
pub fn primary_lease_id(&self) -> i64 {
self.etcd_client.lease_id()
}
/// Create a [`Lease`] with a given time-to-live (TTL).
/// This [`Lease`] will be tied to the [`Runtime`], but has its own independent [`crate::CancellationToken`].
pub async fn create_lease(&self, ttl: i64) -> Result<Lease> {
self.etcd_client.create_lease(ttl).await
}
// the following two commented out codes are not implemented, but are placeholders for proposed ectd usage patterns
// /// Create an ephemeral key/value pair tied to a lease_id.
// /// This is an atomic create. If the key already exists, this will fail.
// /// The [`etcd_client::KeyValue`] will be removed when the lease expires or is revoked.
// pub async fn create_ephemerial_key(&self, key: &str, value: &str, lease_id: i64) -> Result<()> {
// // self.etcd_client.create_ephemeral_key(key, value, lease_id).await
// unimplemented!()
// }
// /// Create a shared [`etcd_client::KeyValue`] which behaves similar to a C++ `std::shared_ptr` or a
// /// Rust [std::sync::Arc]. Instead of having one owner of the lease, multiple owners participate in
// /// maintaining the lease. In this manner, when the last member of the group sharing the lease is gone,
// /// the lease will be expired.
// ///
// /// Implementation notes: At the time of writing, it is unclear if we have atomics that control leases,
// /// so in our initial implementation, the last member of the group will not revoke the lease, so the object
// /// will live for upto the TTL after the last member is gone.
// ///
// /// Notes
// /// -----
// ///
// /// - Multiple members sharing the lease and contributing to the heartbeat might cause some overheads.
// /// The implementation will try to randomize the heartbeat intervals to avoid thundering herd problem,
// /// and with any luck, the heartbeat watchers will be able to detect when if a external member triggered
// /// the heartbeat checking this interval and skip unnecessary heartbeat messages.
// ///
// /// A new lease will be created for this object. If you wish to add an object to a shared group s
// ///
// /// The [`etcd_client::KeyValue`] will be removed when the lease expires or is revoked.
// pub async fn create_shared_key(&self, key: &str, value: &str, lease_id: i64) -> Result<()> {
// // self.etcd_client.create_ephemeral_key(key, value, lease_id).await
// unimplemented!()
// }
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
pub use crate::component::Component;
use crate::{
component::{self, ComponentBuilder, Namespace},
discovery::DiscoveryClient,
service::ServiceClient,
transports::{etcd, nats, tcp},
ErrorContext,
};
use super::{error, Arc, DistributedRuntime, OnceCell, Result, Runtime, OK};
use derive_getters::Dissolve;
use figment::error;
impl DistributedRuntime {
pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> {
let secondary = runtime.secondary();
let (etcd_config, nats_config) = config.dissolve();
let runtime_clone = runtime.clone();
let etcd_client = secondary
.spawn(async move {
let client = etcd::Client::new(etcd_config.clone(), runtime_clone)
.await
.context(format!(
"Failed to connect to etcd server with config {:?}",
etcd_config
))?;
OK(client)
})
.await??;
let nats_client = secondary
.spawn(async move {
let client = nats_config.clone().connect().await.context(format!(
"Failed to connect to NATS server with config {:?}",
nats_config
))?;
anyhow::Ok(client)
})
.await??;
Ok(Self {
runtime,
etcd_client,
nats_client,
tcp_server: Arc::new(OnceCell::new()),
component_registry: component::Registry::new(),
})
}
pub async fn from_settings(runtime: Runtime) -> Result<Self> {
let config = DistributedConfig::from_settings();
Self::new(runtime, config).await
}
pub fn runtime(&self) -> &Runtime {
&self.runtime
}
pub fn primary_lease(&self) -> etcd::Lease {
self.etcd_client.primary_lease()
}
pub fn shutdown(&self) {
self.runtime.shutdown();
}
/// Create a [`Namespace`]
pub fn namespace(&self, name: impl Into<String>) -> Result<Namespace> {
Namespace::new(self.clone(), name.into())
}
// /// Create a [`Component`]
// pub fn component(
// &self,
// name: impl Into<String>,
// namespace: impl Into<String>,
// ) -> Result<Component> {
// Ok(ComponentBuilder::from_runtime(self.clone())
// .name(name.into())
// .namespace(namespace.into())
// .build()?)
// }
pub(crate) fn discovery_client(&self, namespace: impl Into<String>) -> DiscoveryClient {
DiscoveryClient::new(namespace.into(), self.etcd_client.clone())
}
pub(crate) fn service_client(&self) -> ServiceClient {
ServiceClient::new(self.nats_client.clone())
}
pub(crate) async fn tcp_server(&self) -> Result<Arc<tcp::server::TcpStreamServer>> {
Ok(self
.tcp_server
.get_or_try_init(async move {
let options = tcp::server::ServerOptions::default();
let server = tcp::server::TcpStreamServer::new(options).await?;
OK(server)
})
.await?
.clone())
}
pub fn nats_client(&self) -> nats::Client {
self.nats_client.clone()
}
pub fn etcd_client(&self) -> etcd::Client {
self.etcd_client.clone()
}
}
#[derive(Dissolve)]
pub struct DistributedConfig {
pub etcd_config: etcd::ClientOptions,
pub nats_config: nats::ClientOptions,
}
impl DistributedConfig {
pub fn from_settings() -> DistributedConfig {
DistributedConfig {
etcd_config: etcd::ClientOptions::default(),
nats_config: nats::ClientOptions::default(),
}
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc};
pub use async_trait::async_trait;
use futures::stream::Stream;
/// All [`Send`] + [`Sync`] + `'static` types can be used as [`AsyncEngine`] request and response types.
pub trait Data: Send + Sync + 'static {}
impl<T: Send + Sync + 'static> Data for T {}
/// [`DataStream`] is a type alias for a stream of [`Data`] items. This can be adapted to a [`ResponseStream`]
/// by associating it with a [`AsyncEngineContext`].
pub type DataUnary<T> = Pin<Box<dyn Future<Output = T> + Send + Sync>>;
pub type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
pub type Engine<Req, Resp, E> = Arc<dyn AsyncEngine<Req, Resp, E>>;
pub type EngineUnary<Resp> = Pin<Box<dyn AsyncEngineUnary<Resp>>>;
pub type EngineStream<Resp> = Pin<Box<dyn AsyncEngineStream<Resp>>>;
pub type Context = Arc<dyn AsyncEngineContext>;
impl<T: Data> From<EngineStream<T>> for DataStream<T> {
fn from(stream: EngineStream<T>) -> Self {
Box::pin(stream)
}
}
// The Controller and the Context when https://github.com/rust-lang/rust/issues/65991 becomes stable
pub trait AsyncEngineController: Send + Sync {}
/// The [`AsyncEngineContext`] trait defines the interface to control the resulting stream
/// produced by the engine.
#[async_trait]
pub trait AsyncEngineContext: Send + Sync + Debug {
/// Unique ID for the Stream
fn id(&self) -> &str;
/// Returns true if `stop_generating()` has been called; otherwise, false.
fn is_stopped(&self) -> bool;
/// Returns true if `kill()` has been called; otherwise, false.
/// This can be used with a `.take_while()` stream combinator to immediately terminate
/// the stream.
///
/// An ideal location for a `[.take_while(!ctx.is_killed())]` stream combinator is on
/// the most downstream return stream.
fn is_killed(&self) -> bool;
/// Calling this method when [`AsyncEngineContext::is_stopped`] is `true` will return
/// immediately; otherwise, it will [`AsyncEngineContext::is_stopped`] will return true.
async fn stopped(&self);
/// Calling this method when [`AsyncEngineContext::is_killed`] is `true` will return
/// immediately; otherwise, it will [`AsyncEngineContext::is_killed`] will return true.
async fn killed(&self);
// Controller
/// Informs the [`AsyncEngine`] to stop producing results for this particular stream.
/// This method is idempotent. This method does not invalidate results current in the
/// stream. It might take some time for the engine to stop producing results. The caller
/// can decided to drain the stream or drop the stream.
fn stop_generating(&self);
/// See [`AsyncEngineContext::stop_generating`].
fn stop(&self);
/// Extends the [`AsyncEngineContext::stop_generating`] also indicates a preference to
/// terminate without draining the remaining items in the stream. This is implementation
/// specific and may not be supported by all engines.
fn kill(&self);
}
pub trait AsyncEngineContextProvider: Send + Sync + Debug {
fn context(&self) -> Arc<dyn AsyncEngineContext>;
}
pub trait AsyncEngineUnary<Resp: Data>:
Future<Output = Resp> + AsyncEngineContextProvider + Send + Sync
{
}
pub trait AsyncEngineStream<Resp: Data>:
Stream<Item = Resp> + AsyncEngineContextProvider + Send + Sync
{
}
/// Engine is a trait that defines the interface for a steaming LLM completion engine.
/// The synchronous Engine version is does not need to be awaited.
#[async_trait]
pub trait AsyncEngine<Req: Data, Resp: Data + AsyncEngineContextProvider, E: Data>:
Send + Sync
{
/// Generate a stream of completion responses.
async fn generate(&self, request: Req) -> Result<Resp, E>;
}
/// Adapter for a [`DataStream`] to a [`ResponseStream`].
///
/// A common pattern is to consume the [`ResponseStream`] with standard stream combinators
/// which produces a [`DataStream`] stream, then form a [`ResponseStream`] by propagating the
/// original [`AsyncEngineContext`].
pub struct ResponseStream<R: Data> {
stream: DataStream<R>,
ctx: Arc<dyn AsyncEngineContext>,
}
impl<R: Data> ResponseStream<R> {
pub fn new(stream: DataStream<R>, ctx: Arc<dyn AsyncEngineContext>) -> Pin<Box<Self>> {
Box::pin(Self { stream, ctx })
}
}
impl<R: Data> Stream for ResponseStream<R> {
type Item = R;
#[inline]
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
impl<R: Data> AsyncEngineStream<R> for ResponseStream<R> {}
impl<R: Data> AsyncEngineContextProvider for ResponseStream<R> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.ctx.clone()
}
}
impl<R: Data> Debug for ResponseStream<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseStream")
// todo: add debug for stream - possibly propagate some information about what
// engine created the stream
// .field("stream", &self.stream)
.field("ctx", &self.ctx)
.finish()
}
}
impl<T: Data> AsyncEngineContextProvider for Pin<Box<dyn AsyncEngineUnary<T>>> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
AsyncEngineContextProvider::context(&**self)
}
}
impl<T: Data> AsyncEngineContextProvider for Pin<Box<dyn AsyncEngineStream<T>>> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
AsyncEngineContextProvider::context(&**self)
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! Triton
#![allow(dead_code)]
#![allow(unused_imports)]
use std::sync::{Arc, Mutex};
pub use anyhow::{anyhow as error, Context as ErrorContext, Error, Ok as OK, Result};
use async_once_cell::OnceCell;
use tracing as log;
mod config;
pub use config::RuntimeConfig;
pub mod component;
pub mod discovery;
pub mod engine;
pub mod pipeline;
pub mod protocols;
pub mod runtime;
pub mod service;
pub mod transports;
pub mod worker;
pub mod distributed;
pub use tokio_util::sync::CancellationToken;
pub use worker::Worker;
/// Types of Tokio runtimes that can be used to construct a Triton [Runtime].
#[derive(Clone)]
enum RuntimeType {
Shared(Arc<tokio::runtime::Runtime>),
External(tokio::runtime::Handle),
}
/// Local [Runtime] which provides access to shared resources local to the physical node/machine.
#[derive(Debug, Clone)]
pub struct Runtime {
id: Arc<String>,
primary: RuntimeType,
secondary: Arc<tokio::runtime::Runtime>,
cancellation_token: CancellationToken,
}
/// Distributed [Runtime] which provides access to shared resources across the cluster, this includes
/// communication protocols and transports.
#[derive(Clone)]
pub struct DistributedRuntime {
// local runtime
runtime: Runtime,
// we might consider a unifed transport manager here
etcd_client: transports::etcd::Client,
nats_client: transports::nats::Client,
tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>,
// local registry for components
// the registry allows us to use share runtime resources across instances of the same component object.
// take fo example two instances of a client to the same remote component. The registry allows us to use
// a single endpoint watcher for both clients, this keeps the number background tasking watching specific
// paths in etcd to a minimum.
component_registry: component::Registry,
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
/// In a Pipeline, the [`AsyncEngine`] is constrained to take a [`Context`] as input and return
/// a [`super::engine::ResponseStream`] as output.
use serde::{Deserialize, Serialize};
mod nodes;
pub use nodes::{
Operator, PipelineNode, PipelineOperator, SegmentSink, SegmentSource, Service, ServiceBackend,
ServiceFrontend, Sink, Source,
};
pub mod context;
pub mod error;
pub mod network;
pub mod registry;
pub use crate::engine::{
self as engine, async_trait, AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, Data,
DataStream, Engine, EngineStream, EngineUnary, ResponseStream,
};
pub use anyhow::Error;
pub use context::Context;
pub use error::{PipelineError, PipelineErrorExt, TwoPartCodecError};
/// Pipeline inputs carry a [`Context`] which can be used to carry metadata or additional information
/// about the request. This information propagates through the stages, both local and distributed.
pub type SingleIn<T> = Context<T>;
/// Pipeline inputs carry a [`Context`] which can be used to carry metadata or additional information
/// about the request. This information propagates through the stages, both local and distributed.
pub type ManyIn<T> = Context<DataStream<T>>;
/// Type alias for the output of pipeline that returns a single value
pub type SingleOut<T> = EngineUnary<T>;
/// Type alias for the output of pipeline that returns multiple values
pub type ManyOut<T> = EngineStream<T>;
pub type ServiceEngine<T, U> = Engine<T, U, Error>;
/// Unary Engine is a pipeline that takes a single input and returns a single output
pub type UnaryEngine<T, U> = ServiceEngine<SingleIn<T>, SingleOut<U>>;
/// `ClientStreaming` Engine is a pipeline that takes multiple inputs and returns a single output
/// Typically the engine will consume the entire input stream; however, it can also decided to exit
/// early and emit a response without consuming the entire input stream.
pub type ClientStreamingEngine<T, U> = ServiceEngine<ManyIn<T>, SingleOut<U>>;
/// `ServerStreaming` takes a single input and returns multiple outputs.
pub type ServerStreamingEngine<T, U> = ServiceEngine<SingleIn<T>, ManyOut<U>>;
/// `BidirectionalStreaming` takes multiple inputs and returns multiple outputs. Input and output values
/// are considered independent of each other; however, they could be constrained to be related.
pub type BidirectionalStreamingEngine<T, U> = ServiceEngine<ManyIn<T>, ManyOut<U>>;
pub trait AsyncTransportEngine<T: PipelineIO, U: PipelineIO>:
AsyncEngine<T, U, Error> + Send + Sync + 'static
{
}
// pub type TransportEngine<T, U> = Arc<dyn AsyncTransportEngine<T, U>>;
mod sealed {
use super::*;
#[allow(dead_code)]
pub struct Token;
pub trait Connectable {
type DataType: Data;
}
impl<T: Data> Connectable for Context<T> {
type DataType = T;
}
impl<T: Data> Connectable for EngineUnary<T> {
type DataType = T;
}
impl<T: Data> Connectable for EngineStream<T> {
type DataType = T;
}
}
pub trait PipelineIO: Data + sealed::Connectable + AsyncEngineContextProvider {
fn id(&self) -> String;
}
impl<T: Data> PipelineIO for Context<T> {
fn id(&self) -> String {
self.id().to_string()
}
}
impl<T: Data> PipelineIO for EngineUnary<T> {
fn id(&self) -> String {
self.context().id().to_string()
}
}
impl<T: Data> PipelineIO for EngineStream<T> {
fn id(&self) -> String {
self.context().id().to_string()
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Event {
pub id: String,
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! Context Module
//!
//! There are two context object defined in this module:
//!
//! - [`Context`] is an input context which is propagated through the processing pipeline,
//! up to the point where the input is pass to an [`nim_llm_async_engine::AsyncEngine`] for processing.
//! - [`StreamContext`] is the input context transformed into to a type erased context that maintains the inputs
//! registry and visitors. `StreamAdaptors` will amend themselves to the [`StreamContext`] to allow for the
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use super::{AsyncEngineContext, AsyncEngineContextProvider, Data};
use crate::engine::AsyncEngineController;
use async_trait::async_trait;
use super::registry::Registry;
pub struct Context<T: Data> {
current: T,
controller: Arc<Controller>, //todo: hold this as an arc
registry: Registry,
stages: Vec<String>,
}
impl<T: Send + Sync + 'static> Context<T> {
// Create a new context with initial data
pub fn new(current: T) -> Self {
Context {
current,
controller: Arc::new(Controller::default()),
registry: Registry::new(),
stages: Vec::new(),
}
}
pub fn with_controller(current: T, controller: Controller) -> Self {
Context {
current,
controller: Arc::new(controller),
registry: Registry::new(),
stages: Vec::new(),
}
}
pub fn with_id(current: T, id: String) -> Self {
Context {
current,
controller: Arc::new(Controller::new(id)),
registry: Registry::new(),
stages: Vec::new(),
}
}
pub fn id(&self) -> &str {
self.controller.id()
}
pub fn controller(&self) -> &Controller {
&self.controller
}
/// Insert an object into the registry with a specific key.
pub fn insert<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
self.registry.insert_shared(key, value);
}
/// Insert a unique and takable object into the registry with a specific key.
pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
self.registry.insert_unique(key, value);
}
/// Retrieve an object from the registry by key and type.
pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
self.registry.get_shared(key)
}
/// Clone a unique object from the registry by key and type.
pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
self.registry.clone_unique(key)
}
/// Take a unique object from the registry by key and type.
pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
self.registry.take_unique(key)
}
/// Transfer the Context to a new Object without updating the registry
/// This returns a tuple of the previous object and the new Context
pub fn transfer<U: Send + Sync + 'static>(self, new_current: U) -> (T, Context<U>) {
(
self.current,
Context {
current: new_current,
controller: self.controller,
registry: self.registry,
stages: self.stages,
},
)
}
/// Separate out the current object and context
pub fn into_parts(self) -> (T, Context<()>) {
self.transfer(())
}
pub fn stages(&self) -> &Vec<String> {
&self.stages
}
pub fn add_stage(&mut self, stage: &str) {
self.stages.push(stage.to_string());
}
/// Transforms the current context to another type using a provided function.
pub fn map<U: Send + Sync + 'static, F>(self, f: F) -> Context<U>
where
F: FnOnce(T) -> U,
{
// Use the transfer method to move the current value out
let (current, temp_context) = self.transfer(());
// Apply the transformation function to the current value
let new_current = f(current);
// Use transfer again to create the new context with the transformed type
temp_context.transfer(new_current).1
}
pub fn try_map<U, F, E>(self, f: F) -> Result<Context<U>, E>
where
F: FnOnce(T) -> Result<U, E>,
U: Send + Sync + 'static,
{
// Use the transfer method to move the current value out
let (current, temp_context) = self.transfer(());
// Apply the transformation function to the current value
let new_current = f(current)?;
// Use transfer again to create the new context with the transformed type
Ok(temp_context.transfer(new_current).1)
}
}
impl<T: Data> std::fmt::Debug for Context<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context")
.field("id", &self.controller.id())
.finish()
}
}
// Implement Deref to allow Context<T> to act like &T
impl<T: Data> Deref for Context<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.current
}
}
// Implement DerefMut to allow Context<T> to act like &mut T
impl<T: Data> DerefMut for Context<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.current
}
}
// Implement the custom trait for Context<T>
impl<T> From<T> for Context<T>
where
T: Send + Sync + 'static,
{
fn from(current: T) -> Self {
Context::new(current)
}
}
// Define a custom trait for conversion from Context<T> to Context<U>
pub trait IntoContext<U: Data> {
fn into_context(self) -> Context<U>;
}
// Implement the custom trait for converting Context<T> to Context<U>
impl<T, U> IntoContext<U> for Context<T>
where
T: Send + Sync + 'static + Into<U>,
U: Send + Sync + 'static,
{
fn into_context(self) -> Context<U> {
self.map(|current| current.into())
}
}
impl<T: Data> AsyncEngineContextProvider for Context<T> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.controller.clone()
}
}
#[derive(Debug, Clone)]
pub struct StreamContext {
controller: Arc<Controller>,
registry: Arc<Registry>,
stages: Vec<String>,
}
impl StreamContext {
fn new(controller: Arc<Controller>, registry: Registry) -> Self {
StreamContext {
controller,
registry: Arc::new(registry),
stages: Vec::new(),
}
}
/// Retrieve an object from the registry by key and type.
pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
self.registry.get_shared(key)
}
/// Clone a unique object from the registry by key and type.
pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
self.registry.clone_unique(key)
}
pub fn registry(&self) -> Arc<Registry> {
self.registry.clone()
}
pub fn stages(&self) -> &Vec<String> {
&self.stages
}
pub fn add_stage(&mut self, stage: &str) {
self.stages.push(stage.to_string());
}
}
#[async_trait]
impl AsyncEngineContext for StreamContext {
fn id(&self) -> &str {
self.controller.id()
}
fn stop(&self) {
self.controller.stop();
}
fn kill(&self) {
self.controller.kill();
}
fn stop_generating(&self) {
self.controller.stop_generating();
}
fn is_stopped(&self) -> bool {
self.controller.is_stopped()
}
fn is_killed(&self) -> bool {
self.controller.is_killed()
}
async fn stopped(&self) {
self.controller.stopped().await
}
async fn killed(&self) {
self.controller.killed().await
}
}
impl AsyncEngineContextProvider for StreamContext {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.controller.clone()
}
}
impl<T: Send + Sync + 'static> From<Context<T>> for StreamContext {
fn from(value: Context<T>) -> Self {
StreamContext::new(value.controller, value.registry)
}
}
// TODO - refactor here - this came from the nim-llm-async-engine crate
use tokio::sync::watch::{channel, Receiver, Sender};
#[derive(Debug, Eq, PartialEq)]
enum State {
Live,
Stopped,
Killed,
}
/// A context implementation with cancellation propagation.
#[derive(Debug)]
pub struct Controller {
id: String,
tx: Sender<State>,
rx: Receiver<State>,
}
impl Controller {
pub fn new(id: String) -> Self {
let (tx, rx) = channel(State::Live);
Self { id, tx, rx }
}
pub fn id(&self) -> &str {
&self.id
}
}
impl Default for Controller {
fn default() -> Self {
Self::new(uuid::Uuid::new_v4().to_string())
}
}
impl AsyncEngineController for Controller {}
#[async_trait]
impl AsyncEngineContext for Controller {
fn id(&self) -> &str {
&self.id
}
fn is_stopped(&self) -> bool {
*self.rx.borrow() != State::Live
}
fn is_killed(&self) -> bool {
*self.rx.borrow() == State::Killed
}
async fn stopped(&self) {
let mut rx = self.rx.clone();
if *rx.borrow_and_update() != State::Live {
return;
}
let _ = rx.changed().await;
}
async fn killed(&self) {
let mut rx = self.rx.clone();
if *rx.borrow_and_update() == State::Killed {
return;
}
let _ = rx.changed().await;
}
fn stop_generating(&self) {
self.stop();
}
fn stop(&self) {
let _ = self.tx.send(State::Stopped);
}
fn kill(&self) {
let _ = self.tx.send(State::Killed);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct Input {
value: String,
}
#[derive(Debug, Clone)]
struct Processed {
length: usize,
}
#[derive(Debug, Clone)]
struct Final {
message: String,
}
impl From<Input> for Processed {
fn from(input: Input) -> Self {
Processed {
length: input.value.len(),
}
}
}
impl From<Processed> for Final {
fn from(processed: Processed) -> Self {
Final {
message: format!("Processed length: {}", processed.length),
}
}
}
#[test]
fn test_insert_and_get() {
let mut ctx = Context::new(Input {
value: "Hello".to_string(),
});
ctx.insert("key1", 42);
ctx.insert("key2", "some data".to_string());
assert_eq!(*ctx.get::<i32>("key1").unwrap(), 42);
assert_eq!(*ctx.get::<String>("key2").unwrap(), "some data");
assert!(ctx.get::<f64>("key1").is_err()); // Testing a downcast failure
}
#[test]
fn test_transfer() {
let ctx = Context::new(Input {
value: "Hello".to_string(),
});
let (input, ctx) = ctx.transfer(Processed { length: 5 });
assert_eq!(input.value, "Hello");
assert_eq!(ctx.length, 5);
}
#[test]
fn test_map() {
let ctx = Context::new(Input {
value: "Hello".to_string(),
});
let ctx: Context<Processed> = ctx.map(|input| input.into());
let ctx: Context<Final> = ctx.map(|processed| processed.into());
assert_eq!(ctx.current.message, "Processed length: 5");
}
#[test]
fn test_into_context() {
let ctx = Context::new(Input {
value: "Hello".to_string(),
});
let ctx: Context<Processed> = ctx.into_context();
let ctx: Context<Final> = ctx.into_context();
assert_eq!(ctx.current.message, "Processed length: 5");
}
}
This diff is collapsed.
This diff is collapsed.
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! Codec Module
//!
//! Codec map structure into blobs of bytes and streams of bytes.
//!
//! In this module, we define three primary codec used to issue single, two-part or multi-part messages,
//! on a byte stream.
use tokio_util::{
bytes::{Buf, BufMut, BytesMut},
codec::{Decoder, Encoder},
};
mod two_part;
pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
// // Custom codec that reads a u64 length header and the message of that length
// #[derive(Default)]
// pub struct LengthPrefixedCodec;
// impl LengthPrefixedCodec {
// pub fn new() -> Self {
// LengthPrefixedCodec {}
// }
// }
// impl Decoder for LengthPrefixedCodec {
// type Item = Vec<u8>;
// type Error = tokio::io::Error;
// fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// // Check if enough bytes are available to read the length (u64 = 8 bytes)
// if src.len() < 8 {
// return Ok(None); // Not enough data to read the length
// }
// // Read the u64 length header
// let len = src.get_u64() as usize;
// // Check if enough bytes are available to read the full message
// if src.len() < len {
// src.reserve(len - src.len()); // Reserve space for the remaining bytes
// return Ok(None);
// }
// // Read the actual message bytes of the specified length
// let data = src.split_to(len).to_vec();
// Ok(Some(data))
// }
// }
// impl Encoder<Vec<u8>> for LengthPrefixedCodec {
// type Error = tokio::io::Error;
// fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
// // Write the length of the message as a u64 header
// dst.put_u64(item.len() as u64);
// // Write the actual message bytes
// dst.put_slice(&item);
// Ok(())
// }
// }
This diff is collapsed.
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
pub mod push;
use super::*;
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment