Unverified Commit dca90f1d authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Remove the requirement of config file for mooncake backend (#5460)

parent 0961feef
...@@ -121,7 +121,7 @@ class DecodePreallocQueue: ...@@ -121,7 +121,7 @@ class DecodePreallocQueue:
kv_args.aux_item_lens = [ kv_args.aux_item_lens = [
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
] ]
kv_args.ib_device = "mock-ib-device" kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class( kv_manager = kv_manager_class(
......
...@@ -99,8 +99,12 @@ class MooncakeKVManager(BaseKVManager): ...@@ -99,8 +99,12 @@ class MooncakeKVManager(BaseKVManager):
disaggregation_mode: DisaggregationMode, disaggregation_mode: DisaggregationMode,
server_args: ServerArgs, server_args: ServerArgs,
): ):
self.engine = MooncakeTransferEngine()
self.kv_args = args self.kv_args = args
self.engine = MooncakeTransferEngine(
hostname=get_local_ip_by_remote(),
gpu_id=self.kv_args.gpu_id,
ib_device=self.kv_args.ib_device,
)
self.disaggregation_mode = disaggregation_mode self.disaggregation_mode = disaggregation_mode
# for p/d multi node infer # for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port self.bootstrap_port = server_args.disaggregation_bootstrap_port
...@@ -503,52 +507,8 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -503,52 +507,8 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self.thread.start() self.thread.start()
def _setup_routes(self): def _setup_routes(self):
self.app.router.add_route("*", "/metadata", self._handle_metadata)
self.app.router.add_route("*", "/route", self._handle_route) self.app.router.add_route("*", "/route", self._handle_route)
async def _handle_metadata(self, request: web.Request):
key = request.query.get("key", "")
if request.method == "GET":
return await self._handle_metadata_get(key)
elif request.method == "PUT":
return await self._handle_metadata_put(key, request)
elif request.method == "DELETE":
return await self._handle_metadata_delete(key)
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_metadata_get(self, key):
async with self.lock:
value = self.store.get(key)
if value is None:
return web.Response(
text="metadata not found", status=404, content_type="application/json"
)
return web.Response(body=value, status=200, content_type="application/json")
async def _handle_metadata_put(self, key, request):
data = await request.read()
async with self.lock:
self.store[key] = data
return web.Response(
text="metadata updated", status=200, content_type="application/json"
)
async def _handle_metadata_delete(self, key):
async with self.lock:
if key not in self.store:
return web.Response(
text="metadata not found",
status=404,
content_type="application/json",
)
del self.store[key]
return web.Response(
text="metadata deleted", status=200, content_type="application/json"
)
async def _handle_route(self, request: web.Request): async def _handle_route(self, request: web.Request):
method = request.method method = request.method
if method == "PUT": if method == "PUT":
......
import json import json
import logging import logging
import os
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class MooncakeTransferEngineConfig:
local_hostname: str
metadata_server: str
protocol: str
device_name: str
@staticmethod
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
"""Load the config from a JSON file."""
with open(file_path) as fin:
config = json.load(fin)
return MooncakeTransferEngineConfig(
local_hostname=config.get("local_hostname", None),
metadata_server=config.get("metadata_server"),
protocol=config.get("protocol", "rdma"),
device_name=config.get("device_name", ""),
)
@staticmethod
def load_from_env() -> "MooncakeTransferEngineConfig":
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
return MooncakeTransferEngineConfig.from_file(config_file_path)
class MooncakeTransferEngine: class MooncakeTransferEngine:
def __init__(self): def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
try: try:
from mooncake.engine import TransferEngine from mooncake.engine import TransferEngine
except ImportError as e: except ImportError as e:
...@@ -50,43 +19,43 @@ class MooncakeTransferEngine: ...@@ -50,43 +19,43 @@ class MooncakeTransferEngine:
) from e ) from e
self.engine = TransferEngine() self.engine = TransferEngine()
self.hostname = hostname
self.gpu_id = gpu_id
self.ib_device = ib_device
try:
self.config = MooncakeTransferEngineConfig.load_from_env()
logger.info("Mooncake Configuration loaded successfully.")
except ValueError as e:
logger.error(e)
raise
except Exception as exc:
logger.error("An error occurred while loading the configuration: %s", exc)
raise
self.config = MooncakeTransferEngineConfig.load_from_env()
session_suffix = "_" + str(uuid.uuid4())
self.session_id = self.config.local_hostname + session_suffix
self.initialize( self.initialize(
self.session_id, hostname=self.hostname,
self.config.metadata_server, device_name=self.ib_device,
self.config.protocol,
self.config.device_name,
) )
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
def register(self, ptr, length): def register(self, ptr, length):
self.engine.register_memory(ptr, length) ret_value = self.engine.register_memory(ptr, length)
if ret_value != 0:
logger.error("Mooncake memory registration failed.")
raise RuntimeError("Mooncake memory registration failed.")
def deregister(self, ptr): def deregister(self, ptr):
self.engine.unregister_memory(ptr) ret_value = self.engine.unregister_memory(ptr)
if ret_value != 0:
logger.error("Mooncake memory deregistration failed.")
raise RuntimeError("Mooncake memory deregistration failed.")
def initialize( def initialize(
self, self,
local_hostname: str, hostname: str,
metadata_server: str, device_name: Optional[str],
protocol: str,
device_name: str,
) -> None: ) -> None:
"""Initialize the mooncake instance.""" """Initialize the mooncake instance."""
self.engine.initialize(local_hostname, metadata_server, protocol, device_name) ret_value = self.engine.initialize(
hostname,
"P2PHANDSHAKE",
"rdma",
device_name if device_name is not None else "",
)
if ret_value != 0:
logger.error("Mooncake Transfer Engine initialization failed.")
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
def transfer_sync( def transfer_sync(
self, session_id: str, buffer: int, peer_buffer_address: int, length: int self, session_id: str, buffer: int, peer_buffer_address: int, length: int
...@@ -97,12 +66,12 @@ class MooncakeTransferEngine: ...@@ -97,12 +66,12 @@ class MooncakeTransferEngine:
session_id, buffer, peer_buffer_address, length session_id, buffer, peer_buffer_address, length
) )
if ret < 0: if ret < 0:
logger.error("Transfer Return Error") logger.error("Mooncake Transfer Engine Return Error.")
raise Exception("Transfer Return Error") raise RuntimeError("Mooncake Transfer Engine Return Error.")
return ret return ret
def get_localhost(self): def get_localhost(self):
return self.config.local_hostname return self.hostname
def get_session_id(self): def get_session_id(self):
return self.session_id return self.session_id
...@@ -103,7 +103,7 @@ class PrefillBootstrapQueue: ...@@ -103,7 +103,7 @@ class PrefillBootstrapQueue:
kv_args.aux_item_lens = [ kv_args.aux_item_lens = [
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
] ]
kv_args.ib_device = "mock-ib-device" kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class( kv_manager = kv_manager_class(
......
...@@ -196,6 +196,7 @@ class ServerArgs: ...@@ -196,6 +196,7 @@ class ServerArgs:
disaggregation_mode: str = "null" disaggregation_mode: str = "null"
disaggregation_bootstrap_port: int = 8998 disaggregation_bootstrap_port: int = 8998
disaggregation_transfer_backend: str = "mooncake" disaggregation_transfer_backend: str = "mooncake"
disaggregation_ib_device: Optional[str] = None
def __post_init__(self): def __post_init__(self):
# Expert parallelism # Expert parallelism
...@@ -1193,6 +1194,12 @@ class ServerArgs: ...@@ -1193,6 +1194,12 @@ class ServerArgs:
default=ServerArgs.disaggregation_transfer_backend, default=ServerArgs.disaggregation_transfer_backend,
help="The backend for disaggregation transfer. Default is mooncake.", help="The backend for disaggregation transfer. Default is mooncake.",
) )
parser.add_argument(
"--disaggregation-ib-device",
type=str,
default=ServerArgs.disaggregation_ib_device,
help="The ib device for disaggregation transfer. Default is None, it will be detected automatically if using the mooncake backend.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment