Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d774acad
Unverified
Commit
d774acad
authored
Jul 18, 2024
by
Mingyi
Committed by
GitHub
Jul 18, 2024
Browse files
Remove the dependency of rpyc (#646)
parent
d93388da
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
303 additions
and
551 deletions
+303
-551
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/launch_server.py
python/sglang/launch_server.py
+1
-1
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+1
-4
python/sglang/srt/managers/controller/dp_worker.py
python/sglang/srt/managers/controller/dp_worker.py
+0
-113
python/sglang/srt/managers/controller/manager_multi.py
python/sglang/srt/managers/controller/manager_multi.py
+102
-102
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+58
-96
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+72
-98
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+63
-77
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-9
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-49
No files found.
python/pyproject.toml
View file @
d774acad
...
@@ -21,7 +21,7 @@ dependencies = [
...
@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies]
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"packaging"
,
"pillow"
,
srt
=
[
"aiohttp"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"packaging"
,
"pillow"
,
"psutil"
,
"pydantic"
,
"rpyc"
,
"torch"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"vllm==0.5.1"
,
"outlines>=0.0.44"
]
"psutil"
,
"pydantic"
,
"torch"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"vllm==0.5.1"
,
"outlines>=0.0.44"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
litellm
=
["litellm>=1.0.0"]
...
...
python/sglang/launch_server.py
View file @
d774acad
...
@@ -11,4 +11,4 @@ if __name__ == "__main__":
...
@@ -11,4 +11,4 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server_args
=
ServerArgs
.
from_cli_args
(
args
)
launch_server
(
server_args
,
None
)
launch_server
(
server_args
)
python/sglang/launch_server_llavavid.py
View file @
d774acad
"""Launch the inference server for Llava-video model."""
"""Launch the inference server for Llava-video model."""
import
argparse
import
argparse
import
multiprocessing
as
mp
from
sglang.srt.server
import
ServerArgs
,
launch_server
from
sglang.srt.server
import
ServerArgs
,
launch_server
...
@@ -27,6 +26,4 @@ if __name__ == "__main__":
...
@@ -27,6 +26,4 @@ if __name__ == "__main__":
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server_args
=
ServerArgs
.
from_cli_args
(
args
)
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
launch_server
(
server_args
,
model_overide_args
,
None
)
launch_server
(
server_args
,
pipe_writer
,
model_overide_args
)
python/sglang/srt/managers/controller/dp_worker.py
deleted
100644 → 0
View file @
d93388da
"""A data parallel worker thread."""
import
asyncio
import
logging
import
queue
import
threading
from
typing
import
Callable
,
List
import
uvloop
import
zmq
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
"srt.controller"
)
CHECKING_INTERVAL
=
5
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
class
DataParallelWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
worker_id
:
int
,
request_queue
:
queue
.
Queue
,
detokenizer_port
:
int
,
step_func
:
Callable
,
):
super
(
DataParallelWorkerThread
,
self
).
__init__
()
self
.
worker_id
=
worker_id
self
.
request_queue
=
request_queue
self
.
liveness
=
True
self
.
request_dependency_delay
=
global_config
.
request_dependency_delay
context
=
zmq
.
asyncio
.
Context
()
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
f
"tcp://127.0.0.1:
{
detokenizer_port
}
"
)
self
.
step
=
step_func
async
def
loop_for_forward
(
self
):
while
self
.
liveness
:
requests
=
[]
while
not
self
.
request_queue
.
empty
():
requests
.
append
(
self
.
request_queue
.
get
())
out_pyobjs
:
List
[
BatchTokenIDOut
]
=
[]
try
:
out_pyobjs
=
await
self
.
step
(
requests
)
except
Exception
:
for
r
in
requests
:
self
.
request_queue
.
put
(
r
)
logger
.
error
(
f
"Worker thread
{
self
.
worker_id
}
: "
f
"failed to get back from Model Server
\n
"
f
"
{
get_exception_traceback
()
}
"
)
self
.
liveness
=
False
# Crash the whole server when there are any errors.
# TODO(lianmin): make this an option.
kill_parent_process
()
return
for
obj
in
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
# async sleep for receiving the subsequent request and avoiding cache miss
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
(
[
obj
.
finished_reason
is
not
None
for
obj
in
out_pyobjs
]
)
if
has_finished
:
await
asyncio
.
sleep
(
self
.
request_dependency_delay
)
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
async
def
monitoring
(
self
):
while
True
:
await
asyncio
.
sleep
(
CHECKING_INTERVAL
)
# can plug in monitoring logic here
def
run
(
self
):
logger
.
info
(
f
"DataParallelWorkerThread
{
self
.
worker_id
}
start"
)
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
loop
.
create_task
(
self
.
monitoring
())
loop
.
run_until_complete
(
self
.
loop_for_forward
())
def
start_data_parallel_worker
(
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
,
gpu_ids
:
List
[
int
],
worker_id
:
int
,
):
model_tp_client
=
ModelTpClient
(
gpu_ids
,
server_args
,
port_args
.
model_port_args
[
worker_id
],
model_overide_args
,
)
worker_thread
=
DataParallelWorkerThread
(
worker_id
=
worker_id
,
request_queue
=
queue
.
Queue
(),
detokenizer_port
=
port_args
.
detokenizer_port
,
step_func
=
model_tp_client
.
step
,
)
worker_thread
.
start
()
return
worker_thread
python/sglang/srt/managers/controller/manager_multi.py
View file @
d774acad
...
@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
...
@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
Each data parallel worker can manage multiple tensor parallel workers.
Each data parallel worker can manage multiple tensor parallel workers.
"""
"""
import
asyncio
import
dataclasses
import
logging
import
logging
from
concurrent.futures
import
ThreadPoolExecutor
import
multiprocessing
import
os
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Dict
import
numpy
as
np
import
zmq
import
zmq
import
zmq.asyncio
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.manager_single
import
(
from
sglang.srt.managers.controller.dp_worker
import
(
start_controller_process
as
start_controller_process_single
,
DataParallelWorkerThread
,
start_data_parallel_worker
,
)
)
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
...
@@ -23,12 +21,14 @@ from sglang.srt.managers.io_struct import (
...
@@ -23,12 +21,14 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
"srt.controller"
)
logger
=
logging
.
getLogger
(
"srt.controller"
)
class
LoadBalanceMethod
(
Enum
):
class
LoadBalanceMethod
(
Enum
):
"""Load balance method."""
ROUND_ROBIN
=
auto
()
ROUND_ROBIN
=
auto
()
SHORTEST_QUEUE
=
auto
()
SHORTEST_QUEUE
=
auto
()
...
@@ -41,155 +41,155 @@ class LoadBalanceMethod(Enum):
...
@@ -41,155 +41,155 @@ class LoadBalanceMethod(Enum):
raise
ValueError
(
f
"Invalid load balance method:
{
method
}
"
)
from
exc
raise
ValueError
(
f
"Invalid load balance method:
{
method
}
"
)
from
exc
class
Controller
:
@
dataclasses
.
dataclass
class
WorkerHandle
:
"""Store the handle of a data parallel worker."""
proc
:
multiprocessing
.
Process
queue
:
multiprocessing
.
Queue
class
ControllerMulti
:
"""A controller that manages multiple data parallel workers."""
"""A controller that manages multiple data parallel workers."""
def
__init__
(
def
__init__
(
self
,
self
,
load_balance_method
:
str
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
model_overide_args
,
model_overide_args
,
):
):
self
.
load_balance_method
=
LoadBalanceMethod
.
from_str
(
load_balance_method
)
# Parse args
self
.
server_args
=
server_args
self
.
server_args
=
server_args
self
.
port_args
=
port_args
self
.
port_args
=
port_args
self
.
model_overide_args
=
model_overide_args
self
.
load_balance_method
=
LoadBalanceMethod
.
from_str
(
server_args
.
load_balance_method
)
if
self
.
load_balance_method
==
LoadBalanceMethod
.
ROUND_ROBIN
:
# Init communication
self
.
round_robin_counter
=
0
context
=
zmq
.
Context
()
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
controller_port
}
"
)
self
.
dispatch_lookup
=
{
# Dispatch method
self
.
round_robin_counter
=
0
dispatch_lookup
=
{
LoadBalanceMethod
.
ROUND_ROBIN
:
self
.
round_robin_scheduler
,
LoadBalanceMethod
.
ROUND_ROBIN
:
self
.
round_robin_scheduler
,
LoadBalanceMethod
.
SHORTEST_QUEUE
:
self
.
shortest_queue_scheduler
,
LoadBalanceMethod
.
SHORTEST_QUEUE
:
self
.
shortest_queue_scheduler
,
}
}
self
.
dispatching
=
self
.
dispatch_lookup
[
self
.
load_balance_method
]
self
.
dispatching
=
dispatch_lookup
[
self
.
load_balance_method
]
# Init communication
context
=
zmq
.
asyncio
.
Context
()
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
router_port
}
"
)
# Init status
self
.
recv_reqs
=
[]
# Start data parallel workers
# Start data parallel workers
self
.
workers
:
Dict
[
int
,
DataParallelWorkerThread
]
=
{}
self
.
workers
=
[]
tp_size
=
server_args
.
tp_size
for
i
in
range
(
server_args
.
dp_size
):
self
.
start_dp_worker
(
i
)
def
start_dp_worker
(
i
):
try
:
def
start_dp_worker
(
self
,
dp_worker_id
:
int
):
gpu_ids
=
list
(
range
(
i
*
tp_size
,
(
i
+
1
)
*
tp_size
))
tp_size
=
self
.
server_args
.
tp_size
worker_thread
=
start_data_parallel_worker
(
server_args
,
port_args
,
model_overide_args
,
gpu_ids
,
i
pipe_controller_reader
,
pipe_controller_writer
=
multiprocessing
.
Pipe
(
duplex
=
False
)
gpu_ids
=
list
(
range
(
dp_worker_id
*
tp_size
,
(
dp_worker_id
+
1
)
*
tp_size
))
queue
=
multiprocessing
.
Queue
()
proc
=
multiprocessing
.
Process
(
target
=
start_controller_process_single
,
args
=
(
self
.
server_args
,
self
.
port_args
,
pipe_controller_writer
,
self
.
model_overide_args
,
True
,
gpu_ids
,
dp_worker_id
,
queue
,
)
)
self
.
workers
[
i
]
=
worker_thread
except
Exception
:
logger
.
error
(
f
"Failed to start local worker
{
i
}
\n
{
get_exception_traceback
()
}
"
)
)
proc
.
start
()
for
i
in
range
(
server_args
.
dp_size
):
controller_init_state
=
pipe_controller_reader
.
recv
()
start_dp_worker
(
i
)
if
controller_init_state
!=
"init ok"
:
raise
RuntimeError
(
# Parallel launch is slower, probably due to the disk bandwidth limitations.
f
"Initialization failed. controller_init_state:
{
controller_init_state
}
"
# with ThreadPoolExecutor(server_args.dp_size) as executor:
)
# executor.map(start_dp_worker, range(server_args.dp_size))
self
.
workers
.
append
(
WorkerHandle
(
proc
=
proc
,
def
have_any_live_worker
(
self
):
queue
=
queue
,
return
any
(
worker_thread
.
liveness
for
worker_thread
in
self
.
workers
.
values
())
))
def
put_req_to_worker
(
self
,
worker_id
,
req
):
self
.
workers
[
worker_id
].
request_queue
.
put
(
req
)
async
def
round_robin_scheduler
(
self
,
input_requests
):
def
round_robin_scheduler
(
self
,
input_requests
):
available_workers
=
list
(
self
.
workers
.
keys
())
for
r
in
input_requests
:
for
r
in
input_requests
:
self
.
put_req_to_worker
(
available_
workers
[
self
.
round_robin_counter
]
,
r
)
self
.
workers
[
self
.
round_robin_counter
]
.
queue
.
put
(
r
)
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
available_
workers
self
.
workers
)
)
return
async
def
shortest_queue_scheduler
(
self
,
input_requests
):
def
shortest_queue_scheduler
(
self
,
input_requests
):
for
r
in
input_requests
:
for
r
in
input_requests
:
worker
=
min
(
queue_sizes
=
[
worker
.
queue
.
qsize
()
for
worker
in
self
.
workers
]
self
.
workers
,
key
=
lambda
w
:
self
.
workers
[
w
].
request_queue
.
qsize
()
wid
=
np
.
argmin
(
queue_sizes
)
)
self
.
workers
[
wid
].
queue
.
put
(
r
)
self
.
put_req_to_worker
(
worker
,
r
)
return
async
def
remove_dead_workers
(
self
):
for
i
in
list
(
self
.
workers
.
keys
()):
worker_thread
=
self
.
workers
[
i
]
if
not
worker_thread
.
liveness
:
worker_thread
.
join
()
# move unsuccessful requests back to the queue
while
not
worker_thread
.
request_queue
.
empty
():
self
.
recv_reqs
.
append
(
worker_thread
.
request_queue
.
get
())
del
self
.
workers
[
i
]
logger
.
info
(
f
"Stale worker
{
i
}
removed"
)
async
def
loop_for_forward
(
self
):
while
True
:
await
self
.
remove_dead_workers
()
if
self
.
have_any_live_worker
():
def
loop_for_forward
(
self
):
next_step_input
=
list
(
self
.
recv_reqs
)
while
True
:
self
.
recv_reqs
=
[]
recv_reqs
=
self
.
recv_requests
()
if
next_step_input
:
self
.
dispatching
(
recv_reqs
)
await
self
.
dispatching
(
next_step_input
)
# else:
# logger.error("There is no live worker.")
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
def
recv_requests
(
self
):
recv_reqs
=
[]
async
def
loop_for_recv_requests
(
self
):
while
True
:
while
True
:
recv_req
=
await
self
.
recv_from_tokenizer
.
recv_pyobj
()
try
:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
except
zmq
.
ZMQError
:
break
if
isinstance
(
recv_req
,
FlushCacheReq
):
if
isinstance
(
recv_req
,
FlushCacheReq
):
# TODO(lsyin): apply more specific flushCacheReq
# TODO(lsyin): apply more specific flushCacheReq
for
worker_thread
in
self
.
workers
.
values
():
for
worker
in
self
.
workers
:
worker_thread
.
request_queue
.
put
(
recv_req
)
worker
.
queue
.
put
(
recv_req
)
elif
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
recv_reqs
.
append
(
recv_req
)
elif
isinstance
(
recv_req
,
AbortReq
):
elif
isinstance
(
recv_req
,
AbortReq
):
in_queue
=
False
in_queue
=
False
for
i
,
req
in
enumerate
(
self
.
recv_reqs
):
for
i
,
req
in
enumerate
(
recv_reqs
):
if
req
.
rid
==
recv_req
.
rid
:
if
req
.
rid
==
recv_req
.
rid
:
self
.
recv_reqs
[
i
]
=
recv_req
recv_reqs
[
i
]
=
recv_req
in_queue
=
True
in_queue
=
True
break
break
if
not
in_queue
:
if
not
in_queue
:
# Send abort req to all TP groups
# Send abort req to all TP groups
for
worker
in
list
(
self
.
workers
.
keys
()):
for
worker
in
self
.
workers
:
self
.
put_req_to_worker
(
worker
,
recv_req
)
worker
.
queue
.
put
(
recv_req
)
elif
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
recv_reqs
.
append
(
recv_req
)
else
:
else
:
logger
.
error
(
f
"Invalid object:
{
recv_req
}
"
)
logger
.
error
(
f
"Invalid object:
{
recv_req
}
"
)
return
recv_reqs
def
start_controller_process
(
def
start_controller_process
(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
pipe_writer
,
pipe_writer
,
model_overide_args
=
None
,
model_overide_args
:
dict
,
):
):
"""Start a controller process."""
logging
.
basicConfig
(
logging
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
format
=
"%(message)s"
,
format
=
"%(message)s"
,
)
)
try
:
try
:
controller
=
Controller
(
controller
=
ControllerMulti
(
server_args
,
port_args
,
model_overide_args
)
server_args
.
load_balance_method
,
server_args
,
port_args
,
model_overide_args
)
except
Exception
:
except
Exception
:
pipe_writer
.
send
(
get_exception_traceback
())
pipe_writer
.
send
(
get_exception_traceback
())
raise
raise
pipe_writer
.
send
(
"init ok"
)
loop
=
asyncio
.
new_event_loop
()
pipe_writer
.
send
(
"init ok"
)
loop
.
set_default_executor
(
ThreadPoolExecutor
(
max_workers
=
256
))
asyncio
.
set_event_loop
(
loop
)
try
:
loop
.
create_task
(
controller
.
loop_for_recv_requests
())
controller
.
loop_for_forward
()
loop
.
run_until_complete
(
controller
.
loop_for_forward
())
except
Exception
:
logger
.
error
(
"Exception in ControllerMulti:
\n
"
+
get_exception_traceback
())
finally
:
for
w
in
controller
.
workers
:
os
.
kill
(
w
.
proc
.
pid
,
9
)
kill_parent_process
()
python/sglang/srt/managers/controller/manager_single.py
View file @
d774acad
...
@@ -3,126 +3,61 @@
...
@@ -3,126 +3,61 @@
import
logging
import
logging
import
multiprocessing
import
multiprocessing
import
os
import
os
import
pickle
from
typing
import
List
import
torch
import
torch.distributed
as
dist
import
zmq
import
zmq
import
zmq.asyncio
from
sglang.srt.managers.controller.tp_worker
import
ModelTpServer
from
sglang.srt.managers.controller.tp_worker
import
(
from
sglang.srt.server_args
import
ModelPortArgs
,
PortArgs
,
ServerArgs
broadcast_recv_input
,
launch_tp_servers
,
ModelTpServer
)
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.srt.utils
import
kill_parent_process
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
"srt.controller"
)
logger
=
logging
.
getLogger
(
"srt.controller"
)
def
run_tp_server
(
gpu_id
:
int
,
tp_rank
:
int
,
server_args
:
ServerArgs
,
model_port_args
:
ModelPortArgs
,
model_overide_args
:
dict
,
):
"""Run a tp server."""
try
:
model_server
=
ModelTpServer
(
gpu_id
,
tp_rank
,
server_args
,
model_port_args
,
model_overide_args
,
)
tp_cpu_group
=
model_server
.
model_runner
.
tp_group
.
cpu_group
while
True
:
recv_reqs
=
broadcast_recv_input
(
None
,
tp_rank
,
tp_cpu_group
)
model_server
.
exposed_step
(
recv_reqs
)
except
Exception
:
logger
.
error
(
"Exception in run_tp_server:
\n
"
+
get_exception_traceback
())
raise
def
launch_tp_servers
(
gpu_ids
,
tp_rank_range
,
server_args
,
model_port_args
,
model_overide_args
):
"""Launch multiple tp servers."""
procs
=
[]
for
i
in
tp_rank_range
:
proc
=
multiprocessing
.
Process
(
target
=
run_tp_server
,
args
=
(
gpu_ids
[
i
],
i
,
server_args
,
model_port_args
,
model_overide_args
),
)
proc
.
start
()
procs
.
append
(
proc
)
return
procs
def
broadcast_recv_input
(
data
,
rank
,
dist_group
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if
rank
==
0
:
if
len
(
data
)
==
0
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
else
:
serialized_data
=
pickle
.
dumps
(
data
)
size
=
len
(
serialized_data
)
tensor_data
=
torch
.
ByteTensor
(
list
(
serialized_data
))
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
else
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
size
=
tensor_size
.
item
()
if
size
==
0
:
return
[]
tensor_data
=
torch
.
empty
(
size
,
dtype
=
torch
.
uint8
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
serialized_data
=
bytes
(
tensor_data
.
tolist
())
data
=
pickle
.
loads
(
serialized_data
)
return
data
class
ControllerSingle
:
class
ControllerSingle
:
"""A controller that manages a group of tensor parallel workers."""
"""A controller that manages a group of tensor parallel workers."""
def
__init__
(
def
__init__
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
:
dict
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
:
dict
,
gpu_ids
:
List
[
int
],
is_data_parallel_worker
:
bool
,
dp_worker_id
:
int
,
mp_queue
:
multiprocessing
.
Queue
,
):
):
# Parse args
# Parse args
self
.
server_args
=
server_args
self
.
tp_size
=
server_args
.
tp_size
self
.
tp_procs
=
[]
self
.
is_dp_worker
=
is_data_parallel_worker
self
.
dp_worker_id
=
dp_worker_id
self
.
mp_queue
=
mp_queue
# Init communication
# Init communication
context
=
zmq
.
Context
(
2
)
context
=
zmq
.
Context
(
2
)
if
not
self
.
is_dp_worker
:
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
rout
er_port
}
"
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
controll
er_port
}
"
)
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
self
.
send_to_detokenizer
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
detokenizer_port
}
"
f
"tcp://127.0.0.1:
{
port_args
.
detokenizer_port
}
"
)
)
# Init model server
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
gpu_ids
=
[
i
for
_
in
range
(
server_args
.
nnodes
)
for
i
in
range
(
tp_size_local
)]
# Launch other tp ranks
# Launch other tp ranks
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
self
.
tp_procs
=
[]
if
tp_size_local
>
1
:
if
tp_size_local
>
1
:
tp_rank_range
=
range
(
1
,
tp_size_local
)
tp_rank_range
=
range
(
1
,
tp_size_local
)
self
.
tp_procs
=
launch_tp_servers
(
self
.
tp_procs
=
launch_tp_servers
(
gpu_ids
,
gpu_ids
,
tp_rank_range
,
tp_rank_range
,
server_args
,
server_args
,
port_args
.
mode
l_port
_args
[
0
],
port_args
.
ncc
l_port
s
[
dp_worker_id
],
model_overide_args
,
model_overide_args
,
)
)
...
@@ -131,16 +66,19 @@ class ControllerSingle:
...
@@ -131,16 +66,19 @@ class ControllerSingle:
gpu_ids
[
0
],
gpu_ids
[
0
],
0
,
0
,
server_args
,
server_args
,
port_args
.
mode
l_port
_args
[
0
],
port_args
.
ncc
l_port
s
[
dp_worker_id
],
model_overide_args
,
model_overide_args
,
)
)
self
.
tp_cpu_group
=
self
.
tp_server
.
model_runner
.
tp_group
.
cpu_group
self
.
tp_cpu_group
=
self
.
tp_server
.
model_runner
.
tp_group
.
cpu_group
def
loop_for_forward
(
self
):
def
loop_for_forward
(
self
):
while
True
:
while
True
:
recv_reqs
=
self
.
recv_requests
()
if
not
self
.
is_dp_worker
:
recv_reqs
=
self
.
recv_requests_from_zmq
()
else
:
recv_reqs
=
self
.
recv_requests_from_mp_queue
()
if
self
.
server_args
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
broadcast_recv_input
(
recv_reqs
,
0
,
self
.
tp_cpu_group
)
broadcast_recv_input
(
recv_reqs
,
0
,
self
.
tp_cpu_group
)
out_pyobjs
=
self
.
tp_server
.
exposed_step
(
recv_reqs
)
out_pyobjs
=
self
.
tp_server
.
exposed_step
(
recv_reqs
)
...
@@ -148,27 +86,51 @@ class ControllerSingle:
...
@@ -148,27 +86,51 @@ class ControllerSingle:
for
obj
in
out_pyobjs
:
for
obj
in
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
def
recv_requests
(
self
):
def
recv_requests
_from_zmq
(
self
):
recv_reqs
=
[]
recv_reqs
=
[]
while
True
:
while
True
:
try
:
try
:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
recv_reqs
.
append
(
recv_req
)
except
zmq
.
ZMQError
:
except
zmq
.
ZMQError
:
break
break
recv_reqs
.
append
(
recv_req
)
return
recv_reqs
def
recv_requests_from_mp_queue
(
self
):
recv_reqs
=
[]
while
not
self
.
mp_queue
.
empty
():
recv_reqs
.
append
(
self
.
mp_queue
.
get
())
return
recv_reqs
return
recv_reqs
def
start_controller_process
(
def
start_controller_process
(
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
,
model_overide_args
:
dict
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
:
multiprocessing
.
connection
.
Connection
,
model_overide_args
:
dict
,
is_data_parallel_worker
:
bool
=
False
,
gpu_ids
:
List
[
int
]
=
None
,
dp_worker_id
:
int
=
None
,
queue
:
multiprocessing
.
connection
.
Connection
=
None
,
):
):
"""Start a controller process."""
logging
.
basicConfig
(
logging
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
format
=
"%(message)s"
,
format
=
"%(message)s"
,
)
)
if
not
is_data_parallel_worker
:
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
gpu_ids
=
[
i
for
_
in
range
(
server_args
.
nnodes
)
for
i
in
range
(
tp_size_local
)]
dp_worker_id
=
0
queue
=
None
try
:
try
:
controller
=
ControllerSingle
(
server_args
,
port_args
,
model_overide_args
)
controller
=
ControllerSingle
(
server_args
,
port_args
,
model_overide_args
,
gpu_ids
,
is_data_parallel_worker
,
dp_worker_id
,
queue
)
except
Exception
:
except
Exception
:
pipe_writer
.
send
(
get_exception_traceback
())
pipe_writer
.
send
(
get_exception_traceback
())
raise
raise
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
d774acad
"""A tensor parallel worker."""
"""A tensor parallel worker."""
import
asyncio
import
logging
import
logging
import
multiprocessing
import
pickle
import
time
import
time
import
warnings
import
warnings
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
rpyc
import
torch
import
torch
from
rpyc.utils.classic
import
obtain
import
torch.distributed
as
dist
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
...
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
...
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.server_args
import
ModelPortArgs
,
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
connect_rpyc_service
,
get_int_token_logit_bias
,
get_int_token_logit_bias
,
is_multimodal_model
,
is_multimodal_model
,
set_random_seed
,
set_random_seed
,
start_rpyc_service_process
,
suppress_other_loggers
,
suppress_other_loggers
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -52,10 +49,9 @@ class ModelTpServer:
...
@@ -52,10 +49,9 @@ class ModelTpServer:
gpu_id
:
int
,
gpu_id
:
int
,
tp_rank
:
int
,
tp_rank
:
int
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
mode
l_port
_args
:
ModelPortArgs
,
ncc
l_port
:
int
,
model_overide_args
:
dict
,
model_overide_args
:
dict
,
):
):
server_args
,
model_port_args
=
obtain
(
server_args
),
obtain
(
model_port_args
)
suppress_other_loggers
()
suppress_other_loggers
()
# Copy arguments
# Copy arguments
...
@@ -79,7 +75,7 @@ class ModelTpServer:
...
@@ -79,7 +75,7 @@ class ModelTpServer:
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
tp_size
=
server_args
.
tp_size
,
nccl_port
=
model_port_args
.
nccl_port
,
nccl_port
=
nccl_port
,
server_args
=
server_args
,
server_args
=
server_args
,
)
)
...
@@ -178,9 +174,6 @@ class ModelTpServer:
...
@@ -178,9 +174,6 @@ class ModelTpServer:
self
.
new_token_ratio_recovery
=
global_config
.
new_token_ratio_recovery
self
.
new_token_ratio_recovery
=
global_config
.
new_token_ratio_recovery
def
exposed_step
(
self
,
recv_reqs
):
def
exposed_step
(
self
,
recv_reqs
):
if
not
isinstance
(
recv_reqs
,
list
):
recv_reqs
=
obtain
(
recv_reqs
)
try
:
try
:
# Recv requests
# Recv requests
for
recv_req
in
recv_reqs
:
for
recv_req
in
recv_reqs
:
...
@@ -425,12 +418,6 @@ class ModelTpServer:
...
@@ -425,12 +418,6 @@ class ModelTpServer:
f
"#running-req:
{
running_bs
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
-
len
(
can_run_list
)
}
"
f
"#queue-req:
{
len
(
self
.
forward_queue
)
-
len
(
can_run_list
)
}
"
)
)
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# Return the new batch
# Return the new batch
new_batch
=
Batch
.
init_new
(
new_batch
=
Batch
.
init_new
(
...
@@ -733,87 +720,74 @@ class ModelTpServer:
...
@@ -733,87 +720,74 @@ class ModelTpServer:
break
break
class
ModelTpService
(
rpyc
.
Service
):
def
run_tp_server
(
exposed_ModelTpServer
=
ModelTpServer
gpu_id
:
int
,
tp_rank
:
int
,
class
ModelTpClient
:
def
__init__
(
self
,
gpu_ids
:
List
[
int
],
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
model_port_args
:
ModelPortArgs
,
nccl_port
:
int
,
model_overide_args
,
model_overide_args
:
dict
,
):
):
server_args
,
model_port_args
=
obtain
(
server_args
),
obtain
(
model_port_args
)
"""Run a tensor parallel server."""
self
.
tp_size
=
server_args
.
tp_size
try
:
model_server
=
ModelTpServer
(
if
self
.
tp_size
*
server_args
.
dp_size
==
1
:
gpu_id
,
# Init model
tp_rank
,
assert
len
(
gpu_ids
)
==
1
self
.
model_server
=
ModelTpService
().
exposed_ModelTpServer
(
gpu_ids
[
0
],
0
,
server_args
,
server_args
,
mode
l_port
_args
,
ncc
l_port
,
model_overide_args
,
model_overide_args
,
)
)
tp_cpu_group
=
model_server
.
model_runner
.
tp_group
.
cpu_group
# Wrap functions
while
True
:
def
async_wrap
(
f
):
recv_reqs
=
broadcast_recv_input
(
None
,
tp_rank
,
tp_cpu_group
)
async
def
_func
(
*
args
,
**
kwargs
):
model_server
.
exposed_step
(
recv_reqs
)
return
f
(
*
args
,
**
kwargs
)
except
Exception
:
logger
.
error
(
"Exception in run_tp_server:
\n
"
+
get_exception_traceback
())
raise
return
_func
self
.
step
=
async_wrap
(
self
.
model_server
.
exposed_step
)
def
launch_tp_servers
(
else
:
gpu_ids
,
tp_rank_range
,
server_args
,
nccl_port
,
model_overide_args
with
ThreadPoolExecutor
(
self
.
tp_size
)
as
executor
:
):
# Launch model processes
"""Launch multiple tensor parallel servers."""
if
server_args
.
nnodes
==
1
:
procs
=
[]
self
.
procs
=
list
(
for
i
in
tp_rank_range
:
executor
.
map
(
proc
=
multiprocessing
.
Process
(
lambda
args
:
start_rpyc_service_process
(
*
args
),
target
=
run_tp_server
,
[
args
=
(
gpu_ids
[
i
],
i
,
server_args
,
nccl_port
,
model_overide_args
),
(
ModelTpService
,
p
)
for
p
in
model_port_args
.
model_tp_ports
],
)
)
)
addrs
=
[(
"localhost"
,
p
)
for
p
in
model_port_args
.
model_tp_ports
]
proc
.
start
()
else
:
procs
.
append
(
proc
)
addrs
=
[
(
ip
,
port
)
for
ip
,
port
in
zip
(
model_port_args
.
model_tp_ips
,
model_port_args
.
model_tp_ports
)
]
self
.
model_services
=
list
(
return
procs
executor
.
map
(
lambda
args
:
connect_rpyc_service
(
*
args
),
addrs
)
)
# Init model
def
init_model
(
i
):
return
self
.
model_services
[
i
].
ModelTpServer
(
gpu_ids
[
i
],
i
,
server_args
,
model_port_args
,
model_overide_args
,
)
self
.
model_servers
=
list
(
executor
.
map
(
init_model
,
range
(
self
.
tp_size
)))
def
broadcast_recv_input
(
data
,
rank
,
dist_group
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
# Wrap functions
if
rank
==
0
:
def
async_wrap
(
func_name
):
if
len
(
data
)
==
0
:
fs
=
[
rpyc
.
async_
(
getattr
(
m
,
func_name
))
for
m
in
self
.
model_servers
]
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
else
:
serialized_data
=
pickle
.
dumps
(
data
)
size
=
len
(
serialized_data
)
tensor_data
=
torch
.
ByteTensor
(
list
(
serialized_data
))
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
else
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
size
=
tensor_size
.
item
()
async
def
_func
(
*
args
,
**
kwargs
):
if
size
==
0
:
tasks
=
[
f
(
*
args
,
**
kwargs
)
for
f
in
fs
]
return
[]
await
asyncio
.
gather
(
*
[
asyncio
.
to_thread
(
t
.
wait
)
for
t
in
tasks
])
return
obtain
(
tasks
[
0
].
value
)
return
_func
tensor_data
=
torch
.
empty
(
size
,
dtype
=
torch
.
uint8
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
self
.
step
=
async_wrap
(
"step"
)
serialized_data
=
bytes
(
tensor_data
.
tolist
())
data
=
pickle
.
loads
(
serialized_data
)
return
data
python/sglang/srt/managers/tokenizer_manager.py
View file @
d774acad
...
@@ -61,7 +61,7 @@ class TokenizerManager:
...
@@ -61,7 +61,7 @@ class TokenizerManager:
self
.
recv_from_detokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
self
.
recv_from_detokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
self
.
send_to_router
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_router
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_router
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
rout
er_port
}
"
)
self
.
send_to_router
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
controll
er_port
}
"
)
self
.
model_path
=
server_args
.
model_path
self
.
model_path
=
server_args
.
model_path
self
.
hf_config
=
get_config
(
self
.
hf_config
=
get_config
(
...
...
python/sglang/srt/server.py
View file @
d774acad
...
@@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import (
...
@@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import (
v1_chat_completions
,
v1_chat_completions
,
v1_completions
,
v1_completions
,
)
)
from
sglang.srt.server_args
import
ModelPortArgs
,
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
API_KEY_HEADER_NAME
,
API_KEY_HEADER_NAME
,
APIKeyValidatorMiddleware
,
APIKeyValidatorMiddleware
,
allocate_init_ports
,
allocate_init_ports
,
assert_pkg_version
,
assert_pkg_version
,
enable_show_time_cost
,
enable_show_time_cost
,
receive_addrs
,
send_addrs_to_rank_0
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -98,6 +96,7 @@ async def flush_cache():
...
@@ -98,6 +96,7 @@ async def flush_cache():
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
"""Handle a generate request."""
if
obj
.
stream
:
if
obj
.
stream
:
async
def
stream_results
():
async
def
stream_results
():
...
@@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs):
...
@@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs):
}
}
def
launch_server
(
server_args
:
ServerArgs
,
pipe_finish_writer
,
model_overide_args
=
None
):
def
launch_server
(
server_args
:
ServerArgs
,
model_overide_args
:
Optional
[
dict
]
=
None
,
pipe_finish_writer
:
Optional
[
mp
.
connection
.
Connection
]
=
None
):
"""Launch an HTTP server."""
global
tokenizer_manager
global
tokenizer_manager
logging
.
basicConfig
(
logging
.
basicConfig
(
...
@@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
server_args
.
chat_template
:
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
_set_global_server_args
(
server_args
)
_set_global_server_args
(
server_args
)
# Allocate ports
# Allocate ports
assert
server_args
.
tp_size
%
server_args
.
nnodes
==
0
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
port
,
server_args
.
additional_ports
,
server_args
.
additional_ports
,
tp_size_local
,
server_args
.
dp_size
,
server_args
.
dp_size
,
)
)
ports
=
server_args
.
additional_ports
ports
=
server_args
.
additional_ports
model_port_args
=
[]
for
i
in
range
(
server_args
.
dp_size
):
model_port_args
.
append
(
ModelPortArgs
(
nccl_port
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)],
model_tp_ips
=
[
None
]
*
tp_size_local
,
model_tp_ports
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)
+
1
:
3
+
(
i
+
1
)
*
(
tp_size_local
+
1
)
],
)
)
port_args
=
PortArgs
(
port_args
=
PortArgs
(
tokenizer_port
=
ports
[
0
],
tokenizer_port
=
ports
[
0
],
rout
er_port
=
ports
[
1
],
controll
er_port
=
ports
[
1
],
detokenizer_port
=
ports
[
2
],
detokenizer_port
=
ports
[
2
],
mode
l_port
_args
=
model_port_args
,
ncc
l_port
s
=
ports
[
3
:]
,
)
)
# Handle multi-node t
p
# Handle multi-node t
ensor parallelism
if
server_args
.
nnodes
>
1
:
if
server_args
.
nnodes
>
1
:
assert
server_args
.
dp_size
==
1
,
"Multi-node dp is not supported."
assert
server_args
.
dp_size
==
1
,
"Multi-node dp is not supported."
...
@@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
gpu_ids
,
gpu_ids
,
tp_rank_range
,
tp_rank_range
,
server_args
,
server_args
,
port
_args
.
model_port_arg
s
[
0
],
ports
[
3
],
model_overide_args
,
model_overide_args
,
)
)
while
True
:
while
True
:
...
@@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Launch processes
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
pipe_
rout
er_reader
,
pipe_
rout
er_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_
controll
er_reader
,
pipe_
controll
er_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_detoken_reader
,
pipe_detoken_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_detoken_reader
,
pipe_detoken_writer
=
mp
.
Pipe
(
duplex
=
False
)
if
server_args
.
dp_size
==
1
:
if
server_args
.
dp_size
==
1
:
start_process
=
start_controller_process_single
start_process
=
start_controller_process_single
else
:
else
:
start_process
=
start_controller_process_multi
start_process
=
start_controller_process_multi
proc_
rout
er
=
mp
.
Process
(
proc_
controll
er
=
mp
.
Process
(
target
=
start_process
,
target
=
start_process
,
args
=
(
server_args
,
port_args
,
pipe_
rout
er_writer
,
model_overide_args
),
args
=
(
server_args
,
port_args
,
pipe_
controll
er_writer
,
model_overide_args
),
)
)
proc_
rout
er
.
start
()
proc_
controll
er
.
start
()
proc_detoken
=
mp
.
Process
(
proc_detoken
=
mp
.
Process
(
target
=
start_detokenizer_process
,
target
=
start_detokenizer_process
,
args
=
(
args
=
(
...
@@ -255,27 +241,44 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -255,27 +241,44 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
proc_detoken
.
start
()
proc_detoken
.
start
()
# Wait for the model to finish loading
# Wait for the model to finish loading
rout
er_init_state
=
pipe_
rout
er_reader
.
recv
()
controll
er_init_state
=
pipe_
controll
er_reader
.
recv
()
detoken_init_state
=
pipe_detoken_reader
.
recv
()
detoken_init_state
=
pipe_detoken_reader
.
recv
()
if
rout
er_init_state
!=
"init ok"
or
detoken_init_state
!=
"init ok"
:
if
controll
er_init_state
!=
"init ok"
or
detoken_init_state
!=
"init ok"
:
proc_
rout
er
.
kill
()
proc_
controll
er
.
kill
()
proc_detoken
.
kill
()
proc_detoken
.
kill
()
print
(
print
(
f
"Initialization failed.
rout
er_init_state:
{
rout
er_init_state
}
"
,
flush
=
True
f
"Initialization failed.
controll
er_init_state:
{
controll
er_init_state
}
"
,
flush
=
True
)
)
print
(
print
(
f
"Initialization failed. detoken_init_state:
{
detoken_init_state
}
"
,
f
"Initialization failed. detoken_init_state:
{
detoken_init_state
}
"
,
flush
=
True
,
flush
=
True
,
)
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
assert
proc_
rout
er
.
is_alive
()
and
proc_detoken
.
is_alive
()
assert
proc_
controll
er
.
is_alive
()
and
proc_detoken
.
is_alive
()
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
# Send a warmup request
# Send a warmup request
def
_wait_and_warmup
():
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
,
args
=
(
server_args
,
pipe_finish_writer
))
t
.
start
()
# Listen for requests
try
:
uvicorn
.
run
(
app
,
host
=
server_args
.
host
,
port
=
server_args
.
port
,
log_level
=
server_args
.
log_level_http
or
server_args
.
log_level
,
timeout_keep_alive
=
5
,
loop
=
"uvloop"
,
)
finally
:
t
.
join
()
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
):
headers
=
{}
headers
=
{}
url
=
server_args
.
url
()
url
=
server_args
.
url
()
if
server_args
.
api_key
:
if
server_args
.
api_key
:
...
@@ -316,22 +319,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -316,22 +319,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"init ok"
)
pipe_finish_writer
.
send
(
"init ok"
)
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
)
t
.
start
()
# Listen for requests
try
:
uvicorn
.
run
(
app
,
host
=
server_args
.
host
,
port
=
server_args
.
port
,
log_level
=
server_args
.
log_level_http
or
server_args
.
log_level
,
timeout_keep_alive
=
5
,
loop
=
"uvloop"
,
)
finally
:
t
.
join
()
class
Runtime
:
class
Runtime
:
"""
"""
...
@@ -354,7 +341,6 @@ class Runtime:
...
@@ -354,7 +341,6 @@ class Runtime:
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
=
allocate_init_ports
(
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
=
allocate_init_ports
(
self
.
server_args
.
port
,
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
,
self
.
server_args
.
additional_ports
,
self
.
server_args
.
tp_size
,
self
.
server_args
.
dp_size
,
self
.
server_args
.
dp_size
,
)
)
...
@@ -367,7 +353,7 @@ class Runtime:
...
@@ -367,7 +353,7 @@ class Runtime:
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc
=
mp
.
Process
(
proc
=
mp
.
Process
(
target
=
launch_server
,
target
=
launch_server
,
args
=
(
self
.
server_args
,
pipe_writer
,
model_overide_args
),
args
=
(
self
.
server_args
,
model_overide_args
,
pipe_writer
),
)
)
proc
.
start
()
proc
.
start
()
pipe_writer
.
close
()
pipe_writer
.
close
()
...
...
python/sglang/srt/server_args.py
View file @
d774acad
...
@@ -337,16 +337,9 @@ class ServerArgs:
...
@@ -337,16 +337,9 @@ class ServerArgs:
)
)
@
dataclasses
.
dataclass
class
ModelPortArgs
:
nccl_port
:
int
model_tp_ips
:
List
[
str
]
model_tp_ports
:
List
[
int
]
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
PortArgs
:
class
PortArgs
:
tokenizer_port
:
int
tokenizer_port
:
int
rout
er_port
:
int
controll
er_port
:
int
detokenizer_port
:
int
detokenizer_port
:
int
mode
l_port
_arg
s
:
List
[
ModelPortArgs
]
ncc
l_ports
:
List
[
int
]
python/sglang/srt/utils.py
View file @
d774acad
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
import
base64
import
base64
import
fcntl
import
fcntl
import
logging
import
logging
import
multiprocessing
import
os
import
os
import
random
import
random
import
socket
import
socket
...
@@ -16,12 +15,10 @@ from typing import List, Optional
...
@@ -16,12 +15,10 @@ from typing import List, Optional
import
numpy
as
np
import
numpy
as
np
import
psutil
import
psutil
import
requests
import
requests
import
rpyc
import
torch
import
torch
import
triton
import
triton
from
fastapi.responses
import
JSONResponse
from
fastapi.responses
import
JSONResponse
from
packaging
import
version
as
pkg_version
from
packaging
import
version
as
pkg_version
from
rpyc.utils.server
import
ThreadedServer
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.middleware.base
import
BaseHTTPMiddleware
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -148,7 +145,6 @@ def is_port_available(port):
...
@@ -148,7 +145,6 @@ def is_port_available(port):
def
allocate_init_ports
(
def
allocate_init_ports
(
port
:
Optional
[
int
]
=
None
,
port
:
Optional
[
int
]
=
None
,
additional_ports
:
Optional
[
List
[
int
]]
=
None
,
additional_ports
:
Optional
[
List
[
int
]]
=
None
,
tp_size
:
int
=
1
,
dp_size
:
int
=
1
,
dp_size
:
int
=
1
,
):
):
"""Allocate ports for all connections."""
"""Allocate ports for all connections."""
...
@@ -160,8 +156,8 @@ def allocate_init_ports(
...
@@ -160,8 +156,8 @@ def allocate_init_ports(
ret_ports
=
list
(
set
(
x
for
x
in
ret_ports
if
is_port_available
(
x
)))
ret_ports
=
list
(
set
(
x
for
x
in
ret_ports
if
is_port_available
(
x
)))
cur_port
=
ret_ports
[
-
1
]
+
1
if
len
(
ret_ports
)
>
0
else
10000
cur_port
=
ret_ports
[
-
1
]
+
1
if
len
(
ret_ports
)
>
0
else
10000
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl
+ tp_size
)
# HTTP + Tokenizer + Controller + Detokenizer + dp_size *
1
(nccl)
num_ports_needed
=
4
+
dp_size
*
(
1
+
tp_size
)
num_ports_needed
=
4
+
dp_size
while
len
(
ret_ports
)
<
num_ports_needed
:
while
len
(
ret_ports
)
<
num_ports_needed
:
if
cur_port
not
in
ret_ports
and
is_port_available
(
cur_port
):
if
cur_port
not
in
ret_ports
and
is_port_available
(
cur_port
):
ret_ports
.
append
(
cur_port
)
ret_ports
.
append
(
cur_port
)
...
@@ -371,49 +367,6 @@ def load_image(image_file):
...
@@ -371,49 +367,6 @@ def load_image(image_file):
return
image
,
image_size
return
image
,
image_size
def
connect_rpyc_service
(
host
,
port
):
repeat_count
=
0
while
repeat_count
<
20
:
try
:
con
=
rpyc
.
connect
(
host
,
port
,
config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
3600
,
},
)
break
except
ConnectionRefusedError
as
e
:
time
.
sleep
(
1
)
repeat_count
+=
1
if
repeat_count
==
20
:
raise
RuntimeError
(
f
"Connect rpyc error:
{
e
}
"
)
return
con
.
root
def
start_rpyc_service
(
service
:
rpyc
.
Service
,
port
:
int
):
t
=
ThreadedServer
(
service
=
service
,
port
=
port
,
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
3600
,
},
)
t
.
logger
.
setLevel
(
logging
.
WARN
)
t
.
start
()
def
start_rpyc_service_process
(
service
:
rpyc
.
Service
,
port
:
int
):
proc
=
multiprocessing
.
Process
(
target
=
start_rpyc_service
,
args
=
(
service
,
port
))
proc
.
start
()
return
proc
def
suppress_other_loggers
():
def
suppress_other_loggers
():
from
vllm.logger
import
logger
as
vllm_default_logger
from
vllm.logger
import
logger
as
vllm_default_logger
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment