loader.rs 10.7 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
// SPDX-License-Identifier: Apache-2.0

use std::collections::HashSet;
use std::time::Duration;

use anyhow::Result;

use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;

11
12
13
14
15
16
17
18
use super::decoders::MediaDecoder;
use super::rdma::RdmaMediaDataDescriptor;

#[cfg(feature = "media-nixl")]
use {
    super::common::EncodedMediaData, super::decoders::Decoder, super::rdma::get_nixl_agent,
    dynamo_memory::nixl::NixlAgent,
};
19
20

const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
21
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct MediaFetcher {
    pub user_agent: String,
    pub allow_direct_ip: bool,
    pub allow_direct_port: bool,
    pub allowed_media_domains: Option<HashSet<String>>,
    pub timeout: Option<Duration>,
}

impl Default for MediaFetcher {
    fn default() -> Self {
        Self {
            user_agent: DEFAULT_HTTP_USER_AGENT.to_string(),
            allow_direct_ip: false,
            allow_direct_port: false,
            allowed_media_domains: None,
39
            timeout: Some(DEFAULT_HTTP_TIMEOUT),
40
41
42
43
        }
    }
}

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
impl MediaFetcher {
    pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
        if !matches!(url.scheme(), "http" | "https" | "data") {
            anyhow::bail!("Only HTTP(S) and data URLs are allowed");
        }

        if url.scheme() == "data" {
            return Ok(());
        }

        if !self.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_))) {
            anyhow::bail!("Direct IP access is not allowed");
        }
        if !self.allow_direct_port && url.port().is_some() {
            anyhow::bail!("Direct port access is not allowed");
        }
        if let Some(allowed_domains) = &self.allowed_media_domains
            && let Some(host) = url.host_str()
            && !allowed_domains.contains(host)
        {
            anyhow::bail!("Domain '{host}' is not in allowed list");
        }

        Ok(())
    }
}

71
pub struct MediaLoader {
72
    #[allow(dead_code)]
73
    media_decoder: MediaDecoder,
74
    #[allow(dead_code)]
75
76
    http_client: reqwest::Client,
    media_fetcher: MediaFetcher,
77
78
    #[cfg(feature = "media-nixl")]
    nixl_agent: NixlAgent,
79
80
81
}

impl MediaLoader {
82
83
84
    pub fn new(media_decoder: MediaDecoder, media_fetcher: Option<MediaFetcher>) -> Result<Self> {
        let media_fetcher = media_fetcher.unwrap_or_default();
        let mut http_client_builder: reqwest::ClientBuilder =
85
86
87
88
89
90
91
92
            reqwest::Client::builder().user_agent(&media_fetcher.user_agent);

        if let Some(timeout) = media_fetcher.timeout {
            http_client_builder = http_client_builder.timeout(timeout);
        }

        let http_client = http_client_builder.build()?;

93
94
95
        #[cfg(feature = "media-nixl")]
        let nixl_agent = get_nixl_agent()?;

96
        Ok(Self {
97
            media_decoder,
98
99
            http_client,
            media_fetcher,
100
101
            #[cfg(feature = "media-nixl")]
            nixl_agent,
102
103
104
        })
    }

105
    pub async fn fetch_and_decode_media_part(
106
107
        &self,
        oai_content_part: &ChatCompletionRequestUserMessageContentPart,
108
        media_io_kwargs: Option<&MediaDecoder>,
109
110
111
    ) -> Result<RdmaMediaDataDescriptor> {
        #[cfg(not(feature = "media-nixl"))]
        anyhow::bail!(
112
            "NIXL is not supported, cannot decode and register media data {oai_content_part:?} with media_io_kwargs {media_io_kwargs:?}"
113
        );
114

115
116
117
118
119
        #[cfg(feature = "media-nixl")]
        {
            // fetch the media, decode and NIXL-register
            let decoded = match oai_content_part {
                ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
120
121
122
123
124
                    let mdc_decoder =
                        self.media_decoder.image.as_ref().ok_or_else(|| {
                            anyhow::anyhow!("Model does not support image inputs")
                        })?;

125
                    let url = &image_part.image_url.url;
126
                    self.media_fetcher.check_if_url_allowed(url)?;
127
                    let data = EncodedMediaData::from_url(url, &self.http_client).await?;
128
129
130
131
132

                    // Use runtime decoder if provided, with MDC limits enforced
                    let decoder =
                        mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.image.as_ref()));
                    decoder.decode_async(data).await?
133
                }
134
                #[allow(unused_variables)]
135
                ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
136
137
138
139
140
141
                    #[cfg(not(feature = "media-ffmpeg"))]
                    anyhow::bail!(
                        "Video decoding requires the 'media-ffmpeg' feature to be enabled"
                    );

                    #[cfg(feature = "media-ffmpeg")]
142
143
144
145
146
147
                    {
                        let mdc_decoder = self.media_decoder.video.as_ref().ok_or_else(|| {
                            anyhow::anyhow!("Model does not support video inputs")
                        })?;

                        let url = &video_part.video_url.url;
148
                        self.media_fetcher.check_if_url_allowed(url)?;
149
150
151
152
153
154
155
                        let data = EncodedMediaData::from_url(url, &self.http_client).await?;

                        // Use runtime decoder if provided, with MDC limits enforced
                        let decoder = mdc_decoder
                            .with_runtime(media_io_kwargs.and_then(|k| k.video.as_ref()));
                        decoder.decode_async(data).await?
                    }
156
157
158
159
160
161
162
163
164
165
                }
                ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
                    anyhow::bail!("Audio decoding is not supported yet");
                }
                _ => anyhow::bail!("Unsupported media type"),
            };

            let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?;
            Ok(rdma_descriptor)
        }
166
167
168
    }
}

169
#[cfg(all(test, feature = "media-nixl", feature = "testing-nixl"))]
170
mod tests {
171
    use super::super::decoders::ImageDecoder;
172
    use super::super::rdma::DataType;
173
    use super::*;
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};

    #[tokio::test]
    async fn test_fetch_and_decode() {
        let test_image_bytes =
            include_bytes!("../../../tests/data/media/llm-optimize-deploy-graphic.png");

        let mut server = mockito::Server::new_async().await;
        let mock = server
            .mock("GET", "/llm-optimize-deploy-graphic.png")
            .with_status(200)
            .with_header("content-type", "image/png")
            .with_body(&test_image_bytes[..])
            .create_async()
            .await;

190
191
192
193
194
        let media_decoder = MediaDecoder {
            image: Some(ImageDecoder::default()),
            #[cfg(feature = "media-ffmpeg")]
            video: None,
        };
195
196
197
198
199
200
        let fetcher = MediaFetcher {
            allow_direct_ip: true,
            allow_direct_port: true,
            ..Default::default()
        };

201
        let loader: MediaLoader = MediaLoader::new(media_decoder, Some(fetcher)).unwrap();
202
203
204
205
206
207

        let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url()));
        let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl(
            ChatCompletionRequestMessageContentPartImage { image_url },
        );

208
209
210
        let result = loader
            .fetch_and_decode_media_part(&content_part, None)
            .await;
211

212
213
214
215
216
217
218
219
220
221
        let descriptor = match result {
            Ok(descriptor) => descriptor,
            Err(e) if e.to_string().contains("NIXL agent is not available") => {
                println!("test test_fetch_and_decode ... ignored (NIXL agent not available)");
                return;
            }
            Err(e) => panic!("Failed to fetch and decode image: {}", e),
        };
        mock.assert_async().await;
        assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8);
222
223
224

        // Verify image dimensions: 1,999px × 1,125px (width × height)
        // Shape format is [height, width, channels]
225
226
227
228
229
230
231
232
233
234
235
236
237
        assert_eq!(descriptor.tensor_info.shape.len(), 3);
        assert_eq!(
            descriptor.tensor_info.shape[0], 1125,
            "Height should be 1125"
        );
        assert_eq!(
            descriptor.tensor_info.shape[1], 1999,
            "Width should be 1999"
        );
        assert_eq!(
            descriptor.tensor_info.shape[2], 4,
            "RGBA channels should be 4"
        );
238

239
240
241
242
243
244
245
246
        assert!(
            descriptor.source_storage.is_some(),
            "Source storage should be present"
        );
        assert!(
            descriptor.source_storage.unwrap().is_registered(),
            "Source storage should be registered with NIXL"
        );
247
    }
248
249
250
251
252
}

#[cfg(test)]
mod tests_non_nixl {
    use super::*;
253
254
255
256
257
258
259
260
261

    #[test]
    fn test_direct_ip_blocked() {
        let fetcher = MediaFetcher {
            allow_direct_ip: false,
            ..Default::default()
        };

        let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
262
        let result = fetcher.check_if_url_allowed(&url);
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("Direct IP access is not allowed")
        );
    }

    #[test]
    fn test_direct_port_blocked() {
        let fetcher = MediaFetcher {
            allow_direct_port: false,
            ..Default::default()
        };

        let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
281
        let result = fetcher.check_if_url_allowed(&url);
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304

        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("Direct port access is not allowed")
        );
    }

    #[test]
    fn test_domain_allowlist() {
        let mut allowed_domains = HashSet::new();
        allowed_domains.insert("trusted.com".to_string());
        allowed_domains.insert("example.com".to_string());

        let fetcher = MediaFetcher {
            allowed_media_domains: Some(allowed_domains),
            ..Default::default()
        };

        // Allowed domain should pass
        let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
305
        assert!(fetcher.check_if_url_allowed(&url).is_ok());
306
307
308

        // Disallowed domain should fail
        let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap();
309
        let result = fetcher.check_if_url_allowed(&url);
310
311
312
313
314
315
316
317
318
        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("not in allowed list")
        );
    }
}