Unverified Commit 73b13e69 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Optimize DP attn scheduling for speculative decoding (#7285)

parent 8609e637
...@@ -1399,29 +1399,6 @@ class Scheduler( ...@@ -1399,29 +1399,6 @@ class Scheduler(
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
self._publish_kv_events() self._publish_kv_events()
def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
"""Coordinate the DP attention batch."""
local_info = torch.tensor(
[
(new_batch is not None),
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 1),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
any_new_batch = any(
global_info[:, 0, 0].tolist()
) # Any DP worker has forward batch
return any_new_batch
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch # Merge the prefill batch into the running batch
chunked_req_to_exclude = set() chunked_req_to_exclude = set()
...@@ -1456,13 +1433,15 @@ class Scheduler( ...@@ -1456,13 +1433,15 @@ class Scheduler(
new_batch = self.get_new_batch_prefill() new_batch = self.get_new_batch_prefill()
# TODO(ch-wan): minor refactor is needed here to improve readability need_dp_attn_preparation = require_mlp_sync(self.server_args)
any_new_batch = (
self.server_args.enable_dp_attention if need_dp_attn_preparation and not self.spec_algorithm.is_none():
and not self.spec_algorithm.is_none() # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
and self.coordinate_spec_dp_attn_batch(new_batch) # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
) new_batch, _ = self.prepare_dp_attn_batch(new_batch)
if new_batch is not None or any_new_batch: need_dp_attn_preparation = new_batch is None
if new_batch is not None:
# Run prefill first if possible # Run prefill first if possible
ret = new_batch ret = new_batch
else: else:
...@@ -1473,8 +1452,9 @@ class Scheduler( ...@@ -1473,8 +1452,9 @@ class Scheduler(
else: else:
ret = None ret = None
if require_mlp_sync(self.server_args): # Handle DP attention
ret, _ = self.prepare_mlp_sync_batch(ret) if need_dp_attn_preparation:
ret, _ = self.prepare_dp_attn_batch(ret)
return ret return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment