"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "41d77dad92c3d2502139adae540a00bf868a8805"
Unverified Commit a13dd1e4 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Improve disaggregation common backend and refactor mooncake backend (#10273)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent d500eb91
from __future__ import annotations from __future__ import annotations
import asyncio
import dataclasses import dataclasses
import logging import logging
import queue
import socket
import struct import struct
import threading import threading
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from functools import cache from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
from sglang.srt.disaggregation.common.conn import ( from sglang.srt.disaggregation.common.conn import (
CommonKVBootstrapServer, CommonKVBootstrapServer,
CommonKVManager, CommonKVManager,
CommonKVReceiver, CommonKVReceiver,
CommonKVSender,
) )
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
get_local_ip_auto,
is_valid_ipv6_address,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -134,16 +123,9 @@ class NixlKVManager(CommonKVManager): ...@@ -134,16 +123,9 @@ class NixlKVManager(CommonKVManager):
"to run SGLang with NixlTransferEngine." "to run SGLang with NixlTransferEngine."
) from e ) from e
self.agent = nixl_agent(str(uuid.uuid4())) self.agent = nixl_agent(str(uuid.uuid4()))
self.local_ip = get_local_ip_auto()
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine() self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.request_status: Dict[int, KVPoll] = {}
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self._start_bootstrap_thread() self._start_bootstrap_thread()
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
...@@ -166,6 +148,9 @@ class NixlKVManager(CommonKVManager): ...@@ -166,6 +148,9 @@ class NixlKVManager(CommonKVManager):
self.request_status[bootstrap_room], status self.request_status[bootstrap_room], status
) )
def record_failure(self, bootstrap_room: int, failure_reason: str):
pass
def register_buffer_to_engine(self): def register_buffer_to_engine(self):
kv_addrs = [] kv_addrs = []
for kv_data_ptr, kv_data_len in zip( for kv_data_ptr, kv_data_len in zip(
...@@ -438,7 +423,7 @@ class NixlKVManager(CommonKVManager): ...@@ -438,7 +423,7 @@ class NixlKVManager(CommonKVManager):
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))]) notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
if decode_tp_size == self.tp_size: if decode_tp_size == self.attn_tp_size:
kv_xfer_handle = self.send_kvcache( kv_xfer_handle = self.send_kvcache(
req.agent_name, req.agent_name,
kv_indices, kv_indices,
...@@ -455,7 +440,7 @@ class NixlKVManager(CommonKVManager): ...@@ -455,7 +440,7 @@ class NixlKVManager(CommonKVManager):
chunked_dst_kv_indice, chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id, self.decode_kv_args_table[req.agent_name].gpu_id,
notif, notif,
prefill_tp_size=self.tp_size, prefill_tp_size=self.attn_tp_size,
decode_tp_size=decode_tp_size, decode_tp_size=decode_tp_size,
decode_tp_rank=self.decode_kv_args_table[ decode_tp_rank=self.decode_kv_args_table[
req.agent_name req.agent_name
...@@ -505,9 +490,6 @@ class NixlKVManager(CommonKVManager): ...@@ -505,9 +490,6 @@ class NixlKVManager(CommonKVManager):
return False return False
return self.transfer_statuses[room].is_done() return self.transfer_statuses[room].is_done()
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
def _start_bootstrap_thread(self): def _start_bootstrap_thread(self):
self._bind_server_socket() self._bind_server_socket()
...@@ -548,7 +530,7 @@ class NixlKVManager(CommonKVManager): ...@@ -548,7 +530,7 @@ class NixlKVManager(CommonKVManager):
threading.Thread(target=bootstrap_thread).start() threading.Thread(target=bootstrap_thread).start()
class NixlKVSender(BaseKVSender): class NixlKVSender(CommonKVSender):
def __init__( def __init__(
self, self,
...@@ -558,20 +540,10 @@ class NixlKVSender(BaseKVSender): ...@@ -558,20 +540,10 @@ class NixlKVSender(BaseKVSender):
dest_tp_ranks: List[int], dest_tp_ranks: List[int],
pp_rank: int, pp_rank: int,
): ):
self.kv_mgr = mgr super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
self.bootstrap_room = bootstrap_room
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.xfer_handles = [] self.xfer_handles = []
self.has_sent = False self.has_sent = False
self.chunk_id = 0 self.chunk_id = 0
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
# inner state
self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
self.aux_index = aux_index
def send( def send(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment