// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use super::*; use crate::pipeline::{AsyncEngine, PipelineIO}; mod base; mod common; pub struct Frontend { edge: OnceLock>, sinks: Arc>>>, } /// A [`ServiceFrontend`] is the interface for an [`AsyncEngine>, ManyOut>, Error>`] pub struct ServiceFrontend { inner: Frontend, } pub struct SegmentSource { inner: Frontend, } // impl Frontend { // pub fn new() -> Arc { // Arc::new(Self { // edge: OnceLock::new(), // sinks: Arc::new(Mutex::new(HashMap::new())), // }) // } // } // impl SegmentSource { // pub fn new() -> Arc { // Arc::new(Self { // edge: OnceLock::new(), // sinks: Arc::new(Mutex::new(HashMap::new())), // }) // } // } // #[async_trait] // impl Source> for Frontend { // async fn on_next(&self, data: Context, _: private::Token) -> Result<(), PipelineError> { // self.edge // .get() // .ok_or(PipelineError::NoEdge)? // .write(data) // .await // } // fn set_edge( // &self, // edge: Edge>>, // _: private::Token, // ) -> Result<(), PipelineError> { // self.edge // .set(edge) // .map_err(|_| PipelineError::EdgeAlreadySet)?; // Ok(()) // } // } // #[async_trait] // impl Sink> for Frontend { // async fn on_data( // &self, // data: PipelineStream, // _: private::Token, // ) -> Result<(), PipelineError> { // let context = data.context(); // let mut sinks = self.sinks.lock().unwrap(); // let tx = sinks // .remove(context.id()) // .ok_or(PipelineError::DetatchedStreamReceiver) // .map_err(|e| { // data.context().stop_generating(); // e // })?; // drop(sinks); // let ctx = data.context(); // tx.send(data) // .map_err(|_| PipelineError::DetatchedStreamReceiver) // .map_err(|e| { // ctx.stop_generating(); // e // }) // } // } // impl Link> for Frontend { // fn link> + 'static>(&self, sink: Arc) -> Result, PipelineError> { // let edge = Edge::new(sink.clone()); // self.set_edge(edge.into(), private::Token {})?; // Ok(sink) // } // } // #[async_trait] // impl AsyncEngine, Annotated, PipelineError> // for Frontend // { // async fn generate(&self, request: Context) -> Result, PipelineError> { // let (tx, rx) = oneshot::channel::>(); // { // let mut sinks = self.sinks.lock().unwrap(); // sinks.insert(request.id().to_string(), tx); // } // self.on_next(request, private::Token {}).await?; // rx.await.map_err(|_| PipelineError::DetatchedStreamSender) // } // } // // SegmentSource // #[async_trait] // impl Source> for SegmentSource { // async fn on_next(&self, data: Context, _: private::Token) -> Result<(), PipelineError> { // self.edge // .get() // .ok_or(PipelineError::NoEdge)? // .write(data) // .await // } // fn set_edge( // &self, // edge: Edge>>, // _: private::Token, // ) -> Result<(), PipelineError> { // self.edge // .set(edge) // .map_err(|_| PipelineError::EdgeAlreadySet)?; // Ok(()) // } // } // #[async_trait] // impl Sink> for SegmentSource { // async fn on_data( // &self, // data: PipelineStream, // _: private::Token, // ) -> Result<(), PipelineError> { // let context = data.context(); // let mut sinks = self.sinks.lock().unwrap(); // let tx = sinks // .remove(context.id()) // .ok_or(PipelineError::DetatchedStreamReceiver) // .map_err(|e| { // data.context().stop_generating(); // e // })?; // drop(sinks); // let ctx = data.context(); // tx.send(data) // .map_err(|_| PipelineError::DetatchedStreamReceiver) // .map_err(|e| { // ctx.stop_generating(); // e // }) // } // } // impl Link> for SegmentSource { // fn link> + 'static>(&self, sink: Arc) -> Result, PipelineError> { // let edge = Edge::new(sink.clone()); // self.set_edge(edge.into(), private::Token {})?; // Ok(sink) // } // } // #[async_trait] // impl AsyncEngine, Annotated, PipelineError> // for SegmentSource // { // async fn generate(&self, request: Context) -> Result, PipelineError> { // let (tx, rx) = oneshot::channel::>(); // { // let mut sinks = self.sinks.lock().unwrap(); // sinks.insert(request.id().to_string(), tx); // } // self.on_next(request, private::Token {}).await?; // rx.await.map_err(|_| PipelineError::DetatchedStreamSender) // } // } // #[cfg(test)] // mod tests { // use super::*; // #[tokio::test] // async fn test_pipeline_source_no_edge() { // let source = Frontend::<(), ()>::new(); // let stream = source.generate(().into()).await; // match stream { // Err(PipelineError::NoEdge) => (), // _ => panic!("Expected NoEdge error"), // } // } // } // pub struct IngressPort { // edge: OnceLock>, // } // impl IngressPort // where // In: for<'de> Deserialize<'de> + DataType, // Out: PipelineIO + Serialize, // { // pub fn new() -> Arc { // Arc::new(IngressPort { // edge: OnceLock::new(), // }) // } // } // #[async_trait] // impl AsyncEngine>, Vec> for IngressPort // where // In: for<'de> Deserialize<'de> + DataType, // Out: PipelineIO + Serialize, // { // async fn generate( // &self, // request: Context>, // ) -> Result>, PipelineError> { // // Deserialize request // let request = request.try_map(|bytes| { // bincode::deserialize::(&bytes) // .map_err(|err| PipelineError(format!("Failed to deserialize request: {}", err))) // })?; // // Forward request to edge // let stream = self // .edge // .get() // .ok_or(PipelineError("No engine to forward request to".to_string()))? // .generate(request) // .await?; // // Serialize response stream // let stream = // stream.map(|resp| bincode::serialize(&resp).expect("Failed to serialize response")); // Err(PipelineError(format!("Not implemented"))) // } // } // fn convert_stream( // stream: impl Stream> + Send + 'static, // ctx: Arc, // transform: Arc Result + Send + Sync>, // ) -> Pin> + Send>> // where // T: Send + 'static, // U: Send + 'static, // { // Box::pin(stream.flat_map(move |item| { // let ctx = ctx.clone(); // let transform = transform.clone(); // match item { // ServerStream::Data(data) => match transform(data) { // Ok(transformed) => futures::stream::iter(vec![ServerStream::Data(transformed)]), // Err(e) => { // // Trigger cancellation and propagate the error, followed by Sentinel // ctx.stop_generating(); // futures::stream::iter(vec![ServerStream::Error(e), ServerStream::Sentinel]) // } // }, // other => futures::stream::iter(vec![other]), // } // }) // // Use take_while to stop processing when encountering the Sentinel // .take_while(|item| futures::future::ready(!matches!(item, ServerStream::Sentinel)))) // }