Unverified Commit f7b2853f authored by Guanhua Wang's avatar Guanhua Wang Committed by GitHub
Browse files

[feat] support minimum token load balance in dp attention (#7379)

parent b0add2da
...@@ -155,7 +155,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -155,7 +155,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--dp-size` | The data parallelism size. | 1 | | `--dp-size` | The data parallelism size. | 1 |
| `--load-balance-method` | The load balancing strategy for data parallelism. | round_robin | | `--load-balance-method` | The load balancing strategy for data parallelism. Options include: 'round_robin', 'minimum_tokens'. The Minimum Token algorithm can only be used when DP attention is applied. This algorithm performs load balancing based on the real-time token load of the DP workers. | round_robin |
## Multi-node distributed serving ## Multi-node distributed serving
......
...@@ -732,6 +732,7 @@ def _launch_subprocesses( ...@@ -732,6 +732,7 @@ def _launch_subprocesses(
pp_rank, pp_rank,
None, None,
writer, writer,
None,
), ),
) )
......
...@@ -16,9 +16,13 @@ ...@@ -16,9 +16,13 @@
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import signal import signal
import struct
import sys
import threading import threading
import time import time
from enum import Enum, auto from enum import Enum, auto
from multiprocessing import shared_memory
from typing import Dict, List
import psutil import psutil
import setproctitle import setproctitle
...@@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.utils import DPBalanceMeta
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
...@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum): ...@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum):
ROUND_ROBIN = auto() ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto() SHORTEST_QUEUE = auto()
MINIMUM_TOKENS = auto()
@classmethod @classmethod
def from_str(cls, method: str): def from_str(cls, method: str):
...@@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum): ...@@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum):
class DataParallelController: class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers.""" """A controller that dispatches requests to multiple data parallel workers."""
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
dp_balance_meta: DPBalanceMeta,
) -> None:
# for dp balance
self.global_balance_id = 0
self.balance_meta = dp_balance_meta
# Parse args # Parse args
self.max_total_num_tokens = None self.max_total_num_tokens = None
self.server_args = server_args self.server_args = server_args
...@@ -79,6 +94,7 @@ class DataParallelController: ...@@ -79,6 +94,7 @@ class DataParallelController:
dispatch_lookup = { dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
} }
self.dispatching = dispatch_lookup[self.load_balance_method] self.dispatching = dispatch_lookup[self.load_balance_method]
...@@ -234,6 +250,7 @@ class DataParallelController: ...@@ -234,6 +250,7 @@ class DataParallelController:
pp_rank, pp_rank,
dp_rank, dp_rank,
writer, writer,
self.balance_meta,
), ),
) )
with memory_saver_adapter.configure_subprocess(): with memory_saver_adapter.configure_subprocess():
...@@ -269,6 +286,33 @@ class DataParallelController: ...@@ -269,6 +286,33 @@ class DataParallelController:
def shortest_queue_scheduler(self, input_requests): def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError() raise NotImplementedError()
def minimum_tokens_scheduler(self, req):
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
def get_next_global_balance_id() -> int:
INT32_MAX = 2147483647
current_id = self.global_balance_id
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
return current_id
req.dp_balance_id = get_next_global_balance_id()
with self.balance_meta.mutex:
# 1. local_tokens represents the tokens currently inferring on the worker,
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
onfly_info = self.balance_meta.get_shared_onfly()
local_tokens = self.balance_meta.get_shared_local_tokens()
total_tokens = [
local_token + sum(onfly_dict.values())
for local_token, onfly_dict in zip(local_tokens, onfly_info)
]
target_worker = total_tokens.index(min(total_tokens))
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
# 2. write the new onfly info to the shm
self.balance_meta.set_shared_onfly_info(onfly_info)
# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
self.workers[target_worker].send_pyobj(req)
def event_loop(self): def event_loop(self):
while True: while True:
while True: while True:
...@@ -302,9 +346,12 @@ def run_data_parallel_controller_process( ...@@ -302,9 +346,12 @@ def run_data_parallel_controller_process(
setproctitle.setproctitle("sglang::data_parallel_controller") setproctitle.setproctitle("sglang::data_parallel_controller")
configure_logger(server_args) configure_logger(server_args)
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
balance_meta = DPBalanceMeta(server_args.dp_size)
try: try:
controller = DataParallelController(server_args, port_args) controller = DataParallelController(
server_args, port_args, dp_balance_meta=balance_meta
)
pipe_writer.send( pipe_writer.send(
{ {
"status": "ready", "status": "ready",
...@@ -323,3 +370,6 @@ def run_data_parallel_controller_process( ...@@ -323,3 +370,6 @@ def run_data_parallel_controller_process(
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}") logger.error(f"DataParallelController hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT) parent_process.send_signal(signal.SIGQUIT)
finally:
# we need to destruct mp.Manager() in balance_meta
balance_meta.destructor()
...@@ -523,6 +523,9 @@ class TokenizedGenerateReqInput: ...@@ -523,6 +523,9 @@ class TokenizedGenerateReqInput:
# For data parallel rank routing # For data parallel rank routing
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
# For dp balance
dp_balance_id: int = -1
@dataclass @dataclass
class EmbeddingReqInput: class EmbeddingReqInput:
...@@ -648,6 +651,8 @@ class TokenizedEmbeddingReqInput: ...@@ -648,6 +651,8 @@ class TokenizedEmbeddingReqInput:
token_type_ids: List[int] token_type_ids: List[int]
# Dummy sampling params for compatibility # Dummy sampling params for compatibility
sampling_params: SamplingParams sampling_params: SamplingParams
# For dp balance
dp_balance_id: int = -1
@dataclass @dataclass
......
...@@ -126,7 +126,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import ( ...@@ -126,7 +126,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
from sglang.srt.managers.session_controller import Session from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
...@@ -203,6 +203,7 @@ class Scheduler( ...@@ -203,6 +203,7 @@ class Scheduler(
moe_ep_rank: int, moe_ep_rank: int,
pp_rank: int, pp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
dp_balance_meta: Optional[DPBalanceMeta] = None,
): ):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
...@@ -522,6 +523,15 @@ class Scheduler( ...@@ -522,6 +523,15 @@ class Scheduler(
] ]
) )
self.balance_meta = dp_balance_meta
if (
server_args.enable_dp_attention
and server_args.load_balance_method == "minimum_tokens"
):
assert dp_balance_meta is not None
self.recv_dp_balance_id_this_term = []
def init_tokenizer(self): def init_tokenizer(self):
server_args = self.server_args server_args = self.server_args
...@@ -1049,6 +1059,12 @@ class Scheduler( ...@@ -1049,6 +1059,12 @@ class Scheduler(
self, self,
recv_req: TokenizedGenerateReqInput, recv_req: TokenizedGenerateReqInput,
): ):
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
# Create a new request # Create a new request
if ( if (
recv_req.session_params is None recv_req.session_params is None
...@@ -1459,6 +1475,11 @@ class Scheduler( ...@@ -1459,6 +1475,11 @@ class Scheduler(
# Handle DP attention # Handle DP attention
if need_dp_attn_preparation: if need_dp_attn_preparation:
if (
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
self.handle_dp_balance_data(ret)
ret = self.prepare_mlp_sync_batch(ret) ret = self.prepare_mlp_sync_batch(ret)
return ret return ret
...@@ -1786,6 +1807,86 @@ class Scheduler( ...@@ -1786,6 +1807,86 @@ class Scheduler(
disable_overlap_schedule=self.server_args.disable_overlap_schedule, disable_overlap_schedule=self.server_args.disable_overlap_schedule,
) )
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(
recv_list
) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
recv_list, dtype=torch.int32
)
if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None
torch.distributed.gather(
recv_tensor, gathered_list, group=self.tp_cpu_group
)
gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(
tensor[2 : list_length + 2].tolist()
)
return gathered_id_list_per_worker, holding_tokens_list
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
meta = self.balance_meta
with meta.mutex:
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
assert len(new_recv_rid_lists) == len(
onfly_list
), "num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id = 0
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
for new_recv_rid in new_recv_rids:
assert (
new_recv_rid in on_fly_reqs
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
del on_fly_reqs[new_recv_rid]
worker_id += 1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)
holding_tokens = self.get_load()
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
holding_tokens
)
self.recv_dp_balance_id_this_term.clear()
if self.tp_rank == 0: # only first worker write info
write_shared_dp_balance_info(
new_recv_dp_balance_id_list, holding_token_list
)
@staticmethod @staticmethod
def prepare_mlp_sync_batch_raw( def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch, local_batch: ScheduleBatch,
...@@ -2394,6 +2495,7 @@ def run_scheduler_process( ...@@ -2394,6 +2495,7 @@ def run_scheduler_process(
pp_rank: int, pp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
pipe_writer, pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None,
): ):
# Generate the prefix # Generate the prefix
prefix = "" prefix = ""
...@@ -2427,7 +2529,14 @@ def run_scheduler_process( ...@@ -2427,7 +2529,14 @@ def run_scheduler_process(
# Create a scheduler and run the event loop # Create a scheduler and run the event loop
try: try:
scheduler = Scheduler( scheduler = Scheduler(
server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank server_args,
port_args,
gpu_id,
tp_rank,
moe_ep_rank,
pp_rank,
dp_rank,
dp_balance_meta=balance_meta,
) )
pipe_writer.send( pipe_writer.send(
{ {
......
import logging import logging
import multiprocessing as mp
from http import HTTPStatus from http import HTTPStatus
from typing import Optional from typing import Dict, List, Optional
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
...@@ -38,3 +39,46 @@ def validate_input_length( ...@@ -38,3 +39,46 @@ def validate_input_length(
return error_msg return error_msg
return None return None
class DPBalanceMeta:
"""
This class will be use in scheduler and dp controller
"""
def __init__(self, num_workers: int):
self.num_workers = num_workers
self._manager = mp.Manager()
self.mutex = self._manager.Lock()
init_local_tokens = [0] * self.num_workers
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
self.shared_state = self._manager.Namespace()
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
def destructor(self):
# we must destructor this class manually
self._manager.shutdown()
def get_shared_onfly(self) -> List[Dict[int, int]]:
return [dict(d) for d in self.shared_state.onfly_info]
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
self.shared_state.onfly_info = data
def get_shared_local_tokens(self) -> List[int]:
return list(self.shared_state.local_tokens)
def set_shared_local_tokens(self, data: List[int]):
self.shared_state.local_tokens = data
def __getstate__(self):
state = self.__dict__.copy()
del state["_manager"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._manager = None
...@@ -1171,6 +1171,7 @@ class ServerArgs: ...@@ -1171,6 +1171,7 @@ class ServerArgs:
choices=[ choices=[
"round_robin", "round_robin",
"shortest_queue", "shortest_queue",
"minimum_tokens",
], ],
) )
......
...@@ -137,5 +137,60 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase): ...@@ -137,5 +137,60 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
self.assertGreater(avg_spec_accept_length, 2.5) self.assertGreater(avg_spec_accept_length, 2.5)
class TestDPAttentionMinimumTokenLoadBalance(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--enable-dp-attention",
"--dp",
"2",
"--enable-torch-compile",
"--torch-compile-max-bs",
"2",
"--load-balance-method",
"minimum_tokens",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment