Commit 5ed8c1c0 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: rust - initial commit

the journey begins
parent 4017bd18
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
// TODO - refactor this entire module
//
// we want to carry forward the concept of live vs ready for the components
// we will want to associate the components cancellation token with the
// component's "service state"
use crate::{log, transports::nats, Result};
use async_nats::Message;
use async_stream::try_stream;
use bytes::Bytes;
use derive_getters::Dissolve;
use futures::stream::StreamExt;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::time::Duration;
pub struct ServiceClient {
nats_client: nats::Client,
}
impl ServiceClient {
#[allow(dead_code)]
pub(crate) fn new(nats_client: nats::Client) -> Self {
ServiceClient { nats_client }
}
}
pub struct ServiceSet {
services: Vec<ServiceInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceInfo {
pub name: String,
pub id: String,
pub version: String,
pub started: String,
pub endpoints: Vec<EndpointInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
pub struct EndpointInfo {
pub name: String,
pub subject: String,
#[serde(flatten)]
pub data: Metrics,
}
#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
pub struct Metrics(pub serde_json::Value);
impl Metrics {
pub fn decode<T: DeserializeOwned>(self) -> Result<T> {
serde_json::from_value(self.0).map_err(Into::into)
}
}
impl ServiceClient {
pub async fn unary(
&self,
subject: impl Into<String>,
payload: impl Into<Bytes>,
) -> Result<Message> {
let response = self
.nats_client
.client()
.request(subject.into(), payload.into())
.await?;
Ok(response)
}
pub async fn collect_services(&self, service_name: &str) -> Result<ServiceSet> {
let mut sub = self.nats_client.service_subscriber(service_name).await?;
let deadline = tokio::time::Instant::now() + Duration::from_secs(1);
let services: Vec<Result<ServiceInfo>> = try_stream! {
while let Ok(Some(message)) = tokio::time::timeout_at(deadline, sub.next()).await {
if message.payload.is_empty() {
continue;
}
let service = serde_json::from_slice::<ServiceInfo>(&message.payload)?;
log::trace!("service: {:?}", service);
yield service;
}
}
.collect()
.await;
// split ok and error results
let (ok, err): (Vec<_>, Vec<_>) = services.into_iter().partition(Result::is_ok);
if !err.is_empty() {
log::error!("failed to collect services: {:?}", err);
}
Ok(ServiceSet {
services: ok.into_iter().map(Result::unwrap).collect(),
})
}
}
impl ServiceSet {
pub fn into_endpoints(self) -> impl Iterator<Item = EndpointInfo> {
self.services
.into_iter()
.flat_map(|s| s.endpoints.into_iter())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_service_set() {
let services = vec![
ServiceInfo {
name: "service1".to_string(),
id: "1".to_string(),
version: "1.0".to_string(),
started: "2021-01-01".to_string(),
endpoints: vec![
EndpointInfo {
name: "endpoint1".to_string(),
subject: "subject1".to_string(),
data: Metrics(serde_json::json!({"key": "value1"})),
},
EndpointInfo {
name: "endpoint2-foo".to_string(),
subject: "subject2".to_string(),
data: Metrics(serde_json::json!({"key": "value1"})),
},
],
},
ServiceInfo {
name: "service1".to_string(),
id: "2".to_string(),
version: "1.0".to_string(),
started: "2021-01-01".to_string(),
endpoints: vec![
EndpointInfo {
name: "endpoint1".to_string(),
subject: "subject1".to_string(),
data: Metrics(serde_json::json!({"key": "value1"})),
},
EndpointInfo {
name: "endpoint2-bar".to_string(),
subject: "subject2".to_string(),
data: Metrics(serde_json::json!({"key": "value2"})),
},
],
},
];
let service_set = ServiceSet { services };
let endpoints: Vec<_> = service_set
.into_endpoints()
.filter(|e| e.name.starts_with("endpoint2"))
.collect();
assert_eq!(endpoints.len(), 2);
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! The Transports module hosts all the network communication stacks used for talking
//! to services or moving data around the network.
//!
//! These are the low-level building blocks for the distributed system.
pub mod etcd;
pub mod nats;
pub mod tcp;
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use crate::{error, log, CancellationToken, ErrorContext, Result, Runtime};
use async_nats::jetstream::kv;
use derive_builder::Builder;
use derive_getters::Dissolve;
use futures::StreamExt;
use tokio::sync::mpsc;
use validator::Validate;
use etcd_client::{
Compare, CompareOp, GetOptions, KeyValue, PutOptions, Txn, TxnOp, WatchOptions, Watcher,
};
pub use etcd_client::{ConnectOptions, LeaseClient};
mod lease;
use lease::*;
//pub use etcd::ConnectOptions as EtcdConnectOptions;
/// ETCD Client
#[derive(Clone)]
pub struct Client {
client: etcd_client::Client,
primary_lease: i64,
runtime: Runtime,
}
#[derive(Debug, Clone)]
pub struct Lease {
/// ETCD lease ID
id: i64,
/// [`CancellationToken`] associated with the lease
cancel_token: CancellationToken,
}
impl Lease {
/// Get the lease ID
pub fn id(&self) -> i64 {
self.id
}
/// Get the primary [`CancellationToken`] associated with the lease.
/// This token will revoke the lease if canceled.
pub fn primary_token(&self) -> CancellationToken {
self.cancel_token.clone()
}
/// Get a child [`CancellationToken`] from the lease's [`CancellationToken`].
/// This child token will be triggered if the lease is revoked, but will not revoke the lease if canceled.
pub fn child_token(&self) -> CancellationToken {
self.cancel_token.child_token()
}
/// Revoke the lease triggering the [`CancellationToken`].
pub fn revoke(&self) {
self.cancel_token.cancel();
}
}
impl Client {
pub fn builder() -> ClientOptionsBuilder {
ClientOptionsBuilder::default()
}
/// Create a new discovery client
///
/// This will establish a connection to the etcd server, create a primary lease,
/// and spawn a task to keep the lease alive and tie the lifetime of the [`Runtime`]
/// to the lease.
///
/// If the lease expires, the [`Runtime`] will be shutdown.
/// If the [`Runtime`] is shutdown, the lease will be revoked.
pub async fn new(config: ClientOptions, runtime: Runtime) -> Result<Self> {
runtime
.secondary()
.spawn(Self::create(config, runtime.clone()))
.await?
}
/// Create a new etcd client and tie the primary [`CancellationToken`] to the primary etcd lease.
async fn create(config: ClientOptions, runtime: Runtime) -> Result<Self> {
let token = runtime.primary_token();
let client =
etcd_client::Client::connect(config.etcd_url, config.etcd_connect_options).await?;
let lease_client = client.lease_client();
let lease = create_lease(lease_client, 10, token)
.await
.context("creating primary lease")?;
Ok(Client {
client,
primary_lease: lease.id,
runtime,
})
}
/// Get a reference to the underlying [`etcd_client::Client`] instance.
pub fn etcd_client(&self) -> &etcd_client::Client {
&self.client
}
/// Get the primary lease ID.
pub fn lease_id(&self) -> i64 {
self.primary_lease
}
/// Primary [`Lease`]
pub fn primary_lease(&self) -> Lease {
Lease {
id: self.primary_lease,
cancel_token: self.runtime.primary_token(),
}
}
/// Create a [`Lease`] with a given time-to-live (TTL).
/// This [`Lease`] will be tied to the [`Runtime`], specifically a child [`CancellationToken`].
pub async fn create_lease(&self, ttl: i64) -> Result<Lease> {
let token = self.runtime.child_token();
let lease_client = self.client.lease_client();
self.runtime
.secondary()
.spawn(create_lease(lease_client, ttl, token))
.await?
}
pub async fn kv_create(
&self,
key: String,
value: Vec<u8>,
lease_id: Option<i64>,
) -> Result<()> {
let put_options = lease_id.map(|id| PutOptions::new().with_lease(id));
// Build the transaction
let txn = Txn::new()
.when(vec![Compare::version(key.as_str(), CompareOp::Equal, 0)]) // Ensure the lock does not exist
.and_then(vec![
TxnOp::put(key.as_str(), value, put_options), // Create the object
]);
// Execute the transaction
let _ = self.client.kv_client().txn(txn).await?;
Ok(())
}
pub async fn kv_get_prefix(&self, prefix: impl AsRef<str>) -> Result<Vec<KeyValue>> {
let mut get_response = self
.client
.kv_client()
.get(prefix.as_ref(), Some(GetOptions::new().with_prefix()))
.await?;
Ok(get_response.take_kvs())
}
pub async fn kv_get_and_watch_prefix(&self, prefix: impl AsRef<str>) -> Result<PrefixWatcher> {
let mut kv_client = self.client.kv_client();
let mut watch_client = self.client.watch_client();
let mut get_response = kv_client
.get(prefix.as_ref(), Some(GetOptions::new().with_prefix()))
.await?;
let start_revision = get_response
.header()
.ok_or(error!("missing header; unable to get revision"))?
.revision();
let (watcher, mut watch_stream) = watch_client
.watch(
prefix.as_ref(),
Some(
WatchOptions::new()
.with_prefix()
.with_start_revision(start_revision),
),
)
.await?;
let kvs = get_response.take_kvs();
let (tx, rx) = mpsc::channel(32);
self.runtime.secondary().spawn(async move {
for kv in kvs {
if tx.send(WatchEvent::Put(kv)).await.is_err() {
// receiver is closed
break;
}
}
while let Some(Ok(response)) = watch_stream.next().await {
for event in response.events() {
match event.event_type() {
etcd_client::EventType::Put => {
if let Some(kv) = event.kv() {
if tx.send(WatchEvent::Put(kv.clone())).await.is_err() {
// receiver is closed
break;
}
}
}
etcd_client::EventType::Delete => {
if let Some(kv) = event.kv() {
if tx.send(WatchEvent::Delete(kv.clone())).await.is_err() {
// receiver is closed
break;
}
}
}
}
}
}
});
Ok(PrefixWatcher {
prefix: prefix.as_ref().to_string(),
watcher,
rx,
})
}
}
#[derive(Dissolve)]
pub struct PrefixWatcher {
prefix: String,
watcher: Watcher,
rx: mpsc::Receiver<WatchEvent>,
}
pub enum WatchEvent {
Put(KeyValue),
Delete(KeyValue),
}
/// ETCD client configuration options
#[derive(Debug, Clone, Builder, Validate)]
pub struct ClientOptions {
#[validate(length(min = 1))]
etcd_url: Vec<String>,
#[builder(default)]
etcd_connect_options: Option<ConnectOptions>,
}
impl Default for ClientOptions {
fn default() -> Self {
ClientOptions {
etcd_url: default_servers(),
etcd_connect_options: None,
}
}
}
fn default_servers() -> Vec<String> {
match std::env::var("ETCD_ENDPOINTS") {
Ok(possible_list_of_urls) => possible_list_of_urls
.split(',')
.map(|s| s.to_string())
.collect(),
Err(_) => vec!["http://localhost:2379".to_string()],
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::*;
/// Create a [`Lease`] with a given time-to-live (TTL) attached to the [`CancellationToken`].
pub async fn create_lease(
mut lease_client: LeaseClient,
ttl: i64,
token: CancellationToken,
) -> Result<Lease> {
let lease = lease_client.grant(ttl, None).await?;
let id = lease.id();
let ttl = lease.ttl();
let child = token.child_token();
let clone = token.clone();
tokio::spawn(async move {
match keep_alive(lease_client, id, ttl, child).await {
Ok(_) => log::trace!("keep alive task exited successfully"),
Err(e) => {
log::info!("keep alive task failed: {:?}", e);
token.cancel();
}
}
});
Ok(Lease {
id,
cancel_token: clone,
})
}
/// Task to keep leases alive.
///
/// If this task returns an error, the cancellation token will be invoked on the runtime.
/// If
pub async fn keep_alive(
client: LeaseClient,
lease_id: i64,
ttl: i64,
token: CancellationToken,
) -> Result<()> {
let mut ttl = ttl;
let mut deadline = create_deadline(ttl)?;
let mut client = client;
let (mut heartbeat_sender, mut heartbeat_receiver) = client.keep_alive(lease_id).await?;
loop {
// if the deadline is exceeded, then we have failed to issue a heartbeat in time
// we maybe be permanently disconnected from the etcd server, so we are now officially done
if deadline < std::time::Instant::now() {
return Err(error!("failed to issue heartbeat in time"));
}
tokio::select! {
biased;
status = heartbeat_receiver.message() => {
if let Some(resp) = status? {
log::trace!(lease_id, "keep alive response received: {:?}", resp);
// update ttl and deadline
ttl = resp.ttl();
deadline = create_deadline(ttl)?;
if resp.ttl() == 0 {
return Err(error!("lease expired or revoked"));
}
}
}
_ = token.cancelled() => {
log::trace!(lease_id, "cancellation token triggered; revoking lease");
let _ = client.revoke(lease_id).await?;
return Ok(());
}
_ = tokio::time::sleep(tokio::time::Duration::from_secs(ttl as u64 / 2)) => {
log::trace!(lease_id, "sending keep alive");
// if we get a error issuing the heartbeat, set the ttl to 0
// this will allow us to poll the response stream once and the cancellation token once, then
// immediately try to tick the heartbeat
// this will repeat until either the heartbeat is reestablished or the deadline is exceeded
if let Err(e) = heartbeat_sender.keep_alive().await {
log::warn!(lease_id, "keep alive failed: {:?}", e);
ttl = 0;
}
}
}
}
}
/// Create a deadline for a given time-to-live (TTL).
fn create_deadline(ttl: i64) -> Result<std::time::Instant> {
if ttl <= 0 {
return Err(error!("invalid ttl: {}", ttl));
}
Ok(std::time::Instant::now() + std::time::Duration::from_secs(ttl as u64))
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! NATS transport
//!
//! The following environment variables are used to configure the NATS client:
//!
//! - `NATS_SERVER`: the NATS server address
//!
//! For authentication, the following environment variables are used and prioritized in the following order:
//!
//! - `NATS_AUTH_USERNAME`: the username for authentication
//! - `NATS_AUTH_PASSWORD`: the password for authentication
//! - `NATS_AUTH_TOKEN`: the token for authentication
//! - `NATS_AUTH_NKEY`: the nkey for authentication
//! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file
//!
//! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together.
use crate::Result;
use async_nats::{client, jetstream, Subscriber};
use derive_builder::Builder;
use futures::TryStreamExt;
use std::path::PathBuf;
use validator::{Validate, ValidationError};
mod slug;
pub use slug::Slug;
#[derive(Clone)]
pub struct Client {
client: client::Client,
js_ctx: jetstream::Context,
}
impl Client {
/// Create a NATS [`ClientOptionsBuilder`].
pub fn builder() -> ClientOptionsBuilder {
ClientOptionsBuilder::default()
}
/// Returns a reference to the underlying [`async_nats::client::Client`] instance
pub fn client(&self) -> &client::Client {
&self.client
}
/// Returns a reference to the underlying [`async_nats::jetstream::Context`] instance
pub fn jetstream(&self) -> &jetstream::Context {
&self.js_ctx
}
/// fetch the list of streams
pub async fn list_streams(&self) -> Result<Vec<String>> {
let names = self.js_ctx.stream_names();
let stream_names: Vec<String> = names.try_collect().await?;
Ok(stream_names)
}
/// fetch the list of consumers for a given stream
pub async fn list_consumers(&self, stream_name: &str) -> Result<Vec<String>> {
let stream = self.js_ctx.get_stream(stream_name).await?;
let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
Ok(consumers)
}
pub async fn stream_info(&self, stream_name: &str) -> Result<jetstream::stream::State> {
let mut stream = self.js_ctx.get_stream(stream_name).await?;
let info = stream.info().await?;
Ok(info.state.clone())
}
pub async fn get_stream(&self, name: &str) -> Result<jetstream::stream::Stream> {
let stream = self.js_ctx.get_stream(name).await?;
Ok(stream)
}
pub async fn service_subscriber(&self, service_name: &str) -> Result<Subscriber> {
let subject = format!("$SRV.STATS.{}", service_name);
let reply_subject = format!("_INBOX.{}", nuid::next());
let subscription = self.client.subscribe(reply_subject.clone()).await?;
// Publish the request with the reply-to subject
self.client
.publish_with_reply(subject, reply_subject, "".into())
.await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = time::Instant::now();
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
// tx.send(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
Ok(subscription)
}
// /// create a new stream
// async fn get_or_create_work_queue_stream(
// &self,
// name: &super::Namespace,
// ) -> Result<jetstream::stream::Stream> {
// let stream = self
// .js_ctx
// .get_or_create_stream(async_nats::jetstream::stream::Config {
// name: name.to_string(),
// retention: async_nats::jetstream::stream::RetentionPolicy::WorkQueue,
// subjects: vec![format!("{name}.>")],
// ..Default::default()
// })
// .await?;
// Ok(stream)
// }
// // get work queue
// pub async fn get_or_create_work_queue(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// ) -> Result<WorkQueue> {
// let stream = self.get_or_create_work_queue_stream(namespace).await?;
// let consumer_name = single_name(namespace, queue_name);
// let subject_name = subject_name(namespace, queue_name);
// let subject_name = format!("{}.*", subject_name);
// tracing::trace!(
// durable_name = consumer_name,
// filter_subject = subject_name,
// "get_or_create_work_queue"
// );
// let consumer = stream
// .get_or_create_consumer(
// &consumer_name,
// jetstream::consumer::pull::Config {
// durable_name: Some(consumer_name.clone()),
// filter_subject: subject_name,
// ack_policy: jetstream::consumer::AckPolicy::Explicit,
// ..Default::default()
// },
// )
// .await?;
// Ok(WorkQueue::new(consumer))
// }
// pub async fn get_or_create_work_queue_publisher(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// ) -> Result<WorkQueuePublisher> {
// let _stream = self.get_or_create_work_queue_stream(namespace).await?;
// let _subject = subject_name(namespace, queue_name);
// Ok(WorkQueuePublisher {
// client: self.clone(),
// namespace: namespace.clone(),
// queue_name: queue_name.clone(),
// })
// }
// pub async fn list_work_queues(
// &self,
// namespace: &super::Namespace,
// ) -> Result<Vec<String>> {
// let stream = self.get_stream(namespace.as_ref()).await?;
// let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
// Ok(consumers)
// }
// /// remove a work queue
// pub async fn remove_work_queue(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// ) -> Result<()> {
// let stream = self.get_stream(namespace.as_ref()).await?;
// let consumer_name = single_name(namespace, queue_name);
// let consumers = self.list_consumers(namespace.as_ref()).await?;
// if consumers.contains(&consumer_name) {
// stream.delete_consumer(&consumer_name).await?;
// }
// Ok(())
// }
// /// publish a message to a subject
// pub async fn publish(&self, subject: String, msg: Vec<u8>) -> Result<()> {
// self.client.publish(subject, msg.into()).await?;
// Ok(())
// }
// /// subscribe to a subject
// pub async fn subscribe(
// &self,
// subject: String,
// ) -> Result<async_nats::Subscriber> {
// let sub = self.client.subscribe(subject).await?;
// Ok(sub)
// }
// pub async fn enqueue(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// payload: Bytes,
// ) -> Result<String> {
// // let mut headers = HeaderMap::new();
// let subject = subject_name(namespace, queue_name);
// let request_id = uuid::Uuid::new_v4().to_string();
// let subject = format!("{}.{}", subject, request_id);
// self.client.publish(subject, payload).await?;
// // self.client
// // .publish_with_headers(subject, headers, payload.into())
// // .await?;
// Ok(request_id)
// }
// pub async fn enqueue_with_id(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// request_id: &str,
// payload: Vec<u8>,
// ) -> Result<()> {
// let subject = subject_name(namespace, queue_name);
// let subject = format!("{}.{}", subject, request_id);
// self.client.publish(subject, payload.into()).await?;
// Ok(())
// }
// pub async fn get_endpoints(
// &self,
// service_name: &str,
// timeout: Duration,
// ) -> Result<Vec<Bytes>, anyhow::Error> {
// let subject = format!("$SRV.STATS.{}", service_name);
// let reply_subject = format!("_INBOX.{}", nuid::next());
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
// // Publish the request with the reply-to subject
// self.client
// .publish_with_reply(subject, reply_subject, "".into())
// .await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = time::Instant::now();
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
// responses.push(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
// }
// pub fn frontend_client(&self, request_id: String) -> SpecializedClient {
// SpecializedClient::new(self.client.clone(), ClientKind::Frontend, request_id)
// }
// pub fn backend_client(&self, request_id: String) -> SpecializedClient {
// SpecializedClient::new(self.client.clone(), ClientKind::Backend, request_id)
// }
}
/// NATS client options
///
/// This object uses the builder pattern with default values that are evaluates
/// from the environment variables if they are not explicitly set by the builder.
#[derive(Debug, Clone, Builder, Validate)]
pub struct ClientOptions {
#[builder(setter(into), default = "default_server()")]
#[validate(custom(function = "validate_nats_server"))]
server: String,
#[builder(default)]
auth: NatsAuth,
}
fn default_server() -> String {
if let Ok(server) = std::env::var("NATS_SERVER") {
return server;
}
"nats://localhost:4222".to_string()
}
fn validate_nats_server(server: &str) -> Result<(), ValidationError> {
if server.starts_with("nats://") {
Ok(())
} else {
Err(ValidationError::new("server must start with 'nats://'"))
}
}
#[allow(dead_code)]
impl ClientOptions {
/// Create a new [`ClientOptionsBuilder`]
pub fn builder() -> ClientOptionsBuilder {
ClientOptionsBuilder::default()
}
/// Validate the config and attempt to connection to the NATS server
pub async fn connect(self) -> Result<Client> {
self.validate()?;
let client = match self.auth {
NatsAuth::UserPass(username, password) => {
async_nats::ConnectOptions::with_user_and_password(username, password)
}
NatsAuth::Token(token) => async_nats::ConnectOptions::with_token(token),
NatsAuth::NKey(nkey) => async_nats::ConnectOptions::with_nkey(nkey),
NatsAuth::CredentialsFile(path) => {
async_nats::ConnectOptions::with_credentials_file(path).await?
}
};
let client = client.connect(self.server).await?;
let js_ctx = jetstream::new(client.clone());
Ok(Client { client, js_ctx })
}
}
impl Default for ClientOptions {
fn default() -> Self {
ClientOptions {
server: default_server(),
auth: NatsAuth::default(),
}
}
}
#[derive(Clone, Eq, PartialEq)]
pub enum NatsAuth {
UserPass(String, String),
Token(String),
NKey(String),
CredentialsFile(PathBuf),
}
impl std::fmt::Debug for NatsAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NatsAuth::UserPass(user, _pass) => {
write!(f, "UserPass({}, <redacted>)", user)
}
NatsAuth::Token(_token) => write!(f, "Token(<redacted>)"),
NatsAuth::NKey(_nkey) => write!(f, "NKey(<redacted>)"),
NatsAuth::CredentialsFile(path) => write!(f, "CredentialsFile({:?})", path),
}
}
}
impl Default for NatsAuth {
fn default() -> Self {
if let (Ok(username), Ok(password)) = (
std::env::var("NATS_AUTH_USERNAME"),
std::env::var("NATS_AUTH_PASSWORD"),
) {
return NatsAuth::UserPass(username, password);
}
if let Ok(token) = std::env::var("NATS_AUTH_TOKEN") {
return NatsAuth::Token(token);
}
if let Ok(nkey) = std::env::var("NATS_AUTH_NKEY") {
return NatsAuth::NKey(nkey);
}
if let Ok(path) = std::env::var("NATS_AUTH_CREDENTIALS_FILE") {
return NatsAuth::CredentialsFile(PathBuf::from(path));
}
NatsAuth::UserPass("user".to_string(), "user".to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use figment::Jail;
#[test]
fn test_client_options_builder() {
Jail::expect_with(|_jail| {
let opts = ClientOptions::builder().build();
assert!(opts.is_ok());
Ok(())
});
Jail::expect_with(|jail| {
jail.set_env("NATS_SERVER", "nats://localhost:5222");
jail.set_env("NATS_AUTH_USERNAME", "user");
jail.set_env("NATS_AUTH_PASSWORD", "pass");
let opts = ClientOptions::builder().build();
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.server, "nats://localhost:5222");
assert_eq!(
opts.auth,
NatsAuth::UserPass("user".to_string(), "pass".to_string())
);
Ok(())
});
Jail::expect_with(|jail| {
jail.set_env("NATS_SERVER", "nats://localhost:5222");
jail.set_env("NATS_AUTH_USERNAME", "user");
jail.set_env("NATS_AUTH_PASSWORD", "pass");
let opts = ClientOptions::builder()
.server("nats://localhost:6222")
.auth(NatsAuth::Token("token".to_string()))
.build();
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.server, "nats://localhost:6222");
assert_eq!(opts.auth, NatsAuth::Token("token".to_string()));
Ok(())
});
}
// const TEST_STREAM: &str = "test_async_nats_stream";
// #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
// struct Request {
// id: String,
// }
// async fn nats_client() -> Result<Client> {
// Client::builder()
// .server("nats://localhost:4222")
// .username("user")
// .password("user")
// .build()
// .await
// }
// #[tokio::test]
// async fn test_list_streams() {
// let client = match nats_client().await.ok() {
// Some(client) => client,
// None => {
// println!("Failed to create client; skipping nats tests");
// return;
// }
// };
// let streams = client.list_streams().await.expect("failed to list streams");
// for stream in streams {
// let info = client
// .stream_info(&stream)
// .await
// .expect("failed to get stream info");
// assert_eq!(info.messages, 0, "stream {} not empty", stream);
// }
// }
// #[tokio::test]
// async fn test_workq_pull_and_response_stream() {
// let ns: Namespace = TEST_STREAM.try_into().unwrap();
// let _client = match nats_client().await.ok() {
// Some(client) => client,
// None => {
// println!("Failed to create client; skipping nats tests");
// return;
// }
// };
// let client = Client::builder()
// .server("nats://localhost:4222")
// .username("user")
// .password("user")
// .build()
// .await
// .expect("failed to create client");
// let _streams = client.list_streams().await.expect("failed to list streams");
// // assert!(!streams.contains(&TEST_STREAM.to_string()));
// let _stream = client
// .get_or_create_work_queue_stream(&ns)
// .await
// .expect("failed to create stream");
// let model_name: Slug = "foo".try_into().unwrap();
// let request_id = "bar";
// let request = Request {
// id: request_id.to_string(),
// };
// let request_payload = serde_json::to_vec(&request).expect("failed to serialize request");
// // let request = CompletionRequest {
// // prompt: CompletionContext::from_prompt("deep learning is".to_string()).into(),
// // stop_conditions: None,
// // sampling_options: None,
// // };
// // remove work queue if it exists
// client
// .remove_work_queue(&ns, &model_name)
// .await
// .expect("remove work queue does not fail if queue does not exist");
// // get the count of the work queues
// let initial_work_queue_count = client
// .list_work_queues(&ns)
// .await
// .expect("failed to list work queues")
// .len();
// // create work queue
// let workq = client
// .get_or_create_work_queue(&ns, &model_name)
// .await
// .expect("failed to get work queue");
// // new work queue count
// let work_queue_count = client
// .list_work_queues(&ns)
// .await
// .expect("failed to list work queues")
// .len();
// assert_eq!(initial_work_queue_count, work_queue_count - 1);
// client
// .enqueue(&ns, &model_name, request_payload.into())
// .await
// .expect("failed to enqueue completion request");
// let mut messages = workq
// .pull(1, std::time::Duration::from_secs(1))
// .await
// .expect("failed to pull messages from work queue");
// assert_eq!(1, messages.len());
// let msg = messages.pop().expect("no message received");
// msg.ack().await.expect("failed to ack");
// let request: Request =
// serde_json::from_slice(&msg.payload).expect("failed to deserialize message");
// assert_eq!(request.id, request_id);
// // clean up and delete nats work queue and stream
// client
// .remove_work_queue(&ns, &model_name)
// .await
// .expect("failed to remove work queue");
// // client
// // .delete_stream(TEST_STREAM)
// // .await
// // .expect("failed to delete stream");
// }
}
// let frontend_client = client.frontend_client("test".to_string());
// // the represents the frontend response subscription
// let mut frontend_sub = frontend_client
// .subscribe()
// .await
// .expect("failed to subscribe");
// let backend_client = client.backend_client("test".to_string());
// let mut backend_sub = backend_client
// .subscribe()
// .await
// .expect("failed to subscribe");
// let msg = messages[0].clone();
// let req = serde_json::from_slice::<CompletionRequest>(&msg.payload)
// .expect("failed to deserialize message");
// msg.ack().await.expect("failed to ack");
// assert_eq!(req.prompt, request.prompt);
// // ping pong message between backend and frontend
// // backend publishes to frontend
// backend_client
// .publish(&MessageKind::Initialize(Prologue {
// formatted_prompt: None,
// input_token_ids: None,
// }))
// .await
// .expect("failed to publish");
// // frontend receives initialize message
// let msg = frontend_sub.next().await.expect("msg not received");
// let msg = serde_json::from_slice::<MessageKind>(&msg.payload)
// .expect("failed to deserialize message");
// match msg {
// MessageKind::Initialize(_) => {}
// _ => panic!("unexpected message"),
// }
// // frontend publishes to backend
// frontend_client
// .publish(&MessageKind::Finalize(Epilogue {}))
// .await
// .expect("failed to publish");
// // backend receives finalize message
// let msg = backend_sub.next().await.expect("msg not received");
// let msg = serde_json::from_slice::<MessageKind>(&msg.payload)
// .expect("failed to deserialize message");
// match &msg {
// MessageKind::Finalize(_) => {}
// _ => panic!("unexpected message"),
// }
// // delete the work queue
// client
// .remove_work_queue(model_name, TEST_STREAM)
// .await
// .expect("failed to remove work queue");
// // new work queue count
// let work_queue_count = client
// .list_work_queues(TEST_STREAM)
// .await
// .expect("failed to list work queues")
// .len();
// // compare against the initial work queue count
// assert_eq!(initial_work_queue_count, work_queue_count);
// }
// pub async fn get_endpoints(
// &self,
// service_name: &str,
// timeout: Duration,
// ) -> Result<Vec<Bytes>, anyhow::Error> {
// let subject = format!("$SRV.STATS.{}", service_name);
// let reply_subject = format!("_INBOX.{}", nuid::next());
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
// // Publish the request with the reply-to subject
// self.client
// .publish_with_reply(subject, reply_subject, "".into())
// .await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = time::Instant::now();
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
// responses.push(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
// }
// async fn connect(config: Arc<Config>) -> Result<NatsClient> {
// let client = ClientOptions::builder()
// .server(config.nats_address.clone())
// .build()
// .await
// .context("Creating NATS Client")?;
// Ok(client)
// }
// async fn create_service(
// nats: NatsClient,
// config: Arc<Config>,
// observer: ServiceObserver,
// ) -> Result<NatsService> {
// let service = nats
// .client()
// .service_builder()
// .description(config.service_description.as_str())
// .stats_handler(move |_name, _stats| {
// let stats = InstanceStats {
// stage: observer.stage(),
// };
// serde_json::to_value(&stats).unwrap()
// })
// .start(
// config.service_name.as_str(),
// config.service_version.as_str(),
// )
// .await
// .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
// Ok(service)
// }
// async fn create_endpoint(
// endpoint_name: impl Into<String>,
// service: &NatsService,
// ) -> Result<Endpoint> {
// let info = service.info().await;
// let group_name = format!("{}-{}", info.name, info.id);
// let group = service.group(group_name);
// let endpoint = group
// .endpoint(endpoint_name.into())
// .await
// .map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
// Ok(endpoint)
// }
// async fn shutdown_endpoint_handler(
// controller: ServiceController,
// endpoint: Endpoint,
// ) -> Result<()> {
// let mut endpoint = endpoint;
// // note: this is a child cancellation token, canceling it will not cancel the parent
// // but the parent will cancel the child -- we only use this to observe if another
// // controller has cancelled the service
// let cancellation_token = controller.cancel_token();
// loop {
// let req = tokio::select! {
// _ = cancellation_token.cancelled() => {
// // log::trace!(worker_id, "Shutting down service {}", self.endpoint.name);
// return Ok(());
// }
// // await on service request
// req = endpoint.next() => {
// req
// }
// };
// if let Some(req) = req {
// let response = "DONE".to_string();
// if let Err(e) = req.respond(Ok(response.into())).await {
// log::warn!("Failed to respond to the shutdown request: {:?}", e);
// }
// controller.set_stage(ServiceStage::ShuttingDown);
// }
// }
// }
// #[derive(Debug, Clone, Builder)]
// pub struct Config {
// /// The NATS server address
// #[builder(default = "String::from(\"nats://localhost:4222\")")]
// pub nats_address: String,
// #[builder(setter(into), default = "String::from(SERVICE_NAME)")]
// pub service_name: String,
// #[builder(setter(into), default = "String::from(SERVICE_VERSION)")]
// pub service_version: String,
// #[builder(setter(into), default = "String::from(SERVICE_DESCRIPTION)")]
// pub service_description: String,
// }
// impl Config {
// pub fn new() -> Result<Config> {
// Ok(ConfigBuilder::default().build()?)
// }
// /// Create a new [`ConfigBuilder`]
// pub fn builder() -> ConfigBuilder {
// ConfigBuilder::default()
// }
// }
// // todo: move to icp - transports
// #[derive(Clone, Debug)]
// pub struct NatsClient {
// client: Client,
// js_ctx: jetstream::Context,
// }
// impl NatsClient {
// pub fn client(&self) -> &Client {
// &self.client
// }
// pub fn jetstream(&self) -> &jetstream::Context {
// &self.js_ctx
// }
// pub fn service_builder(&self) -> NatsServiceBuilder {
// self.client.service_builder()
// }
// pub async fn get_endpoints(
// &self,
// service_name: &str,
// timeout: Duration,
// ) -> Result<Vec<Bytes>, anyhow::Error> {
// let subject = format!("$SRV.STATS.{}", service_name);
// let reply_subject = format!("_INBOX.{}", nuid::next());
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
// // Publish the request with the reply-to subject
// self.client
// .publish_with_reply(subject, reply_subject, "".into())
// .await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = tokio::time::Instant::now();
// while let Ok(Some(message)) = tokio::time::timeout(timeout, subscription.next()).await {
// responses.push(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
// }
// }
// #[derive(Debug, Clone, Serialize, Deserialize)]
// pub struct ServiceInfo {
// pub name: String,
// pub id: String,
// pub version: String,
// pub started: String,
// pub endpoints: Vec<EndpointInfo>,
// }
// #[derive(Debug, Clone, Serialize, Deserialize)]
// pub struct EndpointInfo {
// pub name: String,
// pub subject: String,
// pub data: serde_json::Value,
// }
// impl EndpointInfo {
// pub fn get<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
// serde_json::from_value(self.data.clone()).map_err(Into::into)
// }
// }
// #[derive(Clone, Debug, Builder)]
// #[builder(build_fn(private, name = "build_internal"))]
// pub struct ClientOptions {
// #[builder(setter(into))]
// server: String,
// #[builder(setter(into, strip_option), default)]
// username: Option<String>,
// #[builder(setter(into, strip_option), default)]
// password: Option<String>,
// }
// #[allow(dead_code)]
// impl ClientOptions {
// pub fn builder() -> ClientOptionsBuilder {
// ClientOptionsBuilder::default()
// }
// }
// impl ClientOptionsBuilder {
// pub async fn build(&self) -> Result<NatsClient> {
// let opts = self.build_internal()?;
// // Create an unauthenticated connection to NATS.
// let client = async_nats::ConnectOptions::new();
// let client = if let (Some(username), Some(password)) = (opts.username, opts.password) {
// client.user_and_password(username, password)
// } else {
// client
// };
// let client = client.connect(&opts.server).await?;
// let js_ctx = jetstream::new(client.clone());
// Ok(NatsClient { client, js_ctx })
// }
// }
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use serde::de::{self, Deserializer, Visitor};
use serde::{Deserialize, Serialize};
use std::fmt;
const REPLACEMENT_CHAR: char = '_';
/// URL and NATS friendly string.
/// Only a-z, 0-9, - and _.
#[derive(Serialize, Clone, Debug, Eq, PartialEq)]
pub struct Slug(String);
impl Slug {
fn new(s: String) -> Slug {
// remove any leading REPLACEMENT_CHAR
let s = s.trim_start_matches(REPLACEMENT_CHAR).to_string();
Slug(s)
}
/// Create [`Slug`] from a string.
pub fn from_string(s: impl AsRef<str>) -> Slug {
Slug::slugify_unique(s.as_ref())
}
// /// Turn the string into a valid slug, replacing any not-web-or-nats-safe characters with '-'
// fn slugify(s: &str) -> Slug {
// let out = s
// .to_lowercase()
// .chars()
// .map(|c| {
// let is_valid = c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_';
// if is_valid {
// c
// } else {
// REPLACEMENT_CHAR
// }
// })
// .collect::<String>();
// Slug::new(out)
// }
/// Like slugify but also add a four byte hash on the end, in case two different strings slug
/// to the same thing.
fn slugify_unique(s: &str) -> Slug {
let out = s
.to_lowercase()
.chars()
.map(|c| {
let is_valid = c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_';
if is_valid {
c
} else {
REPLACEMENT_CHAR
}
})
.collect::<String>();
let hash = blake3::hash(s.as_bytes()).to_string();
let out = format!("{out}-{}", &hash[(hash.len() - 8)..]);
Slug::new(out)
}
}
impl fmt::Display for Slug {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug)]
pub struct InvalidSlugError(char);
impl fmt::Display for InvalidSlugError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Invalid char '{}'. String can only contain a-z, 0-9, - and _.",
self.0
)
}
}
impl std::error::Error for InvalidSlugError {}
impl TryFrom<&str> for Slug {
type Error = InvalidSlugError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
s.to_string().try_into()
}
}
impl TryFrom<String> for Slug {
type Error = InvalidSlugError;
fn try_from(s: String) -> Result<Self, Self::Error> {
let is_invalid =
|c: &char| !c.is_ascii_lowercase() && !c.is_ascii_digit() && *c != '-' && *c != '_';
match s.chars().find(is_invalid) {
None => Ok(Slug(s)),
Some(c) => Err(InvalidSlugError(c)),
}
}
}
impl<'de> Deserialize<'de> for Slug {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct SlugVisitor;
impl Visitor<'_> for SlugVisitor {
type Value = Slug;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter
.write_str("a valid slug string containing only characters a-z, 0-9, - and _.")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Slug::try_from(v).map_err(de::Error::custom)
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Slug::try_from(v.as_ref()).map_err(de::Error::custom)
}
}
deserializer.deserialize_string(SlugVisitor)
}
}
impl AsRef<str> for Slug {
fn as_ref(&self) -> &str {
&self.0
}
}
impl PartialEq<str> for Slug {
fn eq(&self, other: &str) -> bool {
self.0 == other
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
pub use crate::pipeline::network::tcp::{client, server};
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! The [Worker] class is a convenience wrapper around the construction of the [Runtime]
//! and execution of the users application.
//!
//! In the future, the [Worker] should probably be moved to a procedural macro similar
//! to the `#[tokio::main]` attribute, where we might annotate an async main function with
//! #[triton::main] or similar.
//!
//! The [Worker::execute] method is designed to be called once from main and will block
//! the calling thread until the application completes or is canceled. The method initialized
//! the signal handler used to trap `SIGINT` and `SIGTERM` signals and trigger a graceful shutdown.
//!
//! On termination, the user application is given a graceful shutdown period of controlled by
//! the [TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT] environment variable. If the application does not
//! shutdown in time, the worker will terminate the application with an exit code of 911.
//!
//! The default values of `TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT` differ between the development
//! and release builds. In development, the default is [DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_DEBUG] and
//! in release, the default is [DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_RELEASE].
use super::{error, log, CancellationToken, Result, Runtime, RuntimeConfig};
use futures::Future;
use once_cell::sync::OnceCell;
use std::{sync::Mutex, time::Duration};
use tokio::{signal, task::JoinHandle};
static RT: OnceCell<tokio::runtime::Runtime> = OnceCell::new();
static INIT: OnceCell<Mutex<Option<tokio::task::JoinHandle<Result<()>>>>> = OnceCell::new();
const SHUTDOWN_MESSAGE: &str =
"Application received shutdown signal; attempting to gracefully shutdown";
const SHUTDOWN_TIMEOUT_MESSAGE: &str =
"Use TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT to control the graceful shutdown timeout";
/// Environment variable to control the graceful shutdown timeout
pub const TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT: &str = "TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT";
/// Default graceful shutdown timeout in seconds in debug mode
pub const DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_DEBUG: u64 = 5;
/// Default graceful shutdown timeout in seconds in release mode
pub const DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_RELEASE: u64 = 30;
pub struct Worker {
runtime: Runtime,
}
impl Worker {
/// Create a new [`Worker`] instance from [`RuntimeConfig`] settings which is sourced from the environment
pub fn from_settings() -> Result<Worker> {
let config = RuntimeConfig::from_settings()?;
Worker::from_config(config)
}
/// Create a new [`Worker`] instance from a provided [`RuntimeConfig`]
pub fn from_config(config: RuntimeConfig) -> Result<Worker> {
// if the runtime is already initialized, return an error
if RT.get().is_some() {
return Err(error!("Worker already initialized"));
}
// create a new runtime and insert it into the OnceCell
// there is still a potential race-condition here, two threads cou have passed the first check
// but only one will succeed in inserting the runtime
let rt = RT.try_insert(config.create_runtime()?).map_err(|_| {
error!("Failed to create worker; Only a single Worker should ever be created")
})?;
let runtime = Runtime::from_handle(rt.handle().clone())?;
Ok(Worker { runtime })
}
pub fn tokio_runtime(&self) -> Result<&'static tokio::runtime::Runtime> {
RT.get().ok_or_else(|| error!("Worker not initialized"))
}
pub fn runtime(&self) -> &Runtime {
&self.runtime
}
/// Executes the provided application/closure on the [`Runtime`].
/// This is designed to be called once from main and will block the calling thread until the application completes.
pub fn execute<F, Fut>(self, f: F) -> Result<()>
where
F: FnOnce(Runtime) -> Fut + Send + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
{
let runtime = self.runtime;
let primary = runtime.primary();
let secondary = runtime.secondary.clone();
let timeout = std::env::var(TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or({
if cfg!(debug_assertions) {
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_DEBUG
} else {
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_RELEASE
}
});
INIT.set(Mutex::new(Some(secondary.spawn(async move {
// start signal handler
tokio::spawn(signal_handler(runtime.cancellation_token.clone()));
let cancel_token = runtime.child_token();
let (mut app_tx, app_rx) = tokio::sync::oneshot::channel::<()>();
// spawn a task to run the application
let task: JoinHandle<Result<()>> = primary.spawn(async move {
let _rx = app_rx;
f(runtime).await
});
tokio::select! {
_ = cancel_token.cancelled() => {
eprintln!("{}", SHUTDOWN_MESSAGE);
eprintln!("{} {} seconds", SHUTDOWN_TIMEOUT_MESSAGE, timeout);
}
_ = app_tx.closed() => {
}
};
let result = tokio::select! {
result = task => {
result
}
_ = tokio::time::sleep(tokio::time::Duration::from_secs(timeout)) => {
eprintln!("Application did not shutdown in time; terminating");
std::process::exit(911);
}
}?;
match &result {
Ok(_) => {
log::info!("Application shutdown successfully");
}
Err(e) => {
log::error!("Application shutdown with error: {:?}", e);
}
}
result
}))))
.map_err(|e| error!("Failed to spawn application task: {:?}", e))?;
let task = INIT
.get()
.expect("Application task not initialized")
.lock()
.unwrap()
.take()
.expect("Application initialized; but another thread is awaiting it; Worker.execute() can only be called once");
secondary.block_on(task)?
}
}
/// Catch signals and trigger a shutdown
async fn signal_handler(cancel_token: CancellationToken) -> Result<()> {
let ctrl_c = async {
signal::ctrl_c().await?;
anyhow::Ok(())
};
let sigterm = async {
signal::unix::signal(signal::unix::SignalKind::terminate())?
.recv()
.await;
anyhow::Ok(())
};
tokio::select! {
_ = ctrl_c => {
tracing::info!("Ctrl+C received, starting graceful shutdown");
},
_ = sigterm => {
tracing::info!("SIGTERM received, starting graceful shutdown");
},
_ = cancel_token.cancelled() => {
tracing::info!("CancellationToken triggered; shutting down");
},
}
// trigger a shutdown
cancel_token.cancel();
Ok(())
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
#![allow(dead_code)]
use std::{future::Future, pin::Pin, sync::Arc};
use async_trait::async_trait;
use futures::Stream;
use tokio::sync::mpsc;
use triton_distributed::engine::{
AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream,
Data as DataType, Engine, EngineStream,
};
use triton_distributed::pipeline::{
context::{Context, StreamContext},
Error, ManyOut, SingleIn,
};
pub type AsyncFn<T, U> = dyn Fn(T) -> Pin<Box<dyn Future<Output = U> + Send>> + Send + Sync;
#[derive(Clone)]
// Define a struct that holds an async closure
pub struct AsyncProcessor<T, U> {
func: Arc<AsyncFn<T, U>>,
}
impl<T, U> AsyncProcessor<T, U>
where
T: Send + 'static,
U: Send + 'static,
{
// Define a `new` method that captures the already pinned async block
pub fn new<F, Fut>(f: F) -> Self
where
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = U> + Send + 'static,
{
// Wrap the closure in Arc and Box it for internal management
AsyncProcessor {
func: Arc::new(move |input: T| Box::pin(f(input))),
}
}
// Method to execute the captured async function
pub async fn process(&self, input: T) -> U {
(self.func)(input).await
}
}
#[derive(Debug, Clone)]
pub struct ResponseSource<T: Send + Sync + 'static> {
tx: mpsc::Sender<T>,
ctx: StreamContext,
}
impl<T: Send + Sync + 'static> ResponseSource<T> {
fn new(tx: mpsc::Sender<T>, ctx: StreamContext) -> Self {
ResponseSource { tx, ctx }
}
/// Emit a response to the stream
pub async fn emit(&self, data: T) -> Result<(), ()> {
self.tx.send(data).await.map_err(|_| ())
}
/// Check if a stop has been requested
pub fn stop_requested(&self) -> bool {
self.ctx.is_stopped()
}
/// Yield control until a stop is requested
/// This is useful in a tokio::select! block
pub async fn stopped(&self) {
self.ctx.stopped().await;
}
}
pub type AsyncGenerator<Req, Resp> = AsyncProcessor<(Req, ResponseSource<Resp>), ()>;
pub struct ReceiverStream<Resp: DataType> {
receiver: tokio::sync::mpsc::Receiver<Resp>,
context: Arc<dyn AsyncEngineContext>,
}
impl<Resp: DataType> ReceiverStream<Resp> {
pub fn new(
receiver: tokio::sync::mpsc::Receiver<Resp>,
context: Arc<dyn AsyncEngineContext>,
) -> Self {
Self { receiver, context }
}
}
impl<Resp: DataType> Stream for ReceiverStream<Resp> {
type Item = Resp;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
// if self.context.stop_issued() {
// return std::task::Poll::Ready(None);
// }
// Pinning the receiver to safely call poll_recv
Pin::new(&mut self.receiver).poll_recv(cx)
}
}
impl<Resp: DataType> std::fmt::Debug for ReceiverStream<Resp> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReceiverStream")
.field("context", &self.context)
.finish()
}
}
impl<Resp: DataType> AsyncEngineStream<Resp> for ReceiverStream<Resp> {}
impl<Resp: DataType> AsyncEngineContextProvider for ReceiverStream<Resp> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.context.clone()
}
}
pub struct LlmdbaEngine<Req: DataType, Resp: DataType> {
lambda: Arc<AsyncGenerator<Req, Resp>>,
}
impl<Req: DataType, Resp: DataType> LlmdbaEngine<Req, Resp> {
fn new(lambda: AsyncGenerator<Req, Resp>) -> Self {
LlmdbaEngine {
lambda: Arc::new(lambda),
}
}
pub fn from_generator(
generator: AsyncGenerator<Req, Resp>,
) -> Engine<SingleIn<Req>, ManyOut<Resp>, Error> {
Arc::new(LlmdbaEngine::new(generator))
}
}
#[async_trait]
impl<Req: DataType, Resp: DataType> AsyncEngine<SingleIn<Req>, ManyOut<Resp>, Error>
for LlmdbaEngine<Req, Resp>
{
async fn generate(&self, request: Context<Req>) -> Result<EngineStream<Resp>, Error> {
let (tx, rx) = mpsc::channel::<Resp>(1);
let (req, ctx) = request.transfer(());
let ctx: StreamContext = ctx.into();
let s = ResponseSource::new(tx, ctx.clone());
let lambda = self.lambda.clone();
let _handle = tokio::spawn(async move { lambda.process((req, s)).await });
let ctx = Arc::new(ctx);
let stream = ReceiverStream::<Resp>::new(rx, ctx);
let stream = Box::pin(stream);
Ok(stream)
}
}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use super::*;
#[tokio::test]
async fn test_async_processor() {
let processor = AsyncProcessor::new(move |x: i32| {
async move {
// Simulate some async work
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
format!("Processed value: {}", x)
}
});
// Use the processor to run the async closure
let result = processor.process(42).await;
println!("{}", result); // Output: Processed value: 42
let result2 = processor.process(100).await;
println!("{}", result2); // Output: Processed value: 100
}
#[tokio::test]
async fn test_generator() {
let generator = AsyncGenerator::<String, String>::new(|(req, stream)| async move {
let chars = req.chars().collect::<Vec<char>>();
for c in chars {
match stream.emit(c.to_string()).await {
Ok(_) => {}
Err(_) => break,
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
});
let engine = LlmdbaEngine::new(generator);
let mut stream = engine.generate("test".to_string().into()).await.unwrap();
let mut counter = 0;
while let Some(_output) = stream.next().await {
counter += 1;
}
assert_eq!(counter, 4);
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use async_trait::async_trait;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use triton_distributed::engine::{AsyncEngine, AsyncEngineContext, Data, ResponseStream};
use triton_distributed::pipeline::{
context::{Context, StreamContext},
Error, ManyOut, PipelineError, PipelineIO, SegmentSource, SingleIn,
};
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub enum LatencyModel {
NoDelay,
ConstantDelayInNanos(u64),
NormalDistributionInNanos(u64, u64),
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct MockNetworkOptions {
request_latency: LatencyModel,
response_latency: LatencyModel,
}
impl Default for MockNetworkOptions {
fn default() -> Self {
Self {
request_latency: LatencyModel::NoDelay,
response_latency: LatencyModel::NoDelay,
}
}
}
#[derive(Debug, Clone)]
struct ControlPlaneRequest {
id: String,
request: Vec<u8>,
// convert this into an interface where it describes the worker address
// and how to communicate with the worker
resp_tx: mpsc::Sender<DataPlaneMessage>,
}
enum MockNetworkControlEvents {
ControlPlaneRequest(ControlPlaneRequest),
Cancel(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum MockNetworkDataPlaneHeaders {
Handshake(Handshake),
Error(String),
// tells the subscriber that the stream has ended
// not all transports will be sender side closable, therefore,
// we need a way to signal the end of the stream
//
// note: for transports like nats where the subscriber could
// be left dangling, we will also want to have a keep alive
// and a timeout mechanism
Sentinel,
// heart beat / keep-alive signal to maintain the connection
HeartBeat,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum Status {
Ok,
Error(String),
}
// for transports that support headers, we will use headers for events and the body for the bytes
// for transports like tcp, we may send them as two separate messages on the same socket or as a single
// compound message like the [`DataEnvelope`] object below
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Handshake {
request_id: String,
worker_id: Option<String>,
status: Status,
}
struct DataPlaneMessage {
pub headers: Option<MockNetworkDataPlaneHeaders>,
pub body: Vec<u8>,
}
/// This is an example transport that will inject latency into the response stream.
/// This mimics a network transport that has a delay in the response.
pub struct MockNetworkTransport<T: PipelineIO, U: PipelineIO> {
req: std::marker::PhantomData<T>,
resp: std::marker::PhantomData<U>,
}
impl<Req: PipelineIO, Resp: PipelineIO> MockNetworkTransport<Req, Resp> {
pub fn new_egress_ingress(
options: MockNetworkOptions,
) -> (
Arc<MockNetworkEgress<Req, Resp>>,
MockNetworkIngress<Req, Resp>,
) {
let (ctrl_tx, ctrl_rx) = mpsc::channel::<MockNetworkControlEvents>(8);
// construct the egress/request-sender/response-receiver
let egress = Arc::new(MockNetworkEgress::<Req, Resp>::new(
options.clone(),
ctrl_tx.clone(),
));
// construct the ingress/request-receiver/response-sender
let ingress = MockNetworkIngress::<Req, Resp>::new(options.clone(), ctrl_rx);
(egress, ingress)
}
}
#[allow(dead_code)]
pub struct MockNetworkEgress<Req: PipelineIO, Resp: PipelineIO> {
options: MockNetworkOptions,
ctrl_tx: mpsc::Sender<MockNetworkControlEvents>,
req: std::marker::PhantomData<Req>,
resp: std::marker::PhantomData<Resp>,
}
impl<Req: PipelineIO, Resp: PipelineIO> MockNetworkEgress<Req, Resp> {
fn new(options: MockNetworkOptions, ctrl_tx: mpsc::Sender<MockNetworkControlEvents>) -> Self {
Self {
options,
ctrl_tx,
req: std::marker::PhantomData,
resp: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
for MockNetworkEgress<SingleIn<T>, ManyOut<U>>
where
T: Data + Serialize,
U: for<'de> Deserialize<'de> + Data,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
let id = request.id().to_string();
// serialze the request
let request = request.try_map(|req| serde_json::to_vec(&req))?;
// transfer the request context to a stream context
let (data, context) = request.transfer(());
let context = Arc::new(StreamContext::from(context));
// subscribe to the response stream
// but in this case, we are doing a mock, so we are going to be more explicit
// since we are transferring data over a channel instead of the networ, creating the channel
// is the same as subscribing to the response stream
let (data_tx, data_rx) = mpsc::channel::<DataPlaneMessage>(16);
let mut byte_stream = tokio_stream::wrappers::ReceiverStream::new(data_rx);
// prepare the stateful objects that will be used to monitor the response stream
// finish_rx is a oneshot channel that will be used to signal the natural termination of the stream
let (finished_tx, finished_rx) = tokio::sync::oneshot::channel::<()>();
let stream_monitor = ResponseMonitor {
ctx: context.clone(),
finish_rx: finished_rx,
};
// create the control plane request
// when this is issued, control is handed off to the control plane and the downstream segment
// sometimes we might include the local server address and port for the response find its way home
// todo(design) this will be part of the generalization error for multiple transport types
let request = ControlPlaneRequest {
id,
request: data,
resp_tx: data_tx,
};
// send the request to the control plane
self.ctrl_tx
.send(MockNetworkControlEvents::ControlPlaneRequest(request))
.await
.map_err(|e| PipelineError::ControlPlaneRequestError(e.to_string()))?;
// the first message from the remote publisher on the data plane needs to be a handshake message
// the handshake will indicate to what stream the data belongs to and if the remote segment was
// able to process the request.
//
// note: in the case of the mock transport, the handshaking of the request id is not strictly
// because the channel is specific to the request. this is similar to other transports like nats
// where we will subscribe to a response stream on a subject unique to the stream.
match byte_stream.next().await {
Some(DataPlaneMessage { headers, body }) => {
if !body.is_empty() {
Err(PipelineError::ControlPlaneRequestError(
"Expected an empty body for the handshake message".to_string(),
))?;
}
match headers {
Some(header) => {
match header {
MockNetworkDataPlaneHeaders::Handshake(handshake) => {
match handshake.status {
Status::Ok => {}
Status::Error(e) => {
// todo(metrics): increment metric counter for failed handshakes
Err(PipelineError::ControlPlaneRequestError(format!(
"remote segment was unable to process request: {}",
e
)))?;
}
}
}
_ => {
Err(PipelineError::ControlPlaneRequestError(format!(
"Expected a handshake message; got: {:?}",
header
)))?;
}
}
}
_ => {
Err(PipelineError::ControlPlaneRequestError(
"Failed to receive properly formatted handshake on data plane"
.to_string(),
))?;
}
}
}
None => {
// todo(metrics): increment metric counter for failed requests
Err(PipelineError::ControlPlaneRequestError(
"Failed data plane connection closed before receiving handshake".to_string(),
))?;
}
}
let decoded = byte_stream
// .inspect(|_item| {
// // todo(metrics) increment the metrics counter by the number of bytes
// })
.scan(Some(stream_monitor), move |_stream_monitor, item| {
// we could check the kill state of the context and terminate the stream here
// if our transport needs a heartbeat, trigger a heartbeat here the monitor
if let Some(headers) = &item.headers {
match headers {
MockNetworkDataPlaneHeaders::HeartBeat => {
// todo(metrics): increment metric counter for heartbeats
// send a heartbeat to the control plane
// this is a good place to send a heartbeat to the control plane
// to keep the connection alive
}
MockNetworkDataPlaneHeaders::Sentinel => {
// todo(metrics): increment metric counter for sentinels
// the stream has ended
// send a sentinel to the control plane
// this is a good place to send a sentinel to the control plane
// to indicate the end of the stream
return futures::future::ready(None);
}
_ => {}
}
}
futures::future::ready(Some(item))
})
// decode the response
.map(move |item| {
serde_json::from_slice::<U>(&item.body).expect("failed to deserialize response")
});
// cancellation can be tricky and is transport / protocol specific
// in this case, our channel for this is both ordered and 1:1, thus we can
// use that fact to first send the request, then forward any cancellation requests
// this ensures the downstream node should register the context/request id before any
// cancellation requests are sent
// create the cancellation monitor object
let cancellation_monitor = CancellationMonitor {
ctx: context.clone(),
ctrl_tx: self.ctrl_tx.clone(),
finish_tx: finished_tx,
};
// launch the cancellation monitor task
tokio::spawn(cancellation_monitor.execute());
Ok(ResponseStream::new(Box::pin(decoded), context))
}
}
/// For our MocNetworkTransport, the Ingress will be the one that will be receiving the requests
/// and pushes back the responses
///
/// As such, the Ingress will be the one that will be responsible for receiving control plane messages.
#[allow(dead_code)]
pub struct MockNetworkIngress<Req: PipelineIO, Resp: PipelineIO> {
options: MockNetworkOptions,
ctrl_rx: mpsc::Receiver<MockNetworkControlEvents>,
segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
}
impl<Req: PipelineIO, Resp: PipelineIO> MockNetworkIngress<Req, Resp> {
fn new(options: MockNetworkOptions, ctrl_rx: mpsc::Receiver<MockNetworkControlEvents>) -> Self {
Self {
options,
ctrl_rx,
segment: OnceLock::new(),
}
}
pub fn segment(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<(), PipelineError> {
self.segment
.set(segment)
.map_err(|_| PipelineError::EdgeAlreadySet)
}
}
impl<T: Data, U: Data> MockNetworkIngress<SingleIn<T>, ManyOut<U>>
where
T: Data + for<'de> Deserialize<'de>,
U: Data + Serialize,
{
pub async fn execute(self) -> Result<(), PipelineError> {
let mut state = HashMap::<String, Arc<dyn AsyncEngineContext>>::new();
let worker_id = uuid::Uuid::new_v4().to_string();
let mut ctrl_rx = self.ctrl_rx;
let segment = self.segment.get().expect("segment not set").clone();
while let Some(event) = ctrl_rx.recv().await {
match event {
MockNetworkControlEvents::ControlPlaneRequest(req) => {
// todo(metrics): increment metric counter for bytes received
// todo(metrics): increment metric counter for requests received
let id = req.id.clone();
tracing::debug!("[ingress] received request [id: {}]", id);
// deserialize the request
let request = serde_json::from_slice::<T>(&req.request)
.expect("failed to deserialize request");
// extend request with context
let request = Context::<T>::with_id(request, req.id.clone());
// create the response stream
let response = segment.generate(request).await;
let handshake = match &response {
Ok(_) => Handshake {
request_id: req.id,
worker_id: Some(worker_id.clone()),
status: Status::Ok,
},
Err(e) => Handshake {
request_id: req.id,
worker_id: Some(worker_id.clone()),
status: Status::Error(e.to_string()),
},
};
tracing::debug!("[ingress] sending handshake [id: {}]: {:?}", id, handshake);
// serialize the handshake
let handshake = DataPlaneMessage {
headers: Some(MockNetworkDataPlaneHeaders::Handshake(handshake)),
body: vec![],
};
// send the handshake
req.resp_tx
.send(handshake)
.await
.expect("failed to send handshake");
tracing::trace!("[ingress] handshake sent [id: {}]", id);
if let Ok(response) = response {
// spawn a task to process the response stream:
// - serialize each response
// - forward the bytes to the data plane
tracing::debug!("[ingress] processing response stream [id: {}]", id);
tokio::spawn(async move {
let mut response = response;
while let Some(resp) = response.next().await {
tracing::trace!("[ingress] received response [id: {}]", id);
let resp_bytes = serde_json::to_vec(&resp)
.expect("failed to serialize response");
let msg = DataPlaneMessage {
headers: None,
body: resp_bytes,
};
// send the response
req.resp_tx
.send(msg)
.await
.expect("failed to send response");
tracing::trace!("[ingress] sent response [id: {}]", id);
}
tracing::debug!("response stream completed [id: {}]", id);
});
}
}
MockNetworkControlEvents::Cancel(id) => {
// todo(metrics): increment metric counter for cancelled requests
// todo(metrics): increment metric counter for bytes received
// todo(metrics): increment metric counter for requests received
// cancel the request
if let Some(tx) = state.remove(&id) {
tx.stop_generating();
}
}
}
}
Ok(())
}
}
// fn create_error_message(id: &str, e: &str) -> Hand {
// format!("Failed to deserialize request [id: {}]: {}", id, e)
// }
/// Object transferred to the Cancellation Monitor Task
///
/// The cancellation monitor task will be responsible for taking action on a
/// cancellation request.
///
/// This object holds a oneshot channel that will be used to signal the natural
/// termination of the stream.
///
/// Our cancellation monitor task select on those two signals and complete when
/// either of them is completed.
struct CancellationMonitor {
ctx: Arc<StreamContext>,
// control plane sender
ctrl_tx: tokio::sync::mpsc::Sender<MockNetworkControlEvents>,
// the cancellation mni
// as completed
finish_tx: tokio::sync::oneshot::Sender<()>,
}
impl CancellationMonitor {
async fn execute(self) {
// select on the finish_rx and the kill signal
let ctx = self.ctx;
let ctrl_tx = self.ctrl_tx;
let mut finish_tx = self.finish_tx;
tokio::select! {
_ = ctx.stopped() => {
// todo(metrics): increment metric counter for cancelled requests
// send a cancellation request to the control plane
let _ = ctrl_tx.send(MockNetworkControlEvents::Cancel(ctx.id().to_string())).await;
}
_ = finish_tx.closed() => {
// the stream has completed naturally
}
}
}
}
// held by the scan combinator
#[allow(dead_code)]
struct ResponseMonitor {
ctx: Arc<StreamContext>,
finish_rx: tokio::sync::oneshot::Receiver<()>,
}
pub mod engines;
pub mod mock;
use triton_distributed::{worker::Worker, Result, Runtime};
async fn hello_world(_runtime: Runtime) -> Result<()> {
Ok(())
}
#[test]
fn test_lifecycle() {
let worker = Worker::from_settings().unwrap();
worker.execute(hello_world).unwrap();
}
// async fn discoverable(runtime: Runtime) -> Result<()> {
// let config = DiscoveryConfig {
// etcd_url: vec!["http://localhost:2379".to_string()],
// etcd_connect_options: None,
// };
// let client = DiscoveryClient::new(config, runtime.clone()).await?;
// println!("Primary lease id: {:x}", client.lease_id());
// let lease = client.create_lease(60).await?;
// // Keys and values
// let lock_key = "lock_key"; // Key for the lock
// let object_key = "object_key"; // Key for the object
// let object_value = "This is the object value"; // Value for the object
// let lock_value = "locked"; // Value indicating a lock
// let put_options = Some(PutOptions::new().with_lease(lease.id()));
// // Build the transaction
// let txn = Txn::new()
// .when(vec![Compare::version(lock_key, CompareOp::Equal, 0)]) // Ensure the lock does not exist
// .and_then(vec![
// TxnOp::put(object_key, object_value, put_options.clone()), // Create the object
// TxnOp::put(lock_key, lock_value, put_options), // Set the lock
// ]);
// // Execute the transaction
// let txn_response = client.etc_client().kv_client().txn(txn).await?;
// tokio::spawn(async move {
// println!("custom lease id: {:x}", lease.id());
// lease.cancel_token().cancelled().await;
// println!("custom lease revoked");
// });
// runtime.child_token().cancelled().await;
// Ok(())
// }
// #[test]
// fn test_discovery_client() {
// let runtime = Runtime::new(RuntimeConfig::default()).unwrap();
// runtime.execute(discoverable).unwrap();
// }
use futures::{stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::{sync::Arc, time::Duration};
use triton_distributed::engine::ResponseStream;
use triton_distributed::{
pipeline::{
async_trait, AsyncEngine, Data, Event, ManyOut, Operator, ServiceBackend, ServiceEngine,
ServiceFrontend, SingleIn, *,
},
Error,
};
mod common;
use common::engines::{AsyncGenerator, LlmdbaEngine as LambdaEngine};
use common::mock;
/// The [`super::engine::ResponseStream`] is annotated with the following types.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum Annotated<T: Data> {
/// The primary data which expected to be returned.
Data(T),
/// An actionable [`Event`] that can be handled.
Event(Event),
/// Additional information or metadata produced by the pipeline.
Comment(String),
/// An error produced by the pipeline. Multiple errors can be produced.
Error(String),
/// A sentinel value to indicate the end of the stream. This should not be emitted publicly.
/// The implementation should be able to do the equivalent of a `.take_while` and trigger a
/// stop if detected.
End,
}
/// An [`Operator`] is used when you want to transform both the input and output of a pipeline.
/// In this case, our operator will perform the preprocessing step, but also add an annotation
/// to the output stream
struct PreprocesOperator {}
#[async_trait]
impl
Operator<
SingleIn<String>,
ManyOut<Annotated<String>>,
SingleIn<String>,
ManyOut<Annotated<String>>,
> for PreprocesOperator
{
async fn generate(
&self,
req: SingleIn<String>,
next: Arc<dyn AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error>>,
) -> Result<ManyOut<Annotated<String>>, Error> {
// capture some details about the request
let prepend = vec![Annotated::<String>::Comment(format!(
"PreprocessOperator: {:?}",
req
))];
// we will append the result of this to the response stream via a chain
let prepend_stream = stream::iter(prepend);
// modify the request
let req = req.map(|x| format!("{} from operator", x));
// issue the preprocessed request to the next engine
let stream = next.generate(req).await?;
// capture the context of the response stream
let ctx = stream.context();
// chain the prepend stream to the response stream
Ok(ResponseStream::new(
Box::pin(prepend_stream.chain(stream)),
ctx,
))
}
}
fn make_backend_engine() -> ServiceEngine<SingleIn<String>, ManyOut<Annotated<String>>> {
LambdaEngine::from_generator(AsyncGenerator::<String, Annotated<String>>::new(
|(req, stream)| async move {
let chars = req.chars().collect::<Vec<char>>();
for c in chars {
match stream.emit(Annotated::Data(c.to_string())).await {
Ok(_) => {}
Err(_) => return,
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
},
))
}
#[tokio::test]
async fn test_service_source_sink() {
let source = ServiceFrontend::<SingleIn<String>, ManyOut<Annotated<String>>>::new();
let sink = ServiceBackend::from_engine(make_backend_engine());
let service = source.link(sink).unwrap().link(source).unwrap();
let mut stream = service.generate("test".to_string().into()).await.unwrap();
let mut counter = 0;
while let Some(_output) = stream.next().await {
counter += 1;
}
assert_eq!(counter, 4);
}
fn make_preprocessor() -> Arc<PipelineNode<SingleIn<String>, SingleIn<String>>> {
PipelineNode::<SingleIn<String>, SingleIn<String>>::new(Box::new(|req| {
Ok(req.map(|x| format!("{} world", x)))
}))
}
#[allow(clippy::type_complexity)]
fn make_postprocessor() -> Arc<PipelineNode<ManyOut<Annotated<String>>, ManyOut<Annotated<String>>>>
{
PipelineNode::<ManyOut<Annotated<String>>, ManyOut<Annotated<String>>>::new(Box::new(|req| {
let ctx = req.context();
let double_stream = req.flat_map(|x| {
let x1 = x.clone();
let x2 = x;
stream::iter(vec![x1, x2])
});
Ok(ResponseStream::new(Box::pin(double_stream), ctx))
}))
}
// Node 0:
// [frontend] -------[pre processor]-----> [backend]
// [frontend] <----- [post processor] ---- [backend]
fn make_service(
) -> Result<ServiceEngine<SingleIn<String>, ManyOut<Annotated<String>>>, PipelineError> {
// Frontend - Callable interface
let frontend = ServiceFrontend::<SingleIn<String>, ManyOut<Annotated<String>>>::new();
// Mimics processing the prompt and tokenization
let preprocess = make_preprocessor();
// Mimics decoding; shows we can use any type of stream operation,
// e.g. map, flat_map, fold, scan, etc. to transform the response stream
let postprocess = make_postprocessor();
// Mimics backend streaming by emitting each character of the input string
let backend = ServiceBackend::from_engine(make_backend_engine());
// LLM Pipelines are build by linking the frontend to the backend for input handling
// then linking from the backend to the frontend for the output handling
let service = frontend
.link(preprocess)?
.link(backend)?
.link(postprocess)?
.link(frontend)?;
Ok(service)
}
#[tokio::test]
async fn test_service_source_node_sink() {
let service = make_service().unwrap();
let mut stream = service.generate("test".to_string().into()).await.unwrap();
let mut counter = 0;
while let Some(_output) = stream.next().await {
counter += 1;
}
assert_eq!(counter, 20);
}
// Put the post process on node 0, but the preprocessor and the compute on node1
// Node 0:
// [frontend] ---------------------------> [segment_sink]
// [frontend] <----- [post processor] ---- [segment_sink]
//
// Node 1:
// [segment_source] ---- [preprocessor] ---> [backend]
// [segment_source] <----------------------- [backend]
#[tokio::test]
async fn test_disaggregated_service() {
println!("Running test_disaggregated_service");
// Node 0
let frontend = ServiceFrontend::<SingleIn<String>, ManyOut<Annotated<String>>>::new();
let postprocessor = make_postprocessor();
let end_node_0 = SegmentSink::<SingleIn<String>, ManyOut<Annotated<String>>>::new();
let node0_service = frontend
.link(end_node_0.clone())
.unwrap()
.link(postprocessor)
.unwrap()
.link(frontend)
.unwrap();
// Node 1
let start_node1 = SegmentSource::<SingleIn<String>, ManyOut<Annotated<String>>>::new();
let preprocessor = make_preprocessor();
let backend = ServiceBackend::from_engine(make_backend_engine());
let node1_service = start_node1
.link(preprocessor)
.unwrap()
.link(backend)
.unwrap()
.link(start_node1.clone())
.unwrap();
let opts = mock::MockNetworkOptions::default();
let (egress, ingress) = mock::MockNetworkTransport::<
SingleIn<String>,
ManyOut<Annotated<String>>,
>::new_egress_ingress(opts);
end_node_0.attach(egress).unwrap();
ingress.segment(node1_service).unwrap();
tokio::spawn(ingress.execute());
let mut stream = node0_service
.generate("test".to_string().into())
.await
.unwrap();
let mut counter = 0;
while let Some(_output) = stream.next().await {
counter += 1;
}
assert_eq!(counter, 20);
}
// Node 0:
// [frontend] --> [pre processor] --> [operator] ----------------------> [backend]
// [frontend] <---------------------- [operator] <--[post processor] <-- [backend]
fn make_service_with_operator(
) -> Result<ServiceEngine<SingleIn<String>, ManyOut<Annotated<String>>>, PipelineError> {
// Frontend - Callable interface
let frontend = ServiceFrontend::<SingleIn<String>, ManyOut<Annotated<String>>>::new();
// Mimics processing the prompt and tokenization
let preprocess = make_preprocessor();
// Mimics decoding; shows we can use any type of stream operation,
// e.g. map, flat_map, fold, scan, etc. to transform the response stream
let postprocess = make_postprocessor();
// Mimics backend streaming by emitting each character of the input string
let backend = ServiceBackend::from_engine(make_backend_engine());
let operator = PipelineOperator::new(Arc::new(PreprocesOperator {}));
// LLM Pipelines are build by linking the frontend to the backend for input handling
// then linking from the backend to the frontend for the output handling
let service = frontend
.link(preprocess)?
.link(operator.forward_edge())?
.link(backend)?
.link(postprocess)?
.link(operator.backward_edge())?
.link(frontend)?;
Ok(service)
}
#[tokio::test]
async fn test_service_source_node_sink_with_operator() {
let service = make_service_with_operator().unwrap();
let mut stream = service.generate("test".to_string().into()).await.unwrap();
let mut counter = 0;
let mut annotations_counter = 0;
while let Some(output) = stream.next().await {
match output {
Annotated::Data(_) => counter += 1,
Annotated::Comment(_) => annotations_counter += 1,
_ => {}
}
}
assert_eq!(annotations_counter, 1);
assert_eq!(counter, 48);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment