preprocessor.rs 2.05 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
// SPDX-License-Identifier: Apache-2.0

use super::*;
5
use std::time::Duration;
6

7
use llm_rs::preprocessor::media::{MediaDecoder as RsMediaDecoder, MediaFetcher as RsMediaFetcher};
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

#[pyclass]
#[derive(Clone)]
pub struct MediaDecoder {
    pub(crate) inner: RsMediaDecoder,
}

#[pymethods]
impl MediaDecoder {
    #[new]
    fn new() -> Self {
        Self {
            inner: RsMediaDecoder::default(),
        }
    }

24
25
26
    fn enable_image(&mut self, decoder_options: &Bound<'_, PyDict>) -> PyResult<()> {
        let decoder_options = pythonize::depythonize(decoder_options).map_err(|err| {
            PyErr::new::<PyException, _>(format!("Failed to parse image decoder config: {}", err))
27
        })?;
28
        self.inner.image = Some(decoder_options);
29
30
        Ok(())
    }
31
32

    #[cfg(feature = "media-ffmpeg")]
33
34
35
    fn enable_video(&mut self, decoder_options: &Bound<'_, PyDict>) -> PyResult<()> {
        let decoder_options = pythonize::depythonize(decoder_options).map_err(|err| {
            PyErr::new::<PyException, _>(format!("Failed to parse video decoder config: {}", err))
36
        })?;
37
        self.inner.video = Some(decoder_options);
38
39
        Ok(())
    }
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
}

#[pyclass]
#[derive(Clone)]
pub struct MediaFetcher {
    pub(crate) inner: RsMediaFetcher,
}

#[pymethods]
impl MediaFetcher {
    #[new]
    fn new() -> Self {
        Self {
            inner: RsMediaFetcher::default(),
        }
    }
    fn user_agent(&mut self, user_agent: String) {
        self.inner.user_agent = user_agent;
    }

    fn allow_direct_ip(&mut self, allow: bool) {
        self.inner.allow_direct_ip = allow;
    }

    fn allow_direct_port(&mut self, allow: bool) {
        self.inner.allow_direct_port = allow;
    }

    fn allowed_media_domains(&mut self, domains: Vec<String>) {
        self.inner.allowed_media_domains = Some(domains.into_iter().collect());
    }

    fn timeout_ms(&mut self, timeout_ms: u64) {
        self.inner.timeout = Some(Duration::from_millis(timeout_ms));
    }
}