preprocessor.rs 4.4 KB
Newer Older
1
2
3
4
5
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use super::*;
use crate::llm::model_card::ModelDeploymentCard;
6
use std::time::Duration;
7
8
9

use llm_rs::{
    preprocessor::OpenAIPreprocessor,
10
    preprocessor::media::{MediaDecoder as RsMediaDecoder, MediaFetcher as RsMediaFetcher},
11
    protocols::common::llm_backend::{BackendOutput, PreprocessedRequest},
12
    types::{
13
        Annotated,
14
15
16
17
18
19
        openai::chat_completions::{
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
        },
    },
};

20
21
22
use dynamo_runtime::pipeline::{
    ManyOut, Operator, PushRouter, SegmentSink, ServiceFrontend, SingleIn, Source,
};
23
24
25
26
27
28
29
30
31
32
33
34

#[pyclass]
pub(crate) struct OAIChatPreprocessor {
    inner: Arc<llm_rs::preprocessor::OpenAIPreprocessor>,
    current: Endpoint,
    next: Endpoint,
}

#[pymethods]
impl OAIChatPreprocessor {
    #[new]
    fn new(mdc: ModelDeploymentCard, current: Endpoint, next: Endpoint) -> PyResult<Self> {
35
        let preprocessor = OpenAIPreprocessor::new(mdc.inner.clone()).map_err(to_pyerr)?;
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        Ok(Self {
            inner: preprocessor,
            current,
            next,
        })
    }

    fn start<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let frontend = ServiceFrontend::<
            SingleIn<NvCreateChatCompletionRequest>,
            ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
        >::new();

        let network =
50
            SegmentSink::<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>>::new();
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        let preprocessor = self.inner.into_operator();
        let pipeline = frontend
            .link(preprocessor.forward_edge())
            .map_err(to_pyerr)?
            .link(network.clone())
            .map_err(to_pyerr)?
            .link(preprocessor.backward_edge())
            .map_err(to_pyerr)?
            .link(frontend)
            .map_err(to_pyerr)?;
        let ingress = Ingress::for_engine(pipeline).map_err(to_pyerr)?;
        let builder = self.current.inner.endpoint_builder().handler(ingress);
        let endpoint = Arc::new(self.next.inner.clone());
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
66
            let client = endpoint.client().await.map_err(to_pyerr)?;
67
            let router = PushRouter::<PreprocessedRequest, Annotated<BackendOutput>>::from_client(
68
69
70
71
72
73
                client,
                Default::default(),
            )
            .await
            .map_err(to_pyerr)?;
            network.attach(Arc::new(router)).map_err(to_pyerr)?;
74
75
76
77
78
            builder.start().await.map_err(to_pyerr)?;
            Ok(())
        })
    }
}
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

#[pyclass]
#[derive(Clone)]
pub struct MediaDecoder {
    pub(crate) inner: RsMediaDecoder,
}

#[pymethods]
impl MediaDecoder {
    #[new]
    fn new() -> Self {
        Self {
            inner: RsMediaDecoder::default(),
        }
    }

    fn image_decoder(&mut self, image_decoder: &Bound<'_, PyDict>) -> PyResult<()> {
        let image_decoder = pythonize::depythonize(image_decoder).map_err(|err| {
            PyErr::new::<PyException, _>(format!("Failed to parse image_decoder: {}", err))
        })?;
        self.inner.image_decoder = image_decoder;
        Ok(())
    }
102
103
104
105
106
107
108
109
110

    #[cfg(feature = "media-ffmpeg")]
    fn video_decoder(&mut self, video_decoder: &Bound<'_, PyDict>) -> PyResult<()> {
        let video_decoder = pythonize::depythonize(video_decoder).map_err(|err| {
            PyErr::new::<PyException, _>(format!("Failed to parse video_decoder: {}", err))
        })?;
        self.inner.video_decoder = video_decoder;
        Ok(())
    }
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
}

#[pyclass]
#[derive(Clone)]
pub struct MediaFetcher {
    pub(crate) inner: RsMediaFetcher,
}

#[pymethods]
impl MediaFetcher {
    #[new]
    fn new() -> Self {
        Self {
            inner: RsMediaFetcher::default(),
        }
    }
    fn user_agent(&mut self, user_agent: String) {
        self.inner.user_agent = user_agent;
    }

    fn allow_direct_ip(&mut self, allow: bool) {
        self.inner.allow_direct_ip = allow;
    }

    fn allow_direct_port(&mut self, allow: bool) {
        self.inner.allow_direct_port = allow;
    }

    fn allowed_media_domains(&mut self, domains: Vec<String>) {
        self.inner.allowed_media_domains = Some(domains.into_iter().collect());
    }

    fn timeout_ms(&mut self, timeout_ms: u64) {
        self.inner.timeout = Some(Duration::from_millis(timeout_ms));
    }
}