// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::io::Cursor; use anyhow::Result; use image::{ColorType, GenericImageView, ImageFormat, ImageReader}; use ndarray::Array3; use serde::{Deserialize, Serialize}; use super::super::common::EncodedMediaData; use super::super::rdma::DecodedMediaData; use super::{DecodedMediaMetadata, Decoder}; const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct ImageDecoder { #[serde(default)] pub(crate) max_image_width: Option, #[serde(default)] pub(crate) max_image_height: Option, // maximum allowed total allocation of the decoder in bytes #[serde(default)] pub(crate) max_alloc: Option, } impl Default for ImageDecoder { fn default() -> Self { Self { max_image_width: None, max_image_height: None, max_alloc: Some(DEFAULT_MAX_ALLOC), } } } #[allow(clippy::upper_case_acronyms)] #[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub enum ImageLayout { HWC, } #[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub struct ImageMetadata { pub(crate) format: Option, pub(crate) color_type: ColorType, pub(crate) layout: ImageLayout, } impl Decoder for ImageDecoder { fn decode(&self, data: EncodedMediaData) -> Result { let bytes = data.into_bytes()?; let mut reader = ImageReader::new(Cursor::new(bytes)).with_guessed_format()?; let mut limits = image::Limits::no_limits(); limits.max_image_width = self.max_image_width; limits.max_image_height = self.max_image_height; limits.max_alloc = self.max_alloc; reader.limits(limits); let format = reader.format(); let img = reader.decode()?; let n_channels = img.color().channel_count(); let (data, color_type) = match n_channels { 1 => (img.to_luma8().into_raw(), ColorType::L8), 2 => (img.to_luma_alpha8().into_raw(), ColorType::La8), 3 => (img.to_rgb8().into_raw(), ColorType::Rgb8), 4 => (img.to_rgba8().into_raw(), ColorType::Rgba8), other => anyhow::bail!("Unsupported channel count {other}"), }; let (width, height) = img.dimensions(); let shape = (height as usize, width as usize, n_channels as usize); let array = Array3::from_shape_vec(shape, data)?; let mut decoded: DecodedMediaData = array.try_into()?; decoded.tensor_info.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata { format, color_type, layout: ImageLayout::HWC, })); Ok(decoded) } } #[cfg(test)] mod tests { use super::super::super::rdma::DataType; use super::*; use image::{DynamicImage, ImageBuffer}; use rstest::rstest; use std::io::Cursor; fn create_encoded_media_data(bytes: Vec) -> EncodedMediaData { EncodedMediaData { bytes, b64_encoded: false, } } fn create_test_image( width: u32, height: u32, channels: u32, format: image::ImageFormat, ) -> Vec { // Create dynamic image based on number of channels with constant values let pixels = vec![128u8; channels as usize].repeat((width * height) as usize); let dynamic_image = match channels { 1 => DynamicImage::ImageLuma8( ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"), ), 3 => DynamicImage::ImageRgb8( ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"), ), 4 => DynamicImage::ImageRgba8( ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"), ), _ => unreachable!("Already validated channel count above"), }; // Encode to bytes let mut bytes = Vec::new(); dynamic_image .write_to(&mut Cursor::new(&mut bytes), format) .expect("Failed to encode test image"); bytes } #[rstest] #[case(3, image::ImageFormat::Png, 10, 10, 3, "RGB PNG")] #[case(4, image::ImageFormat::Png, 25, 30, 4, "RGBA PNG")] #[case(1, image::ImageFormat::Png, 8, 12, 1, "Grayscale PNG")] #[case(3, image::ImageFormat::Jpeg, 15, 20, 3, "RGB JPEG")] #[case(3, image::ImageFormat::Bmp, 12, 18, 3, "RGB BMP")] #[case(3, image::ImageFormat::WebP, 8, 8, 3, "RGB WebP")] fn test_image_decode( #[case] input_channels: u32, #[case] format: image::ImageFormat, #[case] width: u32, #[case] height: u32, #[case] expected_channels: u32, #[case] description: &str, ) { let decoder = ImageDecoder::default(); let image_bytes = create_test_image(width, height, input_channels, format); let encoded_data = create_encoded_media_data(image_bytes); let result = decoder.decode(encoded_data); assert!(result.is_ok(), "Failed to decode {}", description); let decoded = result.unwrap(); assert_eq!( decoded.tensor_info.shape, vec![height as usize, width as usize, expected_channels as usize] ); assert_eq!(decoded.tensor_info.dtype, DataType::UINT8); } #[rstest] #[case(Some(100), None, 50, 50, ImageFormat::Png, true, "width ok")] #[case(Some(50), None, 100, 50, ImageFormat::Jpeg, false, "width too large")] #[case(None, Some(100), 50, 100, ImageFormat::Png, true, "height ok")] #[case(None, Some(50), 50, 100, ImageFormat::Png, false, "height too large")] #[case(None, None, 2000, 2000, ImageFormat::Png, true, "no limits")] #[case(None, None, 8000, 8000, ImageFormat::Png, false, "alloc too large")] fn test_limits( #[case] max_width: Option, #[case] max_height: Option, #[case] width: u32, #[case] height: u32, #[case] format: image::ImageFormat, #[case] should_succeed: bool, #[case] test_case: &str, ) { let decoder = ImageDecoder { max_image_width: max_width, max_image_height: max_height, max_alloc: Some(DEFAULT_MAX_ALLOC), }; let image_bytes = create_test_image(width, height, 3, format); // RGB let encoded_data = create_encoded_media_data(image_bytes); let result = decoder.decode(encoded_data); if should_succeed { assert!( result.is_ok(), "Should decode successfully for case: {} with format {:?}", test_case, format ); let decoded = result.unwrap(); assert_eq!( decoded.tensor_info.shape, vec![height as usize, width as usize, 3] ); assert_eq!( decoded.tensor_info.dtype, DataType::UINT8, "dtype should be uint8 for case: {}", test_case ); } else { assert!( result.is_err(), "Should fail for case: {} with format {:?}", test_case, format ); let error_msg = result.unwrap_err().to_string(); assert!( error_msg.contains("dimensions") || error_msg.contains("limit"), "Error should mention dimension limits, got: {} for case: {}", error_msg, test_case ); } } #[rstest] #[case(3, image::ImageFormat::Png)] fn test_decode_1x1_image(#[case] input_channels: u32, #[case] format: image::ImageFormat) { let decoder = ImageDecoder::default(); let image_bytes = create_test_image(1, 1, input_channels, format); let encoded_data = create_encoded_media_data(image_bytes); let result = decoder.decode(encoded_data); assert!( result.is_ok(), "Should decode 1x1 image with {} channels in {:?} format successfully", input_channels, format ); let decoded = result.unwrap(); assert_eq!( decoded.tensor_info.shape.len(), 3, "Should have 3 dimensions" ); assert_eq!(decoded.tensor_info.shape[0], 1, "Height should be 1"); assert_eq!(decoded.tensor_info.shape[1], 1, "Width should be 1"); assert_eq!( decoded.tensor_info.dtype, DataType::UINT8, "dtype should be uint8 for {} channels {:?}", input_channels, format ); } }