"tests/vscode:/vscode.git/clone" did not exist on "a96197f564cb848c5a89ae3438a8bc21f5f5d2c6"
Unverified Commit d2faf0e6 authored by milesial's avatar milesial Committed by GitHub
Browse files

feat: Runtime media decoder config (#5011)


Signed-off-by: default avatarAlexandre Milesi <milesial@users.noreply.github.com>
parent 481dc636
......@@ -92,20 +92,20 @@ impl MediaDecoder {
}
}
fn image_decoder(&mut self, image_decoder: &Bound<'_, PyDict>) -> PyResult<()> {
let image_decoder = pythonize::depythonize(image_decoder).map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to parse image_decoder: {}", err))
fn enable_image(&mut self, decoder_options: &Bound<'_, PyDict>) -> PyResult<()> {
let decoder_options = pythonize::depythonize(decoder_options).map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to parse image decoder config: {}", err))
})?;
self.inner.image_decoder = image_decoder;
self.inner.image = Some(decoder_options);
Ok(())
}
#[cfg(feature = "media-ffmpeg")]
fn video_decoder(&mut self, video_decoder: &Bound<'_, PyDict>) -> PyResult<()> {
let video_decoder = pythonize::depythonize(video_decoder).map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to parse video_decoder: {}", err))
fn enable_video(&mut self, decoder_options: &Bound<'_, PyDict>) -> PyResult<()> {
let decoder_options = pythonize::depythonize(decoder_options).map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to parse video decoder config: {}", err))
})?;
self.inner.video_decoder = video_decoder;
self.inner.video = Some(decoder_options);
Ok(())
}
}
......
......@@ -229,6 +229,7 @@ async fn evaluate(
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let mut stream = engine.generate(Context::new(req)).await?;
......
......@@ -114,6 +114,7 @@ async fn main_loop(
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
......
......@@ -181,7 +181,7 @@ impl ErrorMessage {
// Then check for HttpError
match err.downcast::<HttpError>() {
Ok(http_error) => ErrorMessage::from_http_error(http_error),
Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err}")),
Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err:#}")),
}
}
......@@ -1735,6 +1735,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_required_fields(&request);
......@@ -1766,6 +1767,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_required_fields(&request);
......@@ -1981,6 +1983,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
......@@ -2010,6 +2013,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
......@@ -2038,6 +2042,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
......@@ -2066,6 +2071,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
......@@ -2096,6 +2102,7 @@ mod tests {
.unwrap(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
......@@ -2124,6 +2131,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
......
......@@ -15,7 +15,7 @@ use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::ModelDeploymentCard;
use crate::model_type::{ModelInput, ModelType};
use crate::preprocessor::media::{MediaDecoder, MediaFetcher};
use crate::preprocessor::media::{ImageDecoder, MediaDecoder, MediaFetcher};
use crate::request_template::RequestTemplate;
pub mod runtime_config;
......@@ -243,7 +243,11 @@ impl LocalModelBuilder {
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer;
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
self.media_decoder = Some(MediaDecoder::default());
self.media_decoder = Some(MediaDecoder {
image: Some(ImageDecoder::default()),
#[cfg(feature = "media-ffmpeg")]
video: None,
});
self.media_fetcher = Some(MediaFetcher::default());
}
......
......@@ -345,11 +345,10 @@ impl OpenAIPreprocessor {
#[cfg(feature = "media-nixl")]
if !fetch_tasks.is_empty() {
let loader = self.media_loader.as_ref().unwrap();
let results = futures::future::join_all(
fetch_tasks
.iter()
.map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)),
)
let media_io_kwargs = request.media_io_kwargs();
let results = futures::future::join_all(fetch_tasks.iter().map(|(_, content_part)| {
loader.fetch_and_decode_media_part(content_part, media_io_kwargs)
}))
.await;
for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) {
......
......@@ -19,16 +19,18 @@ fetcher.allow_direct_port(False)
fetcher.allowed_media_domains(["google.com"])
```
Set media decoding options:
Set media decoding default options and limits:
```python
from dynamo.llm import MediaDecoder
decoder = MediaDecoder()
decoder.image_decoder({"max_image_width": 4096, "max_image_height": 4096, "max_alloc": 16*1024*1024})
decoder.video_decoder({"strict": True, "fps": 2.0, "max_frames": 128, "max_alloc": 1024*1024*128*3})
decoder.enable_image({"limits": {"max_image_width": 4096, "max_image_height": 4096, "max_alloc": 16*1024*1024}})
decoder.enable_video({"fps": 2.0, "max_frames": 128, "limits": {"max_alloc": 1024*1024*128*3}})
```
And register the LLM as usual, adding the media configuration:
If `enable_image` or `enable_video` are not called, requests containing the corresponding modality will be rejected.
Register the LLM as usual, adding the media configuration:
```python
register_llm(
......@@ -47,11 +49,15 @@ register_llm(
> [!WARNING]
> **Requires GPU node**: The frontend must run on a node with GPU access. During media processing, decoded tensors are written to GPU memory via NIXL, which requires `libcuda.so.1` to be available. Running the frontend on a CPU-only node will fail with something like: `Failed to initialize required backends: [UCX: No UCX plugin found]`.
> [!WARNING]
> **Video decoding**: Video decoding needs to be enabled via the `dynamo-llm/media-ffmpeg` rust feature. The following ffmpeg dynamic libraries must be available on the system: `libavcodec`, `libavdevice`, `libavfilter`, `libavformat`, `libswresample`, `libswscale`. These are available in dynamo containers built with `container/build.sh --enable-media-ffmpeg ...`
## Image decoding options
- **max_image_width** (uint32, > 0): If the image width exceeds this value, abort the decoding.
- **max_image_height** (uint32, > 0): If the image height exceeds this value, abort the decoding.
- **max_alloc** (uint64, > 0): Maximum allowed total allocation (RAM) of the decoder in bytes
### Limits (not overridable at runtime via `media_io_kwargs`)
- **limits.max_image_width** (uint32, > 0): If the image width exceeds this value, abort the decoding.
- **limits.max_image_height** (uint32, > 0): If the image height exceeds this value, abort the decoding.
- **limits.max_alloc** (uint64, > 0): Maximum allowed total allocation (RAM) of the decoder in bytes
## Video decoding options
### Sampling
......@@ -63,9 +69,30 @@ There are two ways to configure video sampling: either with a fixed number of fr
### Others
- **strict** (bool): if strict mode is enabled, any failure to decode a requested frame will abort the whole video decoding and error out. When strict mode is disabled, it is possible that the decoding of some requested frame fails, and the resulting set of decoded frames might container fewer frames than expected.
- **max_alloc** (usize, > 0): If the total number of bytes in the decoded frames would exceed this value, abort the decoding.
### Limits (not overridable at runtime via `media_io_kwargs`)
- **limits.max_alloc** (usize, > 0): If the total number of bytes in the decoded frames would exceed this value, abort the decoding.
## Runtime media decoding options (`media_io_kwargs`)
Parameters of the decoders, can also be set at runtime via an extension to the OpenAI chat completions API. Limits defined in the MDC such as maximum image size, maximum RAM allocation, cannot be overridden at runtime.
This can be used for example to set the video sampling strategy for a request, that differs from the default one registered in the MDC:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": ...,
"messages": ...,
"media_io_kwargs": {
"video": {
"fps": 1.0,
"max_frames": 16
}
}
}'
```
## TODOs
......@@ -80,7 +107,7 @@ There are two ways to configure video sampling: either with a fixed number of fr
- [x] Image SW decoding
- [ ] Video HW decoding (NVDEC)
- [ ] JPEG HW decoding (nvJPEG)
- [ ] Sparse video sampling (seek-forward)
- [x] Sparse video sampling (seek-forward)
- [ ] Memory slab pre-allocation/registration
### Memory management
......@@ -89,4 +116,4 @@ There are two ways to configure video sampling: either with a fixed number of fr
### Misc
- [ ] Observability on performance, memory usage and input distributions
- [ ] Per-request decoding options
- [x] Per-request decoding options
......@@ -18,6 +18,10 @@ pub use video::{VideoDecoder, VideoMetadata};
pub trait Decoder: Clone + Send + 'static {
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData>;
// Merges this decoder with an optional runtime override.
// Limits should always be enforced from the MDC config
fn with_runtime(&self, runtime: Option<&Self>) -> Self;
async fn decode_async(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
// light clone (only config params)
let decoder = self.clone();
......@@ -27,13 +31,16 @@ pub trait Decoder: Clone + Send + 'static {
}
}
/// Media decoder configuration.
/// Used both for MDC server config and runtime `media_io_kwargs`.
/// When used at runtime, limits are enforced from MDC and cannot be overridden.
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct MediaDecoder {
#[serde(default)]
pub image_decoder: ImageDecoder,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image: Option<ImageDecoder>,
#[cfg(feature = "media-ffmpeg")]
#[serde(default)]
pub video_decoder: VideoDecoder,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub video: Option<VideoDecoder>,
// TODO: audio decoder
}
......
......@@ -14,19 +14,20 @@ use super::{DecodedMediaMetadata, Decoder};
const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB
/// Image decoder limits - can only be set via server config, not runtime kwargs.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ImageDecoder {
pub struct ImageDecoderLimits {
#[serde(default)]
pub(crate) max_image_width: Option<u32>,
pub max_image_width: Option<u32>,
#[serde(default)]
pub(crate) max_image_height: Option<u32>,
// maximum allowed total allocation of the decoder in bytes
pub max_image_height: Option<u32>,
/// Maximum allowed total allocation of the decoder in bytes
#[serde(default)]
pub(crate) max_alloc: Option<u64>,
pub max_alloc: Option<u64>,
}
impl Default for ImageDecoder {
impl Default for ImageDecoderLimits {
fn default() -> Self {
Self {
max_image_width: None,
......@@ -36,6 +37,13 @@ impl Default for ImageDecoder {
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ImageDecoder {
#[serde(default)]
pub(crate) limits: ImageDecoderLimits,
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum ImageLayout {
......@@ -50,14 +58,25 @@ pub struct ImageMetadata {
}
impl Decoder for ImageDecoder {
fn with_runtime(&self, runtime: Option<&Self>) -> Self {
match runtime {
Some(r) => {
let mut d = r.clone();
d.limits.clone_from(&self.limits);
d
}
None => self.clone(),
}
}
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
let bytes = data.into_bytes()?;
let mut reader = ImageReader::new(Cursor::new(bytes)).with_guessed_format()?;
let mut limits = image::Limits::no_limits();
limits.max_image_width = self.max_image_width;
limits.max_image_height = self.max_image_height;
limits.max_alloc = self.max_alloc;
limits.max_image_width = self.limits.max_image_width;
limits.max_image_height = self.limits.max_image_height;
limits.max_alloc = self.limits.max_alloc;
reader.limits(limits);
let format = reader.format();
......@@ -177,9 +196,11 @@ mod tests {
#[case] test_case: &str,
) {
let decoder = ImageDecoder {
limits: ImageDecoderLimits {
max_image_width: max_width,
max_image_height: max_height,
max_alloc: Some(DEFAULT_MAX_ALLOC),
},
};
let image_bytes = create_test_image(width, height, 3, format); // RGB
let encoded_data = create_encoded_media_data(image_bytes);
......@@ -252,4 +273,33 @@ mod tests {
format
);
}
#[test]
fn test_with_runtime_limit_enforcement() {
let server_limits = ImageDecoderLimits {
max_image_width: Some(100),
max_image_height: Some(100),
max_alloc: Some(1024),
};
let server_config = ImageDecoder {
limits: server_limits.clone(),
};
// Runtime config tries to override limits (should be ignored)
let runtime_limits = ImageDecoderLimits {
max_image_width: Some(9999),
max_image_height: Some(9999),
max_alloc: Some(999999),
};
let runtime_config = ImageDecoder {
limits: runtime_limits,
};
let merged = server_config.with_runtime(Some(&runtime_config));
// Check that server limits are preserved
assert_eq!(merged.limits.max_image_width, Some(100));
assert_eq!(merged.limits.max_image_height, Some(100));
assert_eq!(merged.limits.max_alloc, Some(1024));
}
}
......@@ -20,10 +20,30 @@ use crate::preprocessor::media::{
/// Small time buffer (seconds) to avoid edge cases when seeking near frame boundaries
const FRAME_TIME_BUFFER_SECS: f64 = 0.001;
const DEFAULT_MAX_ALLOC: u64 = 512 * 1024 * 1024; // 512 MB
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct VideoDecoderLimits {
/// Maximum allowed total allocation of decoded frames in bytes
#[serde(default)]
pub max_alloc: Option<u64>,
}
impl Default for VideoDecoderLimits {
fn default() -> Self {
Self {
max_alloc: Some(DEFAULT_MAX_ALLOC),
}
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct VideoDecoder {
#[serde(default)]
pub(crate) limits: VideoDecoderLimits,
/// sample N frames per second
#[serde(default)]
pub(crate) fps: Option<f64>,
......@@ -36,9 +56,6 @@ pub struct VideoDecoder {
/// fail if some frames fail to decode
#[serde(default)]
pub(crate) strict: bool,
/// maximum allowed total allocation of the decoded frames in bytes
#[serde(default)]
pub(crate) max_alloc: Option<u64>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
......@@ -182,6 +199,17 @@ fn decode_frame_at_timestamp(
}
impl Decoder for VideoDecoder {
fn with_runtime(&self, runtime: Option<&Self>) -> Self {
match runtime {
Some(r) => {
let mut d = r.clone();
d.limits.clone_from(&self.limits);
d
}
None => self.clone(),
}
}
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
anyhow::ensure!(
self.fps.is_none() || self.num_frames.is_none(),
......@@ -212,7 +240,7 @@ impl Decoder for VideoDecoder {
"Invalid video dimensions {width}x{height}"
);
let max_alloc = self.max_alloc.unwrap_or(u64::MAX);
let max_alloc = self.limits.max_alloc.unwrap_or(u64::MAX);
anyhow::ensure!(
(width as u64) * (height as u64) * requested_frames * 3 <= max_alloc,
"Video dimensions {requested_frames}x{width}x{height}x3 exceed max alloc {max_alloc}"
......@@ -320,11 +348,11 @@ mod tests {
let requested_frames = 5u64;
let decoder = VideoDecoder {
limits: VideoDecoderLimits::default(),
fps: None,
max_frames: None,
num_frames: Some(requested_frames),
strict: false,
max_alloc: None,
};
let decoded = decoder.decode(encoded_data).unwrap();
......@@ -342,11 +370,11 @@ mod tests {
let target_fps = 0.5f64;
let decoder = VideoDecoder {
limits: VideoDecoderLimits::default(),
fps: Some(target_fps),
max_frames: None,
num_frames: None,
strict: false,
max_alloc: None,
};
let decoded = decoder.decode(encoded_data).unwrap();
......@@ -375,11 +403,11 @@ mod tests {
let (encoded_data, width, height, _) = load_test_video(video_file);
let decoder = VideoDecoder {
limits: VideoDecoderLimits { max_alloc },
fps: None,
max_frames: None,
num_frames: Some(num_frames),
strict: false,
max_alloc,
};
let result = decoder.decode(encoded_data);
......@@ -403,11 +431,11 @@ mod tests {
let (encoded_data, ..) = load_test_video("240p_10.mp4");
let decoder = VideoDecoder {
limits: VideoDecoderLimits::default(),
fps: Some(2.0f64),
max_frames: None,
num_frames: Some(5u64),
strict: false,
max_alloc: None,
};
let result = decoder.decode(encoded_data);
......@@ -442,4 +470,35 @@ mod tests {
last_time
);
}
#[test]
fn test_with_runtime_limit_enforcement() {
let server_limits = VideoDecoderLimits {
max_alloc: Some(1024),
};
let server_config = VideoDecoder {
limits: server_limits,
fps: Some(1.0),
..Default::default()
};
// Runtime config tries to override limits (should be ignored)
// And sets different FPS (should be accepted)
let runtime_limits = VideoDecoderLimits {
max_alloc: Some(999999),
};
let runtime_config = VideoDecoder {
limits: runtime_limits,
fps: Some(60.0),
..Default::default()
};
let merged = server_config.with_runtime(Some(&runtime_config));
// Check that server limits are preserved
assert_eq!(merged.limits.max_alloc, Some(1024));
// Check that other fields are overridden
assert_eq!(merged.fps, Some(60.0));
}
}
......@@ -104,11 +104,11 @@ impl MediaLoader {
pub async fn fetch_and_decode_media_part(
&self,
oai_content_part: &ChatCompletionRequestUserMessageContentPart,
// TODO: request-level options
media_io_kwargs: Option<&MediaDecoder>,
) -> Result<RdmaMediaDataDescriptor> {
#[cfg(not(feature = "media-nixl"))]
anyhow::bail!(
"NIXL is not supported, cannot decode and register media data {oai_content_part:?}"
"NIXL is not supported, cannot decode and register media data {oai_content_part:?} with media_io_kwargs {media_io_kwargs:?}"
);
#[cfg(feature = "media-nixl")]
......@@ -116,23 +116,41 @@ impl MediaLoader {
// fetch the media, decode and NIXL-register
let decoded = match oai_content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
let mdc_decoder =
self.media_decoder.image.as_ref().ok_or_else(|| {
anyhow::anyhow!("Model does not support image inputs")
})?;
let url = &image_part.image_url.url;
self.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
self.media_decoder.image_decoder.decode_async(data).await?
// Use runtime decoder if provided, with MDC limits enforced
let decoder =
mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.image.as_ref()));
decoder.decode_async(data).await?
}
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
let url = &video_part.video_url.url;
self.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
#[cfg(not(feature = "media-ffmpeg"))]
anyhow::bail!(
"Video decoding requires the 'media-ffmpeg' feature to be enabled"
);
#[cfg(feature = "media-ffmpeg")]
self.media_decoder.video_decoder.decode_async(data).await?
{
let mdc_decoder = self.media_decoder.video.as_ref().ok_or_else(|| {
anyhow::anyhow!("Model does not support video inputs")
})?;
let url = &video_part.video_url.url;
self.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
// Use runtime decoder if provided, with MDC limits enforced
let decoder = mdc_decoder
.with_runtime(media_io_kwargs.and_then(|k| k.video.as_ref()));
decoder.decode_async(data).await?
}
}
ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
anyhow::bail!("Audio decoding is not supported yet");
......@@ -148,6 +166,7 @@ impl MediaLoader {
#[cfg(all(test, feature = "media-nixl"))]
mod tests {
use super::super::decoders::ImageDecoder;
use super::super::rdma::DataType;
use super::*;
use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};
......@@ -166,7 +185,11 @@ mod tests {
.create_async()
.await;
let media_decoder = MediaDecoder::default();
let media_decoder = MediaDecoder {
image: Some(ImageDecoder::default()),
#[cfg(feature = "media-ffmpeg")]
video: None,
};
let fetcher = MediaFetcher {
allow_direct_ip: true,
allow_direct_port: true,
......@@ -180,7 +203,9 @@ mod tests {
ChatCompletionRequestMessageContentPartImage { image_url },
);
let result = loader.fetch_and_decode_media_part(&content_part).await;
let result = loader
.fetch_and_decode_media_part(&content_part, None)
.await;
let descriptor = match result {
Ok(descriptor) => descriptor,
......
......@@ -23,6 +23,8 @@ use minijinja::value::Value;
use std::collections::HashMap;
use std::sync::Arc;
use crate::preprocessor::media::MediaDecoder;
pub mod deepseek_v32;
mod template;
......@@ -77,6 +79,10 @@ pub trait OAIChatLikeRequest {
fn extract_text(&self) -> Option<TextInput> {
None
}
fn media_io_kwargs(&self) -> Option<&MediaDecoder> {
None
}
}
pub trait OAIPromptFormatter: Send + Sync + 'static {
......
......@@ -6,6 +6,7 @@ use super::*;
use minijinja::{context, value::Value};
use std::result::Result::Ok;
use crate::preprocessor::media::MediaDecoder;
use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
};
......@@ -214,6 +215,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn chat_template_args(&self) -> Option<&std::collections::HashMap<String, serde_json::Value>> {
self.chat_template_args.as_ref()
}
fn media_io_kwargs(&self) -> Option<&MediaDecoder> {
self.media_io_kwargs.as_ref()
}
}
impl OAIChatLikeRequest for NvCreateCompletionRequest {
......
......@@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
use validator::Validate;
use crate::engines::ValidateRequest;
use crate::preprocessor::media::MediaDecoder;
use super::{
OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
......@@ -45,6 +46,12 @@ pub struct NvCreateChatCompletionRequest {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>,
/// Runtime media decoding parameters.
/// When provided, these override the MDC defaults
/// Example: `{"video": {"num_frames": 16}}`
#[serde(default, skip_serializing_if = "Option::is_none")]
pub media_io_kwargs: Option<MediaDecoder>,
/// Catch-all for unsupported fields - checked during validation
#[serde(flatten, default, skip_serializing)]
pub unsupported_fields: std::collections::HashMap<String, serde_json::Value>,
......
......@@ -508,6 +508,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
}
}
......
......@@ -193,6 +193,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
common: Default::default(),
nvext: resp.nvext,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
})
}
......
......@@ -773,6 +773,7 @@ async fn test_nv_custom_client() {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
......@@ -814,6 +815,7 @@ async fn test_nv_custom_client() {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
......@@ -856,6 +858,7 @@ async fn test_nv_custom_client() {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
......
......@@ -91,6 +91,7 @@ fn create_mock_chat_completion_request() -> NvCreateChatCompletionRequest {
common: CommonExt::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
}
}
......
......@@ -283,6 +283,7 @@ impl Request {
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
}
}
......
......@@ -68,6 +68,7 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() {
.unwrap(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
......@@ -297,6 +298,7 @@ fn test_serialization_preserves_structure() {
..Default::default()
}),
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
......@@ -348,6 +350,7 @@ fn test_sampling_parameters_extraction() {
.unwrap(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment