Unverified Commit 2fe37a51 authored by huitian bai's avatar huitian bai Committed by GitHub
Browse files

fix: sglang eagle bigram tokens kv event report. (#6872)


Signed-off-by: default avatarbaihuitian <baihuitian.bht@gmail.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarPeaBrane <yanrpei@gmail.com>
parent db79f324
......@@ -12,8 +12,9 @@ use dynamo_kv_router::{
indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError},
protocols::KV_EVENT_SUBJECT,
protocols::{
BlockExtraInfo, DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest,
RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq,
BlockExtraInfo, BlockHashOptions, DpRank, LocalBlockHash, OverlapScores, RouterEvent,
RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
compute_block_hash_for_seq,
},
};
use dynamo_runtime::{
......@@ -301,6 +302,7 @@ where
kv_router_config: KvRouterConfig,
cancellation_token: tokio_util::sync::CancellationToken,
client: Client,
is_eagle: bool,
}
impl<Sel> KvRouter<Sel>
......@@ -317,6 +319,7 @@ where
kv_router_config: Option<KvRouterConfig>,
worker_type: &'static str,
model_name: Option<String>,
is_eagle: bool,
) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default();
kv_router_config.validate()?;
......@@ -370,6 +373,7 @@ where
kv_router_config,
cancellation_token,
client,
is_eagle,
})
}
......@@ -386,12 +390,15 @@ where
&self.kv_router_config
}
pub fn is_eagle(&self) -> bool {
self.is_eagle
}
pub async fn record_routing_decision(
&self,
tokens: Vec<u32>,
mut tokens_with_hashes: TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let mut tokens_with_hashes = TokensWithHashes::new(tokens, self.block_size);
self.indexer
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
......@@ -427,8 +434,11 @@ where
compute_block_hash_for_seq(
tokens,
self.block_size,
BlockHashOptions {
block_mm_infos,
lora_name.as_deref(),
lora_name: lora_name.as_deref(),
is_eagle: Some(self.is_eagle),
},
)
});
let hash_elapsed = start.elapsed();
......@@ -446,7 +456,11 @@ where
tokens,
self.block_size,
router_config_override,
lora_name.as_deref(),
BlockHashOptions {
block_mm_infos,
lora_name: lora_name.as_deref(),
is_eagle: Some(self.is_eagle),
},
)
});
let seq_hash_elapsed = start.elapsed();
......@@ -502,6 +516,7 @@ where
&self,
request_id: String,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
overlap_blocks: u32,
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank,
......@@ -514,7 +529,11 @@ where
tokens,
self.block_size,
router_config_override,
lora_name.as_deref(),
BlockHashOptions {
block_mm_infos,
lora_name: lora_name.as_deref(),
is_eagle: Some(self.is_eagle),
},
);
if let Err(e) = self
......@@ -570,10 +589,19 @@ where
pub async fn get_overlap_blocks(
&self,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
worker: WorkerWithDpRank,
lora_name: Option<&str>,
) -> Result<u32, KvRouterError> {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None, lora_name);
let block_hashes = compute_block_hash_for_seq(
tokens,
self.block_size,
BlockHashOptions {
block_mm_infos,
lora_name,
is_eagle: Some(self.is_eagle),
},
);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
}
......@@ -583,17 +611,30 @@ where
&self,
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
lora_name: Option<&str>,
) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None, lora_name);
let block_hashes = compute_block_hash_for_seq(
tokens,
self.block_size,
BlockHashOptions {
block_mm_infos,
lora_name,
is_eagle: Some(self.is_eagle),
},
);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
self.block_size,
router_config_override,
BlockHashOptions {
block_mm_infos,
lora_name,
is_eagle: Some(self.is_eagle),
},
);
Ok(self
......
......@@ -117,6 +117,7 @@ pub struct PrefillRouter {
model_name: String,
/// Namespace used to look up the correct WorkerSet's worker monitor
namespace: String,
is_eagle: bool,
}
impl PrefillRouter {
......@@ -135,6 +136,7 @@ impl PrefillRouter {
enforce_disagg,
model_name: String::new(), // Not used for disabled router
namespace: String::new(), // Not used for disabled router
is_eagle: false,
})
}
......@@ -148,6 +150,7 @@ impl PrefillRouter {
enforce_disagg: bool,
model_name: String,
namespace: String,
is_eagle: bool,
) -> Arc<Self> {
let prefill_router = OnceLock::new();
let cancel_token = CancellationToken::new();
......@@ -161,6 +164,7 @@ impl PrefillRouter {
enforce_disagg,
model_name,
namespace,
is_eagle,
});
// Spawn background task to wait for activation
......@@ -222,6 +226,7 @@ impl PrefillRouter {
kv_router_config,
WORKER_TYPE_PREFILL,
Some(self.model_name.clone()),
self.is_eagle,
)
.await?;
......
......@@ -1048,7 +1048,7 @@ impl WorkerMetricsPublisher {
#[cfg(test)]
mod test_event_processing {
use super::*;
use dynamo_kv_router::protocols::compute_block_hash_for_seq;
use dynamo_kv_router::protocols::{BlockHashOptions, compute_block_hash_for_seq};
// ---------------------------------------------------------------------
// create_stored_block_from_parts --------------------------------------
......@@ -1060,10 +1060,11 @@ mod test_event_processing {
let blk_hash = 0xdead_beef;
let stored =
create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, None, None);
create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, None, None, None);
assert_eq!(stored.block_hash.0, blk_hash);
let expected_hash = compute_block_hash_for_seq(&token_ids, 4, None, None)[0];
let expected_hash =
compute_block_hash_for_seq(&token_ids, 4, BlockHashOptions::default())[0];
assert_eq!(stored.tokens_hash, expected_hash);
assert!(stored.mm_extra_info.is_none());
}
......@@ -1087,6 +1088,7 @@ mod test_event_processing {
None,
&Arc::new(AtomicU32::new(0)),
None,
None,
);
assert_eq!(blocks.len(), 2);
......@@ -1110,6 +1112,7 @@ mod test_event_processing {
None,
&warning_count,
None,
None,
);
// should early-exit as second has mismatch
......@@ -1131,6 +1134,7 @@ mod test_event_processing {
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
};
let out = convert_event(
......@@ -1156,6 +1160,7 @@ mod test_event_processing {
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
};
let lora_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
......@@ -1165,6 +1170,7 @@ mod test_event_processing {
medium: None,
lora_name: Some("my-lora".to_string()),
block_mm_infos: None,
is_eagle: None,
};
let wc = Arc::new(AtomicU32::new(0));
......@@ -1211,6 +1217,7 @@ mod test_event_processing {
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
};
let evt2 = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
......@@ -1220,6 +1227,7 @@ mod test_event_processing {
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
};
let out1 = convert_event(
......@@ -1883,6 +1891,7 @@ mod tests_startup_helpers {
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
}];
let batch = KvEventBatch {
......
......@@ -4,7 +4,7 @@
use std::sync::Arc;
use anyhow::Result;
use dynamo_kv_router::protocols::WorkerWithDpRank;
use dynamo_kv_router::protocols::{TokensWithHashes, WorkerWithDpRank};
use dynamo_runtime::{
dynamo_nvtx_range,
pipeline::{
......@@ -298,7 +298,12 @@ impl KvPushRouter {
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = self
.chooser
.get_overlap_blocks(routing_token_ids, worker, lora_name.as_deref())
.get_overlap_blocks(
routing_token_ids,
block_mm_infos,
worker,
lora_name.as_deref(),
)
.await?;
if !is_query_only {
......@@ -306,6 +311,7 @@ impl KvPushRouter {
.add_request(
context_id.to_string(),
routing_token_ids,
block_mm_infos,
overlap_blocks,
expected_output_tokens,
worker,
......@@ -385,10 +391,21 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// so the indexer can track cache state based on routing decisions.
// This covers both pre-selected workers and find_best_match selections.
if !is_query_only && !self.chooser.kv_router_config().use_kv_events {
let lora_name = request.routing.as_ref().and_then(|r| r.lora_name.clone());
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
let worker = WorkerWithDpRank::new(instance_id, dp_rank);
let mut tokens_with_hashes =
TokensWithHashes::new(routing_token_ids.to_vec(), self.chooser.block_size())
.with_is_eagle(self.chooser.is_eagle());
if let Some(infos) = block_mm_infos {
tokens_with_hashes = tokens_with_hashes.with_mm_infos(infos.to_vec());
}
if let Some(lora_name) = lora_name {
tokens_with_hashes = tokens_with_hashes.with_lora_name(lora_name);
}
if let Err(e) = self
.chooser
.record_routing_decision(request.token_ids.clone(), worker)
.record_routing_decision(tokens_with_hashes, worker)
.await
{
tracing::warn!(
......
......@@ -57,6 +57,9 @@ pub struct ModelRuntimeConfig {
/// Bootstrap endpoint for disaggregated serving (prefill workers publish this)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_endpoint: Option<DisaggregatedEndpoint>,
#[serde(default = "default_eagle")]
pub enable_eagle: bool,
}
const fn default_data_parallel_start_rank() -> u32 {
......@@ -71,6 +74,10 @@ const fn default_local_indexer() -> bool {
true
}
const fn default_eagle() -> bool {
false
}
impl Default for ModelRuntimeConfig {
fn default() -> Self {
Self {
......@@ -85,6 +92,7 @@ impl Default for ModelRuntimeConfig {
runtime_data: HashMap::new(),
tensor_model_config: None,
disaggregated_endpoint: None,
enable_eagle: false,
}
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_kv_router::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
use dynamo_kv_router::protocols::{
BlockHashOptions, compute_block_hash_for_seq, compute_seq_hash_for_block,
};
use tempfile::NamedTempFile;
use uuid::Uuid;
......@@ -111,7 +113,8 @@ fn test_turn_replay_hashes_match_full_blocks_only() {
.to_direct_request(4, Uuid::from_u128(1), Some(5.0))
.unwrap();
let replay_hashes = turn.to_replay_hashes(4).unwrap();
let expected_local = compute_block_hash_for_seq(&request.tokens, 4, None, None);
let expected_local =
compute_block_hash_for_seq(&request.tokens, 4, BlockHashOptions::default());
assert_eq!(replay_hashes.local_block_hashes, expected_local);
assert_eq!(
......
......@@ -10,7 +10,7 @@ use anyhow::{Context, Result, anyhow};
use dynamo_kv_router::LocalBlockHash;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::{
OverlapScores, RouterEvent, WorkerConfigLike, WorkerId, WorkerWithDpRank,
BlockHashOptions, OverlapScores, RouterEvent, WorkerConfigLike, WorkerId, WorkerWithDpRank,
compute_block_hash_for_seq,
};
use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS;
......@@ -68,7 +68,14 @@ impl SyncReplayIndexer {
}
fn find_matches_for_request(&self, tokens: &[u32], lora_name: Option<&str>) -> OverlapScores {
let sequence = compute_block_hash_for_seq(tokens, self.block_size, None, lora_name);
let sequence = compute_block_hash_for_seq(
tokens,
self.block_size,
BlockHashOptions {
lora_name,
..Default::default()
},
);
self.tree.find_matches(sequence, false)
}
......@@ -340,7 +347,7 @@ impl OfflineReplayRouter {
&request.tokens,
self.block_size,
None,
None,
BlockHashOptions::default(),
)
};
(overlaps, token_seq)
......@@ -351,7 +358,7 @@ impl OfflineReplayRouter {
&request.tokens,
self.block_size,
None,
None,
BlockHashOptions::default(),
);
(overlaps, token_seq)
}
......
......@@ -11,7 +11,7 @@ use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, ThreadPoolIndexer,
};
use dynamo_kv_router::protocols::{OverlapScores, RouterEvent, WorkerId};
use dynamo_kv_router::protocols::{BlockHashOptions, OverlapScores, RouterEvent, WorkerId};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
......@@ -46,11 +46,11 @@ impl ReplayIndexer {
) -> Result<OverlapScores> {
match self {
Self::Single(indexer) => indexer
.find_matches_for_request(tokens, lora_name)
.find_matches_for_request(tokens, lora_name, None)
.await
.map_err(Into::into),
Self::Concurrent(indexer) => indexer
.find_matches_for_request(tokens, lora_name)
.find_matches_for_request(tokens, lora_name, None)
.await
.map_err(Into::into),
}
......@@ -187,7 +187,7 @@ impl KvReplayRouter {
&request.tokens,
self.block_size,
None,
None,
BlockHashOptions::default(),
);
let response = self
.scheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment