Unverified Commit 14af074e authored by milesial's avatar milesial Committed by GitHub
Browse files

feat: Media decoder and fetcher options in the MDC (#4094)


Signed-off-by: default avatarAlexandre Milesi <milesial@users.noreply.github.com>
parent e17b0460
...@@ -39,6 +39,7 @@ use dynamo_llm::{self as llm_rs}; ...@@ -39,6 +39,7 @@ use dynamo_llm::{self as llm_rs};
use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig}; use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig};
use crate::llm::local_model::ModelRuntimeConfig; use crate::llm::local_model::ModelRuntimeConfig;
use crate::llm::preprocessor::{MediaDecoder, MediaFetcher};
#[pyclass(eq, eq_int)] #[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
...@@ -161,6 +162,8 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -161,6 +162,8 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::model_card::ModelDeploymentCard>()?; m.add_class::<llm::model_card::ModelDeploymentCard>()?;
m.add_class::<llm::local_model::ModelRuntimeConfig>()?; m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?; m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
m.add_class::<llm::preprocessor::MediaDecoder>()?;
m.add_class::<llm::preprocessor::MediaFetcher>()?;
m.add_class::<llm::backend::Backend>()?; m.add_class::<llm::backend::Backend>()?;
m.add_class::<llm::kv::OverlapScores>()?; m.add_class::<llm::kv::OverlapScores>()?;
m.add_class::<llm::kv::KvIndexer>()?; m.add_class::<llm::kv::KvIndexer>()?;
...@@ -217,7 +220,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) ...@@ -217,7 +220,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
/// Create an engine and attach it to an endpoint to make it visible to the frontend. /// Create an engine and attach it to an endpoint to make it visible to the frontend.
/// This is the main way you create a Dynamo worker / backend. /// This is the main way you create a Dynamo worker / backend.
#[pyfunction] #[pyfunction]
#[pyo3(signature = (model_input, model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None))] #[pyo3(signature = (model_input, model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None, media_decoder=None, media_fetcher=None))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn register_llm<'p>( fn register_llm<'p>(
py: Python<'p>, py: Python<'p>,
...@@ -233,6 +236,8 @@ fn register_llm<'p>( ...@@ -233,6 +236,8 @@ fn register_llm<'p>(
runtime_config: Option<ModelRuntimeConfig>, runtime_config: Option<ModelRuntimeConfig>,
user_data: Option<&Bound<'p, PyDict>>, user_data: Option<&Bound<'p, PyDict>>,
custom_template_path: Option<&str>, custom_template_path: Option<&str>,
media_decoder: Option<MediaDecoder>,
media_fetcher: Option<MediaFetcher>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
// Validate Prefill model type requirements // Validate Prefill model type requirements
if model_type.inner == llm_rs::model_type::ModelType::Prefill { if model_type.inner == llm_rs::model_type::ModelType::Prefill {
...@@ -305,7 +310,9 @@ fn register_llm<'p>( ...@@ -305,7 +310,9 @@ fn register_llm<'p>(
.migration_limit(Some(migration_limit)) .migration_limit(Some(migration_limit))
.runtime_config(runtime_config.unwrap_or_default().inner) .runtime_config(runtime_config.unwrap_or_default().inner)
.user_data(user_data_json) .user_data(user_data_json)
.custom_template_path(custom_template_path_owned); .custom_template_path(custom_template_path_owned)
.media_decoder(media_decoder.map(|m| m.inner))
.media_fetcher(media_fetcher.map(|m| m.inner));
// Load the ModelDeploymentCard // Load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?; let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us // Advertise ourself on etcd so ingress can find us
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
use super::*; use super::*;
use crate::llm::model_card::ModelDeploymentCard; use crate::llm::model_card::ModelDeploymentCard;
use std::time::Duration;
use llm_rs::{ use llm_rs::{
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
preprocessor::media::{MediaDecoder as RsMediaDecoder, MediaFetcher as RsMediaFetcher},
protocols::common::llm_backend::{BackendOutput, PreprocessedRequest}, protocols::common::llm_backend::{BackendOutput, PreprocessedRequest},
types::{ types::{
Annotated, Annotated,
...@@ -74,3 +76,62 @@ impl OAIChatPreprocessor { ...@@ -74,3 +76,62 @@ impl OAIChatPreprocessor {
}) })
} }
} }
#[pyclass]
#[derive(Clone)]
pub struct MediaDecoder {
pub(crate) inner: RsMediaDecoder,
}
#[pymethods]
impl MediaDecoder {
#[new]
fn new() -> Self {
Self {
inner: RsMediaDecoder::default(),
}
}
fn image_decoder(&mut self, image_decoder: &Bound<'_, PyDict>) -> PyResult<()> {
let image_decoder = pythonize::depythonize(image_decoder).map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to parse image_decoder: {}", err))
})?;
self.inner.image_decoder = image_decoder;
Ok(())
}
}
#[pyclass]
#[derive(Clone)]
pub struct MediaFetcher {
pub(crate) inner: RsMediaFetcher,
}
#[pymethods]
impl MediaFetcher {
#[new]
fn new() -> Self {
Self {
inner: RsMediaFetcher::default(),
}
}
fn user_agent(&mut self, user_agent: String) {
self.inner.user_agent = user_agent;
}
fn allow_direct_ip(&mut self, allow: bool) {
self.inner.allow_direct_ip = allow;
}
fn allow_direct_port(&mut self, allow: bool) {
self.inner.allow_direct_port = allow;
}
fn allowed_media_domains(&mut self, domains: Vec<String>) {
self.inner.allowed_media_domains = Some(domains.into_iter().collect());
}
fn timeout_ms(&mut self, timeout_ms: u64) {
self.inner.timeout = Some(Duration::from_millis(timeout_ms));
}
}
...@@ -18,6 +18,8 @@ from dynamo._core import KvPushRouter as KvPushRouter ...@@ -18,6 +18,8 @@ from dynamo._core import KvPushRouter as KvPushRouter
from dynamo._core import KvRecorder as KvRecorder from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouterConfig as KvRouterConfig from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import KvStats as KvStats from dynamo._core import KvStats as KvStats
from dynamo._core import MediaDecoder as MediaDecoder
from dynamo._core import MediaFetcher as MediaFetcher
from dynamo._core import ModelInput as ModelInput from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
from dynamo._core import ModelType as ModelType from dynamo._core import ModelType as ModelType
......
...@@ -14,6 +14,7 @@ use crate::entrypoint::RouterConfig; ...@@ -14,6 +14,7 @@ use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs; use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::{self, ModelDeploymentCard}; use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::{ModelInput, ModelType}; use crate::model_type::{ModelInput, ModelType};
use crate::preprocessor::media::{MediaDecoder, MediaFetcher};
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
pub mod runtime_config; pub mod runtime_config;
...@@ -52,6 +53,8 @@ pub struct LocalModelBuilder { ...@@ -52,6 +53,8 @@ pub struct LocalModelBuilder {
namespace: Option<String>, namespace: Option<String>,
custom_backend_metrics_endpoint: Option<String>, custom_backend_metrics_endpoint: Option<String>,
custom_backend_metrics_polling_interval: Option<f64>, custom_backend_metrics_polling_interval: Option<f64>,
media_decoder: Option<MediaDecoder>,
media_fetcher: Option<MediaFetcher>,
} }
impl Default for LocalModelBuilder { impl Default for LocalModelBuilder {
...@@ -77,6 +80,8 @@ impl Default for LocalModelBuilder { ...@@ -77,6 +80,8 @@ impl Default for LocalModelBuilder {
namespace: Default::default(), namespace: Default::default(),
custom_backend_metrics_endpoint: Default::default(), custom_backend_metrics_endpoint: Default::default(),
custom_backend_metrics_polling_interval: Default::default(), custom_backend_metrics_polling_interval: Default::default(),
media_decoder: Default::default(),
media_fetcher: Default::default(),
} }
} }
} }
...@@ -184,6 +189,16 @@ impl LocalModelBuilder { ...@@ -184,6 +189,16 @@ impl LocalModelBuilder {
self self
} }
pub fn media_decoder(&mut self, media_decoder: Option<MediaDecoder>) -> &mut Self {
self.media_decoder = media_decoder;
self
}
pub fn media_fetcher(&mut self, media_fetcher: Option<MediaFetcher>) -> &mut Self {
self.media_fetcher = media_fetcher;
self
}
/// Make an LLM ready for use: /// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary /// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path /// - Resolve the path
...@@ -219,6 +234,8 @@ impl LocalModelBuilder { ...@@ -219,6 +234,8 @@ impl LocalModelBuilder {
self.runtime_config.max_num_batched_tokens = self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64); mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size; self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
self.media_decoder = Some(MediaDecoder::default());
self.media_fetcher = Some(MediaFetcher::default());
} }
// frontend and echo engine don't need a path. // frontend and echo engine don't need a path.
...@@ -230,6 +247,8 @@ impl LocalModelBuilder { ...@@ -230,6 +247,8 @@ impl LocalModelBuilder {
card.migration_limit = self.migration_limit; card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take(); card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone(); card.runtime_config = self.runtime_config.clone();
card.media_decoder = self.media_decoder.clone();
card.media_fetcher = self.media_fetcher.clone();
return Ok(LocalModel { return Ok(LocalModel {
card, card,
...@@ -280,6 +299,8 @@ impl LocalModelBuilder { ...@@ -280,6 +299,8 @@ impl LocalModelBuilder {
card.migration_limit = self.migration_limit; card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take(); card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone(); card.runtime_config = self.runtime_config.clone();
card.media_decoder = self.media_decoder.clone();
card.media_fetcher = self.media_fetcher.clone();
Ok(LocalModel { Ok(LocalModel {
card, card,
......
...@@ -25,6 +25,7 @@ use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned}; ...@@ -25,6 +25,7 @@ use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer; use tokenizers::Tokenizer as HfTokenizer;
use crate::preprocessor::media::{MediaDecoder, MediaFetcher};
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
/// Identify model deployment cards in the key-value store /// Identify model deployment cards in the key-value store
...@@ -217,6 +218,14 @@ pub struct ModelDeploymentCard { ...@@ -217,6 +218,14 @@ pub struct ModelDeploymentCard {
#[serde(default)] #[serde(default)]
pub runtime_config: ModelRuntimeConfig, pub runtime_config: ModelRuntimeConfig,
/// Media decoding configuration
#[serde(default)]
pub media_decoder: Option<MediaDecoder>,
/// Media fetching configuration
#[serde(default)]
pub media_fetcher: Option<MediaFetcher>,
#[serde(skip, default)] #[serde(skip, default)]
checksum: OnceLock<String>, checksum: OnceLock<String>,
} }
...@@ -520,6 +529,8 @@ impl ModelDeploymentCard { ...@@ -520,6 +529,8 @@ impl ModelDeploymentCard {
model_input: Default::default(), // set later model_input: Default::default(), // set later
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
media_decoder: None,
media_fetcher: None,
checksum: OnceLock::new(), checksum: OnceLock::new(),
}) })
} }
......
...@@ -7,4 +7,4 @@ mod loader; ...@@ -7,4 +7,4 @@ mod loader;
pub use common::EncodedMediaData; pub use common::EncodedMediaData;
pub use decoders::{Decoder, ImageDecoder, MediaDecoder}; pub use decoders::{Decoder, ImageDecoder, MediaDecoder};
pub use loader::MediaLoader; pub use loader::{MediaFetcher, MediaLoader};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment