Unverified Commit 2fe37a51 authored by huitian bai's avatar huitian bai Committed by GitHub
Browse files

fix: sglang eagle bigram tokens kv event report. (#6872)


Signed-off-by: default avatarbaihuitian <baihuitian.bht@gmail.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarPeaBrane <yanrpei@gmail.com>
parent db79f324
...@@ -194,6 +194,9 @@ async def _get_runtime_config( ...@@ -194,6 +194,9 @@ async def _get_runtime_config(
if max_prefill_tokens: if max_prefill_tokens:
runtime_config.max_num_batched_tokens = max_prefill_tokens runtime_config.max_num_batched_tokens = max_prefill_tokens
if server_args.speculative_algorithm in ("EAGLE", "NEXTN"):
runtime_config.enable_eagle = True
try: try:
# Try to check if the engine has a scheduler attribute with the computed values # Try to check if the engine has a scheduler attribute with the computed values
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None: if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
......
...@@ -542,6 +542,8 @@ def _try_hostname_resolution() -> str | None: ...@@ -542,6 +542,8 @@ def _try_hostname_resolution() -> str | None:
) )
for family, socktype, _, _, sockaddr in infos: for family, socktype, _, _, sockaddr in infos:
host_ip = sockaddr[0] host_ip = sockaddr[0]
if not isinstance(host_ip, str):
continue
if not _is_routable(host_ip): if not _is_routable(host_ip):
continue continue
try: try:
......
...@@ -13,10 +13,17 @@ import time ...@@ -13,10 +13,17 @@ import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, Final, Generic, Optional, TypeVar from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Final,
Generic,
Optional,
TypeVar,
)
import ray
import ray.util.state as _ray_util_state
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
...@@ -58,6 +65,17 @@ from .multimodal_utils.hash_utils import compute_mm_uuids_from_images ...@@ -58,6 +65,17 @@ from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model from .multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
if TYPE_CHECKING:
import ray
import ray.util.state as _ray_util_state
try:
import ray
import ray.util.state as _ray_util_state
except ModuleNotFoundError:
ray = None
_ray_util_state = None
# TODO(upstream-vllm): remove this patch once vLLM fixes add_dp_placement_groups in # TODO(upstream-vllm): remove this patch once vLLM fixes add_dp_placement_groups in
# vllm/v1/engine/utils.py to use ray.nodes() instead of ray.util.state.list_nodes(). # vllm/v1/engine/utils.py to use ray.nodes() instead of ray.util.state.list_nodes().
# #
...@@ -84,9 +102,10 @@ class _NodeInfo: ...@@ -84,9 +102,10 @@ class _NodeInfo:
self.node_id: str = d["NodeID"] self.node_id: str = d["NodeID"]
_ray_util_state.list_nodes = lambda **kw: [ if ray is not None and _ray_util_state is not None:
_ray_util_state.list_nodes = lambda **kw: [
_NodeInfo(n) for n in ray.nodes() if n.get("Alive", False) _NodeInfo(n) for n in ray.nodes() if n.get("Alive", False)
] ]
# Multimodal data dictionary keys # Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url" IMAGE_URL_KEY: Final = "image_url"
......
...@@ -16,6 +16,7 @@ use dynamo_kv_router::{ ...@@ -16,6 +16,7 @@ use dynamo_kv_router::{
protocols::*, protocols::*,
}; };
use dynamo_llm::kv_router::publisher::KvEventPublisher; use dynamo_llm::kv_router::publisher::KvEventPublisher;
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_llm::preprocessor::OpenAIPreprocessor; use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_runtime::discovery::{DiscoveryQuery, hash_pod_name}; use dynamo_runtime::discovery::{DiscoveryQuery, hash_pod_name};
use dynamo_runtime::{DistributedRuntime, Worker}; use dynamo_runtime::{DistributedRuntime, Worker};
...@@ -33,6 +34,12 @@ static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new(); ...@@ -33,6 +34,12 @@ static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls? // [FIXME] shouldn't the publisher be instance passing between API calls?
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new(); static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
struct DiscoveredModelBootstrap {
preprocessor: Arc<OpenAIPreprocessor>,
card: ModelDeploymentCard,
actual_namespace: String,
}
/// Convert a C string pointer to a Rust string, falling back to a default when: /// Convert a C string pointer to a Rust string, falling back to a default when:
/// - the pointer is NULL, /// - the pointer is NULL,
/// - the bytes are not valid UTF-8, /// - the bytes are not valid UTF-8,
...@@ -221,8 +228,10 @@ fn kv_event_create_stored_block_from_parts( ...@@ -221,8 +228,10 @@ fn kv_event_create_stored_block_from_parts(
let tokens_hash = compute_block_hash_for_seq( let tokens_hash = compute_block_hash_for_seq(
unsafe { std::slice::from_raw_parts(token_ids, num_tokens) }, unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
kv_block_size, kv_block_size,
None, BlockHashOptions {
lora_name, lora_name,
..Default::default()
},
)[0]; )[0];
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash), block_hash: ExternalSequenceBlockHash(block_hash),
...@@ -645,19 +654,25 @@ pub unsafe extern "C" fn create_routers( ...@@ -645,19 +654,25 @@ pub unsafe extern "C" fn create_routers(
} }
}; };
let (preprocessor, block_size, model_name, actual_namespace) = let DiscoveredModelBootstrap {
match init_preprocessor(&drt, &namespace_str).await { preprocessor,
card,
actual_namespace,
} = match init_preprocessor(&drt, &namespace_str).await {
Ok(result) => result, Ok(result) => result,
Err(e) => { Err(e) => {
tracing::error!(error = %e, "Failed to initialize preprocessor"); tracing::error!(error = %e, "Failed to initialize preprocessor");
return Err(QueryRouterResult::ErrInitFailed); return Err(QueryRouterResult::ErrInitFailed);
} }
}; };
let block_size = card.kv_cache_block_size;
let model_name = card.display_name.clone();
let enable_eagle = card.runtime_config.enable_eagle;
if actual_namespace != namespace_str { if actual_namespace != namespace_str {
tracing::info!( tracing::info!(
base_namespace = namespace_str, base_namespace = %namespace_str,
actual_namespace = actual_namespace, actual_namespace = %actual_namespace,
"Worker namespace has rolling-update suffix" "Worker namespace has rolling-update suffix"
); );
} }
...@@ -692,6 +707,7 @@ pub unsafe extern "C" fn create_routers( ...@@ -692,6 +707,7 @@ pub unsafe extern "C" fn create_routers(
Some(kv_router_config.clone()), Some(kv_router_config.clone()),
WORKER_TYPE_DECODE, WORKER_TYPE_DECODE,
Some(model_name.clone()), Some(model_name.clone()),
enable_eagle,
) )
.await .await
{ {
...@@ -762,7 +778,8 @@ pub unsafe extern "C" fn create_routers( ...@@ -762,7 +778,8 @@ pub unsafe extern "C" fn create_routers(
Some(prefill_config), Some(prefill_config),
enforce_disagg, enforce_disagg,
model_name.clone(), model_name.clone(),
namespace_str.clone(), actual_namespace.clone(),
enable_eagle,
) )
} }
None if enforce_disagg => { None if enforce_disagg => {
...@@ -782,7 +799,7 @@ pub unsafe extern "C" fn create_routers( ...@@ -782,7 +799,7 @@ pub unsafe extern "C" fn create_routers(
decode_router, decode_router,
model_manager, model_manager,
namespace_str, namespace_str,
preprocessor, Some(preprocessor),
)) ))
}); });
...@@ -848,7 +865,7 @@ pub unsafe extern "C" fn add_request( ...@@ -848,7 +865,7 @@ pub unsafe extern "C" fn add_request(
// Compute overlap_blocks using the public method // Compute overlap_blocks using the public method
let overlap_blocks = match decode_router let overlap_blocks = match decode_router
.get_overlap_blocks(&tokens, worker, None) .get_overlap_blocks(&tokens, None, worker, None)
.await .await
{ {
Ok(overlap) => overlap, Ok(overlap) => overlap,
...@@ -862,6 +879,7 @@ pub unsafe extern "C" fn add_request( ...@@ -862,6 +879,7 @@ pub unsafe extern "C" fn add_request(
.add_request( .add_request(
request_id_str.clone(), request_id_str.clone(),
&tokens, &tokens,
None,
overlap_blocks, overlap_blocks,
None, None,
worker, worker,
...@@ -1279,16 +1297,15 @@ pub unsafe extern "C" fn route_decode_request( ...@@ -1279,16 +1297,15 @@ pub unsafe extern "C" fn route_decode_request(
} }
} }
/// Initialize the preprocessor, block size, and model name. /// Initialize the preprocessor and fetch the model card used for routing.
/// ///
/// Waits for discovery to sync (model card must be available for tokenization), /// Waits for discovery to sync (model card must be available for tokenization),
/// then creates the preprocessor from the model card. The `kv_cache_block_size` /// then creates the preprocessor from the model card. Router settings are
/// and `model_name` are taken from the model card to ensure consistency with /// derived directly from the returned card by the caller.
/// the worker configuration.
async fn init_preprocessor( async fn init_preprocessor(
drt: &DistributedRuntime, drt: &DistributedRuntime,
target_namespace: &str, target_namespace: &str,
) -> anyhow::Result<(Option<Arc<OpenAIPreprocessor>>, u32, String, String)> { ) -> anyhow::Result<DiscoveredModelBootstrap> {
let instance_count = wait_for_discovery_sync(drt).await; let instance_count = wait_for_discovery_sync(drt).await;
if instance_count == 0 { if instance_count == 0 {
anyhow::bail!("Discovery sync failed: no worker instances found. Is the backend running?"); anyhow::bail!("Discovery sync failed: no worker instances found. Is the backend running?");
...@@ -1300,7 +1317,7 @@ async fn init_preprocessor( ...@@ -1300,7 +1317,7 @@ async fn init_preprocessor(
// Retry fetching the preprocessor: model card metadata may arrive after // Retry fetching the preprocessor: model card metadata may arrive after
// worker endpoints are registered. // worker endpoints are registered.
let (prep, block_size, model_name, actual_namespace) = loop { let bootstrap = loop {
match fetch_preprocessor_from_discovery(drt, target_namespace).await { match fetch_preprocessor_from_discovery(drt, target_namespace).await {
Ok(result) => break result, Ok(result) => break result,
Err(e) => { Err(e) => {
...@@ -1315,13 +1332,14 @@ async fn init_preprocessor( ...@@ -1315,13 +1332,14 @@ async fn init_preprocessor(
}; };
tracing::info!( tracing::info!(
kv_cache_block_size = block_size, kv_cache_block_size = bootstrap.card.kv_cache_block_size,
model_name = model_name, model_name = %bootstrap.card.display_name,
actual_namespace = actual_namespace, actual_namespace = %bootstrap.actual_namespace,
enable_eagle = bootstrap.card.runtime_config.enable_eagle,
"Preprocessor initialized from model card" "Preprocessor initialized from model card"
); );
Ok((Some(prep), block_size, model_name, actual_namespace)) Ok(bootstrap)
} }
/// Fetch model card via discovery and create preprocessor. /// Fetch model card via discovery and create preprocessor.
...@@ -1331,12 +1349,11 @@ async fn init_preprocessor( ...@@ -1331,12 +1349,11 @@ async fn init_preprocessor(
/// 2. Finds the first model in the target namespace (decode workers only) /// 2. Finds the first model in the target namespace (decode workers only)
/// 3. Downloads the model config (tokenizer files) if needed /// 3. Downloads the model config (tokenizer files) if needed
/// 4. Creates an OpenAIPreprocessor from the model card /// 4. Creates an OpenAIPreprocessor from the model card
/// 5. Returns the preprocessor, the kv_cache_block_size, and model_name from the model card /// 5. Returns the preprocessor, the model card, and the resolved worker namespace
async fn fetch_preprocessor_from_discovery( async fn fetch_preprocessor_from_discovery(
drt: &DistributedRuntime, drt: &DistributedRuntime,
target_namespace: &str, target_namespace: &str,
) -> anyhow::Result<(Arc<OpenAIPreprocessor>, u32, String, String)> { ) -> anyhow::Result<DiscoveredModelBootstrap> {
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_runtime::discovery::DiscoveryInstance; use dynamo_runtime::discovery::DiscoveryInstance;
let discovery = drt.discovery(); let discovery = drt.discovery();
...@@ -1383,12 +1400,11 @@ async fn fetch_preprocessor_from_discovery( ...@@ -1383,12 +1400,11 @@ async fn fetch_preprocessor_from_discovery(
) )
})?; })?;
let kv_cache_block_size = card.kv_cache_block_size;
let model_name = card.name().to_string();
tracing::info!( tracing::info!(
model_name = model_name, model_name = %card.display_name,
kv_cache_block_size = kv_cache_block_size, kv_cache_block_size = card.kv_cache_block_size,
actual_namespace = actual_namespace, actual_namespace = %actual_namespace,
enable_eagle = card.runtime_config.enable_eagle,
"Found model card via discovery" "Found model card via discovery"
); );
...@@ -1396,13 +1412,12 @@ async fn fetch_preprocessor_from_discovery( ...@@ -1396,13 +1412,12 @@ async fn fetch_preprocessor_from_discovery(
card.download_config().await?; card.download_config().await?;
// Create preprocessor // Create preprocessor
let preprocessor = OpenAIPreprocessor::new(card)?; let preprocessor = OpenAIPreprocessor::new(card.clone())?;
Ok(( Ok(DiscoveredModelBootstrap {
preprocessor, preprocessor,
kv_cache_block_size, card,
model_name,
actual_namespace, actual_namespace,
)) })
} }
/// Find a prefill endpoint from already-discovered instances (one-time filter). /// Find a prefill endpoint from already-discovered instances (one-time filter).
......
...@@ -161,13 +161,14 @@ fn init_standalone_logging() { ...@@ -161,13 +161,14 @@ fn init_standalone_logging() {
} }
#[pyfunction] #[pyfunction]
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None))] #[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None, is_eagle=None))]
pub fn compute_block_hash_for_seq_py( pub fn compute_block_hash_for_seq_py(
_py: Python, _py: Python,
tokens: Vec<u32>, tokens: Vec<u32>,
kv_block_size: usize, kv_block_size: usize,
block_mm_infos: Option<Bound<PyAny>>, block_mm_infos: Option<Bound<PyAny>>,
lora_name: Option<String>, lora_name: Option<String>,
is_eagle: Option<bool>,
) -> PyResult<Vec<u64>> { ) -> PyResult<Vec<u64>> {
if kv_block_size == 0 { if kv_block_size == 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>( return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
...@@ -183,8 +184,11 @@ pub fn compute_block_hash_for_seq_py( ...@@ -183,8 +184,11 @@ pub fn compute_block_hash_for_seq_py(
let hashes = compute_block_hash_for_seq( let hashes = compute_block_hash_for_seq(
&tokens, &tokens,
kv_block_size as u32, kv_block_size as u32,
mm_infos.as_deref(), BlockHashOptions {
lora_name.as_deref(), block_mm_infos: mm_infos.as_deref(),
lora_name: lora_name.as_deref(),
is_eagle,
},
); );
Ok(hashes.into_iter().map(|h| h.0).collect()) Ok(hashes.into_iter().map(|h| h.0).collect())
...@@ -310,7 +314,7 @@ impl KvEventPublisher { ...@@ -310,7 +314,7 @@ impl KvEventPublisher {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None))] #[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None, is_eagle=None))]
fn publish_stored( fn publish_stored(
&self, &self,
py: Python, py: Python,
...@@ -320,6 +324,7 @@ impl KvEventPublisher { ...@@ -320,6 +324,7 @@ impl KvEventPublisher {
parent_hash: Option<i64>, parent_hash: Option<i64>,
block_mm_infos: Option<Bound<PyAny>>, block_mm_infos: Option<Bound<PyAny>>,
lora_name: Option<String>, lora_name: Option<String>,
is_eagle: Option<bool>,
) -> PyResult<()> { ) -> PyResult<()> {
let kv_block_size = self.kv_block_size as u32; let kv_block_size = self.kv_block_size as u32;
let dp_rank = self.dp_rank; let dp_rank = self.dp_rank;
...@@ -347,6 +352,7 @@ impl KvEventPublisher { ...@@ -347,6 +352,7 @@ impl KvEventPublisher {
lora_name.as_deref(), lora_name.as_deref(),
&warning_count, &warning_count,
mm_infos.as_deref(), mm_infos.as_deref(),
is_eagle,
), ),
}), }),
dp_rank, dp_rank,
...@@ -716,14 +722,13 @@ async fn create_kv_router_from_endpoint( ...@@ -716,14 +722,13 @@ async fn create_kv_router_from_endpoint(
llm_rs::discovery::WORKER_TYPE_DECODE llm_rs::discovery::WORKER_TYPE_DECODE
}; };
// Only query discovery for model_name when a remote indexer is configured, // Query discovery once so we can derive both model_name (for remote indexer)
// since model_name is only needed for the RemoteIndexer path. // and Eagle routing semantics from the model card.
let needs_model_name = kv_router_config let needs_model_name = kv_router_config
.as_ref() .as_ref()
.map(|cfg| cfg.remote_indexer_component.is_some()) .map(|cfg| cfg.remote_indexer_component.is_some())
.unwrap_or(false); .unwrap_or(false);
let (model_name, enable_eagle) = {
let model_name = if needs_model_name {
let discovery = endpoint.inner.component().drt().discovery(); let discovery = endpoint.inner.component().drt().discovery();
let instances = discovery let instances = discovery
.list(rs::discovery::DiscoveryQuery::EndpointModels { .list(rs::discovery::DiscoveryQuery::EndpointModels {
...@@ -734,23 +739,26 @@ async fn create_kv_router_from_endpoint( ...@@ -734,23 +739,26 @@ async fn create_kv_router_from_endpoint(
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Some( let maybe_card = instances.into_iter().find_map(|inst| {
instances
.into_iter()
.find_map(|inst| {
inst.deserialize_model::<llm_rs::model_card::ModelDeploymentCard>() inst.deserialize_model::<llm_rs::model_card::ModelDeploymentCard>()
.ok() .ok()
.map(|card| card.display_name) });
})
.ok_or_else(|| { match maybe_card {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!( Some(card) => {
"no model card found in discovery for endpoint {}/{}/{}", let model_name = needs_model_name.then(|| card.display_name.clone());
endpoint_id.namespace, endpoint_id.component, endpoint_id.name (model_name, card.runtime_config.enable_eagle)
)) }
})?, None => {
) tracing::warn!(
} else { namespace = %endpoint_id.namespace,
None component = %endpoint_id.component,
endpoint = %endpoint_id.name,
"No model card found in discovery; defaulting to non-Eagle routing semantics"
);
(None, false)
}
}
}; };
let kv_router = model_manager let kv_router = model_manager
...@@ -760,6 +768,7 @@ async fn create_kv_router_from_endpoint( ...@@ -760,6 +768,7 @@ async fn create_kv_router_from_endpoint(
kv_router_config, kv_router_config,
worker_type, worker_type,
model_name, model_name,
enable_eagle,
) )
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
...@@ -1083,7 +1092,7 @@ impl KvRouter { ...@@ -1083,7 +1092,7 @@ impl KvRouter {
block_mm_infos.as_deref(), block_mm_infos.as_deref(),
router_config_override.as_ref(), router_config_override.as_ref(),
update_states, update_states,
lora_name, lora_name.clone(),
0.0, 0.0,
None, None,
None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
...@@ -1092,8 +1101,17 @@ impl KvRouter { ...@@ -1092,8 +1101,17 @@ impl KvRouter {
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
if update_indexer && !chooser.kv_router_config().use_kv_events { if update_indexer && !chooser.kv_router_config().use_kv_events {
let mut tokens_with_hashes =
TokensWithHashes::new(token_ids.clone(), chooser.block_size())
.with_is_eagle(chooser.is_eagle());
if let Some(infos) = block_mm_infos.as_ref() {
tokens_with_hashes = tokens_with_hashes.with_mm_infos(infos.clone());
}
if let Some(lora_name) = lora_name.as_ref() {
tokens_with_hashes = tokens_with_hashes.with_lora_name(lora_name.clone());
}
chooser chooser
.record_routing_decision(token_ids.clone(), best_worker) .record_routing_decision(tokens_with_hashes, best_worker)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
} }
...@@ -1129,18 +1147,27 @@ impl KvRouter { ...@@ -1129,18 +1147,27 @@ impl KvRouter {
}) })
} }
#[pyo3(signature = (token_ids, lora_name=None))] #[pyo3(signature = (token_ids, block_mm_infos=None, lora_name=None))]
fn get_potential_loads<'p>( fn get_potential_loads<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
block_mm_infos: Option<PyObject>,
lora_name: Option<String>, lora_name: Option<String>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let block_mm_infos = block_mm_infos
.map(|obj| depythonize_block_mm_infos(obj.bind(py)))
.transpose()?;
let chooser = self.inner.chooser.clone(); let chooser = self.inner.chooser.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let loads = chooser let loads = chooser
.get_potential_loads(&token_ids, None, lora_name.as_deref()) .get_potential_loads(
&token_ids,
None,
block_mm_infos.as_deref(),
lora_name.as_deref(),
)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -60,6 +60,11 @@ impl ModelRuntimeConfig { ...@@ -60,6 +60,11 @@ impl ModelRuntimeConfig {
self.inner.enable_local_indexer = enable_local_indexer; self.inner.enable_local_indexer = enable_local_indexer;
} }
#[setter]
fn set_enable_eagle(&mut self, enable_eagle: bool) {
self.inner.enable_eagle = enable_eagle;
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner self.inner
...@@ -159,4 +164,9 @@ impl ModelRuntimeConfig { ...@@ -159,4 +164,9 @@ impl ModelRuntimeConfig {
.as_ref() .as_ref()
.and_then(|e| e.bootstrap_port) .and_then(|e| e.bootstrap_port)
} }
#[getter]
fn enable_eagle(&self) -> bool {
self.inner.enable_eagle
}
} }
...@@ -279,6 +279,7 @@ def compute_block_hash_for_seq( ...@@ -279,6 +279,7 @@ def compute_block_hash_for_seq(
kv_block_size: int, kv_block_size: int,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None, block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None, lora_name: Optional[str] = None,
is_eagle: Optional[bool] = None,
) -> List[int]: ) -> List[int]:
""" """
Compute block hashes for a sequence of tokens, optionally including multimodal metadata. Compute block hashes for a sequence of tokens, optionally including multimodal metadata.
...@@ -299,6 +300,9 @@ def compute_block_hash_for_seq( ...@@ -299,6 +300,9 @@ def compute_block_hash_for_seq(
} }
] ]
} }
lora_name: Optional LoRA adapter name for adapter-aware block hashing.
is_eagle: Optional Eagle mode flag. When true, hashes use overlapping
`kv_block_size + 1` token windows with `kv_block_size` stride.
Returns: Returns:
List of block hashes (one per block) List of block hashes (one per block)
...@@ -478,6 +482,7 @@ class ModelRuntimeConfig: ...@@ -478,6 +482,7 @@ class ModelRuntimeConfig:
data_parallel_start_rank: int data_parallel_start_rank: int
data_parallel_size: int data_parallel_size: int
enable_local_indexer: bool enable_local_indexer: bool
enable_eagle: bool
runtime_data: dict[str, Any] runtime_data: dict[str, Any]
tensor_model_config: Any | None tensor_model_config: Any | None
bootstrap_host: str | None bootstrap_host: str | None
...@@ -634,7 +639,7 @@ class KvIndexer: ...@@ -634,7 +639,7 @@ class KvIndexer:
... ...
def find_matches_for_request( def find_matches_for_request(
self, token_ids: List[int], lora_name: Optional[str] = None self, token_ids: List[int], lora_name: Optional[str] = None, is_eagle: Optional[bool] = None
) -> OverlapScores: ) -> OverlapScores:
""" """
Return the overlapping scores of workers for the given token ids. Return the overlapping scores of workers for the given token ids.
...@@ -682,7 +687,7 @@ class ApproxKvIndexer: ...@@ -682,7 +687,7 @@ class ApproxKvIndexer:
... ...
def find_matches_for_request( def find_matches_for_request(
self, token_ids: List[int], lora_name: Optional[str] = None self, token_ids: List[int], lora_name: Optional[str] = None, is_eagle: Optional[bool] = None
) -> OverlapScores: ) -> OverlapScores:
""" """
Return the overlapping scores of workers for the given token ids. Return the overlapping scores of workers for the given token ids.
...@@ -765,6 +770,7 @@ class KvEventPublisher: ...@@ -765,6 +770,7 @@ class KvEventPublisher:
parent_hash: Optional[int] = None, parent_hash: Optional[int] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None, block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None, lora_name: Optional[str] = None,
is_eagle: Optional[bool] = None,
) -> None: ) -> None:
""" """
Publish a KV stored event. Publish a KV stored event.
...@@ -780,6 +786,8 @@ class KvEventPublisher: ...@@ -780,6 +786,8 @@ class KvEventPublisher:
Each item is either None or a dict with "mm_objects" key containing Each item is either None or a dict with "mm_objects" key containing
a list of {"mm_hash": int, "offsets": [[start, end], ...]} dicts. a list of {"mm_hash": int, "offsets": [[start, end], ...]} dicts.
lora_name: Optional LoRA adapter name for adapter-aware block hashing. lora_name: Optional LoRA adapter name for adapter-aware block hashing.
is_eagle: Optional Eagle mode flag. When true, stored blocks are
reconstructed using overlapping `kv_block_size + 1` token windows.
""" """
... ...
...@@ -1739,6 +1747,7 @@ class KvRouter: ...@@ -1739,6 +1747,7 @@ class KvRouter:
async def get_potential_loads( async def get_potential_loads(
self, self,
token_ids: List[int], token_ids: List[int],
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None, lora_name: Optional[str] = None,
) -> List[Dict[str, int]]: ) -> List[Dict[str, int]]:
""" """
...@@ -1746,6 +1755,9 @@ class KvRouter: ...@@ -1746,6 +1755,9 @@ class KvRouter:
Args: Args:
token_ids: List of token IDs to evaluate token_ids: List of token IDs to evaluate
block_mm_infos: Optional block-level multimodal metadata aligned to request
blocks. When provided, this is used in hash computation
for MM-aware potential-load estimation.
Returns: Returns:
A list of dictionaries, each containing: A list of dictionaries, each containing:
......
...@@ -449,13 +449,22 @@ impl KvIndexerInterface for KvIndexer { ...@@ -449,13 +449,22 @@ impl KvIndexerInterface for KvIndexer {
&self, &self,
tokens: &[u32], tokens: &[u32],
lora_name: Option<&str>, lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
tracing::debug!( tracing::debug!(
"Finding matches for request tokens: {:?} / len: {}", "Finding matches for request tokens: {:?} / len: {}",
tokens, tokens,
tokens.len() tokens.len()
); );
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name); let sequence = compute_block_hash_for_seq(
tokens,
self.kv_block_size,
BlockHashOptions {
lora_name,
is_eagle,
..Default::default()
},
);
tracing::debug!("Computed sequence: {:?}", sequence); tracing::debug!("Computed sequence: {:?}", sequence);
self.find_matches(sequence).await self.find_matches(sequence).await
} }
......
...@@ -271,9 +271,10 @@ impl KvIndexerInterface for LocalKvIndexer { ...@@ -271,9 +271,10 @@ impl KvIndexerInterface for LocalKvIndexer {
&self, &self,
tokens: &[u32], tokens: &[u32],
lora_name: Option<&str>, lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
self.indexer self.indexer
.find_matches_for_request(tokens, lora_name) .find_matches_for_request(tokens, lora_name, is_eagle)
.await .await
} }
......
...@@ -349,7 +349,7 @@ mod tests { ...@@ -349,7 +349,7 @@ mod tests {
// 1. Before routing decision there should be no matches // 1. Before routing decision there should be no matches
let pre_scores = indexer let pre_scores = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.expect("indexer offline"); .expect("indexer offline");
assert!(pre_scores.scores.is_empty()); assert!(pre_scores.scores.is_empty());
...@@ -367,7 +367,7 @@ mod tests { ...@@ -367,7 +367,7 @@ mod tests {
// Poll until we observe the match being registered // Poll until we observe the match being registered
spin_until(Duration::from_millis(100), async || { spin_until(Duration::from_millis(100), async || {
let s = indexer let s = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
s.scores s.scores
...@@ -380,7 +380,7 @@ mod tests { ...@@ -380,7 +380,7 @@ mod tests {
// 3. After the TTL has passed the entry should expire automatically // 3. After the TTL has passed the entry should expire automatically
time::sleep(TTL + Duration::from_millis(50)).await; time::sleep(TTL + Duration::from_millis(50)).await;
let post_scores = indexer let post_scores = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
assert!(post_scores.scores.is_empty()); assert!(post_scores.scores.is_empty());
...@@ -420,7 +420,7 @@ mod tests { ...@@ -420,7 +420,7 @@ mod tests {
// Wait until the worker is registered // Wait until the worker is registered
spin_until(Duration::from_millis(100), async || { spin_until(Duration::from_millis(100), async || {
let s = indexer let s = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
s.scores s.scores
...@@ -434,7 +434,7 @@ mod tests { ...@@ -434,7 +434,7 @@ mod tests {
// Ensure the worker's entries are gone // Ensure the worker's entries are gone
spin_until(Duration::from_millis(100), async || { spin_until(Duration::from_millis(100), async || {
let s = indexer let s = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
!s.scores !s.scores
...@@ -488,7 +488,7 @@ mod tests { ...@@ -488,7 +488,7 @@ mod tests {
// Ensure both workers are registered // Ensure both workers are registered
spin_until(Duration::from_millis(100), async || { spin_until(Duration::from_millis(100), async || {
let s = indexer let s = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
s.scores s.scores
...@@ -508,7 +508,7 @@ mod tests { ...@@ -508,7 +508,7 @@ mod tests {
// Confirm the removed worker is gone, and the other remains. // Confirm the removed worker is gone, and the other remains.
spin_until(Duration::from_millis(100), async || { spin_until(Duration::from_millis(100), async || {
let s = indexer let s = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
!s.scores !s.scores
...@@ -558,7 +558,7 @@ mod tests { ...@@ -558,7 +558,7 @@ mod tests {
// Ensure the indexer has registered the block // Ensure the indexer has registered the block
spin_until(Duration::from_millis(100), async || { spin_until(Duration::from_millis(100), async || {
let s = indexer let s = indexer
.find_matches_for_request(&seq_a, None) .find_matches_for_request(&seq_a, None, None)
.await .await
.unwrap(); .unwrap();
s.scores s.scores
...@@ -573,7 +573,7 @@ mod tests { ...@@ -573,7 +573,7 @@ mod tests {
// Query the indexer for overlaps of Sequence B (before it has been routed anywhere) // Query the indexer for overlaps of Sequence B (before it has been routed anywhere)
let overlap = indexer let overlap = indexer
.find_matches_for_request(&seq_b, None) .find_matches_for_request(&seq_b, None, None)
.await .await
.unwrap(); .unwrap();
...@@ -631,7 +631,7 @@ mod tests { ...@@ -631,7 +631,7 @@ mod tests {
// Wait until both workers are reflected in overlap scores // Wait until both workers are reflected in overlap scores
spin_until(Duration::from_millis(100), async || { spin_until(Duration::from_millis(100), async || {
let s = indexer let s = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
s.scores s.scores
...@@ -646,7 +646,7 @@ mod tests { ...@@ -646,7 +646,7 @@ mod tests {
.await; .await;
let scores = indexer let scores = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
...@@ -808,7 +808,7 @@ mod tests { ...@@ -808,7 +808,7 @@ mod tests {
for i in 0..5 { for i in 0..5 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3]; let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer let scores = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
...@@ -837,7 +837,7 @@ mod tests { ...@@ -837,7 +837,7 @@ mod tests {
for i in 0..4 { for i in 0..4 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3]; let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer let scores = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
assert!( assert!(
...@@ -851,7 +851,7 @@ mod tests { ...@@ -851,7 +851,7 @@ mod tests {
for i in 4..6 { for i in 4..6 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3]; let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer let scores = indexer
.find_matches_for_request(&tokens, None) .find_matches_for_request(&tokens, None, None)
.await .await
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
......
...@@ -424,8 +424,17 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -424,8 +424,17 @@ impl KvIndexerInterface for KvIndexerSharded {
&self, &self,
tokens: &[u32], tokens: &[u32],
lora_name: Option<&str>, lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name); let sequence = compute_block_hash_for_seq(
tokens,
self.kv_block_size,
BlockHashOptions {
lora_name,
is_eagle,
..Default::default()
},
);
self.find_matches(sequence).await self.find_matches(sequence).await
} }
......
...@@ -547,7 +547,10 @@ mod interface_tests { ...@@ -547,7 +547,10 @@ mod interface_tests {
// Empty index should return no matches // Empty index should return no matches
let tokens = vec![1, 2, 3, 4]; let tokens = vec![1, 2, 3, 4];
let scores = index.find_matches_for_request(&tokens, None).await.unwrap(); let scores = index
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
assert!(scores.scores.is_empty()); assert!(scores.scores.is_empty());
// Store some data and verify we can find it via tokens // Store some data and verify we can find it via tokens
...@@ -559,7 +562,10 @@ mod interface_tests { ...@@ -559,7 +562,10 @@ mod interface_tests {
// Note: find_matches_for_request computes block hashes from tokens, // Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values. // so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error. // For this test, we just verify the method works without error.
let scores = index.find_matches_for_request(&tokens, None).await.unwrap(); let scores = index
.find_matches_for_request(&tokens, None, None)
.await
.unwrap();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes // The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens // because find_matches_for_request computes different hashes from raw tokens
assert!(scores.scores.is_empty() || !scores.scores.is_empty()); assert!(scores.scores.is_empty() || !scores.scores.is_empty());
...@@ -883,9 +889,16 @@ mod lora_tests { ...@@ -883,9 +889,16 @@ mod lora_tests {
// Same token sequence for both base model and LoRA adapter // Same token sequence for both base model and LoRA adapter
let tokens: Vec<u32> = (0..kv_block_size * 3).collect(); let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
let base_hashes = compute_block_hash_for_seq(&tokens, kv_block_size, None, None); let base_hashes =
let lora_hashes = compute_block_hash_for_seq(&tokens, kv_block_size, BlockHashOptions::default());
compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter")); let lora_hashes = compute_block_hash_for_seq(
&tokens,
kv_block_size,
BlockHashOptions {
lora_name: Some("my-adapter"),
..Default::default()
},
);
// Hashes must differ despite identical tokens // Hashes must differ despite identical tokens
assert_ne!( assert_ne!(
...@@ -970,9 +983,16 @@ mod lora_tests { ...@@ -970,9 +983,16 @@ mod lora_tests {
let tokens: Vec<u32> = (0..kv_block_size * 3).collect(); let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
// With LoRA-aware hashing, base and adapter produce different LocalBlockHash // With LoRA-aware hashing, base and adapter produce different LocalBlockHash
let base_local = compute_block_hash_for_seq(&tokens, kv_block_size, None, None); let base_local =
let lora_local = compute_block_hash_for_seq(&tokens, kv_block_size, BlockHashOptions::default());
compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter")); let lora_local = compute_block_hash_for_seq(
&tokens,
kv_block_size,
BlockHashOptions {
lora_name: Some("my-adapter"),
..Default::default()
},
);
assert_ne!( assert_ne!(
base_local, lora_local, base_local, lora_local,
...@@ -1044,8 +1064,22 @@ mod lora_tests { ...@@ -1044,8 +1064,22 @@ mod lora_tests {
let tokens: Vec<u32> = (0..kv_block_size * 2).collect(); let tokens: Vec<u32> = (0..kv_block_size * 2).collect();
let hashes_a = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-a")); let hashes_a = compute_block_hash_for_seq(
let hashes_b = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-b")); &tokens,
kv_block_size,
BlockHashOptions {
lora_name: Some("adapter-a"),
..Default::default()
},
);
let hashes_b = compute_block_hash_for_seq(
&tokens,
kv_block_size,
BlockHashOptions {
lora_name: Some("adapter-b"),
..Default::default()
},
);
assert_ne!( assert_ne!(
hashes_a, hashes_b, hashes_a, hashes_b,
......
...@@ -159,8 +159,17 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { ...@@ -159,8 +159,17 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
&self, &self,
tokens: &[u32], tokens: &[u32],
lora_name: Option<&str>, lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name); let sequence = compute_block_hash_for_seq(
tokens,
self.kv_block_size,
BlockHashOptions {
lora_name,
is_eagle,
..Default::default()
},
);
Ok(self.backend.find_matches(&sequence, false)) Ok(self.backend.find_matches(&sequence, false))
} }
......
...@@ -36,6 +36,7 @@ pub trait KvIndexerInterface { ...@@ -36,6 +36,7 @@ pub trait KvIndexerInterface {
&self, &self,
tokens: &[u32], tokens: &[u32],
lora_name: Option<&str>, lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError>; ) -> Result<OverlapScores, KvRouterError>;
/// Apply a `RouterEvent` to the KV store. /// Apply a `RouterEvent` to the KV store.
......
...@@ -24,6 +24,13 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash { ...@@ -24,6 +24,13 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
LocalBlockHash(compute_hash(data)) LocalBlockHash(compute_hash(data))
} }
#[derive(Debug, Clone, Copy, Default)]
pub struct BlockHashOptions<'a> {
pub block_mm_infos: Option<&'a [Option<BlockExtraInfo>]>,
pub lora_name: Option<&'a str>,
pub is_eagle: Option<bool>,
}
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata /// Compute the hash for a sequence of tokens, optionally including multimodal metadata
/// and LoRA adapter identity. /// and LoRA adapter identity.
/// ///
...@@ -39,20 +46,30 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash { ...@@ -39,20 +46,30 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
pub fn compute_block_hash_for_seq( pub fn compute_block_hash_for_seq(
tokens: &[u32], tokens: &[u32],
kv_block_size: u32, kv_block_size: u32,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>, options: BlockHashOptions<'_>,
lora_name: Option<&str>,
) -> Vec<LocalBlockHash> { ) -> Vec<LocalBlockHash> {
let seed = match lora_name.filter(|n| !n.is_empty()) { if kv_block_size == 0 {
return Vec::new();
}
let seed = match options.lora_name.filter(|n| !n.is_empty()) {
Some(name) => XXH3_SEED.wrapping_add(xxh3::xxh3_64(name.as_bytes())), Some(name) => XXH3_SEED.wrapping_add(xxh3::xxh3_64(name.as_bytes())),
None => XXH3_SEED, None => XXH3_SEED,
}; };
tokens
.chunks_exact(kv_block_size as usize) let is_eagle_flag = options.is_eagle.unwrap_or(false);
.enumerate() let stride = kv_block_size as usize;
.map(|(block_idx, chunk)| { let window_size = if is_eagle_flag { stride + 1 } else { stride };
let mut hashes = Vec::new();
let mut block_idx = 0;
let mut start = 0;
while start + window_size <= tokens.len() {
let chunk = &tokens[start..start + window_size];
let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect(); let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
if let Some(mm_infos) = block_mm_infos if let Some(mm_infos) = options.block_mm_infos
&& let Some(Some(block_mm_info)) = mm_infos.get(block_idx) && let Some(Some(block_mm_info)) = mm_infos.get(block_idx)
{ {
let mut mm_hashes: Vec<u64> = block_mm_info let mut mm_hashes: Vec<u64> = block_mm_info
...@@ -67,9 +84,13 @@ pub fn compute_block_hash_for_seq( ...@@ -67,9 +84,13 @@ pub fn compute_block_hash_for_seq(
} }
} }
LocalBlockHash(xxh3::xxh3_64_with_seed(&bytes, seed)) hashes.push(LocalBlockHash(xxh3::xxh3_64_with_seed(&bytes, seed)));
})
.collect() start += stride;
block_idx += 1;
}
hashes
} }
/// Compute rolling sequence hashes for a vector of block hashes. /// Compute rolling sequence hashes for a vector of block hashes.
...@@ -718,6 +739,7 @@ pub struct TokensWithHashes { ...@@ -718,6 +739,7 @@ pub struct TokensWithHashes {
lora_name: Option<String>, lora_name: Option<String>,
block_hashes: Option<Vec<LocalBlockHash>>, block_hashes: Option<Vec<LocalBlockHash>>,
seq_hashes: Option<Vec<SequenceHash>>, seq_hashes: Option<Vec<SequenceHash>>,
is_eagle: Option<bool>,
} }
impl TokensWithHashes { impl TokensWithHashes {
...@@ -730,6 +752,7 @@ impl TokensWithHashes { ...@@ -730,6 +752,7 @@ impl TokensWithHashes {
lora_name: None, lora_name: None,
block_hashes: None, block_hashes: None,
seq_hashes: None, seq_hashes: None,
is_eagle: None,
} }
} }
...@@ -745,6 +768,24 @@ impl TokensWithHashes { ...@@ -745,6 +768,24 @@ impl TokensWithHashes {
self self
} }
/// Sets Eagle hashing semantics for this token sequence.
pub fn with_is_eagle(mut self, is_eagle: bool) -> Self {
self.set_is_eagle(is_eagle);
self
}
/// Updates Eagle hashing semantics and invalidates cached hashes when it changes.
pub fn set_is_eagle(&mut self, is_eagle: bool) {
let is_eagle = Some(is_eagle);
if self.is_eagle == is_eagle {
return;
}
self.is_eagle = is_eagle;
self.block_hashes = None;
self.seq_hashes = None;
}
/// Returns a reference to the tokens. /// Returns a reference to the tokens.
pub fn tokens(&self) -> &[u32] { pub fn tokens(&self) -> &[u32] {
&self.tokens &self.tokens
...@@ -776,8 +817,11 @@ impl TokensWithHashes { ...@@ -776,8 +817,11 @@ impl TokensWithHashes {
self.block_hashes = Some(compute_block_hash_for_seq( self.block_hashes = Some(compute_block_hash_for_seq(
&self.tokens, &self.tokens,
self.block_size, self.block_size,
self.block_mm_infos.as_deref(), BlockHashOptions {
self.lora_name.as_deref(), block_mm_infos: self.block_mm_infos.as_deref(),
lora_name: self.lora_name.as_deref(),
is_eagle: self.is_eagle,
},
)); ));
} }
self.block_hashes.as_ref().unwrap() self.block_hashes.as_ref().unwrap()
...@@ -858,24 +902,41 @@ mod tests { ...@@ -858,24 +902,41 @@ mod tests {
#[case(64)] #[case(64)]
fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) { fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
let sequence = (0..kv_block_size).collect::<Vec<u32>>(); let sequence = (0..kv_block_size).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None); let hashes =
compute_block_hash_for_seq(&sequence, kv_block_size, BlockHashOptions::default());
assert_eq!(hashes.len(), 1); assert_eq!(hashes.len(), 1);
let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>(); let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None); let hashes =
compute_block_hash_for_seq(&sequence, kv_block_size, BlockHashOptions::default());
assert_eq!(hashes.len(), 1); assert_eq!(hashes.len(), 1);
let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>(); let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None); let hashes =
compute_block_hash_for_seq(&sequence, kv_block_size, BlockHashOptions::default());
assert_eq!(hashes.len(), 2); assert_eq!(hashes.len(), 2);
} }
#[test] #[test]
fn test_lora_name_produces_different_hash() { fn test_lora_name_produces_different_hash() {
let tokens: Vec<u32> = (0..4).collect(); let tokens: Vec<u32> = (0..4).collect();
let base = compute_block_hash_for_seq(&tokens, 4, None, None); let base = compute_block_hash_for_seq(&tokens, 4, BlockHashOptions::default());
let lora_a = compute_block_hash_for_seq(&tokens, 4, None, Some("adapter-a")); let lora_a = compute_block_hash_for_seq(
let lora_b = compute_block_hash_for_seq(&tokens, 4, None, Some("adapter-b")); &tokens,
4,
BlockHashOptions {
lora_name: Some("adapter-a"),
..Default::default()
},
);
let lora_b = compute_block_hash_for_seq(
&tokens,
4,
BlockHashOptions {
lora_name: Some("adapter-b"),
..Default::default()
},
);
assert_ne!(base[0], lora_a[0]); assert_ne!(base[0], lora_a[0]);
assert_ne!(base[0], lora_b[0]); assert_ne!(base[0], lora_b[0]);
...@@ -885,16 +946,23 @@ mod tests { ...@@ -885,16 +946,23 @@ mod tests {
#[test] #[test]
fn test_lora_name_none_matches_legacy() { fn test_lora_name_none_matches_legacy() {
let tokens: Vec<u32> = (0..8).collect(); let tokens: Vec<u32> = (0..8).collect();
let hashes_none = compute_block_hash_for_seq(&tokens, 4, None, None); let hashes_none = compute_block_hash_for_seq(&tokens, 4, BlockHashOptions::default());
let hashes_none2 = compute_block_hash_for_seq(&tokens, 4, None, None); let hashes_none2 = compute_block_hash_for_seq(&tokens, 4, BlockHashOptions::default());
assert_eq!(hashes_none, hashes_none2); assert_eq!(hashes_none, hashes_none2);
} }
#[test] #[test]
fn test_lora_name_empty_string_normalized_to_none() { fn test_lora_name_empty_string_normalized_to_none() {
let tokens: Vec<u32> = (0..4).collect(); let tokens: Vec<u32> = (0..4).collect();
let base = compute_block_hash_for_seq(&tokens, 4, None, None); let base = compute_block_hash_for_seq(&tokens, 4, BlockHashOptions::default());
let empty = compute_block_hash_for_seq(&tokens, 4, None, Some("")); let empty = compute_block_hash_for_seq(
&tokens,
4,
BlockHashOptions {
lora_name: Some(""),
..Default::default()
},
);
assert_eq!( assert_eq!(
base, empty, base, empty,
"empty lora_name should be treated as base model" "empty lora_name should be treated as base model"
...@@ -918,6 +986,73 @@ mod tests { ...@@ -918,6 +986,73 @@ mod tests {
} }
} }
#[test]
fn test_compute_block_hash_for_seq_eagle_windows() {
let tokens: Vec<u32> = (0..6).collect();
let default_hashes = compute_block_hash_for_seq(&tokens, 2, BlockHashOptions::default());
let eagle_hashes = compute_block_hash_for_seq(
&tokens,
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
let expected_first = compute_block_hash_for_seq(
&[0, 1, 2],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
let expected_second = compute_block_hash_for_seq(
&[2, 3, 4],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
assert_eq!(default_hashes.len(), 3);
assert_eq!(eagle_hashes.len(), 2);
assert_eq!(eagle_hashes, vec![expected_first[0], expected_second[0]]);
assert_ne!(default_hashes[0], eagle_hashes[0]);
}
#[test]
fn test_tokens_with_hashes_set_is_eagle_invalidates_cache() {
let tokens: Vec<u32> = (0..6).collect();
let mut with_hashes = TokensWithHashes::new(tokens, 2);
let default_hashes = with_hashes.get_or_compute_block_hashes().to_vec();
with_hashes.set_is_eagle(true);
let eagle_hashes = with_hashes.get_or_compute_block_hashes().to_vec();
let expected_first = compute_block_hash_for_seq(
&[0, 1, 2],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
let expected_second = compute_block_hash_for_seq(
&[2, 3, 4],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
assert_eq!(default_hashes.len(), 3);
assert_eq!(eagle_hashes.len(), 2);
assert_eq!(eagle_hashes, vec![expected_first[0], expected_second[0]]);
assert_ne!(default_hashes[0], eagle_hashes[0]);
}
#[test] #[test]
fn test_local_block_hash_serialization() { fn test_local_block_hash_serialization() {
let hash = LocalBlockHash(12345); let hash = LocalBlockHash(12345);
......
...@@ -9,7 +9,7 @@ use rand::Rng; ...@@ -9,7 +9,7 @@ use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError}; use validator::{Validate, ValidationError};
use crate::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block}; use crate::protocols::{BlockHashOptions, compute_block_hash_for_seq, compute_seq_hash_for_block};
const fn default_min_initial_workers() -> usize { const fn default_min_initial_workers() -> usize {
1 1
...@@ -217,7 +217,7 @@ impl KvRouterConfig { ...@@ -217,7 +217,7 @@ impl KvRouterConfig {
tokens: &[u32], tokens: &[u32],
block_size: u32, block_size: u32,
config_override: Option<&RouterConfigOverride>, config_override: Option<&RouterConfigOverride>,
lora_name: Option<&str>, hash_options: BlockHashOptions<'_>,
) -> Option<Vec<u64>> { ) -> Option<Vec<u64>> {
if !self.router_track_active_blocks { if !self.router_track_active_blocks {
return None; return None;
...@@ -233,7 +233,7 @@ impl KvRouterConfig { ...@@ -233,7 +233,7 @@ impl KvRouterConfig {
.unwrap_or(self.router_assume_kv_reuse); .unwrap_or(self.router_assume_kv_reuse);
if assume_kv_reuse { if assume_kv_reuse {
let block_hashes = compute_block_hash_for_seq(tokens, block_size, None, lora_name); let block_hashes = compute_block_hash_for_seq(tokens, block_size, hash_options);
Some(compute_seq_hash_for_block(&block_hashes)) Some(compute_seq_hash_for_block(&block_hashes))
} else { } else {
let mut rng = rand::rng(); let mut rng = rand::rng();
...@@ -257,6 +257,7 @@ impl KvRouterConfig { ...@@ -257,6 +257,7 @@ impl KvRouterConfig {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::protocols::{BlockExtraInfo, BlockMmObjectInfo};
#[test] #[test]
fn router_queue_policy_display_and_parse_support_lcfs() { fn router_queue_policy_display_and_parse_support_lcfs() {
...@@ -288,4 +289,36 @@ mod tests { ...@@ -288,4 +289,36 @@ mod tests {
}; };
assert!(cfg.validate().is_err()); assert!(cfg.validate().is_err());
} }
#[test]
fn compute_seq_hashes_for_tracking_uses_mm_hashes() {
let cfg = KvRouterConfig::default();
let tokens = vec![1, 2, 3, 4];
let mm_infos = vec![
Some(BlockExtraInfo {
mm_objects: vec![BlockMmObjectInfo {
mm_hash: 42,
offsets: vec![],
}],
}),
None,
];
let without_mm = cfg
.compute_seq_hashes_for_tracking(&tokens, 2, None, BlockHashOptions::default())
.unwrap();
let with_mm = cfg
.compute_seq_hashes_for_tracking(
&tokens,
2,
None,
BlockHashOptions {
block_mm_infos: Some(&mm_infos),
..Default::default()
},
)
.unwrap();
assert_ne!(without_mm, with_mm);
}
} }
...@@ -13,7 +13,7 @@ use axum::{Json, Router}; ...@@ -13,7 +13,7 @@ use axum::{Json, Router};
use prometheus::Encoder; use prometheus::Encoder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::protocols::{LocalBlockHash, WorkerId, compute_block_hash_for_seq}; use crate::protocols::{BlockHashOptions, LocalBlockHash, WorkerId, compute_block_hash_for_seq};
use super::registry::{IndexerKey, ListenerControlError, WorkerRegistry}; use super::registry::{IndexerKey, ListenerControlError, WorkerRegistry};
...@@ -197,8 +197,14 @@ async fn query( ...@@ -197,8 +197,14 @@ async fn query(
let indexer = ie.indexer.clone(); let indexer = ie.indexer.clone();
drop(ie); drop(ie);
let block_hashes = let block_hashes = compute_block_hash_for_seq(
compute_block_hash_for_seq(&req.token_ids, block_size, None, req.lora_name.as_deref()); &req.token_ids,
block_size,
BlockHashOptions {
lora_name: req.lora_name.as_deref(),
..Default::default()
},
);
match indexer.find_matches(block_hashes).await { match indexer.find_matches(block_hashes).await {
Ok(overlap) => ( Ok(overlap) => (
StatusCode::OK, StatusCode::OK,
......
...@@ -17,9 +17,9 @@ use serde::Serialize; ...@@ -17,9 +17,9 @@ use serde::Serialize;
use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor}; use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor};
use crate::protocols::{ use crate::protocols::{
BlockExtraInfo, BlockMmObjectInfo, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, BlockExtraInfo, BlockHashOptions, BlockMmObjectInfo, ExternalSequenceBlockHash, KvCacheEvent,
KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, Placement, PlacementEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, Placement,
StorageTier, WorkerWithDpRank, compute_block_hash_for_seq, PlacementEvent, StorageTier, WorkerWithDpRank, compute_block_hash_for_seq,
}; };
// ------------------------------------------------------------------------- // -------------------------------------------------------------------------
...@@ -65,6 +65,13 @@ impl BlockHashValue { ...@@ -65,6 +65,13 @@ impl BlockHashValue {
} }
} }
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum KvTokenIds {
Single(Vec<u32>),
Bigram(Vec<(u32, u32)>),
}
#[derive(Debug, Serialize, Clone)] #[derive(Debug, Serialize, Clone)]
#[serde(tag = "type")] // msgspec encodes variant tag as a string when `tag=True` #[serde(tag = "type")] // msgspec encodes variant tag as a string when `tag=True`
pub enum RawKvEvent { pub enum RawKvEvent {
...@@ -83,6 +90,8 @@ pub enum RawKvEvent { ...@@ -83,6 +90,8 @@ pub enum RawKvEvent {
/// Multimodal extra info for each block (length should match block_hashes) /// Multimodal extra info for each block (length should match block_hashes)
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>, block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
#[serde(skip_serializing_if = "Option::is_none")]
is_eagle: Option<bool>,
}, },
BlockRemoved { BlockRemoved {
block_hashes: Vec<BlockHashValue>, block_hashes: Vec<BlockHashValue>,
...@@ -180,7 +189,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -180,7 +189,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let mut event_type: Option<String> = None; let mut event_type: Option<String> = None;
let mut block_hashes: Option<Vec<BlockHashValue>> = None; let mut block_hashes: Option<Vec<BlockHashValue>> = None;
let mut parent_block_hash: Option<Option<BlockHashValue>> = None; let mut parent_block_hash: Option<Option<BlockHashValue>> = None;
let mut token_ids: Option<Vec<u32>> = None; let mut token_ids: Option<KvTokenIds> = None;
let mut block_size: Option<usize> = None; let mut block_size: Option<usize> = None;
let mut medium: Option<Option<String>> = None; let mut medium: Option<Option<String>> = None;
let mut lora_name: Option<Option<String>> = None; let mut lora_name: Option<Option<String>> = None;
...@@ -227,6 +236,17 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -227,6 +236,17 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let block_hashes = let block_hashes =
block_hashes.ok_or_else(|| de::Error::missing_field("block_hashes"))?; block_hashes.ok_or_else(|| de::Error::missing_field("block_hashes"))?;
let token_ids = token_ids.ok_or_else(|| de::Error::missing_field("token_ids"))?; let token_ids = token_ids.ok_or_else(|| de::Error::missing_field("token_ids"))?;
let (raw_token_ids, is_eagle) = match token_ids {
KvTokenIds::Single(tids) => (tids, false),
KvTokenIds::Bigram(tids) => {
let mut new_tids: Vec<u32> = tids.iter().map(|&(first, _)| first).collect();
if !tids.is_empty() {
let last_token = tids.last().map(|&(_, second)| second).unwrap();
new_tids.push(last_token);
}
(new_tids, true)
}
};
let block_size = let block_size =
block_size.ok_or_else(|| de::Error::missing_field("block_size"))?; block_size.ok_or_else(|| de::Error::missing_field("block_size"))?;
let block_mm_infos = block_mm_infos let block_mm_infos = block_mm_infos
...@@ -235,11 +255,12 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -235,11 +255,12 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
Ok(RawKvEvent::BlockStored { Ok(RawKvEvent::BlockStored {
block_hashes, block_hashes,
parent_block_hash: parent_block_hash.unwrap_or(None), parent_block_hash: parent_block_hash.unwrap_or(None),
token_ids, token_ids: raw_token_ids,
block_size, block_size,
medium: medium.unwrap_or(None), medium: medium.unwrap_or(None),
lora_name: lora_name.unwrap_or(None), lora_name: lora_name.unwrap_or(None),
block_mm_infos, block_mm_infos,
is_eagle: Some(is_eagle),
}) })
} }
Some("BlockRemoved") => { Some("BlockRemoved") => {
...@@ -277,7 +298,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -277,7 +298,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
.next_element()? .next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &"missing block_hashes"))?; .ok_or_else(|| de::Error::invalid_length(1, &"missing block_hashes"))?;
let parent_block_hash: Option<BlockHashValue> = seq.next_element()?.unwrap_or(None); let parent_block_hash: Option<BlockHashValue> = seq.next_element()?.unwrap_or(None);
let token_ids: Vec<u32> = seq let token_ids: KvTokenIds = seq
.next_element()? .next_element()?
.ok_or_else(|| de::Error::invalid_length(3, &"missing token_ids"))?; .ok_or_else(|| de::Error::invalid_length(3, &"missing token_ids"))?;
let block_size: usize = seq let block_size: usize = seq
...@@ -297,14 +318,27 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -297,14 +318,27 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let block_mm_infos = let block_mm_infos =
block_mm_infos.or_else(|| extra_keys_to_block_mm_infos(extra_keys)); block_mm_infos.or_else(|| extra_keys_to_block_mm_infos(extra_keys));
let (raw_token_ids, is_eagle) = match token_ids {
KvTokenIds::Single(tids) => (tids, false),
KvTokenIds::Bigram(tids) => {
let mut new_tids: Vec<u32> = tids.iter().map(|&(first, _)| first).collect();
if !tids.is_empty() {
let last_token = tids.last().map(|&(_, second)| second).unwrap();
new_tids.push(last_token);
}
(new_tids, true)
}
};
Ok(RawKvEvent::BlockStored { Ok(RawKvEvent::BlockStored {
block_hashes, block_hashes,
parent_block_hash, parent_block_hash,
token_ids, token_ids: raw_token_ids,
block_size, block_size,
medium, medium,
lora_name, lora_name,
block_mm_infos, block_mm_infos,
is_eagle: Some(is_eagle),
}) })
} }
"BlockRemoved" => { "BlockRemoved" => {
...@@ -360,6 +394,7 @@ pub fn convert_event( ...@@ -360,6 +394,7 @@ pub fn convert_event(
lora_name, lora_name,
block_mm_infos, block_mm_infos,
medium: _, medium: _,
is_eagle,
} => { } => {
// Reject self-referencing blocks: all block hashes (including parent) must be unique. // Reject self-referencing blocks: all block hashes (including parent) must be unique.
{ {
...@@ -408,6 +443,7 @@ pub fn convert_event( ...@@ -408,6 +443,7 @@ pub fn convert_event(
lora_name.as_deref(), lora_name.as_deref(),
warning_count, warning_count,
block_mm_infos.as_deref(), block_mm_infos.as_deref(),
is_eagle,
), ),
}), }),
dp_rank, dp_rank,
...@@ -446,13 +482,17 @@ pub fn create_stored_block_from_parts( ...@@ -446,13 +482,17 @@ pub fn create_stored_block_from_parts(
token_ids: &[u32], token_ids: &[u32],
lora_name: Option<&str>, lora_name: Option<&str>,
mm_extra_info: Option<BlockExtraInfo>, mm_extra_info: Option<BlockExtraInfo>,
is_eagle: Option<bool>,
) -> KvCacheStoredBlockData { ) -> KvCacheStoredBlockData {
let block_mm_infos = mm_extra_info.as_ref().map(|info| vec![Some(info.clone())]); let block_mm_infos = mm_extra_info.as_ref().map(|info| vec![Some(info.clone())]);
let tokens_hash = compute_block_hash_for_seq( let tokens_hash = compute_block_hash_for_seq(
token_ids, token_ids,
kv_block_size, kv_block_size,
block_mm_infos.as_deref(), BlockHashOptions {
block_mm_infos: block_mm_infos.as_deref(),
lora_name, lora_name,
is_eagle,
},
)[0]; )[0];
tracing::trace!( tracing::trace!(
...@@ -470,6 +510,7 @@ pub fn create_stored_block_from_parts( ...@@ -470,6 +510,7 @@ pub fn create_stored_block_from_parts(
} }
} }
#[allow(clippy::too_many_arguments)]
pub fn create_stored_blocks( pub fn create_stored_blocks(
kv_block_size: u32, kv_block_size: u32,
token_ids: &[u32], token_ids: &[u32],
...@@ -478,10 +519,13 @@ pub fn create_stored_blocks( ...@@ -478,10 +519,13 @@ pub fn create_stored_blocks(
lora_name: Option<&str>, lora_name: Option<&str>,
warning_count: &Arc<AtomicU32>, warning_count: &Arc<AtomicU32>,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>, block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
is_eagle: Option<bool>,
) -> Vec<KvCacheStoredBlockData> { ) -> Vec<KvCacheStoredBlockData> {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new(); let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
let mut token_offset: usize = 0; let mut token_offset: usize = 0;
let append = is_eagle.unwrap_or(false) as usize;
for (block_idx, (num_tokens_it, block_hash_it)) in for (block_idx, (num_tokens_it, block_hash_it)) in
num_block_tokens.iter().zip(block_hashes.iter()).enumerate() num_block_tokens.iter().zip(block_hashes.iter()).enumerate()
{ {
...@@ -496,7 +540,19 @@ pub fn create_stored_blocks( ...@@ -496,7 +540,19 @@ pub fn create_stored_blocks(
break; break;
} }
let tokens = &token_ids[token_offset..(token_offset + *num_tokens_it as usize)]; let end = token_offset + append + *num_tokens_it as usize;
if end > token_ids.len() {
if warning_count.fetch_add(1, Ordering::Relaxed) < 3 {
tracing::warn!(
"Block not published. token_ids too short: need {}, got {}",
end,
token_ids.len()
);
}
break;
}
let tokens = &token_ids[token_offset..end];
let mm_extra_info = block_mm_infos let mm_extra_info = block_mm_infos
.and_then(|infos| infos.get(block_idx)) .and_then(|infos| infos.get(block_idx))
.and_then(|opt| opt.clone()); .and_then(|opt| opt.clone());
...@@ -507,9 +563,102 @@ pub fn create_stored_blocks( ...@@ -507,9 +563,102 @@ pub fn create_stored_blocks(
tokens, tokens,
lora_name, lora_name,
mm_extra_info, mm_extra_info,
is_eagle,
)); ));
token_offset += *num_tokens_it as usize; token_offset += *num_tokens_it as usize;
} }
blocks blocks
} }
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::AtomicU32;
use rmp_serde::{from_slice, to_vec};
use super::*;
#[test]
fn test_deserialize_bigram_block_stored_sequence() {
let raw_event = (
"BlockStored",
vec![BlockHashValue::Unsigned(11), BlockHashValue::Unsigned(12)],
Option::<BlockHashValue>::None,
vec![(10u32, 11u32), (11, 12), (12, 13), (13, 14)],
2usize,
Option::<u64>::None,
Option::<String>::None,
Option::<String>::None,
);
let encoded = to_vec(&raw_event).unwrap();
let event: RawKvEvent = from_slice(&encoded).unwrap();
match event {
RawKvEvent::BlockStored {
token_ids,
block_size,
is_eagle,
..
} => {
assert_eq!(token_ids, vec![10, 11, 12, 13, 14]);
assert_eq!(block_size, 2);
assert_eq!(is_eagle, Some(true));
}
other => panic!("expected BlockStored, got {other:?}"),
}
}
#[test]
fn test_convert_event_bigram_emits_eagle_windows() {
let raw_event = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(21), BlockHashValue::Unsigned(22)],
parent_block_hash: None,
token_ids: vec![10, 11, 12, 13, 14],
block_size: 2,
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: Some(true),
};
let warning_count = Arc::new(AtomicU32::new(0));
let placement_event =
convert_event(raw_event, 7, 2, WorkerWithDpRank::new(3, 0), &warning_count);
match placement_event.event.data {
KvCacheEventData::Stored(store_data) => {
assert_eq!(store_data.blocks.len(), 2);
assert_eq!(
store_data.blocks[0].block_hash,
ExternalSequenceBlockHash(21)
);
assert_eq!(
store_data.blocks[1].block_hash,
ExternalSequenceBlockHash(22)
);
let expected_first = compute_block_hash_for_seq(
&[10, 11, 12],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
let expected_second = compute_block_hash_for_seq(
&[12, 13, 14],
2,
BlockHashOptions {
is_eagle: Some(true),
..Default::default()
},
);
assert_eq!(store_data.blocks[0].tokens_hash, expected_first[0]);
assert_eq!(store_data.blocks[1].tokens_hash, expected_second[0]);
}
other => panic!("expected Stored event, got {other:?}"),
}
}
}
...@@ -562,6 +562,7 @@ impl ModelManager { ...@@ -562,6 +562,7 @@ impl ModelManager {
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
worker_type: &'static str, worker_type: &'static str,
model_name: Option<String>, model_name: Option<String>,
is_eagle: bool,
) -> anyhow::Result<Arc<KvRouter>> { ) -> anyhow::Result<Arc<KvRouter>> {
let client = endpoint.client().await?; let client = endpoint.client().await?;
...@@ -597,6 +598,7 @@ impl ModelManager { ...@@ -597,6 +598,7 @@ impl ModelManager {
kv_router_config, kv_router_config,
worker_type, worker_type,
model_name, model_name,
is_eagle,
) )
.await?; .await?;
Ok(Arc::new(chooser)) Ok(Arc::new(chooser))
......
...@@ -465,6 +465,7 @@ impl ModelWatcher { ...@@ -465,6 +465,7 @@ impl ModelWatcher {
Some(self.router_config.kv_router_config.clone()), Some(self.router_config.kv_router_config.clone()),
WORKER_TYPE_DECODE, // This is the decode router WORKER_TYPE_DECODE, // This is the decode router
Some(card.display_name.clone()), Some(card.display_name.clone()),
card.runtime_config.enable_eagle,
) )
.await?, .await?,
) )
...@@ -495,6 +496,7 @@ impl ModelWatcher { ...@@ -495,6 +496,7 @@ impl ModelWatcher {
self.router_config.enforce_disagg, self.router_config.enforce_disagg,
model_name.clone(), model_name.clone(),
namespace.clone(), namespace.clone(),
card.runtime_config.enable_eagle,
) )
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment