// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::collections::HashMap; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use crate::protocols::tensor; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] pub struct DisaggregatedEndpoint { #[serde(default, skip_serializing_if = "Option::is_none")] pub bootstrap_host: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub bootstrap_port: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ModelRuntimeConfig { pub total_kv_blocks: Option, pub max_num_seqs: Option, pub max_num_batched_tokens: Option, pub tool_call_parser: Option, pub reasoning_parser: Option, /// Starting rank of data parallel ranks for this worker (0 if DP not enabled) #[serde(default = "default_data_parallel_start_rank")] pub data_parallel_start_rank: u32, /// Total number of data parallel ranks for this worker (1 if DP not enabled) #[serde(default = "default_data_parallel_size")] pub data_parallel_size: u32, /// Enable worker-local KV indexer for tracking this worker's own KV cache state (default: true) #[serde(default = "default_local_indexer")] pub enable_local_indexer: bool, /// Mapping of engine-specific runtime configs #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub runtime_data: HashMap, // Provide tensor model config in the case where the model type is Tensor. // Currently use JSON object for convinence, the programmatic way is to // define the model config struct as part of the tensor protocol and // import it here. // [gluo TODO] switch to ModelConfig if desired and workout a way to // prepare it in a convinent way, the protobuf library used by tonic // doesn't provide JSON parsing. #[serde(default, skip_serializing_if = "Option::is_none")] pub tensor_model_config: Option, /// Bootstrap endpoint for disaggregated serving (prefill workers publish this) #[serde(default, skip_serializing_if = "Option::is_none")] pub disaggregated_endpoint: Option, } const fn default_data_parallel_start_rank() -> u32 { 0 } const fn default_data_parallel_size() -> u32 { 1 } const fn default_local_indexer() -> bool { true } impl Default for ModelRuntimeConfig { fn default() -> Self { Self { total_kv_blocks: None, max_num_seqs: None, max_num_batched_tokens: None, tool_call_parser: None, reasoning_parser: None, data_parallel_start_rank: default_data_parallel_start_rank(), data_parallel_size: default_data_parallel_size(), enable_local_indexer: true, runtime_data: HashMap::new(), tensor_model_config: None, disaggregated_endpoint: None, } } } impl dynamo_kv_router::WorkerConfigLike for ModelRuntimeConfig { fn data_parallel_start_rank(&self) -> u32 { self.data_parallel_start_rank } fn data_parallel_size(&self) -> u32 { self.data_parallel_size } fn max_num_batched_tokens(&self) -> Option { self.max_num_batched_tokens } fn total_kv_blocks(&self) -> Option { self.total_kv_blocks } } impl ModelRuntimeConfig { pub fn new() -> Self { Self::default() } pub fn set_engine_specific(&mut self, key: &str, value: T) -> anyhow::Result<()> { self.runtime_data .insert(key.to_string(), serde_json::to_value(value)?); Ok(()) } pub fn get_engine_specific(&self, key: &str) -> anyhow::Result> { if let Some(value) = self.runtime_data.get(key) { Ok(Some(serde_json::from_value(value.clone())?)) } else { Ok(None) } } }