loader.rs 10.4 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
// SPDX-License-Identifier: Apache-2.0

use std::collections::HashSet;
use std::time::Duration;

use anyhow::Result;

use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;
10
use dynamo_memory::nixl::NixlAgent;
11

12
13
14
use super::common::EncodedMediaData;
use super::decoders::{Decoder, MediaDecoder};
use super::rdma::{RdmaMediaDataDescriptor, get_nixl_agent};
15
16

const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
17
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct MediaFetcher {
    pub user_agent: String,
    pub allow_direct_ip: bool,
    pub allow_direct_port: bool,
    pub allowed_media_domains: Option<HashSet<String>>,
    pub timeout: Option<Duration>,
}

impl Default for MediaFetcher {
    fn default() -> Self {
        Self {
            user_agent: DEFAULT_HTTP_USER_AGENT.to_string(),
            allow_direct_ip: false,
            allow_direct_port: false,
            allowed_media_domains: None,
35
            timeout: Some(DEFAULT_HTTP_TIMEOUT),
36
37
38
39
        }
    }
}

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
impl MediaFetcher {
    pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
        if !matches!(url.scheme(), "http" | "https" | "data") {
            anyhow::bail!("Only HTTP(S) and data URLs are allowed");
        }

        if url.scheme() == "data" {
            return Ok(());
        }

        if !self.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_))) {
            anyhow::bail!("Direct IP access is not allowed");
        }
        if !self.allow_direct_port && url.port().is_some() {
            anyhow::bail!("Direct port access is not allowed");
        }
        if let Some(allowed_domains) = &self.allowed_media_domains
            && let Some(host) = url.host_str()
            && !allowed_domains.contains(host)
        {
            anyhow::bail!("Domain '{host}' is not in allowed list");
        }

        Ok(())
    }
}

67
pub struct MediaLoader {
68
    #[allow(dead_code)]
69
    media_decoder: MediaDecoder,
70
    #[allow(dead_code)]
71
    http_client: reqwest::Client,
72
    #[allow(dead_code)]
73
    media_fetcher: MediaFetcher,
74
    nixl_agent: NixlAgent,
75
76
77
}

impl MediaLoader {
78
79
80
    pub fn new(media_decoder: MediaDecoder, media_fetcher: Option<MediaFetcher>) -> Result<Self> {
        let media_fetcher = media_fetcher.unwrap_or_default();
        let mut http_client_builder: reqwest::ClientBuilder =
81
82
83
84
85
86
87
88
            reqwest::Client::builder().user_agent(&media_fetcher.user_agent);

        if let Some(timeout) = media_fetcher.timeout {
            http_client_builder = http_client_builder.timeout(timeout);
        }

        let http_client = http_client_builder.build()?;

89
90
        let nixl_agent = get_nixl_agent()?;

91
        Ok(Self {
92
            media_decoder,
93
94
            http_client,
            media_fetcher,
95
            nixl_agent,
96
97
98
        })
    }

99
    pub async fn fetch_and_decode_media_part(
100
101
        &self,
        oai_content_part: &ChatCompletionRequestUserMessageContentPart,
102
        media_io_kwargs: Option<&MediaDecoder>,
103
    ) -> Result<RdmaMediaDataDescriptor> {
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        // fetch the media, decode and NIXL-register
        let decoded = match oai_content_part {
            ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
                let mdc_decoder = self
                    .media_decoder
                    .image
                    .as_ref()
                    .ok_or_else(|| anyhow::anyhow!("Model does not support image inputs"))?;

                let url = &image_part.image_url.url;
                self.media_fetcher.check_if_url_allowed(url)?;
                let data = EncodedMediaData::from_url(url, &self.http_client).await?;

                // Use runtime decoder if provided, with MDC limits enforced
                let decoder =
                    mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.image.as_ref()));
                decoder.decode_async(data).await?
            }
            #[allow(unused_variables)]
            ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
                #[cfg(not(feature = "media-ffmpeg"))]
                anyhow::bail!("Video decoding requires the 'media-ffmpeg' feature to be enabled");

                #[cfg(feature = "media-ffmpeg")]
                {
129
                    let mdc_decoder =
130
131
                        self.media_decoder.video.as_ref().ok_or_else(|| {
                            anyhow::anyhow!("Model does not support video inputs")
132
133
                        })?;

134
                    let url = &video_part.video_url.url;
135
                    self.media_fetcher.check_if_url_allowed(url)?;
136
                    let data = EncodedMediaData::from_url(url, &self.http_client).await?;
137
138
139

                    // Use runtime decoder if provided, with MDC limits enforced
                    let decoder =
140
                        mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.video.as_ref()));
141
                    decoder.decode_async(data).await?
142
                }
143
144
145
146
147
148
149
150
151
            }
            ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
                anyhow::bail!("Audio decoding is not supported yet");
            }
            _ => anyhow::bail!("Unsupported media type"),
        };

        let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?;
        Ok(rdma_descriptor)
152
153
154
    }
}

155
#[cfg(all(test, feature = "testing-nixl"))]
156
mod tests {
157
    use super::super::decoders::ImageDecoder;
158
    use super::super::rdma::DataType;
159
    use super::*;
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};

    #[tokio::test]
    async fn test_fetch_and_decode() {
        let test_image_bytes =
            include_bytes!("../../../tests/data/media/llm-optimize-deploy-graphic.png");

        let mut server = mockito::Server::new_async().await;
        let mock = server
            .mock("GET", "/llm-optimize-deploy-graphic.png")
            .with_status(200)
            .with_header("content-type", "image/png")
            .with_body(&test_image_bytes[..])
            .create_async()
            .await;

176
177
178
179
180
        let media_decoder = MediaDecoder {
            image: Some(ImageDecoder::default()),
            #[cfg(feature = "media-ffmpeg")]
            video: None,
        };
181
182
183
184
185
186
        let fetcher = MediaFetcher {
            allow_direct_ip: true,
            allow_direct_port: true,
            ..Default::default()
        };

187
188
189
190
191
192
193
194
195
196
        let loader: MediaLoader = match MediaLoader::new(media_decoder, Some(fetcher)) {
            Ok(l) => l,
            Err(e) => {
                println!(
                    "test test_fetch_and_decode ... ignored (NIXL/UCX not available: {})",
                    e
                );
                return;
            }
        };
197
198
199
200
201
202

        let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url()));
        let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl(
            ChatCompletionRequestMessageContentPartImage { image_url },
        );

203
204
205
        let result = loader
            .fetch_and_decode_media_part(&content_part, None)
            .await;
206

207
208
209
210
211
212
213
214
215
216
        let descriptor = match result {
            Ok(descriptor) => descriptor,
            Err(e) if e.to_string().contains("NIXL agent is not available") => {
                println!("test test_fetch_and_decode ... ignored (NIXL agent not available)");
                return;
            }
            Err(e) => panic!("Failed to fetch and decode image: {}", e),
        };
        mock.assert_async().await;
        assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8);
217
218
219

        // Verify image dimensions: 1,999px × 1,125px (width × height)
        // Shape format is [height, width, channels]
220
221
222
223
224
225
226
227
228
229
230
231
232
        assert_eq!(descriptor.tensor_info.shape.len(), 3);
        assert_eq!(
            descriptor.tensor_info.shape[0], 1125,
            "Height should be 1125"
        );
        assert_eq!(
            descriptor.tensor_info.shape[1], 1999,
            "Width should be 1999"
        );
        assert_eq!(
            descriptor.tensor_info.shape[2], 4,
            "RGBA channels should be 4"
        );
233

234
235
236
237
238
239
240
241
        assert!(
            descriptor.source_storage.is_some(),
            "Source storage should be present"
        );
        assert!(
            descriptor.source_storage.unwrap().is_registered(),
            "Source storage should be registered with NIXL"
        );
242
    }
243
244
245
246
247
}

#[cfg(test)]
mod tests_non_nixl {
    use super::*;
248
249
250
251
252
253
254
255
256

    #[test]
    fn test_direct_ip_blocked() {
        let fetcher = MediaFetcher {
            allow_direct_ip: false,
            ..Default::default()
        };

        let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
257
        let result = fetcher.check_if_url_allowed(&url);
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("Direct IP access is not allowed")
        );
    }

    #[test]
    fn test_direct_port_blocked() {
        let fetcher = MediaFetcher {
            allow_direct_port: false,
            ..Default::default()
        };

        let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
276
        let result = fetcher.check_if_url_allowed(&url);
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299

        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("Direct port access is not allowed")
        );
    }

    #[test]
    fn test_domain_allowlist() {
        let mut allowed_domains = HashSet::new();
        allowed_domains.insert("trusted.com".to_string());
        allowed_domains.insert("example.com".to_string());

        let fetcher = MediaFetcher {
            allowed_media_domains: Some(allowed_domains),
            ..Default::default()
        };

        // Allowed domain should pass
        let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
300
        assert!(fetcher.check_if_url_allowed(&url).is_ok());
301
302
303

        // Disallowed domain should fail
        let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap();
304
        let result = fetcher.check_if_url_allowed(&url);
305
306
307
308
309
310
311
312
313
        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("not in allowed list")
        );
    }
}