Unverified Commit c55550cb authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[PD] Better logs (#5715)

parent 43fb95c2
...@@ -307,7 +307,7 @@ class DecodeTransferQueue: ...@@ -307,7 +307,7 @@ class DecodeTransferQueue:
def extend(self, req_conns) -> None: def extend(self, req_conns) -> None:
self.queue.extend(req_conns) self.queue.extend(req_conns)
def pop_transferred(self) -> List[Req]: def pop_transferred(self) -> List[DecodeRequest]:
if not self.queue: if not self.queue:
return [] return []
...@@ -330,7 +330,7 @@ class DecodeTransferQueue: ...@@ -330,7 +330,7 @@ class DecodeTransferQueue:
assert len(decode_req.req.output_ids) == 0 assert len(decode_req.req.output_ids) == 0
assert decode_req.req.transferred_output_id is None assert decode_req.req.transferred_output_id is None
decode_req.req.transferred_output_id = output_id decode_req.req.transferred_output_id = output_id
transferred_reqs.append(decode_req.req) transferred_reqs.append(decode_req)
indices_to_remove.add(i) indices_to_remove.add(i)
elif poll in [ elif poll in [
KVPoll.Bootstrapping, KVPoll.Bootstrapping,
...@@ -454,7 +454,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -454,7 +454,7 @@ class SchedulerDisaggregationDecodeMixin:
return batch, result return batch, result
@torch.no_grad() @torch.no_grad()
def event_loop_normal_disagg_decode(self): def event_loop_normal_disagg_decode(self: Scheduler):
"""A normal scheduler loop for decode worker in disaggregation mode.""" """A normal scheduler loop for decode worker in disaggregation mode."""
while True: while True:
...@@ -497,7 +497,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -497,7 +497,7 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch = batch self.last_batch = batch
@torch.no_grad() @torch.no_grad()
def event_loop_overlap_disagg_decode(self): def event_loop_overlap_disagg_decode(self: Scheduler):
result_queue = deque() result_queue = deque()
self.last_batch: Optional[ScheduleBatch] = None self.last_batch: Optional[ScheduleBatch] = None
self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
...@@ -641,8 +641,15 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -641,8 +641,15 @@ class SchedulerDisaggregationDecodeMixin:
def process_decode_queue(self: Scheduler): def process_decode_queue(self: Scheduler):
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
def _num_pre_alloc(req):
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
self.disagg_decode_transfer_queue.extend(req_conns) self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = ( alloc_reqs = (
self.disagg_decode_transfer_queue.pop_transferred() self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived ) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs) self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
self.waiting_queue.extend([req.req for req in alloc_reqs])
...@@ -176,14 +176,14 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -176,14 +176,14 @@ class SchedulerDisaggregationPrefillMixin:
""" """
@torch.no_grad() @torch.no_grad()
def event_loop_normal_disagg_prefill(self): def event_loop_normal_disagg_prefill(self: Scheduler):
"""A normal scheduler loop for prefill worker in disaggregation mode.""" """A normal scheduler loop for prefill worker in disaggregation mode."""
while True: while True:
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
self.waiting_queue.extend( self.waiting_queue.extend(
self.disagg_prefill_pending_queue.pop_bootstrapped() self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
) )
self.process_prefill_chunk() self.process_prefill_chunk()
batch = self.get_new_batch_prefill() batch = self.get_new_batch_prefill()
...@@ -214,14 +214,14 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -214,14 +214,14 @@ class SchedulerDisaggregationPrefillMixin:
self.running_batch.batch_is_full = False self.running_batch.batch_is_full = False
@torch.no_grad() @torch.no_grad()
def event_loop_overlap_disagg_prefill(self): def event_loop_overlap_disagg_prefill(self: Scheduler):
self.result_queue = deque() self.result_queue = deque()
while True: while True:
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
self.waiting_queue.extend( self.waiting_queue.extend(
self.disagg_prefill_pending_queue.pop_bootstrapped() self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
) )
self.process_prefill_chunk() self.process_prefill_chunk()
batch = self.get_new_batch_prefill() batch = self.get_new_batch_prefill()
...@@ -326,7 +326,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -326,7 +326,7 @@ class SchedulerDisaggregationPrefillMixin:
raise Exception("Transferring failed") raise Exception("Transferring failed")
for req in done_reqs: for req in done_reqs:
self.disagg_prefill_pending_queue.req_to_metadata_buffer_idx_allocator.free( self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
req.metadata_buffer_index req.metadata_buffer_index
) )
...@@ -342,9 +342,8 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -342,9 +342,8 @@ class SchedulerDisaggregationPrefillMixin:
# only finished requests to running_batch. # only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req) self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req) self.tree_cache.cache_unfinished_req(self.chunked_req)
if ( if self.enable_overlap:
self.enable_overlap # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
self.chunked_req.tmp_end_idx = min( self.chunked_req.tmp_end_idx = min(
len(self.chunked_req.fill_ids), len(self.chunked_req.fill_ids),
len(self.chunked_req.origin_input_ids), len(self.chunked_req.origin_input_ids),
...@@ -390,7 +389,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -390,7 +389,7 @@ class SchedulerDisaggregationPrefillMixin:
.numpy() .numpy()
) )
if last_chunk is True: if last_chunk is True:
self.disagg_prefill_pending_queue.store_prefill_results( self.disagg_prefill_bootstrap_queue.store_prefill_results(
req.metadata_buffer_index, token_id req.metadata_buffer_index, token_id
) )
page_indices = kv_to_page_indices(kv_indices, page_size) page_indices = kv_to_page_indices(kv_indices, page_size)
......
...@@ -578,6 +578,10 @@ class Scheduler( ...@@ -578,6 +578,10 @@ class Scheduler(
bootstrap_port=self.server_args.disaggregation_bootstrap_port, bootstrap_port=self.server_args.disaggregation_bootstrap_port,
transfer_backend=self.transfer_backend, transfer_backend=self.transfer_backend,
) )
# Metric for pre-allocation
self.num_tokens_pre_allocated = 0
elif self.disaggregation_mode == DisaggregationMode.PREFILL: elif self.disaggregation_mode == DisaggregationMode.PREFILL:
# *2 for the headroom. # *2 for the headroom.
buffer_size = self.max_running_requests * 2 buffer_size = self.max_running_requests * 2
...@@ -593,7 +597,7 @@ class Scheduler( ...@@ -593,7 +597,7 @@ class Scheduler(
) )
metadata_buffers = [output_id_buffer] metadata_buffers = [output_id_buffer]
self.disagg_prefill_pending_queue = PrefillBootstrapQueue( self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers, metadata_buffers=metadata_buffers,
...@@ -901,7 +905,7 @@ class Scheduler( ...@@ -901,7 +905,7 @@ class Scheduler(
def _add_request_to_queue(self, req: Req): def _add_request_to_queue(self, req: Req):
req.queue_time_start = time.time() req.queue_time_start = time.time()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_pending_queue.add(req) self.disagg_prefill_bootstrap_queue.add(req)
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req) self.disagg_decode_prealloc_queue.add(req)
else: else:
...@@ -991,8 +995,15 @@ class Scheduler( ...@@ -991,8 +995,15 @@ class Scheduler(
f"#cached-token: {adder.log_hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, " f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue)}, "
) )
if self.disaggregation_mode == DisaggregationMode.PREFILL:
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
else:
f += f"#queue-req: {len(self.waiting_queue)}"
logger.info(f) logger.info(f)
if self.enable_metrics: if self.enable_metrics:
...@@ -1028,15 +1039,14 @@ class Scheduler( ...@@ -1028,15 +1039,14 @@ class Scheduler(
gap_latency / self.server_args.decode_log_interval gap_latency / self.server_args.decode_log_interval
) )
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
)
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
spec_accept_length = 0 spec_accept_length = 0
else: else:
spec_accept_length = ( spec_accept_length = (
...@@ -1045,15 +1055,15 @@ class Scheduler( ...@@ -1045,15 +1055,15 @@ class Scheduler(
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg = ( msg += f"accept len: {spec_accept_length:.2f}, "
f"Decode batch. "
f"#running-req: {num_running_reqs}, " if self.disaggregation_mode == DisaggregationMode.DECODE:
f"#token: {num_used}, " msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"accept len: {spec_accept_length:.2f}, " msg += (
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.waiting_queue)}"
) )
logger.info(msg) logger.info(msg)
if self.enable_metrics: if self.enable_metrics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment