Commit d29f7fcc authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: conditional disagg based on prefill queue size (#303)

parent d7165149
......@@ -23,14 +23,23 @@ class PyDisaggregatedRouter:
runtime,
served_model_name,
max_local_prefill_length=1000,
max_prefill_queue_size=2,
):
self.runtime = runtime
self.served_model_name = served_model_name
self.max_local_prefill_length = max_local_prefill_length
self.max_prefill_queue_size = max_prefill_queue_size
def prefill_remote(self, prompt_length: int, prefix_hit_rate: float):
def prefill_remote(
self, prompt_length: int, prefix_hit_rate: float, queue_size: int
):
absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate))
# TODO: consider size of each request in the queue when making the decision
decision = (
absolute_prefill_length > self.max_local_prefill_length
and queue_size < self.max_prefill_queue_size
)
vllm_logger.info(
f"Remote prefill: {absolute_prefill_length > self.max_local_prefill_length} (prefill length: {absolute_prefill_length}/{prompt_length})"
f"Remote prefill: {decision} (prefill length: {absolute_prefill_length}/{prompt_length}, prefill queue size: {queue_size}/{self.max_prefill_queue_size})"
)
return absolute_prefill_length > self.max_local_prefill_length
return decision
......@@ -125,6 +125,7 @@ class VllmWorker:
runtime,
self.model_name,
max_local_prefill_length=self.engine_args.max_local_prefill_length,
max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
)
else:
self.disaggregated_router = None
......@@ -148,9 +149,17 @@ class VllmWorker:
@dynamo_endpoint()
async def generate(self, request: vLLMGenerateRequest):
# TODO: consider prefix hit when deciding prefill locally or remotely
if self.disaggregated_router is not None:
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
prefill_queue_size = await prefill_queue.get_queue_size()
disagg_router_decision = self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]), request.prefix_hit_rate
len(request.engine_prompt["prompt_token_ids"]),
request.prefix_hit_rate,
prefill_queue_size,
)
else:
# always prefill remotely if no disaggregated router is provided
......
......@@ -30,6 +30,7 @@ VllmWorker:
remote-prefill: true
conditional-disagg: true
max-local-prefill-length: 10
max-prefill-queue-size: 2
ServiceArgs:
workers: 1
resources:
......
......@@ -36,6 +36,8 @@ VllmWorker:
max-model-len: 16384
max-num-batched-tokens: 16384
conditional-disagg: true
max-local-prefill-length: 10
max-prefill-queue-size: 2
tensor-parallel-size: 1
router: kv
enable-prefix-caching: true
......
......@@ -140,3 +140,16 @@ class NATSQueue:
return None
except NatsError as e:
raise RuntimeError(f"Failed to dequeue task: {e}")
async def get_queue_size(self) -> int:
"""Get the number of messages currently in the queue"""
await self.ensure_connection()
try:
# Get consumer info to get pending messages count
consumer_info = await self._js.consumer_info( # type: ignore
self._stream_name, "worker-group"
)
# Return number of pending messages (real-time queue size)
return consumer_info.num_pending
except NatsError as e:
raise RuntimeError(f"Failed to get queue size: {e}")
......@@ -45,6 +45,12 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
default=1000,
help="Maximum length of local prefill",
)
parser.add_argument(
"--max-prefill-queue-size",
type=int,
default=3,
help="Do not send remote prefill requests (prefill locally) if the queue size is greater than this value",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
......@@ -52,4 +58,5 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg
engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.max_prefill_queue_size = args.max_prefill_queue_size
return engine_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment