Commit a954a1c6 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: kv aware disagg router (#98)


Co-authored-by: default avataralec-flowers <aflowers@nvidia.com>
Co-authored-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarAlec <35311602+alec-flowers@users.noreply.github.com>
parent d99b188d
...@@ -124,13 +124,11 @@ There are three steps needed to enable the kv router: ...@@ -124,13 +124,11 @@ There are three steps needed to enable the kv router:
3. Launch the kv router in a separate terminal. 3. Launch the kv router in a separate terminal.
``` ```
RUST_LOG=info python3 kv_router.py \ RUST_LOG=info python3 kv_router.py \
--routing-strategy prefix \
--model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--block-size 64 \ --block-size 64 \
--min-workers 1 --min-workers 1
``` ```
where `--min-workers` is the number of (decode) workers. where `--min-workers` is the number of (decode) workers.
There is also python-based customized router that can be enabled by `--custom-router`.
You can choose only the prefix strategy for now: You can choose only the prefix strategy for now:
- `prefix`: Route requests to the worker that has the longest prefix match. - `prefix`: Route requests to the worker that has the longest prefix match.
...@@ -139,37 +137,18 @@ You can choose only the prefix strategy for now: ...@@ -139,37 +137,18 @@ You can choose only the prefix strategy for now:
The disaggregated router determines whether a request should be send to a The disaggregated router determines whether a request should be send to a
remote prefill engine or a local prefill engine for prefilling based on the remote prefill engine or a local prefill engine for prefilling based on the
prefill length. When prefilling locally, the vllm scheduler will prioritize prefill length. If kv router is enabled, the disaggregated router will use
prefill request and pause any ongoing decode requests. the absolute prefill length (actual prefill length - prefix hit length) to make
the decision.
There are two types of disaggregated router implementations:
* Rust native: provide a simple heuristic to route to prefill engine
if prefill length (including prefix catch hit) is greater than a threshold.
This threshold can by dynamically adjusted at runtime through etcd.
To check the current threshold (this will print out all kv pairs in etcd):
```
curl -s -L http://localhost:2379/v3/kv/range -X POST -d '{"key":"AA==", "range_end":"AA=="}' | jq -r '.kvs[] | "KEY: \(.key | @base64d)\nVALUE: \(.value | @base64d)\n---"'
```
To update the threshold: When prefilling locally, the vllm scheduler will prioritize
``` prefill request and pause any ongoing decode requests.
ETCDCTL_API=3 etcdctl --endpoints=http://localhost:2379 put 'public/components/disagg_router/models/chat/<vllm.served_model_name(default to "vllm")>' '{"max_local_prefill_length": <new_threshold>}'
```
* Python customized: provide a python implementation that can be easily customized.
However, it does not support dynamic threshold adjustment through etcd.
It is recommended to use the custom disaggregated router together with the custom
kv router as the rust kv router does not report kv cache hit ratio.
To use the python disaggregated router, add the following commands when launching
the decode worker:
To enable the disaggregated router, add the following commands in the decode workers: To enable the disaggregated router, add the following commands in the decode workers:
``` ```
python worker.py \ python worker.py \
... ...
--conditional-disagg \ --conditional-disagg \
<optional: --custom-disagg-router> \
--max-local-prefill-length <length> --max-local-prefill-length <length>
``` ```
...@@ -214,7 +193,7 @@ CUDA_VISIBLE_DEVICES=1 python3 worker.py \ ...@@ -214,7 +193,7 @@ CUDA_VISIBLE_DEVICES=1 python3 worker.py \
--max-num-batched-tokens 16384 \ --max-num-batched-tokens 16384 \
--max-model-len 16384 \ --max-model-len 16384 \
<optional kv router args: --router kv --enable-prefix-caching> <optional kv router args: --router kv --enable-prefix-caching>
<optional disaggregated router args: --conditional-disagg --custom-disagg-router --max-local-prefill-length <length>> <optional disaggregated router args: --conditional-disagg --max-local-prefill-length <length>>
``` ```
### Multi-Node Deployment ### Multi-Node Deployment
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from dynamo.llm import DisaggregatedRouter from vllm.logger import logger as vllm_logger
class PyDisaggregatedRouter: class PyDisaggregatedRouter:
...@@ -22,27 +22,15 @@ class PyDisaggregatedRouter: ...@@ -22,27 +22,15 @@ class PyDisaggregatedRouter:
self, self,
runtime, runtime,
served_model_name, served_model_name,
custom_disagg_router=False,
max_local_prefill_length=1000, max_local_prefill_length=1000,
max_remote_prefill_cache_hit_ratio=0.5,
): ):
self.runtime = runtime self.runtime = runtime
self.served_model_name = served_model_name self.served_model_name = served_model_name
self.max_local_prefill_length = max_local_prefill_length self.max_local_prefill_length = max_local_prefill_length
self.max_remote_prefill_cache_hit_ratio = max_remote_prefill_cache_hit_ratio
self.custom_disagg_router = custom_disagg_router
if not self.custom_disagg_router: def prefill_remote(self, prompt_length: int, prefix_hit_rate: float):
# TODO: add max_remote_prefill_cache_hit_ratio to rust router absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate))
self.disagg_router = DisaggregatedRouter( vllm_logger.info(
runtime, f"Remote prefill: {absolute_prefill_length > self.max_local_prefill_length} (prefill length: {absolute_prefill_length}/{prompt_length})"
served_model_name,
max_local_prefill_length,
) )
return absolute_prefill_length > self.max_local_prefill_length
def prefill_remote(self, prompt_length, cache_hit_length=0):
if self.custom_disagg_router:
# TODO: add max_remote_prefill_cache_hit_ratio to python router
return prompt_length > self.max_local_prefill_length
else:
return self.disagg_router.prefill_remote(prompt_length, cache_hit_length)
...@@ -15,98 +15,138 @@ ...@@ -15,98 +15,138 @@
import asyncio import asyncio
import random
from argparse import Namespace from argparse import Namespace
from enum import Enum
from typing import AsyncIterator from typing import AsyncIterator
import uvloop import uvloop
from utils.protocol import Tokens from utils.protocol import Tokens
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
from dynamo.llm import KvIndexer, KvMetricsAggregator, KvRouter from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
WorkerId = str WorkerId = str
class RoutingStrategy(Enum): class CustomRouter:
PREFIX = "prefix"
ROUND_ROBIN = "round_robin"
RANDOM = "random"
class Router:
""" """
Request handler for the generate endpoint Request handler for the generate endpoint
""" """
def __init__( def __init__(
self, self,
router: KvRouter, workers_client,
routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX, indexer: KvIndexer,
metrics_aggregator: KvMetricsAggregator,
): ):
vllm_logger.info( vllm_logger.info("Initializing Custom Router")
f"Initializing KV Router with strategy: {routing_strategy.value}" self.indexer = indexer
self.metrics_aggregator = metrics_aggregator
self.workers_client = workers_client
def _cost_function(
self,
scores: OverlapScores | None,
metrics: AggregatedMetrics | None,
token_length: int,
):
worker_scores = {}
if scores:
for worker_id, score in scores.scores.items():
# score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better
worker_scores[worker_id] = (
score * self.indexer.block_size() / token_length
) )
self.router = router
self.routing_strategy = routing_strategy
@dynamo_endpoint(Tokens, WorkerId) worker_metrics = {}
async def generate(self, request) -> AsyncIterator[WorkerId]: # pull metrics for each worker
lora_id = 0 max_waiting = 0.0
worker_id = None if metrics:
if self.routing_strategy == RoutingStrategy.PREFIX: for endpoint in metrics.endpoints:
try: worker_id = endpoint.worker_id
worker_id = await self.router.schedule(request.tokens, lora_id) worker_metrics[worker_id] = {
# [NOTE][TODO] Now that the scheduler may return more error messages, "gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
# now we are catching all exceptions and logging them. Should have if hasattr(endpoint, "gpu_cache_usage_perc")
# catch specific router exceptions once we have dedicated types. else 0.0,
except Exception as e: "num_requests_waiting": endpoint.num_requests_waiting
vllm_logger.info(f"{e}") if hasattr(endpoint, "num_requests_waiting")
worker_id = "" else 0.0,
vllm_logger.exception(f"Error during worker selection: {e}") "gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
else 0.0,
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
)
vllm_logger.info(f"Scheduling to worker_id: {worker_id}") # Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.endpoint_ids()
worker_logits = {}
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(
worker_id,
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
yield str(worker_id) normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
if max_waiting > 0
else 0.0
)
else: # Have 1 metric that weights towards cache hit
# TODO: Do we implement round_robin and random here? # 2 metrics that penalize overloaded worker and queuing
# or just skip this router and directly enable in preprocess? worker_logits[worker_id] = (
raise NotImplementedError( 2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
f"Routing strategy {self.routing_strategy} not implemented" )
vllm_logger.info(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
) )
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return ""
class CustomRouter: # Select the worker with the highest logit
""" if worker_logits:
Request handler for the generate endpoint max_logit = max(worker_logits.values())
""" best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
else:
best_worker_id = ""
def __init__( # Log the metrics for the selected worker
self, if best_worker_id:
indexer: KvIndexer, vllm_logger.info(
metrics_aggregator: KvMetricsAggregator, f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
): )
self.indexer = indexer vllm_logger.info(
self.metrics_aggregator = metrics_aggregator f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
def _cost_function(self, scores, metrics): metrics_dict = worker_metrics.get(best_worker_id, {})
# naive cost function for demonstration purposes vllm_logger.info(
current_best = ("", 0) f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
for worker_id, score in scores.scores.items(): )
if score > current_best[1]: vllm_logger.info(
current_best = (worker_id, score) f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
for endpoint in metrics.endpoints:
if endpoint.worker_id == current_best[0]:
print(f"Metrics of endpoint: {endpoint.worker_id}")
print(
f"request slot usage: {endpoint.request_active_slots} / {endpoint.request_total_slots}"
) )
print( vllm_logger.info(
f"KV block usage: {endpoint.kv_active_blocks} / {endpoint.kv_total_blocks}" f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
) )
return current_best[0]
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
@dynamo_endpoint(Tokens, WorkerId) @dynamo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]: async def generate(self, request) -> AsyncIterator[WorkerId]:
...@@ -116,20 +156,24 @@ class CustomRouter: ...@@ -116,20 +156,24 @@ class CustomRouter:
scores = await self.indexer.find_matches_for_request( scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id request.tokens, lora_id
) )
metrics = await self.metrics_aggregator.get_metrics()
worker_id = self._cost_function(scores, metrics)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e: except Exception as e:
vllm_logger.info(f"{e}") scores = {}
vllm_logger.exception(f"Error finding matches: {e}")
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics()
schedule_result = self._cost_function(scores, metrics, token_length)
if schedule_result == "":
worker_id = "" worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}") prefix_hit_rate = 0.0
else:
worker_id, prefix_hit_rate = schedule_result
vllm_logger.info(f"Scheduling to worker_id: {worker_id}") vllm_logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
yield str(worker_id) yield f"{worker_id}_{prefix_hit_rate}"
@dynamo_worker() @dynamo_worker()
...@@ -144,14 +188,6 @@ async def worker(runtime: DistributedRuntime, args: Namespace): ...@@ -144,14 +188,6 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
.endpoint("generate") .endpoint("generate")
.client() .client()
) )
wait_task = workers_client.wait_for_endpoints()
await asyncio.sleep(1)
while not wait_task.done():
vllm_logger.info("Waiting for workers to be ready...")
await asyncio.sleep(5)
wait_task.result()
while len(workers_client.endpoint_ids()) < args.min_workers: while len(workers_client.endpoint_ids()) < args.min_workers:
vllm_logger.info( vllm_logger.info(
...@@ -172,16 +208,11 @@ async def worker(runtime: DistributedRuntime, args: Namespace): ...@@ -172,16 +208,11 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
endpoint = router_component.endpoint("generate") endpoint = router_component.endpoint("generate")
if args.custom_router:
indexer = KvIndexer(kv_listener, args.block_size) indexer = KvIndexer(kv_listener, args.block_size)
metrics_aggregator = KvMetricsAggregator(kv_listener) metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
CustomRouter(indexer, metrics_aggregator).generate CustomRouter(workers_client, indexer, metrics_aggregator).generate
) )
else:
# TODO Read block_size from MDC
router = KvRouter(runtime, kv_listener, args.block_size)
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -190,13 +221,6 @@ if __name__ == "__main__": ...@@ -190,13 +221,6 @@ if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--routing-strategy",
type=RoutingStrategy,
default=RoutingStrategy.PREFIX,
choices=list(RoutingStrategy),
help="Routing strategy to use",
)
parser.add_argument( parser.add_argument(
"--min-workers", "--min-workers",
type=int, type=int,
...@@ -209,9 +233,11 @@ if __name__ == "__main__": ...@@ -209,9 +233,11 @@ if __name__ == "__main__":
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served", help="Model that is being served",
) )
# TODO: Read block size
parser.add_argument( parser.add_argument(
"--block-size", "--block-size",
type=int, type=int,
default=64,
help="KV block size", help="KV block size",
) )
parser.add_argument( parser.add_argument(
......
...@@ -98,11 +98,14 @@ class Processor(ProcessMixIn): ...@@ -98,11 +98,14 @@ class Processor(ProcessMixIn):
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json() Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
) )
worker_id = ( route_response = (
await worker_id_generator.__anext__() await worker_id_generator.__anext__()
) # only one worker id is returned ) # only one worker id is returned
worker_id = worker_id.data() worker_id, prefix_hit_rate = route_response.data().split("_")
vllm_logger.info(f"Worker ID: {worker_id}") prefix_hit_rate = float(prefix_hit_rate)
vllm_logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
if worker_id == "": if worker_id == "":
engine_generator = await self.workers_client.random( engine_generator = await self.workers_client.random(
...@@ -110,6 +113,7 @@ class Processor(ProcessMixIn): ...@@ -110,6 +113,7 @@ class Processor(ProcessMixIn):
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json() ).model_dump_json()
) )
else: else:
...@@ -118,6 +122,7 @@ class Processor(ProcessMixIn): ...@@ -118,6 +122,7 @@ class Processor(ProcessMixIn):
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json(), ).model_dump_json(),
int(worker_id), int(worker_id),
) )
......
...@@ -75,6 +75,7 @@ class vLLMGenerateRequest(BaseModel): ...@@ -75,6 +75,7 @@ class vLLMGenerateRequest(BaseModel):
engine_prompt: PatchedTokensPrompt engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams sampling_params: SamplingParams
request_id: str request_id: str
prefix_hit_rate: Optional[float] = 0.0
@field_validator("sampling_params", mode="before") @field_validator("sampling_params", mode="before")
@classmethod @classmethod
......
...@@ -35,33 +35,17 @@ def parse_vllm_args() -> AsyncEngineArgs: ...@@ -35,33 +35,17 @@ def parse_vllm_args() -> AsyncEngineArgs:
action="store_true", action="store_true",
help="Use disaggregated router to decide whether to prefill locally or remotely", help="Use disaggregated router to decide whether to prefill locally or remotely",
) )
parser.add_argument(
"--custom-disagg-router",
action="store_true",
help="Use custom python implementation of disaggregated router instead of the default rust one",
)
parser.add_argument( parser.add_argument(
"--max-local-prefill-length", "--max-local-prefill-length",
type=int, type=int,
default=1000, default=1000,
help="Maximum length of local prefill", help="Maximum length of local prefill",
) )
parser.add_argument(
"--max-remote-prefill-cache-hit-ratio",
type=float,
default=0.5,
help="Maximum cache hit ratio for remote prefill "
"(only applicable to custom python implementation of disaggregated router)",
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.router = args.router engine_args.router = args.router
engine_args.remote_prefill = args.remote_prefill engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg engine_args.conditional_disagg = args.conditional_disagg
engine_args.custom_disagg_router = args.custom_disagg_router
engine_args.max_local_prefill_length = args.max_local_prefill_length engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.max_remote_prefill_cache_hit_ratio = (
args.max_remote_prefill_cache_hit_ratio
)
return engine_args return engine_args
...@@ -81,7 +81,7 @@ class RequestHandler: ...@@ -81,7 +81,7 @@ class RequestHandler:
# TODO: consider prefix hit when deciding prefill locally or remotely # TODO: consider prefix hit when deciding prefill locally or remotely
if self.disaggregated_router is not None: if self.disaggregated_router is not None:
disagg_router_decision = self.disaggregated_router.prefill_remote( disagg_router_decision = self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]), 0 len(request.engine_prompt["prompt_token_ids"]), request.prefix_hit_rate
) )
else: else:
# always prefill remotely if no disaggregated router is provided # always prefill remotely if no disaggregated router is provided
...@@ -164,10 +164,13 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -164,10 +164,13 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
# Initially send dummy metrics to kick start, # Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered # vLLM will not update stat until forward pass is triggered
metrics_publisher.publish( metrics_publisher.publish(
0, 0, # request_active_slots
1024, 1024, # request_total_slots
0, 0, # kv_active_blocks
1024, 1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
) )
if engine_args.remote_prefill: if engine_args.remote_prefill:
...@@ -179,9 +182,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -179,9 +182,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
disaggregated_router = PyDisaggregatedRouter( disaggregated_router = PyDisaggregatedRouter(
runtime, runtime,
served_model_name, served_model_name,
custom_disagg_router=engine_args.custom_disagg_router,
max_local_prefill_length=engine_args.max_local_prefill_length, max_local_prefill_length=engine_args.max_local_prefill_length,
max_remote_prefill_cache_hit_ratio=engine_args.max_remote_prefill_cache_hit_ratio,
) )
else: else:
disaggregated_router = None disaggregated_router = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment