Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
75235419
"vscode:/vscode.git/clone" did not exist on "e335f05fb15fec92b523f28fc4d9f019a35b7e75"
Unverified
Commit
75235419
authored
Mar 24, 2024
by
Liangsheng Yin
Committed by
GitHub
Mar 24, 2024
Browse files
`model_rpc` style improvement (#293)
parent
64ee9c03
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
24 deletions
+31
-24
python/sglang/backend/openai.py
python/sglang/backend/openai.py
+2
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+15
-9
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+14
-14
No files found.
python/sglang/backend/openai.py
View file @
75235419
...
@@ -9,8 +9,9 @@ from sglang.lang.interpreter import StreamExecutor
...
@@ -9,8 +9,9 @@ from sglang.lang.interpreter import StreamExecutor
from
sglang.lang.ir
import
SglSamplingParams
from
sglang.lang.ir
import
SglSamplingParams
try
:
try
:
import
openai
import
tiktoken
import
tiktoken
import
openai
except
ImportError
as
e
:
except
ImportError
as
e
:
openai
=
tiktoken
=
e
openai
=
tiktoken
=
e
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
75235419
...
@@ -6,7 +6,6 @@ import warnings
...
@@ -6,7 +6,6 @@ import warnings
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
rpyc
import
rpyc
import
torch
import
torch
from
rpyc.utils.classic
import
obtain
from
rpyc.utils.classic
import
obtain
...
@@ -36,8 +35,8 @@ from vllm.logger import _default_handler as vllm_default_handler
...
@@ -36,8 +35,8 @@ from vllm.logger import _default_handler as vllm_default_handler
logger
=
logging
.
getLogger
(
"model_rpc"
)
logger
=
logging
.
getLogger
(
"model_rpc"
)
class
ModelRpcServer
(
rpyc
.
Service
)
:
class
ModelRpcServer
:
def
exposed
_init_
model
(
def
_
_init_
_
(
self
,
self
,
tp_rank
:
int
,
tp_rank
:
int
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
...
@@ -608,14 +607,19 @@ class ModelRpcServer(rpyc.Service):
...
@@ -608,14 +607,19 @@ class ModelRpcServer(rpyc.Service):
batch
.
reqs
=
[]
batch
.
reqs
=
[]
class
ModelRpcService
(
rpyc
.
Service
):
exposed_ModelRpcServer
=
ModelRpcServer
class
ModelRpcClient
:
class
ModelRpcClient
:
def
__init__
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
):
def
__init__
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
):
tp_size
=
server_args
.
tp_size
tp_size
=
server_args
.
tp_size
if
tp_size
==
1
:
if
tp_size
==
1
:
# Init model
# Init model
self
.
model_server
=
ModelRpcServer
()
self
.
model_server
=
ModelRpcService
().
exposed_ModelRpcServer
(
self
.
model_server
.
exposed_init_model
(
0
,
server_args
,
port_args
)
0
,
server_args
,
port_args
)
# Wrap functions
# Wrap functions
def
async_wrap
(
f
):
def
async_wrap
(
f
):
...
@@ -629,14 +633,16 @@ class ModelRpcClient:
...
@@ -629,14 +633,16 @@ class ModelRpcClient:
with
ThreadPoolExecutor
(
tp_size
)
as
executor
:
with
ThreadPoolExecutor
(
tp_size
)
as
executor
:
# Launch model processes
# Launch model processes
rets
=
executor
.
map
(
start_model_process
,
port_args
.
model_rpc_ports
)
rets
=
executor
.
map
(
start_model_process
,
port_args
.
model_rpc_ports
)
self
.
model
_serve
r
s
=
[
x
[
0
]
for
x
in
rets
]
self
.
remote
_serv
ic
es
=
[
x
[
0
]
for
x
in
rets
]
self
.
procs
=
[
x
[
1
]
for
x
in
rets
]
self
.
procs
=
[
x
[
1
]
for
x
in
rets
]
# Init model
# Init model
def
init_model
(
i
):
def
init_model
(
i
):
return
self
.
model_servers
[
i
].
init_model
(
i
,
server_args
,
port_args
)
return
self
.
remote_services
[
i
].
ModelRpcServer
(
i
,
server_args
,
port_args
)
rets
=
[
obtain
(
x
)
for
x
in
executor
.
map
(
init_model
,
range
(
tp_size
))
]
self
.
model_servers
=
executor
.
map
(
init_model
,
range
(
tp_size
))
# Wrap functions
# Wrap functions
def
async_wrap
(
func_name
):
def
async_wrap
(
func_name
):
...
@@ -654,7 +660,7 @@ class ModelRpcClient:
...
@@ -654,7 +660,7 @@ class ModelRpcClient:
def
_init_service
(
port
):
def
_init_service
(
port
):
t
=
ThreadedServer
(
t
=
ThreadedServer
(
ModelRpcServe
r
(),
ModelRpcServ
ic
e
(),
port
=
port
,
port
=
port
,
protocol_config
=
{
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
},
protocol_config
=
{
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
},
)
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
75235419
import
importlib
import
importlib
import
logging
import
importlib.resources
import
inspect
import
inspect
import
logging
import
pkgutil
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
pathlib
import
Path
import
importlib.resources
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -18,11 +18,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
...
@@ -18,11 +18,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
import
importlib
import
pkgutil
import
sglang
QUANTIONCONFIG_MAPPING
=
{
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"marlin"
:
MarlinConfig
}
QUANTIONCONFIG_MAPPING
=
{
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"marlin"
:
MarlinConfig
}
logger
=
logging
.
getLogger
(
"model_runner"
)
logger
=
logging
.
getLogger
(
"model_runner"
)
...
@@ -37,7 +32,7 @@ def import_model_classes():
...
@@ -37,7 +32,7 @@ def import_model_classes():
model_arch_name_to_cls
=
{}
model_arch_name_to_cls
=
{}
package_name
=
"sglang.srt.models"
package_name
=
"sglang.srt.models"
package
=
importlib
.
import_module
(
package_name
)
package
=
importlib
.
import_module
(
package_name
)
for
finder
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
'.'
):
for
_
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
"."
):
if
not
ispkg
:
if
not
ispkg
:
module
=
importlib
.
import_module
(
name
)
module
=
importlib
.
import_module
(
name
)
if
hasattr
(
module
,
"EntryClass"
):
if
hasattr
(
module
,
"EntryClass"
):
...
@@ -144,9 +139,12 @@ class InputMetadata:
...
@@ -144,9 +139,12 @@ class InputMetadata:
# flashinfer >= 0.0.3
# flashinfer >= 0.0.3
# FIXME: Drop this when flashinfer updates to 0.0.4
# FIXME: Drop this when flashinfer updates to 0.0.4
if
len
(
inspect
.
signature
(
self
.
prefill_wrapper
.
begin_forward
).
parameters
)
==
7
:
if
(
len
(
inspect
.
signature
(
self
.
prefill_wrapper
.
begin_forward
).
parameters
)
==
7
):
args
.
append
(
self
.
model_runner
.
model_config
.
head_dim
)
args
.
append
(
self
.
model_runner
.
model_config
.
head_dim
)
self
.
prefill_wrapper
.
begin_forward
(
*
args
)
self
.
prefill_wrapper
.
begin_forward
(
*
args
)
else
:
else
:
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
...
@@ -307,9 +305,11 @@ class ModelRunner:
...
@@ -307,9 +305,11 @@ class ModelRunner:
hf_quant_method
=
hf_quant_config
[
"quant_method"
]
hf_quant_method
=
hf_quant_config
[
"quant_method"
]
# compat: autogptq uses is_marlin_format within quant config
# compat: autogptq uses is_marlin_format within quant config
if
(
hf_quant_method
==
"gptq"
if
(
and
"is_marlin_format"
in
hf_quant_config
hf_quant_method
==
"gptq"
and
hf_quant_config
[
"is_marlin_format"
]):
and
"is_marlin_format"
in
hf_quant_config
and
hf_quant_config
[
"is_marlin_format"
]
):
hf_quant_method
=
"marlin"
hf_quant_method
=
"marlin"
quant_config_class
=
QUANTIONCONFIG_MAPPING
.
get
(
hf_quant_method
)
quant_config_class
=
QUANTIONCONFIG_MAPPING
.
get
(
hf_quant_method
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment