Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1ce4878d
Unverified
Commit
1ce4878d
authored
Mar 14, 2025
by
wangyu
Committed by
GitHub
Mar 14, 2025
Browse files
feat(remote_model): support variable remote backend for model loader (#3964)
Signed-off-by:
wangyu
<
wangyu.steph@bytedance.com
>
parent
977d7cd2
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1005 additions
and
7 deletions
+1005
-7
examples/runtime/engine/save_remote_state.py
examples/runtime/engine/save_remote_state.py
+51
-0
examples/runtime/engine/save_sharded_state.py
examples/runtime/engine/save_sharded_state.py
+74
-0
python/sglang/__init__.py
python/sglang/__init__.py
+2
-0
python/sglang/srt/configs/load_config.py
python/sglang/srt/configs/load_config.py
+1
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+25
-1
python/sglang/srt/connector/__init__.py
python/sglang/srt/connector/__init__.py
+51
-0
python/sglang/srt/connector/base_connector.py
python/sglang/srt/connector/base_connector.py
+112
-0
python/sglang/srt/connector/redis.py
python/sglang/srt/connector/redis.py
+85
-0
python/sglang/srt/connector/s3.py
python/sglang/srt/connector/s3.py
+122
-0
python/sglang/srt/connector/serde/__init__.py
python/sglang/srt/connector/serde/__init__.py
+31
-0
python/sglang/srt/connector/serde/safe_serde.py
python/sglang/srt/connector/serde/safe_serde.py
+29
-0
python/sglang/srt/connector/serde/serde.py
python/sglang/srt/connector/serde/serde.py
+43
-0
python/sglang/srt/connector/utils.py
python/sglang/srt/connector/utils.py
+35
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+40
-4
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+10
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+12
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+62
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+16
-0
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+159
-1
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+45
-0
No files found.
examples/runtime/engine/save_remote_state.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_remote_state.py
\
--model-path /path/to/load
\
--tensor-parallel-size 8
\
--remote-model-save-url [protocol]://[host]:[port]/[model_name]
\
Then, the model can be loaded with
llm = Engine(
model_path="/path/to/save",
--remote-model-url [protocol]://[host]:[port]/[model_name],
tensor_parallel_size=8,
)
"""
import
dataclasses
from
argparse
import
ArgumentParser
from
pathlib
import
Path
from
sglang
import
Engine
,
ServerArgs
parser
=
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
"--remote-model-save-url"
,
required
=
True
,
type
=
str
,
help
=
"remote address to store model weights"
,
)
def
main
(
args
):
engine_args
=
ServerArgs
.
from_cli_args
(
args
)
model_path
=
engine_args
.
model_path
if
not
Path
(
model_path
).
is_dir
():
raise
ValueError
(
"model path must be a local directory"
)
# Create LLM instance from arguments
llm
=
Engine
(
**
dataclasses
.
asdict
(
engine_args
))
llm
.
save_remote_model
(
url
=
args
.
remote_model_save_url
)
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
main
(
args
)
examples/runtime/engine/save_sharded_state.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_sharded_state.py
\
--model-path /path/to/load
\
--quantization deepspeedfp
\
--tensor-parallel-size 8
\
--output /path/to/save
Then, the model can be loaded with
llm = Engine(
model_path="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import
dataclasses
import
os
import
shutil
from
argparse
import
ArgumentParser
from
pathlib
import
Path
from
sglang
import
Engine
,
ServerArgs
parser
=
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
"--output"
,
"-o"
,
required
=
True
,
type
=
str
,
help
=
"path to output checkpoint"
)
parser
.
add_argument
(
"--file-pattern"
,
type
=
str
,
help
=
"string pattern of saved filenames"
)
parser
.
add_argument
(
"--max-file-size"
,
type
=
str
,
default
=
5
*
1024
**
3
,
help
=
"max size (in bytes) of each safetensors file"
,
)
def
main
(
args
):
engine_args
=
ServerArgs
.
from_cli_args
(
args
)
model_path
=
engine_args
.
model_path
if
not
Path
(
model_path
).
is_dir
():
raise
ValueError
(
"model path must be a local directory"
)
# Create LLM instance from arguments
llm
=
Engine
(
**
dataclasses
.
asdict
(
engine_args
))
Path
(
args
.
output
).
mkdir
(
exist_ok
=
True
)
llm
.
save_sharded_model
(
path
=
args
.
output
,
pattern
=
args
.
file_pattern
,
max_size
=
args
.
max_file_size
)
# Copy metadata files to output directory
for
file
in
os
.
listdir
(
model_path
):
if
os
.
path
.
splitext
(
file
)[
1
]
not
in
(
".bin"
,
".pt"
,
".safetensors"
):
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_path
,
file
)):
shutil
.
copytree
(
os
.
path
.
join
(
model_path
,
file
),
os
.
path
.
join
(
args
.
output
,
file
)
)
else
:
shutil
.
copy
(
os
.
path
.
join
(
model_path
,
file
),
args
.
output
)
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
main
(
args
)
python/sglang/__init__.py
View file @
1ce4878d
...
@@ -32,6 +32,7 @@ from sglang.lang.choices import (
...
@@ -32,6 +32,7 @@ from sglang.lang.choices import (
)
)
from
sglang.utils
import
LazyImport
from
sglang.utils
import
LazyImport
ServerArgs
=
LazyImport
(
"sglang.srt.server_args"
,
"ServerArgs"
)
Anthropic
=
LazyImport
(
"sglang.lang.backend.anthropic"
,
"Anthropic"
)
Anthropic
=
LazyImport
(
"sglang.lang.backend.anthropic"
,
"Anthropic"
)
LiteLLM
=
LazyImport
(
"sglang.lang.backend.litellm"
,
"LiteLLM"
)
LiteLLM
=
LazyImport
(
"sglang.lang.backend.litellm"
,
"LiteLLM"
)
OpenAI
=
LazyImport
(
"sglang.lang.backend.openai"
,
"OpenAI"
)
OpenAI
=
LazyImport
(
"sglang.lang.backend.openai"
,
"OpenAI"
)
...
@@ -67,6 +68,7 @@ __all__ = [
...
@@ -67,6 +68,7 @@ __all__ = [
"greedy_token_selection"
,
"greedy_token_selection"
,
"token_length_normalized"
,
"token_length_normalized"
,
"unconditional_likelihood_normalized"
,
"unconditional_likelihood_normalized"
,
"ServerArgs"
,
"Anthropic"
,
"Anthropic"
,
"LiteLLM"
,
"LiteLLM"
,
"OpenAI"
,
"OpenAI"
,
...
...
python/sglang/srt/configs/load_config.py
View file @
1ce4878d
...
@@ -22,6 +22,7 @@ class LoadFormat(str, enum.Enum):
...
@@ -22,6 +22,7 @@ class LoadFormat(str, enum.Enum):
MISTRAL
=
"mistral"
MISTRAL
=
"mistral"
LAYERED
=
"layered"
LAYERED
=
"layered"
JAX
=
"jax"
JAX
=
"jax"
REMOTE
=
"remote"
@
dataclass
@
dataclass
...
...
python/sglang/srt/configs/model_config.py
View file @
1ce4878d
...
@@ -51,13 +51,14 @@ class ModelConfig:
...
@@ -51,13 +51,14 @@ class ModelConfig:
self
.
quantization
=
quantization
self
.
quantization
=
quantization
# Parse args
# Parse args
self
.
maybe_pull_model_tokenizer_from_remote
()
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
kwargs
=
{}
kwargs
=
{}
if
override_config_file
and
override_config_file
.
strip
():
if
override_config_file
and
override_config_file
.
strip
():
kwargs
[
"_configuration_file"
]
=
override_config_file
.
strip
()
kwargs
[
"_configuration_file"
]
=
override_config_file
.
strip
()
self
.
hf_config
=
get_config
(
self
.
hf_config
=
get_config
(
model_path
,
self
.
model_path
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
revision
=
revision
,
model_override_args
=
self
.
model_override_args
,
model_override_args
=
self
.
model_override_args
,
...
@@ -318,6 +319,29 @@ class ModelConfig:
...
@@ -318,6 +319,29 @@ class ModelConfig:
eos_ids
=
{
eos_ids
}
if
isinstance
(
eos_ids
,
int
)
else
set
(
eos_ids
)
eos_ids
=
{
eos_ids
}
if
isinstance
(
eos_ids
,
int
)
else
set
(
eos_ids
)
return
eos_ids
return
eos_ids
def
maybe_pull_model_tokenizer_from_remote
(
self
)
->
None
:
"""
Pull the model config files to a temporary
directory in case of remote.
Args:
model: The model name or path.
"""
from
sglang.srt.connector
import
create_remote_connector
from
sglang.srt.utils
import
is_remote_url
if
is_remote_url
(
self
.
model_path
):
logger
.
info
(
"Pulling model configs from remote..."
)
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client
=
create_remote_connector
(
self
.
model_path
)
if
is_remote_url
(
self
.
model_path
):
client
.
pull_files
(
allow_pattern
=
[
"*config.json"
])
self
.
model_weights
=
self
.
model_path
self
.
model_path
=
client
.
get_local_dir
()
def
get_hf_text_config
(
config
:
PretrainedConfig
):
def
get_hf_text_config
(
config
:
PretrainedConfig
):
"""Get the "sub" config relevant to llm for multi modal models.
"""Get the "sub" config relevant to llm for multi modal models.
...
...
python/sglang/srt/connector/__init__.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
enum
import
logging
from
sglang.srt.connector.base_connector
import
(
BaseConnector
,
BaseFileConnector
,
BaseKVConnector
,
)
from
sglang.srt.connector.redis
import
RedisConnector
from
sglang.srt.connector.s3
import
S3Connector
from
sglang.srt.utils
import
parse_connector_type
logger
=
logging
.
getLogger
(
__name__
)
class
ConnectorType
(
str
,
enum
.
Enum
):
FS
=
"filesystem"
KV
=
"KV"
def
create_remote_connector
(
url
,
device
=
"cpu"
)
->
BaseConnector
:
connector_type
=
parse_connector_type
(
url
)
if
connector_type
==
"redis"
:
return
RedisConnector
(
url
)
elif
connector_type
==
"s3"
:
return
S3Connector
(
url
)
else
:
raise
ValueError
(
f
"Invalid connector type:
{
url
}
"
)
def
get_connector_type
(
client
:
BaseConnector
)
->
ConnectorType
:
if
isinstance
(
client
,
BaseKVConnector
):
return
ConnectorType
.
KV
if
isinstance
(
client
,
BaseFileConnector
):
return
ConnectorType
.
FS
raise
ValueError
(
f
"Invalid connector type:
{
client
}
"
)
__all__
=
[
"BaseConnector"
,
"BaseFileConnector"
,
"BaseKVConnector"
,
"RedisConnector"
,
"S3Connector"
,
"ConnectorType"
,
"create_remote_connector"
,
"get_connector_type"
,
]
python/sglang/srt/connector/base_connector.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
os
import
shutil
import
signal
import
tempfile
from
abc
import
ABC
,
abstractmethod
from
typing
import
Generator
,
List
,
Optional
,
Tuple
import
torch
class
BaseConnector
(
ABC
):
"""
For fs connector such as s3:
<connector_type>://<path>/<filename>
For kv connector such as redis:
<connector_type>://<host>:<port>/<model_name>/keys/<key>
<connector_type://<host>:<port>/<model_name>/files/<filename>
"""
def
__init__
(
self
,
url
:
str
,
device
:
torch
.
device
=
"cpu"
):
self
.
url
=
url
self
.
device
=
device
self
.
closed
=
False
self
.
local_dir
=
tempfile
.
mkdtemp
()
for
sig
in
(
signal
.
SIGINT
,
signal
.
SIGTERM
):
existing_handler
=
signal
.
getsignal
(
sig
)
signal
.
signal
(
sig
,
self
.
_close_by_signal
(
existing_handler
))
def
get_local_dir
(
self
):
return
self
.
local_dir
@
abstractmethod
def
weight_iterator
(
self
,
rank
:
int
=
0
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
raise
NotImplementedError
()
@
abstractmethod
def
pull_files
(
self
,
allow_pattern
:
Optional
[
List
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
raise
NotImplementedError
()
def
close
(
self
):
if
self
.
closed
:
return
self
.
closed
=
True
if
os
.
path
.
exists
(
self
.
local_dir
):
shutil
.
rmtree
(
self
.
local_dir
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
close
()
def
__del__
(
self
):
self
.
close
()
def
_close_by_signal
(
self
,
existing_handler
=
None
):
def
new_handler
(
signum
,
frame
):
self
.
close
()
if
existing_handler
:
existing_handler
(
signum
,
frame
)
return
new_handler
class
BaseKVConnector
(
BaseConnector
):
@
abstractmethod
def
get
(
self
,
key
:
str
)
->
Optional
[
torch
.
Tensor
]:
raise
NotImplementedError
()
@
abstractmethod
def
getstr
(
self
,
key
:
str
)
->
Optional
[
str
]:
raise
NotImplementedError
()
@
abstractmethod
def
set
(
self
,
key
:
str
,
obj
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
()
@
abstractmethod
def
setstr
(
self
,
key
:
str
,
obj
:
str
)
->
None
:
raise
NotImplementedError
()
@
abstractmethod
def
list
(
self
,
prefix
:
str
)
->
List
[
str
]:
raise
NotImplementedError
()
class
BaseFileConnector
(
BaseConnector
):
"""
List full file names from remote fs path and filter by allow pattern.
Args:
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full paths allowed by the pattern
"""
@
abstractmethod
def
glob
(
self
,
allow_pattern
:
str
)
->
List
[
str
]:
raise
NotImplementedError
()
python/sglang/srt/connector/redis.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
logging
from
typing
import
Generator
,
List
,
Optional
,
Tuple
from
urllib.parse
import
urlparse
import
torch
from
sglang.srt.connector
import
BaseKVConnector
from
sglang.srt.connector.serde
import
create_serde
from
sglang.srt.connector.utils
import
pull_files_from_db
logger
=
logging
.
getLogger
(
__name__
)
class
RedisConnector
(
BaseKVConnector
):
def
__init__
(
self
,
url
:
str
,
device
:
torch
.
device
=
"cpu"
):
import
redis
super
().
__init__
(
url
,
device
)
parsed_url
=
urlparse
(
url
)
self
.
connection
=
redis
.
Redis
(
host
=
parsed_url
.
hostname
,
port
=
parsed_url
.
port
)
self
.
model_name
=
parsed_url
.
path
.
lstrip
(
"/"
)
# TODO: more serde options
self
.
s
,
self
.
d
=
create_serde
(
"safe"
)
def
get
(
self
,
key
:
str
)
->
Optional
[
torch
.
Tensor
]:
val
=
self
.
connection
.
get
(
key
)
if
val
is
None
:
logger
.
error
(
"Key %s not found"
,
key
)
return
None
return
self
.
d
.
from_bytes
(
val
)
def
getstr
(
self
,
key
:
str
)
->
Optional
[
str
]:
val
=
self
.
connection
.
get
(
key
)
if
val
is
None
:
logger
.
error
(
"Key %s not found"
,
key
)
return
None
return
val
.
decode
(
"utf-8"
)
def
set
(
self
,
key
:
str
,
tensor
:
torch
.
Tensor
)
->
None
:
assert
tensor
is
not
None
self
.
connection
.
set
(
key
,
self
.
s
.
to_bytes
(
tensor
))
def
setstr
(
self
,
key
:
str
,
obj
:
str
)
->
None
:
self
.
connection
.
set
(
key
,
obj
)
def
list
(
self
,
prefix
:
str
)
->
List
[
str
]:
cursor
=
0
all_keys
:
List
[
bytes
]
=
[]
while
True
:
ret
:
Tuple
[
int
,
List
[
bytes
]]
=
self
.
connection
.
scan
(
cursor
=
cursor
,
match
=
f
"
{
prefix
}
*"
)
# type: ignore
cursor
,
keys
=
ret
all_keys
.
extend
(
keys
)
if
cursor
==
0
:
break
return
[
key
.
decode
(
"utf-8"
)
for
key
in
all_keys
]
def
weight_iterator
(
self
,
rank
:
int
=
0
)
->
Generator
[
Tuple
[
str
,
bytes
],
None
,
None
]:
keys
=
self
.
list
(
f
"
{
self
.
model_name
}
/keys/rank_
{
rank
}
/"
)
for
key
in
keys
:
val
=
self
.
get
(
key
)
key
=
key
.
removeprefix
(
f
"
{
self
.
model_name
}
/keys/rank_
{
rank
}
/"
)
yield
key
,
val
def
pull_files
(
self
,
allow_pattern
:
Optional
[
List
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
pull_files_from_db
(
self
,
self
.
model_name
,
allow_pattern
,
ignore_pattern
)
def
close
(
self
):
self
.
connection
.
close
()
super
().
close
()
python/sglang/srt/connector/s3.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
fnmatch
import
os
from
pathlib
import
Path
from
typing
import
Generator
,
Optional
,
Tuple
import
torch
from
sglang.srt.connector
import
BaseFileConnector
def
_filter_allow
(
paths
:
list
[
str
],
patterns
:
list
[
str
])
->
list
[
str
]:
return
[
path
for
path
in
paths
if
any
(
fnmatch
.
fnmatch
(
path
,
pattern
)
for
pattern
in
patterns
)
]
def
_filter_ignore
(
paths
:
list
[
str
],
patterns
:
list
[
str
])
->
list
[
str
]:
return
[
path
for
path
in
paths
if
not
any
(
fnmatch
.
fnmatch
(
path
,
pattern
)
for
pattern
in
patterns
)
]
def
list_files
(
s3
,
path
:
str
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
list
[
str
]]
=
None
,
)
->
tuple
[
str
,
str
,
list
[
str
]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket
and the prefix as a dir like string
- The third element is a list of files allowed or
disallowed by pattern
"""
parts
=
path
.
removeprefix
(
"s3://"
).
split
(
"/"
)
prefix
=
"/"
.
join
(
parts
[
1
:])
bucket_name
=
parts
[
0
]
objects
=
s3
.
list_objects_v2
(
Bucket
=
bucket_name
,
Prefix
=
prefix
)
paths
=
[
obj
[
"Key"
]
for
obj
in
objects
.
get
(
"Contents"
,
[])]
paths
=
_filter_ignore
(
paths
,
[
"*/"
])
if
allow_pattern
is
not
None
:
paths
=
_filter_allow
(
paths
,
allow_pattern
)
if
ignore_pattern
is
not
None
:
paths
=
_filter_ignore
(
paths
,
ignore_pattern
)
return
bucket_name
,
prefix
,
paths
class
S3Connector
(
BaseFileConnector
):
def
__init__
(
self
,
url
:
str
)
->
None
:
import
boto3
super
().
__init__
(
url
)
self
.
client
=
boto3
.
client
(
"s3"
)
def
glob
(
self
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
)
->
list
[
str
]:
bucket_name
,
_
,
paths
=
list_files
(
self
.
client
,
path
=
self
.
url
,
allow_pattern
=
allow_pattern
)
return
[
f
"s3://
{
bucket_name
}
/
{
path
}
"
for
path
in
paths
]
def
pull_files
(
self
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
list
[
str
]]
=
None
,
)
->
None
:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name
,
base_dir
,
files
=
list_files
(
self
.
client
,
self
.
url
,
allow_pattern
,
ignore_pattern
)
if
len
(
files
)
==
0
:
return
for
file
in
files
:
destination_file
=
os
.
path
.
join
(
self
.
local_dir
,
file
.
removeprefix
(
base_dir
))
local_dir
=
Path
(
destination_file
).
parent
os
.
makedirs
(
local_dir
,
exist_ok
=
True
)
self
.
client
.
download_file
(
bucket_name
,
file
,
destination_file
)
def
weight_iterator
(
self
,
rank
:
int
=
0
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
from
sglang.srt.model_loader.weight_utils
import
(
runai_safetensors_weights_iterator
,
)
# only support safetensor files now
hf_weights_files
=
self
.
glob
(
allow_pattern
=
[
"*.safetensors"
])
return
runai_safetensors_weights_iterator
(
hf_weights_files
)
def
close
(
self
):
self
.
client
.
close
()
super
().
close
()
python/sglang/srt/connector/serde/__init__.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
# inspired by LMCache
from
typing
import
Optional
,
Tuple
import
torch
from
sglang.srt.connector.serde.safe_serde
import
SafeDeserializer
,
SafeSerializer
from
sglang.srt.connector.serde.serde
import
Deserializer
,
Serializer
def
create_serde
(
serde_type
:
str
)
->
Tuple
[
Serializer
,
Deserializer
]:
s
:
Optional
[
Serializer
]
=
None
d
:
Optional
[
Deserializer
]
=
None
if
serde_type
==
"safe"
:
s
=
SafeSerializer
()
d
=
SafeDeserializer
(
torch
.
uint8
)
else
:
raise
ValueError
(
f
"Unknown serde type:
{
serde_type
}
"
)
return
s
,
d
__all__
=
[
"Serializer"
,
"Deserializer"
,
"SafeSerializer"
,
"SafeDeserializer"
,
"create_serde"
,
]
python/sglang/srt/connector/serde/safe_serde.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Union
import
torch
from
safetensors.torch
import
load
,
save
from
sglang.srt.connector.serde.serde
import
Deserializer
,
Serializer
class
SafeSerializer
(
Serializer
):
def
__init__
(
self
):
super
().
__init__
()
def
to_bytes
(
self
,
t
:
torch
.
Tensor
)
->
bytes
:
return
save
({
"tensor_bytes"
:
t
.
cpu
().
contiguous
()})
class
SafeDeserializer
(
Deserializer
):
def
__init__
(
self
,
dtype
):
super
().
__init__
(
dtype
)
def
from_bytes_normal
(
self
,
b
:
Union
[
bytearray
,
bytes
])
->
torch
.
Tensor
:
return
load
(
bytes
(
b
))[
"tensor_bytes"
].
to
(
dtype
=
self
.
dtype
)
def
from_bytes
(
self
,
b
:
Union
[
bytearray
,
bytes
])
->
torch
.
Tensor
:
return
self
.
from_bytes_normal
(
b
)
python/sglang/srt/connector/serde/serde.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
abc
from
abc
import
ABC
,
abstractmethod
import
torch
class
Serializer
(
ABC
):
@
abstractmethod
def
to_bytes
(
self
,
t
:
torch
.
Tensor
)
->
bytes
:
"""
Serialize a pytorch tensor to bytes. The serialized bytes should contain
both the data and the metadata (shape, dtype, etc.) of the tensor.
Input:
t: the input pytorch tensor, can be on any device, in any shape,
with any dtype
Returns:
bytes: the serialized bytes
"""
raise
NotImplementedError
class
Deserializer
(
metaclass
=
abc
.
ABCMeta
):
def
__init__
(
self
,
dtype
):
self
.
dtype
=
dtype
@
abstractmethod
def
from_bytes
(
self
,
bs
:
bytes
)
->
torch
.
Tensor
:
"""
Deserialize a pytorch tensor from bytes.
Input:
bytes: a stream of bytes
Output:
torch.Tensor: the deserialized pytorch tensor
"""
raise
NotImplementedError
python/sglang/srt/connector/utils.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
os
from
pathlib
import
Path
from
typing
import
Optional
from
urllib.parse
import
urlparse
from
sglang.srt.connector
import
BaseConnector
def
parse_model_name
(
url
:
str
)
->
str
:
"""
Parse the model name from the url.
Only used for db connector
"""
parsed_url
=
urlparse
(
url
)
return
parsed_url
.
path
.
lstrip
(
"/"
)
def
pull_files_from_db
(
connector
:
BaseConnector
,
model_name
:
str
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
list
[
str
]]
=
None
,
)
->
None
:
prefix
=
f
"
{
model_name
}
/files/"
local_dir
=
connector
.
get_local_dir
()
files
=
connector
.
list
(
prefix
)
for
file
in
files
:
destination_file
=
os
.
path
.
join
(
local_dir
,
file
.
removeprefix
(
prefix
))
local_dir
=
Path
(
destination_file
).
parent
os
.
makedirs
(
local_dir
,
exist_ok
=
True
)
with
open
(
destination_file
,
"wb"
)
as
f
:
f
.
write
(
connector
.
getstr
(
file
).
encode
(
"utf-8"
))
python/sglang/srt/entrypoints/engine.py
View file @
1ce4878d
...
@@ -27,6 +27,9 @@ import signal
...
@@ -27,6 +27,9 @@ import signal
import
threading
import
threading
from
typing
import
AsyncIterator
,
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
AsyncIterator
,
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
zmq
import
zmq.asyncio
# Fix a bug of Python threading
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
...
@@ -44,6 +47,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -44,6 +47,8 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
RpcReqInput
,
RpcReqOutput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
...
@@ -57,6 +62,7 @@ from sglang.srt.utils import (
...
@@ -57,6 +62,7 @@ from sglang.srt.utils import (
MultiprocessingSerializer
,
MultiprocessingSerializer
,
assert_pkg_version
,
assert_pkg_version
,
configure_logger
,
configure_logger
,
get_zmq_socket
,
kill_process_tree
,
kill_process_tree
,
launch_dummy_health_check_server
,
launch_dummy_health_check_server
,
maybe_set_triton_cache_manager
,
maybe_set_triton_cache_manager
,
...
@@ -102,15 +108,25 @@ class Engine:
...
@@ -102,15 +108,25 @@ class Engine:
# Shutdown the subprocesses automatically when the program exits
# Shutdown the subprocesses automatically when the program exits
atexit
.
register
(
self
.
shutdown
)
atexit
.
register
(
self
.
shutdown
)
# Allocate ports for inter-process communications
port_args
=
PortArgs
.
init_new
(
server_args
)
logger
.
info
(
f
"
{
server_args
=
}
"
)
# Launch subprocesses
# Launch subprocesses
tokenizer_manager
,
scheduler_info
=
_launch_subprocesses
(
tokenizer_manager
,
scheduler_info
=
_launch_subprocesses
(
server_args
=
server_args
server_args
=
server_args
,
port_args
=
port_args
,
)
)
self
.
server_args
=
server_args
self
.
server_args
=
server_args
self
.
tokenizer_manager
=
tokenizer_manager
self
.
tokenizer_manager
=
tokenizer_manager
self
.
scheduler_info
=
scheduler_info
self
.
scheduler_info
=
scheduler_info
context
=
zmq
.
Context
(
2
)
self
.
send_to_rpc
=
get_zmq_socket
(
context
,
zmq
.
DEALER
,
port_args
.
rpc_ipc_name
,
True
)
def
generate
(
def
generate
(
self
,
self
,
# The input prompt. It can be a single prompt or a batch of prompts.
# The input prompt. It can be a single prompt or a batch of prompts.
...
@@ -350,6 +366,23 @@ class Engine:
...
@@ -350,6 +366,23 @@ class Engine:
self
.
tokenizer_manager
.
resume_memory_occupation
(
obj
,
None
)
self
.
tokenizer_manager
.
resume_memory_occupation
(
obj
,
None
)
)
)
"""
Execute an RPC call on all scheduler processes.
"""
def
collective_rpc
(
self
,
method
:
str
,
**
kwargs
):
obj
=
RpcReqInput
(
method
=
method
,
parameters
=
kwargs
)
self
.
send_to_rpc
.
send_pyobj
(
obj
)
recv_req
=
self
.
send_to_rpc
.
recv_pyobj
(
zmq
.
BLOCKY
)
assert
isinstance
(
recv_req
,
RpcReqOutput
)
assert
recv_req
.
success
,
recv_req
.
message
def
save_remote_model
(
self
,
**
kwargs
):
self
.
collective_rpc
(
"save_remote_model"
,
**
kwargs
)
def
save_sharded_model
(
self
,
**
kwargs
):
self
.
collective_rpc
(
"save_sharded_model"
,
**
kwargs
)
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
# Set global environments
# Set global environments
...
@@ -408,7 +441,9 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -408,7 +441,9 @@ def _set_envs_and_config(server_args: ServerArgs):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
_launch_subprocesses
(
server_args
:
ServerArgs
)
->
Tuple
[
TokenizerManager
,
Dict
]:
def
_launch_subprocesses
(
server_args
:
ServerArgs
,
port_args
:
Optional
[
PortArgs
]
=
None
)
->
Tuple
[
TokenizerManager
,
Dict
]:
"""
"""
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
"""
...
@@ -418,8 +453,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
...
@@ -418,8 +453,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
_set_envs_and_config
(
server_args
)
_set_envs_and_config
(
server_args
)
# Allocate ports for inter-process communications
# Allocate ports for inter-process communications
port_args
=
PortArgs
.
init_new
(
server_args
)
if
port_args
is
None
:
logger
.
info
(
f
"
{
server_args
=
}
"
)
port_args
=
PortArgs
.
init_new
(
server_args
)
logger
.
info
(
f
"
{
server_args
=
}
"
)
# If using model from www.modelscope.cn, first download the model.
# If using model from www.modelscope.cn, first download the model.
server_args
.
model_path
,
server_args
.
tokenizer_path
=
prepare_model_and_tokenizer
(
server_args
.
model_path
,
server_args
.
tokenizer_path
=
prepare_model_and_tokenizer
(
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
1ce4878d
...
@@ -37,6 +37,8 @@ from sglang.srt.configs import (
...
@@ -37,6 +37,8 @@ from sglang.srt.configs import (
MultiModalityConfig
,
MultiModalityConfig
,
Qwen2_5_VLConfig
,
Qwen2_5_VLConfig
,
)
)
from
sglang.srt.connector
import
create_remote_connector
from
sglang.srt.utils
import
is_remote_url
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
...
@@ -155,6 +157,14 @@ def get_tokenizer(
...
@@ -155,6 +157,14 @@ def get_tokenizer(
kwargs
[
"gguf_file"
]
=
tokenizer_name
kwargs
[
"gguf_file"
]
=
tokenizer_name
tokenizer_name
=
Path
(
tokenizer_name
).
parent
tokenizer_name
=
Path
(
tokenizer_name
).
parent
if
is_remote_url
(
tokenizer_name
):
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client
=
create_remote_connector
(
tokenizer_name
)
client
.
pull_files
(
ignore_pattern
=
[
"*.pt"
,
"*.safetensors"
,
"*.bin"
])
tokenizer_name
=
client
.
get_local_dir
()
try
:
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
tokenizer_name
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
1ce4878d
...
@@ -723,3 +723,15 @@ class SeparateReasoningReqInput:
...
@@ -723,3 +723,15 @@ class SeparateReasoningReqInput:
class
VertexGenerateReqInput
:
class
VertexGenerateReqInput
:
instances
:
List
[
dict
]
instances
:
List
[
dict
]
parameters
:
Optional
[
dict
]
=
None
parameters
:
Optional
[
dict
]
=
None
@
dataclass
class
RpcReqInput
:
method
:
str
parameters
:
Optional
[
Dict
]
=
None
@
dataclass
class
RpcReqOutput
:
success
:
bool
message
:
str
python/sglang/srt/managers/scheduler.py
View file @
1ce4878d
...
@@ -32,6 +32,7 @@ import psutil
...
@@ -32,6 +32,7 @@ import psutil
import
setproctitle
import
setproctitle
import
torch
import
torch
import
zmq
import
zmq
from
torch.distributed
import
barrier
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
...
@@ -59,6 +60,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -59,6 +60,8 @@ from sglang.srt.managers.io_struct import (
ReleaseMemoryOccupationReqOutput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqOutput
,
RpcReqInput
,
RpcReqOutput
,
SetInternalStateReq
,
SetInternalStateReq
,
SetInternalStateReqOutput
,
SetInternalStateReqOutput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
...
@@ -193,8 +196,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
...
@@ -193,8 +196,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
self
.
send_to_detokenizer
=
get_zmq_socket
(
self
.
send_to_detokenizer
=
get_zmq_socket
(
context
,
zmq
.
PUSH
,
port_args
.
detokenizer_ipc_name
,
False
context
,
zmq
.
PUSH
,
port_args
.
detokenizer_ipc_name
,
False
)
)
self
.
recv_from_rpc
=
get_zmq_socket
(
context
,
zmq
.
DEALER
,
port_args
.
rpc_ipc_name
,
False
)
else
:
else
:
self
.
recv_from_tokenizer
=
None
self
.
recv_from_tokenizer
=
None
self
.
recv_from_rpc
=
None
self
.
send_to_tokenizer
=
SimpleNamespace
(
send_pyobj
=
lambda
x
:
None
)
self
.
send_to_tokenizer
=
SimpleNamespace
(
send_pyobj
=
lambda
x
:
None
)
self
.
send_to_detokenizer
=
SimpleNamespace
(
send_pyobj
=
lambda
x
:
None
)
self
.
send_to_detokenizer
=
SimpleNamespace
(
send_pyobj
=
lambda
x
:
None
)
...
@@ -376,6 +384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
...
@@ -376,6 +384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
(
ProfileReq
,
self
.
profile
),
(
ProfileReq
,
self
.
profile
),
(
GetInternalStateReq
,
self
.
get_internal_state
),
(
GetInternalStateReq
,
self
.
get_internal_state
),
(
SetInternalStateReq
,
self
.
set_internal_state
),
(
SetInternalStateReq
,
self
.
set_internal_state
),
(
RpcReqInput
,
self
.
handle_rpc_request
),
]
]
)
)
...
@@ -549,6 +558,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
...
@@ -549,6 +558,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
except
zmq
.
ZMQError
:
except
zmq
.
ZMQError
:
break
break
recv_reqs
.
append
(
recv_req
)
recv_reqs
.
append
(
recv_req
)
while
True
:
try
:
recv_rpc
=
self
.
recv_from_rpc
.
recv_pyobj
(
zmq
.
NOBLOCK
)
except
zmq
.
ZMQError
:
break
recv_reqs
.
append
(
recv_rpc
)
else
:
else
:
recv_reqs
=
None
recv_reqs
=
None
...
@@ -600,7 +616,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
...
@@ -600,7 +616,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
output
=
self
.
_request_dispatcher
(
recv_req
)
output
=
self
.
_request_dispatcher
(
recv_req
)
if
output
is
not
None
:
if
output
is
not
None
:
self
.
send_to_tokenizer
.
send_pyobj
(
output
)
if
isinstance
(
output
,
RpcReqOutput
):
if
self
.
recv_from_rpc
is
not
None
:
self
.
recv_from_rpc
.
send_pyobj
(
output
)
else
:
self
.
send_to_tokenizer
.
send_pyobj
(
output
)
def
handle_generate_request
(
def
handle_generate_request
(
self
,
self
,
...
@@ -1492,6 +1512,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
...
@@ -1492,6 +1512,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
server_args
=
global_server_args_dict
,
server_args
=
global_server_args_dict
,
)
)
def
handle_rpc_request
(
self
,
recv_req
:
RpcReqInput
):
# Handle RPC requests
logger
.
info
(
f
"handle_rpc_request:
{
recv_req
.
method
}
, param:
{
recv_req
.
parameters
}
"
)
success
=
True
exec
=
None
try
:
func
=
getattr
(
self
,
recv_req
.
method
)
func
(
recv_req
.
parameters
)
except
Exception
as
e
:
success
=
False
exec
=
e
logger
.
error
(
f
"Failed to call rpc
{
recv_req
.
method
}
:
{
str
(
e
)
}
"
)
barrier
()
return
RpcReqOutput
(
success
,
""
if
not
exec
else
str
(
exec
))
def
save_remote_model
(
self
,
params
):
url
=
params
[
"url"
]
if
isinstance
(
self
.
tp_worker
,
TpModelWorkerClient
):
worker
=
self
.
tp_worker
.
worker
else
:
worker
=
self
.
tp_worker
worker
.
model_runner
.
save_remote_model
(
url
)
def
save_sharded_model
(
self
,
params
):
if
isinstance
(
self
.
tp_worker
,
TpModelWorkerClient
):
worker
=
self
.
tp_worker
.
worker
else
:
worker
=
self
.
tp_worker
worker
.
model_runner
.
save_sharded_model
(
path
=
params
[
"path"
],
pattern
=
params
[
"pattern"
],
max_size
=
params
[
"max_size"
],
)
def
abort_request
(
self
,
recv_req
:
AbortReq
):
def
abort_request
(
self
,
recv_req
:
AbortReq
):
# Delete requests in the waiting queue
# Delete requests in the waiting queue
to_del
=
[]
to_del
=
[]
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
1ce4878d
...
@@ -1009,6 +1009,22 @@ class ModelRunner:
...
@@ -1009,6 +1009,22 @@ class ModelRunner:
return
False
return
False
return
rope_scaling
.
get
(
"type"
,
None
)
==
"mrope"
return
rope_scaling
.
get
(
"type"
,
None
)
==
"mrope"
def
save_remote_model
(
self
,
url
:
str
):
from
sglang.srt.model_loader.loader
import
RemoteModelLoader
logger
.
info
(
f
"Saving model to
{
url
}
"
)
RemoteModelLoader
.
save_model
(
self
.
model
,
self
.
model_config
.
model_path
,
url
)
def
save_sharded_model
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
):
from
sglang.srt.model_loader.loader
import
ShardedStateLoader
logger
.
info
(
f
"Save sharded model to
{
path
}
with pattern
{
pattern
}
and max_size
{
max_size
}
"
)
ShardedStateLoader
.
save_model
(
self
.
model
,
path
,
pattern
,
max_size
)
def
_model_load_weights_direct
(
model
,
named_tensors
:
List
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
_model_load_weights_direct
(
model
,
named_tensors
:
List
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
model
.
named_parameters
())
params_dict
=
dict
(
model
.
named_parameters
())
...
...
python/sglang/srt/model_loader/loader.py
View file @
1ce4878d
...
@@ -9,6 +9,7 @@ import json
...
@@ -9,6 +9,7 @@ import json
import
logging
import
logging
import
math
import
math
import
os
import
os
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
...
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
...
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.connector
import
(
ConnectorType
,
create_remote_connector
,
get_connector_type
,
)
from
sglang.srt.connector.utils
import
parse_model_name
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
...
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
...
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
np_cache_weights_iterator
,
np_cache_weights_iterator
,
pt_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
,
safetensors_weights_iterator
,
set_runai_streamer_env
,
)
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_bool_env_var
,
...
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
Model loader that directly loads each worker's model state dict, which
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_state.py` for creating a sharded checkpoint.
`examples/
runtime/engine/
save_sharded_state.py` for creating a sharded checkpoint.
"""
"""
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
...
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
return
model
return
model
class
RemoteModelLoader
(
BaseModelLoader
):
"""Model loader that can load Tensors from remote database."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
# TODO @DellCurry: move to s3 connector only
set_runai_streamer_env
(
load_config
)
def
_get_weights_iterator_kv
(
self
,
client
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights from remote storage."""
assert
get_connector_type
(
client
)
==
ConnectorType
.
KV
rank
=
get_tensor_model_parallel_rank
()
return
client
.
weight_iterator
(
rank
)
def
_get_weights_iterator_fs
(
self
,
client
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights from remote storage."""
assert
get_connector_type
(
client
)
==
ConnectorType
.
FS
return
client
.
weight_iterator
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
model_path
:
str
,
url
:
str
,
)
->
None
:
with
create_remote_connector
(
url
)
as
client
:
assert
get_connector_type
(
client
)
==
ConnectorType
.
KV
model_name
=
parse_model_name
(
url
)
rank
=
get_tensor_model_parallel_rank
()
state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
model
.
state_dict
())
for
key
,
tensor
in
state_dict
.
items
():
r_key
=
f
"
{
model_name
}
/keys/rank_
{
rank
}
/
{
key
}
"
client
.
set
(
r_key
,
tensor
)
for
root
,
_
,
files
in
os
.
walk
(
model_path
):
for
file_name
in
files
:
# ignore hidden files
if
file_name
.
startswith
(
"."
):
continue
if
os
.
path
.
splitext
(
file_name
)[
1
]
not
in
(
".bin"
,
".pt"
,
".safetensors"
,
):
file_path
=
os
.
path
.
join
(
root
,
file_name
)
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
file
:
file_content
=
file
.
read
()
f_key
=
f
"
{
model_name
}
/files/
{
file_name
}
"
client
.
setstr
(
f_key
,
file_content
)
def
_load_model_from_remote_kv
(
self
,
model
:
nn
.
Module
,
client
):
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
weights_iterator
=
self
.
_get_weights_iterator_kv
(
client
)
state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
model
.
state_dict
())
for
key
,
tensor
in
weights_iterator
:
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data
=
state_dict
[
key
].
data
param_shape
=
state_dict
[
key
].
shape
for
dim
,
size
in
enumerate
(
tensor
.
shape
):
if
size
<
param_shape
[
dim
]:
param_data
=
param_data
.
narrow
(
dim
,
0
,
size
)
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
key
,
param_shape
,
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
def
_load_model_from_remote_fs
(
self
,
model
,
client
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
)
->
nn
.
Module
:
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
model
.
load_weights
(
self
.
_get_weights_iterator_fs
(
client
))
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
logger
.
info
(
"Loading weights from remote storage ..."
)
start
=
time
.
perf_counter
()
load_config
=
self
.
load_config
assert
load_config
.
load_format
==
LoadFormat
.
REMOTE
,
(
f
"Model loader
{
self
.
load_config
.
load_format
}
is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
model_weights
=
model_config
.
model_path
if
hasattr
(
model_config
,
"model_weights"
):
model_weights
=
model_config
.
model_weights
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
with
create_remote_connector
(
model_weights
,
device_config
.
device
)
as
client
:
connector_type
=
get_connector_type
(
client
)
if
connector_type
==
ConnectorType
.
KV
:
self
.
_load_model_from_remote_kv
(
model
,
client
)
elif
connector_type
==
ConnectorType
.
FS
:
self
.
_load_model_from_remote_fs
(
model
,
client
,
model_config
,
device_config
)
end
=
time
.
perf_counter
()
logger
.
info
(
"Loaded weights from remote storage in %.2f seconds."
,
end
-
start
)
return
model
.
eval
()
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
"""Get a model loader based on the load format."""
...
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
...
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if
load_config
.
load_format
==
LoadFormat
.
LAYERED
:
if
load_config
.
load_format
==
LoadFormat
.
LAYERED
:
return
LayeredModelLoader
(
load_config
)
return
LayeredModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
REMOTE
:
return
RemoteModelLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
python/sglang/srt/model_loader/weight_utils.py
View file @
1ce4878d
...
@@ -585,6 +585,51 @@ def composed_weight_loader(
...
@@ -585,6 +585,51 @@ def composed_weight_loader(
return
composed_loader
return
composed_loader
def
runai_safetensors_weights_iterator
(
hf_weights_files
:
List
[
str
],
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model safetensor files."""
from
runai_model_streamer
import
SafetensorsStreamer
enable_tqdm
=
(
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_rank
()
==
0
)
with
SafetensorsStreamer
()
as
streamer
:
for
st_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading safetensors using Runai Model Streamer"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
):
streamer
.
stream_file
(
st_file
)
yield
from
streamer
.
get_tensors
()
def
set_runai_streamer_env
(
load_config
:
LoadConfig
):
if
load_config
.
model_loader_extra_config
:
extra_config
=
load_config
.
model_loader_extra_config
if
"concurrency"
in
extra_config
and
isinstance
(
extra_config
.
get
(
"concurrency"
),
int
):
os
.
environ
[
"RUNAI_STREAMER_CONCURRENCY"
]
=
str
(
extra_config
.
get
(
"concurrency"
)
)
if
"memory_limit"
in
extra_config
and
isinstance
(
extra_config
.
get
(
"memory_limit"
),
int
):
os
.
environ
[
"RUNAI_STREAMER_MEMORY_LIMIT"
]
=
str
(
extra_config
.
get
(
"memory_limit"
)
)
runai_streamer_s3_endpoint
=
os
.
getenv
(
"RUNAI_STREAMER_S3_ENDPOINT"
)
aws_endpoint_url
=
os
.
getenv
(
"AWS_ENDPOINT_URL"
)
if
runai_streamer_s3_endpoint
is
None
and
aws_endpoint_url
is
not
None
:
os
.
environ
[
"RUNAI_STREAMER_S3_ENDPOINT"
]
=
aws_endpoint_url
def
initialize_dummy_weights
(
def
initialize_dummy_weights
(
model
:
torch
.
nn
.
Module
,
model
:
torch
.
nn
.
Module
,
low
:
float
=
-
1e-3
,
low
:
float
=
-
1e-3
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment