"vscode:/vscode.git/clone" did not exist on "88970260f34b12fab5cc112812e233c99ef21e8e"
Unverified Commit a13dd1e4 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Improve disaggregation common backend and refactor mooncake backend (#10273)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent d500eb91
from __future__ import annotations
import asyncio
import dataclasses
import logging
import queue
import socket
import struct
import threading
import uuid
from collections import defaultdict
from functools import cache
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
from typing import Dict, List, Optional, Set
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
from sglang.srt.disaggregation.common.conn import (
CommonKVBootstrapServer,
CommonKVManager,
CommonKVReceiver,
CommonKVSender,
)
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
get_local_ip_auto,
is_valid_ipv6_address,
)
logger = logging.getLogger(__name__)
......@@ -134,16 +123,9 @@ class NixlKVManager(CommonKVManager):
"to run SGLang with NixlTransferEngine."
) from e
self.agent = nixl_agent(str(uuid.uuid4()))
self.local_ip = get_local_ip_auto()
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.request_status: Dict[int, KVPoll] = {}
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self._start_bootstrap_thread()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
......@@ -166,6 +148,9 @@ class NixlKVManager(CommonKVManager):
self.request_status[bootstrap_room], status
)
def record_failure(self, bootstrap_room: int, failure_reason: str):
pass
def register_buffer_to_engine(self):
kv_addrs = []
for kv_data_ptr, kv_data_len in zip(
......@@ -438,7 +423,7 @@ class NixlKVManager(CommonKVManager):
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
if decode_tp_size == self.tp_size:
if decode_tp_size == self.attn_tp_size:
kv_xfer_handle = self.send_kvcache(
req.agent_name,
kv_indices,
......@@ -455,7 +440,7 @@ class NixlKVManager(CommonKVManager):
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
prefill_tp_size=self.tp_size,
prefill_tp_size=self.attn_tp_size,
decode_tp_size=decode_tp_size,
decode_tp_rank=self.decode_kv_args_table[
req.agent_name
......@@ -505,9 +490,6 @@ class NixlKVManager(CommonKVManager):
return False
return self.transfer_statuses[room].is_done()
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
def _start_bootstrap_thread(self):
self._bind_server_socket()
......@@ -548,7 +530,7 @@ class NixlKVManager(CommonKVManager):
threading.Thread(target=bootstrap_thread).start()
class NixlKVSender(BaseKVSender):
class NixlKVSender(CommonKVSender):
def __init__(
self,
......@@ -558,20 +540,10 @@ class NixlKVSender(BaseKVSender):
dest_tp_ranks: List[int],
pp_rank: int,
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
self.xfer_handles = []
self.has_sent = False
self.chunk_id = 0
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
# inner state
self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
self.aux_index = aux_index
def send(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment