Unverified Commit f40038fb authored by Jimmy's avatar Jimmy Committed by GitHub
Browse files

[Vulnerability]feat(conn): set bootstrap server host (#9931)

parent bebd0576
...@@ -131,4 +131,4 @@ class BaseKVReceiver(ABC): ...@@ -131,4 +131,4 @@ class BaseKVReceiver(ABC):
class BaseKVBootstrapServer(ABC): class BaseKVBootstrapServer(ABC):
@abstractmethod @abstractmethod
def __init__(self, port: int): ... def __init__(self, host: str, port: int): ...
...@@ -47,6 +47,7 @@ class CommonKVManager(BaseKVManager): ...@@ -47,6 +47,7 @@ class CommonKVManager(BaseKVManager):
self.is_mla_backend = is_mla_backend self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode self.disaggregation_mode = disaggregation_mode
# for p/d multi node infer # for p/d multi node infer
self.bootstrap_host = server_args.host
self.bootstrap_port = server_args.disaggregation_bootstrap_port self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
...@@ -72,6 +73,7 @@ class CommonKVManager(BaseKVManager): ...@@ -72,6 +73,7 @@ class CommonKVManager(BaseKVManager):
def _register_to_bootstrap(self): def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST.""" """Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr: if self.dist_init_addr:
# multi node: bootstrap server's host is dist_init_addr
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6] if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"): if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr host = self.dist_init_addr
...@@ -80,7 +82,8 @@ class CommonKVManager(BaseKVManager): ...@@ -80,7 +82,8 @@ class CommonKVManager(BaseKVManager):
else: else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0]) host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else: else:
host = get_ip() # single node: bootstrap server's host is same as http server's host
host = self.bootstrap_host
host = maybe_wrap_ipv6_address(host) host = maybe_wrap_ipv6_address(host)
bootstrap_server_url = f"{host}:{self.bootstrap_port}" bootstrap_server_url = f"{host}:{self.bootstrap_port}"
...@@ -308,7 +311,8 @@ class CommonKVReceiver(BaseKVReceiver): ...@@ -308,7 +311,8 @@ class CommonKVReceiver(BaseKVReceiver):
class CommonKVBootstrapServer(BaseKVBootstrapServer): class CommonKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, port: int): def __init__(self, host: str, port: int):
self.host = host
self.port = port self.port = port
self.app = web.Application() self.app = web.Application()
self.store = dict() self.store = dict()
...@@ -412,7 +416,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer): ...@@ -412,7 +416,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
self._runner = web.AppRunner(self.app) self._runner = web.AppRunner(self.app)
self._loop.run_until_complete(self._runner.setup()) self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, port=self.port) site = web.TCPSite(self._runner, host=self.host, port=self.port)
self._loop.run_until_complete(site.start()) self._loop.run_until_complete(site.start())
self._loop.run_forever() self._loop.run_forever()
except Exception as e: except Exception as e:
......
...@@ -24,7 +24,7 @@ import logging ...@@ -24,7 +24,7 @@ import logging
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -218,8 +218,10 @@ class DecodePreallocQueue: ...@@ -218,8 +218,10 @@ class DecodePreallocQueue:
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager_class: Type[BaseKVManager] = get_kv_class(
kv_manager = kv_manager_class( self.transfer_backend, KVClassType.MANAGER
)
kv_manager: BaseKVManager = kv_manager_class(
kv_args, kv_args,
DisaggregationMode.DECODE, DisaggregationMode.DECODE,
self.scheduler.server_args, self.scheduler.server_args,
......
...@@ -175,6 +175,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -175,6 +175,7 @@ class MooncakeKVManager(BaseKVManager):
self.disaggregation_mode = disaggregation_mode self.disaggregation_mode = disaggregation_mode
self.init_engine() self.init_engine()
# for p/d multi node infer # for p/d multi node infer
self.bootstrap_host = server_args.host
self.bootstrap_port = server_args.disaggregation_bootstrap_port self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr self.dist_init_addr = server_args.dist_init_addr
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
...@@ -1020,6 +1021,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -1020,6 +1021,7 @@ class MooncakeKVManager(BaseKVManager):
def _register_to_bootstrap(self): def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST.""" """Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr: if self.dist_init_addr:
# multi node case: bootstrap server's host is dist_init_addr
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6] if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"): if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr host = self.dist_init_addr
...@@ -1028,7 +1030,8 @@ class MooncakeKVManager(BaseKVManager): ...@@ -1028,7 +1030,8 @@ class MooncakeKVManager(BaseKVManager):
else: else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0]) host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else: else:
host = get_ip() # single node case: bootstrap server's host is same as http server's host
host = self.bootstrap_host
host = maybe_wrap_ipv6_address(host) host = maybe_wrap_ipv6_address(host)
bootstrap_server_url = f"{host}:{self.bootstrap_port}" bootstrap_server_url = f"{host}:{self.bootstrap_port}"
...@@ -1545,7 +1548,8 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1545,7 +1548,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
class MooncakeKVBootstrapServer(BaseKVBootstrapServer): class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, port: int): def __init__(self, host: str, port: int):
self.host = host
self.port = port self.port = port
self.app = web.Application() self.app = web.Application()
self.store = dict() self.store = dict()
...@@ -1673,7 +1677,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -1673,7 +1677,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self._runner = web.AppRunner(self.app, access_log=access_log) self._runner = web.AppRunner(self.app, access_log=access_log)
self._loop.run_until_complete(self._runner.setup()) self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, port=self.port) site = web.TCPSite(self._runner, host=self.host, port=self.port)
self._loop.run_until_complete(site.start()) self._loop.run_until_complete(site.start())
self._loop.run_forever() self._loop.run_forever()
except Exception as e: except Exception as e:
......
...@@ -23,7 +23,7 @@ import logging ...@@ -23,7 +23,7 @@ import logging
import threading import threading
from collections import deque from collections import deque
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional, Type
import torch import torch
...@@ -140,8 +140,10 @@ class PrefillBootstrapQueue: ...@@ -140,8 +140,10 @@ class PrefillBootstrapQueue:
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager_class: Type[BaseKVManager] = get_kv_class(
kv_manager = kv_manager_class( self.transfer_backend, KVClassType.MANAGER
)
kv_manager: BaseKVManager = kv_manager_class(
kv_args, kv_args,
DisaggregationMode.PREFILL, DisaggregationMode.PREFILL,
self.scheduler.server_args, self.scheduler.server_args,
......
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
from collections import deque from collections import deque
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, List, Optional, Type, Union
import numpy as np import numpy as np
import torch import torch
...@@ -213,7 +213,9 @@ class KVClassType(Enum): ...@@ -213,7 +213,9 @@ class KVClassType(Enum):
BOOTSTRAP_SERVER = "bootstrap_server" BOOTSTRAP_SERVER = "bootstrap_server"
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): def get_kv_class(
transfer_backend: TransferBackend, class_type: KVClassType
) -> Optional[Type]:
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
if transfer_backend == TransferBackend.MOONCAKE: if transfer_backend == TransferBackend.MOONCAKE:
......
...@@ -40,6 +40,7 @@ from typing import ( ...@@ -40,6 +40,7 @@ from typing import (
List, List,
Optional, Optional,
Tuple, Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
) )
...@@ -53,6 +54,7 @@ from fastapi import BackgroundTasks ...@@ -53,6 +54,7 @@ from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
KVClassType, KVClassType,
...@@ -479,11 +481,12 @@ class TokenizerManager: ...@@ -479,11 +481,12 @@ class TokenizerManager:
# Start kv boostrap server on prefill # Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm # only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class( kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
) )
self.bootstrap_server = kv_bootstrap_server_class( self.bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port host=self.server_args.host,
port=self.server_args.disaggregation_bootstrap_port,
) )
is_create_store = ( is_create_store = (
self.server_args.node_rank == 0 self.server_args.node_rank == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment