Unverified Commit 7ff740a6 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Remove dp balance metadata and minimul token balance. (#11170)

parent bfcd9b24
......@@ -812,7 +812,6 @@ def _launch_subprocesses(
pp_rank,
None,
writer,
None,
),
)
......
......@@ -120,11 +120,8 @@ message GenerateRequest {
// Data parallel routing
int32 data_parallel_rank = 16;
// For load balancing
int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
bool stream = 17;
}
message TokenizedInput {
......
......@@ -82,7 +82,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "stream")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
......@@ -99,7 +99,6 @@ class GenerateRequest(_message.Message):
INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int]
LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
......@@ -117,9 +116,8 @@ class GenerateRequest(_message.Message):
input_embeds: _containers.RepeatedScalarFieldContainer[float]
lora_id: str
data_parallel_rank: int
dp_balance_id: int
stream: bool
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., stream: bool = ...) -> None: ...
class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids")
......
......@@ -17,14 +17,11 @@ import faulthandler
import logging
import multiprocessing as mp
import signal
import struct
import sys
import threading
import time
from collections import deque
from enum import Enum, auto
from multiprocessing import shared_memory
from typing import Dict, List
from typing import List
import psutil
import setproctitle
......@@ -39,7 +36,6 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.utils import DPBalanceMeta
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
......@@ -108,15 +104,9 @@ class DPBudget:
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
dp_balance_meta: DPBalanceMeta,
) -> None:
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
# for dp balance
self.global_balance_id = 0
self.balance_meta = dp_balance_meta
# Parse args
self.max_total_num_tokens = None
......@@ -322,7 +312,6 @@ class DataParallelController:
pp_rank,
dp_rank,
writer,
self.balance_meta,
),
)
with memory_saver_adapter.configure_subprocess():
......@@ -370,31 +359,11 @@ class DataParallelController:
if self.maybe_external_dp_rank_routing(req):
return
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
def get_next_global_balance_id() -> int:
INT32_MAX = 2147483647
current_id = self.global_balance_id
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
return current_id
req.dp_balance_id = get_next_global_balance_id()
with self.balance_meta.mutex:
# 1. local_tokens represents the tokens currently inferring on the worker,
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
onfly_info = self.balance_meta.get_shared_onfly()
local_tokens = self.balance_meta.get_shared_local_tokens()
total_tokens = [
local_token + sum(onfly_dict.values())
for local_token, onfly_dict in zip(local_tokens, onfly_info)
]
target_worker = total_tokens.index(min(total_tokens))
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
# 2. write the new onfly info to the shm
self.balance_meta.set_shared_onfly_info(onfly_info)
# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
self.workers[target_worker].send_pyobj(req)
logger.warning(
"The 'minimum_tokens' load balancing method is deprecated for now and will introduced later."
"Fall back to 'round_robin_scheduler'"
)
self.round_robin_scheduler(req)
def event_loop(self):
while True:
......@@ -416,12 +385,9 @@ def run_data_parallel_controller_process(
faulthandler.enable()
configure_logger(server_args)
parent_process = psutil.Process().parent()
balance_meta = DPBalanceMeta(server_args.dp_size)
try:
controller = DataParallelController(
server_args, port_args, dp_balance_meta=balance_meta
)
controller = DataParallelController(server_args, port_args)
pipe_writer.send(
{
"status": "ready",
......@@ -440,6 +406,3 @@ def run_data_parallel_controller_process(
traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
finally:
# we need to destruct mp.Manager() in balance_meta
balance_meta.destructor()
......@@ -606,9 +606,6 @@ class TokenizedGenerateReqInput:
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
# For dp balance
dp_balance_id: int = -1
# Priority for the request
priority: Optional[int] = None
......@@ -778,8 +775,6 @@ class TokenizedEmbeddingReqInput:
sampling_params: SamplingParams
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
# For dp balance
dp_balance_id: int = -1
# Priority for the request
priority: Optional[int] = None
......
......@@ -145,7 +145,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
......@@ -271,7 +271,6 @@ class Scheduler(
moe_ep_rank: int,
pp_rank: int,
dp_rank: Optional[int],
dp_balance_meta: Optional[DPBalanceMeta] = None,
):
# Parse args
self.server_args = server_args
......@@ -600,7 +599,6 @@ class Scheduler(
# Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank)
self.init_dp_balance(dp_balance_meta)
if self.enable_kv_cache_events:
self.init_kv_events(server_args.kv_events_config)
......@@ -1270,8 +1268,6 @@ class Scheduler(
self,
recv_req: TokenizedGenerateReqInput,
):
self.maybe_update_dp_balance_data(recv_req)
# Create a new request
if (
recv_req.session_params is None
......@@ -1797,7 +1793,6 @@ class Scheduler(
# Handle DP attention
if need_dp_attn_preparation:
self.maybe_handle_dp_balance_data()
ret = self.prepare_mlp_sync_batch(ret)
return ret
......@@ -2803,7 +2798,6 @@ def run_scheduler_process(
pp_rank: int,
dp_rank: Optional[int],
pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None,
):
# Generate the logger prefix
prefix = ""
......@@ -2852,7 +2846,6 @@ def run_scheduler_process(
moe_ep_rank,
pp_rank,
dp_rank,
dp_balance_meta=balance_meta,
)
pipe_writer.send(
{
......
......@@ -12,7 +12,6 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_policy import PrefillAdder
from sglang.srt.managers.scheduler import Req, ScheduleBatch
from sglang.srt.managers.utils import DPBalanceMeta
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.utils import get_bool_env_var
......@@ -64,16 +63,6 @@ class SchedulerMetricsMixin:
labels["dp_rank"] = dp_rank
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
self.balance_meta = dp_balance_meta
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
assert dp_balance_meta is not None
self.recv_dp_balance_id_this_term = []
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(
......@@ -319,91 +308,6 @@ class SchedulerMetricsMixin:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
def maybe_update_dp_balance_data(
self: Scheduler, recv_req: TokenizedGenerateReqInput
):
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
def maybe_handle_dp_balance_data(self: Scheduler):
if (
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
holding_tokens = self.get_load().num_tokens
new_recv_dp_balance_id_list, holding_token_list = (
self.gather_dp_balance_info(holding_tokens)
)
self.recv_dp_balance_id_this_term.clear()
if self.tp_rank == 0: # only first worker write info
self.write_shared_dp_balance_info(
new_recv_dp_balance_id_list, holding_token_list
)
def gather_dp_balance_info(
self: Scheduler, holding_tokens_list
) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None
torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
return gathered_id_list_per_worker, holding_tokens_list
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
meta = self.balance_meta
with meta.mutex:
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id = 0
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
for new_recv_rid in new_recv_rids:
assert (
new_recv_rid in on_fly_reqs
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
del on_fly_reqs[new_recv_rid]
worker_id += 1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)
def calculate_utilization(self):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.utilization = -1
......
......@@ -96,46 +96,3 @@ def get_logprob_from_pp_outputs(
]
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
class DPBalanceMeta:
"""
This class will be use in scheduler and dp controller
"""
def __init__(self, num_workers: int):
self.num_workers = num_workers
self._manager = mp.Manager()
self.mutex = self._manager.Lock()
init_local_tokens = [0] * self.num_workers
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
self.shared_state = self._manager.Namespace()
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
def destructor(self):
# we must destructor this class manually
self._manager.shutdown()
def get_shared_onfly(self) -> List[Dict[int, int]]:
return [dict(d) for d in self.shared_state.onfly_info]
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
self.shared_state.onfly_info = data
def get_shared_local_tokens(self) -> List[int]:
return list(self.shared_state.local_tokens)
def set_shared_local_tokens(self, data: List[int]):
self.shared_state.local_tokens = data
def __getstate__(self):
state = self.__dict__.copy()
del state["_manager"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._manager = None
......@@ -120,11 +120,8 @@ message GenerateRequest {
// Data parallel routing
int32 data_parallel_rank = 16;
// For load balancing
int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
bool stream = 17;
}
message TokenizedInput {
......
......@@ -124,47 +124,5 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
self.assertGreater(avg_spec_accept_length, 2.5)
class TestDPAttentionMinimumTokenLoadBalance(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--enable-dp-attention",
"--dp",
"2",
"--enable-torch-compile",
"--torch-compile-max-bs",
"2",
"--load-balance-method",
"minimum_tokens",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment