Commit a954a1c6 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: kv aware disagg router (#98)


Co-authored-by: default avataralec-flowers <aflowers@nvidia.com>
Co-authored-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarAlec <35311602+alec-flowers@users.noreply.github.com>
parent d99b188d
......@@ -124,13 +124,11 @@ There are three steps needed to enable the kv router:
3. Launch the kv router in a separate terminal.
```
RUST_LOG=info python3 kv_router.py \
--routing-strategy prefix \
--model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--block-size 64 \
--min-workers 1
```
where `--min-workers` is the number of (decode) workers.
There is also python-based customized router that can be enabled by `--custom-router`.
You can choose only the prefix strategy for now:
- `prefix`: Route requests to the worker that has the longest prefix match.
......@@ -139,37 +137,18 @@ You can choose only the prefix strategy for now:
The disaggregated router determines whether a request should be send to a
remote prefill engine or a local prefill engine for prefilling based on the
prefill length. When prefilling locally, the vllm scheduler will prioritize
prefill request and pause any ongoing decode requests.
There are two types of disaggregated router implementations:
* Rust native: provide a simple heuristic to route to prefill engine
if prefill length (including prefix catch hit) is greater than a threshold.
This threshold can by dynamically adjusted at runtime through etcd.
To check the current threshold (this will print out all kv pairs in etcd):
```
curl -s -L http://localhost:2379/v3/kv/range -X POST -d '{"key":"AA==", "range_end":"AA=="}' | jq -r '.kvs[] | "KEY: \(.key | @base64d)\nVALUE: \(.value | @base64d)\n---"'
```
prefill length. If kv router is enabled, the disaggregated router will use
the absolute prefill length (actual prefill length - prefix hit length) to make
the decision.
To update the threshold:
```
ETCDCTL_API=3 etcdctl --endpoints=http://localhost:2379 put 'public/components/disagg_router/models/chat/<vllm.served_model_name(default to "vllm")>' '{"max_local_prefill_length": <new_threshold>}'
```
* Python customized: provide a python implementation that can be easily customized.
However, it does not support dynamic threshold adjustment through etcd.
It is recommended to use the custom disaggregated router together with the custom
kv router as the rust kv router does not report kv cache hit ratio.
To use the python disaggregated router, add the following commands when launching
the decode worker:
When prefilling locally, the vllm scheduler will prioritize
prefill request and pause any ongoing decode requests.
To enable the disaggregated router, add the following commands in the decode workers:
```
python worker.py \
...
--conditional-disagg \
<optional: --custom-disagg-router> \
--max-local-prefill-length <length>
```
......@@ -214,7 +193,7 @@ CUDA_VISIBLE_DEVICES=1 python3 worker.py \
--max-num-batched-tokens 16384 \
--max-model-len 16384 \
<optional kv router args: --router kv --enable-prefix-caching>
<optional disaggregated router args: --conditional-disagg --custom-disagg-router --max-local-prefill-length <length>>
<optional disaggregated router args: --conditional-disagg --max-local-prefill-length <length>>
```
### Multi-Node Deployment
......
......@@ -14,7 +14,7 @@
# limitations under the License.
from dynamo.llm import DisaggregatedRouter
from vllm.logger import logger as vllm_logger
class PyDisaggregatedRouter:
......@@ -22,27 +22,15 @@ class PyDisaggregatedRouter:
self,
runtime,
served_model_name,
custom_disagg_router=False,
max_local_prefill_length=1000,
max_remote_prefill_cache_hit_ratio=0.5,
):
self.runtime = runtime
self.served_model_name = served_model_name
self.max_local_prefill_length = max_local_prefill_length
self.max_remote_prefill_cache_hit_ratio = max_remote_prefill_cache_hit_ratio
self.custom_disagg_router = custom_disagg_router
if not self.custom_disagg_router:
# TODO: add max_remote_prefill_cache_hit_ratio to rust router
self.disagg_router = DisaggregatedRouter(
runtime,
served_model_name,
max_local_prefill_length,
def prefill_remote(self, prompt_length: int, prefix_hit_rate: float):
absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate))
vllm_logger.info(
f"Remote prefill: {absolute_prefill_length > self.max_local_prefill_length} (prefill length: {absolute_prefill_length}/{prompt_length})"
)
def prefill_remote(self, prompt_length, cache_hit_length=0):
if self.custom_disagg_router:
# TODO: add max_remote_prefill_cache_hit_ratio to python router
return prompt_length > self.max_local_prefill_length
else:
return self.disagg_router.prefill_remote(prompt_length, cache_hit_length)
return absolute_prefill_length > self.max_local_prefill_length
......@@ -15,98 +15,138 @@
import asyncio
import random
from argparse import Namespace
from enum import Enum
from typing import AsyncIterator
import uvloop
from utils.protocol import Tokens
from vllm.logger import logger as vllm_logger
from dynamo.llm import KvIndexer, KvMetricsAggregator, KvRouter
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
WorkerId = str
class RoutingStrategy(Enum):
PREFIX = "prefix"
ROUND_ROBIN = "round_robin"
RANDOM = "random"
class Router:
class CustomRouter:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
router: KvRouter,
routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX,
workers_client,
indexer: KvIndexer,
metrics_aggregator: KvMetricsAggregator,
):
vllm_logger.info(
f"Initializing KV Router with strategy: {routing_strategy.value}"
vllm_logger.info("Initializing Custom Router")
self.indexer = indexer
self.metrics_aggregator = metrics_aggregator
self.workers_client = workers_client
def _cost_function(
self,
scores: OverlapScores | None,
metrics: AggregatedMetrics | None,
token_length: int,
):
worker_scores = {}
if scores:
for worker_id, score in scores.scores.items():
# score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better
worker_scores[worker_id] = (
score * self.indexer.block_size() / token_length
)
self.router = router
self.routing_strategy = routing_strategy
@dynamo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = None
if self.routing_strategy == RoutingStrategy.PREFIX:
try:
worker_id = await self.router.schedule(request.tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}")
worker_metrics = {}
# pull metrics for each worker
max_waiting = 0.0
if metrics:
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
worker_metrics[worker_id] = {
"gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
if hasattr(endpoint, "gpu_cache_usage_perc")
else 0.0,
"num_requests_waiting": endpoint.num_requests_waiting
if hasattr(endpoint, "num_requests_waiting")
else 0.0,
"gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
else 0.0,
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
)
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.endpoint_ids()
worker_logits = {}
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(
worker_id,
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
yield str(worker_id)
normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
if max_waiting > 0
else 0.0
)
else:
# TODO: Do we implement round_robin and random here?
# or just skip this router and directly enable in preprocess?
raise NotImplementedError(
f"Routing strategy {self.routing_strategy} not implemented"
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = (
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
)
vllm_logger.info(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
)
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return ""
class CustomRouter:
"""
Request handler for the generate endpoint
"""
# Select the worker with the highest logit
if worker_logits:
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
else:
best_worker_id = ""
def __init__(
self,
indexer: KvIndexer,
metrics_aggregator: KvMetricsAggregator,
):
self.indexer = indexer
self.metrics_aggregator = metrics_aggregator
# Log the metrics for the selected worker
if best_worker_id:
vllm_logger.info(
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
)
vllm_logger.info(
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
def _cost_function(self, scores, metrics):
# naive cost function for demonstration purposes
current_best = ("", 0)
for worker_id, score in scores.scores.items():
if score > current_best[1]:
current_best = (worker_id, score)
for endpoint in metrics.endpoints:
if endpoint.worker_id == current_best[0]:
print(f"Metrics of endpoint: {endpoint.worker_id}")
print(
f"request slot usage: {endpoint.request_active_slots} / {endpoint.request_total_slots}"
metrics_dict = worker_metrics.get(best_worker_id, {})
vllm_logger.info(
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
)
vllm_logger.info(
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
)
print(
f"KV block usage: {endpoint.kv_active_blocks} / {endpoint.kv_total_blocks}"
vllm_logger.info(
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
)
return current_best[0]
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
@dynamo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
......@@ -116,20 +156,24 @@ class CustomRouter:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
metrics = await self.metrics_aggregator.get_metrics()
worker_id = self._cost_function(scores, metrics)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
scores = {}
vllm_logger.exception(f"Error finding matches: {e}")
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics()
schedule_result = self._cost_function(scores, metrics, token_length)
if schedule_result == "":
worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}")
prefix_hit_rate = 0.0
else:
worker_id, prefix_hit_rate = schedule_result
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
vllm_logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
yield str(worker_id)
yield f"{worker_id}_{prefix_hit_rate}"
@dynamo_worker()
......@@ -144,14 +188,6 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
.endpoint("generate")
.client()
)
wait_task = workers_client.wait_for_endpoints()
await asyncio.sleep(1)
while not wait_task.done():
vllm_logger.info("Waiting for workers to be ready...")
await asyncio.sleep(5)
wait_task.result()
while len(workers_client.endpoint_ids()) < args.min_workers:
vllm_logger.info(
......@@ -172,16 +208,11 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
endpoint = router_component.endpoint("generate")
if args.custom_router:
indexer = KvIndexer(kv_listener, args.block_size)
metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint(
CustomRouter(indexer, metrics_aggregator).generate
CustomRouter(workers_client, indexer, metrics_aggregator).generate
)
else:
# TODO Read block_size from MDC
router = KvRouter(runtime, kv_listener, args.block_size)
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
if __name__ == "__main__":
......@@ -190,13 +221,6 @@ if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--routing-strategy",
type=RoutingStrategy,
default=RoutingStrategy.PREFIX,
choices=list(RoutingStrategy),
help="Routing strategy to use",
)
parser.add_argument(
"--min-workers",
type=int,
......@@ -209,9 +233,11 @@ if __name__ == "__main__":
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
# TODO: Read block size
parser.add_argument(
"--block-size",
type=int,
default=64,
help="KV block size",
)
parser.add_argument(
......
......@@ -98,11 +98,14 @@ class Processor(ProcessMixIn):
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
)
worker_id = (
route_response = (
await worker_id_generator.__anext__()
) # only one worker id is returned
worker_id = worker_id.data()
vllm_logger.info(f"Worker ID: {worker_id}")
worker_id, prefix_hit_rate = route_response.data().split("_")
prefix_hit_rate = float(prefix_hit_rate)
vllm_logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
if worker_id == "":
engine_generator = await self.workers_client.random(
......@@ -110,6 +113,7 @@ class Processor(ProcessMixIn):
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json()
)
else:
......@@ -118,6 +122,7 @@ class Processor(ProcessMixIn):
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json(),
int(worker_id),
)
......
......@@ -75,6 +75,7 @@ class vLLMGenerateRequest(BaseModel):
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
prefix_hit_rate: Optional[float] = 0.0
@field_validator("sampling_params", mode="before")
@classmethod
......
......@@ -35,33 +35,17 @@ def parse_vllm_args() -> AsyncEngineArgs:
action="store_true",
help="Use disaggregated router to decide whether to prefill locally or remotely",
)
parser.add_argument(
"--custom-disagg-router",
action="store_true",
help="Use custom python implementation of disaggregated router instead of the default rust one",
)
parser.add_argument(
"--max-local-prefill-length",
type=int,
default=1000,
help="Maximum length of local prefill",
)
parser.add_argument(
"--max-remote-prefill-cache-hit-ratio",
type=float,
default=0.5,
help="Maximum cache hit ratio for remote prefill "
"(only applicable to custom python implementation of disaggregated router)",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.router = args.router
engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg
engine_args.custom_disagg_router = args.custom_disagg_router
engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.max_remote_prefill_cache_hit_ratio = (
args.max_remote_prefill_cache_hit_ratio
)
return engine_args
......@@ -81,7 +81,7 @@ class RequestHandler:
# TODO: consider prefix hit when deciding prefill locally or remotely
if self.disaggregated_router is not None:
disagg_router_decision = self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]), 0
len(request.engine_prompt["prompt_token_ids"]), request.prefix_hit_rate
)
else:
# always prefill remotely if no disaggregated router is provided
......@@ -164,10 +164,13 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
metrics_publisher.publish(
0,
1024,
0,
1024,
0, # request_active_slots
1024, # request_total_slots
0, # kv_active_blocks
1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
)
if engine_args.remote_prefill:
......@@ -179,9 +182,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
disaggregated_router = PyDisaggregatedRouter(
runtime,
served_model_name,
custom_disagg_router=engine_args.custom_disagg_router,
max_local_prefill_length=engine_args.max_local_prefill_length,
max_remote_prefill_cache_hit_ratio=engine_args.max_remote_prefill_cache_hit_ratio,
)
else:
disaggregated_router = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment