"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "093efb9a81af77aac0f396b157f11cd5a197fa74"
Unverified Commit 2be5e8f5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: reduce code repetition in processor (#919)

parent 0086ebc6
......@@ -18,17 +18,19 @@ import argparse
import logging
import random
from argparse import Namespace
from typing import AsyncIterator
from typing import AsyncIterator, Tuple
from components.worker import VllmWorker
from utils.logging import check_required_workers
from utils.protocol import Tokens
from utils.vllm import RouterType
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
WorkerId = str
fallback_msg = "Will fallback to random routing."
logger = logging.getLogger(__name__)
......@@ -60,6 +62,12 @@ def parse_args(service_name, prefix) -> Namespace:
default=False,
help="Whether to use custom router or not",
)
parser.add_argument(
"--router",
type=str,
default="kv",
help="The router type",
)
config = ServiceConfig.get_instance()
config_args = config.as_args(service_name, prefix=prefix)
args = parser.parse_args(config_args)
......@@ -101,11 +109,14 @@ class Router:
.client()
)
self.router_type = self.args.router
await check_required_workers(self.workers_client, self.args.min_workers)
kv_listener = self.runtime.namespace("dynamo").component("VllmWorker")
await kv_listener.create_service()
self.indexer = KvIndexer(kv_listener, self.args.block_size)
if self.router_type == RouterType.KV:
self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
logger.info("KV Router initialized")
......@@ -182,7 +193,8 @@ class Router:
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {gpu_cache_usage:.3f} - {normalized_waiting:.3f}"
)
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
if not worker_logits or not any(worker_logits.values()):
logger.warning(f"All worker logits are zero. {fallback_msg}.")
return "", 0.0
# Select the worker with the highest logit
......@@ -211,8 +223,47 @@ class Router:
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
def _get_underloaded_worker(self, metrics: AggregatedMetrics | None):
if not metrics:
logger.warning(f"Cannot get metrics. {fallback_msg}")
return "", 0.0
kv_load = {
endpoint.worker_id: getattr(endpoint, "gpu_cache_usage_perc", 0.0)
for endpoint in metrics.endpoints
}
if not kv_load or not any(kv_load.values()):
logger.warning(f"All KV loads are zero. {fallback_msg}")
return "", 0.0
min_load = min(kv_load.values())
min_load_workers = [
worker_id for worker_id, load in kv_load.items() if load == min_load
]
best_worker_id = random.choice(min_load_workers)
logger.info(
f"Selected worker: {best_worker_id}, KV load: {kv_load[best_worker_id]:.3f}"
)
return best_worker_id, kv_load[best_worker_id]
@dynamo_endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
async def generate(self, request: Tokens) -> AsyncIterator[Tuple[WorkerId, float]]:
metrics = await self.metrics_aggregator.get_metrics()
# Quick return for KV_LOAD mode
if self.router_type == RouterType.KV_LOAD:
try:
yield self._get_underloaded_worker(metrics)
except Exception as e:
logger.exception(
f"Error finding underloaded worker: {e}. {fallback_msg}"
)
yield "", 0.0
return
# Existing KV routing logic
lora_id = 0
try:
scores = await self.indexer.find_matches_for_request(
......@@ -220,14 +271,17 @@ class Router:
)
except Exception as e:
scores = {}
logger.exception(f"Error finding matches: {e}")
logger.exception(f"Error finding matches: {e}. {fallback_msg}")
yield "", 0.0
return
metrics = await self.metrics_aggregator.get_metrics()
worker_id, prefix_hit_rate = self._cost_function(
scores, metrics, len(request.tokens)
)
logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
yield f"{worker_id}_{prefix_hit_rate}"
if worker_id:
logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
yield worker_id, prefix_hit_rate
......@@ -95,7 +95,8 @@ class Processor(ProcessMixIn):
.client()
)
if self.engine_args.router == RouterType.KV:
self.use_router = self.engine_args.router in (RouterType.KV, RouterType.KV_LOAD)
if self.use_router:
router_ns, router_name = Router.dynamo_address() # type: ignore
self.router_client = (
await runtime.namespace(router_ns)
......@@ -116,22 +117,6 @@ class Processor(ProcessMixIn):
{"router": self.engine_args.router},
)
async def _get_kv_load(self):
metrics = await self.metrics_aggregator.get_metrics()
kv_load = {}
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
kv_load[worker_id] = getattr(endpoint, "gpu_cache_usage_perc", 0.0)
return kv_load
async def _get_pending_requests(self):
metrics = await self.metrics_aggregator.get_metrics()
pending_requests = {}
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
pending_requests[worker_id] = getattr(endpoint, "num_requests_waiting", 0)
return pending_requests
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
......@@ -146,81 +131,38 @@ class Processor(ProcessMixIn):
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
# TODO: queue request at processor when engines are full
router_mode = (await self.etcd_kv_cache.get("router")).decode()
if router_mode == RouterType.KV:
prefix_hit_rate = 0.0
if self.use_router:
router_generator = await self.router_client.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
)
decision = await router_generator.__anext__()
decision = decision.data()
worker_id, prefix_hit_rate = decision.split("_")
worker_id, prefix_hit_rate = decision.data()
prefix_hit_rate = float(prefix_hit_rate)
logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
# Create request object once with default prefix_hit_rate
request_obj = vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json()
if self.use_router:
if worker_id == "":
engine_generator = await self.worker_client.generate(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json()
)
engine_generator = await self.worker_client.generate(request_obj)
else:
engine_generator = await self.worker_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json(),
int(worker_id),
request_obj, int(worker_id)
)
elif router_mode == RouterType.RANDOM:
engine_generator = await self.worker_client.generate(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
engine_generator = await self.worker_client.generate(request_obj)
elif router_mode == RouterType.ROUND_ROBIN:
engine_generator = await self.worker_client.round_robin(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
elif router_mode == RouterType.KV_LOAD:
# route to worker with least kv load
# TODO: move the router to a separate file and clean up processor.py
try:
kv_load = await self._get_kv_load()
best_worker_id = min(kv_load, key=kv_load.get)
logger.info(f"Routing to worker {best_worker_id} (kv load: {kv_load})")
engine_generator = await self.worker_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json(),
int(best_worker_id),
)
except Exception as e:
logger.info(
f"Error finding worker with least kv load: {e}, fallback to random"
)
engine_generator = await self.worker_client.generate(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
engine_generator = await self.worker_client.round_robin(request_obj)
output = self._generate_responses(engine_generator, request_type)
async for response in await self._stream_response(
......
......@@ -29,7 +29,7 @@ Processor:
Router:
min-workers: 1
common-configs: [model]
common-configs: [model, router]
VllmWorker:
enforce-eager: true
......
......@@ -29,7 +29,7 @@ Processor:
Router:
min-workers: 1
common-configs: [model]
common-configs: [model, router]
VllmWorker:
max-num-batched-tokens: 16384
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment