Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2bdaf482
Unverified
Commit
2bdaf482
authored
Sep 27, 2025
by
amysaq2023
Committed by
GitHub
Sep 26, 2025
Browse files
refactor loading weights from remote instance coding format (#10941)
Signed-off-by:
Anqi Shen
<
amy.saq@antgroup.com
>
parent
777eb538
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
34 deletions
+21
-34
python/sglang/srt/configs/load_config.py
python/sglang/srt/configs/load_config.py
+4
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+0
-21
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+0
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+7
-3
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+10
-9
python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py
...g/srt/model_loader/remote_instance_weight_loader_utils.py
+0
-0
No files found.
python/sglang/srt/configs/load_config.py
View file @
2bdaf482
...
@@ -58,6 +58,10 @@ class LoadConfig:
...
@@ -58,6 +58,10 @@ class LoadConfig:
ignore_patterns
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
ignore_patterns
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
decryption_key_file
:
Optional
[
str
]
=
None
decryption_key_file
:
Optional
[
str
]
=
None
decrypt_max_concurrency
:
int
=
-
1
decrypt_max_concurrency
:
int
=
-
1
tp_rank
:
Optional
[
int
]
=
None
remote_instance_weight_loader_seed_instance_ip
:
Optional
[
str
]
=
None
remote_instance_weight_loader_seed_instance_service_port
:
Optional
[
int
]
=
None
remote_instance_weight_loader_send_weights_group_ports
:
Optional
[
List
[
int
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
model_loader_extra_config
=
self
.
model_loader_extra_config
or
{}
model_loader_extra_config
=
self
.
model_loader_extra_config
or
{}
...
...
python/sglang/srt/configs/model_config.py
View file @
2bdaf482
...
@@ -64,12 +64,6 @@ class ModelConfig:
...
@@ -64,12 +64,6 @@ class ModelConfig:
is_draft_model
:
bool
=
False
,
is_draft_model
:
bool
=
False
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
model_impl
:
Union
[
str
,
ModelImpl
]
=
ModelImpl
.
AUTO
,
model_impl
:
Union
[
str
,
ModelImpl
]
=
ModelImpl
.
AUTO
,
tp_rank
:
Optional
[
int
]
=
None
,
remote_instance_weight_loader_seed_instance_ip
:
Optional
[
str
]
=
None
,
remote_instance_weight_loader_seed_instance_service_port
:
Optional
[
int
]
=
None
,
remote_instance_weight_loader_send_weights_group_ports
:
Optional
[
List
[
int
]
]
=
None
,
)
->
None
:
)
->
None
:
# Parse args
# Parse args
self
.
model_path
=
model_path
self
.
model_path
=
model_path
...
@@ -78,18 +72,6 @@ class ModelConfig:
...
@@ -78,18 +72,6 @@ class ModelConfig:
self
.
is_draft_model
=
is_draft_model
self
.
is_draft_model
=
is_draft_model
self
.
model_impl
=
model_impl
self
.
model_impl
=
model_impl
# TODO: remove these fields
self
.
tp_rank
=
tp_rank
self
.
remote_instance_weight_loader_seed_instance_ip
=
(
remote_instance_weight_loader_seed_instance_ip
)
self
.
remote_instance_weight_loader_seed_instance_service_port
=
(
remote_instance_weight_loader_seed_instance_service_port
)
self
.
remote_instance_weight_loader_send_weights_group_ports
=
(
remote_instance_weight_loader_send_weights_group_ports
)
# Get hf config
# Get hf config
self
.
_maybe_pull_model_tokenizer_from_remote
()
self
.
_maybe_pull_model_tokenizer_from_remote
()
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
...
@@ -204,9 +186,6 @@ class ModelConfig:
...
@@ -204,9 +186,6 @@ class ModelConfig:
quantization
=
server_args
.
quantization
,
quantization
=
server_args
.
quantization
,
hybrid_kvcache_ratio
=
server_args
.
hybrid_kvcache_ratio
,
hybrid_kvcache_ratio
=
server_args
.
hybrid_kvcache_ratio
,
model_impl
=
server_args
.
model_impl
,
model_impl
=
server_args
.
model_impl
,
remote_instance_weight_loader_seed_instance_ip
=
server_args
.
remote_instance_weight_loader_seed_instance_ip
,
remote_instance_weight_loader_seed_instance_service_port
=
server_args
.
remote_instance_weight_loader_seed_instance_service_port
,
remote_instance_weight_loader_send_weights_group_ports
=
server_args
.
remote_instance_weight_loader_send_weights_group_ports
,
**
kwargs
,
**
kwargs
,
)
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
2bdaf482
...
@@ -91,7 +91,6 @@ class TpModelWorker:
...
@@ -91,7 +91,6 @@ class TpModelWorker:
else
server_args
.
speculative_draft_model_revision
else
server_args
.
speculative_draft_model_revision
),
),
is_draft_model
=
is_draft_worker
,
is_draft_model
=
is_draft_worker
,
tp_rank
=
tp_rank
,
)
)
self
.
model_runner
=
ModelRunner
(
self
.
model_runner
=
ModelRunner
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
2bdaf482
...
@@ -104,6 +104,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
...
@@ -104,6 +104,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from
sglang.srt.model_executor.npu_graph_runner
import
NPUGraphRunner
from
sglang.srt.model_executor.npu_graph_runner
import
NPUGraphRunner
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
,
get_model_loader
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
,
get_model_loader
from
sglang.srt.model_loader.remote_instance_weight_loader_utils
import
(
trigger_init_weights_send_group_for_remote_instance_request
,
)
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.offloader
import
(
from
sglang.srt.offloader
import
(
...
@@ -112,9 +115,6 @@ from sglang.srt.offloader import (
...
@@ -112,9 +115,6 @@ from sglang.srt.offloader import (
set_offloader
,
set_offloader
,
)
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.remote_instance_weight_loader_utils
import
(
trigger_init_weights_send_group_for_remote_instance_request
,
)
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
...
@@ -743,6 +743,10 @@ class ModelRunner:
...
@@ -743,6 +743,10 @@ class ModelRunner:
load_format
=
self
.
server_args
.
load_format
,
load_format
=
self
.
server_args
.
load_format
,
download_dir
=
self
.
server_args
.
download_dir
,
download_dir
=
self
.
server_args
.
download_dir
,
model_loader_extra_config
=
self
.
server_args
.
model_loader_extra_config
,
model_loader_extra_config
=
self
.
server_args
.
model_loader_extra_config
,
tp_rank
=
self
.
tp_rank
,
remote_instance_weight_loader_seed_instance_ip
=
self
.
server_args
.
remote_instance_weight_loader_seed_instance_ip
,
remote_instance_weight_loader_seed_instance_service_port
=
self
.
server_args
.
remote_instance_weight_loader_seed_instance_service_port
,
remote_instance_weight_loader_send_weights_group_ports
=
self
.
server_args
.
remote_instance_weight_loader_send_weights_group_ports
,
)
)
if
self
.
device
==
"cpu"
:
if
self
.
device
==
"cpu"
:
self
.
model_config
=
adjust_config_with_unaligned_cpu_tp
(
self
.
model_config
=
adjust_config_with_unaligned_cpu_tp
(
...
...
python/sglang/srt/model_loader/loader.py
View file @
2bdaf482
...
@@ -54,6 +54,9 @@ from sglang.srt.distributed import (
...
@@ -54,6 +54,9 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
sglang.srt.model_loader.remote_instance_weight_loader_utils
import
(
trigger_transferring_weights_request
,
)
from
sglang.srt.model_loader.utils
import
(
from
sglang.srt.model_loader.utils
import
(
get_model_architecture
,
get_model_architecture
,
post_load_weights
,
post_load_weights
,
...
@@ -77,9 +80,6 @@ from sglang.srt.model_loader.weight_utils import (
...
@@ -77,9 +80,6 @@ from sglang.srt.model_loader.weight_utils import (
safetensors_weights_iterator
,
safetensors_weights_iterator
,
set_runai_streamer_env
,
set_runai_streamer_env
,
)
)
from
sglang.srt.remote_instance_weight_loader_utils
import
(
trigger_transferring_weights_request
,
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_bool_env_var
,
get_device_capability
,
get_device_capability
,
...
@@ -1420,7 +1420,7 @@ class RemoteInstanceModelLoader(BaseModelLoader):
...
@@ -1420,7 +1420,7 @@ class RemoteInstanceModelLoader(BaseModelLoader):
f
"load format
{
load_config
.
load_format
}
"
f
"load format
{
load_config
.
load_format
}
"
)
)
model_weights
=
f
"instance://
{
model
_config
.
remote_instance_weight_loader_seed_instance_ip
}
:
{
model
_config
.
remote_instance_weight_loader_send_weights_group_ports
[
model
_config
.
tp_rank
]
}
"
model_weights
=
f
"instance://
{
load
_config
.
remote_instance_weight_loader_seed_instance_ip
}
:
{
load
_config
.
remote_instance_weight_loader_send_weights_group_ports
[
load
_config
.
tp_rank
]
}
"
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
...
@@ -1442,11 +1442,12 @@ class RemoteInstanceModelLoader(BaseModelLoader):
...
@@ -1442,11 +1442,12 @@ class RemoteInstanceModelLoader(BaseModelLoader):
def
load_model_from_remote_instance
(
def
load_model_from_remote_instance
(
self
,
model
,
client
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
self
,
model
,
client
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
)
->
nn
.
Module
:
)
->
nn
.
Module
:
load_config
=
self
.
load_config
instance_ip
=
socket
.
gethostbyname
(
socket
.
gethostname
())
instance_ip
=
socket
.
gethostbyname
(
socket
.
gethostname
())
start_build_group_tic
=
time
.
time
()
start_build_group_tic
=
time
.
time
()
client
.
build_group
(
client
.
build_group
(
gpu_id
=
device_config
.
gpu_id
,
gpu_id
=
device_config
.
gpu_id
,
tp_rank
=
model
_config
.
tp_rank
,
tp_rank
=
load
_config
.
tp_rank
,
instance_ip
=
instance_ip
,
instance_ip
=
instance_ip
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -1455,13 +1456,13 @@ class RemoteInstanceModelLoader(BaseModelLoader):
...
@@ -1455,13 +1456,13 @@ class RemoteInstanceModelLoader(BaseModelLoader):
f
"finish building group for remote instance, time used:
{
(
end_build_group_tic
-
start_build_group_tic
):.
4
f
}
s"
f
"finish building group for remote instance, time used:
{
(
end_build_group_tic
-
start_build_group_tic
):.
4
f
}
s"
)
)
if
model
_config
.
tp_rank
==
0
:
if
load
_config
.
tp_rank
==
0
:
t
=
threading
.
Thread
(
t
=
threading
.
Thread
(
target
=
trigger_transferring_weights_request
,
target
=
trigger_transferring_weights_request
,
args
=
(
args
=
(
model
_config
.
remote_instance_weight_loader_seed_instance_ip
,
load
_config
.
remote_instance_weight_loader_seed_instance_ip
,
model
_config
.
remote_instance_weight_loader_seed_instance_service_port
,
load
_config
.
remote_instance_weight_loader_seed_instance_service_port
,
model
_config
.
remote_instance_weight_loader_send_weights_group_ports
,
load
_config
.
remote_instance_weight_loader_send_weights_group_ports
,
instance_ip
,
instance_ip
,
),
),
)
)
...
...
python/sglang/srt/remote_instance_weight_loader_utils.py
→
python/sglang/srt/
model_loader/
remote_instance_weight_loader_utils.py
View file @
2bdaf482
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment