Unverified Commit 305c9e8c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[4/N]DP refactor: support watching mode `get_load` and shortest queue strategy (#10201)

parent ca63f075
......@@ -27,7 +27,7 @@ import tempfile
import threading
import time
from http import HTTPStatus
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
import setproctitle
......@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager,
MultiTokenizerRouter,
get_main_process_id,
monkey_patch_uvicorn_multiprocessing,
read_from_shared_memory,
......@@ -127,7 +128,9 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
# Store global states
@dataclasses.dataclass
class _GlobalState:
tokenizer_manager: TokenizerManager
tokenizer_manager: Union[
TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager
]
template_manager: TemplateManager
scheduler_info: Dict
......
......@@ -21,6 +21,7 @@ import struct
import sys
import threading
import time
from collections import deque
from enum import Enum, auto
from multiprocessing import shared_memory
from typing import Dict, List
......@@ -34,6 +35,7 @@ from sglang.srt.managers.io_struct import (
BlockReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
WatchLoadUpdateReq,
)
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process
......@@ -46,7 +48,7 @@ from sglang.srt.utils import (
get_zmq_socket,
kill_itself_when_parent_died,
)
from sglang.utils import get_exception_traceback
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__)
......@@ -67,6 +69,42 @@ class LoadBalanceMethod(Enum):
raise ValueError(f"Invalid load balance method: {method}") from exc
class DPBudget:
def __init__(self):
# TODO: support minimum tokens method
self.budget_queue = deque()
def update_budget(self, load_update: WatchLoadUpdateReq):
"""Update the budget queue.
Use num_reqs instead of num_waiting_reqs to balance decode running batch.
"""
loads = load_update.loads
self.budget_queue.clear()
num_reqs = [load.num_reqs for load in loads]
if not num_reqs:
return
max_num_reqs = max(num_reqs)
if all(x == max_num_reqs for x in num_reqs):
return
while any(x != num_reqs[0] for x in num_reqs):
min_load = min(num_reqs)
min_indices = [i for i, x in enumerate(num_reqs) if x == min_load]
second_min_load = min(x for x in num_reqs if x > min_load)
self.budget_queue.extend(
[loads[i].dp_rank for i in min_indices] * (second_min_load - min_load)
)
for idx in min_indices:
num_reqs[idx] = second_min_load
def dispatch(self):
if self.budget_queue:
return self.budget_queue.popleft()
return None
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""
......@@ -104,6 +142,9 @@ class DataParallelController:
}
self.dispatching = dispatch_lookup[self.load_balance_method]
# Load balance budget
self.dp_budget = DPBudget()
# Launch data parallel workers
self.scheduler_procs = []
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
......@@ -127,6 +168,31 @@ class DataParallelController:
self.max_req_input_len = None
self.init_dispatcher()
def send_to_all_workers(self, obj):
for worker in self.workers:
worker.send_pyobj(obj)
def send_control_message(self, obj):
# Send control messages to first worker of tp group
for worker in self.workers[:: self.control_message_step]:
worker.send_pyobj(obj)
def handle_load_update_req(self, obj):
self.dp_budget.update_budget(obj)
def init_dispatcher(self):
self._request_dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.dispatching),
(TokenizedEmbeddingReqInput, self.dispatching),
(BlockReqInput, self.send_to_all_workers),
(WatchLoadUpdateReq, self.handle_load_update_req),
]
)
self._request_dispatcher.add_fallback_fn(self.send_control_message)
def launch_dp_schedulers(self, server_args, port_args):
base_gpu_id = 0
......@@ -291,10 +357,14 @@ class DataParallelController:
else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
def shortest_queue_scheduler(self, input_requests):
def shortest_queue_scheduler(self, req):
if self.maybe_external_dp_rank_routing(req):
return
raise NotImplementedError()
target_worker = self.dp_budget.dispatch()
if target_worker is None:
self.round_robin_scheduler(req)
else:
self.workers[target_worker].send_pyobj(req)
def minimum_tokens_scheduler(self, req):
if self.maybe_external_dp_rank_routing(req):
......@@ -333,22 +403,7 @@ class DataParallelController:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
if isinstance(
recv_req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
),
):
self.dispatching(recv_req)
elif isinstance(recv_req, BlockReqInput):
for worker in self.workers:
worker.send_pyobj(recv_req)
else:
# Send other control messages to first worker of tp group
for worker in self.workers[:: self.control_message_step]:
worker.send_pyobj(recv_req)
self._request_dispatcher(recv_req)
def run_data_parallel_controller_process(
......
......@@ -297,7 +297,7 @@ def run_detokenizer_process(
else:
manager.event_loop()
except Exception:
manager.socket_mapping.clear_all_sockets()
manager.maybe_clear_socket_mapping()
traceback = get_exception_traceback()
logger.error(f"DetokenizerManager hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
......@@ -1374,3 +1374,21 @@ class BlockReqType(Enum):
@dataclass
class BlockReqInput:
type: BlockReqType
@dataclass
class GetLoadReqInput:
pass
@dataclass
class GetLoadReqOutput:
dp_rank: int
num_reqs: int
num_waiting_reqs: int
num_tokens: int
@dataclass
class WatchLoadUpdateReq:
loads: List[GetLoadReqOutput]
......@@ -354,6 +354,10 @@ class MultiHttpWorkerDetokenizerMixin:
worker_ids = []
return worker_ids
def maybe_clear_socket_mapping(self):
if hasattr(self, "socket_mapping"):
self.socket_mapping.clear_all_sockets()
def multi_http_worker_event_loop(self):
"""The event loop that handles requests, for multi multi-http-worker mode"""
self.socket_mapping = SocketMapping()
......
......@@ -79,6 +79,8 @@ from sglang.srt.managers.io_struct import (
FreezeGCReq,
GetInternalStateReq,
GetInternalStateReqOutput,
GetLoadReqInput,
GetLoadReqOutput,
GetWeightsByNameReqInput,
HealthCheckOutput,
InitWeightsSendGroupForRemoteInstanceReqInput,
......@@ -577,6 +579,7 @@ class Scheduler(
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
(GetLoadReqInput, self.get_load),
]
)
......@@ -2279,39 +2282,50 @@ class Scheduler(
if_success = False
return if_success
def get_load(self):
def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
# TODO(lsyin): use dynamically maintained num_waiting_tokens
if self.is_hybrid:
load_full = (
num_tokens_full = (
self.full_tokens_per_layer
- self.token_to_kv_pool_allocator.full_available_size()
- self.tree_cache.full_evictable_size()
)
load_swa = (
num_tokens_swa = (
self.swa_tokens_per_layer
- self.token_to_kv_pool_allocator.swa_available_size()
- self.tree_cache.swa_evictable_size()
)
load = max(load_full, load_swa)
num_tokens = max(num_tokens_full, num_tokens_swa)
else:
load = (
num_tokens = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size()
)
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
# Tokens in waiting queue, bootstrap queue, prealloc queue
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
num_waiting_reqs = len(self.waiting_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
load += sum(
num_tokens += sum(
len(req.origin_input_ids)
for req in self.disagg_prefill_bootstrap_queue.queue
)
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
load += sum(
num_tokens += sum(
len(req.req.origin_input_ids)
for req in self.disagg_decode_prealloc_queue.queue
)
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
return load
return GetLoadReqOutput(
dp_rank=self.dp_rank,
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
num_waiting_reqs=num_waiting_reqs,
num_tokens=num_tokens,
)
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
......@@ -2337,8 +2351,6 @@ class Scheduler(
if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict
ret["load"] = self.get_load()
return GetInternalStateReqOutput(internal_state=ret)
def set_internal_state(self, recv_req: SetInternalStateReq):
......
......@@ -279,7 +279,7 @@ class SchedulerMetricsMixin:
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
holding_tokens = self.get_load()
holding_tokens = self.get_load().num_tokens
new_recv_dp_balance_id_list, holding_token_list = (
self.gather_dp_balance_info(holding_tokens)
......
from __future__ import annotations
import asyncio
import copy
import logging
import os
import time
......@@ -18,6 +19,7 @@ from typing import (
)
import fastapi
import zmq
from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput,
......@@ -28,6 +30,8 @@ from sglang.srt.managers.io_struct import (
FlushCacheReqOutput,
GetInternalStateReq,
GetInternalStateReqOutput,
GetLoadReqInput,
GetLoadReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
InitWeightsSendGroupForRemoteInstanceReqInput,
......@@ -75,14 +79,17 @@ class _Communicator(Generic[T]):
enable_multi_tokenizer = False
def __init__(self, sender, fan_out: int):
def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"):
self._sender = sender
self._fan_out = fan_out
self._mode = mode
self._result_event: Optional[asyncio.Event] = None
self._result_values: Optional[List[T]] = None
self._ready_queue: Deque[asyncio.Future] = deque()
async def __call__(self, obj):
assert mode in ["queueing", "watching"]
async def queueing_call(self, obj: T):
ready_event = asyncio.Event()
if self._result_event is not None or len(self._ready_queue) > 0:
self._ready_queue.append(ready_event)
......@@ -106,6 +113,28 @@ class _Communicator(Generic[T]):
return result_values
async def watching_call(self, obj):
if self._result_event is None:
assert self._result_values is None
self._result_values = []
self._result_event = asyncio.Event()
if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj)
await self._result_event.wait()
result_values = copy.deepcopy(self._result_values)
self._result_event = self._result_values = None
return result_values
async def __call__(self, obj):
if self._mode == "queueing":
return await self.queueing_call(obj)
else:
return await self.watching_call(obj)
def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
......@@ -165,6 +194,9 @@ class TokenizerCommunicatorMixin:
self.update_lora_adapter_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_load_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size, mode="watching"
)
self._result_dispatcher += self._get_communicator_dispatcher()
......@@ -235,6 +267,10 @@ class TokenizerCommunicatorMixin:
LoRAUpdateResult,
self.update_lora_adapter_communicator.handle_recv,
),
(
GetLoadReqOutput,
self.get_load_communicator.handle_recv,
),
]
)
......@@ -528,10 +564,6 @@ class TokenizerCommunicatorMixin:
)
return [res.updated for res in responses]
async def get_load(self: TokenizerManager) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]:
req = GetLoadReqInput()
return await self.get_load_communicator(req)
......@@ -64,6 +64,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
FreezeGCReq,
GenerateReqInput,
GetLoadReqInput,
HealthCheckOutput,
MultiTokenizerWrapper,
OpenSessionReqInput,
......@@ -73,6 +74,7 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
WatchLoadUpdateReq,
)
from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
......@@ -1240,6 +1242,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
)
self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.watch_load_thread))
)
def dump_requests_before_crash(self):
if self.crash_dump_performed:
......@@ -1844,6 +1849,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return scores
async def watch_load_thread(self):
# Only for dp_controller when dp_size > 1
if (
self.server_args.dp_size == 1
or self.server_args.load_balance_method == "round_robin"
):
return
while True:
await asyncio.sleep(self.server_args.load_watch_interval)
loads = await self.get_load_communicator(GetLoadReqInput())
load_udpate_req = WatchLoadUpdateReq(loads=loads)
self.send_to_scheduler.send_pyobj(load_udpate_req)
class ServerStatus(Enum):
Up = "Up"
......
......@@ -233,6 +233,7 @@ class ServerArgs:
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
load_watch_interval: float = 0.1
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
prefill_round_robin_balance: bool = False
......@@ -663,6 +664,7 @@ class ServerArgs:
if self.dp_size == 1:
self.enable_dp_attention = False
self.enable_dp_lm_head = False
# Data parallelism attention
if self.enable_dp_attention:
......@@ -1488,6 +1490,12 @@ class ServerArgs:
"minimum_tokens",
],
)
parser.add_argument(
"--load-watch-interval",
type=float,
default=ServerArgs.load_watch_interval,
help="The interval of load watching in seconds.",
)
parser.add_argument(
"--prefill-round-robin-balance",
default=ServerArgs.prefill_round_robin_balance,
......
......@@ -1160,7 +1160,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
def get_zmq_socket(
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
):
) -> zmq.Socket:
mem = psutil.virtual_memory()
total_mem = mem.total / 1024**3
available_mem = mem.available / 1024**3
......
......@@ -472,6 +472,10 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
class TypeBasedDispatcher:
def __init__(self, mapping: List[Tuple[Type, Callable]]):
self._mapping = mapping
self._fallback_fn = None
def add_fallback_fn(self, fallback_fn: Callable):
self._fallback_fn = fallback_fn
def __iadd__(self, other: "TypeBasedDispatcher"):
self._mapping.extend(other._mapping)
......@@ -481,6 +485,9 @@ class TypeBasedDispatcher:
for ty, fn in self._mapping:
if isinstance(obj, ty):
return fn(obj)
if self._fallback_fn is not None:
return self._fallback_fn(obj)
raise ValueError(f"Invalid object: {obj}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment