Unverified Commit 55a6e644 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[Hack] Add pd-disaggregation decode polling interval (#10411)

parent 6897e06b
...@@ -886,9 +886,18 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -886,9 +886,18 @@ class SchedulerDisaggregationDecodeMixin:
# if there are still retracted requests, we do not allocate new requests # if there are still retracted requests, we do not allocate new requests
return return
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() if not hasattr(self, "polling_count"):
self.disagg_decode_transfer_queue.extend(req_conns) self.polling_count = 0
alloc_reqs = ( self.polling_interval = (
self.disagg_decode_transfer_queue.pop_transferred() self.server_args.disaggregation_decode_polling_interval
) # the requests which kv has arrived )
self.waiting_queue.extend(alloc_reqs)
self.polling_count = (self.polling_count + 1) % self.polling_interval
if self.polling_count % self.polling_interval == 0:
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = (
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs)
...@@ -394,6 +394,9 @@ class ServerArgs: ...@@ -394,6 +394,9 @@ class ServerArgs:
disaggregation_ib_device: Optional[str] = None disaggregation_ib_device: Optional[str] = None
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
# FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval: int = 1
# For model weight update # For model weight update
custom_weight_loader: Optional[List[str]] = None custom_weight_loader: Optional[List[str]] = None
weight_loader_disable_mmap: bool = False weight_loader_disable_mmap: bool = False
...@@ -2245,6 +2248,12 @@ class ServerArgs: ...@@ -2245,6 +2248,12 @@ class ServerArgs:
default=ServerArgs.num_reserved_decode_tokens, default=ServerArgs.num_reserved_decode_tokens,
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.", help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
) )
parser.add_argument(
"--disaggregation-decode-polling-interval",
type=int,
default=ServerArgs.disaggregation_decode_polling_interval,
help="The interval to poll requests in decode server. Can be set to >1 to reduce the overhead of this.",
)
# Custom weight loader # Custom weight loader
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment