Unverified Commit 8dd1fc07 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: in discovery, use dashmap in place of mutex hashmap (#5868)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent c5f5ab60
...@@ -7,7 +7,7 @@ use std::{ ...@@ -7,7 +7,7 @@ use std::{
}; };
use dashmap::{DashMap, mapref::entry::Entry}; use dashmap::{DashMap, mapref::entry::Entry};
use parking_lot::{Mutex, RwLock}; use parking_lot::RwLock;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use crate::discovery::KvWorkerMonitor; use crate::discovery::KvWorkerMonitor;
...@@ -70,10 +70,9 @@ pub struct ModelManager { ...@@ -70,10 +70,9 @@ pub struct ModelManager {
// Prefill models don't have engines - they're only tracked for discovery/lifecycle // Prefill models don't have engines - they're only tracked for discovery/lifecycle
prefill_engines: RwLock<ModelEngines<()>>, prefill_engines: RwLock<ModelEngines<()>>,
// These are Mutex because we read and write rarely and equally cards: DashMap<String, ModelDeploymentCard>,
cards: Mutex<HashMap<String, ModelDeploymentCard>>, kv_choosers: DashMap<EndpointId, Arc<KvRouter>>,
kv_choosers: Mutex<HashMap<EndpointId, Arc<KvRouter>>>, prefill_router_activators: DashMap<String, PrefillActivationState>,
prefill_router_activators: Mutex<HashMap<String, PrefillActivationState>>,
/// Per-model worker monitors for dynamic KV cache load rejection. /// Per-model worker monitors for dynamic KV cache load rejection.
/// Key: model name, Value: cloneable monitor (all fields are Arc). /// Key: model name, Value: cloneable monitor (all fields are Arc).
...@@ -100,9 +99,9 @@ impl ModelManager { ...@@ -100,9 +99,9 @@ impl ModelManager {
embeddings_engines: RwLock::new(ModelEngines::default()), embeddings_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()), tensor_engines: RwLock::new(ModelEngines::default()),
prefill_engines: RwLock::new(ModelEngines::default()), prefill_engines: RwLock::new(ModelEngines::default()),
cards: Mutex::new(HashMap::new()), cards: DashMap::new(),
kv_choosers: Mutex::new(HashMap::new()), kv_choosers: DashMap::new(),
prefill_router_activators: Mutex::new(HashMap::new()), prefill_router_activators: DashMap::new(),
worker_monitors: RwLock::new(HashMap::new()), worker_monitors: RwLock::new(HashMap::new()),
runtime_configs: DashMap::new(), runtime_configs: DashMap::new(),
} }
...@@ -147,7 +146,7 @@ impl ModelManager { ...@@ -147,7 +146,7 @@ impl ModelManager {
} }
pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> { pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.cards.lock().values().cloned().collect() self.cards.iter().map(|r| r.value().clone()).collect()
} }
/// Check if a decode model (chat or completions) is registered /// Check if a decode model (chat or completions) is registered
...@@ -318,13 +317,13 @@ impl ModelManager { ...@@ -318,13 +317,13 @@ impl ModelManager {
/// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is /// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
/// deleted. /// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> { pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.lock().insert(key.to_string(), card); self.cards.insert(key.to_string(), card);
Ok(()) Ok(())
} }
/// Remove and return model card for this instance's key. We do this when the instance stops. /// Remove and return model card for this instance's key. We do this when the instance stops.
pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> { pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
self.cards.lock().remove(key) self.cards.remove(key).map(|(_, v)| v)
} }
pub async fn kv_chooser_for( pub async fn kv_chooser_for(
...@@ -384,14 +383,12 @@ impl ModelManager { ...@@ -384,14 +383,12 @@ impl ModelManager {
) )
.await?; .await?;
let new_kv_chooser = Arc::new(chooser); let new_kv_chooser = Arc::new(chooser);
self.kv_choosers self.kv_choosers.insert(endpoint_id, new_kv_chooser.clone());
.lock()
.insert(endpoint_id, new_kv_chooser.clone());
Ok(new_kv_chooser) Ok(new_kv_chooser)
} }
fn get_kv_chooser(&self, id: &EndpointId) -> Option<Arc<KvRouter>> { fn get_kv_chooser(&self, id: &EndpointId) -> Option<Arc<KvRouter>> {
self.kv_choosers.lock().get(id).cloned() self.kv_choosers.get(id).map(|r| r.value().clone())
} }
/// Register a prefill router for a decode model. Returns a receiver that will be /// Register a prefill router for a decode model. Returns a receiver that will be
...@@ -401,10 +398,8 @@ impl ModelManager { ...@@ -401,10 +398,8 @@ impl ModelManager {
&self, &self,
model_name: String, model_name: String,
) -> Option<oneshot::Receiver<Endpoint>> { ) -> Option<oneshot::Receiver<Endpoint>> {
let mut activators = self.prefill_router_activators.lock(); match self.prefill_router_activators.remove(&model_name) {
Some((_, PrefillActivationState::PrefillReady(rx))) => {
match activators.remove(&model_name) {
Some(PrefillActivationState::PrefillReady(rx)) => {
// Prefill endpoint already arrived - rx will immediately resolve // Prefill endpoint already arrived - rx will immediately resolve
tracing::debug!( tracing::debug!(
model_name = %model_name, model_name = %model_name,
...@@ -412,19 +407,20 @@ impl ModelManager { ...@@ -412,19 +407,20 @@ impl ModelManager {
); );
Some(rx) Some(rx)
} }
Some(PrefillActivationState::DecodeWaiting(tx)) => { Some((key, PrefillActivationState::DecodeWaiting(tx))) => {
// Decode already registered - this shouldn't happen, restore state and return None // Decode already registered - this shouldn't happen, restore state and return None
tracing::error!( tracing::error!(
model_name = %model_name, model_name = %model_name,
"Decode model already registered for this prefill router" "Decode model already registered for this prefill router"
); );
activators.insert(model_name, PrefillActivationState::DecodeWaiting(tx)); self.prefill_router_activators
.insert(key, PrefillActivationState::DecodeWaiting(tx));
None None
} }
None => { None => {
// New registration: create tx/rx pair, store sender and return receiver // New registration: create tx/rx pair, store sender and return receiver
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
activators.insert( self.prefill_router_activators.insert(
model_name.clone(), model_name.clone(),
PrefillActivationState::DecodeWaiting(tx), PrefillActivationState::DecodeWaiting(tx),
); );
...@@ -444,10 +440,8 @@ impl ModelManager { ...@@ -444,10 +440,8 @@ impl ModelManager {
model_name: &str, model_name: &str,
endpoint: Endpoint, endpoint: Endpoint,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut activators = self.prefill_router_activators.lock(); match self.prefill_router_activators.remove(model_name) {
Some((_, PrefillActivationState::DecodeWaiting(sender))) => {
match activators.remove(model_name) {
Some(PrefillActivationState::DecodeWaiting(sender)) => {
// Decode model already registered // Decode model already registered
sender.send(endpoint).map_err(|_| { sender.send(endpoint).map_err(|_| {
anyhow::anyhow!( anyhow::anyhow!(
...@@ -463,7 +457,7 @@ impl ModelManager { ...@@ -463,7 +457,7 @@ impl ModelManager {
Ok(()) Ok(())
} }
Some(PrefillActivationState::PrefillReady(_)) => { Some((_, PrefillActivationState::PrefillReady(_))) => {
// Prefill already activated - this shouldn't happen // Prefill already activated - this shouldn't happen
anyhow::bail!("Prefill router for model {} already activated", model_name); anyhow::bail!("Prefill router for model {} already activated", model_name);
} }
...@@ -476,7 +470,7 @@ impl ModelManager { ...@@ -476,7 +470,7 @@ impl ModelManager {
})?; })?;
// Store the receiver for when decode model registers // Store the receiver for when decode model registers
activators.insert( self.prefill_router_activators.insert(
model_name.to_string(), model_name.to_string(),
PrefillActivationState::PrefillReady(rx), PrefillActivationState::PrefillReady(rx),
); );
...@@ -493,11 +487,9 @@ impl ModelManager { ...@@ -493,11 +487,9 @@ impl ModelManager {
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> { pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
self.cards self.cards
.lock() .iter()
.values() .find(|r| r.value().display_name == model)
.find(|c| c.display_name == model) .and_then(|r| r.value().runtime_config.tool_call_parser.clone())
.and_then(|c| c.runtime_config.tool_call_parser.as_ref())
.map(|parser| parser.to_string())
} }
/// Creates parsing options with tool call parser and reasoning parser for the specified model. /// Creates parsing options with tool call parser and reasoning parser for the specified model.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use dashmap::DashMap;
use crate::kv_router::KV_METRICS_SUBJECT; use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::protocols::ActiveLoad; use crate::kv_router::protocols::ActiveLoad;
use crate::model_card::ModelDeploymentCard; use crate::model_card::ModelDeploymentCard;
...@@ -9,9 +15,6 @@ use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field}; ...@@ -9,9 +15,6 @@ use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait}; use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::EventSubscriber; use dynamo_runtime::transports::event_plane::EventSubscriber;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places) /// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
const THRESHOLD_SCALE: u32 = 10000; const THRESHOLD_SCALE: u32 = 10000;
...@@ -86,7 +89,7 @@ impl WorkerLoadState { ...@@ -86,7 +89,7 @@ impl WorkerLoadState {
#[derive(Clone)] #[derive(Clone)]
pub struct KvWorkerMonitor { pub struct KvWorkerMonitor {
client: Client, client: Client,
worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>, worker_load_states: Arc<DashMap<u64, WorkerLoadState>>,
/// Active decode blocks threshold stored as parts-per-10000 (e.g., 8500 = 0.85) /// Active decode blocks threshold stored as parts-per-10000 (e.g., 8500 = 0.85)
active_decode_blocks_threshold: Arc<AtomicU32>, active_decode_blocks_threshold: Arc<AtomicU32>,
/// Active prefill tokens threshold stored as literal token count (u64) /// Active prefill tokens threshold stored as literal token count (u64)
...@@ -110,7 +113,7 @@ impl KvWorkerMonitor { ...@@ -110,7 +113,7 @@ impl KvWorkerMonitor {
) -> Self { ) -> Self {
Self { Self {
client, client,
worker_load_states: Arc::new(RwLock::new(HashMap::new())), worker_load_states: Arc::new(DashMap::new()),
active_decode_blocks_threshold: Arc::new(AtomicU32::new( active_decode_blocks_threshold: Arc::new(AtomicU32::new(
Self::active_decode_blocks_threshold_to_scaled(active_decode_blocks_threshold), Self::active_decode_blocks_threshold_to_scaled(active_decode_blocks_threshold),
)), )),
...@@ -160,7 +163,7 @@ impl KvWorkerMonitor { ...@@ -160,7 +163,7 @@ impl KvWorkerMonitor {
} }
/// Get the worker load states for external access /// Get the worker load states for external access
pub fn load_states(&self) -> Arc<RwLock<HashMap<u64, WorkerLoadState>>> { pub fn load_states(&self) -> Arc<DashMap<u64, WorkerLoadState>> {
self.worker_load_states.clone() self.worker_load_states.clone()
} }
} }
...@@ -219,12 +222,11 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -219,12 +222,11 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
_ = config_events_rx.changed() => { _ = config_events_rx.changed() => {
let runtime_configs = config_events_rx.borrow().clone(); let runtime_configs = config_events_rx.borrow().clone();
let mut states = worker_load_states.write().unwrap(); worker_load_states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
// Update worker load states with total blocks for all dp_ranks // Update worker load states with total blocks for all dp_ranks
for (lease_id, runtime_config) in runtime_configs.iter() { for (lease_id, runtime_config) in runtime_configs.iter() {
let state = states.entry(*lease_id).or_default(); let mut state = worker_load_states.entry(*lease_id).or_default();
// Populate total_blocks for all dp_ranks (they share the same total) // Populate total_blocks for all dp_ranks (they share the same total)
if let Some(total_blocks) = runtime_config.total_kv_blocks { if let Some(total_blocks) = runtime_config.total_kv_blocks {
...@@ -251,16 +253,16 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -251,16 +253,16 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let dp_rank = active_load.dp_rank; let dp_rank = active_load.dp_rank;
// Update worker load state per dp_rank // Update worker load state per dp_rank
let mut states = worker_load_states.write().unwrap(); {
let state = states.entry(worker_id).or_default(); let mut state = worker_load_states.entry(worker_id).or_default();
if let Some(active_blocks) = active_load.active_decode_blocks { if let Some(active_blocks) = active_load.active_decode_blocks {
state.active_decode_blocks.insert(dp_rank, active_blocks); state.active_decode_blocks.insert(dp_rank, active_blocks);
} }
if let Some(active_tokens) = active_load.active_prefill_tokens { if let Some(active_tokens) = active_load.active_prefill_tokens {
state.active_prefill_tokens.insert(dp_rank, active_tokens); state.active_prefill_tokens.insert(dp_rank, active_tokens);
}
} }
drop(states);
// Load thresholds dynamically - allows runtime updates // Load thresholds dynamically - allows runtime updates
let current_active_decode_blocks_threshold = Self::scaled_to_active_decode_blocks_threshold( let current_active_decode_blocks_threshold = Self::scaled_to_active_decode_blocks_threshold(
...@@ -269,16 +271,14 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -269,16 +271,14 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let current_active_prefill_tokens_threshold = active_prefill_tokens_threshold.load(Ordering::Relaxed); let current_active_prefill_tokens_threshold = active_prefill_tokens_threshold.load(Ordering::Relaxed);
// Recalculate all busy instances and update // Recalculate all busy instances and update
let states = worker_load_states.read().unwrap(); let busy_instances: Vec<u64> = worker_load_states
let busy_instances: Vec<u64> = states
.iter() .iter()
.filter_map(|(&id, state)| { .filter_map(|r| {
state r.value()
.is_busy(current_active_decode_blocks_threshold, current_active_prefill_tokens_threshold) .is_busy(current_active_decode_blocks_threshold, current_active_prefill_tokens_threshold)
.then_some(id) .then_some(*r.key())
}) })
.collect(); .collect();
drop(states);
// Only update if busy_instances has changed // Only update if busy_instances has changed
if busy_instances != previous_busy_instances { if busy_instances != previous_busy_instances {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment