Unverified Commit 73b13e69 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Optimize DP attn scheduling for speculative decoding (#7285)

parent 8609e637
......@@ -1399,29 +1399,6 @@ class Scheduler(
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
"""Coordinate the DP attention batch."""
local_info = torch.tensor(
[
(new_batch is not None),
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 1),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
any_new_batch = any(
global_info[:, 0, 0].tolist()
) # Any DP worker has forward batch
return any_new_batch
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
chunked_req_to_exclude = set()
......@@ -1456,13 +1433,15 @@ class Scheduler(
new_batch = self.get_new_batch_prefill()
# TODO(ch-wan): minor refactor is needed here to improve readability
any_new_batch = (
self.server_args.enable_dp_attention
and not self.spec_algorithm.is_none()
and self.coordinate_spec_dp_attn_batch(new_batch)
)
if new_batch is not None or any_new_batch:
need_dp_attn_preparation = require_mlp_sync(self.server_args)
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
new_batch, _ = self.prepare_dp_attn_batch(new_batch)
need_dp_attn_preparation = new_batch is None
if new_batch is not None:
# Run prefill first if possible
ret = new_batch
else:
......@@ -1473,8 +1452,9 @@ class Scheduler(
else:
ret = None
if require_mlp_sync(self.server_args):
ret, _ = self.prepare_mlp_sync_batch(ret)
# Handle DP attention
if need_dp_attn_preparation:
ret, _ = self.prepare_dp_attn_batch(ret)
return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment