Unverified Commit 0a065918 authored by milesial's avatar milesial Committed by GitHub
Browse files
parent b5db4e75
......@@ -2678,6 +2678,7 @@ dependencies = [
"derive_builder",
"dialoguer",
"dynamo-async-openai",
"dynamo-memory",
"dynamo-parsers",
"dynamo-runtime",
"either",
......@@ -4641,6 +4642,7 @@ dependencies = [
"ravif",
"rayon",
"rgb",
"serde",
"tiff",
"zune-core 0.5.0",
"zune-jpeg 0.5.5",
......
......@@ -2972,6 +2972,7 @@ dependencies = [
"ravif",
"rayon",
"rgb",
"serde",
"tiff",
"zune-core",
"zune-jpeg",
......
......@@ -24,6 +24,7 @@ testing-etcd = []
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec"]
cuda = ["dep:cudarc"]
integration = ["dynamo-runtime/integration"]
media-nixl = ["dep:nixl-sys", "dep:dynamo-memory"]
[[bench]]
name = "tokenizer"
......@@ -33,9 +34,11 @@ harness = false
name = "transfer_context_v2"
harness = false
required-features = ["block-manager", "testing-cuda"]
[dependencies]
# repo
dynamo-runtime = { workspace = true }
dynamo-memory = { path = "../memory", version = "0.7.0", optional = true }
# workspace
aho-corasick = "1.1"
......@@ -145,7 +148,7 @@ json-five = { version = "0.3" }
# media loading in the preprocessor
reqwest = { workspace = true }
base64 = { version = "0.22" }
image = { version = "0.25" }
image = { version = "0.25", features = ["serde"] }
tokio-rayon = {version = "2" }
ndarray = { version = "0.16" }
ndarray-npy = { version = "0.9" }
......
......@@ -27,7 +27,8 @@ use std::{collections::HashMap, pin::Pin, sync::Arc};
use tracing;
use crate::model_card::{ModelDeploymentCard, ModelInfo};
use crate::preprocessor::media::MediaLoader;
#[cfg(feature = "media-nixl")]
use crate::preprocessor::media::{MediaDecoder, MediaFetcher, MediaLoader};
use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::{
MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder,
......@@ -114,6 +115,7 @@ pub struct OpenAIPreprocessor {
/// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser)
runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
tool_call_parser: Option<String>,
#[cfg(feature = "media-nixl")]
media_loader: Option<MediaLoader>,
}
......@@ -143,7 +145,13 @@ impl OpenAIPreprocessor {
// // Initialize runtime config from the ModelDeploymentCard
let runtime_config = mdc.runtime_config.clone();
let media_loader = None; // TODO: enable with decoder config from MDC
#[cfg(feature = "media-nixl")]
let media_loader = match mdc.media_decoder {
Some(media_decoder) => Some(MediaLoader::new(media_decoder, mdc.media_fetcher)?),
None => None,
};
Ok(Arc::new(Self {
formatter,
tokenizer,
......@@ -151,6 +159,7 @@ impl OpenAIPreprocessor {
mdcsum,
runtime_config,
tool_call_parser,
#[cfg(feature = "media-nixl")]
media_loader,
}))
}
......@@ -280,7 +289,9 @@ impl OpenAIPreprocessor {
let messages = request.messages();
let message_count = messages.len().unwrap_or(0);
let mut media_map: MultimodalDataMap = HashMap::new();
let mut fetch_tasks = Vec::new();
#[cfg(feature = "media-nixl")]
let mut fetch_tasks: Vec<(String, ChatCompletionRequestUserMessageContentPart)> =
Vec::new();
for idx in 0..message_count {
let msg = messages
......@@ -313,29 +324,39 @@ impl OpenAIPreprocessor {
_ => continue,
};
#[cfg(feature = "media-nixl")]
if self.media_loader.is_some() {
fetch_tasks.push((type_str, content_part.clone()));
} else {
// No loader, just pass the URL through
media_map
.entry(type_str)
.or_default()
.push(MultimodalData::Url(url));
continue;
}
//Fallback: ust pass the URL through
media_map
.entry(type_str)
.or_default()
.push(MultimodalData::Url(url));
}
}
// Execute all fetch tasks
#[cfg(feature = "media-nixl")]
if !fetch_tasks.is_empty() {
let loader = self.media_loader.as_ref().unwrap();
let _results = futures::future::join_all(
let results = futures::future::join_all(
fetch_tasks
.iter()
.map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)),
)
.await;
// TODO: decode and pass NIXL descriptors to the media map
for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) {
// if one item fails, errors the whole request, other items will be cleaned up by Drop
let rdma_descriptor = result?;
media_map
.entry(type_str)
.or_default()
.push(MultimodalData::Decoded(rdma_descriptor));
}
}
if !media_map.is_empty() {
......
......@@ -4,7 +4,12 @@
mod common;
mod decoders;
mod loader;
mod rdma;
pub use common::EncodedMediaData;
pub use decoders::{Decoder, ImageDecoder, MediaDecoder};
pub use loader::{MediaFetcher, MediaLoader};
pub use rdma::{DecodedMediaData, RdmaMediaDataDescriptor};
#[cfg(feature = "media-nixl")]
pub use rdma::{get_nixl_agent, get_nixl_metadata};
# Media decoding in the frontend
This component performs media download, base64 decoding, media decoding and NIXL registration. Today, this is used in the OpenAI preprocessor, to transform multimodal inputs (image_url, video_url, audio_url) into fully decoded data (pixel values, ...) accessible to the backends via NIXL.
## Usage
Media decoding is enabled when registering the MDC:
Set HTTP download options:
```python
from dynamo.llm import MediaFetcher
fetcher = MediaFetcher()
fetcher.user_agent("dynamo")
fetcher.timeout_ms(15000)
fetcher.allow_direct_ip(True)
fetcher.allow_direct_port(False)
fetcher.allowed_media_domains(["google.com"])
```
Set media decoding options:
```python
from dynamo.llm import MediaDecoder
decoder = MediaDecoder()
decoder.image_decoder({"max_image_width": 4096, "max_image_height": 4096, "max_alloc": 16*1024*1024})
```
And register the LLM as usual, adding the media configuration:
```python
register_llm(
...,
media_decoder=decoder,
media_fetcher=fetcher,
)
```
## TODOs
### Modalities
- [x] Image decoding
- [ ] Video decoding
- [ ] Audio decoding
### Performance
- [x] Image SW decoding
- [ ] Video HW decoding (NVDEC)
- [ ] JPEG HW decoding (nvJPEG)
- [ ] Sparse video sampling (seek-forward)
- [ ] Memory slab pre-allocation/registration
### Memory management
- [ ] Memory spilling to lower storage tiers
- [ ] Early-free memory on client notifications
### Misc
- [ ] Observability on performance, memory usage and input distributions
- [ ] Per-request decoding options
......@@ -2,52 +2,14 @@
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use serde::{Deserialize, Serialize};
use super::common::EncodedMediaData;
use ndarray::{ArrayBase, Dimension, OwnedRepr};
mod image;
use super::rdma::DecodedMediaData;
pub mod image;
pub use image::{ImageDecoder, ImageMetadata};
#[derive(Debug)]
pub enum DecodedMediaMetadata {
#[allow(dead_code)] // used in followup MR
Image(ImageMetadata),
}
#[derive(Debug, PartialEq, Eq)]
pub enum DataType {
UINT8,
}
// Decoded media data (image RGB, video frames pixels, ...)
#[derive(Debug)]
pub struct DecodedMediaData {
#[allow(dead_code)] // used in followup MR
pub(crate) data: Vec<u8>,
#[allow(dead_code)] // used in followup MR
pub(crate) shape: Vec<usize>,
#[allow(dead_code)] // used in followup MR
pub(crate) dtype: DataType,
#[allow(dead_code)] // used in followup MR
pub(crate) metadata: Option<DecodedMediaMetadata>,
}
// convert Array{N}<u8> to DecodedMediaData
// TODO: Array1<f32> for audio
impl<D: Dimension> From<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
fn from(array: ArrayBase<OwnedRepr<u8>, D>) -> Self {
let shape = array.shape().to_vec();
let (data, _) = array.into_raw_vec_and_offset();
Self {
data,
shape,
dtype: DataType::UINT8,
metadata: None,
}
}
}
#[async_trait::async_trait]
pub trait Decoder: Clone + Send + 'static {
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData>;
......@@ -67,3 +29,8 @@ pub struct MediaDecoder {
pub image_decoder: ImageDecoder,
// TODO: video, audio decoders
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum DecodedMediaMetadata {
Image(ImageMetadata),
}
......@@ -6,14 +6,15 @@ use std::io::Cursor;
use anyhow::Result;
use image::{ColorType, GenericImageView, ImageFormat, ImageReader};
use ndarray::Array3;
use serde::{Deserialize, Serialize};
use super::super::common::EncodedMediaData;
use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata};
use super::Decoder;
use super::super::rdma::DecodedMediaData;
use super::{DecodedMediaMetadata, Decoder};
const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ImageDecoder {
#[serde(default)]
......@@ -36,18 +37,15 @@ impl Default for ImageDecoder {
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug)]
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum ImageLayout {
HWC,
}
#[derive(Debug)]
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub struct ImageMetadata {
#[allow(dead_code)] // used in followup MR
pub(crate) format: Option<ImageFormat>,
#[allow(dead_code)] // used in followup MR
pub(crate) color_type: ColorType,
#[allow(dead_code)] // used in followup MR
pub(crate) layout: ImageLayout,
}
......@@ -78,8 +76,8 @@ impl Decoder for ImageDecoder {
let (width, height) = img.dimensions();
let shape = (height as usize, width as usize, n_channels as usize);
let array = Array3::from_shape_vec(shape, data)?;
let mut decoded: DecodedMediaData = array.into();
decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
let mut decoded: DecodedMediaData = array.try_into()?;
decoded.tensor_info.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
format,
color_type,
layout: ImageLayout::HWC,
......@@ -90,7 +88,7 @@ impl Decoder for ImageDecoder {
#[cfg(test)]
mod tests {
use super::super::super::decoders::DataType;
use super::super::super::rdma::DataType;
use super::*;
use image::{DynamicImage, ImageBuffer};
use rstest::rstest;
......@@ -156,10 +154,10 @@ mod tests {
let decoded = result.unwrap();
assert_eq!(
decoded.shape,
decoded.tensor_info.shape,
vec![height as usize, width as usize, expected_channels as usize]
);
assert_eq!(decoded.dtype, DataType::UINT8);
assert_eq!(decoded.tensor_info.dtype, DataType::UINT8);
}
#[rstest]
......@@ -196,9 +194,12 @@ mod tests {
format
);
let decoded = result.unwrap();
assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]);
assert_eq!(
decoded.dtype,
decoded.tensor_info.shape,
vec![height as usize, width as usize, 3]
);
assert_eq!(
decoded.tensor_info.dtype,
DataType::UINT8,
"dtype should be uint8 for case: {}",
test_case
......@@ -236,11 +237,15 @@ mod tests {
);
let decoded = result.unwrap();
assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions");
assert_eq!(decoded.shape[0], 1, "Height should be 1");
assert_eq!(decoded.shape[1], 1, "Width should be 1");
assert_eq!(
decoded.dtype,
decoded.tensor_info.shape.len(),
3,
"Should have 3 dimensions"
);
assert_eq!(decoded.tensor_info.shape[0], 1, "Height should be 1");
assert_eq!(decoded.tensor_info.shape[1], 1, "Width should be 1");
assert_eq!(
decoded.tensor_info.dtype,
DataType::UINT8,
"dtype should be uint8 for {} channels {:?}",
input_channels,
......
......@@ -8,8 +8,14 @@ use anyhow::Result;
use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;
use super::common::EncodedMediaData;
use super::decoders::{DecodedMediaData, Decoder, MediaDecoder};
use super::decoders::MediaDecoder;
use super::rdma::RdmaMediaDataDescriptor;
#[cfg(feature = "media-nixl")]
use {
super::common::EncodedMediaData, super::decoders::Decoder, super::rdma::get_nixl_agent,
dynamo_memory::nixl::NixlAgent,
};
const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
......@@ -36,15 +42,19 @@ impl Default for MediaFetcher {
}
pub struct MediaLoader {
#[allow(dead_code)]
media_decoder: MediaDecoder,
#[allow(dead_code)]
http_client: reqwest::Client,
media_fetcher: MediaFetcher,
// TODO: NIXL agent
#[cfg(feature = "media-nixl")]
nixl_agent: NixlAgent,
}
impl MediaLoader {
pub fn new(media_decoder: MediaDecoder, media_fetcher: MediaFetcher) -> Result<Self> {
let mut http_client_builder =
pub fn new(media_decoder: MediaDecoder, media_fetcher: Option<MediaFetcher>) -> Result<Self> {
let media_fetcher = media_fetcher.unwrap_or_default();
let mut http_client_builder: reqwest::ClientBuilder =
reqwest::Client::builder().user_agent(&media_fetcher.user_agent);
if let Some(timeout) = media_fetcher.timeout {
......@@ -53,10 +63,15 @@ impl MediaLoader {
let http_client = http_client_builder.build()?;
#[cfg(feature = "media-nixl")]
let nixl_agent = get_nixl_agent()?;
Ok(Self {
media_decoder,
http_client,
media_fetcher,
#[cfg(feature = "media-nixl")]
nixl_agent,
})
}
......@@ -90,35 +105,43 @@ impl MediaLoader {
&self,
oai_content_part: &ChatCompletionRequestUserMessageContentPart,
// TODO: request-level options
) -> Result<DecodedMediaData> {
// fetch the media
// TODO: decode and NIXL-register
let decoded = match oai_content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
let url = &image_part.image_url.url;
self.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
self.media_decoder.image_decoder.decode_async(data).await?
}
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
let url = &video_part.video_url.url;
self.check_if_url_allowed(url)?;
EncodedMediaData::from_url(url, &self.http_client).await?;
anyhow::bail!("Video decoding is not supported yet");
}
ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
anyhow::bail!("Audio decoding is not supported yet");
}
_ => anyhow::bail!("Unsupported media type"),
};
) -> Result<RdmaMediaDataDescriptor> {
#[cfg(not(feature = "media-nixl"))]
anyhow::bail!(
"NIXL is not supported, cannot decode and register media data {oai_content_part:?}"
);
Ok(decoded)
#[cfg(feature = "media-nixl")]
{
// fetch the media, decode and NIXL-register
let decoded = match oai_content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
let url = &image_part.image_url.url;
self.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
self.media_decoder.image_decoder.decode_async(data).await?
}
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
let url = &video_part.video_url.url;
self.check_if_url_allowed(url)?;
EncodedMediaData::from_url(url, &self.http_client).await?;
anyhow::bail!("Video decoding is not supported yet");
}
ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
anyhow::bail!("Audio decoding is not supported yet");
}
_ => anyhow::bail!("Unsupported media type"),
};
let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?;
Ok(rdma_descriptor)
}
}
}
#[cfg(test)]
#[cfg(all(test, feature = "media-nixl"))]
mod tests {
use super::super::decoders::DataType;
use super::super::rdma::DataType;
use super::*;
use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};
......@@ -143,7 +166,7 @@ mod tests {
..Default::default()
};
let loader = MediaLoader::new(media_decoder, fetcher).unwrap();
let loader: MediaLoader = MediaLoader::new(media_decoder, fetcher).unwrap();
let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url()));
let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl(
......@@ -151,24 +174,48 @@ mod tests {
);
let result = loader.fetch_and_decode_media_part(&content_part).await;
assert!(
result.is_ok(),
"Failed to fetch and decode image: {:?}",
result.err()
);
let data = result.unwrap();
assert_eq!(data.dtype, DataType::UINT8);
let descriptor = match result {
Ok(descriptor) => descriptor,
Err(e) if e.to_string().contains("NIXL agent is not available") => {
println!("test test_fetch_and_decode ... ignored (NIXL agent not available)");
return;
}
Err(e) => panic!("Failed to fetch and decode image: {}", e),
};
mock.assert_async().await;
assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8);
// Verify image dimensions: 1,999px × 1,125px (width × height)
// Shape format is [height, width, channels]
assert_eq!(data.shape.len(), 3);
assert_eq!(data.shape[0], 1125, "Height should be 1125");
assert_eq!(data.shape[1], 1999, "Width should be 1999");
assert_eq!(data.shape[2], 4, "RGBA channels should be 4");
assert_eq!(descriptor.tensor_info.shape.len(), 3);
assert_eq!(
descriptor.tensor_info.shape[0], 1125,
"Height should be 1125"
);
assert_eq!(
descriptor.tensor_info.shape[1], 1999,
"Width should be 1999"
);
assert_eq!(
descriptor.tensor_info.shape[2], 4,
"RGBA channels should be 4"
);
mock.assert_async().await;
assert!(
descriptor.source_storage.is_some(),
"Source storage should be present"
);
assert!(
descriptor.source_storage.unwrap().is_registered(),
"Source storage should be registered with NIXL"
);
}
}
#[cfg(test)]
mod tests_non_nixl {
use super::*;
#[test]
fn test_direct_ip_blocked() {
......@@ -176,7 +223,7 @@ mod tests {
allow_direct_ip: false,
..Default::default()
};
let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap();
let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url);
......@@ -196,7 +243,7 @@ mod tests {
allow_direct_port: false,
..Default::default()
};
let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap();
let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url);
......@@ -220,7 +267,7 @@ mod tests {
allowed_media_domains: Some(allowed_domains),
..Default::default()
};
let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap();
let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
// Allowed domain should pass
let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use ndarray::{ArrayBase, Dimension, OwnedRepr};
use serde::{Deserialize, Serialize};
#[cfg(feature = "media-nixl")]
use {
base64::{Engine as _, engine::general_purpose},
dynamo_memory::SystemStorage,
dynamo_memory::nixl::{self, NixlAgent, NixlDescriptor, RegisteredView},
std::sync::Arc,
};
use super::decoders::DecodedMediaMetadata;
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum DataType {
UINT8,
}
// Common tensor metadata shared between decoded and RDMA descriptors
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct MediaTensorInfo {
pub(crate) shape: Vec<usize>,
pub(crate) dtype: DataType,
pub(crate) metadata: Option<DecodedMediaMetadata>,
}
// Decoded media data (image RGB, video frames pixels, ...)
#[derive(Debug)]
pub struct DecodedMediaData {
#[cfg(feature = "media-nixl")]
pub(crate) data: SystemStorage,
pub(crate) tensor_info: MediaTensorInfo,
}
// Decoded media data NIXL descriptor (sent to the next step in the pipeline / NATS)
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct RdmaMediaDataDescriptor {
// b64 agent metadata
#[cfg(feature = "media-nixl")]
pub(crate) nixl_metadata: String,
// tensor descriptor
#[cfg(feature = "media-nixl")]
pub(crate) nixl_descriptor: NixlDescriptor,
#[serde(flatten)]
pub(crate) tensor_info: MediaTensorInfo,
// reference to the actual data, kept alive while the rdma descriptor is alive
#[serde(skip, default)]
#[allow(dead_code)]
#[cfg(feature = "media-nixl")]
pub(crate) source_storage: Option<Arc<nixl::NixlRegistered<SystemStorage>>>,
}
impl DecodedMediaData {
#[cfg(feature = "media-nixl")]
pub fn into_rdma_descriptor(self, nixl_agent: &NixlAgent) -> Result<RdmaMediaDataDescriptor> {
let source_storage = self.data;
let registered = nixl::register_with_nixl(source_storage, nixl_agent, None)
.map_err(|_| anyhow::anyhow!("Failed to register storage with NIXL"))?;
let nixl_descriptor = registered.descriptor();
let nixl_metadata = get_nixl_metadata(nixl_agent, registered.storage())?;
Ok(RdmaMediaDataDescriptor {
nixl_metadata,
nixl_descriptor,
tensor_info: self.tensor_info,
// Keep registered storage alive
source_storage: Some(Arc::new(registered)),
})
}
}
// convert Array{N}<u8> to DecodedMediaData
// TODO: Array1<f32> for audio
impl<D: Dimension> TryFrom<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
type Error = anyhow::Error;
fn try_from(array: ArrayBase<OwnedRepr<u8>, D>) -> Result<Self, Self::Error> {
let shape = array.shape().to_vec();
#[cfg(feature = "media-nixl")]
let (data_vec, _) = array.into_raw_vec_and_offset();
#[cfg(feature = "media-nixl")]
let mut storage = SystemStorage::new(data_vec.len())?;
#[cfg(feature = "media-nixl")]
unsafe {
std::ptr::copy_nonoverlapping(data_vec.as_ptr(), storage.as_mut_ptr(), data_vec.len());
}
Ok(Self {
#[cfg(feature = "media-nixl")]
data: storage,
tensor_info: MediaTensorInfo {
shape,
dtype: DataType::UINT8,
metadata: None,
},
})
}
}
// Get NIXL metadata for a descriptor
// Avoids cross-request leak possibility and reduces metadata size
// TODO: pre-allocate a fixed NIXL-registered RAM pool so metadata can be cached on the target?
#[cfg(feature = "media-nixl")]
pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result<String> {
// WAR: Until https://github.com/ai-dynamo/nixl/pull/970 is merged, can't use get_local_partial_md
let nixl_md = agent.raw_agent().get_local_md()?;
// let mut reg_desc_list = RegDescList::new(MemType::Dram)?;
// reg_desc_list.add_storage_desc(storage)?;
// let nixl_partial_md = agent.raw_agent().get_local_partial_md(&reg_desc_list, None)?;
let b64_encoded = general_purpose::STANDARD.encode(&nixl_md);
Ok(format!("b64:{}", b64_encoded))
}
#[cfg(feature = "media-nixl")]
pub fn get_nixl_agent() -> Result<NixlAgent> {
let name = format!("media-loader-{}", uuid::Uuid::new_v4());
let nixl_agent = NixlAgent::with_backends(&name, &["UCX"])?;
Ok(nixl_agent)
}
......@@ -6,6 +6,8 @@ use serde::{Deserialize, Serialize};
use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride;
#[cfg(feature = "media-nixl")]
use crate::preprocessor::media::RdmaMediaDataDescriptor;
use crate::protocols::TokenIdType;
#[derive(Serialize, Deserialize, Debug, Clone)]
......@@ -20,7 +22,8 @@ pub struct PrefillResult {
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum MultimodalData {
Url(url::Url),
// TODO: Decoded(DecodedMediaData),
#[cfg(feature = "media-nixl")]
Decoded(RdmaMediaDataDescriptor),
}
// multimodal map containing {mm_part_type: [data...]}
......@@ -40,6 +43,7 @@ pub struct PreprocessedRequest {
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub multi_modal_data: Option<MultimodalDataMap>,
/// StopConditions are conditions that the inference engine will use to stop generation.
pub stop_conditions: StopConditions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment