loader.rs 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::collections::HashSet;
use std::time::Duration;

use anyhow::Result;

use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;

11
12
13
14
15
16
17
18
use super::decoders::MediaDecoder;
use super::rdma::RdmaMediaDataDescriptor;

#[cfg(feature = "media-nixl")]
use {
    super::common::EncodedMediaData, super::decoders::Decoder, super::rdma::get_nixl_agent,
    dynamo_memory::nixl::NixlAgent,
};
19
20

const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
21
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct MediaFetcher {
    pub user_agent: String,
    pub allow_direct_ip: bool,
    pub allow_direct_port: bool,
    pub allowed_media_domains: Option<HashSet<String>>,
    pub timeout: Option<Duration>,
}

impl Default for MediaFetcher {
    fn default() -> Self {
        Self {
            user_agent: DEFAULT_HTTP_USER_AGENT.to_string(),
            allow_direct_ip: false,
            allow_direct_port: false,
            allowed_media_domains: None,
39
            timeout: Some(DEFAULT_HTTP_TIMEOUT),
40
41
42
43
44
        }
    }
}

pub struct MediaLoader {
45
    #[allow(dead_code)]
46
    media_decoder: MediaDecoder,
47
    #[allow(dead_code)]
48
49
    http_client: reqwest::Client,
    media_fetcher: MediaFetcher,
50
51
    #[cfg(feature = "media-nixl")]
    nixl_agent: NixlAgent,
52
53
54
}

impl MediaLoader {
55
56
57
    pub fn new(media_decoder: MediaDecoder, media_fetcher: Option<MediaFetcher>) -> Result<Self> {
        let media_fetcher = media_fetcher.unwrap_or_default();
        let mut http_client_builder: reqwest::ClientBuilder =
58
59
60
61
62
63
64
65
            reqwest::Client::builder().user_agent(&media_fetcher.user_agent);

        if let Some(timeout) = media_fetcher.timeout {
            http_client_builder = http_client_builder.timeout(timeout);
        }

        let http_client = http_client_builder.build()?;

66
67
68
        #[cfg(feature = "media-nixl")]
        let nixl_agent = get_nixl_agent()?;

69
        Ok(Self {
70
            media_decoder,
71
72
            http_client,
            media_fetcher,
73
74
            #[cfg(feature = "media-nixl")]
            nixl_agent,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        })
    }

    pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
        if !matches!(url.scheme(), "http" | "https" | "data") {
            anyhow::bail!("Only HTTP(S) and data URLs are allowed");
        }

        if url.scheme() == "data" {
            return Ok(());
        }

        if !self.media_fetcher.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_)))
        {
            anyhow::bail!("Direct IP access is not allowed");
        }
        if !self.media_fetcher.allow_direct_port && url.port().is_some() {
            anyhow::bail!("Direct port access is not allowed");
        }
        if let Some(allowed_domains) = &self.media_fetcher.allowed_media_domains
            && let Some(host) = url.host_str()
            && !allowed_domains.contains(host)
        {
            anyhow::bail!("Domain '{host}' is not in allowed list");
        }

        Ok(())
    }

104
    pub async fn fetch_and_decode_media_part(
105
106
        &self,
        oai_content_part: &ChatCompletionRequestUserMessageContentPart,
107
        media_io_kwargs: Option<&MediaDecoder>,
108
109
110
    ) -> Result<RdmaMediaDataDescriptor> {
        #[cfg(not(feature = "media-nixl"))]
        anyhow::bail!(
111
            "NIXL is not supported, cannot decode and register media data {oai_content_part:?} with media_io_kwargs {media_io_kwargs:?}"
112
        );
113

114
115
116
117
118
        #[cfg(feature = "media-nixl")]
        {
            // fetch the media, decode and NIXL-register
            let decoded = match oai_content_part {
                ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
119
120
121
122
123
                    let mdc_decoder =
                        self.media_decoder.image.as_ref().ok_or_else(|| {
                            anyhow::anyhow!("Model does not support image inputs")
                        })?;

124
125
126
                    let url = &image_part.image_url.url;
                    self.check_if_url_allowed(url)?;
                    let data = EncodedMediaData::from_url(url, &self.http_client).await?;
127
128
129
130
131

                    // Use runtime decoder if provided, with MDC limits enforced
                    let decoder =
                        mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.image.as_ref()));
                    decoder.decode_async(data).await?
132
133
                }
                ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
134
135
136
137
138
139
                    #[cfg(not(feature = "media-ffmpeg"))]
                    anyhow::bail!(
                        "Video decoding requires the 'media-ffmpeg' feature to be enabled"
                    );

                    #[cfg(feature = "media-ffmpeg")]
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                    {
                        let mdc_decoder = self.media_decoder.video.as_ref().ok_or_else(|| {
                            anyhow::anyhow!("Model does not support video inputs")
                        })?;

                        let url = &video_part.video_url.url;
                        self.check_if_url_allowed(url)?;
                        let data = EncodedMediaData::from_url(url, &self.http_client).await?;

                        // Use runtime decoder if provided, with MDC limits enforced
                        let decoder = mdc_decoder
                            .with_runtime(media_io_kwargs.and_then(|k| k.video.as_ref()));
                        decoder.decode_async(data).await?
                    }
154
155
156
157
158
159
160
161
162
163
                }
                ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
                    anyhow::bail!("Audio decoding is not supported yet");
                }
                _ => anyhow::bail!("Unsupported media type"),
            };

            let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?;
            Ok(rdma_descriptor)
        }
164
165
166
    }
}

167
#[cfg(all(test, feature = "media-nixl"))]
168
mod tests {
169
    use super::super::decoders::ImageDecoder;
170
    use super::super::rdma::DataType;
171
    use super::*;
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};

    #[tokio::test]
    async fn test_fetch_and_decode() {
        let test_image_bytes =
            include_bytes!("../../../tests/data/media/llm-optimize-deploy-graphic.png");

        let mut server = mockito::Server::new_async().await;
        let mock = server
            .mock("GET", "/llm-optimize-deploy-graphic.png")
            .with_status(200)
            .with_header("content-type", "image/png")
            .with_body(&test_image_bytes[..])
            .create_async()
            .await;

188
189
190
191
192
        let media_decoder = MediaDecoder {
            image: Some(ImageDecoder::default()),
            #[cfg(feature = "media-ffmpeg")]
            video: None,
        };
193
194
195
196
197
198
        let fetcher = MediaFetcher {
            allow_direct_ip: true,
            allow_direct_port: true,
            ..Default::default()
        };

199
        let loader: MediaLoader = MediaLoader::new(media_decoder, Some(fetcher)).unwrap();
200
201
202
203
204
205

        let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url()));
        let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl(
            ChatCompletionRequestMessageContentPartImage { image_url },
        );

206
207
208
        let result = loader
            .fetch_and_decode_media_part(&content_part, None)
            .await;
209

210
211
212
213
214
215
216
217
218
219
        let descriptor = match result {
            Ok(descriptor) => descriptor,
            Err(e) if e.to_string().contains("NIXL agent is not available") => {
                println!("test test_fetch_and_decode ... ignored (NIXL agent not available)");
                return;
            }
            Err(e) => panic!("Failed to fetch and decode image: {}", e),
        };
        mock.assert_async().await;
        assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8);
220
221
222

        // Verify image dimensions: 1,999px × 1,125px (width × height)
        // Shape format is [height, width, channels]
223
224
225
226
227
228
229
230
231
232
233
234
235
        assert_eq!(descriptor.tensor_info.shape.len(), 3);
        assert_eq!(
            descriptor.tensor_info.shape[0], 1125,
            "Height should be 1125"
        );
        assert_eq!(
            descriptor.tensor_info.shape[1], 1999,
            "Width should be 1999"
        );
        assert_eq!(
            descriptor.tensor_info.shape[2], 4,
            "RGBA channels should be 4"
        );
236

237
238
239
240
241
242
243
244
        assert!(
            descriptor.source_storage.is_some(),
            "Source storage should be present"
        );
        assert!(
            descriptor.source_storage.unwrap().is_registered(),
            "Source storage should be registered with NIXL"
        );
245
    }
246
247
248
249
250
}

#[cfg(test)]
mod tests_non_nixl {
    use super::*;
251
252
253
254
255
256
257

    #[test]
    fn test_direct_ip_blocked() {
        let fetcher = MediaFetcher {
            allow_direct_ip: false,
            ..Default::default()
        };
258
        let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

        let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
        let result = loader.check_if_url_allowed(&url);

        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("Direct IP access is not allowed")
        );
    }

    #[test]
    fn test_direct_port_blocked() {
        let fetcher = MediaFetcher {
            allow_direct_port: false,
            ..Default::default()
        };
278
        let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301

        let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
        let result = loader.check_if_url_allowed(&url);

        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("Direct port access is not allowed")
        );
    }

    #[test]
    fn test_domain_allowlist() {
        let mut allowed_domains = HashSet::new();
        allowed_domains.insert("trusted.com".to_string());
        allowed_domains.insert("example.com".to_string());

        let fetcher = MediaFetcher {
            allowed_media_domains: Some(allowed_domains),
            ..Default::default()
        };
302
        let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

        // Allowed domain should pass
        let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
        assert!(loader.check_if_url_allowed(&url).is_ok());

        // Disallowed domain should fail
        let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap();
        let result = loader.check_if_url_allowed(&url);
        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("not in allowed list")
        );
    }
}