Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1ce4878d
Unverified
Commit
1ce4878d
authored
Mar 14, 2025
by
wangyu
Committed by
GitHub
Mar 14, 2025
Browse files
feat(remote_model): support variable remote backend for model loader (#3964)
Signed-off-by:
wangyu
<
wangyu.steph@bytedance.com
>
parent
977d7cd2
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1005 additions
and
7 deletions
+1005
-7
examples/runtime/engine/save_remote_state.py
examples/runtime/engine/save_remote_state.py
+51
-0
examples/runtime/engine/save_sharded_state.py
examples/runtime/engine/save_sharded_state.py
+74
-0
python/sglang/__init__.py
python/sglang/__init__.py
+2
-0
python/sglang/srt/configs/load_config.py
python/sglang/srt/configs/load_config.py
+1
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+25
-1
python/sglang/srt/connector/__init__.py
python/sglang/srt/connector/__init__.py
+51
-0
python/sglang/srt/connector/base_connector.py
python/sglang/srt/connector/base_connector.py
+112
-0
python/sglang/srt/connector/redis.py
python/sglang/srt/connector/redis.py
+85
-0
python/sglang/srt/connector/s3.py
python/sglang/srt/connector/s3.py
+122
-0
python/sglang/srt/connector/serde/__init__.py
python/sglang/srt/connector/serde/__init__.py
+31
-0
python/sglang/srt/connector/serde/safe_serde.py
python/sglang/srt/connector/serde/safe_serde.py
+29
-0
python/sglang/srt/connector/serde/serde.py
python/sglang/srt/connector/serde/serde.py
+43
-0
python/sglang/srt/connector/utils.py
python/sglang/srt/connector/utils.py
+35
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+40
-4
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+10
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+12
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+62
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+16
-0
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+159
-1
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+45
-0
No files found.
examples/runtime/engine/save_remote_state.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_remote_state.py
\
--model-path /path/to/load
\
--tensor-parallel-size 8
\
--remote-model-save-url [protocol]://[host]:[port]/[model_name]
\
Then, the model can be loaded with
llm = Engine(
model_path="/path/to/save",
--remote-model-url [protocol]://[host]:[port]/[model_name],
tensor_parallel_size=8,
)
"""
import
dataclasses
from
argparse
import
ArgumentParser
from
pathlib
import
Path
from
sglang
import
Engine
,
ServerArgs
parser
=
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
"--remote-model-save-url"
,
required
=
True
,
type
=
str
,
help
=
"remote address to store model weights"
,
)
def
main
(
args
):
engine_args
=
ServerArgs
.
from_cli_args
(
args
)
model_path
=
engine_args
.
model_path
if
not
Path
(
model_path
).
is_dir
():
raise
ValueError
(
"model path must be a local directory"
)
# Create LLM instance from arguments
llm
=
Engine
(
**
dataclasses
.
asdict
(
engine_args
))
llm
.
save_remote_model
(
url
=
args
.
remote_model_save_url
)
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
main
(
args
)
examples/runtime/engine/save_sharded_state.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_sharded_state.py
\
--model-path /path/to/load
\
--quantization deepspeedfp
\
--tensor-parallel-size 8
\
--output /path/to/save
Then, the model can be loaded with
llm = Engine(
model_path="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import
dataclasses
import
os
import
shutil
from
argparse
import
ArgumentParser
from
pathlib
import
Path
from
sglang
import
Engine
,
ServerArgs
parser
=
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
"--output"
,
"-o"
,
required
=
True
,
type
=
str
,
help
=
"path to output checkpoint"
)
parser
.
add_argument
(
"--file-pattern"
,
type
=
str
,
help
=
"string pattern of saved filenames"
)
parser
.
add_argument
(
"--max-file-size"
,
type
=
str
,
default
=
5
*
1024
**
3
,
help
=
"max size (in bytes) of each safetensors file"
,
)
def
main
(
args
):
engine_args
=
ServerArgs
.
from_cli_args
(
args
)
model_path
=
engine_args
.
model_path
if
not
Path
(
model_path
).
is_dir
():
raise
ValueError
(
"model path must be a local directory"
)
# Create LLM instance from arguments
llm
=
Engine
(
**
dataclasses
.
asdict
(
engine_args
))
Path
(
args
.
output
).
mkdir
(
exist_ok
=
True
)
llm
.
save_sharded_model
(
path
=
args
.
output
,
pattern
=
args
.
file_pattern
,
max_size
=
args
.
max_file_size
)
# Copy metadata files to output directory
for
file
in
os
.
listdir
(
model_path
):
if
os
.
path
.
splitext
(
file
)[
1
]
not
in
(
".bin"
,
".pt"
,
".safetensors"
):
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_path
,
file
)):
shutil
.
copytree
(
os
.
path
.
join
(
model_path
,
file
),
os
.
path
.
join
(
args
.
output
,
file
)
)
else
:
shutil
.
copy
(
os
.
path
.
join
(
model_path
,
file
),
args
.
output
)
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
main
(
args
)
python/sglang/__init__.py
View file @
1ce4878d
...
...
@@ -32,6 +32,7 @@ from sglang.lang.choices import (
)
from
sglang.utils
import
LazyImport
ServerArgs
=
LazyImport
(
"sglang.srt.server_args"
,
"ServerArgs"
)
Anthropic
=
LazyImport
(
"sglang.lang.backend.anthropic"
,
"Anthropic"
)
LiteLLM
=
LazyImport
(
"sglang.lang.backend.litellm"
,
"LiteLLM"
)
OpenAI
=
LazyImport
(
"sglang.lang.backend.openai"
,
"OpenAI"
)
...
...
@@ -67,6 +68,7 @@ __all__ = [
"greedy_token_selection"
,
"token_length_normalized"
,
"unconditional_likelihood_normalized"
,
"ServerArgs"
,
"Anthropic"
,
"LiteLLM"
,
"OpenAI"
,
...
...
python/sglang/srt/configs/load_config.py
View file @
1ce4878d
...
...
@@ -22,6 +22,7 @@ class LoadFormat(str, enum.Enum):
MISTRAL
=
"mistral"
LAYERED
=
"layered"
JAX
=
"jax"
REMOTE
=
"remote"
@
dataclass
...
...
python/sglang/srt/configs/model_config.py
View file @
1ce4878d
...
...
@@ -51,13 +51,14 @@ class ModelConfig:
self
.
quantization
=
quantization
# Parse args
self
.
maybe_pull_model_tokenizer_from_remote
()
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
kwargs
=
{}
if
override_config_file
and
override_config_file
.
strip
():
kwargs
[
"_configuration_file"
]
=
override_config_file
.
strip
()
self
.
hf_config
=
get_config
(
model_path
,
self
.
model_path
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
model_override_args
=
self
.
model_override_args
,
...
...
@@ -318,6 +319,29 @@ class ModelConfig:
eos_ids
=
{
eos_ids
}
if
isinstance
(
eos_ids
,
int
)
else
set
(
eos_ids
)
return
eos_ids
def
maybe_pull_model_tokenizer_from_remote
(
self
)
->
None
:
"""
Pull the model config files to a temporary
directory in case of remote.
Args:
model: The model name or path.
"""
from
sglang.srt.connector
import
create_remote_connector
from
sglang.srt.utils
import
is_remote_url
if
is_remote_url
(
self
.
model_path
):
logger
.
info
(
"Pulling model configs from remote..."
)
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client
=
create_remote_connector
(
self
.
model_path
)
if
is_remote_url
(
self
.
model_path
):
client
.
pull_files
(
allow_pattern
=
[
"*config.json"
])
self
.
model_weights
=
self
.
model_path
self
.
model_path
=
client
.
get_local_dir
()
def
get_hf_text_config
(
config
:
PretrainedConfig
):
"""Get the "sub" config relevant to llm for multi modal models.
...
...
python/sglang/srt/connector/__init__.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
enum
import
logging
from
sglang.srt.connector.base_connector
import
(
BaseConnector
,
BaseFileConnector
,
BaseKVConnector
,
)
from
sglang.srt.connector.redis
import
RedisConnector
from
sglang.srt.connector.s3
import
S3Connector
from
sglang.srt.utils
import
parse_connector_type
logger
=
logging
.
getLogger
(
__name__
)
class
ConnectorType
(
str
,
enum
.
Enum
):
FS
=
"filesystem"
KV
=
"KV"
def
create_remote_connector
(
url
,
device
=
"cpu"
)
->
BaseConnector
:
connector_type
=
parse_connector_type
(
url
)
if
connector_type
==
"redis"
:
return
RedisConnector
(
url
)
elif
connector_type
==
"s3"
:
return
S3Connector
(
url
)
else
:
raise
ValueError
(
f
"Invalid connector type:
{
url
}
"
)
def
get_connector_type
(
client
:
BaseConnector
)
->
ConnectorType
:
if
isinstance
(
client
,
BaseKVConnector
):
return
ConnectorType
.
KV
if
isinstance
(
client
,
BaseFileConnector
):
return
ConnectorType
.
FS
raise
ValueError
(
f
"Invalid connector type:
{
client
}
"
)
__all__
=
[
"BaseConnector"
,
"BaseFileConnector"
,
"BaseKVConnector"
,
"RedisConnector"
,
"S3Connector"
,
"ConnectorType"
,
"create_remote_connector"
,
"get_connector_type"
,
]
python/sglang/srt/connector/base_connector.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
os
import
shutil
import
signal
import
tempfile
from
abc
import
ABC
,
abstractmethod
from
typing
import
Generator
,
List
,
Optional
,
Tuple
import
torch
class
BaseConnector
(
ABC
):
"""
For fs connector such as s3:
<connector_type>://<path>/<filename>
For kv connector such as redis:
<connector_type>://<host>:<port>/<model_name>/keys/<key>
<connector_type://<host>:<port>/<model_name>/files/<filename>
"""
def
__init__
(
self
,
url
:
str
,
device
:
torch
.
device
=
"cpu"
):
self
.
url
=
url
self
.
device
=
device
self
.
closed
=
False
self
.
local_dir
=
tempfile
.
mkdtemp
()
for
sig
in
(
signal
.
SIGINT
,
signal
.
SIGTERM
):
existing_handler
=
signal
.
getsignal
(
sig
)
signal
.
signal
(
sig
,
self
.
_close_by_signal
(
existing_handler
))
def
get_local_dir
(
self
):
return
self
.
local_dir
@
abstractmethod
def
weight_iterator
(
self
,
rank
:
int
=
0
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
raise
NotImplementedError
()
@
abstractmethod
def
pull_files
(
self
,
allow_pattern
:
Optional
[
List
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
raise
NotImplementedError
()
def
close
(
self
):
if
self
.
closed
:
return
self
.
closed
=
True
if
os
.
path
.
exists
(
self
.
local_dir
):
shutil
.
rmtree
(
self
.
local_dir
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
close
()
def
__del__
(
self
):
self
.
close
()
def
_close_by_signal
(
self
,
existing_handler
=
None
):
def
new_handler
(
signum
,
frame
):
self
.
close
()
if
existing_handler
:
existing_handler
(
signum
,
frame
)
return
new_handler
class
BaseKVConnector
(
BaseConnector
):
@
abstractmethod
def
get
(
self
,
key
:
str
)
->
Optional
[
torch
.
Tensor
]:
raise
NotImplementedError
()
@
abstractmethod
def
getstr
(
self
,
key
:
str
)
->
Optional
[
str
]:
raise
NotImplementedError
()
@
abstractmethod
def
set
(
self
,
key
:
str
,
obj
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
()
@
abstractmethod
def
setstr
(
self
,
key
:
str
,
obj
:
str
)
->
None
:
raise
NotImplementedError
()
@
abstractmethod
def
list
(
self
,
prefix
:
str
)
->
List
[
str
]:
raise
NotImplementedError
()
class
BaseFileConnector
(
BaseConnector
):
"""
List full file names from remote fs path and filter by allow pattern.
Args:
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full paths allowed by the pattern
"""
@
abstractmethod
def
glob
(
self
,
allow_pattern
:
str
)
->
List
[
str
]:
raise
NotImplementedError
()
python/sglang/srt/connector/redis.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
logging
from
typing
import
Generator
,
List
,
Optional
,
Tuple
from
urllib.parse
import
urlparse
import
torch
from
sglang.srt.connector
import
BaseKVConnector
from
sglang.srt.connector.serde
import
create_serde
from
sglang.srt.connector.utils
import
pull_files_from_db
logger
=
logging
.
getLogger
(
__name__
)
class
RedisConnector
(
BaseKVConnector
):
def
__init__
(
self
,
url
:
str
,
device
:
torch
.
device
=
"cpu"
):
import
redis
super
().
__init__
(
url
,
device
)
parsed_url
=
urlparse
(
url
)
self
.
connection
=
redis
.
Redis
(
host
=
parsed_url
.
hostname
,
port
=
parsed_url
.
port
)
self
.
model_name
=
parsed_url
.
path
.
lstrip
(
"/"
)
# TODO: more serde options
self
.
s
,
self
.
d
=
create_serde
(
"safe"
)
def
get
(
self
,
key
:
str
)
->
Optional
[
torch
.
Tensor
]:
val
=
self
.
connection
.
get
(
key
)
if
val
is
None
:
logger
.
error
(
"Key %s not found"
,
key
)
return
None
return
self
.
d
.
from_bytes
(
val
)
def
getstr
(
self
,
key
:
str
)
->
Optional
[
str
]:
val
=
self
.
connection
.
get
(
key
)
if
val
is
None
:
logger
.
error
(
"Key %s not found"
,
key
)
return
None
return
val
.
decode
(
"utf-8"
)
def
set
(
self
,
key
:
str
,
tensor
:
torch
.
Tensor
)
->
None
:
assert
tensor
is
not
None
self
.
connection
.
set
(
key
,
self
.
s
.
to_bytes
(
tensor
))
def
setstr
(
self
,
key
:
str
,
obj
:
str
)
->
None
:
self
.
connection
.
set
(
key
,
obj
)
def
list
(
self
,
prefix
:
str
)
->
List
[
str
]:
cursor
=
0
all_keys
:
List
[
bytes
]
=
[]
while
True
:
ret
:
Tuple
[
int
,
List
[
bytes
]]
=
self
.
connection
.
scan
(
cursor
=
cursor
,
match
=
f
"
{
prefix
}
*"
)
# type: ignore
cursor
,
keys
=
ret
all_keys
.
extend
(
keys
)
if
cursor
==
0
:
break
return
[
key
.
decode
(
"utf-8"
)
for
key
in
all_keys
]
def
weight_iterator
(
self
,
rank
:
int
=
0
)
->
Generator
[
Tuple
[
str
,
bytes
],
None
,
None
]:
keys
=
self
.
list
(
f
"
{
self
.
model_name
}
/keys/rank_
{
rank
}
/"
)
for
key
in
keys
:
val
=
self
.
get
(
key
)
key
=
key
.
removeprefix
(
f
"
{
self
.
model_name
}
/keys/rank_
{
rank
}
/"
)
yield
key
,
val
def
pull_files
(
self
,
allow_pattern
:
Optional
[
List
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
pull_files_from_db
(
self
,
self
.
model_name
,
allow_pattern
,
ignore_pattern
)
def
close
(
self
):
self
.
connection
.
close
()
super
().
close
()
python/sglang/srt/connector/s3.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
fnmatch
import
os
from
pathlib
import
Path
from
typing
import
Generator
,
Optional
,
Tuple
import
torch
from
sglang.srt.connector
import
BaseFileConnector
def
_filter_allow
(
paths
:
list
[
str
],
patterns
:
list
[
str
])
->
list
[
str
]:
return
[
path
for
path
in
paths
if
any
(
fnmatch
.
fnmatch
(
path
,
pattern
)
for
pattern
in
patterns
)
]
def
_filter_ignore
(
paths
:
list
[
str
],
patterns
:
list
[
str
])
->
list
[
str
]:
return
[
path
for
path
in
paths
if
not
any
(
fnmatch
.
fnmatch
(
path
,
pattern
)
for
pattern
in
patterns
)
]
def
list_files
(
s3
,
path
:
str
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
list
[
str
]]
=
None
,
)
->
tuple
[
str
,
str
,
list
[
str
]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket
and the prefix as a dir like string
- The third element is a list of files allowed or
disallowed by pattern
"""
parts
=
path
.
removeprefix
(
"s3://"
).
split
(
"/"
)
prefix
=
"/"
.
join
(
parts
[
1
:])
bucket_name
=
parts
[
0
]
objects
=
s3
.
list_objects_v2
(
Bucket
=
bucket_name
,
Prefix
=
prefix
)
paths
=
[
obj
[
"Key"
]
for
obj
in
objects
.
get
(
"Contents"
,
[])]
paths
=
_filter_ignore
(
paths
,
[
"*/"
])
if
allow_pattern
is
not
None
:
paths
=
_filter_allow
(
paths
,
allow_pattern
)
if
ignore_pattern
is
not
None
:
paths
=
_filter_ignore
(
paths
,
ignore_pattern
)
return
bucket_name
,
prefix
,
paths
class
S3Connector
(
BaseFileConnector
):
def
__init__
(
self
,
url
:
str
)
->
None
:
import
boto3
super
().
__init__
(
url
)
self
.
client
=
boto3
.
client
(
"s3"
)
def
glob
(
self
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
)
->
list
[
str
]:
bucket_name
,
_
,
paths
=
list_files
(
self
.
client
,
path
=
self
.
url
,
allow_pattern
=
allow_pattern
)
return
[
f
"s3://
{
bucket_name
}
/
{
path
}
"
for
path
in
paths
]
def
pull_files
(
self
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
list
[
str
]]
=
None
,
)
->
None
:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name
,
base_dir
,
files
=
list_files
(
self
.
client
,
self
.
url
,
allow_pattern
,
ignore_pattern
)
if
len
(
files
)
==
0
:
return
for
file
in
files
:
destination_file
=
os
.
path
.
join
(
self
.
local_dir
,
file
.
removeprefix
(
base_dir
))
local_dir
=
Path
(
destination_file
).
parent
os
.
makedirs
(
local_dir
,
exist_ok
=
True
)
self
.
client
.
download_file
(
bucket_name
,
file
,
destination_file
)
def
weight_iterator
(
self
,
rank
:
int
=
0
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
from
sglang.srt.model_loader.weight_utils
import
(
runai_safetensors_weights_iterator
,
)
# only support safetensor files now
hf_weights_files
=
self
.
glob
(
allow_pattern
=
[
"*.safetensors"
])
return
runai_safetensors_weights_iterator
(
hf_weights_files
)
def
close
(
self
):
self
.
client
.
close
()
super
().
close
()
python/sglang/srt/connector/serde/__init__.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
# inspired by LMCache
from
typing
import
Optional
,
Tuple
import
torch
from
sglang.srt.connector.serde.safe_serde
import
SafeDeserializer
,
SafeSerializer
from
sglang.srt.connector.serde.serde
import
Deserializer
,
Serializer
def
create_serde
(
serde_type
:
str
)
->
Tuple
[
Serializer
,
Deserializer
]:
s
:
Optional
[
Serializer
]
=
None
d
:
Optional
[
Deserializer
]
=
None
if
serde_type
==
"safe"
:
s
=
SafeSerializer
()
d
=
SafeDeserializer
(
torch
.
uint8
)
else
:
raise
ValueError
(
f
"Unknown serde type:
{
serde_type
}
"
)
return
s
,
d
__all__
=
[
"Serializer"
,
"Deserializer"
,
"SafeSerializer"
,
"SafeDeserializer"
,
"create_serde"
,
]
python/sglang/srt/connector/serde/safe_serde.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Union
import
torch
from
safetensors.torch
import
load
,
save
from
sglang.srt.connector.serde.serde
import
Deserializer
,
Serializer
class
SafeSerializer
(
Serializer
):
def
__init__
(
self
):
super
().
__init__
()
def
to_bytes
(
self
,
t
:
torch
.
Tensor
)
->
bytes
:
return
save
({
"tensor_bytes"
:
t
.
cpu
().
contiguous
()})
class
SafeDeserializer
(
Deserializer
):
def
__init__
(
self
,
dtype
):
super
().
__init__
(
dtype
)
def
from_bytes_normal
(
self
,
b
:
Union
[
bytearray
,
bytes
])
->
torch
.
Tensor
:
return
load
(
bytes
(
b
))[
"tensor_bytes"
].
to
(
dtype
=
self
.
dtype
)
def
from_bytes
(
self
,
b
:
Union
[
bytearray
,
bytes
])
->
torch
.
Tensor
:
return
self
.
from_bytes_normal
(
b
)
python/sglang/srt/connector/serde/serde.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
abc
from
abc
import
ABC
,
abstractmethod
import
torch
class
Serializer
(
ABC
):
@
abstractmethod
def
to_bytes
(
self
,
t
:
torch
.
Tensor
)
->
bytes
:
"""
Serialize a pytorch tensor to bytes. The serialized bytes should contain
both the data and the metadata (shape, dtype, etc.) of the tensor.
Input:
t: the input pytorch tensor, can be on any device, in any shape,
with any dtype
Returns:
bytes: the serialized bytes
"""
raise
NotImplementedError
class
Deserializer
(
metaclass
=
abc
.
ABCMeta
):
def
__init__
(
self
,
dtype
):
self
.
dtype
=
dtype
@
abstractmethod
def
from_bytes
(
self
,
bs
:
bytes
)
->
torch
.
Tensor
:
"""
Deserialize a pytorch tensor from bytes.
Input:
bytes: a stream of bytes
Output:
torch.Tensor: the deserialized pytorch tensor
"""
raise
NotImplementedError
python/sglang/srt/connector/utils.py
0 → 100644
View file @
1ce4878d
# SPDX-License-Identifier: Apache-2.0
import
os
from
pathlib
import
Path
from
typing
import
Optional
from
urllib.parse
import
urlparse
from
sglang.srt.connector
import
BaseConnector
def
parse_model_name
(
url
:
str
)
->
str
:
"""
Parse the model name from the url.
Only used for db connector
"""
parsed_url
=
urlparse
(
url
)
return
parsed_url
.
path
.
lstrip
(
"/"
)
def
pull_files_from_db
(
connector
:
BaseConnector
,
model_name
:
str
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
list
[
str
]]
=
None
,
)
->
None
:
prefix
=
f
"
{
model_name
}
/files/"
local_dir
=
connector
.
get_local_dir
()
files
=
connector
.
list
(
prefix
)
for
file
in
files
:
destination_file
=
os
.
path
.
join
(
local_dir
,
file
.
removeprefix
(
prefix
))
local_dir
=
Path
(
destination_file
).
parent
os
.
makedirs
(
local_dir
,
exist_ok
=
True
)
with
open
(
destination_file
,
"wb"
)
as
f
:
f
.
write
(
connector
.
getstr
(
file
).
encode
(
"utf-8"
))
python/sglang/srt/entrypoints/engine.py
View file @
1ce4878d
...
...
@@ -27,6 +27,9 @@ import signal
import
threading
from
typing
import
AsyncIterator
,
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
zmq
import
zmq.asyncio
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
...
...
@@ -44,6 +47,8 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput
,
ReleaseMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
RpcReqInput
,
RpcReqOutput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
...
...
@@ -57,6 +62,7 @@ from sglang.srt.utils import (
MultiprocessingSerializer
,
assert_pkg_version
,
configure_logger
,
get_zmq_socket
,
kill_process_tree
,
launch_dummy_health_check_server
,
maybe_set_triton_cache_manager
,
...
...
@@ -102,15 +108,25 @@ class Engine:
# Shutdown the subprocesses automatically when the program exits
atexit
.
register
(
self
.
shutdown
)
# Allocate ports for inter-process communications
port_args
=
PortArgs
.
init_new
(
server_args
)
logger
.
info
(
f
"
{
server_args
=
}
"
)
# Launch subprocesses
tokenizer_manager
,
scheduler_info
=
_launch_subprocesses
(
server_args
=
server_args
server_args
=
server_args
,
port_args
=
port_args
,
)
self
.
server_args
=
server_args
self
.
tokenizer_manager
=
tokenizer_manager
self
.
scheduler_info
=
scheduler_info
context
=
zmq
.
Context
(
2
)
self
.
send_to_rpc
=
get_zmq_socket
(
context
,
zmq
.
DEALER
,
port_args
.
rpc_ipc_name
,
True
)
def
generate
(
self
,
# The input prompt. It can be a single prompt or a batch of prompts.
...
...
@@ -350,6 +366,23 @@ class Engine:
self
.
tokenizer_manager
.
resume_memory_occupation
(
obj
,
None
)
)
"""
Execute an RPC call on all scheduler processes.
"""
def
collective_rpc
(
self
,
method
:
str
,
**
kwargs
):
obj
=
RpcReqInput
(
method
=
method
,
parameters
=
kwargs
)
self
.
send_to_rpc
.
send_pyobj
(
obj
)
recv_req
=
self
.
send_to_rpc
.
recv_pyobj
(
zmq
.
BLOCKY
)
assert
isinstance
(
recv_req
,
RpcReqOutput
)
assert
recv_req
.
success
,
recv_req
.
message
def
save_remote_model
(
self
,
**
kwargs
):
self
.
collective_rpc
(
"save_remote_model"
,
**
kwargs
)
def
save_sharded_model
(
self
,
**
kwargs
):
self
.
collective_rpc
(
"save_sharded_model"
,
**
kwargs
)
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
# Set global environments
...
...
@@ -408,7 +441,9 @@ def _set_envs_and_config(server_args: ServerArgs):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
_launch_subprocesses
(
server_args
:
ServerArgs
)
->
Tuple
[
TokenizerManager
,
Dict
]:
def
_launch_subprocesses
(
server_args
:
ServerArgs
,
port_args
:
Optional
[
PortArgs
]
=
None
)
->
Tuple
[
TokenizerManager
,
Dict
]:
"""
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
...
...
@@ -418,8 +453,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
_set_envs_and_config
(
server_args
)
# Allocate ports for inter-process communications
port_args
=
PortArgs
.
init_new
(
server_args
)
logger
.
info
(
f
"
{
server_args
=
}
"
)
if
port_args
is
None
:
port_args
=
PortArgs
.
init_new
(
server_args
)
logger
.
info
(
f
"
{
server_args
=
}
"
)
# If using model from www.modelscope.cn, first download the model.
server_args
.
model_path
,
server_args
.
tokenizer_path
=
prepare_model_and_tokenizer
(
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
1ce4878d
...
...
@@ -37,6 +37,8 @@ from sglang.srt.configs import (
MultiModalityConfig
,
Qwen2_5_VLConfig
,
)
from
sglang.srt.connector
import
create_remote_connector
from
sglang.srt.utils
import
is_remote_url
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
...
...
@@ -155,6 +157,14 @@ def get_tokenizer(
kwargs
[
"gguf_file"
]
=
tokenizer_name
tokenizer_name
=
Path
(
tokenizer_name
).
parent
if
is_remote_url
(
tokenizer_name
):
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client
=
create_remote_connector
(
tokenizer_name
)
client
.
pull_files
(
ignore_pattern
=
[
"*.pt"
,
"*.safetensors"
,
"*.bin"
])
tokenizer_name
=
client
.
get_local_dir
()
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
1ce4878d
...
...
@@ -723,3 +723,15 @@ class SeparateReasoningReqInput:
class
VertexGenerateReqInput
:
instances
:
List
[
dict
]
parameters
:
Optional
[
dict
]
=
None
@
dataclass
class
RpcReqInput
:
method
:
str
parameters
:
Optional
[
Dict
]
=
None
@
dataclass
class
RpcReqOutput
:
success
:
bool
message
:
str
python/sglang/srt/managers/scheduler.py
View file @
1ce4878d
...
...
@@ -32,6 +32,7 @@ import psutil
import
setproctitle
import
torch
import
zmq
from
torch.distributed
import
barrier
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
...
...
@@ -59,6 +60,8 @@ from sglang.srt.managers.io_struct import (
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
RpcReqInput
,
RpcReqOutput
,
SetInternalStateReq
,
SetInternalStateReqOutput
,
TokenizedEmbeddingReqInput
,
...
...
@@ -193,8 +196,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
self
.
send_to_detokenizer
=
get_zmq_socket
(
context
,
zmq
.
PUSH
,
port_args
.
detokenizer_ipc_name
,
False
)
self
.
recv_from_rpc
=
get_zmq_socket
(
context
,
zmq
.
DEALER
,
port_args
.
rpc_ipc_name
,
False
)
else
:
self
.
recv_from_tokenizer
=
None
self
.
recv_from_rpc
=
None
self
.
send_to_tokenizer
=
SimpleNamespace
(
send_pyobj
=
lambda
x
:
None
)
self
.
send_to_detokenizer
=
SimpleNamespace
(
send_pyobj
=
lambda
x
:
None
)
...
...
@@ -376,6 +384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
(
ProfileReq
,
self
.
profile
),
(
GetInternalStateReq
,
self
.
get_internal_state
),
(
SetInternalStateReq
,
self
.
set_internal_state
),
(
RpcReqInput
,
self
.
handle_rpc_request
),
]
)
...
...
@@ -549,6 +558,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
except
zmq
.
ZMQError
:
break
recv_reqs
.
append
(
recv_req
)
while
True
:
try
:
recv_rpc
=
self
.
recv_from_rpc
.
recv_pyobj
(
zmq
.
NOBLOCK
)
except
zmq
.
ZMQError
:
break
recv_reqs
.
append
(
recv_rpc
)
else
:
recv_reqs
=
None
...
...
@@ -600,7 +616,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
output
=
self
.
_request_dispatcher
(
recv_req
)
if
output
is
not
None
:
self
.
send_to_tokenizer
.
send_pyobj
(
output
)
if
isinstance
(
output
,
RpcReqOutput
):
if
self
.
recv_from_rpc
is
not
None
:
self
.
recv_from_rpc
.
send_pyobj
(
output
)
else
:
self
.
send_to_tokenizer
.
send_pyobj
(
output
)
def
handle_generate_request
(
self
,
...
...
@@ -1492,6 +1512,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
server_args
=
global_server_args_dict
,
)
def
handle_rpc_request
(
self
,
recv_req
:
RpcReqInput
):
# Handle RPC requests
logger
.
info
(
f
"handle_rpc_request:
{
recv_req
.
method
}
, param:
{
recv_req
.
parameters
}
"
)
success
=
True
exec
=
None
try
:
func
=
getattr
(
self
,
recv_req
.
method
)
func
(
recv_req
.
parameters
)
except
Exception
as
e
:
success
=
False
exec
=
e
logger
.
error
(
f
"Failed to call rpc
{
recv_req
.
method
}
:
{
str
(
e
)
}
"
)
barrier
()
return
RpcReqOutput
(
success
,
""
if
not
exec
else
str
(
exec
))
def
save_remote_model
(
self
,
params
):
url
=
params
[
"url"
]
if
isinstance
(
self
.
tp_worker
,
TpModelWorkerClient
):
worker
=
self
.
tp_worker
.
worker
else
:
worker
=
self
.
tp_worker
worker
.
model_runner
.
save_remote_model
(
url
)
def
save_sharded_model
(
self
,
params
):
if
isinstance
(
self
.
tp_worker
,
TpModelWorkerClient
):
worker
=
self
.
tp_worker
.
worker
else
:
worker
=
self
.
tp_worker
worker
.
model_runner
.
save_sharded_model
(
path
=
params
[
"path"
],
pattern
=
params
[
"pattern"
],
max_size
=
params
[
"max_size"
],
)
def
abort_request
(
self
,
recv_req
:
AbortReq
):
# Delete requests in the waiting queue
to_del
=
[]
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
1ce4878d
...
...
@@ -1009,6 +1009,22 @@ class ModelRunner:
return
False
return
rope_scaling
.
get
(
"type"
,
None
)
==
"mrope"
def
save_remote_model
(
self
,
url
:
str
):
from
sglang.srt.model_loader.loader
import
RemoteModelLoader
logger
.
info
(
f
"Saving model to
{
url
}
"
)
RemoteModelLoader
.
save_model
(
self
.
model
,
self
.
model_config
.
model_path
,
url
)
def
save_sharded_model
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
):
from
sglang.srt.model_loader.loader
import
ShardedStateLoader
logger
.
info
(
f
"Save sharded model to
{
path
}
with pattern
{
pattern
}
and max_size
{
max_size
}
"
)
ShardedStateLoader
.
save_model
(
self
.
model
,
path
,
pattern
,
max_size
)
def
_model_load_weights_direct
(
model
,
named_tensors
:
List
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
model
.
named_parameters
())
...
...
python/sglang/srt/model_loader/loader.py
View file @
1ce4878d
...
...
@@ -9,6 +9,7 @@ import json
import
logging
import
math
import
os
import
time
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
...
...
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.connector
import
(
ConnectorType
,
create_remote_connector
,
get_connector_type
,
)
from
sglang.srt.connector.utils
import
parse_model_name
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
...
...
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
,
set_runai_streamer_env
,
)
from
sglang.srt.utils
import
(
get_bool_env_var
,
...
...
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_state.py` for creating a sharded checkpoint.
`examples/
runtime/engine/
save_sharded_state.py` for creating a sharded checkpoint.
"""
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
...
...
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
return
model
class
RemoteModelLoader
(
BaseModelLoader
):
"""Model loader that can load Tensors from remote database."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
# TODO @DellCurry: move to s3 connector only
set_runai_streamer_env
(
load_config
)
def
_get_weights_iterator_kv
(
self
,
client
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights from remote storage."""
assert
get_connector_type
(
client
)
==
ConnectorType
.
KV
rank
=
get_tensor_model_parallel_rank
()
return
client
.
weight_iterator
(
rank
)
def
_get_weights_iterator_fs
(
self
,
client
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights from remote storage."""
assert
get_connector_type
(
client
)
==
ConnectorType
.
FS
return
client
.
weight_iterator
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
model_path
:
str
,
url
:
str
,
)
->
None
:
with
create_remote_connector
(
url
)
as
client
:
assert
get_connector_type
(
client
)
==
ConnectorType
.
KV
model_name
=
parse_model_name
(
url
)
rank
=
get_tensor_model_parallel_rank
()
state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
model
.
state_dict
())
for
key
,
tensor
in
state_dict
.
items
():
r_key
=
f
"
{
model_name
}
/keys/rank_
{
rank
}
/
{
key
}
"
client
.
set
(
r_key
,
tensor
)
for
root
,
_
,
files
in
os
.
walk
(
model_path
):
for
file_name
in
files
:
# ignore hidden files
if
file_name
.
startswith
(
"."
):
continue
if
os
.
path
.
splitext
(
file_name
)[
1
]
not
in
(
".bin"
,
".pt"
,
".safetensors"
,
):
file_path
=
os
.
path
.
join
(
root
,
file_name
)
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
file
:
file_content
=
file
.
read
()
f_key
=
f
"
{
model_name
}
/files/
{
file_name
}
"
client
.
setstr
(
f_key
,
file_content
)
def
_load_model_from_remote_kv
(
self
,
model
:
nn
.
Module
,
client
):
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
weights_iterator
=
self
.
_get_weights_iterator_kv
(
client
)
state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
model
.
state_dict
())
for
key
,
tensor
in
weights_iterator
:
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data
=
state_dict
[
key
].
data
param_shape
=
state_dict
[
key
].
shape
for
dim
,
size
in
enumerate
(
tensor
.
shape
):
if
size
<
param_shape
[
dim
]:
param_data
=
param_data
.
narrow
(
dim
,
0
,
size
)
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
key
,
param_shape
,
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
def
_load_model_from_remote_fs
(
self
,
model
,
client
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
)
->
nn
.
Module
:
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
model
.
load_weights
(
self
.
_get_weights_iterator_fs
(
client
))
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
logger
.
info
(
"Loading weights from remote storage ..."
)
start
=
time
.
perf_counter
()
load_config
=
self
.
load_config
assert
load_config
.
load_format
==
LoadFormat
.
REMOTE
,
(
f
"Model loader
{
self
.
load_config
.
load_format
}
is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
model_weights
=
model_config
.
model_path
if
hasattr
(
model_config
,
"model_weights"
):
model_weights
=
model_config
.
model_weights
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
with
create_remote_connector
(
model_weights
,
device_config
.
device
)
as
client
:
connector_type
=
get_connector_type
(
client
)
if
connector_type
==
ConnectorType
.
KV
:
self
.
_load_model_from_remote_kv
(
model
,
client
)
elif
connector_type
==
ConnectorType
.
FS
:
self
.
_load_model_from_remote_fs
(
model
,
client
,
model_config
,
device_config
)
end
=
time
.
perf_counter
()
logger
.
info
(
"Loaded weights from remote storage in %.2f seconds."
,
end
-
start
)
return
model
.
eval
()
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
...
...
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if
load_config
.
load_format
==
LoadFormat
.
LAYERED
:
return
LayeredModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
REMOTE
:
return
RemoteModelLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
python/sglang/srt/model_loader/weight_utils.py
View file @
1ce4878d
...
...
@@ -585,6 +585,51 @@ def composed_weight_loader(
return
composed_loader
def
runai_safetensors_weights_iterator
(
hf_weights_files
:
List
[
str
],
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model safetensor files."""
from
runai_model_streamer
import
SafetensorsStreamer
enable_tqdm
=
(
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_rank
()
==
0
)
with
SafetensorsStreamer
()
as
streamer
:
for
st_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading safetensors using Runai Model Streamer"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
):
streamer
.
stream_file
(
st_file
)
yield
from
streamer
.
get_tensors
()
def
set_runai_streamer_env
(
load_config
:
LoadConfig
):
if
load_config
.
model_loader_extra_config
:
extra_config
=
load_config
.
model_loader_extra_config
if
"concurrency"
in
extra_config
and
isinstance
(
extra_config
.
get
(
"concurrency"
),
int
):
os
.
environ
[
"RUNAI_STREAMER_CONCURRENCY"
]
=
str
(
extra_config
.
get
(
"concurrency"
)
)
if
"memory_limit"
in
extra_config
and
isinstance
(
extra_config
.
get
(
"memory_limit"
),
int
):
os
.
environ
[
"RUNAI_STREAMER_MEMORY_LIMIT"
]
=
str
(
extra_config
.
get
(
"memory_limit"
)
)
runai_streamer_s3_endpoint
=
os
.
getenv
(
"RUNAI_STREAMER_S3_ENDPOINT"
)
aws_endpoint_url
=
os
.
getenv
(
"AWS_ENDPOINT_URL"
)
if
runai_streamer_s3_endpoint
is
None
and
aws_endpoint_url
is
not
None
:
os
.
environ
[
"RUNAI_STREAMER_S3_ENDPOINT"
]
=
aws_endpoint_url
def
initialize_dummy_weights
(
model
:
torch
.
nn
.
Module
,
low
:
float
=
-
1e-3
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment