Unverified Commit 7d121448 authored by popsiclexu's avatar popsiclexu Committed by GitHub
Browse files

[Bug fix][PD Dissaggregation] fix prefill hanging issue with PP and DP Attention, (#12368)

parent 6a63a985
......@@ -588,7 +588,7 @@ class SchedulerDisaggregationPrefillMixin:
"""
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.tp_worker.get_tp_group().cpu_group,
self.tp_worker.get_attention_tp_cpu_group(),
)
transferred_rids: List[str] = []
......@@ -722,8 +722,11 @@ class SchedulerDisaggregationPrefillMixin:
else:
data = None
if self.tp_size != 1:
if self.attn_tp_size != 1:
data = broadcast_pyobj(
data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
data,
self.attn_tp_group.rank,
self.attn_tp_cpu_group,
src=self.attn_tp_group.ranks[0],
)
return data
......@@ -4,7 +4,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.utils import GenerationBatchResult
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj
from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj, require_mlp_sync
class SchedulerPPMixin:
......@@ -236,7 +236,12 @@ class SchedulerPPMixin:
tmbs[mb_id] = transferred_rids
self.process_prefill_chunk()
mbs[mb_id] = self.get_new_batch_prefill()
batch = self.get_new_batch_prefill()
if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch)
mbs[mb_id] = batch
self.running_mbs[mb_id] = self.running_batch
self.cur_batch = mbs[mb_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment