Unverified Commit 30d20ce8 authored by amysaq2023's avatar amysaq2023 Committed by GitHub
Browse files

Support loading weights from remote instance (#8215)


Signed-off-by: default avatarAnqi Shen <amy.saq@antgroup.com>
Co-authored-by: default avatarChayenne <74843776+zhaochenyang20@users.noreply.github.com>
parent 1b1701f1
...@@ -8,10 +8,12 @@ logger = logging.getLogger(__name__) ...@@ -8,10 +8,12 @@ logger = logging.getLogger(__name__)
class DeviceConfig: class DeviceConfig:
device: Optional[torch.device] device: Optional[torch.device]
gpu_id: Optional[int]
def __init__(self, device: str = "cuda") -> None: def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None:
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]: if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
self.device_type = device self.device_type = device
else: else:
raise RuntimeError(f"Not supported device type: {device}") raise RuntimeError(f"Not supported device type: {device}")
self.device = torch.device(self.device_type) self.device = torch.device(self.device_type)
self.gpu_id = gpu_id
...@@ -23,6 +23,7 @@ class LoadFormat(str, enum.Enum): ...@@ -23,6 +23,7 @@ class LoadFormat(str, enum.Enum):
LAYERED = "layered" LAYERED = "layered"
JAX = "jax" JAX = "jax"
REMOTE = "remote" REMOTE = "remote"
REMOTE_INSTANCE = "remote_instance"
@dataclass @dataclass
......
...@@ -64,12 +64,28 @@ class ModelConfig: ...@@ -64,12 +64,28 @@ class ModelConfig:
is_draft_model: bool = False, is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None, hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
tp_rank: Optional[int] = None,
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
remote_instance_weight_loader_send_weights_group_ports: Optional[
List[int]
] = None,
) -> None: ) -> None:
# Parse args # Parse args
self.model_path = model_path self.model_path = model_path
self.revision = revision self.revision = revision
self.quantization = quantization self.quantization = quantization
self.model_impl = model_impl self.model_impl = model_impl
self.tp_rank = tp_rank
self.remote_instance_weight_loader_seed_instance_ip = (
remote_instance_weight_loader_seed_instance_ip
)
self.remote_instance_weight_loader_seed_instance_service_port = (
remote_instance_weight_loader_seed_instance_service_port
)
self.remote_instance_weight_loader_send_weights_group_ports = (
remote_instance_weight_loader_send_weights_group_ports
)
self.maybe_pull_model_tokenizer_from_remote() self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args) self.model_override_args = json.loads(model_override_args)
...@@ -329,6 +345,9 @@ class ModelConfig: ...@@ -329,6 +345,9 @@ class ModelConfig:
quantization=server_args.quantization, quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl, model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs, **kwargs,
) )
......
...@@ -9,6 +9,7 @@ from sglang.srt.connector.base_connector import ( ...@@ -9,6 +9,7 @@ from sglang.srt.connector.base_connector import (
BaseKVConnector, BaseKVConnector,
) )
from sglang.srt.connector.redis import RedisConnector from sglang.srt.connector.redis import RedisConnector
from sglang.srt.connector.remote_instance import RemoteInstanceConnector
from sglang.srt.connector.s3 import S3Connector from sglang.srt.connector.s3 import S3Connector
from sglang.srt.utils import parse_connector_type from sglang.srt.utils import parse_connector_type
...@@ -18,14 +19,17 @@ logger = logging.getLogger(__name__) ...@@ -18,14 +19,17 @@ logger = logging.getLogger(__name__)
class ConnectorType(str, enum.Enum): class ConnectorType(str, enum.Enum):
FS = "filesystem" FS = "filesystem"
KV = "KV" KV = "KV"
INSTANCE = "instance"
def create_remote_connector(url, **kwargs) -> BaseConnector: def create_remote_connector(url, device, **kwargs) -> BaseConnector:
connector_type = parse_connector_type(url) connector_type = parse_connector_type(url)
if connector_type == "redis": if connector_type == "redis":
return RedisConnector(url) return RedisConnector(url)
elif connector_type == "s3": elif connector_type == "s3":
return S3Connector(url) return S3Connector(url)
elif connector_type == "instance":
return RemoteInstanceConnector(url, device)
else: else:
raise ValueError(f"Invalid connector type: {url}") raise ValueError(f"Invalid connector type: {url}")
...@@ -35,6 +39,8 @@ def get_connector_type(client: BaseConnector) -> ConnectorType: ...@@ -35,6 +39,8 @@ def get_connector_type(client: BaseConnector) -> ConnectorType:
return ConnectorType.KV return ConnectorType.KV
if isinstance(client, BaseFileConnector): if isinstance(client, BaseFileConnector):
return ConnectorType.FS return ConnectorType.FS
if isinstance(client, RemoteInstanceConnector):
return ConnectorType.INSTANCE
raise ValueError(f"Invalid connector type: {client}") raise ValueError(f"Invalid connector type: {client}")
...@@ -44,6 +50,7 @@ __all__ = [ ...@@ -44,6 +50,7 @@ __all__ = [
"BaseFileConnector", "BaseFileConnector",
"BaseKVConnector", "BaseKVConnector",
"RedisConnector", "RedisConnector",
"RemoteInstanceConnector",
"S3Connector", "S3Connector",
"ConnectorType", "ConnectorType",
"create_remote_connector", "create_remote_connector",
......
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Generator, List, Optional, Tuple
from urllib.parse import urlparse
import torch
import torch.distributed as dist
from sglang.srt.connector import BaseConnector
from sglang.srt.utils import init_custom_process_group
logger = logging.getLogger(__name__)
class RemoteInstanceConnector(BaseConnector):
def __init__(self, url: str, device: torch.device = "cpu"):
assert (
device.type == "cuda"
), "RemoteInstanceConnector only supports cuda device."
super().__init__(url)
self.url = url
self.device = device
def build_group(
self,
gpu_id: int = -1,
tp_rank: int = -1,
instance_ip: str = None,
group_rank: int = 1,
world_size: int = 2,
):
assert (
self.device.type == "cuda"
), "RemoteInstanceConnector only supports cuda device."
assert (
gpu_id != -1 and tp_rank != -1
), "gpu_id and tp_rank must be specified for RemoteInstanceConnector. "
self.device_id = torch.device(self.device.type, gpu_id)
parsed_url = urlparse(self.url)
master_address = parsed_url.hostname
master_port = parsed_url.port
group_name = f"send_weights_{instance_ip}_{master_port}_{tp_rank}"
backend = "nccl"
logger.info(
f"init custom process group: master_address={master_address}, master_port={master_port}, "
f"rank_offset={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
)
try:
self._model_update_group = init_custom_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=group_rank,
group_name=group_name,
device_id=self.device_id,
)
dist.barrier(group=self._model_update_group)
return True, "Succeeded to initialize custom process group."
except Exception as e:
message = f"Failed to initialize custom process group: {e}."
logger.error(message)
return False, message
# Implemented as a no-op to make BaseConnector interface consistent.
def pull_files(
self,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> None:
return
# Implemented as a no-op to make BaseConnector interface consistent.
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
return
...@@ -73,6 +73,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -73,6 +73,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
OpenSessionReqInput, OpenSessionReqInput,
...@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
ProfileReqInput, ProfileReqInput,
ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
SendWeightsToRemoteInstanceReqInput,
SeparateReasoningReqInput, SeparateReasoningReqInput,
SetInternalStateReq, SetInternalStateReq,
SlowDownReqInput, SlowDownReqInput,
...@@ -670,6 +672,38 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R ...@@ -670,6 +672,38 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
) )
@app.post("/init_weights_send_group_for_remote_instance")
async def init_weights_send_group_for_remote_instance(
obj: InitWeightsSendGroupForRemoteInstanceReqInput, request: Request
):
success, message = (
await _global_state.tokenizer_manager.init_weights_send_group_for_remote_instance(
obj, request
)
)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/send_weights_to_remote_instance")
async def send_weights_to_remote_instance(
obj: SendWeightsToRemoteInstanceReqInput, request: Request
):
success, message = (
await _global_state.tokenizer_manager.send_weights_to_remote_instance(
obj, request
)
)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/init_weights_update_group") @app.post("/init_weights_update_group")
async def init_weights_update_group( async def init_weights_update_group(
obj: InitWeightsUpdateGroupReqInput, request: Request obj: InitWeightsUpdateGroupReqInput, request: Request
......
...@@ -1020,6 +1020,44 @@ class UpdateWeightsFromTensorReqOutput: ...@@ -1020,6 +1020,44 @@ class UpdateWeightsFromTensorReqOutput:
message: str message: str
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqInput:
# The master address
master_address: str
# The ports for each rank's communication group
ports: str
# The rank in the communication group
group_rank: int
# The world size
world_size: int
# The group name
group_name: str = "weight_send_group"
# The backend
backend: str = "nccl"
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput:
success: bool
message: str
@dataclass
class SendWeightsToRemoteInstanceReqInput:
# The master address
master_address: str
# The ports for each rank's communication group
ports: str
# The group name
group_name: str = "weight_send_group"
@dataclass
class SendWeightsToRemoteInstanceReqOutput:
success: bool
message: str
@dataclass @dataclass
class InitWeightsUpdateGroupReqInput: class InitWeightsUpdateGroupReqInput:
# The master address # The master address
......
...@@ -81,6 +81,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -81,6 +81,8 @@ from sglang.srt.managers.io_struct import (
GetInternalStateReqOutput, GetInternalStateReqOutput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
HealthCheckOutput, HealthCheckOutput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsSendGroupForRemoteInstanceReqOutput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput, LoadLoRAAdapterReqOutput,
...@@ -93,6 +95,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -93,6 +95,8 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
RpcReqInput, RpcReqInput,
RpcReqOutput, RpcReqOutput,
SendWeightsToRemoteInstanceReqInput,
SendWeightsToRemoteInstanceReqOutput,
SetInternalStateReq, SetInternalStateReq,
SetInternalStateReqOutput, SetInternalStateReqOutput,
SlowDownReqInput, SlowDownReqInput,
...@@ -538,6 +542,14 @@ class Scheduler( ...@@ -538,6 +542,14 @@ class Scheduler(
(CloseSessionReqInput, self.close_session), (CloseSessionReqInput, self.close_session),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk), (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group), (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(
InitWeightsSendGroupForRemoteInstanceReqInput,
self.init_weights_send_group_for_remote_instance,
),
(
SendWeightsToRemoteInstanceReqInput,
self.send_weights_to_remote_instance,
),
( (
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
self.update_weights_from_distributed, self.update_weights_from_distributed,
...@@ -2429,6 +2441,22 @@ class Scheduler( ...@@ -2429,6 +2441,22 @@ class Scheduler(
self.send_to_detokenizer.send_pyobj(recv_req) self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req return recv_req
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
"""Init the seed and client instance communication group."""
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
recv_req
)
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
"""Send the seed instance weights to the destination instance."""
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
return SendWeightsToRemoteInstanceReqOutput(success, message)
def slow_down(self, recv_req: SlowDownReqInput): def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time t = recv_req.forward_sleep_time
if t is not None and t <= 0: if t is not None and t <= 0:
......
...@@ -30,6 +30,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -30,6 +30,8 @@ from sglang.srt.managers.io_struct import (
GetInternalStateReqOutput, GetInternalStateReqOutput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
GetWeightsByNameReqOutput, GetWeightsByNameReqOutput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsSendGroupForRemoteInstanceReqOutput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput, InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
...@@ -43,6 +45,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -43,6 +45,8 @@ from sglang.srt.managers.io_struct import (
ReleaseMemoryOccupationReqOutput, ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput, ResumeMemoryOccupationReqOutput,
SendWeightsToRemoteInstanceReqInput,
SendWeightsToRemoteInstanceReqOutput,
SetInternalStateReq, SetInternalStateReq,
SetInternalStateReqOutput, SetInternalStateReqOutput,
SlowDownReqInput, SlowDownReqInput,
...@@ -119,6 +123,12 @@ class TokenizerCommunicatorMixin: ...@@ -119,6 +123,12 @@ class TokenizerCommunicatorMixin:
self.update_weights_from_distributed_communicator = _Communicator( self.update_weights_from_distributed_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.init_weights_send_group_for_remote_instance_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.send_weights_to_remote_instance_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_tensor_communicator = _Communicator( self.update_weights_from_tensor_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
...@@ -169,6 +179,14 @@ class TokenizerCommunicatorMixin: ...@@ -169,6 +179,14 @@ class TokenizerCommunicatorMixin:
UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromDistributedReqOutput,
self.update_weights_from_distributed_communicator.handle_recv, self.update_weights_from_distributed_communicator.handle_recv,
), ),
(
InitWeightsSendGroupForRemoteInstanceReqOutput,
self.init_weights_send_group_for_remote_instance_communicator.handle_recv,
),
(
SendWeightsToRemoteInstanceReqOutput,
self.send_weights_to_remote_instance_communicator.handle_recv,
),
( (
UpdateWeightsFromTensorReqOutput, UpdateWeightsFromTensorReqOutput,
self.update_weights_from_tensor_communicator.handle_recv, self.update_weights_from_tensor_communicator.handle_recv,
...@@ -310,6 +328,34 @@ class TokenizerCommunicatorMixin: ...@@ -310,6 +328,34 @@ class TokenizerCommunicatorMixin:
result = (await self.update_weights_from_distributed_communicator(obj))[0] result = (await self.update_weights_from_distributed_communicator(obj))[0]
return result.success, result.message return result.success, result.message
async def init_weights_send_group_for_remote_instance(
self,
obj: InitWeightsSendGroupForRemoteInstanceReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
# TODO: support DP
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init_weights_send_group_for_remote_instance"
result = (
await self.init_weights_send_group_for_remote_instance_communicator(obj)
)[0]
return result.success, result.message
async def send_weights_to_remote_instance(
self,
obj: SendWeightsToRemoteInstanceReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
# TODO: support DP
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for send_weights_to_remote_instance"
result = (await self.send_weights_to_remote_instance_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_tensor( async def update_weights_from_tensor(
self: TokenizerManager, self: TokenizerManager,
obj: UpdateWeightsFromTensorReqInput, obj: UpdateWeightsFromTensorReqInput,
......
...@@ -30,8 +30,10 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -30,8 +30,10 @@ from sglang.srt.hf_transformers_utils import (
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
SendWeightsToRemoteInstanceReqInput,
UnloadLoRAAdapterReqInput, UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
...@@ -88,6 +90,7 @@ class TpModelWorker: ...@@ -88,6 +90,7 @@ class TpModelWorker:
else server_args.speculative_draft_model_revision else server_args.speculative_draft_model_revision
), ),
is_draft_model=is_draft_worker, is_draft_model=is_draft_worker,
tp_rank=tp_rank,
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
...@@ -292,6 +295,31 @@ class TpModelWorker: ...@@ -292,6 +295,31 @@ class TpModelWorker:
) )
return success, message return success, message
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
success, message = (
self.model_runner.init_weights_send_group_for_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_rank,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
)
return success, message
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
success, message = self.model_runner.send_weights_to_remote_instance(
recv_req.master_address,
recv_req.ports,
recv_req.group_name,
)
return success, message
def update_weights_from_distributed( def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput self, recv_req: UpdateWeightsFromDistributedReqInput
): ):
......
...@@ -26,8 +26,10 @@ import torch ...@@ -26,8 +26,10 @@ import torch
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
SendWeightsToRemoteInstanceReqInput,
UnloadLoRAAdapterReqInput, UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
...@@ -267,6 +269,20 @@ class TpModelWorkerClient: ...@@ -267,6 +269,20 @@ class TpModelWorkerClient:
success, message = self.worker.init_weights_update_group(recv_req) success, message = self.worker.init_weights_update_group(recv_req)
return success, message return success, message
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
success, message = self.worker.init_weights_send_group_for_remote_instance(
recv_req
)
return success, message
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
success, message = self.worker.send_weights_to_remote_instance(recv_req)
return success, message
def update_weights_from_distributed( def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput self, recv_req: UpdateWeightsFromDistributedReqInput
): ):
......
...@@ -19,18 +19,23 @@ import inspect ...@@ -19,18 +19,23 @@ import inspect
import json import json
import logging import logging
import os import os
import socket
import threading
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from urllib.parse import urlparse
import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
from sglang.srt.connector import ConnectorType
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_pp_group, get_pp_group,
...@@ -106,6 +111,9 @@ from sglang.srt.offloader import ( ...@@ -106,6 +111,9 @@ from sglang.srt.offloader import (
set_offloader, set_offloader,
) )
from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.remote_instance_weight_loader_utils import (
trigger_init_weights_send_group_for_remote_instance_request,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
...@@ -128,6 +136,7 @@ from sglang.srt.utils import ( ...@@ -128,6 +136,7 @@ from sglang.srt.utils import (
is_sm100_supported, is_sm100_supported,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
parse_connector_type,
set_cuda_arch, set_cuda_arch,
) )
from sglang.srt.weight_sync.tensor_bucket import ( from sglang.srt.weight_sync.tensor_bucket import (
...@@ -256,6 +265,7 @@ class ModelRunner: ...@@ -256,6 +265,7 @@ class ModelRunner:
# For weight updates # For weight updates
self._model_update_group = {} self._model_update_group = {}
self._weights_send_group = {}
def initialize(self, min_per_gpu_memory: float): def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args server_args = self.server_args
...@@ -726,6 +736,20 @@ class ModelRunner: ...@@ -726,6 +736,20 @@ class ModelRunner:
if self.server_args.load_format == "gguf": if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config() monkey_patch_vllm_gguf_config()
if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
if self.tp_rank == 0:
instance_ip = socket.gethostbyname(socket.gethostname())
t = threading.Thread(
target=trigger_init_weights_send_group_for_remote_instance_request,
args=(
self.server_args.remote_instance_weight_loader_seed_instance_ip,
self.server_args.remote_instance_weight_loader_seed_instance_service_port,
self.server_args.remote_instance_weight_loader_send_weights_group_ports,
instance_ip,
),
)
t.start()
# Load the model # Load the model
# Remove monkey_patch when linear.py quant remove dependencies with vllm # Remove monkey_patch when linear.py quant remove dependencies with vllm
monkey_patch_vllm_parallel_state() monkey_patch_vllm_parallel_state()
...@@ -735,7 +759,7 @@ class ModelRunner: ...@@ -735,7 +759,7 @@ class ModelRunner:
self.model = get_model( self.model = get_model(
model_config=self.model_config, model_config=self.model_config,
load_config=self.load_config, load_config=self.load_config,
device_config=DeviceConfig(self.device), device_config=DeviceConfig(self.device, self.gpu_id),
) )
monkey_patch_vllm_parallel_state(reverse=True) monkey_patch_vllm_parallel_state(reverse=True)
monkey_patch_isinstance_for_vllm_base_layer(reverse=True) monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
...@@ -867,6 +891,103 @@ class ModelRunner: ...@@ -867,6 +891,103 @@ class ModelRunner:
logger.info("Update weights end.") logger.info("Update weights end.")
return True, "Succeeded to update model weights." return True, "Succeeded to update model weights."
def init_weights_send_group_for_remote_instance(
self,
master_address,
ports,
group_rank,
world_size,
group_name,
backend="nccl",
):
assert (
torch.distributed.is_initialized()
), "Default torch process group must be initialized"
assert group_name != "", "Group name cannot be empty"
ports_list = ports.split(",")
assert (
len(ports_list) == self.tp_size
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
group_port = ports_list[self.tp_rank]
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
logger.info(
f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
)
torch.cuda.empty_cache()
success = False
message = ""
try:
self._weights_send_group[group_name] = init_custom_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{group_port}",
world_size=world_size,
rank=group_rank,
group_name=group_name,
device_id=torch.device("cuda", self.gpu_id),
)
dist.barrier(group=self._weights_send_group[group_name])
success = True
message = (
f"Succeeded to init group through {master_address}:{group_port} group."
)
except Exception as e:
message = f"Failed to init group: {e}."
logger.error(message)
torch.cuda.empty_cache()
return success, message
def send_weights_to_remote_instance(
self,
master_address,
ports,
group_name,
):
assert (
torch.distributed.is_initialized()
), "Default torch process group must be initialized"
assert group_name != "", "Group name cannot be empty"
ports_list = ports.split(",")
assert (
len(ports_list) == self.tp_size
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
group_port = ports_list[self.tp_rank]
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
if self._weights_send_group[group_name] is not None:
send_group = self._weights_send_group[group_name]
else:
message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
logger.error(message)
return False, message
torch.cuda.empty_cache()
success = False
message = ""
try:
for _, weights in self.model.named_parameters():
torch.distributed.broadcast(
weights,
src=0,
group=send_group,
)
success = True
message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
except Exception as e:
message = f"Failed to send weights: {e}."
logger.error(message)
# destroy the process group after sending weights
del self._weights_send_group[group_name]
torch.distributed.distributed_c10d.destroy_process_group(send_group)
torch.cuda.empty_cache()
return success, message
def init_weights_update_group( def init_weights_update_group(
self, self,
master_address, master_address,
......
...@@ -12,6 +12,9 @@ import json ...@@ -12,6 +12,9 @@ import json
import logging import logging
import math import math
import os import os
import re
import socket
import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
...@@ -27,9 +30,11 @@ from typing import ( ...@@ -27,9 +30,11 @@ from typing import (
Tuple, Tuple,
cast, cast,
) )
from urllib.parse import urlparse
import huggingface_hub import huggingface_hub
import numpy as np import numpy as np
import requests
import safetensors.torch import safetensors.torch
import torch import torch
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
...@@ -56,6 +61,7 @@ from sglang.srt.model_loader.utils import ( ...@@ -56,6 +61,7 @@ from sglang.srt.model_loader.utils import (
) )
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
_BAR_FORMAT, _BAR_FORMAT,
default_weight_loader,
download_safetensors_index_file_from_hf, download_safetensors_index_file_from_hf,
download_weights_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_duplicate_safetensors_files,
...@@ -71,6 +77,9 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -71,6 +77,9 @@ from sglang.srt.model_loader.weight_utils import (
safetensors_weights_iterator, safetensors_weights_iterator,
set_runai_streamer_env, set_runai_streamer_env,
) )
from sglang.srt.remote_instance_weight_loader_utils import (
trigger_transferring_weights_request,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_device_capability, get_device_capability,
...@@ -1380,6 +1389,104 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -1380,6 +1389,104 @@ class GGUFModelLoader(BaseModelLoader):
return model return model
class RemoteInstanceModelLoader(BaseModelLoader):
"""Model loader that can load Tensors from remote sglang instance."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(
f"Model loader extra config is not supported for "
f"load format {load_config.load_format}"
)
def download_model(self, model_config: ModelConfig) -> None:
raise NotImplementedError
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
logger.info("Loading weights from remote instance ...")
load_config = self.load_config
assert load_config.load_format == LoadFormat.REMOTE_INSTANCE, (
f"Model loader {self.load_config.load_format} is not supported for "
f"load format {load_config.load_format}"
)
model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config)
with create_remote_connector(model_weights, device_config.device) as client:
connector_type = get_connector_type(client)
if connector_type == ConnectorType.INSTANCE:
self.load_model_from_remote_instance(
model, client, model_config, device_config
)
else:
raise ValueError(
f"Unsupported connector type {connector_type} for "
f"remote tensor model loading."
)
return model.eval()
def load_model_from_remote_instance(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
instance_ip = socket.gethostbyname(socket.gethostname())
start_build_group_tic = time.time()
client.build_group(
gpu_id=device_config.gpu_id,
tp_rank=model_config.tp_rank,
instance_ip=instance_ip,
)
torch.cuda.synchronize()
end_build_group_tic = time.time()
logger.debug(
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
)
if model_config.tp_rank == 0:
t = threading.Thread(
target=trigger_transferring_weights_request,
args=(
model_config.remote_instance_weight_loader_seed_instance_ip,
model_config.remote_instance_weight_loader_seed_instance_service_port,
model_config.remote_instance_weight_loader_send_weights_group_ports,
instance_ip,
),
)
t.start()
start_get_weights_tic = time.time()
with set_default_torch_dtype(model_config.dtype):
for _, tensor in model.named_parameters():
torch.distributed.broadcast(
tensor.data,
src=0,
group=client._model_update_group,
)
torch.cuda.synchronize()
if hasattr(model, "post_load_weights"):
model.post_load_weights()
end_get_weights_tic = time.time()
logger.debug(
f"finish getting all weights from remote instance, time used: {(end_get_weights_tic - start_get_weights_tic):.4f}s"
)
# destroy the process group after loading weights
torch.distributed.distributed_c10d.destroy_process_group(
client._model_update_group
)
torch.cuda.empty_cache()
class RemoteModelLoader(BaseModelLoader): class RemoteModelLoader(BaseModelLoader):
"""Model loader that can load Tensors from remote database.""" """Model loader that can load Tensors from remote database."""
...@@ -1581,4 +1688,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: ...@@ -1581,4 +1688,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.REMOTE: if load_config.load_format == LoadFormat.REMOTE:
return RemoteModelLoader(load_config) return RemoteModelLoader(load_config)
if load_config.load_format == LoadFormat.REMOTE_INSTANCE:
return RemoteInstanceModelLoader(load_config)
return DefaultModelLoader(load_config) return DefaultModelLoader(load_config)
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import List
import requests
logger = logging.getLogger(__name__)
def trigger_init_weights_send_group_for_remote_instance_request(
remote_instance_weight_loader_seed_instance_ip: str,
remote_instance_weight_loader_seed_instance_service_port: int,
remote_instance_weight_loader_send_weights_group_ports: List[int],
remote_instance_weight_loader_client_id: str,
):
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
# Only support loading weights from instance with same parallelism strategy.
# Per TP rank pair between seed and dst instances will build a communication group for sending weights.
# i.e. seed TP 0 <-> dst TP 0, seed TP 1 <-> dst TP 1, etc.
# Each communication group will have a world size 2.
try:
requests.post(
f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance",
json={
"master_address": remote_instance_weight_loader_seed_instance_ip,
"ports": (
",".join(
str(p)
for p in remote_instance_weight_loader_send_weights_group_ports
)
),
"group_rank": 0,
"world_size": 2,
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
"backend": "nccl",
},
)
except Exception as e:
logger.error(
f"Failed to trigger init_weights_send_group_for_remote_instance_request to seed instance {seed_instance_service_url}: {e}."
)
raise
def trigger_transferring_weights_request(
remote_instance_weight_loader_seed_instance_ip: str,
remote_instance_weight_loader_seed_instance_service_port: int,
remote_instance_weight_loader_send_weights_group_ports: List[int],
remote_instance_weight_loader_client_id: str,
):
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
try:
requests.post(
f"{seed_instance_service_url}/send_weights_to_remote_instance",
json={
"master_address": remote_instance_weight_loader_seed_instance_ip,
"ports": (
",".join(
str(p)
for p in remote_instance_weight_loader_send_weights_group_ports
)
),
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
},
)
except Exception as e:
logger.error(f"Failed to trigger send weights to remote instance request: {e}")
raise
...@@ -19,10 +19,12 @@ import json ...@@ -19,10 +19,12 @@ import json
import logging import logging
import os import os
import random import random
import socket
import sys import sys
import tempfile import tempfile
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
from sglang.srt.connector import ConnectorType
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.lora_registry import LoRARef
...@@ -42,7 +44,9 @@ from sglang.srt.utils import ( ...@@ -42,7 +44,9 @@ from sglang.srt.utils import (
is_sm100_supported, is_sm100_supported,
is_triton_kernels_available, is_triton_kernels_available,
is_valid_ipv6_address, is_valid_ipv6_address,
json_list_type,
nullable_str, nullable_str,
parse_connector_type,
) )
from sglang.utils import is_in_ci from sglang.utils import is_in_ci
...@@ -61,6 +65,7 @@ LOAD_FORMAT_CHOICES = [ ...@@ -61,6 +65,7 @@ LOAD_FORMAT_CHOICES = [
"bitsandbytes", "bitsandbytes",
"layered", "layered",
"remote", "remote",
"remote_instance",
] ]
QUANTIZATION_CHOICES = [ QUANTIZATION_CHOICES = [
...@@ -387,6 +392,11 @@ class ServerArgs: ...@@ -387,6 +392,11 @@ class ServerArgs:
custom_weight_loader: Optional[List[str]] = None custom_weight_loader: Optional[List[str]] = None
weight_loader_disable_mmap: bool = False weight_loader_disable_mmap: bool = False
# Remote instance weight loading
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
# For PD-Multiplexing # For PD-Multiplexing
enable_pdmux: bool = False enable_pdmux: bool = False
sm_group_num: int = 3 sm_group_num: int = 3
...@@ -445,6 +455,7 @@ class ServerArgs: ...@@ -445,6 +455,7 @@ class ServerArgs:
# Set missing default values # Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
if self.served_model_name is None: if self.served_model_name is None:
self.served_model_name = self.model_path self.served_model_name = self.model_path
if self.device is None: if self.device is None:
...@@ -538,7 +549,8 @@ class ServerArgs: ...@@ -538,7 +549,8 @@ class ServerArgs:
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
# Model-specific adjustments # Model-specific adjustments
self.model_specific_adjustments() if parse_connector_type(self.model_path) != ConnectorType.INSTANCE:
self.model_specific_adjustments()
# Set kernel backends # Set kernel backends
if self.device == "cpu": if self.device == "cpu":
...@@ -818,12 +830,19 @@ class ServerArgs: ...@@ -818,12 +830,19 @@ class ServerArgs:
) and check_gguf_file(self.model_path): ) and check_gguf_file(self.model_path):
self.quantization = self.load_format = "gguf" self.quantization = self.load_format = "gguf"
# Model loading
if is_remote_url(self.model_path): if is_remote_url(self.model_path):
self.load_format = "remote" self.load_format = "remote"
if self.custom_weight_loader is None: if self.custom_weight_loader is None:
self.custom_weight_loader = [] self.custom_weight_loader = []
if self.load_format == "remote_instance":
if (
self.remote_instance_weight_loader_seed_instance_ip is None
or self.remote_instance_weight_loader_seed_instance_service_port is None
or self.remote_instance_weight_loader_send_weights_group_ports is None
):
self.load_format = "auto"
# PD disaggregation # PD disaggregation
if self.disaggregation_mode == "decode": if self.disaggregation_mode == "decode":
assert ( assert (
...@@ -881,6 +900,24 @@ class ServerArgs: ...@@ -881,6 +900,24 @@ class ServerArgs:
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
required=True, required=True,
) )
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-ip",
type=str,
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
help="The ip of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-service-port",
type=int,
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
help="The service port of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-send-weights-group-ports",
type=json_list_type,
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
help="The communication group ports for loading weights from remote instance.",
)
parser.add_argument( parser.add_argument(
"--tokenizer-path", "--tokenizer-path",
type=str, type=str,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import annotations from __future__ import annotations
import argparse
import asyncio import asyncio
import builtins import builtins
import ctypes import ctypes
...@@ -1431,6 +1432,7 @@ def init_custom_process_group( ...@@ -1431,6 +1432,7 @@ def init_custom_process_group(
store=None, store=None,
group_name=None, group_name=None,
pg_options=None, pg_options=None,
device_id=None,
): ):
from torch.distributed.distributed_c10d import ( from torch.distributed.distributed_c10d import (
Backend, Backend,
...@@ -1484,6 +1486,7 @@ def init_custom_process_group( ...@@ -1484,6 +1486,7 @@ def init_custom_process_group(
group_name=group_name, group_name=group_name,
**{pg_options_param_name: pg_options}, **{pg_options_param_name: pg_options},
timeout=timeout, timeout=timeout,
device_id=device_id,
) )
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)} _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
...@@ -3046,3 +3049,12 @@ def numa_bind_to_node(node: int): ...@@ -3046,3 +3049,12 @@ def numa_bind_to_node(node: int):
libnuma.numa_run_on_node(ctypes.c_int(node)) libnuma.numa_run_on_node(ctypes.c_int(node))
libnuma.numa_set_localalloc() libnuma.numa_set_localalloc()
def json_list_type(value):
try:
return json.loads(value)
except json.JSONDecodeError:
raise argparse.ArgumentTypeError(
f"Invalid JSON list: {value}. Please provide a valid JSON list."
)
...@@ -123,6 +123,7 @@ suites = { ...@@ -123,6 +123,7 @@ suites = {
TestFile("rl/test_update_weights_from_distributed.py", 103), TestFile("rl/test_update_weights_from_distributed.py", 103),
TestFile("test_data_parallelism.py", 73), TestFile("test_data_parallelism.py", 73),
TestFile("test_dp_attention.py", 277), TestFile("test_dp_attention.py", 277),
TestFile("test_load_weights_from_remote_instance.py", 72),
TestFile("test_patch_torch.py", 19), TestFile("test_patch_torch.py", 19),
TestFile("test_release_memory_occupation.py", 127), TestFile("test_release_memory_occupation.py", 127),
TestFile("hicache/test_hicache_storage_file_backend.py", 400), TestFile("hicache/test_hicache_storage_file_backend.py", 400),
...@@ -251,6 +252,7 @@ suite_amd = { ...@@ -251,6 +252,7 @@ suite_amd = {
TestFile("lora/test_lora_tp.py", 116), TestFile("lora/test_lora_tp.py", 116),
TestFile("rl/test_update_weights_from_distributed.py", 103), TestFile("rl/test_update_weights_from_distributed.py", 103),
TestFile("test_data_parallelism.py", 73), TestFile("test_data_parallelism.py", 73),
TestFile("test_load_weights_from_remote_instance.py", 72),
TestFile("test_patch_torch.py", 19), TestFile("test_patch_torch.py", 19),
], ],
"per-commit-4-gpu-amd": [ "per-commit-4-gpu-amd": [
......
"""Test loading weights from remote instance.
This test suite simulates loading weights from a remote instance.
Rank 0 represents the seed instance, while ranks 1 represents the
new instance that needs to loading weights from the seed instance.
Seed instance must be started in `Server` mode, while the dst instance
can be either `Engine` mode or `Server` mode.
Seed instance does not support concurrently serving multiple dst instances.
User has to guarantee that there is only one dst instance trying to load
weights from the seed instance at any time.
"""
import gc
import os
import random
import unittest
import numpy as np
import requests
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import sglang as sgl
from sglang.test.test_utils import (
DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
from sglang.utils import terminate_process
mp.set_start_method("spawn", force=True)
def verify_params_close(params1, params2, error_msg):
"""Verify if two parameter arrays are close enough."""
try:
assert np.allclose(np.array(params1), np.array(params2)), error_msg
except Exception as e:
print(f"Parameters not close for {error_msg}")
print("Params1:", np.array(params1))
print("Params2:", np.array(params2))
raise e
def init_process(
rank,
param_queue,
truncate_size,
tp_size,
model_name,
backends,
checking_parameters,
seed_instance_ip,
seed_instance_service_port,
seed_instance_group_base_port,
event_seed_ready,
event_dst_ready_list,
):
torch.cuda.set_device(rank)
if rank == 0:
init_process_seed(
rank,
param_queue,
truncate_size,
model_name,
checking_parameters,
tp_size,
event_seed_ready,
event_dst_ready_list,
)
elif rank in [1, 2]:
init_process_dst(
rank,
param_queue,
truncate_size,
model_name,
seed_instance_ip,
seed_instance_service_port,
seed_instance_group_base_port,
checking_parameters,
backends[rank - 1],
tp_size,
event_seed_ready,
event_dst_ready_list,
)
def init_process_seed(
rank,
param_queue,
truncate_size,
model_name,
checking_parameters,
tp_size,
event_seed_ready,
event_dst_ready_list,
):
# These two environment variables are very important
# to avoid unexpected behaviors of CUDA and NCCL.
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
# Load model and get parameters
torch.cuda.set_device(rank)
torch.cuda.synchronize()
url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
model_name,
url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--base-gpu-id",
str(rank),
"--tp-size",
str(tp_size),
),
)
torch.cuda.synchronize()
seed_params = []
# Get the weights of seed instance for correctness check.
for parameter_name in checking_parameters:
seed_params.append(
requests.get(
f"{url}/get_weights_by_name",
json={
"name": parameter_name,
"truncate_size": truncate_size,
},
).json()
)
param_queue.put((f"seed_params", seed_params))
event_seed_ready.set()
for i in range(len(event_dst_ready_list)):
event_dst_ready_list[i].wait()
terminate_process(process)
def init_process_dst(
rank,
param_queue,
truncate_size,
model_name,
seed_instance_ip,
seed_instance_service_port,
seed_instance_group_base_port,
checking_parameters,
backend,
tp_size,
event_seed_ready,
event_dst_ready_list,
):
torch.cuda.set_device(rank * tp_size)
torch.cuda.synchronize()
base_gpu_id = rank * tp_size
event_seed_ready.wait()
print(f"rank {rank}, seed ready")
for i in range(rank - 1):
print(f"rank {rank}, wait dst {i}")
event_dst_ready_list[i].wait()
ports = []
for i in range(tp_size):
ports.append(seed_instance_group_base_port + (rank - 1) * tp_size + i)
if backend == "Engine":
print(f"[sgl] rank {rank} init engine")
engine = sgl.Engine(
model_path=model_name,
base_gpu_id=base_gpu_id,
tp_size=tp_size,
cuda_graph_max_bs=2,
tokenizer_path=model_name,
remote_instance_weight_loader_seed_instance_ip=seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=ports,
load_format="remote_instance",
)
else:
host, _, port = DEFAULT_URL_FOR_TEST.rpartition(":")
url = ":".join([host, str(int(port) + 10000 + rank)])
print(f"[sgl] rank {rank} init server on url: {url}")
process = popen_launch_server(
model_name,
url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--base-gpu-id",
str(base_gpu_id),
"--tp-size",
str(tp_size),
"--cuda-graph-max-bs",
2,
"--tokenizer-path",
model_name,
"--remote-instance-weight-loader-seed-instance-ip",
seed_instance_ip,
"--remote-instance-weight-loader-seed-instance-service-port",
seed_instance_service_port,
"--remote-instance-weight-loader-send-weights-group-ports",
f"[{','.join(str(port) for port in ports)}]",
"--load-format",
"remote_instance",
),
)
torch.cuda.synchronize()
event_dst_ready_list[rank - 1].set()
# Get weights of destination instance loaded from remote instance.
dst_params = []
for parameter_name in checking_parameters:
dst_params.append(
engine.get_weights_by_name(parameter_name, truncate_size)
if backend == "Engine"
else requests.get(
f"{url}/get_weights_by_name",
json={"name": parameter_name, "truncate_size": truncate_size},
).json()
)
param_queue.put((f"sgl_dp_{rank}_dst_params", dst_params))
# Shutdown the engine or terminate the server process.
if backend == "Engine":
engine.shutdown()
else:
terminate_process(process)
def test_load_weights_from_remote_instance(
tp_size,
dp_size,
model_name,
backends,
truncate_size,
checking_parameters,
seed_instance_ip,
seed_instance_service_port,
seed_instance_group_base_port,
):
print(
f"Testing model: {model_name} tp_size: {tp_size}, dp_size: {dp_size} backend: {backends}"
)
param_queue = mp.Queue()
results = {}
event_seed_ready = mp.Event()
event_dst_ready_list = []
for i in range(dp_size):
event_dst_ready = mp.Event()
event_dst_ready_list.append(event_dst_ready)
context = mp.spawn(
init_process,
args=(
param_queue,
truncate_size,
tp_size,
model_name,
backends,
checking_parameters,
seed_instance_ip,
seed_instance_service_port,
seed_instance_group_base_port,
event_seed_ready,
event_dst_ready_list,
),
nprocs=1 + dp_size,
join=False,
)
while len(results) < (1 + dp_size):
try:
key, value = param_queue.get(timeout=5)
results[key] = value
except Exception as e:
if all(not p.is_alive() for p in context.processes):
break
context.join()
if len(results) != (1 + dp_size):
raise RuntimeError(
f"Expected {(1 + dp_size)} parameters but got {len(results)}"
)
params = {
"seed": results.get("seed_params"),
"sgl_dp_1_dest": results.get("sgl_dp_1_dst_params"),
}
if dp_size == 2:
dp2_params = {
"sgl_dp_2_dest": results.get("sgl_dp_2_dst_params"),
}
assert all(v is not None for v in dp2_params.values())
params.update(dp2_params)
# Check the correctness of weights loaded from remote instance
# by verifying the weights of seed instance and destination instance.
for i in range(len(params["seed"])):
verify_params_close(
params["seed"][i],
params["sgl_dp_1_dest"][i],
f"sgl_dp_1_dst_params rank {i}",
)
if dp_size == 2:
verify_params_close(
params["seed"][i],
params["sgl_dp_2_dest"][i],
f"sgl_dp_2_dst_params rank {i}",
)
# Delete the context and close the parameter queue.
del context
param_queue.close()
param_queue.join_thread()
gc.collect()
torch.cuda.empty_cache()
class TestLoadWeightsFromRemoteInstance(CustomTestCase):
def test_load_weights_from_remote_instance(self):
assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required"
# test_suits : tp, dp, model_name, backend, dst_instance_id
if is_in_ci():
mode = random.choice(["Engine", "Server"])
test_suits = [
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, [mode]),
]
else:
test_suits = [
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Engine"]),
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Sever"]),
(2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["Engine", "Server"]),
]
truncate_size = 10
checking_parameters = [
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.1.self_attn.q_proj.weight",
"model.layers.2.self_attn.k_proj.weight",
"model.layers.3.self_attn.v_proj.weight",
"model.layers.4.self_attn.o_proj.weight",
"model.layers.5.mlp.gate_proj.weight",
"model.layers.6.mlp.up_proj.weight",
"model.layers.7.mlp.down_proj.weight",
"model.layers.8.post_attention_layernorm.weight",
"model.norm.weight",
]
for tp_size, dp_size, model_name, backends in test_suits:
test_load_weights_from_remote_instance(
tp_size,
dp_size,
model_name,
backends,
truncate_size,
checking_parameters,
"127.0.0.1",
DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000,
60000,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment