Unverified Commit f1569876 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: add direct routing strategy to DP worker (#6884)

parent 3465d7ae
...@@ -109,10 +109,12 @@ class CommonKVReceiver(BaseKVReceiver): ...@@ -109,10 +109,12 @@ class CommonKVReceiver(BaseKVReceiver):
mgr: BaseKVManager, mgr: BaseKVManager,
bootstrap_addr: str, bootstrap_addr: str,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
): ):
self.bootstrap_room = bootstrap_room self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr self.kv_mgr = mgr
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = ( self.prefill_tp_size, self.prefill_dp_size = (
...@@ -180,7 +182,11 @@ class CommonKVReceiver(BaseKVReceiver): ...@@ -180,7 +182,11 @@ class CommonKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0] self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.target_dp_group = bootstrap_room % self.prefill_dp_size if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = ( bootstrap_key = (
......
...@@ -156,6 +156,7 @@ class DecodePreallocQueue: ...@@ -156,6 +156,7 @@ class DecodePreallocQueue:
mgr=self.kv_manager, mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room, bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
) )
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver)) self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
......
...@@ -56,6 +56,7 @@ class FakeKVReceiver(BaseKVReceiver): ...@@ -56,6 +56,7 @@ class FakeKVReceiver(BaseKVReceiver):
mgr: BaseKVManager, mgr: BaseKVManager,
bootstrap_addr: str, bootstrap_addr: str,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
): ):
self.has_init = False self.has_init = False
......
...@@ -765,6 +765,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -765,6 +765,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
mgr: MooncakeKVManager, mgr: MooncakeKVManager,
bootstrap_addr: str, bootstrap_addr: str,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
): ):
self.bootstrap_room = bootstrap_room self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr self.bootstrap_addr = bootstrap_addr
...@@ -772,6 +773,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -772,6 +773,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.session_id = self.kv_mgr.get_session_id() self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
self.conclude_state = None self.conclude_state = None
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = ( self.prefill_tp_size, self.prefill_dp_size = (
...@@ -845,7 +847,11 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -845,7 +847,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0] self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.target_dp_group = self.bootstrap_room % self.prefill_dp_size if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = ( bootstrap_key = (
......
...@@ -407,9 +407,10 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -407,9 +407,10 @@ class NixlKVReceiver(CommonKVReceiver):
mgr: NixlKVManager, mgr: NixlKVManager,
bootstrap_addr: str, bootstrap_addr: str,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
): ):
self.started_transfer = False self.started_transfer = False
super().__init__(mgr, bootstrap_addr, bootstrap_room) super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
......
...@@ -23,6 +23,12 @@ class EngineBase(ABC): ...@@ -23,6 +23,12 @@ class EngineBase(ABC):
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None, lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: Optional[bool] = None,
stream: Optional[bool] = None,
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, Iterator[Dict]]: ) -> Union[Dict, Iterator[Dict]]:
"""Generate outputs based on given inputs.""" """Generate outputs based on given inputs."""
pass pass
......
...@@ -167,11 +167,22 @@ class Engine(EngineBase): ...@@ -167,11 +167,22 @@ class Engine(EngineBase):
bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, Iterator[Dict]]: ) -> Union[Dict, Iterator[Dict]]:
""" """
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation. Please refer to `GenerateReqInput` for the documentation.
""" """
if self.server_args.enable_dp_attention:
if data_parallel_rank is None:
logger.info("data_parallel_rank not provided, using default dispatch")
elif data_parallel_rank < 0:
raise ValueError("data_parallel_rank must be non-negative")
elif data_parallel_rank >= self.server_args.dp_size:
raise ValueError(
f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}"
)
obj = GenerateReqInput( obj = GenerateReqInput(
text=prompt, text=prompt,
input_ids=input_ids, input_ids=input_ids,
...@@ -188,6 +199,7 @@ class Engine(EngineBase): ...@@ -188,6 +199,7 @@ class Engine(EngineBase):
bootstrap_host=bootstrap_host, bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port, bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room, bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
...@@ -237,11 +249,24 @@ class Engine(EngineBase): ...@@ -237,11 +249,24 @@ class Engine(EngineBase):
bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, AsyncIterator[Dict]]: ) -> Union[Dict, AsyncIterator[Dict]]:
""" """
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation. Please refer to `GenerateReqInput` for the documentation.
""" """
if self.server_args.enable_dp_attention:
if data_parallel_rank is None:
logger.info("data_parallel_rank not provided, using default dispatch")
elif data_parallel_rank < 0:
raise ValueError("data_parallel_rank must be non-negative")
elif data_parallel_rank >= self.server_args.dp_size:
raise ValueError(
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
)
logger.info(f"data_parallel_rank: {data_parallel_rank}")
obj = GenerateReqInput( obj = GenerateReqInput(
text=prompt, text=prompt,
input_ids=input_ids, input_ids=input_ids,
...@@ -257,6 +282,7 @@ class Engine(EngineBase): ...@@ -257,6 +282,7 @@ class Engine(EngineBase):
bootstrap_host=bootstrap_host, bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port, bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room, bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
) )
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
......
...@@ -248,12 +248,20 @@ class DataParallelController: ...@@ -248,12 +248,20 @@ class DataParallelController:
def round_robin_scheduler(self, req: Req): def round_robin_scheduler(self, req: Req):
if self.server_args.disaggregation_mode == "null": if self.server_args.disaggregation_mode == "null":
self.workers[self.round_robin_counter].send_pyobj(req) if req.data_parallel_rank is not None:
self.round_robin_counter = (self.round_robin_counter + 1) % len( logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers self.workers[req.data_parallel_rank].send_pyobj(req)
) else:
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
else: else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
def shortest_queue_scheduler(self, input_requests): def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -106,6 +106,9 @@ class GenerateReqInput: ...@@ -106,6 +106,9 @@ class GenerateReqInput:
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data) return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
...@@ -417,6 +420,9 @@ class GenerateReqInput: ...@@ -417,6 +420,9 @@ class GenerateReqInput:
bootstrap_room=( bootstrap_room=(
self.bootstrap_room[i] if self.bootstrap_room is not None else None self.bootstrap_room[i] if self.bootstrap_room is not None else None
), ),
data_parallel_rank=(
self.data_parallel_rank if self.data_parallel_rank is not None else None
),
) )
...@@ -464,6 +470,9 @@ class TokenizedGenerateReqInput: ...@@ -464,6 +470,9 @@ class TokenizedGenerateReqInput:
bootstrap_port: Optional[int] = None bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None bootstrap_room: Optional[int] = None
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
@dataclass @dataclass
class EmbeddingReqInput: class EmbeddingReqInput:
......
...@@ -451,6 +451,7 @@ class Req: ...@@ -451,6 +451,7 @@ class Req:
bootstrap_host: Optional[str] = None, bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None, bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
): ):
# Input and output info # Input and output info
self.rid = rid self.rid = rid
...@@ -605,6 +606,9 @@ class Req: ...@@ -605,6 +606,9 @@ class Req:
self.bootstrap_room: Optional[int] = bootstrap_room self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[BaseKVSender] = None self.disagg_kv_sender: Optional[BaseKVSender] = None
# For data parallel rank routing
self.data_parallel_rank: Optional[int] = data_parallel_rank
# the start index of the sent kv cache # the start index of the sent kv cache
# We want to send it chunk by chunk for chunked prefill. # We want to send it chunk by chunk for chunked prefill.
# After every chunk forward, we do the following: # After every chunk forward, we do the following:
......
...@@ -949,6 +949,7 @@ class Scheduler( ...@@ -949,6 +949,7 @@ class Scheduler(
bootstrap_host=recv_req.bootstrap_host, bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port, bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room, bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank,
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
......
...@@ -570,6 +570,7 @@ class TokenizerManager: ...@@ -570,6 +570,7 @@ class TokenizerManager:
session_params=session_params, session_params=session_params,
custom_logit_processor=obj.custom_logit_processor, custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states, return_hidden_states=obj.return_hidden_states,
data_parallel_rank=obj.data_parallel_rank,
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment