// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #![allow(dead_code)] use std::{future::Future, pin::Pin, sync::Arc}; use async_trait::async_trait; use futures::Stream; use tokio::sync::mpsc; use triton_distributed_runtime::engine::{ AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data as DataType, Engine, EngineStream, }; use triton_distributed_runtime::pipeline::{ context::{Context, StreamContext}, Error, ManyOut, SingleIn, }; pub type AsyncFn = dyn Fn(T) -> Pin + Send>> + Send + Sync; #[derive(Clone)] // Define a struct that holds an async closure pub struct AsyncProcessor { func: Arc>, } impl AsyncProcessor where T: Send + 'static, U: Send + 'static, { // Define a `new` method that captures the already pinned async block pub fn new(f: F) -> Self where F: Fn(T) -> Fut + Send + Sync + 'static, Fut: Future + Send + 'static, { // Wrap the closure in Arc and Box it for internal management AsyncProcessor { func: Arc::new(move |input: T| Box::pin(f(input))), } } // Method to execute the captured async function pub async fn process(&self, input: T) -> U { (self.func)(input).await } } #[derive(Debug, Clone)] pub struct ResponseSource { tx: mpsc::Sender, ctx: StreamContext, } impl ResponseSource { fn new(tx: mpsc::Sender, ctx: StreamContext) -> Self { ResponseSource { tx, ctx } } /// Emit a response to the stream pub async fn emit(&self, data: T) -> Result<(), ()> { self.tx.send(data).await.map_err(|_| ()) } /// Check if a stop has been requested pub fn stop_requested(&self) -> bool { self.ctx.is_stopped() } /// Yield control until a stop is requested /// This is useful in a tokio::select! block pub async fn stopped(&self) { self.ctx.stopped().await; } } pub type AsyncGenerator = AsyncProcessor<(Req, ResponseSource), ()>; pub struct ReceiverStream { receiver: tokio::sync::mpsc::Receiver, context: Arc, } impl ReceiverStream { pub fn new( receiver: tokio::sync::mpsc::Receiver, context: Arc, ) -> Self { Self { receiver, context } } } impl Stream for ReceiverStream { type Item = Resp; fn poll_next( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { // if self.context.stop_issued() { // return std::task::Poll::Ready(None); // } // Pinning the receiver to safely call poll_recv Pin::new(&mut self.receiver).poll_recv(cx) } } impl std::fmt::Debug for ReceiverStream { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ReceiverStream") .field("context", &self.context) .finish() } } impl AsyncEngineStream for ReceiverStream {} impl AsyncEngineContextProvider for ReceiverStream { fn context(&self) -> Arc { self.context.clone() } } pub struct LlmdbaEngine { lambda: Arc>, } impl LlmdbaEngine { fn new(lambda: AsyncGenerator) -> Self { LlmdbaEngine { lambda: Arc::new(lambda), } } pub fn from_generator( generator: AsyncGenerator, ) -> Engine, ManyOut, Error> { Arc::new(LlmdbaEngine::new(generator)) } } #[async_trait] impl AsyncEngine, ManyOut, Error> for LlmdbaEngine { async fn generate(&self, request: Context) -> Result, Error> { let (tx, rx) = mpsc::channel::(1); let (req, ctx) = request.transfer(()); let ctx: StreamContext = ctx.into(); let s = ResponseSource::new(tx, ctx.clone()); let lambda = self.lambda.clone(); let _handle = tokio::spawn(async move { lambda.process((req, s)).await }); let ctx = Arc::new(ctx); let stream = ReceiverStream::::new(rx, ctx); let stream = Box::pin(stream); Ok(stream) } } #[cfg(test)] mod tests { use futures::StreamExt; use super::*; #[tokio::test] async fn test_async_processor() { let processor = AsyncProcessor::new(move |x: i32| { async move { // Simulate some async work tokio::time::sleep(std::time::Duration::from_millis(100)).await; format!("Processed value: {}", x) } }); // Use the processor to run the async closure let result = processor.process(42).await; println!("{}", result); // Output: Processed value: 42 let result2 = processor.process(100).await; println!("{}", result2); // Output: Processed value: 100 } #[tokio::test] async fn test_generator() { let generator = AsyncGenerator::::new(|(req, stream)| async move { let chars = req.chars().collect::>(); for c in chars { match stream.emit(c.to_string()).await { Ok(_) => {} Err(_) => break, } tokio::time::sleep(std::time::Duration::from_millis(100)).await; } }); let engine = LlmdbaEngine::new(generator); let mut stream = engine.generate("test".to_string().into()).await.unwrap(); let mut counter = 0; while let Some(_output) = stream.next().await { counter += 1; } assert_eq!(counter, 4); } }