common.rs 2.11 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
Ryan Olson's avatar
Ryan Olson committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

use crate::engine::AsyncEngineContextProvider;

use super::*;

macro_rules! impl_frontend {
    ($type:ident) => {
        impl<In: PipelineIO, Out: PipelineIO> $type<In, Out> {
            pub fn new() -> Arc<Self> {
                Arc::new(Self {
                    inner: Frontend::default(),
                })
            }
        }

        #[async_trait]
        impl<In: PipelineIO, Out: PipelineIO> Source<In> for $type<In, Out> {
            async fn on_next(&self, data: In, token: private::Token) -> Result<(), Error> {
                self.inner.on_next(data, token).await
            }

            fn set_edge(&self, edge: Edge<In>, token: private::Token) -> Result<(), PipelineError> {
                self.inner.set_edge(edge, token)
            }
        }

        #[async_trait]
        impl<In: PipelineIO, Out: PipelineIO + AsyncEngineContextProvider> Sink<Out>
            for $type<In, Out>
        {
            async fn on_data(&self, data: Out, token: private::Token) -> Result<(), Error> {
                self.inner.on_data(data, token).await
            }
        }

        #[async_trait]
39
40
41
        impl<In: PipelineIO + Sync, Out: PipelineIO> AsyncEngine<In, Out, Error>
            for $type<In, Out>
        {
Ryan Olson's avatar
Ryan Olson committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            async fn generate(&self, request: In) -> Result<Out, Error> {
                self.inner.generate(request).await
            }
        }
    };
}

impl_frontend!(ServiceFrontend);
impl_frontend!(SegmentSource);

#[cfg(test)]
mod tests {
    use super::*;
    use crate::pipeline::{ManyOut, PipelineErrorExt, SingleIn};

    #[tokio::test]
    async fn test_pipeline_source_no_edge() {
        let source = Frontend::<SingleIn<()>, ManyOut<()>>::default();
        let stream = source
            .generate(().into())
            .await
            .unwrap_err()
            .try_into_pipeline_error()
            .unwrap();

        match stream {
            PipelineError::NoEdge => (),
            _ => panic!("Expected NoEdge error"),
        }
    }
}