"...text-generation-inference.git" did not exist on "88aae2595d9a2011a58956391848cde869479653"
Unverified Commit 2bdaf482 authored by amysaq2023's avatar amysaq2023 Committed by GitHub
Browse files

refactor loading weights from remote instance coding format (#10941)


Signed-off-by: default avatarAnqi Shen <amy.saq@antgroup.com>
parent 777eb538
......@@ -58,6 +58,10 @@ class LoadConfig:
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
decrypt_max_concurrency: int = -1
tp_rank: Optional[int] = None
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
......
......@@ -64,12 +64,6 @@ class ModelConfig:
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
tp_rank: Optional[int] = None,
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
remote_instance_weight_loader_send_weights_group_ports: Optional[
List[int]
] = None,
) -> None:
# Parse args
self.model_path = model_path
......@@ -78,18 +72,6 @@ class ModelConfig:
self.is_draft_model = is_draft_model
self.model_impl = model_impl
# TODO: remove these fields
self.tp_rank = tp_rank
self.remote_instance_weight_loader_seed_instance_ip = (
remote_instance_weight_loader_seed_instance_ip
)
self.remote_instance_weight_loader_seed_instance_service_port = (
remote_instance_weight_loader_seed_instance_service_port
)
self.remote_instance_weight_loader_send_weights_group_ports = (
remote_instance_weight_loader_send_weights_group_ports
)
# Get hf config
self._maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
......@@ -204,9 +186,6 @@ class ModelConfig:
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs,
)
......
......@@ -91,7 +91,6 @@ class TpModelWorker:
else server_args.speculative_draft_model_revision
),
is_draft_model=is_draft_worker,
tp_rank=tp_rank,
)
self.model_runner = ModelRunner(
......
......@@ -104,6 +104,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
trigger_init_weights_send_group_for_remote_instance_request,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.offloader import (
......@@ -112,9 +115,6 @@ from sglang.srt.offloader import (
set_offloader,
)
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.remote_instance_weight_loader_utils import (
trigger_init_weights_send_group_for_remote_instance_request,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
......@@ -743,6 +743,10 @@ class ModelRunner:
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config,
tp_rank=self.tp_rank,
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
)
if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp(
......
......@@ -54,6 +54,9 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
trigger_transferring_weights_request,
)
from sglang.srt.model_loader.utils import (
get_model_architecture,
post_load_weights,
......@@ -77,9 +80,6 @@ from sglang.srt.model_loader.weight_utils import (
safetensors_weights_iterator,
set_runai_streamer_env,
)
from sglang.srt.remote_instance_weight_loader_utils import (
trigger_transferring_weights_request,
)
from sglang.srt.utils import (
get_bool_env_var,
get_device_capability,
......@@ -1420,7 +1420,7 @@ class RemoteInstanceModelLoader(BaseModelLoader):
f"load format {load_config.load_format}"
)
model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
......@@ -1442,11 +1442,12 @@ class RemoteInstanceModelLoader(BaseModelLoader):
def load_model_from_remote_instance(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
load_config = self.load_config
instance_ip = socket.gethostbyname(socket.gethostname())
start_build_group_tic = time.time()
client.build_group(
gpu_id=device_config.gpu_id,
tp_rank=model_config.tp_rank,
tp_rank=load_config.tp_rank,
instance_ip=instance_ip,
)
torch.cuda.synchronize()
......@@ -1455,13 +1456,13 @@ class RemoteInstanceModelLoader(BaseModelLoader):
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
)
if model_config.tp_rank == 0:
if load_config.tp_rank == 0:
t = threading.Thread(
target=trigger_transferring_weights_request,
args=(
model_config.remote_instance_weight_loader_seed_instance_ip,
model_config.remote_instance_weight_loader_seed_instance_service_port,
model_config.remote_instance_weight_loader_send_weights_group_ports,
load_config.remote_instance_weight_loader_seed_instance_ip,
load_config.remote_instance_weight_loader_seed_instance_service_port,
load_config.remote_instance_weight_loader_send_weights_group_ports,
instance_ip,
),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment