Unverified Commit 711efe78 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Integrating PD disaggregation with DP attention and DeepEP (#5435)


Co-authored-by: default avatarByron Hsu <byronhsu1230@gmail.com>
parent fbb5f229
...@@ -444,6 +444,15 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -444,6 +444,15 @@ class ScheduleBatchDisaggregationDecodeMixin:
class SchedulerDisaggregationDecodeMixin: class SchedulerDisaggregationDecodeMixin:
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
return batch, result
@torch.no_grad() @torch.no_grad()
def event_loop_normal_disagg_decode(self): def event_loop_normal_disagg_decode(self):
"""A normal scheduler loop for decode worker in disaggregation mode.""" """A normal scheduler loop for decode worker in disaggregation mode."""
...@@ -456,14 +465,25 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -456,14 +465,25 @@ class SchedulerDisaggregationDecodeMixin:
batch = self.get_next_disagg_decode_batch_to_run() batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
prepare_dp_attn_flag = (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
)
if batch: if batch:
# Generate fake extend output. # Generate fake extend output.
if batch.forward_mode.is_extend(): if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine. # Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, False) self.stream_output(batch.reqs, False)
if prepare_dp_attn_flag:
self._prepare_idle_batch_and_run(None)
else: else:
if prepare_dp_attn_flag:
self.prepare_dp_attn_batch(batch)
result = self.run_batch(batch) result = self.run_batch(batch)
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
elif prepare_dp_attn_flag:
batch, _ = self._prepare_idle_batch_and_run(None)
if batch is None and ( if batch is None and (
len(self.disagg_decode_transfer_queue.queue) len(self.disagg_decode_transfer_queue.queue)
...@@ -480,7 +500,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -480,7 +500,7 @@ class SchedulerDisaggregationDecodeMixin:
def event_loop_overlap_disagg_decode(self): def event_loop_overlap_disagg_decode(self):
result_queue = deque() result_queue = deque()
self.last_batch: Optional[ScheduleBatch] = None self.last_batch: Optional[ScheduleBatch] = None
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
while True: while True:
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
...@@ -489,20 +509,41 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -489,20 +509,41 @@ class SchedulerDisaggregationDecodeMixin:
self.process_decode_queue() self.process_decode_queue()
batch = self.get_next_disagg_decode_batch_to_run() batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
last_batch_is_extend = False last_batch_in_queue = False
prepare_dp_attn_flag = (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
)
if batch: if batch:
# Generate fake extend output. # Generate fake extend output.
if batch.forward_mode.is_extend(): if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine. # Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, False) self.stream_output(batch.reqs, False)
last_batch_is_extend = True if prepare_dp_attn_flag:
batch_, result = self._prepare_idle_batch_and_run(
None, delay_process=True
)
if batch_:
result_queue.append((batch_.copy(), result))
last_batch_in_queue = True
else: else:
if prepare_dp_attn_flag:
self.prepare_dp_attn_batch(batch)
result = self.run_batch(batch) result = self.run_batch(batch)
result_queue.append((batch.copy(), result)) result_queue.append((batch.copy(), result))
last_batch_in_queue = True
elif prepare_dp_attn_flag:
batch, result = self._prepare_idle_batch_and_run(
None, delay_process=True
)
if batch:
result_queue.append((batch.copy(), result))
last_batch_in_queue = True
# Process the results of the previous batch but skip if the last batch is extend # Process the results of the previous batch but skip if the last batch is extend
if self.last_batch and not self.last_batch_is_extend: if self.last_batch and self.last_batch_in_queue:
tmp_batch, tmp_result = result_queue.popleft() tmp_batch, tmp_result = result_queue.popleft()
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
...@@ -516,7 +557,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -516,7 +557,7 @@ class SchedulerDisaggregationDecodeMixin:
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch self.last_batch = batch
self.last_batch_is_extend = last_batch_is_extend self.last_batch_in_queue = last_batch_in_queue
def get_next_disagg_decode_batch_to_run( def get_next_disagg_decode_batch_to_run(
self: Scheduler, self: Scheduler,
......
...@@ -187,6 +187,14 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -187,6 +187,14 @@ class SchedulerDisaggregationPrefillMixin:
) )
self.process_prefill_chunk() self.process_prefill_chunk()
batch = self.get_new_batch_prefill() batch = self.get_new_batch_prefill()
# Handle DP attention
if (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
):
batch, _ = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
...@@ -217,6 +225,14 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -217,6 +225,14 @@ class SchedulerDisaggregationPrefillMixin:
) )
self.process_prefill_chunk() self.process_prefill_chunk()
batch = self.get_new_batch_prefill() batch = self.get_new_batch_prefill()
# Handle DP attention
if (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
):
batch, _ = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
......
...@@ -23,11 +23,13 @@ import psutil ...@@ -23,11 +23,13 @@ import psutil
import setproctitle import setproctitle
import zmq import zmq
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
...@@ -226,9 +228,14 @@ class DataParallelController: ...@@ -226,9 +228,14 @@ class DataParallelController:
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
def round_robin_scheduler(self, req): def round_robin_scheduler(self, req: Req):
self.workers[self.round_robin_counter].send_pyobj(req) if self.server_args.disaggregation_mode == "null":
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
def shortest_queue_scheduler(self, input_requests): def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError() raise NotImplementedError()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment