Unverified Commit 2bdaf482 authored by amysaq2023's avatar amysaq2023 Committed by GitHub
Browse files

refactor loading weights from remote instance coding format (#10941)


Signed-off-by: default avatarAnqi Shen <amy.saq@antgroup.com>
parent 777eb538
...@@ -58,6 +58,10 @@ class LoadConfig: ...@@ -58,6 +58,10 @@ class LoadConfig:
ignore_patterns: Optional[Union[List[str], str]] = None ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None decryption_key_file: Optional[str] = None
decrypt_max_concurrency: int = -1 decrypt_max_concurrency: int = -1
tp_rank: Optional[int] = None
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {} model_loader_extra_config = self.model_loader_extra_config or {}
......
...@@ -64,12 +64,6 @@ class ModelConfig: ...@@ -64,12 +64,6 @@ class ModelConfig:
is_draft_model: bool = False, is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None, hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
tp_rank: Optional[int] = None,
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
remote_instance_weight_loader_send_weights_group_ports: Optional[
List[int]
] = None,
) -> None: ) -> None:
# Parse args # Parse args
self.model_path = model_path self.model_path = model_path
...@@ -78,18 +72,6 @@ class ModelConfig: ...@@ -78,18 +72,6 @@ class ModelConfig:
self.is_draft_model = is_draft_model self.is_draft_model = is_draft_model
self.model_impl = model_impl self.model_impl = model_impl
# TODO: remove these fields
self.tp_rank = tp_rank
self.remote_instance_weight_loader_seed_instance_ip = (
remote_instance_weight_loader_seed_instance_ip
)
self.remote_instance_weight_loader_seed_instance_service_port = (
remote_instance_weight_loader_seed_instance_service_port
)
self.remote_instance_weight_loader_send_weights_group_ports = (
remote_instance_weight_loader_send_weights_group_ports
)
# Get hf config # Get hf config
self._maybe_pull_model_tokenizer_from_remote() self._maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args) self.model_override_args = json.loads(model_override_args)
...@@ -204,9 +186,6 @@ class ModelConfig: ...@@ -204,9 +186,6 @@ class ModelConfig:
quantization=server_args.quantization, quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl, model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs, **kwargs,
) )
......
...@@ -91,7 +91,6 @@ class TpModelWorker: ...@@ -91,7 +91,6 @@ class TpModelWorker:
else server_args.speculative_draft_model_revision else server_args.speculative_draft_model_revision
), ),
is_draft_model=is_draft_worker, is_draft_model=is_draft_worker,
tp_rank=tp_rank,
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
......
...@@ -104,6 +104,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe ...@@ -104,6 +104,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
trigger_init_weights_send_group_for_remote_instance_request,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.offloader import ( from sglang.srt.offloader import (
...@@ -112,9 +115,6 @@ from sglang.srt.offloader import ( ...@@ -112,9 +115,6 @@ from sglang.srt.offloader import (
set_offloader, set_offloader,
) )
from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.remote_instance_weight_loader_utils import (
trigger_init_weights_send_group_for_remote_instance_request,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
...@@ -743,6 +743,10 @@ class ModelRunner: ...@@ -743,6 +743,10 @@ class ModelRunner:
load_format=self.server_args.load_format, load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir, download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config, model_loader_extra_config=self.server_args.model_loader_extra_config,
tp_rank=self.tp_rank,
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
) )
if self.device == "cpu": if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp( self.model_config = adjust_config_with_unaligned_cpu_tp(
......
...@@ -54,6 +54,9 @@ from sglang.srt.distributed import ( ...@@ -54,6 +54,9 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
trigger_transferring_weights_request,
)
from sglang.srt.model_loader.utils import ( from sglang.srt.model_loader.utils import (
get_model_architecture, get_model_architecture,
post_load_weights, post_load_weights,
...@@ -77,9 +80,6 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -77,9 +80,6 @@ from sglang.srt.model_loader.weight_utils import (
safetensors_weights_iterator, safetensors_weights_iterator,
set_runai_streamer_env, set_runai_streamer_env,
) )
from sglang.srt.remote_instance_weight_loader_utils import (
trigger_transferring_weights_request,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_device_capability, get_device_capability,
...@@ -1420,7 +1420,7 @@ class RemoteInstanceModelLoader(BaseModelLoader): ...@@ -1420,7 +1420,7 @@ class RemoteInstanceModelLoader(BaseModelLoader):
f"load format {load_config.load_format}" f"load format {load_config.load_format}"
) )
model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}" model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
...@@ -1442,11 +1442,12 @@ class RemoteInstanceModelLoader(BaseModelLoader): ...@@ -1442,11 +1442,12 @@ class RemoteInstanceModelLoader(BaseModelLoader):
def load_model_from_remote_instance( def load_model_from_remote_instance(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module: ) -> nn.Module:
load_config = self.load_config
instance_ip = socket.gethostbyname(socket.gethostname()) instance_ip = socket.gethostbyname(socket.gethostname())
start_build_group_tic = time.time() start_build_group_tic = time.time()
client.build_group( client.build_group(
gpu_id=device_config.gpu_id, gpu_id=device_config.gpu_id,
tp_rank=model_config.tp_rank, tp_rank=load_config.tp_rank,
instance_ip=instance_ip, instance_ip=instance_ip,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -1455,13 +1456,13 @@ class RemoteInstanceModelLoader(BaseModelLoader): ...@@ -1455,13 +1456,13 @@ class RemoteInstanceModelLoader(BaseModelLoader):
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s" f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
) )
if model_config.tp_rank == 0: if load_config.tp_rank == 0:
t = threading.Thread( t = threading.Thread(
target=trigger_transferring_weights_request, target=trigger_transferring_weights_request,
args=( args=(
model_config.remote_instance_weight_loader_seed_instance_ip, load_config.remote_instance_weight_loader_seed_instance_ip,
model_config.remote_instance_weight_loader_seed_instance_service_port, load_config.remote_instance_weight_loader_seed_instance_service_port,
model_config.remote_instance_weight_loader_send_weights_group_ports, load_config.remote_instance_weight_loader_send_weights_group_ports,
instance_ip, instance_ip,
), ),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment