Unverified Commit c30c6990 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix: direct clients vs dependancies (#704)


Co-authored-by: default avatarZiqi Fan <ziqif@nvidia.com>
parent f4780e85
......@@ -67,6 +67,7 @@ class Processor(ProcessMixIn):
self.tokenizer, self.model_config
)
self.min_workers = 1
print(f"Processor init: {self.engine_args.router}")
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
......@@ -93,6 +94,15 @@ class Processor(ProcessMixIn):
.client()
)
if self.engine_args.router == "kv":
router_ns, router_name = Router.dynamo_address() # type: ignore
self.router_client = (
await runtime.namespace(router_ns)
.component(router_name)
.endpoint("generate")
.client()
)
await check_required_workers(self.worker_client, self.min_workers)
self.etcd_kv_cache = await EtcdKvCache.create(
......@@ -117,15 +127,16 @@ class Processor(ProcessMixIn):
) = await self._parse_raw_request(raw_request)
router_mode = (await self.etcd_kv_cache.get("router")).decode()
if router_mode == "kv":
async for route_response in self.router.generate(
router_generator = await self.router_client.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
):
worker_id, prefix_hit_rate = route_response.split("_")
)
decision = await router_generator.__anext__()
decision = decision.data()
worker_id, prefix_hit_rate = decision.split("_")
prefix_hit_rate = float(prefix_hit_rate)
logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
break
if worker_id == "":
engine_generator = await self.worker_client.generate(
......
......@@ -52,6 +52,7 @@ class Processor(ChatProcessorMixin):
self.remote_prefill = args.remote_prefill
self.router_mode = args.router
self.min_workers = 1
self.args = args
super().__init__(engine_config)
......@@ -65,6 +66,16 @@ class Processor(ChatProcessorMixin):
.endpoint("generate")
.client()
)
if self.args.router == "kv":
router_ns, router_name = Router.dynamo_address() # type: ignore
self.router_client = (
await runtime.namespace(router_ns)
.component(router_name)
.endpoint("generate")
.client()
)
while len(self.worker_client.endpoint_ids()) < self.min_workers:
logger.info(
f"Waiting for workers to be ready.\n"
......@@ -88,15 +99,16 @@ class Processor(ChatProcessorMixin):
worker_id = ""
if self.router_mode == "kv":
async for route_response in self.router.generate(
router_generator = await self.router_client.generate(
preprocessed_request.tokens.model_dump_json()
):
worker_id, prefix_hit_rate = route_response.split("_")
)
decision = await router_generator.__anext__()
decision = decision.data()
worker_id, prefix_hit_rate = decision.split("_")
prefix_hit_rate = float(prefix_hit_rate)
logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
break
if worker_id == "":
if self.router_mode == "round-robin":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment