Unverified Commit ae4b08ac authored by milesial's avatar milesial Committed by GitHub
Browse files

feat: Image decoder in the frontend (#3971)


Signed-off-by: default avatarAlexandre Milesi <milesial@users.noreply.github.com>
parent 72b0aec1
......@@ -26,6 +26,8 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to NGC
......
......@@ -44,6 +44,8 @@ jobs:
contents: read
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up system dependencies
run: |
# Install protoc for Rust build dependencies (NOTE: much faster than apt install)
......@@ -94,6 +96,8 @@ jobs:
contents: read
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up system dependencies
run: |
# Install protoc for Rust build dependencies (NOTE: much faster than apt install)
......
......@@ -2244,6 +2244,7 @@ dependencies = [
"galil-seiferas",
"hf-hub",
"humantime",
"image",
"insta",
"itertools 0.14.0",
"json-five",
......@@ -2280,6 +2281,7 @@ dependencies = [
"tmq",
"tokenizers",
"tokio",
"tokio-rayon",
"tokio-stream",
"tokio-util",
"toktrie 1.2.0",
......
......@@ -21,7 +21,7 @@ testing-full = ["testing-cuda", "testing-nixl"]
testing-cuda = ["dep:cudarc"]
testing-nixl = ["dep:nixl-sys"]
testing-etcd = []
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix", "dep:aligned-vec"]
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec"]
cuda = ["dep:cudarc"]
integration = ["dynamo-runtime/integration"]
......@@ -97,7 +97,6 @@ dialoguer = { version = "0.11", default-features = false, features = [
aligned-vec = { version = "0.6.4", optional = true }
nixl-sys = { version = "=0.7.0", optional = true }
cudarc = { workspace = true, optional = true }
ndarray = { version = "0.16", optional = true }
nix = { version = "0.26", optional = true }
......@@ -143,6 +142,9 @@ json-five = { version = "0.3" }
# media loading in the preprocessor
reqwest = { workspace = true }
base64 = { version = "0.22" }
image = { version = "0.25" }
tokio-rayon = {version = "2" }
ndarray = { version = "0.16" }
# Publishers
zeromq = "0.4.1"
......
......@@ -330,7 +330,7 @@ impl OpenAIPreprocessor {
let _results = futures::future::join_all(
fetch_tasks
.iter()
.map(|(_, content_part)| loader.fetch_media_part(content_part)),
.map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)),
)
.await;
......
......@@ -2,7 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
mod common;
mod decoders;
mod loader;
pub use common::EncodedMediaData;
pub use decoders::{Decoder, ImageDecoder, MediaDecoder};
pub use loader::MediaLoader;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use super::common::EncodedMediaData;
use ndarray::{ArrayBase, Dimension, OwnedRepr};
mod image;
pub use image::{ImageDecoder, ImageMetadata};
#[derive(Debug)]
pub enum DecodedMediaMetadata {
#[allow(dead_code)] // used in followup MR
Image(ImageMetadata),
}
#[derive(Debug, PartialEq, Eq)]
pub enum DataType {
UINT8,
}
// Decoded media data (image RGB, video frames pixels, ...)
#[derive(Debug)]
pub struct DecodedMediaData {
#[allow(dead_code)] // used in followup MR
pub(crate) data: Vec<u8>,
#[allow(dead_code)] // used in followup MR
pub(crate) shape: Vec<usize>,
#[allow(dead_code)] // used in followup MR
pub(crate) dtype: DataType,
#[allow(dead_code)] // used in followup MR
pub(crate) metadata: Option<DecodedMediaMetadata>,
}
// convert Array{N}<u8> to DecodedMediaData
// TODO: Array1<f32> for audio
impl<D: Dimension> From<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
fn from(array: ArrayBase<OwnedRepr<u8>, D>) -> Self {
let shape = array.shape().to_vec();
let (data, _) = array.into_raw_vec_and_offset();
Self {
data,
shape,
dtype: DataType::UINT8,
metadata: None,
}
}
}
#[async_trait::async_trait]
pub trait Decoder: Clone + Send + 'static {
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData>;
async fn decode_async(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
// light clone (only config params)
let decoder = self.clone();
// compute heavy -> rayon
let result = tokio_rayon::spawn(move || decoder.decode(data)).await?;
Ok(result)
}
}
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct MediaDecoder {
#[serde(default)]
pub image_decoder: ImageDecoder,
// TODO: video, audio decoders
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::io::Cursor;
use anyhow::Result;
use image::{ColorType, GenericImageView, ImageFormat, ImageReader};
use ndarray::Array3;
use super::super::common::EncodedMediaData;
use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata};
use super::Decoder;
const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ImageDecoder {
#[serde(default)]
pub(crate) max_image_width: Option<u32>,
#[serde(default)]
pub(crate) max_image_height: Option<u32>,
// maximum allowed total allocation of the decoder in bytes
#[serde(default)]
pub(crate) max_alloc: Option<u64>,
}
impl Default for ImageDecoder {
fn default() -> Self {
Self {
max_image_width: None,
max_image_height: None,
max_alloc: Some(DEFAULT_MAX_ALLOC),
}
}
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug)]
pub enum ImageLayout {
HWC,
}
#[derive(Debug)]
pub struct ImageMetadata {
#[allow(dead_code)] // used in followup MR
pub(crate) format: Option<ImageFormat>,
#[allow(dead_code)] // used in followup MR
pub(crate) color_type: ColorType,
#[allow(dead_code)] // used in followup MR
pub(crate) layout: ImageLayout,
}
impl Decoder for ImageDecoder {
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
let bytes = data.into_bytes()?;
let mut reader = ImageReader::new(Cursor::new(bytes)).with_guessed_format()?;
let mut limits = image::Limits::no_limits();
limits.max_image_width = self.max_image_width;
limits.max_image_height = self.max_image_height;
limits.max_alloc = self.max_alloc;
reader.limits(limits);
let format = reader.format();
let img = reader.decode()?;
let n_channels = img.color().channel_count();
let (data, color_type) = match n_channels {
1 => (img.to_luma8().into_raw(), ColorType::L8),
2 => (img.to_luma_alpha8().into_raw(), ColorType::La8),
3 => (img.to_rgb8().into_raw(), ColorType::Rgb8),
4 => (img.to_rgba8().into_raw(), ColorType::Rgba8),
other => anyhow::bail!("Unsupported channel count {other}"),
};
let (width, height) = img.dimensions();
let shape = (height as usize, width as usize, n_channels as usize);
let array = Array3::from_shape_vec(shape, data)?;
let mut decoded: DecodedMediaData = array.into();
decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
format,
color_type,
layout: ImageLayout::HWC,
}));
Ok(decoded)
}
}
#[cfg(test)]
mod tests {
use super::super::super::decoders::DataType;
use super::*;
use image::{DynamicImage, ImageBuffer};
use rstest::rstest;
use std::io::Cursor;
fn create_encoded_media_data(bytes: Vec<u8>) -> EncodedMediaData {
EncodedMediaData {
bytes,
b64_encoded: false,
}
}
fn create_test_image(
width: u32,
height: u32,
channels: u32,
format: image::ImageFormat,
) -> Vec<u8> {
// Create dynamic image based on number of channels with constant values
let pixels = vec![128u8; channels as usize].repeat((width * height) as usize);
let dynamic_image = match channels {
1 => DynamicImage::ImageLuma8(
ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),
),
3 => DynamicImage::ImageRgb8(
ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),
),
4 => DynamicImage::ImageRgba8(
ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),
),
_ => unreachable!("Already validated channel count above"),
};
// Encode to bytes
let mut bytes = Vec::new();
dynamic_image
.write_to(&mut Cursor::new(&mut bytes), format)
.expect("Failed to encode test image");
bytes
}
#[rstest]
#[case(3, image::ImageFormat::Png, 10, 10, 3, "RGB PNG")]
#[case(4, image::ImageFormat::Png, 25, 30, 4, "RGBA PNG")]
#[case(1, image::ImageFormat::Png, 8, 12, 1, "Grayscale PNG")]
#[case(3, image::ImageFormat::Jpeg, 15, 20, 3, "RGB JPEG")]
#[case(3, image::ImageFormat::Bmp, 12, 18, 3, "RGB BMP")]
#[case(3, image::ImageFormat::WebP, 8, 8, 3, "RGB WebP")]
fn test_image_decode(
#[case] input_channels: u32,
#[case] format: image::ImageFormat,
#[case] width: u32,
#[case] height: u32,
#[case] expected_channels: u32,
#[case] description: &str,
) {
let decoder = ImageDecoder::default();
let image_bytes = create_test_image(width, height, input_channels, format);
let encoded_data = create_encoded_media_data(image_bytes);
let result = decoder.decode(encoded_data);
assert!(result.is_ok(), "Failed to decode {}", description);
let decoded = result.unwrap();
assert_eq!(
decoded.shape,
vec![height as usize, width as usize, expected_channels as usize]
);
assert_eq!(decoded.dtype, DataType::UINT8);
}
#[rstest]
#[case(Some(100), None, 50, 50, ImageFormat::Png, true, "width ok")]
#[case(Some(50), None, 100, 50, ImageFormat::Jpeg, false, "width too large")]
#[case(None, Some(100), 50, 100, ImageFormat::Png, true, "height ok")]
#[case(None, Some(50), 50, 100, ImageFormat::Png, false, "height too large")]
#[case(None, None, 2000, 2000, ImageFormat::Png, true, "no limits")]
#[case(None, None, 8000, 8000, ImageFormat::Png, false, "alloc too large")]
fn test_limits(
#[case] max_width: Option<u32>,
#[case] max_height: Option<u32>,
#[case] width: u32,
#[case] height: u32,
#[case] format: image::ImageFormat,
#[case] should_succeed: bool,
#[case] test_case: &str,
) {
let decoder = ImageDecoder {
max_image_width: max_width,
max_image_height: max_height,
max_alloc: Some(DEFAULT_MAX_ALLOC),
};
let image_bytes = create_test_image(width, height, 3, format); // RGB
let encoded_data = create_encoded_media_data(image_bytes);
let result = decoder.decode(encoded_data);
if should_succeed {
assert!(
result.is_ok(),
"Should decode successfully for case: {} with format {:?}",
test_case,
format
);
let decoded = result.unwrap();
assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]);
assert_eq!(
decoded.dtype,
DataType::UINT8,
"dtype should be uint8 for case: {}",
test_case
);
} else {
assert!(
result.is_err(),
"Should fail for case: {} with format {:?}",
test_case,
format
);
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("dimensions") || error_msg.contains("limit"),
"Error should mention dimension limits, got: {} for case: {}",
error_msg,
test_case
);
}
}
#[rstest]
#[case(3, image::ImageFormat::Png)]
fn test_decode_1x1_image(#[case] input_channels: u32, #[case] format: image::ImageFormat) {
let decoder = ImageDecoder::default();
let image_bytes = create_test_image(1, 1, input_channels, format);
let encoded_data = create_encoded_media_data(image_bytes);
let result = decoder.decode(encoded_data);
assert!(
result.is_ok(),
"Should decode 1x1 image with {} channels in {:?} format successfully",
input_channels,
format
);
let decoded = result.unwrap();
assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions");
assert_eq!(decoded.shape[0], 1, "Height should be 1");
assert_eq!(decoded.shape[1], 1, "Width should be 1");
assert_eq!(
decoded.dtype,
DataType::UINT8,
"dtype should be uint8 for {} channels {:?}",
input_channels,
format
);
}
}
......@@ -9,8 +9,10 @@ use anyhow::Result;
use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;
use super::common::EncodedMediaData;
use super::decoders::{DecodedMediaData, Decoder, MediaDecoder};
const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct MediaFetcher {
......@@ -28,19 +30,20 @@ impl Default for MediaFetcher {
allow_direct_ip: false,
allow_direct_port: false,
allowed_media_domains: None,
timeout: None,
timeout: Some(DEFAULT_HTTP_TIMEOUT),
}
}
}
pub struct MediaLoader {
media_decoder: MediaDecoder,
http_client: reqwest::Client,
media_fetcher: MediaFetcher,
// TODO: decoders, NIXL agent
// TODO: NIXL agent
}
impl MediaLoader {
pub fn new(media_fetcher: MediaFetcher) -> Result<Self> {
pub fn new(media_decoder: MediaDecoder, media_fetcher: MediaFetcher) -> Result<Self> {
let mut http_client_builder =
reqwest::Client::builder().user_agent(&media_fetcher.user_agent);
......@@ -51,6 +54,7 @@ impl MediaLoader {
let http_client = http_client_builder.build()?;
Ok(Self {
media_decoder,
http_client,
media_fetcher,
})
......@@ -82,23 +86,25 @@ impl MediaLoader {
Ok(())
}
pub async fn fetch_media_part(
pub async fn fetch_and_decode_media_part(
&self,
oai_content_part: &ChatCompletionRequestUserMessageContentPart,
// TODO: request-level options
) -> Result<EncodedMediaData> {
) -> Result<DecodedMediaData> {
// fetch the media
// TODO: decode and NIXL-register
let data = match oai_content_part {
let decoded = match oai_content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
let url = &image_part.image_url.url;
self.check_if_url_allowed(url)?;
EncodedMediaData::from_url(url, &self.http_client).await?
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
self.media_decoder.image_decoder.decode_async(data).await?
}
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
let url = &video_part.video_url.url;
self.check_if_url_allowed(url)?;
EncodedMediaData::from_url(url, &self.http_client).await?
EncodedMediaData::from_url(url, &self.http_client).await?;
anyhow::bail!("Video decoding is not supported yet");
}
ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
anyhow::bail!("Audio decoding is not supported yet");
......@@ -106,13 +112,63 @@ impl MediaLoader {
_ => anyhow::bail!("Unsupported media type"),
};
Ok(data)
Ok(decoded)
}
}
#[cfg(test)]
mod tests {
use super::super::decoders::DataType;
use super::*;
use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};
#[tokio::test]
async fn test_fetch_and_decode() {
let test_image_bytes =
include_bytes!("../../../tests/data/media/llm-optimize-deploy-graphic.png");
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/llm-optimize-deploy-graphic.png")
.with_status(200)
.with_header("content-type", "image/png")
.with_body(&test_image_bytes[..])
.create_async()
.await;
let media_decoder = MediaDecoder::default();
let fetcher = MediaFetcher {
allow_direct_ip: true,
allow_direct_port: true,
..Default::default()
};
let loader = MediaLoader::new(media_decoder, fetcher).unwrap();
let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url()));
let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl(
ChatCompletionRequestMessageContentPartImage { image_url },
);
let result = loader.fetch_and_decode_media_part(&content_part).await;
assert!(
result.is_ok(),
"Failed to fetch and decode image: {:?}",
result.err()
);
let data = result.unwrap();
assert_eq!(data.dtype, DataType::UINT8);
// Verify image dimensions: 1,999px × 1,125px (width × height)
// Shape format is [height, width, channels]
assert_eq!(data.shape.len(), 3);
assert_eq!(data.shape[0], 1125, "Height should be 1125");
assert_eq!(data.shape[1], 1999, "Width should be 1999");
assert_eq!(data.shape[2], 4, "RGBA channels should be 4");
mock.assert_async().await;
}
#[test]
fn test_direct_ip_blocked() {
......@@ -120,7 +176,7 @@ mod tests {
allow_direct_ip: false,
..Default::default()
};
let loader = MediaLoader::new(fetcher).unwrap();
let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap();
let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url);
......@@ -140,7 +196,7 @@ mod tests {
allow_direct_port: false,
..Default::default()
};
let loader = MediaLoader::new(fetcher).unwrap();
let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap();
let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url);
......@@ -164,7 +220,7 @@ mod tests {
allowed_media_domains: Some(allowed_domains),
..Default::default()
};
let loader = MediaLoader::new(fetcher).unwrap();
let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap();
// Allowed domain should pass
let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
......
llm-optimize-deploy-graphic.png filter=lfs diff=lfs merge=lfs -text
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment