Unverified Commit 33b16ad1 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Distinguish bootstrap key only in decode server (#5422)

parent ffde65a0
...@@ -28,13 +28,7 @@ import numpy as np ...@@ -28,13 +28,7 @@ import numpy as np
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.base import ( from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
KVClassType, KVClassType,
......
...@@ -329,7 +329,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -329,7 +329,7 @@ class MooncakeKVManager(BaseKVManager):
"role": "Prefill", "role": "Prefill",
"rank_ip": get_local_ip_by_remote(), "rank_ip": get_local_ip_by_remote(),
"rank_port": self.rank_port, "rank_port": self.rank_port,
"bootstrap_key": f"{bootstrap_server_url}_{self.kv_args.engine_rank}", "engine_rank": self.kv_args.engine_rank,
} }
try: try:
...@@ -400,28 +400,29 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -400,28 +400,29 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.session_id = self.kv_mgr.get_session_id() self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}" # NOTE: key distinguished by bootstrap_addr and engine_rank
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
if self.bootstrap_key not in self.kv_mgr.connection_pool: if bootstrap_key not in self.kv_mgr.connection_pool:
self.bootstrap_info = self._get_bootstrap_info_from_server( self.bootstrap_info = self._get_bootstrap_info_from_server(
self.bootstrap_key self.kv_mgr.kv_args.engine_rank
) )
if self.bootstrap_info is None: if self.bootstrap_info is None:
logger.error( logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
) )
else: else:
self.kv_mgr.connection_pool[self.bootstrap_key] = self.bootstrap_info self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
else: else:
self.bootstrap_info = self.kv_mgr.connection_pool[self.bootstrap_key] self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
assert self.bootstrap_info is not None assert self.bootstrap_info is not None
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
def _get_bootstrap_info_from_server(self, bootstrap_key: str): def _get_bootstrap_info_from_server(self, engine_rank):
"""Fetch the bootstrap info from the bootstrap server.""" """Fetch the bootstrap info from the bootstrap server."""
try: try:
url = f"http://{self.bootstrap_addr}/route?bootstrap_key={bootstrap_key}" url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
response = requests.get(url) response = requests.get(url)
if response.status_code == 200: if response.status_code == 200:
bootstrap_info = response.json() bootstrap_info = response.json()
...@@ -556,28 +557,28 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -556,28 +557,28 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
role = data["role"] role = data["role"]
rank_ip = data["rank_ip"] rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"]) rank_port = int(data["rank_port"])
bootstrap_key = data["bootstrap_key"] engine_rank = int(data["engine_rank"])
# Add lock to make sure thread-safe # Add lock to make sure thread-safe
if role == "Prefill": if role == "Prefill":
self.prefill_port_table[bootstrap_key] = { self.prefill_port_table[engine_rank] = {
"rank_ip": rank_ip, "rank_ip": rank_ip,
"rank_port": rank_port, "rank_port": rank_port,
} }
logger.debug( logger.debug(
f"Registered Prefill bootstrap_key: {bootstrap_key} with rank_ip: {rank_ip} and rank_port: {rank_port}" f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
) )
return web.Response(text="OK", status=200) return web.Response(text="OK", status=200)
async def _handle_route_get(self, request: web.Request): async def _handle_route_get(self, request: web.Request):
bootstrap_key = request.query.get("bootstrap_key") engine_rank = request.query.get("engine_rank")
if not bootstrap_key: if not engine_rank:
return web.Response(text="Missing bootstrap_key", status=400) return web.Response(text="Missing rank", status=400)
# Find corresponding prefill info # Find corresponding prefill info
async with self.lock: async with self.lock:
bootstrap_info = self.prefill_port_table.get(bootstrap_key) bootstrap_info = self.prefill_port_table.get(int(engine_rank))
if bootstrap_info is not None: if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200) return web.json_response(bootstrap_info, status=200)
......
...@@ -24,13 +24,7 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -24,13 +24,7 @@ from typing import TYPE_CHECKING, List, Optional
import torch import torch
from sglang.srt.disaggregation.base import ( from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
KVClassType, KVClassType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment