Unverified Commit 6b231325 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[PD Perf] replace Queue to FastQueue (#6649)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent b1c8d4e9
...@@ -31,6 +31,7 @@ from sglang.srt.disaggregation.base.conn import ( ...@@ -31,6 +31,7 @@ from sglang.srt.disaggregation.base.conn import (
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
FastQueue,
group_concurrent_contiguous, group_concurrent_contiguous,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -151,7 +152,6 @@ class MooncakeKVManager(BaseKVManager): ...@@ -151,7 +152,6 @@ class MooncakeKVManager(BaseKVManager):
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine() self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_queue = queue.Queue()
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread() self.start_prefill_thread()
...@@ -159,15 +159,31 @@ class MooncakeKVManager(BaseKVManager): ...@@ -159,15 +159,31 @@ class MooncakeKVManager(BaseKVManager):
self.session_failures = defaultdict(int) self.session_failures = defaultdict(int)
self.failed_sessions = set() self.failed_sessions = set()
self.session_lock = threading.Lock() self.session_lock = threading.Lock()
# Determine the number of threads to use for kv sender # Determine the number of threads to use for kv sender
cpu_count = os.cpu_count() cpu_count = os.cpu_count()
self.executor = concurrent.futures.ThreadPoolExecutor( transfer_thread_pool_size = get_int_env_var(
get_int_env_var( "SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE", min(max(4, int(0.75 * cpu_count) // 8), 12),
min(max(1, cpu_count // 8), 8),
)
) )
transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4)
self.transfer_queues: List[FastQueue] = [
FastQueue() for _ in range(transfer_queue_size)
]
assert transfer_thread_pool_size >= transfer_queue_size, (
f"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be "
f"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}."
)
self.executors = [
concurrent.futures.ThreadPoolExecutor(
transfer_thread_pool_size // transfer_queue_size
)
for _ in range(transfer_queue_size)
]
for queue, executor in zip(self.transfer_queues, self.executors):
threading.Thread(
target=self.transfer_worker, args=(queue, executor), daemon=True
).start()
self.bootstrap_time_out = get_int_env_var( self.bootstrap_time_out = get_int_env_var(
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30 "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
) )
...@@ -183,7 +199,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -183,7 +199,7 @@ class MooncakeKVManager(BaseKVManager):
) )
# Heartbeat failure should be at least 1 # Heartbeat failure should be at least 1
self.max_failures = max( self.max_failures = max(
int(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1 get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
) )
self.start_decode_thread() self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
...@@ -220,6 +236,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -220,6 +236,7 @@ class MooncakeKVManager(BaseKVManager):
prefill_kv_indices: npt.NDArray[np.int64], prefill_kv_indices: npt.NDArray[np.int64],
dst_kv_ptrs: list[int], dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64], dst_kv_indices: npt.NDArray[np.int64],
executor: concurrent.futures.ThreadPoolExecutor,
): ):
# Group by indices # Group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
...@@ -251,7 +268,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -251,7 +268,7 @@ class MooncakeKVManager(BaseKVManager):
return 0 return 0
futures = [ futures = [
self.executor.submit( executor.submit(
process_layer, process_layer,
src_ptr, src_ptr,
dst_ptr, dst_ptr,
...@@ -298,6 +315,123 @@ class MooncakeKVManager(BaseKVManager): ...@@ -298,6 +315,123 @@ class MooncakeKVManager(BaseKVManager):
] ]
) )
def transfer_worker(
self, queue: FastQueue, executor: concurrent.futures.ThreadPoolExecutor
):
while True:
try:
kv_chunk: TransferKVChunk = queue.get()
reqs_to_be_processed = (
self.transfer_infos[kv_chunk.room].values()
if kv_chunk.room in self.transfer_infos
else []
)
polls = []
dst_ranks_infos = []
for req in reqs_to_be_processed:
if not req.is_dummy:
# Early exit if the request has failed
with self.session_lock:
if req.mooncake_session_id in self.failed_sessions:
self.record_failure(
kv_chunk.room,
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
)
break
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices
):
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
len(chunked_dst_kv_indice)
]
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)
ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_kv_ptrs,
chunked_dst_kv_indice,
executor,
)
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if self.session_failures[req.mooncake_session_id] >= 1:
self.failed_sessions.add(req.mooncake_session_id)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room, KVPoll.Failed
)
break
if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index,
)
polls.append(True if ret == 0 else False)
dst_ranks_infos.append(
(req.endpoint, req.dst_port, req.room)
)
# Only sync status when all the dst ranks have received the kvcache
if len(polls) == req.required_dst_info_num:
status = KVPoll.Success if all(polls) else KVPoll.Failed
self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint(
endpoint, dst_port, room, status
)
else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly
# Dummy request does not need to sync status to decode endpoint
if kv_chunk.is_last and req.room in self.request_status:
self.update_status(req.room, KVPoll.Success)
if (
kv_chunk.room not in self.request_status
or self.check_status(kv_chunk.room) == KVPoll.Success
):
if kv_chunk.room in self.transfer_infos:
self.transfer_infos.pop(kv_chunk.room)
except queue.Empty:
continue
except Exception as e:
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
raise RuntimeError(
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)
def start_prefill_thread(self): def start_prefill_thread(self):
self.rank_port = get_free_port() self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
...@@ -335,134 +469,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -335,134 +469,7 @@ class MooncakeKVManager(BaseKVManager):
if len(self.transfer_infos[room]) == required_dst_info_num: if len(self.transfer_infos[room]) == required_dst_info_num:
self.update_status(room, KVPoll.WaitingForInput) self.update_status(room, KVPoll.WaitingForInput)
def transfer_thread():
# TODO: Shall we use KVPoll.Transferring state?
while True:
try:
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
reqs_to_be_processed = (
self.transfer_infos[kv_chunk.room].values()
if kv_chunk.room in self.transfer_infos
else []
)
polls = []
dst_ranks_infos = []
for req in reqs_to_be_processed:
if not req.is_dummy:
# Early exit if the request has failed
with self.session_lock:
if req.mooncake_session_id in self.failed_sessions:
self.record_failure(
kv_chunk.room,
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
)
break
chunked_dst_kv_indice = req.dst_kv_indices[
kv_chunk.index_slice
]
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices
):
kv_chunk.prefill_kv_indices = (
kv_chunk.prefill_kv_indices[
len(chunked_dst_kv_indice)
]
)
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)
ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_kv_ptrs,
chunked_dst_kv_indice,
)
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if (
self.session_failures[req.mooncake_session_id]
>= 1
):
self.failed_sessions.add(
req.mooncake_session_id
)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room, KVPoll.Failed
)
break
if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index,
)
polls.append(True if ret == 0 else False)
dst_ranks_infos.append(
(req.endpoint, req.dst_port, req.room)
)
# Only sync status when all the dst ranks have received the kvcache
if len(polls) == req.required_dst_info_num:
status = (
KVPoll.Success if all(polls) else KVPoll.Failed
)
self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint(
endpoint, dst_port, room, status
)
else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly
# Dummy request does not need to sync status to decode endpoint
if kv_chunk.is_last and req.room in self.request_status:
self.update_status(req.room, KVPoll.Success)
if (
kv_chunk.room not in self.request_status
or self.check_status(kv_chunk.room) == KVPoll.Success
):
if kv_chunk.room in self.transfer_infos:
self.transfer_infos.pop(kv_chunk.room)
except queue.Empty:
continue
except Exception as e:
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
raise RuntimeError(
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)
threading.Thread(target=bootstrap_thread).start() threading.Thread(target=bootstrap_thread).start()
threading.Thread(target=transfer_thread).start()
def start_decode_thread(self): def start_decode_thread(self):
self.rank_port = get_free_port() self.rank_port = get_free_port()
...@@ -555,7 +562,14 @@ class MooncakeKVManager(BaseKVManager): ...@@ -555,7 +562,14 @@ class MooncakeKVManager(BaseKVManager):
) )
return return
self.transfer_queue.put( # NOTE(shangming): sharding according to the dst_infos to make sure
# requests with the same dst_sessions will be added into the same
# queue, which enables early abort with failed sessions.
dst_infos = self.transfer_infos[bootstrap_room].keys()
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
shard_idx = session_port_sum % len(self.transfer_queues)
self.transfer_queues[shard_idx].put(
TransferKVChunk( TransferKVChunk(
room=bootstrap_room, room=bootstrap_room,
prefill_kv_indices=kv_indices, prefill_kv_indices=kv_indices,
......
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
import dataclasses import dataclasses
import os import os
import random import random
import threading
import warnings import warnings
from collections import deque from collections import deque
from enum import Enum from enum import Enum
...@@ -281,6 +282,25 @@ class MetadataBuffers: ...@@ -281,6 +282,25 @@ class MetadataBuffers:
) )
class FastQueue:
def __init__(self):
self._buf = deque()
self._cond = threading.Condition()
def put(self, item):
with self._cond:
self._buf.append(item)
# wake up a thread of wait()
self._cond.notify()
def get(self):
with self._cond:
# if queue is empty ,block until is notified()
while not self._buf:
self._cond.wait()
return self._buf.popleft()
def group_concurrent_contiguous( def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment