Unverified Commit 2fe37a51 authored by huitian bai's avatar huitian bai Committed by GitHub
Browse files

fix: sglang eagle bigram tokens kv event report. (#6872)


Signed-off-by: default avatarbaihuitian <baihuitian.bht@gmail.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarPeaBrane <yanrpei@gmail.com>
parent db79f324
......@@ -194,6 +194,9 @@ async def _get_runtime_config(
if max_prefill_tokens:
runtime_config.max_num_batched_tokens = max_prefill_tokens
if server_args.speculative_algorithm in ("EAGLE", "NEXTN"):
runtime_config.enable_eagle = True
try:
# Try to check if the engine has a scheduler attribute with the computed values
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
......
......@@ -542,6 +542,8 @@ def _try_hostname_resolution() -> str | None:
)
for family, socktype, _, _, sockaddr in infos:
host_ip = sockaddr[0]
if not isinstance(host_ip, str):
continue
if not _is_routable(host_ip):
continue
try:
......
......@@ -13,10 +13,17 @@ import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, Final, Generic, Optional, TypeVar
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Final,
Generic,
Optional,
TypeVar,
)
import ray
import ray.util.state as _ray_util_state
import torch
from vllm.config import VllmConfig
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
......@@ -58,6 +65,17 @@ from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
if TYPE_CHECKING:
import ray
import ray.util.state as _ray_util_state
try:
import ray
import ray.util.state as _ray_util_state
except ModuleNotFoundError:
ray = None
_ray_util_state = None
# TODO(upstream-vllm): remove this patch once vLLM fixes add_dp_placement_groups in
# vllm/v1/engine/utils.py to use ray.nodes() instead of ray.util.state.list_nodes().
#
......@@ -84,9 +102,10 @@ class _NodeInfo:
self.node_id: str = d["NodeID"]
_ray_util_state.list_nodes = lambda **kw: [
if ray is not None and _ray_util_state is not None:
_ray_util_state.list_nodes = lambda **kw: [
_NodeInfo(n) for n in ray.nodes() if n.get("Alive", False)
]
]
# Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url"
......
......@@ -16,6 +16,7 @@ use dynamo_kv_router::{
protocols::*,
};
use dynamo_llm::kv_router::publisher::KvEventPublisher;
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_runtime::discovery::{DiscoveryQuery, hash_pod_name};
use dynamo_runtime::{DistributedRuntime, Worker};
......@@ -33,6 +34,12 @@ static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
struct DiscoveredModelBootstrap {
preprocessor: Arc<OpenAIPreprocessor>,
card: ModelDeploymentCard,
actual_namespace: String,
}
/// Convert a C string pointer to a Rust string, falling back to a default when:
/// - the pointer is NULL,
/// - the bytes are not valid UTF-8,
......@@ -221,8 +228,10 @@ fn kv_event_create_stored_block_from_parts(
let tokens_hash = compute_block_hash_for_seq(
unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
kv_block_size,
None,
BlockHashOptions {
lora_name,
..Default::default()
},
)[0];
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash),
......@@ -645,19 +654,25 @@ pub unsafe extern "C" fn create_routers(
}
};
let (preprocessor, block_size, model_name, actual_namespace) =
match init_preprocessor(&drt, &namespace_str).await {
let DiscoveredModelBootstrap {
preprocessor,
card,
actual_namespace,
} = match init_preprocessor(&drt, &namespace_str).await {
Ok(result) => result,
Err(e) => {
tracing::error!(error = %e, "Failed to initialize preprocessor");
return Err(QueryRouterResult::ErrInitFailed);
}
};
let block_size = card.kv_cache_block_size;
let model_name = card.display_name.clone();
let enable_eagle = card.runtime_config.enable_eagle;
if actual_namespace != namespace_str {
tracing::info!(
base_namespace = namespace_str,
actual_namespace = actual_namespace,
base_namespace = %namespace_str,
actual_namespace = %actual_namespace,
"Worker namespace has rolling-update suffix"
);
}
......@@ -692,6 +707,7 @@ pub unsafe extern "C" fn create_routers(
Some(kv_router_config.clone()),
WORKER_TYPE_DECODE,
Some(model_name.clone()),
enable_eagle,
)
.await
{
......@@ -762,7 +778,8 @@ pub unsafe extern "C" fn create_routers(
Some(prefill_config),
enforce_disagg,
model_name.clone(),
namespace_str.clone(),
actual_namespace.clone(),
enable_eagle,
)
}
None if enforce_disagg => {
......@@ -782,7 +799,7 @@ pub unsafe extern "C" fn create_routers(
decode_router,
model_manager,
namespace_str,
preprocessor,
Some(preprocessor),
))
});
......@@ -848,7 +865,7 @@ pub unsafe extern "C" fn add_request(
// Compute overlap_blocks using the public method
let overlap_blocks = match decode_router
.get_overlap_blocks(&tokens, worker, None)
.get_overlap_blocks(&tokens, None, worker, None)
.await
{
Ok(overlap) => overlap,
......@@ -862,6 +879,7 @@ pub unsafe extern "C" fn add_request(
.add_request(
request_id_str.clone(),
&tokens,
None,
overlap_blocks,
None,
worker,
......@@ -1279,16 +1297,15 @@ pub unsafe extern "C" fn route_decode_request(
}
}
/// Initialize the preprocessor, block size, and model name.
/// Initialize the preprocessor and fetch the model card used for routing.
///
/// Waits for discovery to sync (model card must be available for tokenization),
/// then creates the preprocessor from the model card. The `kv_cache_block_size`
/// and `model_name` are taken from the model card to ensure consistency with
/// the worker configuration.
/// then creates the preprocessor from the model card. Router settings are
/// derived directly from the returned card by the caller.
async fn init_preprocessor(
drt: &DistributedRuntime,
target_namespace: &str,
) -> anyhow::Result<(Option<Arc<OpenAIPreprocessor>>, u32, String, String)> {
) -> anyhow::Result<DiscoveredModelBootstrap> {
let instance_count = wait_for_discovery_sync(drt).await;
if instance_count == 0 {
anyhow::bail!("Discovery sync failed: no worker instances found. Is the backend running?");
......@@ -1300,7 +1317,7 @@ async fn init_preprocessor(
// Retry fetching the preprocessor: model card metadata may arrive after
// worker endpoints are registered.
let (prep, block_size, model_name, actual_namespace) = loop {
let bootstrap = loop {
match fetch_preprocessor_from_discovery(drt, target_namespace).await {
Ok(result) => break result,
Err(e) => {
......@@ -1315,13 +1332,14 @@ async fn init_preprocessor(
};
tracing::info!(
kv_cache_block_size = block_size,
model_name = model_name,
actual_namespace = actual_namespace,
kv_cache_block_size = bootstrap.card.kv_cache_block_size,
model_name = %bootstrap.card.display_name,
actual_namespace = %bootstrap.actual_namespace,
enable_eagle = bootstrap.card.runtime_config.enable_eagle,
"Preprocessor initialized from model card"
);
Ok((Some(prep), block_size, model_name, actual_namespace))
Ok(bootstrap)
}
/// Fetch model card via discovery and create preprocessor.
......@@ -1331,12 +1349,11 @@ async fn init_preprocessor(
/// 2. Finds the first model in the target namespace (decode workers only)
/// 3. Downloads the model config (tokenizer files) if needed
/// 4. Creates an OpenAIPreprocessor from the model card
/// 5. Returns the preprocessor, the kv_cache_block_size, and model_name from the model card
/// 5. Returns the preprocessor, the model card, and the resolved worker namespace
async fn fetch_preprocessor_from_discovery(
drt: &DistributedRuntime,
target_namespace: &str,
) -> anyhow::Result<(Arc<OpenAIPreprocessor>, u32, String, String)> {
use dynamo_llm::model_card::ModelDeploymentCard;
) -> anyhow::Result<DiscoveredModelBootstrap> {
use dynamo_runtime::discovery::DiscoveryInstance;
let discovery = drt.discovery();
......@@ -1383,12 +1400,11 @@ async fn fetch_preprocessor_from_discovery(
)
})?;
let kv_cache_block_size = card.kv_cache_block_size;
let model_name = card.name().to_string();
tracing::info!(
model_name = model_name,
kv_cache_block_size = kv_cache_block_size,
actual_namespace = actual_namespace,
model_name = %card.display_name,
kv_cache_block_size = card.kv_cache_block_size,
actual_namespace = %actual_namespace,
enable_eagle = card.runtime_config.enable_eagle,
"Found model card via discovery"
);
......@@ -1396,13 +1412,12 @@ async fn fetch_preprocessor_from_discovery(
card.download_config().await?;
// Create preprocessor
let preprocessor = OpenAIPreprocessor::new(card)?;
Ok((
let preprocessor = OpenAIPreprocessor::new(card.clone())?;
Ok(DiscoveredModelBootstrap {
preprocessor,
kv_cache_block_size,
model_name,
card,
actual_namespace,
))
})
}
/// Find a prefill endpoint from already-discovered instances (one-time filter).
......
......@@ -161,13 +161,14 @@ fn init_standalone_logging() {
}
#[pyfunction]
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None))]
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None, is_eagle=None))]
pub fn compute_block_hash_for_seq_py(
_py: Python,
tokens: Vec<u32>,
kv_block_size: usize,
block_mm_infos: Option<Bound<PyAny>>,
lora_name: Option<String>,
is_eagle: Option<bool>,
) -> PyResult<Vec<u64>> {
if kv_block_size == 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
......@@ -183,8 +184,11 @@ pub fn compute_block_hash_for_seq_py(
let hashes = compute_block_hash_for_seq(
&tokens,
kv_block_size as u32,
mm_infos.as_deref(),
lora_name.as_deref(),
BlockHashOptions {
block_mm_infos: mm_infos.as_deref(),
lora_name: lora_name.as_deref(),
is_eagle,
},
);
Ok(hashes.into_iter().map(|h| h.0).collect())
......@@ -310,7 +314,7 @@ impl KvEventPublisher {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None))]
#[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None, is_eagle=None))]
fn publish_stored(
&self,
py: Python,
......@@ -320,6 +324,7 @@ impl KvEventPublisher {
parent_hash: Option<i64>,
block_mm_infos: Option<Bound<PyAny>>,
lora_name: Option<String>,
is_eagle: Option<bool>,
) -> PyResult<()> {
let kv_block_size = self.kv_block_size as u32;
let dp_rank = self.dp_rank;
......@@ -347,6 +352,7 @@ impl KvEventPublisher {
lora_name.as_deref(),
&warning_count,
mm_infos.as_deref(),
is_eagle,
),
}),
dp_rank,
......@@ -716,14 +722,13 @@ async fn create_kv_router_from_endpoint(
llm_rs::discovery::WORKER_TYPE_DECODE
};
// Only query discovery for model_name when a remote indexer is configured,
// since model_name is only needed for the RemoteIndexer path.
// Query discovery once so we can derive both model_name (for remote indexer)
// and Eagle routing semantics from the model card.
let needs_model_name = kv_router_config
.as_ref()
.map(|cfg| cfg.remote_indexer_component.is_some())
.unwrap_or(false);
let model_name = if needs_model_name {
let (model_name, enable_eagle) = {
let discovery = endpoint.inner.component().drt().discovery();
let instances = discovery
.list(rs::discovery::DiscoveryQuery::EndpointModels {
......@@ -734,23 +739,26 @@ async fn create_kv_router_from_endpoint(
.await
.map_err(to_pyerr)?;
Some(
instances
.into_iter()
.find_map(|inst| {
let maybe_card = instances.into_iter().find_map(|inst| {
inst.deserialize_model::<llm_rs::model_card::ModelDeploymentCard>()
.ok()
.map(|card| card.display_name)
})
.ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"no model card found in discovery for endpoint {}/{}/{}",
endpoint_id.namespace, endpoint_id.component, endpoint_id.name
))
})?,
)
} else {
None
});
match maybe_card {
Some(card) => {
let model_name = needs_model_name.then(|| card.display_name.clone());
(model_name, card.runtime_config.enable_eagle)
}
None => {
tracing::warn!(
namespace = %endpoint_id.namespace,
component = %endpoint_id.component,
endpoint = %endpoint_id.name,
"No model card found in discovery; defaulting to non-Eagle routing semantics"
);
(None, false)
}
}
};
let kv_router = model_manager
......@@ -760,6 +768,7 @@ async fn create_kv_router_from_endpoint(
kv_router_config,
worker_type,
model_name,
enable_eagle,
)
.await
.map_err(to_pyerr)?;
......@@ -1083,7 +1092,7 @@ impl KvRouter {
block_mm_infos.as_deref(),
router_config_override.as_ref(),
update_states,
lora_name,
lora_name.clone(),
0.0,
None,
None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
......@@ -1092,8 +1101,17 @@ impl KvRouter {
.map_err(to_pyerr)?;
if update_indexer && !chooser.kv_router_config().use_kv_events {
let mut tokens_with_hashes =
TokensWithHashes::new(token_ids.clone(), chooser.block_size())
.with_is_eagle(chooser.is_eagle());
if let Some(infos) = block_mm_infos.as_ref() {
tokens_with_hashes = tokens_with_hashes.with_mm_infos(infos.clone());
}
if let Some(lora_name) = lora_name.as_ref() {
tokens_with_hashes = tokens_with_hashes.with_lora_name(lora_name.clone());
}
chooser
.record_routing_decision(token_ids.clone(), best_worker)
.record_routing_decision(tokens_with_hashes, best_worker)
.await
.map_err(to_pyerr)?;
}
......@@ -1129,18 +1147,27 @@ impl KvRouter {
})
}
#[pyo3(signature = (token_ids, lora_name=None))]
#[pyo3(signature = (token_ids, block_mm_infos=None, lora_name=None))]
fn get_potential_loads<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
block_mm_infos: Option<PyObject>,
lora_name: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
let block_mm_infos = block_mm_infos
.map(|obj| depythonize_block_mm_infos(obj.bind(py)))
.transpose()?;
let chooser = self.inner.chooser.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let loads = chooser
.get_potential_loads(&token_ids, None, lora_name.as_deref())
.get_potential_loads(
&token_ids,
None,
block_mm_infos.as_deref(),
lora_name.as_deref(),
)
.await
.map_err(to_pyerr)?;
......
......@@ -60,6 +60,11 @@ impl ModelRuntimeConfig {
self.inner.enable_local_indexer = enable_local_indexer;
}
#[setter]
fn set_enable_eagle(&mut self, enable_eagle: bool) {
self.inner.enable_eagle = enable_eagle;
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner
......@@ -159,4 +164,9 @@ impl ModelRuntimeConfig {
.as_ref()
.and_then(|e| e.bootstrap_port)
}
#[getter]
fn enable_eagle(&self) -> bool {
self.inner.enable_eagle
}
}
......@@ -279,6 +279,7 @@ def compute_block_hash_for_seq(
kv_block_size: int,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None,
is_eagle: Optional[bool] = None,
) -> List[int]:
"""
Compute block hashes for a sequence of tokens, optionally including multimodal metadata.
......@@ -299,6 +300,9 @@ def compute_block_hash_for_seq(
}
]
}
lora_name: Optional LoRA adapter name for adapter-aware block hashing.
is_eagle: Optional Eagle mode flag. When true, hashes use overlapping
`kv_block_size + 1` token windows with `kv_block_size` stride.
Returns:
List of block hashes (one per block)
......@@ -478,6 +482,7 @@ class ModelRuntimeConfig:
data_parallel_start_rank: int
data_parallel_size: int
enable_local_indexer: bool
enable_eagle: bool
runtime_data: dict[str, Any]
tensor_model_config: Any | None
bootstrap_host: str | None
......@@ -634,7 +639,7 @@ class KvIndexer:
...
def find_matches_for_request(
self, token_ids: List[int], lora_name: Optional[str] = None
self, token_ids: List[int], lora_name: Optional[str] = None, is_eagle: Optional[bool] = None
) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
......@@ -682,7 +687,7 @@ class ApproxKvIndexer:
...
def find_matches_for_request(
self, token_ids: List[int], lora_name: Optional[str] = None
self, token_ids: List[int], lora_name: Optional[str] = None, is_eagle: Optional[bool] = None
) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
......@@ -765,6 +770,7 @@ class KvEventPublisher:
parent_hash: Optional[int] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None,
is_eagle: Optional[bool] = None,
) -> None:
"""
Publish a KV stored event.
......@@ -780,6 +786,8 @@ class KvEventPublisher:
Each item is either None or a dict with "mm_objects" key containing
a list of {"mm_hash": int, "offsets": [[start, end], ...]} dicts.
lora_name: Optional LoRA adapter name for adapter-aware block hashing.
is_eagle: Optional Eagle mode flag. When true, stored blocks are
reconstructed using overlapping `kv_block_size + 1` token windows.
"""
...
......@@ -1739,6 +1747,7 @@ class KvRouter:
async def get_potential_loads(
self,
token_ids: List[int],
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None,
) -> List[Dict[str, int]]:
"""
......@@ -1746,6 +1755,9 @@ class KvRouter:
Args:
token_ids: List of token IDs to evaluate
block_mm_infos: Optional block-level multimodal metadata aligned to request
blocks. When provided, this is used in hash computation
for MM-aware potential-load estimation.
Returns:
A list of dictionaries, each containing:
......
......@@ -449,13 +449,22 @@ impl KvIndexerInterface for KvIndexer {
&self,
tokens: &[u32],
lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError> {
tracing::debug!(
"Finding matches for request tokens: {:?} / len: {}",
tokens,
tokens.len()
);
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name);
let sequence = compute_block_hash_for_seq(
tokens,
self.kv_block_size,
BlockHashOptions {
lora_name,
is_eagle,
..Default::default()
},
);
tracing::debug!("Computed sequence: {:?}", sequence);
self.find_matches(sequence).await
}
......
......@@ -271,9 +271,10 @@ impl KvIndexerInterface for LocalKvIndexer {
&self,
tokens: &[u32],
lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError> {
self.indexer
.find_matches_for_request(tokens, lora_name)
.find_matches_for_request(tokens, lora_name, is_eagle)
.await
}
......
......@@ -349,7 +349,7 @@ mod tests {
// 1. Before routing decision there should be no matches
let pre_scores = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.expect("indexer offline");
assert!(pre_scores.scores.is_empty());
......@@ -367,7 +367,7 @@ mod tests {
// Poll until we observe the match being registered
spin_until(Duration::from_millis(100), async || {
let s = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
s.scores
......@@ -380,7 +380,7 @@ mod tests {
// 3. After the TTL has passed the entry should expire automatically
time::sleep(TTL + Duration::from_millis(50)).await;
let post_scores = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
assert!(post_scores.scores.is_empty());
......@@ -420,7 +420,7 @@ mod tests {
// Wait until the worker is registered
spin_until(Duration::from_millis(100), async || {
let s = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
s.scores
......@@ -434,7 +434,7 @@ mod tests {
// Ensure the worker's entries are gone
spin_until(Duration::from_millis(100), async || {
let s = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
!s.scores
......@@ -488,7 +488,7 @@ mod tests {
// Ensure both workers are registered
spin_until(Duration::from_millis(100), async || {
let s = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
s.scores
......@@ -508,7 +508,7 @@ mod tests {
// Confirm the removed worker is gone, and the other remains.
spin_until(Duration::from_millis(100), async || {
let s = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
!s.scores
......@@ -558,7 +558,7 @@ mod tests {
// Ensure the indexer has registered the block
spin_until(Duration::from_millis(100), async || {
let s = indexer
.find_matches_for_request(&seq_a, None)
.find_matches_for_request(&seq_a, None, None)
.await
.unwrap();
s.scores
......@@ -573,7 +573,7 @@ mod tests {
// Query the indexer for overlaps of Sequence B (before it has been routed anywhere)
let overlap = indexer
.find_matches_for_request(&seq_b, None)
.find_matches_for_request(&seq_b, None, None)
.await
.unwrap();
......@@ -631,7 +631,7 @@ mod tests {
// Wait until both workers are reflected in overlap scores
spin_until(Duration::from_millis(100), async || {
let s = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
s.scores
......@@ -646,7 +646,7 @@ mod tests {
.await;
let scores = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
......@@ -808,7 +808,7 @@ mod tests {
for i in 0..5 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
assert_eq!(
......@@ -837,7 +837,7 @@ mod tests {
for i in 0..4 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
assert!(
......@@ -851,7 +851,7 @@ mod tests {
for i in 4..6 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer
.find_matches_for_request(&tokens, None)
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
assert_eq!(
......
......@@ -424,8 +424,17 @@ impl KvIndexerInterface for KvIndexerSharded {
&self,
tokens: &[u32],
lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name);
let sequence = compute_block_hash_for_seq(
tokens,
self.kv_block_size,
BlockHashOptions {
lora_name,
is_eagle,
..Default::default()
},
);
self.find_matches(sequence).await
}
......
......@@ -547,7 +547,10 @@ mod interface_tests {
// Empty index should return no matches
let tokens = vec![1, 2, 3, 4];
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
let scores = index
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
assert!(scores.scores.is_empty());
// Store some data and verify we can find it via tokens
......@@ -559,7 +562,10 @@ mod interface_tests {
// Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error.
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
let scores = index
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens
assert!(scores.scores.is_empty() || !scores.scores.is_empty());
......@@ -883,9 +889,16 @@ mod lora_tests {
// Same token sequence for both base model and LoRA adapter
let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
let base_hashes = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let lora_hashes =
compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
let base_hashes =
compute_block_hash_for_seq(&tokens, kv_block_size, BlockHashOptions::default());
let lora_hashes = compute_block_hash_for_seq(
&tokens,
kv_block_size,
BlockHashOptions {
lora_name: Some("my-adapter"),
..Default::default()
},
);
// Hashes must differ despite identical tokens
assert_ne!(
......@@ -970,9 +983,16 @@ mod lora_tests {
let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
// With LoRA-aware hashing, base and adapter produce different LocalBlockHash
let base_local = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let lora_local =
compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
let base_local =
compute_block_hash_for_seq(&tokens, kv_block_size, BlockHashOptions::default());
let lora_local = compute_block_hash_for_seq(
&tokens,
kv_block_size,
BlockHashOptions {
lora_name: Some("my-adapter"),
..Default::default()
},
);
assert_ne!(
base_local, lora_local,
......@@ -1044,8 +1064,22 @@ mod lora_tests {
let tokens: Vec<u32> = (0..kv_block_size * 2).collect();
let hashes_a = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-a"));
let hashes_b = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-b"));
let hashes_a = compute_block_hash_for_seq(
&tokens,
kv_block_size,
BlockHashOptions {
lora_name: Some("adapter-a"),
..Default::default()
},
);
let hashes_b = compute_block_hash_for_seq(
&tokens,
kv_block_size,
BlockHashOptions {
lora_name: Some("adapter-b"),
..Default::default()
},
);
assert_ne!(
hashes_a, hashes_b,
......
......@@ -159,8 +159,17 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
&self,
tokens: &[u32],
lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name);
let sequence = compute_block_hash_for_seq(
tokens,
self.kv_block_size,
BlockHashOptions {
lora_name,
is_eagle,
..Default::default()
},
);
Ok(self.backend.find_matches(&sequence, false))
}
......
......@@ -36,6 +36,7 @@ pub trait KvIndexerInterface {
&self,
tokens: &[u32],
lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError>;
/// Apply a `RouterEvent` to the KV store.
......
......@@ -24,6 +24,13 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
LocalBlockHash(compute_hash(data))
}
#[derive(Debug, Clone, Copy, Default)]
pub struct BlockHashOptions<'a> {
pub block_mm_infos: Option<&'a [Option<BlockExtraInfo>]>,
pub lora_name: Option<&'a str>,
pub is_eagle: Option<bool>,
}
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata
/// and LoRA adapter identity.
///
......@@ -39,20 +46,30 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
pub fn compute_block_hash_for_seq(
tokens: &[u32],
kv_block_size: u32,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
lora_name: Option<&str>,
options: BlockHashOptions<'_>,
) -> Vec<LocalBlockHash> {
let seed = match lora_name.filter(|n| !n.is_empty()) {
if kv_block_size == 0 {
return Vec::new();
}
let seed = match options.lora_name.filter(|n| !n.is_empty()) {
Some(name) => XXH3_SEED.wrapping_add(xxh3::xxh3_64(name.as_bytes())),
None => XXH3_SEED,
};
tokens
.chunks_exact(kv_block_size as usize)
.enumerate()
.map(|(block_idx, chunk)| {
let is_eagle_flag = options.is_eagle.unwrap_or(false);
let stride = kv_block_size as usize;
let window_size = if is_eagle_flag { stride + 1 } else { stride };
let mut hashes = Vec::new();
let mut block_idx = 0;
let mut start = 0;
while start + window_size <= tokens.len() {
let chunk = &tokens[start..start + window_size];
let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
if let Some(mm_infos) = block_mm_infos
if let Some(mm_infos) = options.block_mm_infos
&& let Some(Some(block_mm_info)) = mm_infos.get(block_idx)
{
let mut mm_hashes: Vec<u64> = block_mm_info
......@@ -67,9 +84,13 @@ pub fn compute_block_hash_for_seq(
}
}
LocalBlockHash(xxh3::xxh3_64_with_seed(&bytes, seed))
})
.collect()
hashes.push(LocalBlockHash(xxh3::xxh3_64_with_seed(&bytes, seed)));
start += stride;
block_idx += 1;
}
hashes
}
/// Compute rolling sequence hashes for a vector of block hashes.
......@@ -718,6 +739,7 @@ pub struct TokensWithHashes {
lora_name: Option<String>,
block_hashes: Option<Vec<LocalBlockHash>>,
seq_hashes: Option<Vec<SequenceHash>>,
is_eagle: Option<bool>,
}
impl TokensWithHashes {
......@@ -730,6 +752,7 @@ impl TokensWithHashes {
lora_name: None,
block_hashes: None,
seq_hashes: None,
is_eagle: None,
}
}
......@@ -745,6 +768,24 @@ impl TokensWithHashes {
self
}
/// Sets Eagle hashing semantics for this token sequence.
pub fn with_is_eagle(mut self, is_eagle: bool) -> Self {
self.set_is_eagle(is_eagle);
self
}
/// Updates Eagle hashing semantics and invalidates cached hashes when it changes.
pub fn set_is_eagle(&mut self, is_eagle: bool) {
let is_eagle = Some(is_eagle);
if self.is_eagle == is_eagle {
return;
}
self.is_eagle = is_eagle;
self.block_hashes = None;
self.seq_hashes = None;
}
/// Returns a reference to the tokens.
pub fn tokens(&self) -> &[u32] {
&self.tokens
......@@ -776,8 +817,11 @@ impl TokensWithHashes {
self.block_hashes = Some(compute_block_hash_for_seq(
&self.tokens,
self.block_size,
self.block_mm_infos.as_deref(),
self.lora_name.as_deref(),
BlockHashOptions {
block_mm_infos: self.block_mm_infos.as_deref(),
lora_name: self.lora_name.as_deref(),
is_eagle: self.is_eagle,
},
));
}
self.block_hashes.as_ref().unwrap()
......@@ -858,24 +902,41 @@ mod tests {
#[case(64)]
fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
let sequence = (0..kv_block_size).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None);
let hashes =
compute_block_hash_for_seq(&sequence, kv_block_size, BlockHashOptions::default());
assert_eq!(hashes.len(), 1);
let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None);
let hashes =
compute_block_hash_for_seq(&sequence, kv_block_size, BlockHashOptions::default());
assert_eq!(hashes.len(), 1);
let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None);
let hashes =
compute_block_hash_for_seq(&sequence, kv_block_size, BlockHashOptions::default());
assert_eq!(hashes.len(), 2);
}
#[test]
fn test_lora_name_produces_different_hash() {
let tokens: Vec<u32> = (0..4).collect();
let base = compute_block_hash_for_seq(&tokens, 4, None, None);
let lora_a = compute_block_hash_for_seq(&tokens, 4, None, Some("adapter-a"));
let lora_b = compute_block_hash_for_seq(&tokens, 4, None, Some("adapter-b"));
let base = compute_block_hash_for_seq(&tokens, 4, BlockHashOptions::default());
let lora_a = compute_block_hash_for_seq(
&tokens,
4,
BlockHashOptions {
lora_name: Some("adapter-a"),
..Default::default()
},
);
let lora_b = compute_block_hash_for_seq(
&tokens,
4,
BlockHashOptions {
lora_name: Some("adapter-b"),
..Default::default()
},
);
assert_ne!(base[0], lora_a[0]);
assert_ne!(base[0], lora_b[0]);
......@@ -885,16 +946,23 @@ mod tests {
#[test]
fn test_lora_name_none_matches_legacy() {
let tokens: Vec<u32> = (0..8).collect();
let hashes_none = compute_block_hash_for_seq(&tokens, 4, None, None);
let hashes_none2 = compute_block_hash_for_seq(&tokens, 4, None, None);
let hashes_none = compute_block_hash_for_seq(&tokens, 4, BlockHashOptions::default());
let hashes_none2 = compute_block_hash_for_seq(&tokens, 4, BlockHashOptions::default());
assert_eq!(hashes_none, hashes_none2);
}
#[test]
fn test_lora_name_empty_string_normalized_to_none() {
let tokens: Vec<u32> = (0..4).collect();
let base = compute_block_hash_for_seq(&tokens, 4, None, None);
let empty = compute_block_hash_for_seq(&tokens, 4, None, Some(""));
let base = compute_block_hash_for_seq(&tokens, 4, BlockHashOptions::default());
let empty = compute_block_hash_for_seq(
&tokens,
4,
BlockHashOptions {
lora_name: Some(""),
..Default::default()
},
);
assert_eq!(
base, empty,
"empty lora_name should be treated as base model"
......@@ -918,6 +986,73 @@ mod tests {
}
}
#[test]
fn test_compute_block_hash_for_seq_eagle_windows() {
let tokens: Vec<u32> = (0..6).collect();
let default_hashes = compute_block_hash_for_seq(&tokens, 2, BlockHashOptions::default());
let eagle_hashes = compute_block_hash_for_seq(
&tokens,
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
let expected_first = compute_block_hash_for_seq(
&[0, 1, 2],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
let expected_second = compute_block_hash_for_seq(
&[2, 3, 4],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
assert_eq!(default_hashes.len(), 3);
assert_eq!(eagle_hashes.len(), 2);
assert_eq!(eagle_hashes, vec![expected_first[0], expected_second[0]]);
assert_ne!(default_hashes[0], eagle_hashes[0]);
}
#[test]
fn test_tokens_with_hashes_set_is_eagle_invalidates_cache() {
let tokens: Vec<u32> = (0..6).collect();
let mut with_hashes = TokensWithHashes::new(tokens, 2);
let default_hashes = with_hashes.get_or_compute_block_hashes().to_vec();
with_hashes.set_is_eagle(true);
let eagle_hashes = with_hashes.get_or_compute_block_hashes().to_vec();
let expected_first = compute_block_hash_for_seq(
&[0, 1, 2],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
let expected_second = compute_block_hash_for_seq(
&[2, 3, 4],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
assert_eq!(default_hashes.len(), 3);
assert_eq!(eagle_hashes.len(), 2);
assert_eq!(eagle_hashes, vec![expected_first[0], expected_second[0]]);
assert_ne!(default_hashes[0], eagle_hashes[0]);
}
#[test]
fn test_local_block_hash_serialization() {
let hash = LocalBlockHash(12345);
......
......@@ -9,7 +9,7 @@ use rand::Rng;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError};
use crate::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
use crate::protocols::{BlockHashOptions, compute_block_hash_for_seq, compute_seq_hash_for_block};
const fn default_min_initial_workers() -> usize {
1
......@@ -217,7 +217,7 @@ impl KvRouterConfig {
tokens: &[u32],
block_size: u32,
config_override: Option<&RouterConfigOverride>,
lora_name: Option<&str>,
hash_options: BlockHashOptions<'_>,
) -> Option<Vec<u64>> {
if !self.router_track_active_blocks {
return None;
......@@ -233,7 +233,7 @@ impl KvRouterConfig {
.unwrap_or(self.router_assume_kv_reuse);
if assume_kv_reuse {
let block_hashes = compute_block_hash_for_seq(tokens, block_size, None, lora_name);
let block_hashes = compute_block_hash_for_seq(tokens, block_size, hash_options);
Some(compute_seq_hash_for_block(&block_hashes))
} else {
let mut rng = rand::rng();
......@@ -257,6 +257,7 @@ impl KvRouterConfig {
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::{BlockExtraInfo, BlockMmObjectInfo};
#[test]
fn router_queue_policy_display_and_parse_support_lcfs() {
......@@ -288,4 +289,36 @@ mod tests {
};
assert!(cfg.validate().is_err());
}
#[test]
fn compute_seq_hashes_for_tracking_uses_mm_hashes() {
let cfg = KvRouterConfig::default();
let tokens = vec![1, 2, 3, 4];
let mm_infos = vec![
Some(BlockExtraInfo {
mm_objects: vec![BlockMmObjectInfo {
mm_hash: 42,
offsets: vec![],
}],
}),
None,
];
let without_mm = cfg
.compute_seq_hashes_for_tracking(&tokens, 2, None, BlockHashOptions::default())
.unwrap();
let with_mm = cfg
.compute_seq_hashes_for_tracking(
&tokens,
2,
None,
BlockHashOptions {
block_mm_infos: Some(&mm_infos),
..Default::default()
},
)
.unwrap();
assert_ne!(without_mm, with_mm);
}
}
......@@ -13,7 +13,7 @@ use axum::{Json, Router};
use prometheus::Encoder;
use serde::{Deserialize, Serialize};
use crate::protocols::{LocalBlockHash, WorkerId, compute_block_hash_for_seq};
use crate::protocols::{BlockHashOptions, LocalBlockHash, WorkerId, compute_block_hash_for_seq};
use super::registry::{IndexerKey, ListenerControlError, WorkerRegistry};
......@@ -197,8 +197,14 @@ async fn query(
let indexer = ie.indexer.clone();
drop(ie);
let block_hashes =
compute_block_hash_for_seq(&req.token_ids, block_size, None, req.lora_name.as_deref());
let block_hashes = compute_block_hash_for_seq(
&req.token_ids,
block_size,
BlockHashOptions {
lora_name: req.lora_name.as_deref(),
..Default::default()
},
);
match indexer.find_matches(block_hashes).await {
Ok(overlap) => (
StatusCode::OK,
......
......@@ -17,9 +17,9 @@ use serde::Serialize;
use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor};
use crate::protocols::{
BlockExtraInfo, BlockMmObjectInfo, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData,
KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, Placement, PlacementEvent,
StorageTier, WorkerWithDpRank, compute_block_hash_for_seq,
BlockExtraInfo, BlockHashOptions, BlockMmObjectInfo, ExternalSequenceBlockHash, KvCacheEvent,
KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, Placement,
PlacementEvent, StorageTier, WorkerWithDpRank, compute_block_hash_for_seq,
};
// -------------------------------------------------------------------------
......@@ -65,6 +65,13 @@ impl BlockHashValue {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum KvTokenIds {
Single(Vec<u32>),
Bigram(Vec<(u32, u32)>),
}
#[derive(Debug, Serialize, Clone)]
#[serde(tag = "type")] // msgspec encodes variant tag as a string when `tag=True`
pub enum RawKvEvent {
......@@ -83,6 +90,8 @@ pub enum RawKvEvent {
/// Multimodal extra info for each block (length should match block_hashes)
#[serde(default, skip_serializing_if = "Option::is_none")]
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
#[serde(skip_serializing_if = "Option::is_none")]
is_eagle: Option<bool>,
},
BlockRemoved {
block_hashes: Vec<BlockHashValue>,
......@@ -180,7 +189,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let mut event_type: Option<String> = None;
let mut block_hashes: Option<Vec<BlockHashValue>> = None;
let mut parent_block_hash: Option<Option<BlockHashValue>> = None;
let mut token_ids: Option<Vec<u32>> = None;
let mut token_ids: Option<KvTokenIds> = None;
let mut block_size: Option<usize> = None;
let mut medium: Option<Option<String>> = None;
let mut lora_name: Option<Option<String>> = None;
......@@ -227,6 +236,17 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let block_hashes =
block_hashes.ok_or_else(|| de::Error::missing_field("block_hashes"))?;
let token_ids = token_ids.ok_or_else(|| de::Error::missing_field("token_ids"))?;
let (raw_token_ids, is_eagle) = match token_ids {
KvTokenIds::Single(tids) => (tids, false),
KvTokenIds::Bigram(tids) => {
let mut new_tids: Vec<u32> = tids.iter().map(|&(first, _)| first).collect();
if !tids.is_empty() {
let last_token = tids.last().map(|&(_, second)| second).unwrap();
new_tids.push(last_token);
}
(new_tids, true)
}
};
let block_size =
block_size.ok_or_else(|| de::Error::missing_field("block_size"))?;
let block_mm_infos = block_mm_infos
......@@ -235,11 +255,12 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
Ok(RawKvEvent::BlockStored {
block_hashes,
parent_block_hash: parent_block_hash.unwrap_or(None),
token_ids,
token_ids: raw_token_ids,
block_size,
medium: medium.unwrap_or(None),
lora_name: lora_name.unwrap_or(None),
block_mm_infos,
is_eagle: Some(is_eagle),
})
}
Some("BlockRemoved") => {
......@@ -277,7 +298,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &"missing block_hashes"))?;
let parent_block_hash: Option<BlockHashValue> = seq.next_element()?.unwrap_or(None);
let token_ids: Vec<u32> = seq
let token_ids: KvTokenIds = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(3, &"missing token_ids"))?;
let block_size: usize = seq
......@@ -297,14 +318,27 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let block_mm_infos =
block_mm_infos.or_else(|| extra_keys_to_block_mm_infos(extra_keys));
let (raw_token_ids, is_eagle) = match token_ids {
KvTokenIds::Single(tids) => (tids, false),
KvTokenIds::Bigram(tids) => {
let mut new_tids: Vec<u32> = tids.iter().map(|&(first, _)| first).collect();
if !tids.is_empty() {
let last_token = tids.last().map(|&(_, second)| second).unwrap();
new_tids.push(last_token);
}
(new_tids, true)
}
};
Ok(RawKvEvent::BlockStored {
block_hashes,
parent_block_hash,
token_ids,
token_ids: raw_token_ids,
block_size,
medium,
lora_name,
block_mm_infos,
is_eagle: Some(is_eagle),
})
}
"BlockRemoved" => {
......@@ -360,6 +394,7 @@ pub fn convert_event(
lora_name,
block_mm_infos,
medium: _,
is_eagle,
} => {
// Reject self-referencing blocks: all block hashes (including parent) must be unique.
{
......@@ -408,6 +443,7 @@ pub fn convert_event(
lora_name.as_deref(),
warning_count,
block_mm_infos.as_deref(),
is_eagle,
),
}),
dp_rank,
......@@ -446,13 +482,17 @@ pub fn create_stored_block_from_parts(
token_ids: &[u32],
lora_name: Option<&str>,
mm_extra_info: Option<BlockExtraInfo>,
is_eagle: Option<bool>,
) -> KvCacheStoredBlockData {
let block_mm_infos = mm_extra_info.as_ref().map(|info| vec![Some(info.clone())]);
let tokens_hash = compute_block_hash_for_seq(
token_ids,
kv_block_size,
block_mm_infos.as_deref(),
BlockHashOptions {
block_mm_infos: block_mm_infos.as_deref(),
lora_name,
is_eagle,
},
)[0];
tracing::trace!(
......@@ -470,6 +510,7 @@ pub fn create_stored_block_from_parts(
}
}
#[allow(clippy::too_many_arguments)]
pub fn create_stored_blocks(
kv_block_size: u32,
token_ids: &[u32],
......@@ -478,10 +519,13 @@ pub fn create_stored_blocks(
lora_name: Option<&str>,
warning_count: &Arc<AtomicU32>,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
is_eagle: Option<bool>,
) -> Vec<KvCacheStoredBlockData> {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
let mut token_offset: usize = 0;
let append = is_eagle.unwrap_or(false) as usize;
for (block_idx, (num_tokens_it, block_hash_it)) in
num_block_tokens.iter().zip(block_hashes.iter()).enumerate()
{
......@@ -496,7 +540,19 @@ pub fn create_stored_blocks(
break;
}
let tokens = &token_ids[token_offset..(token_offset + *num_tokens_it as usize)];
let end = token_offset + append + *num_tokens_it as usize;
if end > token_ids.len() {
if warning_count.fetch_add(1, Ordering::Relaxed) < 3 {
tracing::warn!(
"Block not published. token_ids too short: need {}, got {}",
end,
token_ids.len()
);
}
break;
}
let tokens = &token_ids[token_offset..end];
let mm_extra_info = block_mm_infos
.and_then(|infos| infos.get(block_idx))
.and_then(|opt| opt.clone());
......@@ -507,9 +563,102 @@ pub fn create_stored_blocks(
tokens,
lora_name,
mm_extra_info,
is_eagle,
));
token_offset += *num_tokens_it as usize;
}
blocks
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::AtomicU32;
use rmp_serde::{from_slice, to_vec};
use super::*;
#[test]
fn test_deserialize_bigram_block_stored_sequence() {
let raw_event = (
"BlockStored",
vec![BlockHashValue::Unsigned(11), BlockHashValue::Unsigned(12)],
Option::<BlockHashValue>::None,
vec![(10u32, 11u32), (11, 12), (12, 13), (13, 14)],
2usize,
Option::<u64>::None,
Option::<String>::None,
Option::<String>::None,
);
let encoded = to_vec(&raw_event).unwrap();
let event: RawKvEvent = from_slice(&encoded).unwrap();
match event {
RawKvEvent::BlockStored {
token_ids,
block_size,
is_eagle,
..
} => {
assert_eq!(token_ids, vec![10, 11, 12, 13, 14]);
assert_eq!(block_size, 2);
assert_eq!(is_eagle, Some(true));
}
other => panic!("expected BlockStored, got {other:?}"),
}
}
#[test]
fn test_convert_event_bigram_emits_eagle_windows() {
let raw_event = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(21), BlockHashValue::Unsigned(22)],
parent_block_hash: None,
token_ids: vec![10, 11, 12, 13, 14],
block_size: 2,
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: Some(true),
};
let warning_count = Arc::new(AtomicU32::new(0));
let placement_event =
convert_event(raw_event, 7, 2, WorkerWithDpRank::new(3, 0), &warning_count);
match placement_event.event.data {
KvCacheEventData::Stored(store_data) => {
assert_eq!(store_data.blocks.len(), 2);
assert_eq!(
store_data.blocks[0].block_hash,
ExternalSequenceBlockHash(21)
);
assert_eq!(
store_data.blocks[1].block_hash,
ExternalSequenceBlockHash(22)
);
let expected_first = compute_block_hash_for_seq(
&[10, 11, 12],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
let expected_second = compute_block_hash_for_seq(
&[12, 13, 14],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
assert_eq!(store_data.blocks[0].tokens_hash, expected_first[0]);
assert_eq!(store_data.blocks[1].tokens_hash, expected_second[0]);
}
other => panic!("expected Stored event, got {other:?}"),
}
}
}
......@@ -562,6 +562,7 @@ impl ModelManager {
kv_router_config: Option<KvRouterConfig>,
worker_type: &'static str,
model_name: Option<String>,
is_eagle: bool,
) -> anyhow::Result<Arc<KvRouter>> {
let client = endpoint.client().await?;
......@@ -597,6 +598,7 @@ impl ModelManager {
kv_router_config,
worker_type,
model_name,
is_eagle,
)
.await?;
Ok(Arc::new(chooser))
......
......@@ -465,6 +465,7 @@ impl ModelWatcher {
Some(self.router_config.kv_router_config.clone()),
WORKER_TYPE_DECODE, // This is the decode router
Some(card.display_name.clone()),
card.runtime_config.enable_eagle,
)
.await?,
)
......@@ -495,6 +496,7 @@ impl ModelWatcher {
self.router_config.enforce_disagg,
model_name.clone(),
namespace.clone(),
card.runtime_config.enable_eagle,
)
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment