Unverified Commit f7b2853f authored by Guanhua Wang's avatar Guanhua Wang Committed by GitHub
Browse files

[feat] support minimum token load balance in dp attention (#7379)

parent b0add2da
......@@ -155,7 +155,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--dp-size` | The data parallelism size. | 1 |
| `--load-balance-method` | The load balancing strategy for data parallelism. | round_robin |
| `--load-balance-method` | The load balancing strategy for data parallelism. Options include: 'round_robin', 'minimum_tokens'. The Minimum Token algorithm can only be used when DP attention is applied. This algorithm performs load balancing based on the real-time token load of the DP workers. | round_robin |
## Multi-node distributed serving
......
......@@ -732,6 +732,7 @@ def _launch_subprocesses(
pp_rank,
None,
writer,
None,
),
)
......
......@@ -16,9 +16,13 @@
import logging
import multiprocessing as mp
import signal
import struct
import sys
import threading
import time
from enum import Enum, auto
from multiprocessing import shared_memory
from typing import Dict, List
import psutil
import setproctitle
......@@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.utils import DPBalanceMeta
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
......@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum):
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
MINIMUM_TOKENS = auto()
@classmethod
def from_str(cls, method: str):
......@@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum):
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
dp_balance_meta: DPBalanceMeta,
) -> None:
# for dp balance
self.global_balance_id = 0
self.balance_meta = dp_balance_meta
# Parse args
self.max_total_num_tokens = None
self.server_args = server_args
......@@ -79,6 +94,7 @@ class DataParallelController:
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]
......@@ -234,6 +250,7 @@ class DataParallelController:
pp_rank,
dp_rank,
writer,
self.balance_meta,
),
)
with memory_saver_adapter.configure_subprocess():
......@@ -269,6 +286,33 @@ class DataParallelController:
def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()
def minimum_tokens_scheduler(self, req):
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
def get_next_global_balance_id() -> int:
INT32_MAX = 2147483647
current_id = self.global_balance_id
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
return current_id
req.dp_balance_id = get_next_global_balance_id()
with self.balance_meta.mutex:
# 1. local_tokens represents the tokens currently inferring on the worker,
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
onfly_info = self.balance_meta.get_shared_onfly()
local_tokens = self.balance_meta.get_shared_local_tokens()
total_tokens = [
local_token + sum(onfly_dict.values())
for local_token, onfly_dict in zip(local_tokens, onfly_info)
]
target_worker = total_tokens.index(min(total_tokens))
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
# 2. write the new onfly info to the shm
self.balance_meta.set_shared_onfly_info(onfly_info)
# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
self.workers[target_worker].send_pyobj(req)
def event_loop(self):
while True:
while True:
......@@ -302,9 +346,12 @@ def run_data_parallel_controller_process(
setproctitle.setproctitle("sglang::data_parallel_controller")
configure_logger(server_args)
parent_process = psutil.Process().parent()
balance_meta = DPBalanceMeta(server_args.dp_size)
try:
controller = DataParallelController(server_args, port_args)
controller = DataParallelController(
server_args, port_args, dp_balance_meta=balance_meta
)
pipe_writer.send(
{
"status": "ready",
......@@ -323,3 +370,6 @@ def run_data_parallel_controller_process(
traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
finally:
# we need to destruct mp.Manager() in balance_meta
balance_meta.destructor()
......@@ -523,6 +523,9 @@ class TokenizedGenerateReqInput:
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
# For dp balance
dp_balance_id: int = -1
@dataclass
class EmbeddingReqInput:
......@@ -648,6 +651,8 @@ class TokenizedEmbeddingReqInput:
token_type_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams
# For dp balance
dp_balance_id: int = -1
@dataclass
......
......@@ -126,7 +126,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
......@@ -203,6 +203,7 @@ class Scheduler(
moe_ep_rank: int,
pp_rank: int,
dp_rank: Optional[int],
dp_balance_meta: Optional[DPBalanceMeta] = None,
):
# Parse args
self.server_args = server_args
......@@ -522,6 +523,15 @@ class Scheduler(
]
)
self.balance_meta = dp_balance_meta
if (
server_args.enable_dp_attention
and server_args.load_balance_method == "minimum_tokens"
):
assert dp_balance_meta is not None
self.recv_dp_balance_id_this_term = []
def init_tokenizer(self):
server_args = self.server_args
......@@ -1049,6 +1059,12 @@ class Scheduler(
self,
recv_req: TokenizedGenerateReqInput,
):
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
# Create a new request
if (
recv_req.session_params is None
......@@ -1459,6 +1475,11 @@ class Scheduler(
# Handle DP attention
if need_dp_attn_preparation:
if (
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
self.handle_dp_balance_data(ret)
ret = self.prepare_mlp_sync_batch(ret)
return ret
......@@ -1786,6 +1807,86 @@ class Scheduler(
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
)
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(
recv_list
) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
recv_list, dtype=torch.int32
)
if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None
torch.distributed.gather(
recv_tensor, gathered_list, group=self.tp_cpu_group
)
gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(
tensor[2 : list_length + 2].tolist()
)
return gathered_id_list_per_worker, holding_tokens_list
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
meta = self.balance_meta
with meta.mutex:
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
assert len(new_recv_rid_lists) == len(
onfly_list
), "num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id = 0
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
for new_recv_rid in new_recv_rids:
assert (
new_recv_rid in on_fly_reqs
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
del on_fly_reqs[new_recv_rid]
worker_id += 1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)
holding_tokens = self.get_load()
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
holding_tokens
)
self.recv_dp_balance_id_this_term.clear()
if self.tp_rank == 0: # only first worker write info
write_shared_dp_balance_info(
new_recv_dp_balance_id_list, holding_token_list
)
@staticmethod
def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch,
......@@ -2394,6 +2495,7 @@ def run_scheduler_process(
pp_rank: int,
dp_rank: Optional[int],
pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None,
):
# Generate the prefix
prefix = ""
......@@ -2427,7 +2529,14 @@ def run_scheduler_process(
# Create a scheduler and run the event loop
try:
scheduler = Scheduler(
server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
server_args,
port_args,
gpu_id,
tp_rank,
moe_ep_rank,
pp_rank,
dp_rank,
dp_balance_meta=balance_meta,
)
pipe_writer.send(
{
......
import logging
import multiprocessing as mp
from http import HTTPStatus
from typing import Optional
from typing import Dict, List, Optional
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
......@@ -38,3 +39,46 @@ def validate_input_length(
return error_msg
return None
class DPBalanceMeta:
"""
This class will be use in scheduler and dp controller
"""
def __init__(self, num_workers: int):
self.num_workers = num_workers
self._manager = mp.Manager()
self.mutex = self._manager.Lock()
init_local_tokens = [0] * self.num_workers
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
self.shared_state = self._manager.Namespace()
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
def destructor(self):
# we must destructor this class manually
self._manager.shutdown()
def get_shared_onfly(self) -> List[Dict[int, int]]:
return [dict(d) for d in self.shared_state.onfly_info]
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
self.shared_state.onfly_info = data
def get_shared_local_tokens(self) -> List[int]:
return list(self.shared_state.local_tokens)
def set_shared_local_tokens(self, data: List[int]):
self.shared_state.local_tokens = data
def __getstate__(self):
state = self.__dict__.copy()
del state["_manager"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._manager = None
......@@ -1171,6 +1171,7 @@ class ServerArgs:
choices=[
"round_robin",
"shortest_queue",
"minimum_tokens",
],
)
......
......@@ -137,5 +137,60 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
self.assertGreater(avg_spec_accept_length, 2.5)
class TestDPAttentionMinimumTokenLoadBalance(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--enable-dp-attention",
"--dp",
"2",
"--enable-torch-compile",
"--torch-compile-max-bs",
"2",
"--load-balance-method",
"minimum_tokens",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment