Unverified Commit 83d55ac5 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[1/N]DP refactor: Improve dp rank scheduling in PD disaggregation mode. (#10169)

parent 2fe17735
...@@ -128,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver): ...@@ -128,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver):
mgr: BaseKVManager, mgr: BaseKVManager,
bootstrap_addr: str, bootstrap_addr: str,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None, prefill_dp_rank: Optional[int] = None,
): ):
self.bootstrap_room = bootstrap_room self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr self.kv_mgr = mgr
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = ( self.prefill_tp_size, self.prefill_dp_size = (
...@@ -201,11 +200,14 @@ class CommonKVReceiver(BaseKVReceiver): ...@@ -201,11 +200,14 @@ class CommonKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0] self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1 self.required_dst_info_num = 1
if self.data_parallel_rank is not None: if prefill_dp_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
self.target_dp_group = self.data_parallel_rank self.prefill_dp_rank = prefill_dp_rank
else: else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
# FIXME: alias here: target_dp_group -> prefill_dp_rank
self.target_dp_group = self.prefill_dp_rank
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = ( bootstrap_key = (
......
...@@ -250,7 +250,7 @@ class DecodePreallocQueue: ...@@ -250,7 +250,7 @@ class DecodePreallocQueue:
mgr=self.kv_manager, mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room, bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank, prefill_dp_rank=req.data_parallel_rank,
) )
self.queue.append( self.queue.append(
......
...@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver): ...@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver):
mgr: BaseKVManager, mgr: BaseKVManager,
bootstrap_addr: str, bootstrap_addr: str,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None, prefill_dp_rank: Optional[int] = None,
): ):
self.has_init = False self.has_init = False
......
...@@ -1212,7 +1212,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1212,7 +1212,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
mgr: MooncakeKVManager, mgr: MooncakeKVManager,
bootstrap_addr: str, bootstrap_addr: str,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None, prefill_dp_rank: Optional[int] = None,
): ):
self.bootstrap_room = bootstrap_room self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr self.bootstrap_addr = bootstrap_addr
...@@ -1221,7 +1221,6 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1221,7 +1221,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
self.conclude_state = None self.conclude_state = None
self.init_time = None self.init_time = None
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
( (
...@@ -1320,11 +1319,14 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1320,11 +1319,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
) * (self.prefill_pp_size // self.kv_mgr.pp_size) ) * (self.prefill_pp_size // self.kv_mgr.pp_size)
if self.data_parallel_rank is not None: if prefill_dp_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
self.target_dp_group = self.data_parallel_rank self.prefill_dp_rank = prefill_dp_rank
else: else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
# FIXME: alias here: target_dp_group -> prefill_dp_rank
self.target_dp_group = self.prefill_dp_rank
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
self.required_prefill_response_num self.required_prefill_response_num
......
...@@ -454,11 +454,11 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -454,11 +454,11 @@ class NixlKVReceiver(CommonKVReceiver):
mgr: NixlKVManager, mgr: NixlKVManager,
bootstrap_addr: str, bootstrap_addr: str,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None, prefill_dp_rank: Optional[int] = None,
): ):
self.started_transfer = False self.started_transfer = False
self.conclude_state = None self.conclude_state = None
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank) super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
......
...@@ -106,7 +106,7 @@ class DataParallelController: ...@@ -106,7 +106,7 @@ class DataParallelController:
# Launch data parallel workers # Launch data parallel workers
self.scheduler_procs = [] self.scheduler_procs = []
self.workers = [None] * server_args.dp_size self.workers: List[zmq.Socket] = [None] * server_args.dp_size
if server_args.enable_dp_attention: if server_args.enable_dp_attention:
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args) dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
...@@ -272,27 +272,34 @@ class DataParallelController: ...@@ -272,27 +272,34 @@ class DataParallelController:
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
def maybe_external_dp_rank_routing(self, req: Req):
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
return True
return False
def round_robin_scheduler(self, req: Req): def round_robin_scheduler(self, req: Req):
if self.maybe_external_dp_rank_routing(req):
return
if self.server_args.disaggregation_mode == "null": if self.server_args.disaggregation_mode == "null":
if req.data_parallel_rank is not None: self.workers[self.round_robin_counter].send_pyobj(req)
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}") self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers[req.data_parallel_rank].send_pyobj(req) self.workers
else: )
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
else: else:
if req.data_parallel_rank is not None: self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
def shortest_queue_scheduler(self, input_requests): def shortest_queue_scheduler(self, input_requests):
if self.maybe_external_dp_rank_routing(req):
return
raise NotImplementedError() raise NotImplementedError()
def minimum_tokens_scheduler(self, req): def minimum_tokens_scheduler(self, req):
if self.maybe_external_dp_rank_routing(req):
return
# This variable corresponds to the balance_id in TokenizedGenerateReqInput. # This variable corresponds to the balance_id in TokenizedGenerateReqInput.
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received). # We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
def get_next_global_balance_id() -> int: def get_next_global_balance_id() -> int:
......
...@@ -450,9 +450,7 @@ class MultiTokenizerManager(TokenizerManager): ...@@ -450,9 +450,7 @@ class MultiTokenizerManager(TokenizerManager):
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
): ):
setproctitle.setproctitle( setproctitle.setproctitle(f"sglang::tokenizer_worker:{os.getpid()}")
f"sglang::http_server/multi_tokenizer_manager:{os.getpid()}"
)
# prevent init prefill bootstrapserver again # prevent init prefill bootstrapserver again
disaggregation_mode = server_args.disaggregation_mode disaggregation_mode = server_args.disaggregation_mode
server_args.disaggregation_mode = "null" server_args.disaggregation_mode = "null"
......
...@@ -44,6 +44,7 @@ from sglang.srt.utils import ( ...@@ -44,6 +44,7 @@ from sglang.srt.utils import (
is_valid_ipv6_address, is_valid_ipv6_address,
nullable_str, nullable_str,
) )
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -223,6 +224,8 @@ class ServerArgs: ...@@ -223,6 +224,8 @@ class ServerArgs:
# Data parallelism # Data parallelism
dp_size: int = 1 dp_size: int = 1
load_balance_method: str = "round_robin" load_balance_method: str = "round_robin"
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
prefill_round_robin_balance: bool = False
# Multi-node distributed serving # Multi-node distributed serving
dist_init_addr: Optional[str] = None dist_init_addr: Optional[str] = None
...@@ -623,12 +626,12 @@ class ServerArgs: ...@@ -623,12 +626,12 @@ class ServerArgs:
if self.grammar_backend is None: if self.grammar_backend is None:
self.grammar_backend = "xgrammar" self.grammar_backend = "xgrammar"
if self.dp_size == 1:
self.enable_dp_attention = False
# Data parallelism attention # Data parallelism attention
if self.enable_dp_attention: if self.enable_dp_attention:
self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.schedule_conservativeness = self.schedule_conservativeness * 0.3
assert (
self.dp_size > 1
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
assert self.tp_size % self.dp_size == 0 assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
logger.warning( logger.warning(
...@@ -807,6 +810,13 @@ class ServerArgs: ...@@ -807,6 +810,13 @@ class ServerArgs:
self.disable_radix_cache = True self.disable_radix_cache = True
logger.warning("KV cache is forced as chunk cache for decode server") logger.warning("KV cache is forced as chunk cache for decode server")
if self.dp_size > 1 and not is_in_ci():
assert self.prefill_round_robin_balance, (
"Prefill round robin balance is required when dp size > 1. "
"Please make sure that the prefill instance is launched with `--load-balance-method round_robin`"
" and `--prefill-round-robin-balance` is set for decode server."
)
elif self.disaggregation_mode == "prefill": elif self.disaggregation_mode == "prefill":
if self.disaggregation_decode_tp is None: if self.disaggregation_decode_tp is None:
self.disaggregation_decode_tp = self.tp_size self.disaggregation_decode_tp = self.tp_size
...@@ -1384,6 +1394,12 @@ class ServerArgs: ...@@ -1384,6 +1394,12 @@ class ServerArgs:
"minimum_tokens", "minimum_tokens",
], ],
) )
parser.add_argument(
"--prefill-round-robin-balance",
default=ServerArgs.prefill_round_robin_balance,
action="store_true",
help="Prefill is round robin balanced. This is used to promise decode server can get the correct dp rank.",
)
# Multi-node distributed serving # Multi-node distributed serving
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment