Unverified Commit d2faf0e6 authored by milesial's avatar milesial Committed by GitHub
Browse files

feat: Runtime media decoder config (#5011)


Signed-off-by: default avatarAlexandre Milesi <milesial@users.noreply.github.com>
parent 481dc636
...@@ -92,20 +92,20 @@ impl MediaDecoder { ...@@ -92,20 +92,20 @@ impl MediaDecoder {
} }
} }
fn image_decoder(&mut self, image_decoder: &Bound<'_, PyDict>) -> PyResult<()> { fn enable_image(&mut self, decoder_options: &Bound<'_, PyDict>) -> PyResult<()> {
let image_decoder = pythonize::depythonize(image_decoder).map_err(|err| { let decoder_options = pythonize::depythonize(decoder_options).map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to parse image_decoder: {}", err)) PyErr::new::<PyException, _>(format!("Failed to parse image decoder config: {}", err))
})?; })?;
self.inner.image_decoder = image_decoder; self.inner.image = Some(decoder_options);
Ok(()) Ok(())
} }
#[cfg(feature = "media-ffmpeg")] #[cfg(feature = "media-ffmpeg")]
fn video_decoder(&mut self, video_decoder: &Bound<'_, PyDict>) -> PyResult<()> { fn enable_video(&mut self, decoder_options: &Bound<'_, PyDict>) -> PyResult<()> {
let video_decoder = pythonize::depythonize(video_decoder).map_err(|err| { let decoder_options = pythonize::depythonize(decoder_options).map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to parse video_decoder: {}", err)) PyErr::new::<PyException, _>(format!("Failed to parse video decoder config: {}", err))
})?; })?;
self.inner.video_decoder = video_decoder; self.inner.video = Some(decoder_options);
Ok(()) Ok(())
} }
} }
......
...@@ -229,6 +229,7 @@ async fn evaluate( ...@@ -229,6 +229,7 @@ async fn evaluate(
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
let mut stream = engine.generate(Context::new(req)).await?; let mut stream = engine.generate(Context::new(req)).await?;
......
...@@ -114,6 +114,7 @@ async fn main_loop( ...@@ -114,6 +114,7 @@ async fn main_loop(
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
......
...@@ -181,7 +181,7 @@ impl ErrorMessage { ...@@ -181,7 +181,7 @@ impl ErrorMessage {
// Then check for HttpError // Then check for HttpError
match err.downcast::<HttpError>() { match err.downcast::<HttpError>() {
Ok(http_error) => ErrorMessage::from_http_error(http_error), Ok(http_error) => ErrorMessage::from_http_error(http_error),
Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err}")), Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err:#}")),
} }
} }
...@@ -1735,6 +1735,7 @@ mod tests { ...@@ -1735,6 +1735,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_required_fields(&request); let result = validate_chat_completion_required_fields(&request);
...@@ -1766,6 +1767,7 @@ mod tests { ...@@ -1766,6 +1767,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_required_fields(&request); let result = validate_chat_completion_required_fields(&request);
...@@ -1981,6 +1983,7 @@ mod tests { ...@@ -1981,6 +1983,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
...@@ -2010,6 +2013,7 @@ mod tests { ...@@ -2010,6 +2013,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
...@@ -2038,6 +2042,7 @@ mod tests { ...@@ -2038,6 +2042,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
...@@ -2066,6 +2071,7 @@ mod tests { ...@@ -2066,6 +2071,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
...@@ -2096,6 +2102,7 @@ mod tests { ...@@ -2096,6 +2102,7 @@ mod tests {
.unwrap(), .unwrap(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
...@@ -2124,6 +2131,7 @@ mod tests { ...@@ -2124,6 +2131,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
......
...@@ -15,7 +15,7 @@ use crate::entrypoint::RouterConfig; ...@@ -15,7 +15,7 @@ use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs; use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::ModelDeploymentCard; use crate::model_card::ModelDeploymentCard;
use crate::model_type::{ModelInput, ModelType}; use crate::model_type::{ModelInput, ModelType};
use crate::preprocessor::media::{MediaDecoder, MediaFetcher}; use crate::preprocessor::media::{ImageDecoder, MediaDecoder, MediaFetcher};
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
pub mod runtime_config; pub mod runtime_config;
...@@ -243,7 +243,11 @@ impl LocalModelBuilder { ...@@ -243,7 +243,11 @@ impl LocalModelBuilder {
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64); mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer; self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer;
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size; self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
self.media_decoder = Some(MediaDecoder::default()); self.media_decoder = Some(MediaDecoder {
image: Some(ImageDecoder::default()),
#[cfg(feature = "media-ffmpeg")]
video: None,
});
self.media_fetcher = Some(MediaFetcher::default()); self.media_fetcher = Some(MediaFetcher::default());
} }
......
...@@ -345,11 +345,10 @@ impl OpenAIPreprocessor { ...@@ -345,11 +345,10 @@ impl OpenAIPreprocessor {
#[cfg(feature = "media-nixl")] #[cfg(feature = "media-nixl")]
if !fetch_tasks.is_empty() { if !fetch_tasks.is_empty() {
let loader = self.media_loader.as_ref().unwrap(); let loader = self.media_loader.as_ref().unwrap();
let results = futures::future::join_all( let media_io_kwargs = request.media_io_kwargs();
fetch_tasks let results = futures::future::join_all(fetch_tasks.iter().map(|(_, content_part)| {
.iter() loader.fetch_and_decode_media_part(content_part, media_io_kwargs)
.map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)), }))
)
.await; .await;
for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) { for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) {
......
...@@ -19,16 +19,18 @@ fetcher.allow_direct_port(False) ...@@ -19,16 +19,18 @@ fetcher.allow_direct_port(False)
fetcher.allowed_media_domains(["google.com"]) fetcher.allowed_media_domains(["google.com"])
``` ```
Set media decoding options: Set media decoding default options and limits:
```python ```python
from dynamo.llm import MediaDecoder from dynamo.llm import MediaDecoder
decoder = MediaDecoder() decoder = MediaDecoder()
decoder.image_decoder({"max_image_width": 4096, "max_image_height": 4096, "max_alloc": 16*1024*1024}) decoder.enable_image({"limits": {"max_image_width": 4096, "max_image_height": 4096, "max_alloc": 16*1024*1024}})
decoder.video_decoder({"strict": True, "fps": 2.0, "max_frames": 128, "max_alloc": 1024*1024*128*3}) decoder.enable_video({"fps": 2.0, "max_frames": 128, "limits": {"max_alloc": 1024*1024*128*3}})
``` ```
And register the LLM as usual, adding the media configuration: If `enable_image` or `enable_video` are not called, requests containing the corresponding modality will be rejected.
Register the LLM as usual, adding the media configuration:
```python ```python
register_llm( register_llm(
...@@ -47,11 +49,15 @@ register_llm( ...@@ -47,11 +49,15 @@ register_llm(
> [!WARNING] > [!WARNING]
> **Requires GPU node**: The frontend must run on a node with GPU access. During media processing, decoded tensors are written to GPU memory via NIXL, which requires `libcuda.so.1` to be available. Running the frontend on a CPU-only node will fail with something like: `Failed to initialize required backends: [UCX: No UCX plugin found]`. > **Requires GPU node**: The frontend must run on a node with GPU access. During media processing, decoded tensors are written to GPU memory via NIXL, which requires `libcuda.so.1` to be available. Running the frontend on a CPU-only node will fail with something like: `Failed to initialize required backends: [UCX: No UCX plugin found]`.
> [!WARNING]
> **Video decoding**: Video decoding needs to be enabled via the `dynamo-llm/media-ffmpeg` rust feature. The following ffmpeg dynamic libraries must be available on the system: `libavcodec`, `libavdevice`, `libavfilter`, `libavformat`, `libswresample`, `libswscale`. These are available in dynamo containers built with `container/build.sh --enable-media-ffmpeg ...`
## Image decoding options ## Image decoding options
- **max_image_width** (uint32, > 0): If the image width exceeds this value, abort the decoding. ### Limits (not overridable at runtime via `media_io_kwargs`)
- **max_image_height** (uint32, > 0): If the image height exceeds this value, abort the decoding. - **limits.max_image_width** (uint32, > 0): If the image width exceeds this value, abort the decoding.
- **max_alloc** (uint64, > 0): Maximum allowed total allocation (RAM) of the decoder in bytes - **limits.max_image_height** (uint32, > 0): If the image height exceeds this value, abort the decoding.
- **limits.max_alloc** (uint64, > 0): Maximum allowed total allocation (RAM) of the decoder in bytes
## Video decoding options ## Video decoding options
### Sampling ### Sampling
...@@ -63,9 +69,30 @@ There are two ways to configure video sampling: either with a fixed number of fr ...@@ -63,9 +69,30 @@ There are two ways to configure video sampling: either with a fixed number of fr
### Others ### Others
- **strict** (bool): if strict mode is enabled, any failure to decode a requested frame will abort the whole video decoding and error out. When strict mode is disabled, it is possible that the decoding of some requested frame fails, and the resulting set of decoded frames might container fewer frames than expected. - **strict** (bool): if strict mode is enabled, any failure to decode a requested frame will abort the whole video decoding and error out. When strict mode is disabled, it is possible that the decoding of some requested frame fails, and the resulting set of decoded frames might container fewer frames than expected.
- **max_alloc** (usize, > 0): If the total number of bytes in the decoded frames would exceed this value, abort the decoding. ### Limits (not overridable at runtime via `media_io_kwargs`)
- **limits.max_alloc** (usize, > 0): If the total number of bytes in the decoded frames would exceed this value, abort the decoding.
## Runtime media decoding options (`media_io_kwargs`)
Parameters of the decoders, can also be set at runtime via an extension to the OpenAI chat completions API. Limits defined in the MDC such as maximum image size, maximum RAM allocation, cannot be overridden at runtime.
This can be used for example to set the video sampling strategy for a request, that differs from the default one registered in the MDC:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": ...,
"messages": ...,
"media_io_kwargs": {
"video": {
"fps": 1.0,
"max_frames": 16
}
}
}'
```
## TODOs ## TODOs
...@@ -80,7 +107,7 @@ There are two ways to configure video sampling: either with a fixed number of fr ...@@ -80,7 +107,7 @@ There are two ways to configure video sampling: either with a fixed number of fr
- [x] Image SW decoding - [x] Image SW decoding
- [ ] Video HW decoding (NVDEC) - [ ] Video HW decoding (NVDEC)
- [ ] JPEG HW decoding (nvJPEG) - [ ] JPEG HW decoding (nvJPEG)
- [ ] Sparse video sampling (seek-forward) - [x] Sparse video sampling (seek-forward)
- [ ] Memory slab pre-allocation/registration - [ ] Memory slab pre-allocation/registration
### Memory management ### Memory management
...@@ -89,4 +116,4 @@ There are two ways to configure video sampling: either with a fixed number of fr ...@@ -89,4 +116,4 @@ There are two ways to configure video sampling: either with a fixed number of fr
### Misc ### Misc
- [ ] Observability on performance, memory usage and input distributions - [ ] Observability on performance, memory usage and input distributions
- [ ] Per-request decoding options - [x] Per-request decoding options
...@@ -18,6 +18,10 @@ pub use video::{VideoDecoder, VideoMetadata}; ...@@ -18,6 +18,10 @@ pub use video::{VideoDecoder, VideoMetadata};
pub trait Decoder: Clone + Send + 'static { pub trait Decoder: Clone + Send + 'static {
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData>; fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData>;
// Merges this decoder with an optional runtime override.
// Limits should always be enforced from the MDC config
fn with_runtime(&self, runtime: Option<&Self>) -> Self;
async fn decode_async(&self, data: EncodedMediaData) -> Result<DecodedMediaData> { async fn decode_async(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
// light clone (only config params) // light clone (only config params)
let decoder = self.clone(); let decoder = self.clone();
...@@ -27,13 +31,16 @@ pub trait Decoder: Clone + Send + 'static { ...@@ -27,13 +31,16 @@ pub trait Decoder: Clone + Send + 'static {
} }
} }
/// Media decoder configuration.
/// Used both for MDC server config and runtime `media_io_kwargs`.
/// When used at runtime, limits are enforced from MDC and cannot be overridden.
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)] #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct MediaDecoder { pub struct MediaDecoder {
#[serde(default)] #[serde(default, skip_serializing_if = "Option::is_none")]
pub image_decoder: ImageDecoder, pub image: Option<ImageDecoder>,
#[cfg(feature = "media-ffmpeg")] #[cfg(feature = "media-ffmpeg")]
#[serde(default)] #[serde(default, skip_serializing_if = "Option::is_none")]
pub video_decoder: VideoDecoder, pub video: Option<VideoDecoder>,
// TODO: audio decoder // TODO: audio decoder
} }
......
...@@ -14,19 +14,20 @@ use super::{DecodedMediaMetadata, Decoder}; ...@@ -14,19 +14,20 @@ use super::{DecodedMediaMetadata, Decoder};
const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB
/// Image decoder limits - can only be set via server config, not runtime kwargs.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct ImageDecoder { pub struct ImageDecoderLimits {
#[serde(default)] #[serde(default)]
pub(crate) max_image_width: Option<u32>, pub max_image_width: Option<u32>,
#[serde(default)] #[serde(default)]
pub(crate) max_image_height: Option<u32>, pub max_image_height: Option<u32>,
// maximum allowed total allocation of the decoder in bytes /// Maximum allowed total allocation of the decoder in bytes
#[serde(default)] #[serde(default)]
pub(crate) max_alloc: Option<u64>, pub max_alloc: Option<u64>,
} }
impl Default for ImageDecoder { impl Default for ImageDecoderLimits {
fn default() -> Self { fn default() -> Self {
Self { Self {
max_image_width: None, max_image_width: None,
...@@ -36,6 +37,13 @@ impl Default for ImageDecoder { ...@@ -36,6 +37,13 @@ impl Default for ImageDecoder {
} }
} }
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ImageDecoder {
#[serde(default)]
pub(crate) limits: ImageDecoderLimits,
}
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
#[derive(Serialize, Deserialize, Clone, Copy, Debug)] #[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum ImageLayout { pub enum ImageLayout {
...@@ -50,14 +58,25 @@ pub struct ImageMetadata { ...@@ -50,14 +58,25 @@ pub struct ImageMetadata {
} }
impl Decoder for ImageDecoder { impl Decoder for ImageDecoder {
fn with_runtime(&self, runtime: Option<&Self>) -> Self {
match runtime {
Some(r) => {
let mut d = r.clone();
d.limits.clone_from(&self.limits);
d
}
None => self.clone(),
}
}
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> { fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
let bytes = data.into_bytes()?; let bytes = data.into_bytes()?;
let mut reader = ImageReader::new(Cursor::new(bytes)).with_guessed_format()?; let mut reader = ImageReader::new(Cursor::new(bytes)).with_guessed_format()?;
let mut limits = image::Limits::no_limits(); let mut limits = image::Limits::no_limits();
limits.max_image_width = self.max_image_width; limits.max_image_width = self.limits.max_image_width;
limits.max_image_height = self.max_image_height; limits.max_image_height = self.limits.max_image_height;
limits.max_alloc = self.max_alloc; limits.max_alloc = self.limits.max_alloc;
reader.limits(limits); reader.limits(limits);
let format = reader.format(); let format = reader.format();
...@@ -177,9 +196,11 @@ mod tests { ...@@ -177,9 +196,11 @@ mod tests {
#[case] test_case: &str, #[case] test_case: &str,
) { ) {
let decoder = ImageDecoder { let decoder = ImageDecoder {
limits: ImageDecoderLimits {
max_image_width: max_width, max_image_width: max_width,
max_image_height: max_height, max_image_height: max_height,
max_alloc: Some(DEFAULT_MAX_ALLOC), max_alloc: Some(DEFAULT_MAX_ALLOC),
},
}; };
let image_bytes = create_test_image(width, height, 3, format); // RGB let image_bytes = create_test_image(width, height, 3, format); // RGB
let encoded_data = create_encoded_media_data(image_bytes); let encoded_data = create_encoded_media_data(image_bytes);
...@@ -252,4 +273,33 @@ mod tests { ...@@ -252,4 +273,33 @@ mod tests {
format format
); );
} }
#[test]
fn test_with_runtime_limit_enforcement() {
let server_limits = ImageDecoderLimits {
max_image_width: Some(100),
max_image_height: Some(100),
max_alloc: Some(1024),
};
let server_config = ImageDecoder {
limits: server_limits.clone(),
};
// Runtime config tries to override limits (should be ignored)
let runtime_limits = ImageDecoderLimits {
max_image_width: Some(9999),
max_image_height: Some(9999),
max_alloc: Some(999999),
};
let runtime_config = ImageDecoder {
limits: runtime_limits,
};
let merged = server_config.with_runtime(Some(&runtime_config));
// Check that server limits are preserved
assert_eq!(merged.limits.max_image_width, Some(100));
assert_eq!(merged.limits.max_image_height, Some(100));
assert_eq!(merged.limits.max_alloc, Some(1024));
}
} }
...@@ -20,10 +20,30 @@ use crate::preprocessor::media::{ ...@@ -20,10 +20,30 @@ use crate::preprocessor::media::{
/// Small time buffer (seconds) to avoid edge cases when seeking near frame boundaries /// Small time buffer (seconds) to avoid edge cases when seeking near frame boundaries
const FRAME_TIME_BUFFER_SECS: f64 = 0.001; const FRAME_TIME_BUFFER_SECS: f64 = 0.001;
const DEFAULT_MAX_ALLOC: u64 = 512 * 1024 * 1024; // 512 MB
#[derive(Clone, Default, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct VideoDecoderLimits {
/// Maximum allowed total allocation of decoded frames in bytes
#[serde(default)]
pub max_alloc: Option<u64>,
}
impl Default for VideoDecoderLimits {
fn default() -> Self {
Self {
max_alloc: Some(DEFAULT_MAX_ALLOC),
}
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct VideoDecoder { pub struct VideoDecoder {
#[serde(default)]
pub(crate) limits: VideoDecoderLimits,
/// sample N frames per second /// sample N frames per second
#[serde(default)] #[serde(default)]
pub(crate) fps: Option<f64>, pub(crate) fps: Option<f64>,
...@@ -36,9 +56,6 @@ pub struct VideoDecoder { ...@@ -36,9 +56,6 @@ pub struct VideoDecoder {
/// fail if some frames fail to decode /// fail if some frames fail to decode
#[serde(default)] #[serde(default)]
pub(crate) strict: bool, pub(crate) strict: bool,
/// maximum allowed total allocation of the decoded frames in bytes
#[serde(default)]
pub(crate) max_alloc: Option<u64>,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
...@@ -182,6 +199,17 @@ fn decode_frame_at_timestamp( ...@@ -182,6 +199,17 @@ fn decode_frame_at_timestamp(
} }
impl Decoder for VideoDecoder { impl Decoder for VideoDecoder {
fn with_runtime(&self, runtime: Option<&Self>) -> Self {
match runtime {
Some(r) => {
let mut d = r.clone();
d.limits.clone_from(&self.limits);
d
}
None => self.clone(),
}
}
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> { fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
anyhow::ensure!( anyhow::ensure!(
self.fps.is_none() || self.num_frames.is_none(), self.fps.is_none() || self.num_frames.is_none(),
...@@ -212,7 +240,7 @@ impl Decoder for VideoDecoder { ...@@ -212,7 +240,7 @@ impl Decoder for VideoDecoder {
"Invalid video dimensions {width}x{height}" "Invalid video dimensions {width}x{height}"
); );
let max_alloc = self.max_alloc.unwrap_or(u64::MAX); let max_alloc = self.limits.max_alloc.unwrap_or(u64::MAX);
anyhow::ensure!( anyhow::ensure!(
(width as u64) * (height as u64) * requested_frames * 3 <= max_alloc, (width as u64) * (height as u64) * requested_frames * 3 <= max_alloc,
"Video dimensions {requested_frames}x{width}x{height}x3 exceed max alloc {max_alloc}" "Video dimensions {requested_frames}x{width}x{height}x3 exceed max alloc {max_alloc}"
...@@ -320,11 +348,11 @@ mod tests { ...@@ -320,11 +348,11 @@ mod tests {
let requested_frames = 5u64; let requested_frames = 5u64;
let decoder = VideoDecoder { let decoder = VideoDecoder {
limits: VideoDecoderLimits::default(),
fps: None, fps: None,
max_frames: None, max_frames: None,
num_frames: Some(requested_frames), num_frames: Some(requested_frames),
strict: false, strict: false,
max_alloc: None,
}; };
let decoded = decoder.decode(encoded_data).unwrap(); let decoded = decoder.decode(encoded_data).unwrap();
...@@ -342,11 +370,11 @@ mod tests { ...@@ -342,11 +370,11 @@ mod tests {
let target_fps = 0.5f64; let target_fps = 0.5f64;
let decoder = VideoDecoder { let decoder = VideoDecoder {
limits: VideoDecoderLimits::default(),
fps: Some(target_fps), fps: Some(target_fps),
max_frames: None, max_frames: None,
num_frames: None, num_frames: None,
strict: false, strict: false,
max_alloc: None,
}; };
let decoded = decoder.decode(encoded_data).unwrap(); let decoded = decoder.decode(encoded_data).unwrap();
...@@ -375,11 +403,11 @@ mod tests { ...@@ -375,11 +403,11 @@ mod tests {
let (encoded_data, width, height, _) = load_test_video(video_file); let (encoded_data, width, height, _) = load_test_video(video_file);
let decoder = VideoDecoder { let decoder = VideoDecoder {
limits: VideoDecoderLimits { max_alloc },
fps: None, fps: None,
max_frames: None, max_frames: None,
num_frames: Some(num_frames), num_frames: Some(num_frames),
strict: false, strict: false,
max_alloc,
}; };
let result = decoder.decode(encoded_data); let result = decoder.decode(encoded_data);
...@@ -403,11 +431,11 @@ mod tests { ...@@ -403,11 +431,11 @@ mod tests {
let (encoded_data, ..) = load_test_video("240p_10.mp4"); let (encoded_data, ..) = load_test_video("240p_10.mp4");
let decoder = VideoDecoder { let decoder = VideoDecoder {
limits: VideoDecoderLimits::default(),
fps: Some(2.0f64), fps: Some(2.0f64),
max_frames: None, max_frames: None,
num_frames: Some(5u64), num_frames: Some(5u64),
strict: false, strict: false,
max_alloc: None,
}; };
let result = decoder.decode(encoded_data); let result = decoder.decode(encoded_data);
...@@ -442,4 +470,35 @@ mod tests { ...@@ -442,4 +470,35 @@ mod tests {
last_time last_time
); );
} }
#[test]
fn test_with_runtime_limit_enforcement() {
let server_limits = VideoDecoderLimits {
max_alloc: Some(1024),
};
let server_config = VideoDecoder {
limits: server_limits,
fps: Some(1.0),
..Default::default()
};
// Runtime config tries to override limits (should be ignored)
// And sets different FPS (should be accepted)
let runtime_limits = VideoDecoderLimits {
max_alloc: Some(999999),
};
let runtime_config = VideoDecoder {
limits: runtime_limits,
fps: Some(60.0),
..Default::default()
};
let merged = server_config.with_runtime(Some(&runtime_config));
// Check that server limits are preserved
assert_eq!(merged.limits.max_alloc, Some(1024));
// Check that other fields are overridden
assert_eq!(merged.fps, Some(60.0));
}
} }
...@@ -104,11 +104,11 @@ impl MediaLoader { ...@@ -104,11 +104,11 @@ impl MediaLoader {
pub async fn fetch_and_decode_media_part( pub async fn fetch_and_decode_media_part(
&self, &self,
oai_content_part: &ChatCompletionRequestUserMessageContentPart, oai_content_part: &ChatCompletionRequestUserMessageContentPart,
// TODO: request-level options media_io_kwargs: Option<&MediaDecoder>,
) -> Result<RdmaMediaDataDescriptor> { ) -> Result<RdmaMediaDataDescriptor> {
#[cfg(not(feature = "media-nixl"))] #[cfg(not(feature = "media-nixl"))]
anyhow::bail!( anyhow::bail!(
"NIXL is not supported, cannot decode and register media data {oai_content_part:?}" "NIXL is not supported, cannot decode and register media data {oai_content_part:?} with media_io_kwargs {media_io_kwargs:?}"
); );
#[cfg(feature = "media-nixl")] #[cfg(feature = "media-nixl")]
...@@ -116,23 +116,41 @@ impl MediaLoader { ...@@ -116,23 +116,41 @@ impl MediaLoader {
// fetch the media, decode and NIXL-register // fetch the media, decode and NIXL-register
let decoded = match oai_content_part { let decoded = match oai_content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => { ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
let mdc_decoder =
self.media_decoder.image.as_ref().ok_or_else(|| {
anyhow::anyhow!("Model does not support image inputs")
})?;
let url = &image_part.image_url.url; let url = &image_part.image_url.url;
self.check_if_url_allowed(url)?; self.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?; let data = EncodedMediaData::from_url(url, &self.http_client).await?;
self.media_decoder.image_decoder.decode_async(data).await?
// Use runtime decoder if provided, with MDC limits enforced
let decoder =
mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.image.as_ref()));
decoder.decode_async(data).await?
} }
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => { ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
let url = &video_part.video_url.url;
self.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
#[cfg(not(feature = "media-ffmpeg"))] #[cfg(not(feature = "media-ffmpeg"))]
anyhow::bail!( anyhow::bail!(
"Video decoding requires the 'media-ffmpeg' feature to be enabled" "Video decoding requires the 'media-ffmpeg' feature to be enabled"
); );
#[cfg(feature = "media-ffmpeg")] #[cfg(feature = "media-ffmpeg")]
self.media_decoder.video_decoder.decode_async(data).await? {
let mdc_decoder = self.media_decoder.video.as_ref().ok_or_else(|| {
anyhow::anyhow!("Model does not support video inputs")
})?;
let url = &video_part.video_url.url;
self.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
// Use runtime decoder if provided, with MDC limits enforced
let decoder = mdc_decoder
.with_runtime(media_io_kwargs.and_then(|k| k.video.as_ref()));
decoder.decode_async(data).await?
}
} }
ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => { ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
anyhow::bail!("Audio decoding is not supported yet"); anyhow::bail!("Audio decoding is not supported yet");
...@@ -148,6 +166,7 @@ impl MediaLoader { ...@@ -148,6 +166,7 @@ impl MediaLoader {
#[cfg(all(test, feature = "media-nixl"))] #[cfg(all(test, feature = "media-nixl"))]
mod tests { mod tests {
use super::super::decoders::ImageDecoder;
use super::super::rdma::DataType; use super::super::rdma::DataType;
use super::*; use super::*;
use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl}; use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};
...@@ -166,7 +185,11 @@ mod tests { ...@@ -166,7 +185,11 @@ mod tests {
.create_async() .create_async()
.await; .await;
let media_decoder = MediaDecoder::default(); let media_decoder = MediaDecoder {
image: Some(ImageDecoder::default()),
#[cfg(feature = "media-ffmpeg")]
video: None,
};
let fetcher = MediaFetcher { let fetcher = MediaFetcher {
allow_direct_ip: true, allow_direct_ip: true,
allow_direct_port: true, allow_direct_port: true,
...@@ -180,7 +203,9 @@ mod tests { ...@@ -180,7 +203,9 @@ mod tests {
ChatCompletionRequestMessageContentPartImage { image_url }, ChatCompletionRequestMessageContentPartImage { image_url },
); );
let result = loader.fetch_and_decode_media_part(&content_part).await; let result = loader
.fetch_and_decode_media_part(&content_part, None)
.await;
let descriptor = match result { let descriptor = match result {
Ok(descriptor) => descriptor, Ok(descriptor) => descriptor,
......
...@@ -23,6 +23,8 @@ use minijinja::value::Value; ...@@ -23,6 +23,8 @@ use minijinja::value::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use crate::preprocessor::media::MediaDecoder;
pub mod deepseek_v32; pub mod deepseek_v32;
mod template; mod template;
...@@ -77,6 +79,10 @@ pub trait OAIChatLikeRequest { ...@@ -77,6 +79,10 @@ pub trait OAIChatLikeRequest {
fn extract_text(&self) -> Option<TextInput> { fn extract_text(&self) -> Option<TextInput> {
None None
} }
fn media_io_kwargs(&self) -> Option<&MediaDecoder> {
None
}
} }
pub trait OAIPromptFormatter: Send + Sync + 'static { pub trait OAIPromptFormatter: Send + Sync + 'static {
......
...@@ -6,6 +6,7 @@ use super::*; ...@@ -6,6 +6,7 @@ use super::*;
use minijinja::{context, value::Value}; use minijinja::{context, value::Value};
use std::result::Result::Ok; use std::result::Result::Ok;
use crate::preprocessor::media::MediaDecoder;
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest, chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
}; };
...@@ -214,6 +215,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest { ...@@ -214,6 +215,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn chat_template_args(&self) -> Option<&std::collections::HashMap<String, serde_json::Value>> { fn chat_template_args(&self) -> Option<&std::collections::HashMap<String, serde_json::Value>> {
self.chat_template_args.as_ref() self.chat_template_args.as_ref()
} }
fn media_io_kwargs(&self) -> Option<&MediaDecoder> {
self.media_io_kwargs.as_ref()
}
} }
impl OAIChatLikeRequest for NvCreateCompletionRequest { impl OAIChatLikeRequest for NvCreateCompletionRequest {
......
...@@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize}; ...@@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
use validator::Validate; use validator::Validate;
use crate::engines::ValidateRequest; use crate::engines::ValidateRequest;
use crate::preprocessor::media::MediaDecoder;
use super::{ use super::{
OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
...@@ -45,6 +46,12 @@ pub struct NvCreateChatCompletionRequest { ...@@ -45,6 +46,12 @@ pub struct NvCreateChatCompletionRequest {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>, pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>,
/// Runtime media decoding parameters.
/// When provided, these override the MDC defaults
/// Example: `{"video": {"num_frames": 16}}`
#[serde(default, skip_serializing_if = "Option::is_none")]
pub media_io_kwargs: Option<MediaDecoder>,
/// Catch-all for unsupported fields - checked during validation /// Catch-all for unsupported fields - checked during validation
#[serde(flatten, default, skip_serializing)] #[serde(flatten, default, skip_serializing)]
pub unsupported_fields: std::collections::HashMap<String, serde_json::Value>, pub unsupported_fields: std::collections::HashMap<String, serde_json::Value>,
......
...@@ -508,6 +508,7 @@ mod tests { ...@@ -508,6 +508,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
} }
} }
......
...@@ -193,6 +193,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest { ...@@ -193,6 +193,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
common: Default::default(), common: Default::default(),
nvext: resp.nvext, nvext: resp.nvext,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}) })
} }
......
...@@ -773,6 +773,7 @@ async fn test_nv_custom_client() { ...@@ -773,6 +773,7 @@ async fn test_nv_custom_client() {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
...@@ -814,6 +815,7 @@ async fn test_nv_custom_client() { ...@@ -814,6 +815,7 @@ async fn test_nv_custom_client() {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
...@@ -856,6 +858,7 @@ async fn test_nv_custom_client() { ...@@ -856,6 +858,7 @@ async fn test_nv_custom_client() {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
......
...@@ -91,6 +91,7 @@ fn create_mock_chat_completion_request() -> NvCreateChatCompletionRequest { ...@@ -91,6 +91,7 @@ fn create_mock_chat_completion_request() -> NvCreateChatCompletionRequest {
common: CommonExt::default(), common: CommonExt::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
} }
} }
......
...@@ -283,6 +283,7 @@ impl Request { ...@@ -283,6 +283,7 @@ impl Request {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
} }
} }
......
...@@ -68,6 +68,7 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() { ...@@ -68,6 +68,7 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() {
.unwrap(), .unwrap(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
...@@ -297,6 +298,7 @@ fn test_serialization_preserves_structure() { ...@@ -297,6 +298,7 @@ fn test_serialization_preserves_structure() {
..Default::default() ..Default::default()
}), }),
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
...@@ -348,6 +350,7 @@ fn test_sampling_parameters_extraction() { ...@@ -348,6 +350,7 @@ fn test_sampling_parameters_extraction() {
.unwrap(), .unwrap(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(), unsupported_fields: Default::default(),
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment