Unverified Commit c9eb6a83 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: expose estimated kv cache hit in dynamo-run (#1246)


Signed-off-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent b889948c
......@@ -24,6 +24,8 @@ class ServiceConfig(dict):
"""Configuration store that inherits from dict for simpler access patterns"""
_instance = None
COMMON_CONFIG_SERVICE = "Common"
COMMON_CONFIG_KEY = "common-configs"
@classmethod
def get_instance(cls):
......@@ -49,6 +51,33 @@ class ServiceConfig(dict):
raise ValueError(f"{service_name}.{key} must be specified in configuration")
return self[service_name][key]
@classmethod
def get_parsed_config(cls, service_name):
"""Get parsed config for a service with common configs applied, returned as dict"""
instance = cls.get_instance()
if service_name not in instance:
return {}
# Get service config excluding ServiceArgs if it exists
service_config = instance[service_name].copy()
if "ServiceArgs" in service_config:
del service_config["ServiceArgs"]
# Apply common configs if they exist
if (common := instance.get(cls.COMMON_CONFIG_SERVICE)) is not None and (
common_config_keys := service_config.get(cls.COMMON_CONFIG_KEY)
) is not None:
for key in common_config_keys:
if key in common and key not in service_config:
service_config[key] = common[key]
# Remove the common-configs key itself from the final config
if cls.COMMON_CONFIG_KEY in service_config:
del service_config[cls.COMMON_CONFIG_KEY]
return service_config
def as_args(self, service_name, prefix=""):
"""Extract configs as CLI args for a service, with optional prefix filtering.
......@@ -57,8 +86,6 @@ class ServiceConfig(dict):
the component's `common-configs` setting, and that key has not been overriden by the
component's config.
"""
COMMON_CONFIG_SERVICE = "Common"
COMMON_CONFIG_KEY = "common-configs"
if service_name not in self:
return []
......@@ -69,7 +96,7 @@ class ServiceConfig(dict):
if prefix and not key.startswith(prefix):
return
if key.endswith(COMMON_CONFIG_KEY):
if key.endswith(self.COMMON_CONFIG_KEY):
return
# Strip prefix if needed
......@@ -90,8 +117,8 @@ class ServiceConfig(dict):
if "ServiceArgs" in service_config:
del service_config["ServiceArgs"]
if (common := self.get(COMMON_CONFIG_SERVICE)) is not None and (
common_config_keys := service_config.get(COMMON_CONFIG_KEY)
if (common := self.get(self.COMMON_CONFIG_SERVICE)) is not None and (
common_config_keys := service_config.get(self.COMMON_CONFIG_KEY)
) is not None:
for key in common_config_keys:
if key in common and key not in service_config:
......
......@@ -64,8 +64,7 @@ class Frontend:
def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
self.frontend_config = frontend_config
self.process = None
self.setup_model()
......
......@@ -60,8 +60,7 @@ class Frontend:
def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
self.frontend_config = frontend_config
self.process = None
......
......@@ -61,8 +61,7 @@ class Frontend:
processor = depends(Processor)
def __init__(self):
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
# Chat/completions Endpoint
subprocess.run(
......
......@@ -63,10 +63,13 @@ class Frontend:
def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance()
self.frontend_config = FrontendConfig(**config.get("Frontend", {}))
self.frontend_config = FrontendConfig(
**ServiceConfig.get_parsed_config("Frontend")
)
self.process = None
logger.warning(f"Frontend config: {self.frontend_config}")
self.start_ingress_and_processor()
def start_ingress_and_processor(self):
......@@ -87,6 +90,8 @@ class Frontend:
self.frontend_config.router,
]
logger.info(f"Frontend cmd: {cmd}")
self.process = subprocess.Popen(
cmd,
stdout=None,
......
......@@ -212,7 +212,7 @@ class VllmWorker:
prefill_queue_size = await prefill_queue.get_queue_size()
disagg_router_decision = await self.disaggregated_router.prefill_remote(
len(request.token_ids),
0, # TODO: return prefix hit rate from dynamo-run router
request.estimated_prefix_hit_num_blocks * self.engine_args.block_size,
prefill_queue_size,
)
else:
......@@ -225,12 +225,12 @@ class VllmWorker:
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
)
logger.info(
f"Prefilling remotely for request {request_id} with length {len(request.token_ids)}"
f"Prefilling remotely for request {request_id} with length {len(request.token_ids)} (estimated prefix hit length {(request.estimated_prefix_hit_num_blocks or 0) * self.engine_args.block_size})"
)
else:
remote_prefill_params = None
logger.info(
f"Prefilling locally for request {request_id} with length {len(request.token_ids)}"
f"Prefilling locally for request {request_id} with length {len(request.token_ids)} (estimated prefix hit length {request.estimated_prefix_hit_num_blocks * self.engine_args.block_size})"
)
sampling_params = SamplingParams(**self.default_sampling_params)
......
......@@ -16,13 +16,13 @@ Common:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
block-size: 64
max-model-len: 16384
router: kv
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.VllmWorker.generate
port: 8000
router: kv
common-configs: [block-size]
common-configs: [block-size, router]
VllmWorker:
enforce-eager: true
......@@ -32,4 +32,4 @@ VllmWorker:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len]
\ No newline at end of file
common-configs: [model, block-size, max-model-len, router]
\ No newline at end of file
......@@ -17,13 +17,13 @@ Common:
block-size: 64
max-model-len: 16384
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
router: kv
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.VllmWorker.generate
port: 8000
router: kv
common-configs: [block-size]
common-configs: [block-size, router]
VllmWorker:
remote-prefill: true
......@@ -35,7 +35,7 @@ VllmWorker:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config]
common-configs: [model, block-size, max-model-len, kv-transfer-config, router]
PrefillWorker:
max-num-batched-tokens: 16384
......
......@@ -52,6 +52,7 @@ class PreprocessedRequest(BaseModel):
eos_token_ids: List[TokenIdType] = Field(default_factory=list)
mdc_sum: Optional[str] = None
annotations: List[str] = Field(default_factory=list)
estimated_prefix_hit_num_blocks: Optional[int] = None
class DisaggPreprocessedRequest(BaseModel):
......
......@@ -62,8 +62,7 @@ class Frontend:
def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
self.frontend_config = frontend_config
self.process = None
......
......@@ -129,7 +129,8 @@ impl KvRouter {
}
/// Give these tokens, find the worker with the best match in it's KV cache.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<i64> {
/// Returned overlap amount is in number of blocks.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> {
let isl_tokens = tokens.len();
let block_size = self.block_size;
......@@ -141,8 +142,12 @@ impl KvRouter {
.map(|block| LocalBlockHash(block.block_hash()))
.collect();
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
let worker_id = self
.scheduler
.schedule(overlap_scores.clone(), isl_tokens)
.await?;
let overlap_amount = overlap_scores.scores.get(&worker_id).copied().unwrap_or(0);
Ok((worker_id, overlap_amount))
}
/// Get the block size this router was configured with
......@@ -158,7 +163,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts();
let worker_id = self.find_best_match(&request.tokens).await?;
let (worker_id, _) = self.find_best_match(&request.tokens).await?;
let response = RouterResponse { worker_id };
let response = Annotated::from_data(response);
......@@ -192,8 +197,13 @@ impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Er
match self.inner.client.instance_source.as_ref() {
InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => {
let instance_id = self.chooser.find_best_match(&request.token_ids).await?;
self.inner.direct(request, instance_id).await
let (instance_id, overlap_amount) =
self.chooser.find_best_match(&request.token_ids).await?;
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
self.inner.direct(updated_request, instance_id).await
}
}
}
......
......@@ -177,6 +177,7 @@ impl OpenAIPreprocessor {
builder.stop_conditions(stop_conditions);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
Ok((builder.build()?, annotations))
}
......
......@@ -47,6 +47,10 @@ pub struct PreprocessedRequest {
/// User requested annotations for the request
#[builder(default)]
pub annotations: Vec<String>,
/// Estimated number of prefix hit tokens (only used in kv aware routing)
#[builder(default)]
pub estimated_prefix_hit_num_blocks: Option<u32>,
}
impl PreprocessedRequest {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment