Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
3a581e99
"mmdet3d/vscode:/vscode.git/clone" did not exist on "075904181994cd3877fd5572e72fc77d2cba3ea2"
Unverified
Commit
3a581e99
authored
Jan 25, 2024
by
Cody Yu
Committed by
GitHub
Jan 25, 2024
Browse files
Dynamic model class loading (#101)
parent
0147f940
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
40 additions
and
28 deletions
+40
-28
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+31
-27
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+2
-0
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+2
-0
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+2
-0
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+2
-0
No files found.
python/pyproject.toml
View file @
3a581e99
...
@@ -20,7 +20,7 @@ dependencies = [
...
@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
srt
=
[
"aiohttp"
,
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
"zmq"
,
"vllm>=0.2.5"
,
"interegular"
,
"lark"
,
"numba"
,
"zmq"
,
"vllm>=0.2.5"
,
"interegular"
,
"lark"
,
"numba"
,
"pydantic"
,
"diskcache"
,
"cloudpickle"
]
"pydantic"
,
"diskcache"
,
"cloudpickle"
,
"pillow"
]
openai
=
[
"openai>=1.0"
,
"numpy"
]
openai
=
[
"openai>=1.0"
,
"numpy"
]
anthropic
=
[
"anthropic"
,
"numpy"
]
anthropic
=
[
"anthropic"
,
"numpy"
]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]"]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]"]
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
3a581e99
import
importlib
import
logging
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
functools
import
lru_cache
from
pathlib
import
Path
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
sglang
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.utils
import
is_multimodal_model
from
sglang.srt.utils
import
is_multimodal_model
...
@@ -20,6 +23,32 @@ logger = logging.getLogger("model_runner")
...
@@ -20,6 +23,32 @@ logger = logging.getLogger("model_runner")
global_model_mode
:
List
[
str
]
=
[]
global_model_mode
:
List
[
str
]
=
[]
@
lru_cache
()
def
import_model_classes
():
model_arch_name_to_cls
=
{}
for
module_path
in
(
Path
(
sglang
.
__file__
).
parent
/
"srt"
/
"models"
).
glob
(
"*.py"
):
module
=
importlib
.
import_module
(
f
"sglang.srt.models.
{
module_path
.
stem
}
"
)
if
hasattr
(
module
,
"EntryClass"
):
model_arch_name_to_cls
[
module
.
EntryClass
.
__name__
]
=
module
.
EntryClass
return
model_arch_name_to_cls
def
get_model_cls_by_arch_name
(
model_arch_names
):
model_arch_name_to_cls
=
import_model_classes
()
model_class
=
None
for
arch
in
model_arch_names
:
if
arch
in
model_arch_name_to_cls
:
model_class
=
model_arch_name_to_cls
[
arch
]
break
else
:
raise
ValueError
(
f
"Unsupported architectures:
{
arch
}
. "
f
"Supported list:
{
list
(
model_arch_name_to_cls
.
keys
())
}
"
)
return
model_class
@
dataclass
@
dataclass
class
InputMetadata
:
class
InputMetadata
:
model_runner
:
"ModelRunner"
model_runner
:
"ModelRunner"
...
@@ -237,34 +266,9 @@ class ModelRunner:
...
@@ -237,34 +266,9 @@ class ModelRunner:
def
load_model
(
self
):
def
load_model
(
self
):
"""See also vllm/model_executor/model_loader.py::get_model"""
"""See also vllm/model_executor/model_loader.py::get_model"""
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llava
import
LlavaLlamaForCausalLM
from
sglang.srt.models.mixtral
import
MixtralForCausalLM
from
sglang.srt.models.qwen
import
QWenLMHeadModel
# Select model class
# Select model class
architectures
=
getattr
(
self
.
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
self
.
model_config
.
hf_config
,
"architectures"
,
[])
model_class
=
get_model_cls_by_arch_name
(
architectures
)
model_class
=
None
for
arch
in
architectures
:
if
arch
==
"LlamaForCausalLM"
:
model_class
=
LlamaForCausalLM
break
if
arch
==
"MistralForCausalLM"
:
model_class
=
LlamaForCausalLM
break
if
arch
==
"LlavaLlamaForCausalLM"
:
model_class
=
LlavaLlamaForCausalLM
break
if
arch
==
"MixtralForCausalLM"
:
model_class
=
MixtralForCausalLM
break
if
arch
==
"QWenLMHeadModel"
:
model_class
=
QWenLMHeadModel
break
if
model_class
is
None
:
raise
ValueError
(
f
"Unsupported architectures:
{
architectures
}
"
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight begin."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight begin."
)
# Load weights
# Load weights
...
...
python/sglang/srt/models/llama2.py
View file @
3a581e99
...
@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module):
...
@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module):
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
LlamaForCausalLM
python/sglang/srt/models/llava.py
View file @
3a581e99
...
@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
...
@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
"forward"
,
"forward"
,
clip_vision_embed_forward
,
clip_vision_embed_forward
,
)
)
EntryClass
=
LlavaLlamaForCausalLM
python/sglang/srt/models/mixtral.py
View file @
3a581e99
...
@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module):
...
@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module):
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
MixtralForCausalLM
python/sglang/srt/models/qwen.py
View file @
3a581e99
...
@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module):
...
@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module):
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
QWenLMHeadModel
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment