Commit 3b7a462d authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: event plane + count


Signed-off-by: default avatarRyan Olson <ryanolson@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 6e0ccccb
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::StreamExt;
use service_metrics::DEFAULT_NAMESPACE;
use triton_distributed::{
logging,
protocols::annotated::Annotated,
utils::{stream, Duration, Instant},
DistributedRuntime, Result, Runtime, Worker,
};
fn main() -> Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let namespace = distributed.namespace(DEFAULT_NAMESPACE)?;
let component = namespace.component("backend")?;
let client = component
.endpoint("generate")
.client::<String, Annotated<String>>()
.await?;
client.wait_for_endpoints().await?;
let mut stream = client.random("hello world".to_string().into()).await?;
while let Some(resp) = stream.next().await {
println!("{:?}", resp);
}
let service_set = component.scrape_stats(Duration::from_millis(100)).await?;
println!("{:?}", service_set);
runtime.shutdown();
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use service_metrics::DEFAULT_NAMESPACE;
use std::sync::Arc;
use triton_distributed::{
logging,
pipeline::{
async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut,
ResponseStream, SingleIn,
},
protocols::annotated::Annotated,
stream, DistributedRuntime, Result, Runtime, Worker,
};
fn main() -> Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
backend(distributed).await
}
struct RequestHandler {}
impl RequestHandler {
fn new() -> Arc<Self> {
Arc::new(Self {})
}
}
#[async_trait]
impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for RequestHandler {
async fn generate(&self, input: SingleIn<String>) -> Result<ManyOut<Annotated<String>>> {
let (data, ctx) = input.into_parts();
let chars = data
.chars()
.map(|c| Annotated::from_data(c.to_string()))
.collect::<Vec<_>>();
let stream = stream::iter(chars);
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}
async fn backend(runtime: DistributedRuntime) -> Result<()> {
// attach an ingress to an engine
let ingress = Ingress::for_engine(RequestHandler::new())?;
// make the ingress discoverable via a component service
// we must first create a service, then we can attach one more more endpoints
runtime
.namespace(DEFAULT_NAMESPACE)?
.component("backend")?
.service_builder()
.create()
.await?
.endpoint("generate")
.endpoint_builder()
.handler(ingress)
.start()
.await
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub const DEFAULT_NAMESPACE: &str = "triton-init";
......@@ -17,9 +17,9 @@ use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
use std::ffi::CStr;
use uuid::Uuid;
use std::sync::atomic::{AtomicU32, Ordering};
use tracing as log;
use uuid::Uuid;
use triton_distributed::{DistributedRuntime, Worker};
use triton_llm::kv_router::{
......@@ -37,8 +37,7 @@ fn initialize_tracing() {
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish();
tracing::subscriber::set_global_default(subscriber)
.expect("setting default subscriber failed");
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
log::debug!("Tracing initialized");
}
......@@ -193,11 +192,20 @@ fn kv_event_create_stored_from_parts(
let num_toks = unsafe { *num_block_tokens.offset(block_idx.try_into().unwrap()) };
// compute hash only apply to full block (KV_BLOCK_SIZE token)
if num_toks != 64 {
if WARN_COUNT.fetch_update(
Ordering::SeqCst,
Ordering::SeqCst,
|c| if c < 3 { Some(c + 1) } else { None }).is_ok() {
log::warn!("Block size must be 64 tokens to be published. Block size is: {}", num_toks);
if WARN_COUNT
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
if c < 3 {
Some(c + 1)
} else {
None
}
})
.is_ok()
{
log::warn!(
"Block size must be 64 tokens to be published. Block size is: {}",
num_toks
);
}
break;
}
......
......@@ -1188,12 +1188,16 @@ mod tests {
assert_eq!(hashes.len(), 1);
// create a sequence of 65 elements
let sequence = (0..(KV_BLOCK_SIZE + 1)).map(|i| i as u32).collect::<Vec<u32>>();
let sequence = (0..(KV_BLOCK_SIZE + 1))
.map(|i| i as u32)
.collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence);
assert_eq!(hashes.len(), 1);
// create a sequence of 129 elements
let sequence = (0..(2 * KV_BLOCK_SIZE + 1)).map(|i| i as u32).collect::<Vec<u32>>();
let sequence = (0..(2 * KV_BLOCK_SIZE + 1))
.map(|i| i as u32)
.collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence);
assert_eq!(hashes.len(), 2);
}
......
......@@ -15,9 +15,9 @@
use crate::kv_router::{indexer::RouterEvent, protocols::KvCacheEvent, KV_EVENT_SUBJECT};
use tokio::sync::mpsc;
use tracing as log;
use triton_distributed::{component::Component, DistributedRuntime, Result};
use uuid::Uuid;
use tracing as log;
pub struct KvPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>,
......
......@@ -23,6 +23,7 @@ use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints;
#[allow(dead_code)]
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
#[error("no endpoints aviailable to route work")]
......
......@@ -41,9 +41,11 @@
//!
//! TODO: Top-level Overview of Endpoints/Functions
use crate::discovery::Lease;
use crate::{discovery::Lease, service::ServiceSet};
use super::{error, traits::*, transports::nats::Slug, DistributedRuntime, Result, Runtime};
use super::{
error, traits::*, transports::nats::Slug, utils::Duration, DistributedRuntime, Result, Runtime,
};
use crate::pipeline::network::{ingress::push_endpoint::PushEndpoint, PushWorkHandler};
use async_nats::{
......@@ -59,6 +61,7 @@ use validator::{Validate, ValidationError};
mod client;
mod endpoint;
mod namespace;
mod registry;
mod service;
......@@ -125,20 +128,17 @@ impl Component {
format!("{}/components/{}", self.namespace, self.name)
}
pub fn drt(&self) -> &DistributedRuntime {
&self.drt
}
fn slug(&self) -> Slug {
Slug::from_string(self.etcd_path())
}
pub fn service_name(&self) -> String {
self.slug().to_string()
Slug::from_string(format!("{}|{}", self.namespace, self.name)).to_string()
}
// todo - move to EventPlane
pub fn event_subject(&self, name: impl AsRef<str>) -> String {
format!("{}.events.{}", self.slug(), name.as_ref())
format!("{}.events.{}", self.service_name(), name.as_ref())
}
pub fn drt(&self) -> &DistributedRuntime {
&self.drt
}
pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint {
......@@ -154,6 +154,14 @@ impl Component {
unimplemented!("endpoints")
}
pub async fn scrape_stats(&self, duration: Duration) -> Result<ServiceSet> {
let service_name = self.service_name();
let service_client = self.drt().service_client();
service_client
.collect_services(&service_name, duration)
.await
}
/// TODO
///
/// This method will scrape the stats for all available services
......@@ -217,8 +225,17 @@ impl Endpoint {
format!("{}-{:x}", self.name, lease_id)
}
pub fn subject(&self, lease_id: i64) -> String {
format!("{}.{}", self.component.slug(), self.name_with_id(lease_id))
pub fn subject(&self) -> String {
format!("{}.{}", self.component.service_name(), self.name)
}
/// Subject to an instance of the [Endpoint] with a specific lease id
pub fn subject_to(&self, lease_id: i64) -> String {
format!(
"{}.{}",
self.component.service_name(),
self.name_with_id(lease_id)
)
}
pub async fn client<Req, Resp>(&self) -> Result<client::Client<Req, Resp>>
......@@ -273,6 +290,10 @@ impl Namespace {
.namespace(self.name.clone())
.build()?)
}
pub fn name(&self) -> &str {
&self.name
}
}
// Custom validator function
......
......@@ -184,7 +184,7 @@ where
endpoints[offset as usize]
};
let subject = self.endpoint.subject(endpoint_id);
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
......@@ -206,7 +206,7 @@ where
endpoints[offset as usize]
};
let subject = self.endpoint.subject(endpoint_id);
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
......@@ -227,7 +227,7 @@ where
));
}
let subject = self.endpoint.subject(endpoint_id);
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
......
......@@ -57,7 +57,7 @@ impl EndpointConfigBuilder {
.lock()
.await
.get(&endpoint.component.etcd_path())
.map(|service| service.group(endpoint.component.slug()))
.map(|service| service.group(endpoint.component.service_name()))
.ok_or(error!("Service not found"))?;
// let group = service.group(service_name.as_str());
......@@ -89,7 +89,7 @@ impl EndpointConfigBuilder {
endpoint: endpoint.name.clone(),
namespace: endpoint.component.namespace.clone(),
lease_id: lease.id(),
transport: TransportType::NatsTcp(endpoint.subject(lease.id())),
transport: TransportType::NatsTcp(endpoint.subject_to(lease.id())),
};
let info = serde_json::to_vec_pretty(&info)?;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use super::*;
use crate::traits::events::EventPublisher;
#[async_trait]
impl EventPublisher for Namespace {
fn subject(&self) -> String {
format!("namespace.{}", self.name)
}
async fn publish(
&self,
event_name: impl AsRef<str> + Send + Sync,
event: &(impl Serialize + Send + Sync),
) -> Result<()> {
let bytes = serde_json::to_vec(event)?;
self.publish_bytes(event_name, bytes).await
}
async fn publish_bytes(
&self,
event_name: impl AsRef<str> + Send + Sync,
bytes: Vec<u8>,
) -> Result<()> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self
.drt()
.nats_client()
.client()
.publish(subject, bytes.into())
.await?)
}
}
#[cfg(test)]
mod tests {
use super::*;
// todo - make a distributed runtime fixture
// todo - two options - fully mocked or integration test
#[cfg(feature = "integration")]
#[tokio::test]
async fn test_publish() {
// todo - use rtest - make fixtures
let dtr = DistributedRuntime::from_settings(Runtime::single_threaded().unwrap())
.await
.unwrap();
let ns = dtr.namespace("test".to_string()).unwrap();
ns.publish("test", &"test".to_string()).await.unwrap();
}
}
......@@ -48,7 +48,7 @@ impl ServiceConfigBuilder {
let (component, description, stat_handler) = self.build_internal()?.dissolve();
let service_name = component.slug();
let service_name = component.service_name();
let description = description.unwrap_or(format!(
"Triton Component {} in {}",
component.name, component.namespace
......
......@@ -39,7 +39,9 @@ pub mod runnable;
pub mod runtime;
pub mod service;
pub mod slug;
pub mod traits;
pub mod transports;
pub mod utils;
pub mod worker;
pub mod distributed;
......@@ -82,23 +84,3 @@ pub struct DistributedRuntime {
// paths in etcd to a minimum.
component_registry: component::Registry,
}
pub mod traits {
use super::*;
/// A trait for objects taht proivde access to the [Runtime]
pub trait RuntimeProvider {
fn rt(&self) -> &Runtime;
}
/// A trait for objects that provide access to the [DistributedRuntime].
pub trait DistributedRuntimeProvider {
fn drt(&self) -> &DistributedRuntime;
}
impl RuntimeProvider for DistributedRuntime {
fn rt(&self) -> &Runtime {
&self.runtime
}
}
}
......@@ -19,13 +19,13 @@
// we will want to associate the components cancellation token with the
// component's "service state"
use crate::{transports::nats, Result};
use crate::{error, transports::nats, utils::stream, Result};
use async_nats::Message;
use async_stream::try_stream;
use bytes::Bytes;
use derive_getters::Dissolve;
use futures::stream::StreamExt;
use futures::stream::{StreamExt, TryStreamExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::time::Duration;
......@@ -39,6 +39,7 @@ impl ServiceClient {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceSet {
services: Vec<ServiceInfo>,
}
......@@ -58,14 +59,25 @@ pub struct EndpointInfo {
pub subject: String,
#[serde(flatten)]
pub data: Metrics,
pub data: Option<Metrics>,
}
impl EndpointInfo {
pub fn id(&self) -> Result<i64> {
let id = self
.subject
.split('-')
.last()
.ok_or_else(|| error!("No id found in subject"))?;
i64::from_str_radix(id, 16).map_err(|e| error!("Invalid id format: {}", e))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
pub struct Metrics(pub serde_json::Value);
impl Metrics {
pub fn decode<T: DeserializeOwned>(self) -> Result<T> {
pub fn decode<T: for<'de> Deserialize<'de>>(self) -> Result<T> {
serde_json::from_value(self.0).map_err(Into::into)
}
}
......@@ -89,32 +101,30 @@ impl ServiceClient {
service_name: &str,
duration: Duration,
) -> Result<ServiceSet> {
let mut sub = self.nats_client.service_subscriber(service_name).await?;
let sub = self.nats_client.scrape_service(service_name).await?;
if duration.is_zero() {
tracing::warn!("collect_services: duration is zero");
}
if duration > Duration::from_secs(10) {
tracing::warn!("collect_services: duration is greater than 10 seconds");
}
let deadline = tokio::time::Instant::now() + duration;
let services: Vec<Result<ServiceInfo>> = try_stream! {
while let Ok(Some(message)) = tokio::time::timeout_at(deadline, sub.next()).await {
if message.payload.is_empty() {
continue;
let services = stream::until_deadline(sub, deadline)
.map(|message| serde_json::from_slice::<ServiceInfo>(&message.payload))
.filter_map(|info| async move {
match info {
Ok(info) => Some(info),
Err(e) => {
log::debug!("error decoding service info: {:?}", e);
None
}
}
let service = serde_json::from_slice::<ServiceInfo>(&message.payload)?;
tracing::trace!("service: {:?}", service);
yield service;
}
}
.collect()
.await;
// split ok and error results
let (ok, err): (Vec<_>, Vec<_>) = services.into_iter().partition(Result::is_ok);
if !err.is_empty() {
tracing::error!("failed to collect services: {:?}", err);
}
})
.collect()
.await;
Ok(ServiceSet {
services: ok.into_iter().map(Result::unwrap).collect(),
})
Ok(ServiceSet { services })
}
}
......@@ -143,12 +153,12 @@ mod tests {
EndpointInfo {
name: "endpoint1".to_string(),
subject: "subject1".to_string(),
data: Metrics(serde_json::json!({"key": "value1"})),
data: Some(Metrics(serde_json::json!({"key": "value1"}))),
},
EndpointInfo {
name: "endpoint2-foo".to_string(),
subject: "subject2".to_string(),
data: Metrics(serde_json::json!({"key": "value1"})),
data: Some(Metrics(serde_json::json!({"key": "value1"}))),
},
],
},
......@@ -161,12 +171,12 @@ mod tests {
EndpointInfo {
name: "endpoint1".to_string(),
subject: "subject1".to_string(),
data: Metrics(serde_json::json!({"key": "value1"})),
data: Some(Metrics(serde_json::json!({"key": "value1"}))),
},
EndpointInfo {
name: "endpoint2-bar".to_string(),
subject: "subject2".to_string(),
data: Metrics(serde_json::json!({"key": "value2"})),
data: Some(Metrics(serde_json::json!({"key": "value2"}))),
},
],
},
......
......@@ -60,7 +60,7 @@ impl Slug {
.to_lowercase()
.chars()
.map(|c| {
let is_valid = c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_';
let is_valid = c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_';
if is_valid {
c
} else {
......@@ -69,7 +69,7 @@ impl Slug {
})
.collect::<String>();
let hash = blake3::hash(s.as_bytes()).to_string();
let out = format!("{out}-{}", &hash[(hash.len() - 8)..]);
let out = format!("{out}_{}", &hash[(hash.len() - 8)..]);
Slug::new(out)
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod events;
use super::{DistributedRuntime, Runtime};
/// A trait for objects taht proivde access to the [Runtime]
pub trait RuntimeProvider {
fn rt(&self) -> &Runtime;
}
/// A trait for objects that provide access to the [DistributedRuntime].
pub trait DistributedRuntimeProvider {
fn drt(&self) -> &DistributedRuntime;
}
impl RuntimeProvider for DistributedRuntime {
fn rt(&self) -> &Runtime {
&self.runtime
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use crate::Result;
// #[async_trait]
// pub trait Publisher: Debug + Clone + Send + Sync {
// async fn publish(&self, event: &(impl Serialize + Send + Sync)) -> Result<()>;
// }
/// A [EventPlane] is a component that can publish and/or subscribe to events.
///
/// Each implementation of [EventPlane] will define the root subject.
#[async_trait]
pub trait EventPublisher {
/// The base subject used for this implementation of the [EventPlane].
fn subject(&self) -> String;
/// Publish a single event to the event plane. The `event_name` will be `.` concatenated with the
/// base subject provided by the implementation.
async fn publish(
&self,
event_name: impl AsRef<str> + Send + Sync,
event: &(impl Serialize + Send + Sync),
) -> Result<()>;
/// Publish a single event as bytes to the event plane. The `event_name` will be `.` concatenated with the
/// base subject provided by the implementation.
async fn publish_bytes(
&self,
event_name: impl AsRef<str> + Send + Sync,
bytes: Vec<u8>,
) -> Result<()>;
// /// Create a new publisher for the given event name. The `event_name` will be `.` concatenated with the
// /// base subject provided by the implementation.
// fn publisher(&self, event_name: impl AsRef<str>) -> impl Publisher;
// /// Create a new publisher for the given event name. The `event_name` will be `.` concatenated with the
// fn publisher(&self, event_name: impl AsRef<str>) -> Result<Publisher>;
// fn publisher_bytes(&self, event_name: impl AsRef<str>) -> &PublisherBytes;
}
......@@ -88,7 +88,14 @@ impl Client {
Ok(stream)
}
pub async fn service_subscriber(&self, service_name: &str) -> Result<Subscriber> {
/// Issues a broadcast request for all services with the provided `service_name` to report their
/// current stats. Each service will only respond once. The service may have customized the reply
/// so the caller should select which endpoint and what concrete data model should be used to
/// extract the details.
///
/// Note: Because each endpoint will only reply once, the caller must drop the subscription after
/// some time or it will await forever.
pub async fn scrape_service(&self, service_name: &str) -> Result<Subscriber> {
let subject = format!("$SRV.STATS.{}", service_name);
let reply_subject = format!("_INBOX.{}", nuid::next());
let subscription = self.client.subscribe(reply_subject.clone()).await?;
......@@ -98,20 +105,6 @@ impl Client {
.publish_with_reply(subject, reply_subject, "".into())
.await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = time::Instant::now();
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
// tx.send(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
Ok(subscription)
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub use tokio::time::{Duration, Instant};
pub mod stream;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::stream::{Stream, StreamExt};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::time::{self, sleep_until, Duration, Instant, Sleep};
pub struct DeadlineStream<S> {
stream: S,
sleep: Pin<Box<Sleep>>,
}
impl<S: Stream + Unpin> Stream for DeadlineStream<S> {
type Item = S::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// First, check if our sleep future has completed
if Pin::new(&mut self.sleep).poll(cx).is_ready() {
// The deadline expired; end the stream now
return Poll::Ready(None);
}
// Otherwise, poll the underlying stream
self.as_mut().stream.poll_next_unpin(cx)
}
}
pub fn until_deadline<S: Stream + Unpin>(stream: S, deadline: Instant) -> DeadlineStream<S> {
DeadlineStream {
stream,
sleep: Box::pin(sleep_until(deadline)),
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment