Unverified Commit 27e8ffed authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[1/N] DP-refactor: move dp balance code into scheduler's mixin class (#10004)

parent 4dbb34fe
...@@ -500,6 +500,7 @@ class Scheduler( ...@@ -500,6 +500,7 @@ class Scheduler(
# Init metrics stats # Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank) self.init_metrics(tp_rank, pp_rank, dp_rank)
self.init_kv_events(server_args.kv_events_config) self.init_kv_events(server_args.kv_events_config)
self.init_dp_balance(dp_balance_meta)
# Init disaggregation # Init disaggregation
self.disaggregation_mode = DisaggregationMode( self.disaggregation_mode = DisaggregationMode(
...@@ -545,15 +546,6 @@ class Scheduler( ...@@ -545,15 +546,6 @@ class Scheduler(
] ]
) )
self.balance_meta = dp_balance_meta
if (
server_args.enable_dp_attention
and server_args.load_balance_method == "minimum_tokens"
):
assert dp_balance_meta is not None
self.recv_dp_balance_id_this_term = []
def init_tokenizer(self): def init_tokenizer(self):
server_args = self.server_args server_args = self.server_args
self.is_generation = self.model_config.is_generation self.is_generation = self.model_config.is_generation
...@@ -1126,11 +1118,7 @@ class Scheduler( ...@@ -1126,11 +1118,7 @@ class Scheduler(
self, self,
recv_req: TokenizedGenerateReqInput, recv_req: TokenizedGenerateReqInput,
): ):
if ( self.maybe_update_dp_balance_data(recv_req)
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
# Create a new request # Create a new request
if ( if (
...@@ -1568,11 +1556,7 @@ class Scheduler( ...@@ -1568,11 +1556,7 @@ class Scheduler(
# Handle DP attention # Handle DP attention
if need_dp_attn_preparation: if need_dp_attn_preparation:
if ( self.maybe_handle_dp_balance_data()
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
self.handle_dp_balance_data(ret)
ret = self.prepare_mlp_sync_batch(ret) ret = self.prepare_mlp_sync_batch(ret)
return ret return ret
...@@ -1897,86 +1881,6 @@ class Scheduler( ...@@ -1897,86 +1881,6 @@ class Scheduler(
disable_overlap_schedule=self.server_args.disable_overlap_schedule, disable_overlap_schedule=self.server_args.disable_overlap_schedule,
) )
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(
recv_list
) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
recv_list, dtype=torch.int32
)
if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None
torch.distributed.gather(
recv_tensor, gathered_list, group=self.tp_cpu_group
)
gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(
tensor[2 : list_length + 2].tolist()
)
return gathered_id_list_per_worker, holding_tokens_list
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
meta = self.balance_meta
with meta.mutex:
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
assert len(new_recv_rid_lists) == len(
onfly_list
), "num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id = 0
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
for new_recv_rid in new_recv_rids:
assert (
new_recv_rid in on_fly_reqs
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
del on_fly_reqs[new_recv_rid]
worker_id += 1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)
holding_tokens = self.get_load()
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
holding_tokens
)
self.recv_dp_balance_id_this_term.clear()
if self.tp_rank == 0: # only first worker write info
write_shared_dp_balance_info(
new_recv_dp_balance_id_list, holding_token_list
)
@staticmethod @staticmethod
def prepare_mlp_sync_batch_raw( def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch, local_batch: ScheduleBatch,
......
from __future__ import annotations
import logging import logging
import time import time
from collections import defaultdict from collections import defaultdict
from typing import List, Optional from typing import TYPE_CHECKING, Dict, List, Optional, Union
import torch
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_policy import PrefillAdder from sglang.srt.managers.schedule_policy import PrefillAdder
from sglang.srt.managers.scheduler import Req, ScheduleBatch from sglang.srt.managers.scheduler import Req, ScheduleBatch
from sglang.srt.managers.utils import DPBalanceMeta
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Scheduler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME") RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
...@@ -28,7 +37,9 @@ class KvMetrics: ...@@ -28,7 +37,9 @@ class KvMetrics:
class SchedulerMetricsMixin: class SchedulerMetricsMixin:
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]): def init_metrics(
self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
):
self.last_gen_throughput: float = 0.0 self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0 self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
...@@ -50,14 +61,24 @@ class SchedulerMetricsMixin: ...@@ -50,14 +61,24 @@ class SchedulerMetricsMixin:
labels["dp_rank"] = dp_rank labels["dp_rank"] = dp_rank
self.metrics_collector = SchedulerMetricsCollector(labels=labels) self.metrics_collector = SchedulerMetricsCollector(labels=labels)
def init_kv_events(self, kv_events_config: Optional[str]): def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
self.balance_meta = dp_balance_meta
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
assert dp_balance_meta is not None
self.recv_dp_balance_id_this_term = []
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
if self.enable_kv_cache_events: if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
kv_events_config, self.attn_dp_rank kv_events_config, self.attn_dp_rank
) )
def log_prefill_stats( def log_prefill_stats(
self, self: Scheduler,
adder: PrefillAdder, adder: PrefillAdder,
can_run_list: List[Req], can_run_list: List[Req],
running_bs: int, running_bs: int,
...@@ -138,7 +159,7 @@ class SchedulerMetricsMixin: ...@@ -138,7 +159,7 @@ class SchedulerMetricsMixin:
self._publish_kv_events() self._publish_kv_events()
def log_decode_stats( def log_decode_stats(
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
): ):
batch = running_batch or self.running_batch batch = running_batch or self.running_batch
...@@ -220,7 +241,7 @@ class SchedulerMetricsMixin: ...@@ -220,7 +241,7 @@ class SchedulerMetricsMixin:
self._emit_kv_metrics() self._emit_kv_metrics()
self._publish_kv_events() self._publish_kv_events()
def _emit_kv_metrics(self): def _emit_kv_metrics(self: Scheduler):
kv_metrics = KvMetrics() kv_metrics = KvMetrics()
kv_metrics.request_active_slots = self.stats.num_running_reqs kv_metrics.request_active_slots = self.stats.num_running_reqs
kv_metrics.request_total_slots = self.max_running_requests kv_metrics.request_total_slots = self.max_running_requests
...@@ -236,9 +257,94 @@ class SchedulerMetricsMixin: ...@@ -236,9 +257,94 @@ class SchedulerMetricsMixin:
if not self.send_metrics_from_scheduler.closed: if not self.send_metrics_from_scheduler.closed:
self.send_metrics_from_scheduler.send_pyobj(kv_metrics) self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
def _publish_kv_events(self): def _publish_kv_events(self: Scheduler):
if self.enable_kv_cache_events: if self.enable_kv_cache_events:
events = self.tree_cache.take_events() events = self.tree_cache.take_events()
if events: if events:
batch = KVEventBatch(ts=time.time(), events=events) batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch) self.kv_event_publisher.publish(batch)
def maybe_update_dp_balance_data(
self: Scheduler, recv_req: TokenizedGenerateReqInput
):
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
def maybe_handle_dp_balance_data(self: Scheduler):
if (
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
holding_tokens = self.get_load()
new_recv_dp_balance_id_list, holding_token_list = (
self.gather_dp_balance_info(holding_tokens)
)
self.recv_dp_balance_id_this_term.clear()
if self.tp_rank == 0: # only first worker write info
self.write_shared_dp_balance_info(
new_recv_dp_balance_id_list, holding_token_list
)
def gather_dp_balance_info(
self: Scheduler, holding_tokens_list
) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None
torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
return gathered_id_list_per_worker, holding_tokens_list
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
meta = self.balance_meta
with meta.mutex:
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id = 0
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
for new_recv_rid in new_recv_rids:
assert (
new_recv_rid in on_fly_reqs
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
del on_fly_reqs[new_recv_rid]
worker_id += 1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment