Commit d29f7fcc authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: conditional disagg based on prefill queue size (#303)

parent d7165149
...@@ -23,14 +23,23 @@ class PyDisaggregatedRouter: ...@@ -23,14 +23,23 @@ class PyDisaggregatedRouter:
runtime, runtime,
served_model_name, served_model_name,
max_local_prefill_length=1000, max_local_prefill_length=1000,
max_prefill_queue_size=2,
): ):
self.runtime = runtime self.runtime = runtime
self.served_model_name = served_model_name self.served_model_name = served_model_name
self.max_local_prefill_length = max_local_prefill_length self.max_local_prefill_length = max_local_prefill_length
self.max_prefill_queue_size = max_prefill_queue_size
def prefill_remote(self, prompt_length: int, prefix_hit_rate: float): def prefill_remote(
self, prompt_length: int, prefix_hit_rate: float, queue_size: int
):
absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate)) absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate))
# TODO: consider size of each request in the queue when making the decision
decision = (
absolute_prefill_length > self.max_local_prefill_length
and queue_size < self.max_prefill_queue_size
)
vllm_logger.info( vllm_logger.info(
f"Remote prefill: {absolute_prefill_length > self.max_local_prefill_length} (prefill length: {absolute_prefill_length}/{prompt_length})" f"Remote prefill: {decision} (prefill length: {absolute_prefill_length}/{prompt_length}, prefill queue size: {queue_size}/{self.max_prefill_queue_size})"
) )
return absolute_prefill_length > self.max_local_prefill_length return decision
...@@ -125,6 +125,7 @@ class VllmWorker: ...@@ -125,6 +125,7 @@ class VllmWorker:
runtime, runtime,
self.model_name, self.model_name,
max_local_prefill_length=self.engine_args.max_local_prefill_length, max_local_prefill_length=self.engine_args.max_local_prefill_length,
max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
) )
else: else:
self.disaggregated_router = None self.disaggregated_router = None
...@@ -148,9 +149,17 @@ class VllmWorker: ...@@ -148,9 +149,17 @@ class VllmWorker:
@dynamo_endpoint() @dynamo_endpoint()
async def generate(self, request: vLLMGenerateRequest): async def generate(self, request: vLLMGenerateRequest):
# TODO: consider prefix hit when deciding prefill locally or remotely # TODO: consider prefix hit when deciding prefill locally or remotely
if self.disaggregated_router is not None: if self.disaggregated_router is not None:
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
prefill_queue_size = await prefill_queue.get_queue_size()
disagg_router_decision = self.disaggregated_router.prefill_remote( disagg_router_decision = self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]), request.prefix_hit_rate len(request.engine_prompt["prompt_token_ids"]),
request.prefix_hit_rate,
prefill_queue_size,
) )
else: else:
# always prefill remotely if no disaggregated router is provided # always prefill remotely if no disaggregated router is provided
......
...@@ -30,6 +30,7 @@ VllmWorker: ...@@ -30,6 +30,7 @@ VllmWorker:
remote-prefill: true remote-prefill: true
conditional-disagg: true conditional-disagg: true
max-local-prefill-length: 10 max-local-prefill-length: 10
max-prefill-queue-size: 2
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
resources: resources:
......
...@@ -36,6 +36,8 @@ VllmWorker: ...@@ -36,6 +36,8 @@ VllmWorker:
max-model-len: 16384 max-model-len: 16384
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
conditional-disagg: true conditional-disagg: true
max-local-prefill-length: 10
max-prefill-queue-size: 2
tensor-parallel-size: 1 tensor-parallel-size: 1
router: kv router: kv
enable-prefix-caching: true enable-prefix-caching: true
......
...@@ -140,3 +140,16 @@ class NATSQueue: ...@@ -140,3 +140,16 @@ class NATSQueue:
return None return None
except NatsError as e: except NatsError as e:
raise RuntimeError(f"Failed to dequeue task: {e}") raise RuntimeError(f"Failed to dequeue task: {e}")
async def get_queue_size(self) -> int:
"""Get the number of messages currently in the queue"""
await self.ensure_connection()
try:
# Get consumer info to get pending messages count
consumer_info = await self._js.consumer_info( # type: ignore
self._stream_name, "worker-group"
)
# Return number of pending messages (real-time queue size)
return consumer_info.num_pending
except NatsError as e:
raise RuntimeError(f"Failed to get queue size: {e}")
...@@ -45,6 +45,12 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: ...@@ -45,6 +45,12 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
default=1000, default=1000,
help="Maximum length of local prefill", help="Maximum length of local prefill",
) )
parser.add_argument(
"--max-prefill-queue-size",
type=int,
default=3,
help="Do not send remote prefill requests (prefill locally) if the queue size is greater than this value",
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args) args = parser.parse_args(vllm_args)
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
...@@ -52,4 +58,5 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: ...@@ -52,4 +58,5 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
engine_args.remote_prefill = args.remote_prefill engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg engine_args.conditional_disagg = args.conditional_disagg
engine_args.max_local_prefill_length = args.max_local_prefill_length engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.max_prefill_queue_size = args.max_prefill_queue_size
return engine_args return engine_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment