Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3a581e99
Unverified
Commit
3a581e99
authored
Jan 25, 2024
by
Cody Yu
Committed by
GitHub
Jan 25, 2024
Browse files
Dynamic model class loading (#101)
parent
0147f940
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
40 additions
and
28 deletions
+40
-28
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+31
-27
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+2
-0
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+2
-0
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+2
-0
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+2
-0
No files found.
python/pyproject.toml
View file @
3a581e99
...
@@ -20,7 +20,7 @@ dependencies = [
...
@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
srt
=
[
"aiohttp"
,
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
"zmq"
,
"vllm>=0.2.5"
,
"interegular"
,
"lark"
,
"numba"
,
"zmq"
,
"vllm>=0.2.5"
,
"interegular"
,
"lark"
,
"numba"
,
"pydantic"
,
"diskcache"
,
"cloudpickle"
]
"pydantic"
,
"diskcache"
,
"cloudpickle"
,
"pillow"
]
openai
=
[
"openai>=1.0"
,
"numpy"
]
openai
=
[
"openai>=1.0"
,
"numpy"
]
anthropic
=
[
"anthropic"
,
"numpy"
]
anthropic
=
[
"anthropic"
,
"numpy"
]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]"]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]"]
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
3a581e99
import
importlib
import
logging
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
functools
import
lru_cache
from
pathlib
import
Path
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
sglang
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.utils
import
is_multimodal_model
from
sglang.srt.utils
import
is_multimodal_model
...
@@ -20,6 +23,32 @@ logger = logging.getLogger("model_runner")
...
@@ -20,6 +23,32 @@ logger = logging.getLogger("model_runner")
global_model_mode
:
List
[
str
]
=
[]
global_model_mode
:
List
[
str
]
=
[]
@
lru_cache
()
def
import_model_classes
():
model_arch_name_to_cls
=
{}
for
module_path
in
(
Path
(
sglang
.
__file__
).
parent
/
"srt"
/
"models"
).
glob
(
"*.py"
):
module
=
importlib
.
import_module
(
f
"sglang.srt.models.
{
module_path
.
stem
}
"
)
if
hasattr
(
module
,
"EntryClass"
):
model_arch_name_to_cls
[
module
.
EntryClass
.
__name__
]
=
module
.
EntryClass
return
model_arch_name_to_cls
def
get_model_cls_by_arch_name
(
model_arch_names
):
model_arch_name_to_cls
=
import_model_classes
()
model_class
=
None
for
arch
in
model_arch_names
:
if
arch
in
model_arch_name_to_cls
:
model_class
=
model_arch_name_to_cls
[
arch
]
break
else
:
raise
ValueError
(
f
"Unsupported architectures:
{
arch
}
. "
f
"Supported list:
{
list
(
model_arch_name_to_cls
.
keys
())
}
"
)
return
model_class
@
dataclass
@
dataclass
class
InputMetadata
:
class
InputMetadata
:
model_runner
:
"ModelRunner"
model_runner
:
"ModelRunner"
...
@@ -237,34 +266,9 @@ class ModelRunner:
...
@@ -237,34 +266,9 @@ class ModelRunner:
def
load_model
(
self
):
def
load_model
(
self
):
"""See also vllm/model_executor/model_loader.py::get_model"""
"""See also vllm/model_executor/model_loader.py::get_model"""
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llava
import
LlavaLlamaForCausalLM
from
sglang.srt.models.mixtral
import
MixtralForCausalLM
from
sglang.srt.models.qwen
import
QWenLMHeadModel
# Select model class
# Select model class
architectures
=
getattr
(
self
.
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
self
.
model_config
.
hf_config
,
"architectures"
,
[])
model_class
=
get_model_cls_by_arch_name
(
architectures
)
model_class
=
None
for
arch
in
architectures
:
if
arch
==
"LlamaForCausalLM"
:
model_class
=
LlamaForCausalLM
break
if
arch
==
"MistralForCausalLM"
:
model_class
=
LlamaForCausalLM
break
if
arch
==
"LlavaLlamaForCausalLM"
:
model_class
=
LlavaLlamaForCausalLM
break
if
arch
==
"MixtralForCausalLM"
:
model_class
=
MixtralForCausalLM
break
if
arch
==
"QWenLMHeadModel"
:
model_class
=
QWenLMHeadModel
break
if
model_class
is
None
:
raise
ValueError
(
f
"Unsupported architectures:
{
architectures
}
"
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight begin."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight begin."
)
# Load weights
# Load weights
...
...
python/sglang/srt/models/llama2.py
View file @
3a581e99
...
@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module):
...
@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module):
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
LlamaForCausalLM
python/sglang/srt/models/llava.py
View file @
3a581e99
...
@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
...
@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
"forward"
,
"forward"
,
clip_vision_embed_forward
,
clip_vision_embed_forward
,
)
)
EntryClass
=
LlavaLlamaForCausalLM
python/sglang/srt/models/mixtral.py
View file @
3a581e99
...
@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module):
...
@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module):
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
MixtralForCausalLM
python/sglang/srt/models/qwen.py
View file @
3a581e99
...
@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module):
...
@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module):
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
QWenLMHeadModel
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment