Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
75235419
Unverified
Commit
75235419
authored
Mar 24, 2024
by
Liangsheng Yin
Committed by
GitHub
Mar 24, 2024
Browse files
`model_rpc` style improvement (#293)
parent
64ee9c03
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
24 deletions
+31
-24
python/sglang/backend/openai.py
python/sglang/backend/openai.py
+2
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+15
-9
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+14
-14
No files found.
python/sglang/backend/openai.py
View file @
75235419
...
...
@@ -9,8 +9,9 @@ from sglang.lang.interpreter import StreamExecutor
from
sglang.lang.ir
import
SglSamplingParams
try
:
import
openai
import
tiktoken
import
openai
except
ImportError
as
e
:
openai
=
tiktoken
=
e
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
75235419
...
...
@@ -6,7 +6,6 @@ import warnings
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
List
import
numpy
as
np
import
rpyc
import
torch
from
rpyc.utils.classic
import
obtain
...
...
@@ -36,8 +35,8 @@ from vllm.logger import _default_handler as vllm_default_handler
logger
=
logging
.
getLogger
(
"model_rpc"
)
class
ModelRpcServer
(
rpyc
.
Service
)
:
def
exposed
_init_
model
(
class
ModelRpcServer
:
def
_
_init_
_
(
self
,
tp_rank
:
int
,
server_args
:
ServerArgs
,
...
...
@@ -608,14 +607,19 @@ class ModelRpcServer(rpyc.Service):
batch
.
reqs
=
[]
class
ModelRpcService
(
rpyc
.
Service
):
exposed_ModelRpcServer
=
ModelRpcServer
class
ModelRpcClient
:
def
__init__
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
):
tp_size
=
server_args
.
tp_size
if
tp_size
==
1
:
# Init model
self
.
model_server
=
ModelRpcServer
()
self
.
model_server
.
exposed_init_model
(
0
,
server_args
,
port_args
)
self
.
model_server
=
ModelRpcService
().
exposed_ModelRpcServer
(
0
,
server_args
,
port_args
)
# Wrap functions
def
async_wrap
(
f
):
...
...
@@ -629,14 +633,16 @@ class ModelRpcClient:
with
ThreadPoolExecutor
(
tp_size
)
as
executor
:
# Launch model processes
rets
=
executor
.
map
(
start_model_process
,
port_args
.
model_rpc_ports
)
self
.
model
_serve
r
s
=
[
x
[
0
]
for
x
in
rets
]
self
.
remote
_serv
ic
es
=
[
x
[
0
]
for
x
in
rets
]
self
.
procs
=
[
x
[
1
]
for
x
in
rets
]
# Init model
def
init_model
(
i
):
return
self
.
model_servers
[
i
].
init_model
(
i
,
server_args
,
port_args
)
return
self
.
remote_services
[
i
].
ModelRpcServer
(
i
,
server_args
,
port_args
)
rets
=
[
obtain
(
x
)
for
x
in
executor
.
map
(
init_model
,
range
(
tp_size
))
]
self
.
model_servers
=
executor
.
map
(
init_model
,
range
(
tp_size
))
# Wrap functions
def
async_wrap
(
func_name
):
...
...
@@ -654,7 +660,7 @@ class ModelRpcClient:
def
_init_service
(
port
):
t
=
ThreadedServer
(
ModelRpcServe
r
(),
ModelRpcServ
ic
e
(),
port
=
port
,
protocol_config
=
{
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
},
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
75235419
import
importlib
import
logging
import
importlib.resources
import
inspect
import
logging
import
pkgutil
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
pathlib
import
Path
import
importlib.resources
import
numpy
as
np
import
torch
...
...
@@ -18,11 +18,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
import
importlib
import
pkgutil
import
sglang
QUANTIONCONFIG_MAPPING
=
{
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"marlin"
:
MarlinConfig
}
logger
=
logging
.
getLogger
(
"model_runner"
)
...
...
@@ -37,7 +32,7 @@ def import_model_classes():
model_arch_name_to_cls
=
{}
package_name
=
"sglang.srt.models"
package
=
importlib
.
import_module
(
package_name
)
for
finder
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
'.'
):
for
_
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
"."
):
if
not
ispkg
:
module
=
importlib
.
import_module
(
name
)
if
hasattr
(
module
,
"EntryClass"
):
...
...
@@ -144,9 +139,12 @@ class InputMetadata:
# flashinfer >= 0.0.3
# FIXME: Drop this when flashinfer updates to 0.0.4
if
len
(
inspect
.
signature
(
self
.
prefill_wrapper
.
begin_forward
).
parameters
)
==
7
:
if
(
len
(
inspect
.
signature
(
self
.
prefill_wrapper
.
begin_forward
).
parameters
)
==
7
):
args
.
append
(
self
.
model_runner
.
model_config
.
head_dim
)
self
.
prefill_wrapper
.
begin_forward
(
*
args
)
else
:
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
...
...
@@ -307,9 +305,11 @@ class ModelRunner:
hf_quant_method
=
hf_quant_config
[
"quant_method"
]
# compat: autogptq uses is_marlin_format within quant config
if
(
hf_quant_method
==
"gptq"
and
"is_marlin_format"
in
hf_quant_config
and
hf_quant_config
[
"is_marlin_format"
]):
if
(
hf_quant_method
==
"gptq"
and
"is_marlin_format"
in
hf_quant_config
and
hf_quant_config
[
"is_marlin_format"
]
):
hf_quant_method
=
"marlin"
quant_config_class
=
QUANTIONCONFIG_MAPPING
.
get
(
hf_quant_method
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment