Unverified Commit 8dd1fc07 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: in discovery, use dashmap in place of mutex hashmap (#5868)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent c5f5ab60
......@@ -7,7 +7,7 @@ use std::{
};
use dashmap::{DashMap, mapref::entry::Entry};
use parking_lot::{Mutex, RwLock};
use parking_lot::RwLock;
use tokio::sync::oneshot;
use crate::discovery::KvWorkerMonitor;
......@@ -70,10 +70,9 @@ pub struct ModelManager {
// Prefill models don't have engines - they're only tracked for discovery/lifecycle
prefill_engines: RwLock<ModelEngines<()>>,
// These are Mutex because we read and write rarely and equally
cards: Mutex<HashMap<String, ModelDeploymentCard>>,
kv_choosers: Mutex<HashMap<EndpointId, Arc<KvRouter>>>,
prefill_router_activators: Mutex<HashMap<String, PrefillActivationState>>,
cards: DashMap<String, ModelDeploymentCard>,
kv_choosers: DashMap<EndpointId, Arc<KvRouter>>,
prefill_router_activators: DashMap<String, PrefillActivationState>,
/// Per-model worker monitors for dynamic KV cache load rejection.
/// Key: model name, Value: cloneable monitor (all fields are Arc).
......@@ -100,9 +99,9 @@ impl ModelManager {
embeddings_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()),
prefill_engines: RwLock::new(ModelEngines::default()),
cards: Mutex::new(HashMap::new()),
kv_choosers: Mutex::new(HashMap::new()),
prefill_router_activators: Mutex::new(HashMap::new()),
cards: DashMap::new(),
kv_choosers: DashMap::new(),
prefill_router_activators: DashMap::new(),
worker_monitors: RwLock::new(HashMap::new()),
runtime_configs: DashMap::new(),
}
......@@ -147,7 +146,7 @@ impl ModelManager {
}
pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.cards.lock().values().cloned().collect()
self.cards.iter().map(|r| r.value().clone()).collect()
}
/// Check if a decode model (chat or completions) is registered
......@@ -318,13 +317,13 @@ impl ModelManager {
/// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.lock().insert(key.to_string(), card);
self.cards.insert(key.to_string(), card);
Ok(())
}
/// Remove and return model card for this instance's key. We do this when the instance stops.
pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
self.cards.lock().remove(key)
self.cards.remove(key).map(|(_, v)| v)
}
pub async fn kv_chooser_for(
......@@ -384,14 +383,12 @@ impl ModelManager {
)
.await?;
let new_kv_chooser = Arc::new(chooser);
self.kv_choosers
.lock()
.insert(endpoint_id, new_kv_chooser.clone());
self.kv_choosers.insert(endpoint_id, new_kv_chooser.clone());
Ok(new_kv_chooser)
}
fn get_kv_chooser(&self, id: &EndpointId) -> Option<Arc<KvRouter>> {
self.kv_choosers.lock().get(id).cloned()
self.kv_choosers.get(id).map(|r| r.value().clone())
}
/// Register a prefill router for a decode model. Returns a receiver that will be
......@@ -401,10 +398,8 @@ impl ModelManager {
&self,
model_name: String,
) -> Option<oneshot::Receiver<Endpoint>> {
let mut activators = self.prefill_router_activators.lock();
match activators.remove(&model_name) {
Some(PrefillActivationState::PrefillReady(rx)) => {
match self.prefill_router_activators.remove(&model_name) {
Some((_, PrefillActivationState::PrefillReady(rx))) => {
// Prefill endpoint already arrived - rx will immediately resolve
tracing::debug!(
model_name = %model_name,
......@@ -412,19 +407,20 @@ impl ModelManager {
);
Some(rx)
}
Some(PrefillActivationState::DecodeWaiting(tx)) => {
Some((key, PrefillActivationState::DecodeWaiting(tx))) => {
// Decode already registered - this shouldn't happen, restore state and return None
tracing::error!(
model_name = %model_name,
"Decode model already registered for this prefill router"
);
activators.insert(model_name, PrefillActivationState::DecodeWaiting(tx));
self.prefill_router_activators
.insert(key, PrefillActivationState::DecodeWaiting(tx));
None
}
None => {
// New registration: create tx/rx pair, store sender and return receiver
let (tx, rx) = oneshot::channel();
activators.insert(
self.prefill_router_activators.insert(
model_name.clone(),
PrefillActivationState::DecodeWaiting(tx),
);
......@@ -444,10 +440,8 @@ impl ModelManager {
model_name: &str,
endpoint: Endpoint,
) -> anyhow::Result<()> {
let mut activators = self.prefill_router_activators.lock();
match activators.remove(model_name) {
Some(PrefillActivationState::DecodeWaiting(sender)) => {
match self.prefill_router_activators.remove(model_name) {
Some((_, PrefillActivationState::DecodeWaiting(sender))) => {
// Decode model already registered
sender.send(endpoint).map_err(|_| {
anyhow::anyhow!(
......@@ -463,7 +457,7 @@ impl ModelManager {
Ok(())
}
Some(PrefillActivationState::PrefillReady(_)) => {
Some((_, PrefillActivationState::PrefillReady(_))) => {
// Prefill already activated - this shouldn't happen
anyhow::bail!("Prefill router for model {} already activated", model_name);
}
......@@ -476,7 +470,7 @@ impl ModelManager {
})?;
// Store the receiver for when decode model registers
activators.insert(
self.prefill_router_activators.insert(
model_name.to_string(),
PrefillActivationState::PrefillReady(rx),
);
......@@ -493,11 +487,9 @@ impl ModelManager {
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
self.cards
.lock()
.values()
.find(|c| c.display_name == model)
.and_then(|c| c.runtime_config.tool_call_parser.as_ref())
.map(|parser| parser.to_string())
.iter()
.find(|r| r.value().display_name == model)
.and_then(|r| r.value().runtime_config.tool_call_parser.clone())
}
/// Creates parsing options with tool call parser and reasoning parser for the specified model.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use dashmap::DashMap;
use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::protocols::ActiveLoad;
use crate::model_card::ModelDeploymentCard;
......@@ -9,9 +15,6 @@ use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::EventSubscriber;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
const THRESHOLD_SCALE: u32 = 10000;
......@@ -86,7 +89,7 @@ impl WorkerLoadState {
#[derive(Clone)]
pub struct KvWorkerMonitor {
client: Client,
worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>,
worker_load_states: Arc<DashMap<u64, WorkerLoadState>>,
/// Active decode blocks threshold stored as parts-per-10000 (e.g., 8500 = 0.85)
active_decode_blocks_threshold: Arc<AtomicU32>,
/// Active prefill tokens threshold stored as literal token count (u64)
......@@ -110,7 +113,7 @@ impl KvWorkerMonitor {
) -> Self {
Self {
client,
worker_load_states: Arc::new(RwLock::new(HashMap::new())),
worker_load_states: Arc::new(DashMap::new()),
active_decode_blocks_threshold: Arc::new(AtomicU32::new(
Self::active_decode_blocks_threshold_to_scaled(active_decode_blocks_threshold),
)),
......@@ -160,7 +163,7 @@ impl KvWorkerMonitor {
}
/// Get the worker load states for external access
pub fn load_states(&self) -> Arc<RwLock<HashMap<u64, WorkerLoadState>>> {
pub fn load_states(&self) -> Arc<DashMap<u64, WorkerLoadState>> {
self.worker_load_states.clone()
}
}
......@@ -219,12 +222,11 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
_ = config_events_rx.changed() => {
let runtime_configs = config_events_rx.borrow().clone();
let mut states = worker_load_states.write().unwrap();
states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
worker_load_states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
// Update worker load states with total blocks for all dp_ranks
for (lease_id, runtime_config) in runtime_configs.iter() {
let state = states.entry(*lease_id).or_default();
let mut state = worker_load_states.entry(*lease_id).or_default();
// Populate total_blocks for all dp_ranks (they share the same total)
if let Some(total_blocks) = runtime_config.total_kv_blocks {
......@@ -251,8 +253,8 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let dp_rank = active_load.dp_rank;
// Update worker load state per dp_rank
let mut states = worker_load_states.write().unwrap();
let state = states.entry(worker_id).or_default();
{
let mut state = worker_load_states.entry(worker_id).or_default();
if let Some(active_blocks) = active_load.active_decode_blocks {
state.active_decode_blocks.insert(dp_rank, active_blocks);
......@@ -260,7 +262,7 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
if let Some(active_tokens) = active_load.active_prefill_tokens {
state.active_prefill_tokens.insert(dp_rank, active_tokens);
}
drop(states);
}
// Load thresholds dynamically - allows runtime updates
let current_active_decode_blocks_threshold = Self::scaled_to_active_decode_blocks_threshold(
......@@ -269,16 +271,14 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let current_active_prefill_tokens_threshold = active_prefill_tokens_threshold.load(Ordering::Relaxed);
// Recalculate all busy instances and update
let states = worker_load_states.read().unwrap();
let busy_instances: Vec<u64> = states
let busy_instances: Vec<u64> = worker_load_states
.iter()
.filter_map(|(&id, state)| {
state
.filter_map(|r| {
r.value()
.is_busy(current_active_decode_blocks_threshold, current_active_prefill_tokens_threshold)
.then_some(id)
.then_some(*r.key())
})
.collect();
drop(states);
// Only update if busy_instances has changed
if busy_instances != previous_busy_instances {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment