Unverified Commit 83d55ac5 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[1/N]DP refactor: Improve dp rank scheduling in PD disaggregation mode. (#10169)

parent 2fe17735
......@@ -128,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver):
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
......@@ -201,11 +200,14 @@ class CommonKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
if prefill_dp_rank is not None:
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
self.prefill_dp_rank = prefill_dp_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
# FIXME: alias here: target_dp_group -> prefill_dp_rank
self.target_dp_group = self.prefill_dp_rank
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
......
......@@ -250,7 +250,7 @@ class DecodePreallocQueue:
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
prefill_dp_rank=req.data_parallel_rank,
)
self.queue.append(
......
......@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver):
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.has_init = False
......
......@@ -1212,7 +1212,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
mgr: MooncakeKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
......@@ -1221,7 +1221,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
self.conclude_state = None
self.init_time = None
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
(
......@@ -1320,11 +1319,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
if prefill_dp_rank is not None:
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
self.prefill_dp_rank = prefill_dp_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
# FIXME: alias here: target_dp_group -> prefill_dp_rank
self.target_dp_group = self.prefill_dp_rank
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
self.required_prefill_response_num
......
......@@ -454,11 +454,11 @@ class NixlKVReceiver(CommonKVReceiver):
mgr: NixlKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.started_transfer = False
self.conclude_state = None
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
......
......@@ -106,7 +106,7 @@ class DataParallelController:
# Launch data parallel workers
self.scheduler_procs = []
self.workers = [None] * server_args.dp_size
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
if server_args.enable_dp_attention:
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
......@@ -272,27 +272,34 @@ class DataParallelController:
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
def maybe_external_dp_rank_routing(self, req: Req):
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
return True
return False
def round_robin_scheduler(self, req: Req):
if self.maybe_external_dp_rank_routing(req):
return
if self.server_args.disaggregation_mode == "null":
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
else:
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
else:
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
def shortest_queue_scheduler(self, input_requests):
if self.maybe_external_dp_rank_routing(req):
return
raise NotImplementedError()
def minimum_tokens_scheduler(self, req):
if self.maybe_external_dp_rank_routing(req):
return
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
def get_next_global_balance_id() -> int:
......
......@@ -450,9 +450,7 @@ class MultiTokenizerManager(TokenizerManager):
server_args: ServerArgs,
port_args: PortArgs,
):
setproctitle.setproctitle(
f"sglang::http_server/multi_tokenizer_manager:{os.getpid()}"
)
setproctitle.setproctitle(f"sglang::tokenizer_worker:{os.getpid()}")
# prevent init prefill bootstrapserver again
disaggregation_mode = server_args.disaggregation_mode
server_args.disaggregation_mode = "null"
......
......@@ -44,6 +44,7 @@ from sglang.srt.utils import (
is_valid_ipv6_address,
nullable_str,
)
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
......@@ -223,6 +224,8 @@ class ServerArgs:
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
prefill_round_robin_balance: bool = False
# Multi-node distributed serving
dist_init_addr: Optional[str] = None
......@@ -623,12 +626,12 @@ class ServerArgs:
if self.grammar_backend is None:
self.grammar_backend = "xgrammar"
if self.dp_size == 1:
self.enable_dp_attention = False
# Data parallelism attention
if self.enable_dp_attention:
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
assert (
self.dp_size > 1
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
logger.warning(
......@@ -807,6 +810,13 @@ class ServerArgs:
self.disable_radix_cache = True
logger.warning("KV cache is forced as chunk cache for decode server")
if self.dp_size > 1 and not is_in_ci():
assert self.prefill_round_robin_balance, (
"Prefill round robin balance is required when dp size > 1. "
"Please make sure that the prefill instance is launched with `--load-balance-method round_robin`"
" and `--prefill-round-robin-balance` is set for decode server."
)
elif self.disaggregation_mode == "prefill":
if self.disaggregation_decode_tp is None:
self.disaggregation_decode_tp = self.tp_size
......@@ -1384,6 +1394,12 @@ class ServerArgs:
"minimum_tokens",
],
)
parser.add_argument(
"--prefill-round-robin-balance",
default=ServerArgs.prefill_round_robin_balance,
action="store_true",
help="Prefill is round robin balanced. This is used to promise decode server can get the correct dp rank.",
)
# Multi-node distributed serving
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment