Unverified Commit d0a63635 authored by Jorge António's avatar Jorge António Committed by GitHub
Browse files

feat: add RuntimeConfig to ModelEntry (#2311)


Co-authored-by: default avatarYan Ru Pei <yanrpei@gmail.com>
parent b74b887b
......@@ -16,9 +16,11 @@ import zmq
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_ip, get_zmq_socket
from dynamo._core import Endpoint
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelRuntimeConfig,
ModelType,
WorkerMetricsPublisher,
WorkerStats,
......@@ -334,13 +336,8 @@ async def init(
await component.create_service()
endpoint = component.endpoint("generate")
await register_llm(
ModelType.Backend,
endpoint,
server_args.model_path,
server_args.served_model_name,
kv_cache_block_size=server_args.page_size,
migration_limit=migration_limit,
await register_llm_with_runtime_config(
engine, endpoint, server_args, migration_limit
)
if server_args.disaggregation_mode != "null":
......@@ -372,12 +369,75 @@ async def init(
_ = ZmqKvEventPublisher(component=component, config=zmq_config)
tasks = [endpoint.serve_endpoint(handler.generate)]
tasks.extend(setup_native_endpoints(server_args, component, handler))
await asyncio.gather(*tasks)
async def register_llm_with_runtime_config(
engine: sgl.Engine,
endpoint: Endpoint,
server_args: ServerArgs,
migration_limit: int,
):
"""Register LLM with runtime config"""
runtime_config = await _get_runtime_config(engine)
try:
await register_llm(
ModelType.Backend,
endpoint,
server_args.model_path,
server_args.served_model_name,
kv_cache_block_size=server_args.page_size,
migration_limit=migration_limit,
runtime_config=runtime_config,
)
except Exception as e:
logging.error(f"Failed to register with runtime config: {e}")
return None
async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]:
"""Get runtime config from SGLang engine"""
try:
# Try to check if the engine has a scheduler attribute with the computed values
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
runtime_config = ModelRuntimeConfig()
# Get max_total_num_tokens from scheduler_info
if "max_total_num_tokens" in engine.scheduler_info:
max_total_tokens = engine.scheduler_info["max_total_num_tokens"]
if max_total_tokens and hasattr(
engine.tokenizer_manager, "server_args"
):
page_size = engine.tokenizer_manager.server_args.page_size
if page_size:
runtime_config.total_kv_blocks = (
max_total_tokens + page_size - 1
) // page_size
logging.info(
f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} "
f"(max_total_tokens={max_total_tokens}, page_size={page_size})"
)
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info
# TODO: figure out where they are
return runtime_config
# If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config
logging.warning(
"Could not access runtime config from SGLang engine. "
"The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults."
)
return None
except Exception as e:
logging.warning(f"Failed to get runtime config: {e}. Proceeding without it.")
return None
def main():
uvloop.install()
asyncio.run(worker())
......
......@@ -20,10 +20,10 @@ from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from torch.cuda import device_count
from transformers import AutoConfig
from dynamo.llm import ModelType, register_llm
from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import get_llm_engine
from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import get_publisher
from dynamo.trtllm.request_handlers.handlers import (
......@@ -49,6 +49,39 @@ async def graceful_shutdown(runtime):
logging.info("DistributedRuntime shutdown complete")
async def get_engine_runtime_config(
engine: TensorRTLLMEngine, config: Config
) -> ModelRuntimeConfig:
"""Retrieve runtime configuration from TensorRT-LLM engine."""
runtime_config = ModelRuntimeConfig()
try:
# Extract total_kv_blocks from engine stats
stats = engine.llm.get_stats_async(timeout=5)
stat = await anext(stats)
runtime_config.total_kv_blocks = stat["kvCacheStats"]["maxNumBlocks"]
logging.info(
f"Set runtime config total_kv_blocks: {runtime_config.total_kv_blocks}"
)
# Extract max number of sequences
runtime_config.max_num_seqs = config.max_batch_size
logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}")
# Get max_num_batched_tokens from config
runtime_config.max_num_batched_tokens = config.max_num_tokens
logging.info(
f"Set runtime config max_num_batched_tokens: {runtime_config.max_num_batched_tokens}"
)
return runtime_config
except Exception as e:
logging.error(f"Failed to get runtime config from TensorRT-LLM engine: {e}")
# Return config with default/None values if retrieval fails
return runtime_config
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Set up signal handler for graceful shutdown
......@@ -196,7 +229,10 @@ async def init(runtime: DistributedRuntime, config: Config):
endpoint = component.endpoint(config.endpoint)
if is_first_worker(config):
# Register the model with the endpoint if only the worker is first in the disaggregation chain.
# Get runtime configuration from the engine
runtime_config = await get_engine_runtime_config(engine, config)
# Register the model with runtime config
await register_llm(
modelType,
endpoint,
......@@ -204,6 +240,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.served_model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
runtime_config=runtime_config, # Add runtime config here
)
# publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig(
......
......@@ -12,6 +12,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
from dynamo.llm import (
ModelRuntimeConfig,
ModelType,
ZmqKvEventPublisher,
ZmqKvEventPublisherConfig,
......@@ -213,6 +214,17 @@ async def init(runtime: DistributedRuntime, config: Config):
handler.kv_publisher = kv_publisher
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
runtime_config = ModelRuntimeConfig()
# make a `collective_rpc` call to get runtime configuration values
logging.info(
"Getting engine runtime configuration metadata from vLLM engine..."
)
runtime_values = get_engine_cache_info(engine_client)
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
await register_llm(
ModelType.Backend,
generate_endpoint,
......@@ -220,6 +232,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
migration_limit=config.migration_limit,
runtime_config=runtime_config,
)
try:
......@@ -237,6 +250,32 @@ async def init(runtime: DistributedRuntime, config: Config):
handler.cleanup()
def get_engine_cache_info(engine: AsyncLLM):
"""Retrieve cache configuration information from [`AsyncLLM`] engine."""
try:
# Get values directly from vllm_config instead of collective_rpc
cache_values = {
"num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
}
scheduler_values = {
"max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
"max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
}
logging.info(f"Cache config values: {cache_values}")
logging.info(f"Scheduler config values: {scheduler_values}")
return {
"num_gpu_blocks": cache_values["num_gpu_blocks"],
"max_num_seqs": scheduler_values["max_num_seqs"],
"max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
}
except Exception as e:
logging.error(f"Failed to get configuration values from vLLM config: {e}")
raise
def main():
uvloop.run(worker())
......
......@@ -20,6 +20,7 @@
// 2. Update the backend component to produce a config in a standard location.
// 3. Update the KvRouter to read the config from the backend component.
use std::collections::HashMap;
use std::sync::Arc;
use clap::Parser;
......@@ -29,7 +30,7 @@ use dynamo_llm::kv_router::{
scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest},
KvRouter, WorkerSelector,
};
use dynamo_runtime::component::Instance;
use dynamo_llm::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::{
logging, pipeline::network::Ingress, DistributedRuntime, Result, Runtime, Worker,
};
......@@ -86,7 +87,7 @@ pub struct CustomWorkerSelector(DefaultWorkerSelector);
impl WorkerSelector for CustomWorkerSelector {
fn select_worker(
&self,
workers: &[Instance],
workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
......
......@@ -25,6 +25,8 @@ use dynamo_runtime::{
use dynamo_llm::{self as llm_rs};
use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig};
use crate::llm::local_model::ModelRuntimeConfig;
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
pub enum RouterMode {
......@@ -82,6 +84,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::entrypoint::KvRouterConfig>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?;
m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
m.add_class::<llm::backend::Backend>()?;
m.add_class::<llm::kv::OverlapScores>()?;
......@@ -131,7 +134,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
}
#[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, user_data=None))]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None))]
#[allow(clippy::too_many_arguments)]
fn register_llm<'p>(
py: Python<'p>,
......@@ -143,6 +146,7 @@ fn register_llm<'p>(
kv_cache_block_size: Option<u32>,
router_mode: Option<RouterMode>,
migration_limit: u32,
runtime_config: Option<ModelRuntimeConfig>,
user_data: Option<&Bound<'p, PyDict>>,
) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type {
......@@ -173,6 +177,7 @@ fn register_llm<'p>(
.kv_cache_block_size(kv_cache_block_size)
.router_config(Some(router_config))
.migration_limit(Some(migration_limit))
.runtime_config(runtime_config.unwrap_or_default().inner)
.user_data(user_data_json);
// Download from HF, load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?;
......
......@@ -31,6 +31,7 @@ pub mod block_manager;
pub mod disagg_router;
pub mod entrypoint;
pub mod kv;
pub mod local_model;
pub mod model_card;
pub mod nats;
pub mod preprocessor;
......@@ -164,7 +164,8 @@ pub fn make_engine<'p>(
.kv_cache_block_size(args.kv_cache_block_size)
.router_config(args.router_config.clone().map(|rc| rc.into()))
.http_port(args.http_port)
.is_mocker(matches!(args.engine_type, EngineType::Mocker));
.is_mocker(matches!(args.engine_type, EngineType::Mocker))
.extra_engine_args(args.extra_engine_args.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_model = builder.build().await.map_err(to_pyerr)?;
let inner = select_engine(distributed_runtime, args, local_model)
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
use llm_rs::local_model::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig;
#[pyclass]
#[derive(Clone, Default)]
pub struct ModelRuntimeConfig {
pub(crate) inner: RsModelRuntimeConfig,
}
#[pymethods]
impl ModelRuntimeConfig {
#[new]
fn new() -> Self {
Self {
inner: RsModelRuntimeConfig::new(),
}
}
#[setter]
fn set_total_kv_blocks(&mut self, total_kv_blocks: u64) {
self.inner.total_kv_blocks = Some(total_kv_blocks);
}
#[setter]
fn set_max_num_seqs(&mut self, max_num_seqs: u64) {
self.inner.max_num_seqs = Some(max_num_seqs);
}
#[setter]
fn set_max_num_batched_tokens(&mut self, max_num_batched_tokens: u64) {
self.inner.max_num_batched_tokens = Some(max_num_batched_tokens);
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner
.set_engine_specific(key, value)
.map_err(to_pyerr)?;
Ok(())
}
#[getter]
fn total_kv_blocks(&self) -> Option<u64> {
self.inner.total_kv_blocks
}
#[getter]
fn max_num_seqs(&self) -> Option<u64> {
self.inner.max_num_seqs
}
#[getter]
fn max_num_batched_tokens(&self) -> Option<u64> {
self.inner.max_num_batched_tokens
}
#[getter]
fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> {
let dict = PyDict::new(py);
for (key, value) in self.inner.runtime_data.clone() {
dict.set_item(key, value.to_string())?;
}
Ok(dict.into())
}
fn get_engine_specific(&self, key: &str) -> PyResult<Option<String>> {
self.inner.get_engine_specific(key).map_err(to_pyerr)
}
}
......@@ -442,6 +442,12 @@ class ModelDeploymentCard:
...
class ModelRuntimeConfig:
"""
A model runtime configuration is a collection of runtime information
"""
...
class OAIChatPreprocessor:
"""
A preprocessor for OpenAI chat completions
......
......@@ -26,6 +26,7 @@ from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import KvStats as KvStats
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import RadixTree as RadixTree
......
......@@ -12,6 +12,7 @@ use dynamo_runtime::{
use serde::{Deserialize, Serialize};
use crate::{
local_model::runtime_config::ModelRuntimeConfig,
model_card::{self, ModelDeploymentCard},
model_type::ModelType,
};
......@@ -28,6 +29,10 @@ pub struct ModelEntry {
/// Specifies whether the model is a chat, completions, etc model.
pub model_type: ModelType,
/// Runtime configuration specific to this model instance
#[serde(default, skip_serializing_if = "Option::is_none")]
pub runtime_config: Option<ModelRuntimeConfig>,
}
impl ModelEntry {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
......@@ -34,16 +35,16 @@ use crate::{
compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface,
KvRouterError, OverlapScores, RouterEvent,
},
// metrics_aggregator::EndpointCollector,
metrics_aggregator::watch_model_runtime_configs,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
},
local_model::runtime_config::ModelRuntimeConfig,
preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput,
};
use dynamo_runtime::component::Instance;
use dynamo_runtime::traits::events::EventSubscriber;
// [gluo TODO] shouldn't need to be public
......@@ -65,7 +66,7 @@ pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
pub trait WorkerSelector {
fn select_worker(
&self,
workers: &[Instance],
workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError>;
......@@ -176,6 +177,15 @@ impl KvRouter {
}
};
// Create runtime config watcher
// TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality
let etcd_client = component
.drt()
.etcd_client()
.expect("Cannot KV route without etcd client");
let runtime_configs_rx =
watch_model_runtime_configs(etcd_client, cancellation_token.clone()).await?;
let indexer = if kv_router_config.use_kv_events {
Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
} else {
......@@ -191,6 +201,7 @@ impl KvRouter {
component.clone(),
block_size,
instances_rx,
runtime_configs_rx,
selector,
kv_router_config.router_replica_sync,
)
......
......@@ -18,10 +18,14 @@ use std::sync::Once;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::KV_METRICS_ENDPOINT;
use crate::discovery::{ModelEntry, MODEL_ROOT_PATH};
use crate::kv_router::scoring::Endpoint;
use crate::kv_router::ProcessedEndpoints;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::component::Component;
use dynamo_runtime::transports::etcd::{Client as EtcdClient, WatchEvent};
use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result};
use std::collections::HashMap;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
......@@ -208,3 +212,71 @@ pub async fn collect_endpoints_task(
}
}
}
pub async fn watch_model_runtime_configs(
etcd_client: EtcdClient,
cancellation_token: CancellationToken,
) -> Result<watch::Receiver<HashMap<i64, ModelRuntimeConfig>>> {
let (watch_tx, watch_rx) = watch::channel(HashMap::new());
let prefix_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
let (_prefix, _watcher, mut events_rx) = prefix_watcher.dissolve();
tokio::spawn(async move {
let mut runtime_configs: HashMap<i64, ModelRuntimeConfig> = HashMap::new();
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::debug!("Runtime config watcher cancelled");
break;
}
event = events_rx.recv() => {
let Some(event) = event else {
tracing::debug!("Runtime config watch stream closed");
break;
};
match event {
WatchEvent::Put(kv) => {
let Ok(model_entry) = serde_json::from_slice::<ModelEntry>(kv.value()) else {
tracing::warn!(
"Failed to parse ModelEntry from etcd. Key: {}",
kv.key_str().unwrap_or("<invalid>")
);
continue;
};
let lease_id = kv.lease();
if let Some(runtime_config) = model_entry.runtime_config {
runtime_configs.insert(lease_id, runtime_config);
tracing::trace!("Updated runtime config for lease_id: {}", lease_id);
} else {
runtime_configs.remove(&lease_id);
tracing::trace!("Removed runtime config (no config in ModelEntry)");
}
if watch_tx.send(runtime_configs.clone()).is_err() {
tracing::error!("Failed to send runtime configs update; receiver dropped");
break;
}
}
WatchEvent::Delete(kv) => {
let lease_id = kv.lease();
runtime_configs.remove(&lease_id);
tracing::trace!("Removed runtime config for deleted entry");
if watch_tx.send(runtime_configs.clone()).is_err() {
tracing::error!("Failed to send runtime configs update; receiver dropped");
break;
}
}
}
}
}
}
});
Ok(watch_rx)
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::component::{Component, Instance};
use dynamo_runtime::traits::events::EventPublisher;
use rand::Rng;
......@@ -8,6 +9,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use super::indexer::OverlapScores;
use super::protocols::WorkerSelectionResult;
......@@ -77,12 +79,15 @@ impl KvScheduler {
pub async fn start(
component: Component,
block_size: u32,
mut instances_rx: tokio::sync::watch::Receiver<Vec<Instance>>, // Changed from ProcessedEndpoints
mut instances_rx: watch::Receiver<Vec<Instance>>,
mut runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let mut instances: Vec<Instance> = instances_rx.borrow_and_update().clone();
let mut runtime_configs: HashMap<i64, ModelRuntimeConfig> =
runtime_configs_rx.borrow_and_update().clone();
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
let ns_clone = component.namespace().clone();
......@@ -112,10 +117,15 @@ impl KvScheduler {
tokio::spawn(async move {
let mut request_rx = request_rx;
tracing::trace!("scheduler background task started");
let mut workers_with_configs: HashMap<i64, Option<ModelRuntimeConfig>> = HashMap::new();
let mut needs_rebuild = true;
loop {
// First, check for instance updates (non-blocking)
match instances_rx.has_changed() {
// Check for instance updates (non-blocking)
let instances_changed = instances_rx.has_changed();
let configs_changed = runtime_configs_rx.has_changed();
match instances_changed {
Ok(true) => {
instances = instances_rx.borrow_and_update().clone();
let worker_ids: Vec<i64> = instances
......@@ -123,17 +133,42 @@ impl KvScheduler {
.map(|instance| instance.instance_id)
.collect();
slots_clone.update_workers(worker_ids);
needs_rebuild = true;
}
Ok(false) => {
// No changes, continue. This is the happy path.
}
Ok(false) => {}
Err(_) => {
tracing::warn!("endpoint watch sender shutdown");
break;
}
}
// Then, wait for a new request
// Check for runtime config updates
match configs_changed {
Ok(true) => {
runtime_configs = runtime_configs_rx.borrow_and_update().clone();
needs_rebuild = true;
}
Ok(false) => {}
Err(_) => {
tracing::warn!("runtime configs watch sender shutdown");
}
}
// Rebuild workers hashmap only when needed
if needs_rebuild {
workers_with_configs.clear();
for instance in &instances {
let worker_id = instance.instance_id;
let config = runtime_configs.get(&worker_id).cloned();
if config.is_none() {
tracing::warn!("Runtime config not found for worker_id: {}", worker_id);
}
workers_with_configs.insert(worker_id, config);
}
needs_rebuild = false;
}
// Wait for a new request
let Some(mut request) = request_rx.recv().await else {
tracing::warn!("scheduler shutdown");
break;
......@@ -150,7 +185,7 @@ impl KvScheduler {
request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens;
match selector.select_worker(&instances, &request, block_size) {
match selector.select_worker(&workers_with_configs, &request, block_size) {
Ok(selection) => {
if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: selection.worker_id,
......@@ -333,7 +368,7 @@ impl DefaultWorkerSelector {
impl WorkerSelector for DefaultWorkerSelector {
fn select_worker(
&self,
workers: &[Instance],
workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
......@@ -354,17 +389,16 @@ impl WorkerSelector for DefaultWorkerSelector {
let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker
for instance in workers.iter() {
let worker_id = instance.instance_id;
let overlap = *overlaps.get(&worker_id).unwrap_or(&0);
for worker_id in workers.keys() {
let overlap = *overlaps.get(worker_id).unwrap_or(&0);
// this is the number of prefill tokens the worker would have if the request were scheduled there
let prefill_token = *prefill_tokens.get(&worker_id).unwrap_or(&isl);
let prefill_token = *prefill_tokens.get(worker_id).unwrap_or(&isl);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
// this is the number of decode blocks the worker would have if the request were scheduled there
let decode_block = *decode_blocks
.get(&worker_id)
.get(worker_id)
.unwrap_or(&(potential_prefill_block.floor() as usize))
as f64;
......@@ -373,7 +407,7 @@ impl WorkerSelector for DefaultWorkerSelector {
self.kv_router_config.overlap_score_weight * potential_prefill_block + decode_block;
max_logit = max_logit.max(logit);
worker_logits.insert(worker_id, logit);
worker_logits.insert(*worker_id, logit);
let overlap_weight = self.kv_router_config.overlap_score_weight;
tracing::info!(
......@@ -388,10 +422,20 @@ impl WorkerSelector for DefaultWorkerSelector {
let best_worker_id = softmax_sample(&worker_logits, temperature);
let best_logit = worker_logits[&best_worker_id];
let best_overlap = *overlaps.get(&best_worker_id).unwrap_or(&0);
let total_blocks_info = workers
.get(&best_worker_id)
.and_then(|cfg| cfg.as_ref())
.and_then(|cfg| cfg.total_kv_blocks)
.map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default();
tracing::info!(
"Selected worker: {}, logit: {:.3}",
"Selected worker: {}, logit: {:.3}, cached blocks: {}{}",
best_worker_id,
best_logit
best_logit,
best_overlap,
total_blocks_info
);
Ok(WorkerSelectionResult {
......
......@@ -16,12 +16,16 @@ use dynamo_runtime::{
use crate::discovery::ModelEntry;
use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::ModelType;
use crate::request_template::RequestTemplate;
mod network_name;
pub use network_name::ModelNetworkName;
pub mod runtime_config;
use runtime_config::ModelRuntimeConfig;
/// Prefix for Hugging Face model repository
const HF_SCHEME: &str = "hf://";
......@@ -48,6 +52,8 @@ pub struct LocalModelBuilder {
http_port: u16,
migration_limit: u32,
is_mocker: bool,
extra_engine_args: Option<PathBuf>,
runtime_config: ModelRuntimeConfig,
user_data: Option<serde_json::Value>,
}
......@@ -65,6 +71,8 @@ impl Default for LocalModelBuilder {
router_config: Default::default(),
migration_limit: Default::default(),
is_mocker: Default::default(),
extra_engine_args: Default::default(),
runtime_config: Default::default(),
user_data: Default::default(),
}
}
......@@ -128,6 +136,16 @@ impl LocalModelBuilder {
self
}
pub fn extra_engine_args(&mut self, extra_engine_args: Option<PathBuf>) -> &mut Self {
self.extra_engine_args = extra_engine_args;
self
}
pub fn runtime_config(&mut self, runtime_config: ModelRuntimeConfig) -> &mut Self {
self.runtime_config = runtime_config;
self
}
pub fn user_data(&mut self, user_data: Option<serde_json::Value>) -> &mut Self {
self.user_data = user_data;
self
......@@ -170,6 +188,7 @@ impl LocalModelBuilder {
template,
http_port: self.http_port,
router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(),
});
}
......@@ -218,6 +237,20 @@ impl LocalModelBuilder {
card.context_length = context_length;
}
// Override runtime configs with mocker engine args
if self.is_mocker {
if let Some(path) = &self.extra_engine_args {
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.runtime_config.total_kv_blocks =
Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs =
mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
}
}
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
......@@ -228,6 +261,7 @@ impl LocalModelBuilder {
template,
http_port: self.http_port,
router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(),
})
}
}
......@@ -240,6 +274,7 @@ pub struct LocalModel {
template: Option<RequestTemplate>,
http_port: u16, // Only used if input is HTTP server
router_config: RouterConfig,
runtime_config: ModelRuntimeConfig,
}
impl LocalModel {
......@@ -274,6 +309,10 @@ impl LocalModel {
&self.router_config
}
pub fn runtime_config(&self) -> &ModelRuntimeConfig {
&self.runtime_config
}
pub fn is_gguf(&self) -> bool {
// GGUF is the only file (not-folder) we accept, so we don't need to check the extension
// We will error when we come to parse it
......@@ -323,6 +362,7 @@ impl LocalModel {
name: self.display_name().to_string(),
endpoint: endpoint.id(),
model_type,
runtime_config: Some(self.runtime_config.clone()),
};
etcd_client
.kv_create(
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelRuntimeConfig {
pub total_kv_blocks: Option<u64>,
pub max_num_seqs: Option<u64>,
pub max_num_batched_tokens: Option<u64>,
/// Mapping of engine-specific runtime configs
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>,
}
impl ModelRuntimeConfig {
pub fn new() -> Self {
Self::default()
}
pub fn set_engine_specific<T: Serialize>(&mut self, key: &str, value: T) -> anyhow::Result<()> {
self.runtime_data
.insert(key.to_string(), serde_json::to_value(value)?);
Ok(())
}
pub fn get_engine_specific<T: DeserializeOwned>(&self, key: &str) -> anyhow::Result<Option<T>> {
if let Some(value) = self.runtime_data.get(key) {
Ok(Some(serde_json::from_value(value.clone())?))
} else {
Ok(None)
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment