Unverified Commit 6b231325 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[PD Perf] replace Queue to FastQueue (#6649)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent b1c8d4e9
...@@ -31,6 +31,7 @@ from sglang.srt.disaggregation.base.conn import ( ...@@ -31,6 +31,7 @@ from sglang.srt.disaggregation.base.conn import (
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
FastQueue,
group_concurrent_contiguous, group_concurrent_contiguous,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -151,7 +152,6 @@ class MooncakeKVManager(BaseKVManager): ...@@ -151,7 +152,6 @@ class MooncakeKVManager(BaseKVManager):
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine() self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_queue = queue.Queue()
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread() self.start_prefill_thread()
...@@ -159,15 +159,31 @@ class MooncakeKVManager(BaseKVManager): ...@@ -159,15 +159,31 @@ class MooncakeKVManager(BaseKVManager):
self.session_failures = defaultdict(int) self.session_failures = defaultdict(int)
self.failed_sessions = set() self.failed_sessions = set()
self.session_lock = threading.Lock() self.session_lock = threading.Lock()
# Determine the number of threads to use for kv sender # Determine the number of threads to use for kv sender
cpu_count = os.cpu_count() cpu_count = os.cpu_count()
self.executor = concurrent.futures.ThreadPoolExecutor( transfer_thread_pool_size = get_int_env_var(
get_int_env_var(
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE", "SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
min(max(1, cpu_count // 8), 8), min(max(4, int(0.75 * cpu_count) // 8), 12),
)
transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4)
self.transfer_queues: List[FastQueue] = [
FastQueue() for _ in range(transfer_queue_size)
]
assert transfer_thread_pool_size >= transfer_queue_size, (
f"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be "
f"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}."
) )
self.executors = [
concurrent.futures.ThreadPoolExecutor(
transfer_thread_pool_size // transfer_queue_size
) )
for _ in range(transfer_queue_size)
]
for queue, executor in zip(self.transfer_queues, self.executors):
threading.Thread(
target=self.transfer_worker, args=(queue, executor), daemon=True
).start()
self.bootstrap_time_out = get_int_env_var( self.bootstrap_time_out = get_int_env_var(
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30 "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
) )
...@@ -183,7 +199,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -183,7 +199,7 @@ class MooncakeKVManager(BaseKVManager):
) )
# Heartbeat failure should be at least 1 # Heartbeat failure should be at least 1
self.max_failures = max( self.max_failures = max(
int(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1 get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
) )
self.start_decode_thread() self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
...@@ -220,6 +236,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -220,6 +236,7 @@ class MooncakeKVManager(BaseKVManager):
prefill_kv_indices: npt.NDArray[np.int64], prefill_kv_indices: npt.NDArray[np.int64],
dst_kv_ptrs: list[int], dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64], dst_kv_indices: npt.NDArray[np.int64],
executor: concurrent.futures.ThreadPoolExecutor,
): ):
# Group by indices # Group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
...@@ -251,7 +268,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -251,7 +268,7 @@ class MooncakeKVManager(BaseKVManager):
return 0 return 0
futures = [ futures = [
self.executor.submit( executor.submit(
process_layer, process_layer,
src_ptr, src_ptr,
dst_ptr, dst_ptr,
...@@ -298,48 +315,12 @@ class MooncakeKVManager(BaseKVManager): ...@@ -298,48 +315,12 @@ class MooncakeKVManager(BaseKVManager):
] ]
) )
def start_prefill_thread(self): def transfer_worker(
self.rank_port = get_free_port() self, queue: FastQueue, executor: concurrent.futures.ThreadPoolExecutor
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") ):
def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while True:
waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[0].decode("ascii")
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
if room == "None":
self.decode_kv_args_table[mooncake_session_id] = (
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
)
with self.session_lock:
if mooncake_session_id in self.failed_sessions:
self.failed_sessions.remove(mooncake_session_id)
if mooncake_session_id in self.session_failures:
del self.session_failures[mooncake_session_id]
logger.debug(
f"Register KVArgs from {mooncake_session_id} successfully"
)
continue
else:
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
room = int(room)
if room not in self.transfer_infos:
self.transfer_infos[room] = {}
self.transfer_infos[room][mooncake_session_id] = (
TransferInfo.from_zmq(waiting_req_bytes)
)
# NOTE: after bootstrapping we can mark the req as waiting for input
if len(self.transfer_infos[room]) == required_dst_info_num:
self.update_status(room, KVPoll.WaitingForInput)
def transfer_thread():
# TODO: Shall we use KVPoll.Transferring state?
while True: while True:
try: try:
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01) kv_chunk: TransferKVChunk = queue.get()
reqs_to_be_processed = ( reqs_to_be_processed = (
self.transfer_infos[kv_chunk.room].values() self.transfer_infos[kv_chunk.room].values()
if kv_chunk.room in self.transfer_infos if kv_chunk.room in self.transfer_infos
...@@ -365,20 +346,16 @@ class MooncakeKVManager(BaseKVManager): ...@@ -365,20 +346,16 @@ class MooncakeKVManager(BaseKVManager):
) )
break break
chunked_dst_kv_indice = req.dst_kv_indices[ chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
kv_chunk.index_slice
]
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices # NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen. # is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len( if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices kv_chunk.prefill_kv_indices
): ):
kv_chunk.prefill_kv_indices = ( kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
kv_chunk.prefill_kv_indices[
len(chunked_dst_kv_indice) len(chunked_dst_kv_indice)
] ]
)
logger.warning( logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
) )
...@@ -390,18 +367,14 @@ class MooncakeKVManager(BaseKVManager): ...@@ -390,18 +367,14 @@ class MooncakeKVManager(BaseKVManager):
req.mooncake_session_id req.mooncake_session_id
].dst_kv_ptrs, ].dst_kv_ptrs,
chunked_dst_kv_indice, chunked_dst_kv_indice,
executor,
) )
if ret != 0: if ret != 0:
with self.session_lock: with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1 self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed # Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if ( if self.session_failures[req.mooncake_session_id] >= 1:
self.session_failures[req.mooncake_session_id] self.failed_sessions.add(req.mooncake_session_id)
>= 1
):
self.failed_sessions.add(
req.mooncake_session_id
)
logger.error( logger.error(
f"Session {req.mooncake_session_id} failed." f"Session {req.mooncake_session_id} failed."
) )
...@@ -432,9 +405,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -432,9 +405,7 @@ class MooncakeKVManager(BaseKVManager):
# Only sync status when all the dst ranks have received the kvcache # Only sync status when all the dst ranks have received the kvcache
if len(polls) == req.required_dst_info_num: if len(polls) == req.required_dst_info_num:
status = ( status = KVPoll.Success if all(polls) else KVPoll.Failed
KVPoll.Success if all(polls) else KVPoll.Failed
)
self.update_status(req.room, status) self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos: for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint( self.sync_status_to_decode_endpoint(
...@@ -461,8 +432,44 @@ class MooncakeKVManager(BaseKVManager): ...@@ -461,8 +432,44 @@ class MooncakeKVManager(BaseKVManager):
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead." f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
) )
def start_prefill_thread(self):
self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while True:
waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[0].decode("ascii")
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
if room == "None":
self.decode_kv_args_table[mooncake_session_id] = (
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
)
with self.session_lock:
if mooncake_session_id in self.failed_sessions:
self.failed_sessions.remove(mooncake_session_id)
if mooncake_session_id in self.session_failures:
del self.session_failures[mooncake_session_id]
logger.debug(
f"Register KVArgs from {mooncake_session_id} successfully"
)
continue
else:
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
room = int(room)
if room not in self.transfer_infos:
self.transfer_infos[room] = {}
self.transfer_infos[room][mooncake_session_id] = (
TransferInfo.from_zmq(waiting_req_bytes)
)
# NOTE: after bootstrapping we can mark the req as waiting for input
if len(self.transfer_infos[room]) == required_dst_info_num:
self.update_status(room, KVPoll.WaitingForInput)
threading.Thread(target=bootstrap_thread).start() threading.Thread(target=bootstrap_thread).start()
threading.Thread(target=transfer_thread).start()
def start_decode_thread(self): def start_decode_thread(self):
self.rank_port = get_free_port() self.rank_port = get_free_port()
...@@ -555,7 +562,14 @@ class MooncakeKVManager(BaseKVManager): ...@@ -555,7 +562,14 @@ class MooncakeKVManager(BaseKVManager):
) )
return return
self.transfer_queue.put( # NOTE(shangming): sharding according to the dst_infos to make sure
# requests with the same dst_sessions will be added into the same
# queue, which enables early abort with failed sessions.
dst_infos = self.transfer_infos[bootstrap_room].keys()
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
shard_idx = session_port_sum % len(self.transfer_queues)
self.transfer_queues[shard_idx].put(
TransferKVChunk( TransferKVChunk(
room=bootstrap_room, room=bootstrap_room,
prefill_kv_indices=kv_indices, prefill_kv_indices=kv_indices,
......
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
import dataclasses import dataclasses
import os import os
import random import random
import threading
import warnings import warnings
from collections import deque from collections import deque
from enum import Enum from enum import Enum
...@@ -281,6 +282,25 @@ class MetadataBuffers: ...@@ -281,6 +282,25 @@ class MetadataBuffers:
) )
class FastQueue:
def __init__(self):
self._buf = deque()
self._cond = threading.Condition()
def put(self, item):
with self._cond:
self._buf.append(item)
# wake up a thread of wait()
self._cond.notify()
def get(self):
with self._cond:
# if queue is empty ,block until is notified()
while not self._buf:
self._cond.wait()
return self._buf.popleft()
def group_concurrent_contiguous( def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment