Unverified Commit ffde65a0 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Fix dynamic port support and MLA buffer for Mooncake (#5415)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: default avatarybyang <ybyang7@iflytek.com>
parent 471650de
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
class KVArgs: class KVArgs:
...@@ -16,6 +17,7 @@ class KVArgs: ...@@ -16,6 +17,7 @@ class KVArgs:
aux_data_lens: list[int] aux_data_lens: list[int]
aux_item_lens: list[int] aux_item_lens: list[int]
ib_device: str ib_device: str
gpu_id: int
class KVPoll: class KVPoll:
...@@ -30,7 +32,12 @@ class BaseKVManager(ABC): ...@@ -30,7 +32,12 @@ class BaseKVManager(ABC):
"""Base class for managing transfers states""" """Base class for managing transfers states"""
@abstractmethod @abstractmethod
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): ... def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
): ...
class BaseKVSender(ABC): class BaseKVSender(ABC):
......
...@@ -128,8 +128,11 @@ class DecodePreallocQueue: ...@@ -128,8 +128,11 @@ class DecodePreallocQueue:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
] ]
kv_args.ib_device = "mock-ib-device" kv_args.ib_device = "mock-ib-device"
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class(kv_args, DisaggregationMode.DECODE) kv_manager = kv_manager_class(
kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
)
return kv_manager return kv_manager
def add(self, req: Req) -> None: def add(self, req: Req) -> None:
......
...@@ -2,10 +2,9 @@ from __future__ import annotations ...@@ -2,10 +2,9 @@ from __future__ import annotations
import asyncio import asyncio
import dataclasses import dataclasses
import json
import logging import logging
import queue import queue
import random import socket
import struct import struct
import threading import threading
from functools import cache from functools import cache
...@@ -27,24 +26,12 @@ from sglang.srt.disaggregation.base.conn import ( ...@@ -27,24 +26,12 @@ from sglang.srt.disaggregation.base.conn import (
) )
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.utils import is_port_available from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += random.randint(100, 1000)
return available_ports
def group_concurrent_contiguous( def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
...@@ -82,10 +69,10 @@ class TransferKVChunk: ...@@ -82,10 +69,10 @@ class TransferKVChunk:
@dataclasses.dataclass @dataclasses.dataclass
class TransferInfo: class TransferInfo:
room: int
endpoint: str endpoint: str
decode_port: int dst_port: int
mooncake_session_id: str mooncake_session_id: str
room: int
dst_kv_ptrs: list[int] dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64] dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int] dst_aux_ptrs: list[int]
...@@ -94,10 +81,10 @@ class TransferInfo: ...@@ -94,10 +81,10 @@ class TransferInfo:
@classmethod @classmethod
def from_zmq(cls, msg: List[bytes]): def from_zmq(cls, msg: List[bytes]):
return cls( return cls(
endpoint=msg[0].decode("ascii"), room=int(msg[0].decode("ascii")),
decode_port=int(msg[1].decode("ascii")), endpoint=msg[1].decode("ascii"),
mooncake_session_id=msg[2].decode("ascii"), dst_port=int(msg[2].decode("ascii")),
room=int(msg[3].decode("ascii")), mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
...@@ -106,12 +93,20 @@ class TransferInfo: ...@@ -106,12 +93,20 @@ class TransferInfo:
class MooncakeKVManager(BaseKVManager): class MooncakeKVManager(BaseKVManager):
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
):
self.engine = MooncakeTransferEngine() self.engine = MooncakeTransferEngine()
self.kv_args = args self.kv_args = args
self.disaggregation_mode = disaggregation_mode self.disaggregation_mode = disaggregation_mode
# for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.request_status: Dict[int, KVPoll] = {} self.request_status: Dict[int, KVPoll] = {}
self.connection_pool: Dict[int, Dict[str, Union[str, int]]] = {} self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.rank_port = None self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine() self.register_buffer_to_engine()
...@@ -119,6 +114,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -119,6 +114,7 @@ class MooncakeKVManager(BaseKVManager):
self.transfer_queue = queue.Queue() self.transfer_queue = queue.Queue()
self.transfer_infos: Dict[int, TransferInfo] = {} self.transfer_infos: Dict[int, TransferInfo] = {}
self.start_prefill_thread() self.start_prefill_thread()
self._register_to_bootstrap()
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.start_decode_thread() self.start_decode_thread()
else: else:
...@@ -150,54 +146,29 @@ class MooncakeKVManager(BaseKVManager): ...@@ -150,54 +146,29 @@ class MooncakeKVManager(BaseKVManager):
dst_kv_ptrs: list[int], dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64], dst_kv_indices: npt.NDArray[np.int64],
): ):
layer_num = int(len(self.kv_args.kv_data_ptrs) / 2) # group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices prefill_kv_indices, dst_kv_indices
) )
for layer_id in range(layer_num):
prefill_key_layer_ptr = self.kv_args.kv_data_ptrs[layer_id]
key_item_len = self.kv_args.kv_item_lens[layer_id]
prefill_value_layer_ptr = self.kv_args.kv_data_ptrs[layer_num + layer_id]
value_item_len = self.kv_args.kv_item_lens[layer_num + layer_id]
decode_key_layer_ptr = dst_kv_ptrs[layer_id] num_layers = len(self.kv_args.kv_data_ptrs)
decode_value_layer_ptr = dst_kv_ptrs[layer_num + layer_id] for layer_id in range(num_layers):
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
dst_ptr = dst_kv_ptrs[layer_id]
item_len = self.kv_args.kv_item_lens[layer_id]
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
prefill_key_addr = ( src_addr = src_ptr + int(prefill_index[0]) * item_len
prefill_key_layer_ptr + int(prefill_index[0]) * key_item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len
) length = item_len * len(prefill_index)
decode_key_addr = (
decode_key_layer_ptr + int(decode_index[0]) * key_item_len
)
# TODO: mooncake transfer engine can do async transfer. Do async later # TODO: make async later
status = self.engine.transfer_sync( status = self.engine.transfer_sync(
mooncake_session_id, mooncake_session_id, src_addr, dst_addr, length
prefill_key_addr,
decode_key_addr,
key_item_len * len(prefill_index),
) )
if status != 0: if status != 0:
return status return status
prefill_value_addr = (
prefill_value_layer_ptr + int(prefill_index[0]) * value_item_len
)
decode_value_addr = (
decode_value_layer_ptr + int(decode_index[0]) * value_item_len
)
# TODO: mooncake transfer engine can do async transfer. Do async later
status = self.engine.transfer_sync(
mooncake_session_id,
prefill_value_addr,
decode_value_addr,
value_item_len * len(prefill_index),
)
if status != 0:
return status
return 0 return 0
def send_aux( def send_aux(
...@@ -230,16 +201,15 @@ class MooncakeKVManager(BaseKVManager): ...@@ -230,16 +201,15 @@ class MooncakeKVManager(BaseKVManager):
) )
def start_prefill_thread(self): def start_prefill_thread(self):
# Find available port for prefill tp self.rank_port = get_free_port()
self.rank_port = find_available_ports(20000, 1)[0] self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
self.server_socket.bind("tcp://*:" + str(self.rank_port))
def bootstrap_thread(): def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine""" """This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput # KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while True: while True:
waiting_req_bytes = self.server_socket.recv_multipart() waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[3].decode("ascii") room = waiting_req_bytes[0].decode("ascii")
if room == "None": if room == "None":
continue continue
room = int(room) room = int(room)
...@@ -295,8 +265,8 @@ class MooncakeKVManager(BaseKVManager): ...@@ -295,8 +265,8 @@ class MooncakeKVManager(BaseKVManager):
threading.Thread(target=transfer_thread).start() threading.Thread(target=transfer_thread).start()
def start_decode_thread(self): def start_decode_thread(self):
self.rank_port = find_available_ports(25000, 1)[0] self.rank_port = get_free_port()
self.server_socket.bind("tcp://*:" + str(self.rank_port)) self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
def decode_thread(): def decode_thread():
while True: while True:
...@@ -343,54 +313,48 @@ class MooncakeKVManager(BaseKVManager): ...@@ -343,54 +313,48 @@ class MooncakeKVManager(BaseKVManager):
self.request_status[bootstrap_room], status self.request_status[bootstrap_room], status
) )
def get_localhost(self):
return self.engine.get_localhost()
def get_session_id(self): def get_session_id(self):
return self.engine.get_session_id() return self.engine.get_session_id()
class MooncakeKVSender(BaseKVSender):
def __init__(
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.session_id = self.kv_mgr.get_session_id()
# Register to bootstrap server
self._register_to_bootstrap()
def _register_to_bootstrap(self): def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST.""" """Register KVSender to bootstrap server via HTTP POST."""
url = f"http://{self.bootstrap_server_url}/kv_route" if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
else:
ip_address = get_ip()
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = { payload = {
"identity": self.session_id,
"role": "Prefill", "role": "Prefill",
"serve_ip": self.kv_mgr.get_localhost(), "rank_ip": get_local_ip_by_remote(),
"serve_port": self.kv_mgr.rank_port, "rank_port": self.rank_port,
"tp_rank": self.kv_mgr.kv_args.engine_rank, "bootstrap_key": f"{bootstrap_server_url}_{self.kv_args.engine_rank}",
} }
logger.info(
f"Register prefill server port {self.kv_mgr.rank_port} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
try: try:
response = requests.put(url, json=payload) response = requests.put(url, json=payload)
if response.status_code == 200: if response.status_code == 200:
logger.info(f"Prefill successfully registered to bootstrap server.") logger.debug("Prefill successfully registered to bootstrap server.")
else: else:
logger.info( logger.error(
f"Prefill Failed to register to bootstrap server: {response.status_code}, {response.text}" f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
) )
except Exception as e: except Exception as e:
logger.info(f"Prefill Failed to register to bootstrap server: {e}") logger.error(f"Prefill Failed to register to bootstrap server: {e}")
class MooncakeKVSender(BaseKVSender):
def __init__(
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.session_id = self.kv_mgr.get_session_id()
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices self.num_kv_indices = num_kv_indices
...@@ -433,21 +397,35 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -433,21 +397,35 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.bootstrap_room = bootstrap_room self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr self.kv_mgr = mgr
self.decode_ip = self.kv_mgr.get_localhost()
self.session_id = self.kv_mgr.get_session_id() self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
if self.bootstrap_key not in self.kv_mgr.connection_pool:
self.bootstrap_info = self._get_bootstrap_info_from_server(
self.bootstrap_key
)
if self.bootstrap_info is None:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[self.bootstrap_key] = self.bootstrap_info
else:
self.bootstrap_info = self.kv_mgr.connection_pool[self.bootstrap_key]
assert self.bootstrap_info is not None
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
self.prefill_engine_rank = None
self.decode_port = self.kv_mgr.rank_port
self.dealer_socket = None
def _get_prefill_info_from_bootstrap(self, tp_rank: int): def _get_bootstrap_info_from_server(self, bootstrap_key: str):
"""Fetch the prefill server port corresponding to tp_rank from the bootstrap server.""" """Fetch the bootstrap info from the bootstrap server."""
try: try:
url = f"http://{self.bootstrap_addr}/kv_route?tp_rank={tp_rank}" url = f"http://{self.bootstrap_addr}/route?bootstrap_key={bootstrap_key}"
response = requests.get(url) response = requests.get(url)
if response.status_code == 200: if response.status_code == 200:
prefill_info = response.json() bootstrap_info = response.json()
return prefill_info return bootstrap_info
else: else:
logger.error( logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}" f"Failed to get prefill server info: {response.status_code}, {response.text}"
...@@ -464,39 +442,13 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -464,39 +442,13 @@ class MooncakeKVReceiver(BaseKVReceiver):
return socket return socket
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
prefill_info = None
logger.info(f"Decode bootstrap addr {self.bootstrap_addr}.")
if self.kv_mgr.kv_args.engine_rank not in self.kv_mgr.connection_pool:
prefill_info = self._get_prefill_info_from_bootstrap(
self.kv_mgr.kv_args.engine_rank
)
if prefill_info is None:
logger.error(
logger.error(
f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
)
else:
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = (
prefill_info
)
else:
prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank]
if prefill_info:
self.prefill_server_url = ( self.prefill_server_url = (
f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}" f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
) )
logger.debug(
logger.info( f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
) )
self.handshake_prefill_server(kv_indices, aux_index)
def handshake_prefill_server(
self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None
):
packed_kv_data_ptrs = b"".join( packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
) )
...@@ -505,10 +457,10 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -505,10 +457,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
) )
self._connect("tcp://" + self.prefill_server_url).send_multipart( self._connect("tcp://" + self.prefill_server_url).send_multipart(
[ [
self.decode_ip.encode("ascii"),
str(self.decode_port).encode("ascii"),
self.session_id.encode("ascii"),
str(self.bootstrap_room).encode("ascii"), str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"),
packed_kv_data_ptrs, packed_kv_data_ptrs,
kv_indices.tobytes(), kv_indices.tobytes(),
packed_aux_data_ptrs, packed_aux_data_ptrs,
...@@ -530,10 +482,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -530,10 +482,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self.store = dict() self.store = dict()
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self._setup_routes() self._setup_routes()
# prefill_engine_rank -> prefill_info self.prefill_port_table: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
self.context = zmq.Context()
self.prefill_engine_rank = None self.prefill_engine_rank = None
...@@ -546,7 +495,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -546,7 +495,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def _setup_routes(self): def _setup_routes(self):
self.app.router.add_route("*", "/metadata", self._handle_metadata) self.app.router.add_route("*", "/metadata", self._handle_metadata)
self.app.router.add_route("*", "/kv_route", self._handle_kv_route) self.app.router.add_route("*", "/route", self._handle_route)
async def _handle_metadata(self, request: web.Request): async def _handle_metadata(self, request: web.Request):
key = request.query.get("key", "") key = request.query.get("key", "")
...@@ -591,54 +540,47 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -591,54 +540,47 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
text="metadata deleted", status=200, content_type="application/json" text="metadata deleted", status=200, content_type="application/json"
) )
async def _handle_kv_route(self, request: web.Request): async def _handle_route(self, request: web.Request):
method = request.method method = request.method
if method == "PUT": if method == "PUT":
return await self._handle_kv_route_put(request) return await self._handle_route_put(request)
elif method == "GET": elif method == "GET":
return await self._handle_kv_route_get(request) return await self._handle_route_get(request)
else: else:
return web.Response( return web.Response(
text="Method not allowed", status=405, content_type="application/json" text="Method not allowed", status=405, content_type="application/json"
) )
async def _handle_kv_route_put(self, request: web.Request): async def _handle_route_put(self, request: web.Request):
data = await request.json() data = await request.json()
identity = data["identity"]
role = data["role"] role = data["role"]
serve_ip = data["serve_ip"] rank_ip = data["rank_ip"]
serve_port = int(data["serve_port"]) # Assuming serve_port is an integer rank_port = int(data["rank_port"])
tp_rank = int(data["tp_rank"]) bootstrap_key = data["bootstrap_key"]
# Add lock to make sure thread-safe # Add lock to make sure thread-safe
if role == "Prefill": if role == "Prefill":
async with self.lock: self.prefill_port_table[bootstrap_key] = {
self.prefill_port_table[tp_rank] = { "rank_ip": rank_ip,
"serve_ip": serve_ip, "rank_port": rank_port,
"serve_port": serve_port,
} }
logger.info( logger.debug(
f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}" f"Registered Prefill bootstrap_key: {bootstrap_key} with rank_ip: {rank_ip} and rank_port: {rank_port}"
) )
return web.Response(text="OK", status=200) return web.Response(text="OK", status=200)
async def _handle_kv_route_get(self, request: web.Request): async def _handle_route_get(self, request: web.Request):
tp_rank = request.query.get("tp_rank") bootstrap_key = request.query.get("bootstrap_key")
if not tp_rank: if not bootstrap_key:
return web.Response(text="Missing tp_rank", status=400) return web.Response(text="Missing bootstrap_key", status=400)
try:
tp_rank = int(tp_rank)
except ValueError:
return web.Response(text="tp_rank must be int", status=400)
# Find corresponding prefill info # Find corresponding prefill info
async with self.lock: async with self.lock:
prefill_info = self.prefill_port_table.get(tp_rank) bootstrap_info = self.prefill_port_table.get(bootstrap_key)
if prefill_info is not None:
return web.json_response(prefill_info, status=200)
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else: else:
return web.Response(text="Not Found", status=404) return web.Response(text="Not Found", status=404)
......
...@@ -67,6 +67,7 @@ class PrefillBootstrapQueue: ...@@ -67,6 +67,7 @@ class PrefillBootstrapQueue:
bootstrap_port: int, bootstrap_port: int,
gloo_group: ProcessGroup, gloo_group: ProcessGroup,
transfer_backend: TransferBackend, transfer_backend: TransferBackend,
scheduler: Scheduler,
): ):
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool = token_to_kv_pool
self.aux_dtype = aux_dtype self.aux_dtype = aux_dtype
...@@ -76,6 +77,7 @@ class PrefillBootstrapQueue: ...@@ -76,6 +77,7 @@ class PrefillBootstrapQueue:
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.transfer_backend = transfer_backend self.transfer_backend = transfer_backend
self.scheduler = scheduler
self.kv_manager = self._init_kv_manager() self.kv_manager = self._init_kv_manager()
self.queue: List[Req] = [] self.queue: List[Req] = []
self.gloo_group = gloo_group self.gloo_group = gloo_group
...@@ -108,8 +110,11 @@ class PrefillBootstrapQueue: ...@@ -108,8 +110,11 @@ class PrefillBootstrapQueue:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
] ]
kv_args.ib_device = "mock-ib-device" kv_args.ib_device = "mock-ib-device"
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class(kv_args, DisaggregationMode.PREFILL) kv_manager = kv_manager_class(
kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args
)
return kv_manager return kv_manager
def add(self, req: Req) -> None: def add(self, req: Req) -> None:
......
...@@ -599,6 +599,7 @@ class Scheduler( ...@@ -599,6 +599,7 @@ class Scheduler(
bootstrap_port=self.server_args.disaggregation_bootstrap_port, bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(), gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
transfer_backend=self.transfer_backend, transfer_backend=self.transfer_backend,
scheduler=self,
) )
# The prefill requests that are in the middle of kv sending # The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = [] self.disagg_prefill_inflight_queue: List[Req] = []
......
...@@ -1872,3 +1872,36 @@ def is_hopper_with_cuda_12_3(): ...@@ -1872,3 +1872,36 @@ def is_hopper_with_cuda_12_3():
cuda_version = torch.version.cuda.split(".") cuda_version = torch.version.cuda.split(".")
is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3 is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3
return is_hopper and is_cuda_compatible return is_hopper and is_cuda_compatible
def get_free_port():
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip_by_remote() -> str:
# try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try ipv6
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Google's public DNS server, see
# https://developers.google.com/speed/public-dns/docs/using#addresses
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
raise ValueError(f"Can not get local ip")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment