preprocessor.rs 4.42 KB
Newer Older
1
2
3
4
5
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use super::*;
use crate::llm::model_card::ModelDeploymentCard;
6
use std::time::Duration;
7
8
9

use llm_rs::{
    preprocessor::OpenAIPreprocessor,
10
    preprocessor::media::{MediaDecoder as RsMediaDecoder, MediaFetcher as RsMediaFetcher},
11
    protocols::common::llm_backend::{BackendOutput, PreprocessedRequest},
12
    types::{
13
        Annotated,
14
15
16
17
18
19
        openai::chat_completions::{
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
        },
    },
};

20
21
22
use dynamo_runtime::pipeline::{
    ManyOut, Operator, PushRouter, SegmentSink, ServiceFrontend, SingleIn, Source,
};
23
24
25
26
27
28
29
30
31
32
33
34

#[pyclass]
pub(crate) struct OAIChatPreprocessor {
    inner: Arc<llm_rs::preprocessor::OpenAIPreprocessor>,
    current: Endpoint,
    next: Endpoint,
}

#[pymethods]
impl OAIChatPreprocessor {
    #[new]
    fn new(mdc: ModelDeploymentCard, current: Endpoint, next: Endpoint) -> PyResult<Self> {
35
        let preprocessor = OpenAIPreprocessor::new(mdc.inner.clone()).map_err(to_pyerr)?;
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        Ok(Self {
            inner: preprocessor,
            current,
            next,
        })
    }

    fn start<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let frontend = ServiceFrontend::<
            SingleIn<NvCreateChatCompletionRequest>,
            ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
        >::new();

        let network =
50
            SegmentSink::<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>>::new();
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        let preprocessor = self.inner.into_operator();
        let pipeline = frontend
            .link(preprocessor.forward_edge())
            .map_err(to_pyerr)?
            .link(network.clone())
            .map_err(to_pyerr)?
            .link(preprocessor.backward_edge())
            .map_err(to_pyerr)?
            .link(frontend)
            .map_err(to_pyerr)?;
        let ingress = Ingress::for_engine(pipeline).map_err(to_pyerr)?;
        let builder = self.current.inner.endpoint_builder().handler(ingress);
        let endpoint = Arc::new(self.next.inner.clone());
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
66
            let client = endpoint.client().await.map_err(to_pyerr)?;
67
            let router = PushRouter::<PreprocessedRequest, Annotated<BackendOutput>>::from_client(
68
69
70
71
72
73
                client,
                Default::default(),
            )
            .await
            .map_err(to_pyerr)?;
            network.attach(Arc::new(router)).map_err(to_pyerr)?;
74
75
76
77
78
            builder.start().await.map_err(to_pyerr)?;
            Ok(())
        })
    }
}
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

#[pyclass]
#[derive(Clone)]
pub struct MediaDecoder {
    pub(crate) inner: RsMediaDecoder,
}

#[pymethods]
impl MediaDecoder {
    #[new]
    fn new() -> Self {
        Self {
            inner: RsMediaDecoder::default(),
        }
    }

95
96
97
    fn enable_image(&mut self, decoder_options: &Bound<'_, PyDict>) -> PyResult<()> {
        let decoder_options = pythonize::depythonize(decoder_options).map_err(|err| {
            PyErr::new::<PyException, _>(format!("Failed to parse image decoder config: {}", err))
98
        })?;
99
        self.inner.image = Some(decoder_options);
100
101
        Ok(())
    }
102
103

    #[cfg(feature = "media-ffmpeg")]
104
105
106
    fn enable_video(&mut self, decoder_options: &Bound<'_, PyDict>) -> PyResult<()> {
        let decoder_options = pythonize::depythonize(decoder_options).map_err(|err| {
            PyErr::new::<PyException, _>(format!("Failed to parse video decoder config: {}", err))
107
        })?;
108
        self.inner.video = Some(decoder_options);
109
110
        Ok(())
    }
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
}

#[pyclass]
#[derive(Clone)]
pub struct MediaFetcher {
    pub(crate) inner: RsMediaFetcher,
}

#[pymethods]
impl MediaFetcher {
    #[new]
    fn new() -> Self {
        Self {
            inner: RsMediaFetcher::default(),
        }
    }
    fn user_agent(&mut self, user_agent: String) {
        self.inner.user_agent = user_agent;
    }

    fn allow_direct_ip(&mut self, allow: bool) {
        self.inner.allow_direct_ip = allow;
    }

    fn allow_direct_port(&mut self, allow: bool) {
        self.inner.allow_direct_port = allow;
    }

    fn allowed_media_domains(&mut self, domains: Vec<String>) {
        self.inner.allowed_media_domains = Some(domains.into_iter().collect());
    }

    fn timeout_ms(&mut self, timeout_ms: u64) {
        self.inner.timeout = Some(Duration::from_millis(timeout_ms));
    }
}