Unverified Commit f1569876 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: add direct routing strategy to DP worker (#6884)

parent 3465d7ae
......@@ -109,10 +109,12 @@ class CommonKVReceiver(BaseKVReceiver):
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
......@@ -180,7 +182,11 @@ class CommonKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
self.target_dp_group = bootstrap_room % self.prefill_dp_size
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
......
......@@ -156,6 +156,7 @@ class DecodePreallocQueue:
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
)
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
......
......@@ -56,6 +56,7 @@ class FakeKVReceiver(BaseKVReceiver):
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.has_init = False
......
......@@ -765,6 +765,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
mgr: MooncakeKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
......@@ -772,6 +773,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
self.conclude_state = None
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
......@@ -845,7 +847,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
self.target_dp_group = self.bootstrap_room % self.prefill_dp_size
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
......
......@@ -407,9 +407,10 @@ class NixlKVReceiver(CommonKVReceiver):
mgr: NixlKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.started_transfer = False
super().__init__(mgr, bootstrap_addr, bootstrap_room)
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
......
......@@ -23,6 +23,12 @@ class EngineBase(ABC):
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: Optional[bool] = None,
stream: Optional[bool] = None,
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, Iterator[Dict]]:
"""Generate outputs based on given inputs."""
pass
......
......@@ -167,11 +167,22 @@ class Engine(EngineBase):
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
if self.server_args.enable_dp_attention:
if data_parallel_rank is None:
logger.info("data_parallel_rank not provided, using default dispatch")
elif data_parallel_rank < 0:
raise ValueError("data_parallel_rank must be non-negative")
elif data_parallel_rank >= self.server_args.dp_size:
raise ValueError(
f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}"
)
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
......@@ -188,6 +199,7 @@ class Engine(EngineBase):
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
)
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None)
......@@ -237,11 +249,24 @@ class Engine(EngineBase):
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, AsyncIterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
if self.server_args.enable_dp_attention:
if data_parallel_rank is None:
logger.info("data_parallel_rank not provided, using default dispatch")
elif data_parallel_rank < 0:
raise ValueError("data_parallel_rank must be non-negative")
elif data_parallel_rank >= self.server_args.dp_size:
raise ValueError(
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
)
logger.info(f"data_parallel_rank: {data_parallel_rank}")
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
......@@ -257,6 +282,7 @@ class Engine(EngineBase):
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
)
generator = self.tokenizer_manager.generate_request(obj, None)
......
......@@ -248,12 +248,20 @@ class DataParallelController:
def round_robin_scheduler(self, req: Req):
if self.server_args.disaggregation_mode == "null":
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
else:
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()
......
......@@ -106,6 +106,9 @@ class GenerateReqInput:
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
......@@ -417,6 +420,9 @@ class GenerateReqInput:
bootstrap_room=(
self.bootstrap_room[i] if self.bootstrap_room is not None else None
),
data_parallel_rank=(
self.data_parallel_rank if self.data_parallel_rank is not None else None
),
)
......@@ -464,6 +470,9 @@ class TokenizedGenerateReqInput:
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
@dataclass
class EmbeddingReqInput:
......
......@@ -451,6 +451,7 @@ class Req:
bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
# Input and output info
self.rid = rid
......@@ -605,6 +606,9 @@ class Req:
self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[BaseKVSender] = None
# For data parallel rank routing
self.data_parallel_rank: Optional[int] = data_parallel_rank
# the start index of the sent kv cache
# We want to send it chunk by chunk for chunked prefill.
# After every chunk forward, we do the following:
......
......@@ -949,6 +949,7 @@ class Scheduler(
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank,
)
req.tokenizer = self.tokenizer
......
......@@ -570,6 +570,7 @@ class TokenizerManager:
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states,
data_parallel_rank=obj.data_parallel_rank,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment