Unverified Commit 2be5e8f5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: reduce code repetition in processor (#919)

parent 0086ebc6
...@@ -18,17 +18,19 @@ import argparse ...@@ -18,17 +18,19 @@ import argparse
import logging import logging
import random import random
from argparse import Namespace from argparse import Namespace
from typing import AsyncIterator from typing import AsyncIterator, Tuple
from components.worker import VllmWorker from components.worker import VllmWorker
from utils.logging import check_required_workers from utils.logging import check_required_workers
from utils.protocol import Tokens from utils.protocol import Tokens
from utils.vllm import RouterType
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
WorkerId = str WorkerId = str
fallback_msg = "Will fallback to random routing."
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -60,6 +62,12 @@ def parse_args(service_name, prefix) -> Namespace: ...@@ -60,6 +62,12 @@ def parse_args(service_name, prefix) -> Namespace:
default=False, default=False,
help="Whether to use custom router or not", help="Whether to use custom router or not",
) )
parser.add_argument(
"--router",
type=str,
default="kv",
help="The router type",
)
config = ServiceConfig.get_instance() config = ServiceConfig.get_instance()
config_args = config.as_args(service_name, prefix=prefix) config_args = config.as_args(service_name, prefix=prefix)
args = parser.parse_args(config_args) args = parser.parse_args(config_args)
...@@ -101,10 +109,13 @@ class Router: ...@@ -101,10 +109,13 @@ class Router:
.client() .client()
) )
self.router_type = self.args.router
await check_required_workers(self.workers_client, self.args.min_workers) await check_required_workers(self.workers_client, self.args.min_workers)
kv_listener = self.runtime.namespace("dynamo").component("VllmWorker") kv_listener = self.runtime.namespace("dynamo").component("VllmWorker")
await kv_listener.create_service() await kv_listener.create_service()
if self.router_type == RouterType.KV:
self.indexer = KvIndexer(kv_listener, self.args.block_size) self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener) self.metrics_aggregator = KvMetricsAggregator(kv_listener)
logger.info("KV Router initialized") logger.info("KV Router initialized")
...@@ -182,7 +193,8 @@ class Router: ...@@ -182,7 +193,8 @@ class Router:
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {gpu_cache_usage:.3f} - {normalized_waiting:.3f}" f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {gpu_cache_usage:.3f} - {normalized_waiting:.3f}"
) )
if not worker_logits or all(logit == 0 for logit in worker_logits.values()): if not worker_logits or not any(worker_logits.values()):
logger.warning(f"All worker logits are zero. {fallback_msg}.")
return "", 0.0 return "", 0.0
# Select the worker with the highest logit # Select the worker with the highest logit
...@@ -211,8 +223,47 @@ class Router: ...@@ -211,8 +223,47 @@ class Router:
return best_worker_id, worker_scores.get(best_worker_id, 0.0) return best_worker_id, worker_scores.get(best_worker_id, 0.0)
def _get_underloaded_worker(self, metrics: AggregatedMetrics | None):
if not metrics:
logger.warning(f"Cannot get metrics. {fallback_msg}")
return "", 0.0
kv_load = {
endpoint.worker_id: getattr(endpoint, "gpu_cache_usage_perc", 0.0)
for endpoint in metrics.endpoints
}
if not kv_load or not any(kv_load.values()):
logger.warning(f"All KV loads are zero. {fallback_msg}")
return "", 0.0
min_load = min(kv_load.values())
min_load_workers = [
worker_id for worker_id, load in kv_load.items() if load == min_load
]
best_worker_id = random.choice(min_load_workers)
logger.info(
f"Selected worker: {best_worker_id}, KV load: {kv_load[best_worker_id]:.3f}"
)
return best_worker_id, kv_load[best_worker_id]
@dynamo_endpoint() @dynamo_endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]: async def generate(self, request: Tokens) -> AsyncIterator[Tuple[WorkerId, float]]:
metrics = await self.metrics_aggregator.get_metrics()
# Quick return for KV_LOAD mode
if self.router_type == RouterType.KV_LOAD:
try:
yield self._get_underloaded_worker(metrics)
except Exception as e:
logger.exception(
f"Error finding underloaded worker: {e}. {fallback_msg}"
)
yield "", 0.0
return
# Existing KV routing logic
lora_id = 0 lora_id = 0
try: try:
scores = await self.indexer.find_matches_for_request( scores = await self.indexer.find_matches_for_request(
...@@ -220,14 +271,17 @@ class Router: ...@@ -220,14 +271,17 @@ class Router:
) )
except Exception as e: except Exception as e:
scores = {} scores = {}
logger.exception(f"Error finding matches: {e}") logger.exception(f"Error finding matches: {e}. {fallback_msg}")
yield "", 0.0
return
metrics = await self.metrics_aggregator.get_metrics()
worker_id, prefix_hit_rate = self._cost_function( worker_id, prefix_hit_rate = self._cost_function(
scores, metrics, len(request.tokens) scores, metrics, len(request.tokens)
) )
if worker_id:
logger.info( logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}" f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
) )
yield f"{worker_id}_{prefix_hit_rate}"
yield worker_id, prefix_hit_rate
...@@ -95,7 +95,8 @@ class Processor(ProcessMixIn): ...@@ -95,7 +95,8 @@ class Processor(ProcessMixIn):
.client() .client()
) )
if self.engine_args.router == RouterType.KV: self.use_router = self.engine_args.router in (RouterType.KV, RouterType.KV_LOAD)
if self.use_router:
router_ns, router_name = Router.dynamo_address() # type: ignore router_ns, router_name = Router.dynamo_address() # type: ignore
self.router_client = ( self.router_client = (
await runtime.namespace(router_ns) await runtime.namespace(router_ns)
...@@ -116,22 +117,6 @@ class Processor(ProcessMixIn): ...@@ -116,22 +117,6 @@ class Processor(ProcessMixIn):
{"router": self.engine_args.router}, {"router": self.engine_args.router},
) )
async def _get_kv_load(self):
metrics = await self.metrics_aggregator.get_metrics()
kv_load = {}
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
kv_load[worker_id] = getattr(endpoint, "gpu_cache_usage_perc", 0.0)
return kv_load
async def _get_pending_requests(self):
metrics = await self.metrics_aggregator.get_metrics()
pending_requests = {}
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
pending_requests[worker_id] = getattr(endpoint, "num_requests_waiting", 0)
return pending_requests
async def _generate( async def _generate(
self, self,
raw_request: Union[CompletionRequest, ChatCompletionRequest], raw_request: Union[CompletionRequest, ChatCompletionRequest],
...@@ -146,81 +131,38 @@ class Processor(ProcessMixIn): ...@@ -146,81 +131,38 @@ class Processor(ProcessMixIn):
engine_prompt, engine_prompt,
sampling_params, sampling_params,
) = await self._parse_raw_request(raw_request) ) = await self._parse_raw_request(raw_request)
# TODO: queue request at processor when engines are full
router_mode = (await self.etcd_kv_cache.get("router")).decode() router_mode = (await self.etcd_kv_cache.get("router")).decode()
if router_mode == RouterType.KV: prefix_hit_rate = 0.0
if self.use_router:
router_generator = await self.router_client.generate( router_generator = await self.router_client.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json() Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
) )
decision = await router_generator.__anext__() decision = await router_generator.__anext__()
decision = decision.data() worker_id, prefix_hit_rate = decision.data()
worker_id, prefix_hit_rate = decision.split("_")
prefix_hit_rate = float(prefix_hit_rate) prefix_hit_rate = float(prefix_hit_rate)
logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
if worker_id == "": # Create request object once with default prefix_hit_rate
engine_generator = await self.worker_client.generate( request_obj = vLLMGenerateRequest(
vLLMGenerateRequest(
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
prefix_hit_rate=prefix_hit_rate, prefix_hit_rate=prefix_hit_rate,
).model_dump_json() ).model_dump_json()
)
if self.use_router:
if worker_id == "":
engine_generator = await self.worker_client.generate(request_obj)
else: else:
engine_generator = await self.worker_client.direct( engine_generator = await self.worker_client.direct(
vLLMGenerateRequest( request_obj, int(worker_id)
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json(),
int(worker_id),
) )
elif router_mode == RouterType.RANDOM: elif router_mode == RouterType.RANDOM:
engine_generator = await self.worker_client.generate( engine_generator = await self.worker_client.generate(request_obj)
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
elif router_mode == RouterType.ROUND_ROBIN: elif router_mode == RouterType.ROUND_ROBIN:
engine_generator = await self.worker_client.round_robin( engine_generator = await self.worker_client.round_robin(request_obj)
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
elif router_mode == RouterType.KV_LOAD:
# route to worker with least kv load
# TODO: move the router to a separate file and clean up processor.py
try:
kv_load = await self._get_kv_load()
best_worker_id = min(kv_load, key=kv_load.get)
logger.info(f"Routing to worker {best_worker_id} (kv load: {kv_load})")
engine_generator = await self.worker_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json(),
int(best_worker_id),
)
except Exception as e:
logger.info(
f"Error finding worker with least kv load: {e}, fallback to random"
)
engine_generator = await self.worker_client.generate(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
output = self._generate_responses(engine_generator, request_type) output = self._generate_responses(engine_generator, request_type)
async for response in await self._stream_response( async for response in await self._stream_response(
......
...@@ -29,7 +29,7 @@ Processor: ...@@ -29,7 +29,7 @@ Processor:
Router: Router:
min-workers: 1 min-workers: 1
common-configs: [model] common-configs: [model, router]
VllmWorker: VllmWorker:
enforce-eager: true enforce-eager: true
......
...@@ -29,7 +29,7 @@ Processor: ...@@ -29,7 +29,7 @@ Processor:
Router: Router:
min-workers: 1 min-workers: 1
common-configs: [model] common-configs: [model, router]
VllmWorker: VllmWorker:
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment