Unverified Commit c9eb6a83 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: expose estimated kv cache hit in dynamo-run (#1246)


Signed-off-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent b889948c
...@@ -24,6 +24,8 @@ class ServiceConfig(dict): ...@@ -24,6 +24,8 @@ class ServiceConfig(dict):
"""Configuration store that inherits from dict for simpler access patterns""" """Configuration store that inherits from dict for simpler access patterns"""
_instance = None _instance = None
COMMON_CONFIG_SERVICE = "Common"
COMMON_CONFIG_KEY = "common-configs"
@classmethod @classmethod
def get_instance(cls): def get_instance(cls):
...@@ -49,6 +51,33 @@ class ServiceConfig(dict): ...@@ -49,6 +51,33 @@ class ServiceConfig(dict):
raise ValueError(f"{service_name}.{key} must be specified in configuration") raise ValueError(f"{service_name}.{key} must be specified in configuration")
return self[service_name][key] return self[service_name][key]
@classmethod
def get_parsed_config(cls, service_name):
"""Get parsed config for a service with common configs applied, returned as dict"""
instance = cls.get_instance()
if service_name not in instance:
return {}
# Get service config excluding ServiceArgs if it exists
service_config = instance[service_name].copy()
if "ServiceArgs" in service_config:
del service_config["ServiceArgs"]
# Apply common configs if they exist
if (common := instance.get(cls.COMMON_CONFIG_SERVICE)) is not None and (
common_config_keys := service_config.get(cls.COMMON_CONFIG_KEY)
) is not None:
for key in common_config_keys:
if key in common and key not in service_config:
service_config[key] = common[key]
# Remove the common-configs key itself from the final config
if cls.COMMON_CONFIG_KEY in service_config:
del service_config[cls.COMMON_CONFIG_KEY]
return service_config
def as_args(self, service_name, prefix=""): def as_args(self, service_name, prefix=""):
"""Extract configs as CLI args for a service, with optional prefix filtering. """Extract configs as CLI args for a service, with optional prefix filtering.
...@@ -57,8 +86,6 @@ class ServiceConfig(dict): ...@@ -57,8 +86,6 @@ class ServiceConfig(dict):
the component's `common-configs` setting, and that key has not been overriden by the the component's `common-configs` setting, and that key has not been overriden by the
component's config. component's config.
""" """
COMMON_CONFIG_SERVICE = "Common"
COMMON_CONFIG_KEY = "common-configs"
if service_name not in self: if service_name not in self:
return [] return []
...@@ -69,7 +96,7 @@ class ServiceConfig(dict): ...@@ -69,7 +96,7 @@ class ServiceConfig(dict):
if prefix and not key.startswith(prefix): if prefix and not key.startswith(prefix):
return return
if key.endswith(COMMON_CONFIG_KEY): if key.endswith(self.COMMON_CONFIG_KEY):
return return
# Strip prefix if needed # Strip prefix if needed
...@@ -90,8 +117,8 @@ class ServiceConfig(dict): ...@@ -90,8 +117,8 @@ class ServiceConfig(dict):
if "ServiceArgs" in service_config: if "ServiceArgs" in service_config:
del service_config["ServiceArgs"] del service_config["ServiceArgs"]
if (common := self.get(COMMON_CONFIG_SERVICE)) is not None and ( if (common := self.get(self.COMMON_CONFIG_SERVICE)) is not None and (
common_config_keys := service_config.get(COMMON_CONFIG_KEY) common_config_keys := service_config.get(self.COMMON_CONFIG_KEY)
) is not None: ) is not None:
for key in common_config_keys: for key in common_config_keys:
if key in common and key not in service_config: if key in common and key not in service_config:
......
...@@ -64,8 +64,7 @@ class Frontend: ...@@ -64,8 +64,7 @@ class Frontend:
def __init__(self): def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration.""" """Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance() frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
frontend_config = FrontendConfig(**config.get("Frontend", {}))
self.frontend_config = frontend_config self.frontend_config = frontend_config
self.process = None self.process = None
self.setup_model() self.setup_model()
......
...@@ -60,8 +60,7 @@ class Frontend: ...@@ -60,8 +60,7 @@ class Frontend:
def __init__(self): def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration.""" """Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance() frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
frontend_config = FrontendConfig(**config.get("Frontend", {}))
self.frontend_config = frontend_config self.frontend_config = frontend_config
self.process = None self.process = None
......
...@@ -61,8 +61,7 @@ class Frontend: ...@@ -61,8 +61,7 @@ class Frontend:
processor = depends(Processor) processor = depends(Processor)
def __init__(self): def __init__(self):
config = ServiceConfig.get_instance() frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
frontend_config = FrontendConfig(**config.get("Frontend", {}))
# Chat/completions Endpoint # Chat/completions Endpoint
subprocess.run( subprocess.run(
......
...@@ -63,10 +63,13 @@ class Frontend: ...@@ -63,10 +63,13 @@ class Frontend:
def __init__(self): def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration.""" """Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance() self.frontend_config = FrontendConfig(
self.frontend_config = FrontendConfig(**config.get("Frontend", {})) **ServiceConfig.get_parsed_config("Frontend")
)
self.process = None self.process = None
logger.warning(f"Frontend config: {self.frontend_config}")
self.start_ingress_and_processor() self.start_ingress_and_processor()
def start_ingress_and_processor(self): def start_ingress_and_processor(self):
...@@ -87,6 +90,8 @@ class Frontend: ...@@ -87,6 +90,8 @@ class Frontend:
self.frontend_config.router, self.frontend_config.router,
] ]
logger.info(f"Frontend cmd: {cmd}")
self.process = subprocess.Popen( self.process = subprocess.Popen(
cmd, cmd,
stdout=None, stdout=None,
......
...@@ -212,7 +212,7 @@ class VllmWorker: ...@@ -212,7 +212,7 @@ class VllmWorker:
prefill_queue_size = await prefill_queue.get_queue_size() prefill_queue_size = await prefill_queue.get_queue_size()
disagg_router_decision = await self.disaggregated_router.prefill_remote( disagg_router_decision = await self.disaggregated_router.prefill_remote(
len(request.token_ids), len(request.token_ids),
0, # TODO: return prefix hit rate from dynamo-run router request.estimated_prefix_hit_num_blocks * self.engine_args.block_size,
prefill_queue_size, prefill_queue_size,
) )
else: else:
...@@ -225,12 +225,12 @@ class VllmWorker: ...@@ -225,12 +225,12 @@ class VllmWorker:
remote_prefill_request_callback=self.get_remote_prefill_request_callback(), remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
) )
logger.info( logger.info(
f"Prefilling remotely for request {request_id} with length {len(request.token_ids)}" f"Prefilling remotely for request {request_id} with length {len(request.token_ids)} (estimated prefix hit length {(request.estimated_prefix_hit_num_blocks or 0) * self.engine_args.block_size})"
) )
else: else:
remote_prefill_params = None remote_prefill_params = None
logger.info( logger.info(
f"Prefilling locally for request {request_id} with length {len(request.token_ids)}" f"Prefilling locally for request {request_id} with length {len(request.token_ids)} (estimated prefix hit length {request.estimated_prefix_hit_num_blocks * self.engine_args.block_size})"
) )
sampling_params = SamplingParams(**self.default_sampling_params) sampling_params = SamplingParams(**self.default_sampling_params)
......
...@@ -16,13 +16,13 @@ Common: ...@@ -16,13 +16,13 @@ Common:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
block-size: 64 block-size: 64
max-model-len: 16384 max-model-len: 16384
router: kv
Frontend: Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.VllmWorker.generate endpoint: dynamo.VllmWorker.generate
port: 8000 port: 8000
router: kv common-configs: [block-size, router]
common-configs: [block-size]
VllmWorker: VllmWorker:
enforce-eager: true enforce-eager: true
...@@ -32,4 +32,4 @@ VllmWorker: ...@@ -32,4 +32,4 @@ VllmWorker:
workers: 1 workers: 1
resources: resources:
gpu: '1' gpu: '1'
common-configs: [model, block-size, max-model-len] common-configs: [model, block-size, max-model-len, router]
\ No newline at end of file \ No newline at end of file
...@@ -17,13 +17,13 @@ Common: ...@@ -17,13 +17,13 @@ Common:
block-size: 64 block-size: 64
max-model-len: 16384 max-model-len: 16384
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}' kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
router: kv
Frontend: Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.VllmWorker.generate endpoint: dynamo.VllmWorker.generate
port: 8000 port: 8000
router: kv common-configs: [block-size, router]
common-configs: [block-size]
VllmWorker: VllmWorker:
remote-prefill: true remote-prefill: true
...@@ -35,7 +35,7 @@ VllmWorker: ...@@ -35,7 +35,7 @@ VllmWorker:
workers: 1 workers: 1
resources: resources:
gpu: 1 gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config] common-configs: [model, block-size, max-model-len, kv-transfer-config, router]
PrefillWorker: PrefillWorker:
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
......
...@@ -52,6 +52,7 @@ class PreprocessedRequest(BaseModel): ...@@ -52,6 +52,7 @@ class PreprocessedRequest(BaseModel):
eos_token_ids: List[TokenIdType] = Field(default_factory=list) eos_token_ids: List[TokenIdType] = Field(default_factory=list)
mdc_sum: Optional[str] = None mdc_sum: Optional[str] = None
annotations: List[str] = Field(default_factory=list) annotations: List[str] = Field(default_factory=list)
estimated_prefix_hit_num_blocks: Optional[int] = None
class DisaggPreprocessedRequest(BaseModel): class DisaggPreprocessedRequest(BaseModel):
......
...@@ -62,8 +62,7 @@ class Frontend: ...@@ -62,8 +62,7 @@ class Frontend:
def __init__(self): def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration.""" """Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance() frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
frontend_config = FrontendConfig(**config.get("Frontend", {}))
self.frontend_config = frontend_config self.frontend_config = frontend_config
self.process = None self.process = None
......
...@@ -129,7 +129,8 @@ impl KvRouter { ...@@ -129,7 +129,8 @@ impl KvRouter {
} }
/// Give these tokens, find the worker with the best match in it's KV cache. /// Give these tokens, find the worker with the best match in it's KV cache.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<i64> { /// Returned overlap amount is in number of blocks.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_size = self.block_size; let block_size = self.block_size;
...@@ -141,8 +142,12 @@ impl KvRouter { ...@@ -141,8 +142,12 @@ impl KvRouter {
.map(|block| LocalBlockHash(block.block_hash())) .map(|block| LocalBlockHash(block.block_hash()))
.collect(); .collect();
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?; let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?; let worker_id = self
Ok(worker_id) .scheduler
.schedule(overlap_scores.clone(), isl_tokens)
.await?;
let overlap_amount = overlap_scores.scores.get(&worker_id).copied().unwrap_or(0);
Ok((worker_id, overlap_amount))
} }
/// Get the block size this router was configured with /// Get the block size this router was configured with
...@@ -158,7 +163,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -158,7 +163,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
request: SingleIn<RouterRequest>, request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> { ) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts(); let (request, ctx) = request.into_parts();
let worker_id = self.find_best_match(&request.tokens).await?; let (worker_id, _) = self.find_best_match(&request.tokens).await?;
let response = RouterResponse { worker_id }; let response = RouterResponse { worker_id };
let response = Annotated::from_data(response); let response = Annotated::from_data(response);
...@@ -192,8 +197,13 @@ impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Er ...@@ -192,8 +197,13 @@ impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Er
match self.inner.client.instance_source.as_ref() { match self.inner.client.instance_source.as_ref() {
InstanceSource::Static => self.inner.r#static(request).await, InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => { InstanceSource::Dynamic(_) => {
let instance_id = self.chooser.find_best_match(&request.token_ids).await?; let (instance_id, overlap_amount) =
self.inner.direct(request, instance_id).await self.chooser.find_best_match(&request.token_ids).await?;
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
self.inner.direct(updated_request, instance_id).await
} }
} }
} }
......
...@@ -177,6 +177,7 @@ impl OpenAIPreprocessor { ...@@ -177,6 +177,7 @@ impl OpenAIPreprocessor {
builder.stop_conditions(stop_conditions); builder.stop_conditions(stop_conditions);
builder.annotations(request.annotations().unwrap_or_default()); builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone())); builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
Ok((builder.build()?, annotations)) Ok((builder.build()?, annotations))
} }
......
...@@ -47,6 +47,10 @@ pub struct PreprocessedRequest { ...@@ -47,6 +47,10 @@ pub struct PreprocessedRequest {
/// User requested annotations for the request /// User requested annotations for the request
#[builder(default)] #[builder(default)]
pub annotations: Vec<String>, pub annotations: Vec<String>,
/// Estimated number of prefix hit tokens (only used in kv aware routing)
#[builder(default)]
pub estimated_prefix_hit_num_blocks: Option<u32>,
} }
impl PreprocessedRequest { impl PreprocessedRequest {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment