Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9f81d741
Unverified
Commit
9f81d741
authored
Aug 29, 2025
by
wangyu
Committed by
GitHub
Aug 28, 2025
Browse files
fix: fix MLA for ShardedModelLoader/RemoteModelLoader (#6287)
Signed-off-by:
wangyu
<
wangyu.steph@bytedance.com
>
parent
a38c1497
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
37 additions
and
35 deletions
+37
-35
examples/runtime/engine/save_remote_state.py
examples/runtime/engine/save_remote_state.py
+1
-2
python/sglang/srt/connector/__init__.py
python/sglang/srt/connector/__init__.py
+1
-1
python/sglang/srt/connector/base_connector.py
python/sglang/srt/connector/base_connector.py
+1
-2
python/sglang/srt/connector/redis.py
python/sglang/srt/connector/redis.py
+2
-2
python/sglang/srt/connector/serde/__init__.py
python/sglang/srt/connector/serde/__init__.py
+1
-1
python/sglang/srt/connector/serde/safe_serde.py
python/sglang/srt/connector/serde/safe_serde.py
+4
-3
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+15
-24
python/sglang/srt/model_loader/utils.py
python/sglang/srt/model_loader/utils.py
+12
-0
No files found.
examples/runtime/engine/save_remote_state.py
View file @
9f81d741
...
...
@@ -14,8 +14,7 @@ python save_remote_state.py \
Then, the model can be loaded with
llm = Engine(
model_path="/path/to/save",
--remote-model-url [protocol]://[host]:[port]/[model_name],
model_path="[protocol]://[host]:[port]/[model_name]",
tensor_parallel_size=8,
)
"""
...
...
python/sglang/srt/connector/__init__.py
View file @
9f81d741
...
...
@@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum):
KV
=
"KV"
def
create_remote_connector
(
url
,
device
=
"cpu"
)
->
BaseConnector
:
def
create_remote_connector
(
url
,
**
kwargs
)
->
BaseConnector
:
connector_type
=
parse_connector_type
(
url
)
if
connector_type
==
"redis"
:
return
RedisConnector
(
url
)
...
...
python/sglang/srt/connector/base_connector.py
View file @
9f81d741
...
...
@@ -20,9 +20,8 @@ class BaseConnector(ABC):
<connector_type://<host>:<port>/<model_name>/files/<filename>
"""
def
__init__
(
self
,
url
:
str
,
device
:
torch
.
device
=
"cpu"
):
def
__init__
(
self
,
url
:
str
):
self
.
url
=
url
self
.
device
=
device
self
.
closed
=
False
self
.
local_dir
=
tempfile
.
mkdtemp
()
for
sig
in
(
signal
.
SIGINT
,
signal
.
SIGTERM
):
...
...
python/sglang/srt/connector/redis.py
View file @
9f81d741
...
...
@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__)
class
RedisConnector
(
BaseKVConnector
):
def
__init__
(
self
,
url
:
str
,
device
:
torch
.
device
=
"cpu"
):
def
__init__
(
self
,
url
:
str
):
import
redis
super
().
__init__
(
url
,
device
)
super
().
__init__
(
url
)
parsed_url
=
urlparse
(
url
)
self
.
connection
=
redis
.
Redis
(
host
=
parsed_url
.
hostname
,
port
=
parsed_url
.
port
)
self
.
model_name
=
parsed_url
.
path
.
lstrip
(
"/"
)
...
...
python/sglang/srt/connector/serde/__init__.py
View file @
9f81d741
...
...
@@ -15,7 +15,7 @@ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
if
serde_type
==
"safe"
:
s
=
SafeSerializer
()
d
=
SafeDeserializer
(
torch
.
uint8
)
d
=
SafeDeserializer
()
else
:
raise
ValueError
(
f
"Unknown serde type:
{
serde_type
}
"
)
...
...
python/sglang/srt/connector/serde/safe_serde.py
View file @
9f81d741
...
...
@@ -19,11 +19,12 @@ class SafeSerializer(Serializer):
class
SafeDeserializer
(
Deserializer
):
def
__init__
(
self
,
dtype
):
super
().
__init__
(
dtype
)
def
__init__
(
self
):
# TODO: dtype options
super
().
__init__
(
torch
.
float32
)
def
from_bytes_normal
(
self
,
b
:
Union
[
bytearray
,
bytes
])
->
torch
.
Tensor
:
return
load
(
bytes
(
b
))[
"tensor_bytes"
]
.
to
(
dtype
=
self
.
dtype
)
return
load
(
bytes
(
b
))[
"tensor_bytes"
]
def
from_bytes
(
self
,
b
:
Union
[
bytearray
,
bytes
])
->
torch
.
Tensor
:
return
self
.
from_bytes_normal
(
b
)
python/sglang/srt/model_loader/loader.py
View file @
9f81d741
...
...
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_loader.utils
import
(
get_model_architecture
,
post_load_weights
,
set_default_torch_dtype
,
)
from
sglang.srt.model_loader.weight_utils
import
(
...
...
@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
# random values to the weights.
initialize_dummy_weights
(
model
)
# Model weight loading consists of two stages:
# 1. Initial weight loading.
# 2. Post-processing of weights, including assigning specific member variables.
# For `dummy_init`, only the second stage is required.
if
hasattr
(
model
,
"post_load_weights"
):
if
(
model_config
.
hf_config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLMNextN"
):
model
.
post_load_weights
(
is_nextn
=
True
)
else
:
model
.
post_load_weights
()
post_load_weights
(
model
,
model_config
)
return
model
.
eval
()
...
...
@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
state_dict
.
pop
(
key
)
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
post_load_weights
(
model
,
model_config
)
return
model
.
eval
()
@
staticmethod
...
...
@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
# ignore hidden files
if
file_name
.
startswith
(
"."
):
continue
if
os
.
path
.
splitext
(
file_name
)[
1
]
not
in
(
".bin"
,
".pt"
,
".safetensors"
,
):
if
os
.
path
.
splitext
(
file_name
)[
1
]
in
(
".json"
,
".py"
):
file_path
=
os
.
path
.
join
(
root
,
file_name
)
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
file
:
file_content
=
file
.
read
()
f_key
=
f
"
{
model_name
}
/files/
{
file_name
}
"
client
.
setstr
(
f_key
,
file_content
)
def
_load_model_from_remote_kv
(
self
,
model
:
nn
.
Module
,
client
):
def
_load_model_from_remote_kv
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
client
):
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
...
...
@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
post_load_weights
(
model
,
model_config
)
def
_load_model_from_remote_fs
(
self
,
model
,
client
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
)
->
nn
.
Module
:
...
...
@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
with
create_remote_connector
(
model_weights
,
device_config
.
device
)
as
client
:
with
create_remote_connector
(
model_weights
,
device
=
device_config
.
device
)
as
client
:
connector_type
=
get_connector_type
(
client
)
if
connector_type
==
ConnectorType
.
KV
:
self
.
_load_model_from_remote_kv
(
model
,
client
)
self
.
_load_model_from_remote_kv
(
model
,
model_config
,
client
)
elif
connector_type
==
ConnectorType
.
FS
:
self
.
_load_model_from_remote_fs
(
model
,
client
,
model_config
,
device_config
...
...
python/sglang/srt/model_loader/utils.py
View file @
9f81d741
...
...
@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
def
get_architecture_class_name
(
model_config
:
ModelConfig
)
->
str
:
return
get_model_architecture
(
model_config
)[
1
]
def
post_load_weights
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
):
# Model weight loading consists of two stages:
# 1. Initial weight loading.
# 2. Post-processing of weights, including assigning specific member variables.
# For `dummy_init`, only the second stage is required.
if
hasattr
(
model
,
"post_load_weights"
):
if
model_config
.
hf_config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLMNextN"
:
model
.
post_load_weights
(
is_nextn
=
True
)
else
:
model
.
post_load_weights
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment