Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d774acad
"vscode:/vscode.git/clone" did not exist on "5de43476632863dd2d3540e7b1d2e18c2fc14aec"
Unverified
Commit
d774acad
authored
Jul 18, 2024
by
Mingyi
Committed by
GitHub
Jul 18, 2024
Browse files
Remove the dependency of rpyc (#646)
parent
d93388da
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
303 additions
and
551 deletions
+303
-551
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/launch_server.py
python/sglang/launch_server.py
+1
-1
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+1
-4
python/sglang/srt/managers/controller/dp_worker.py
python/sglang/srt/managers/controller/dp_worker.py
+0
-113
python/sglang/srt/managers/controller/manager_multi.py
python/sglang/srt/managers/controller/manager_multi.py
+102
-102
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+58
-96
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+72
-98
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+63
-77
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-9
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-49
No files found.
python/pyproject.toml
View file @
d774acad
...
...
@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"packaging"
,
"pillow"
,
"psutil"
,
"pydantic"
,
"rpyc"
,
"torch"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"vllm==0.5.1"
,
"outlines>=0.0.44"
]
"psutil"
,
"pydantic"
,
"torch"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"vllm==0.5.1"
,
"outlines>=0.0.44"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
...
...
python/sglang/launch_server.py
View file @
d774acad
...
...
@@ -11,4 +11,4 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
launch_server
(
server_args
,
None
)
launch_server
(
server_args
)
python/sglang/launch_server_llavavid.py
View file @
d774acad
"""Launch the inference server for Llava-video model."""
import
argparse
import
multiprocessing
as
mp
from
sglang.srt.server
import
ServerArgs
,
launch_server
...
...
@@ -27,6 +26,4 @@ if __name__ == "__main__":
server_args
=
ServerArgs
.
from_cli_args
(
args
)
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
launch_server
(
server_args
,
pipe_writer
,
model_overide_args
)
launch_server
(
server_args
,
model_overide_args
,
None
)
python/sglang/srt/managers/controller/dp_worker.py
deleted
100644 → 0
View file @
d93388da
"""A data parallel worker thread."""
import
asyncio
import
logging
import
queue
import
threading
from
typing
import
Callable
,
List
import
uvloop
import
zmq
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
"srt.controller"
)
CHECKING_INTERVAL
=
5
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
class
DataParallelWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
worker_id
:
int
,
request_queue
:
queue
.
Queue
,
detokenizer_port
:
int
,
step_func
:
Callable
,
):
super
(
DataParallelWorkerThread
,
self
).
__init__
()
self
.
worker_id
=
worker_id
self
.
request_queue
=
request_queue
self
.
liveness
=
True
self
.
request_dependency_delay
=
global_config
.
request_dependency_delay
context
=
zmq
.
asyncio
.
Context
()
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
f
"tcp://127.0.0.1:
{
detokenizer_port
}
"
)
self
.
step
=
step_func
async
def
loop_for_forward
(
self
):
while
self
.
liveness
:
requests
=
[]
while
not
self
.
request_queue
.
empty
():
requests
.
append
(
self
.
request_queue
.
get
())
out_pyobjs
:
List
[
BatchTokenIDOut
]
=
[]
try
:
out_pyobjs
=
await
self
.
step
(
requests
)
except
Exception
:
for
r
in
requests
:
self
.
request_queue
.
put
(
r
)
logger
.
error
(
f
"Worker thread
{
self
.
worker_id
}
: "
f
"failed to get back from Model Server
\n
"
f
"
{
get_exception_traceback
()
}
"
)
self
.
liveness
=
False
# Crash the whole server when there are any errors.
# TODO(lianmin): make this an option.
kill_parent_process
()
return
for
obj
in
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
# async sleep for receiving the subsequent request and avoiding cache miss
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
(
[
obj
.
finished_reason
is
not
None
for
obj
in
out_pyobjs
]
)
if
has_finished
:
await
asyncio
.
sleep
(
self
.
request_dependency_delay
)
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
async
def
monitoring
(
self
):
while
True
:
await
asyncio
.
sleep
(
CHECKING_INTERVAL
)
# can plug in monitoring logic here
def
run
(
self
):
logger
.
info
(
f
"DataParallelWorkerThread
{
self
.
worker_id
}
start"
)
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
loop
.
create_task
(
self
.
monitoring
())
loop
.
run_until_complete
(
self
.
loop_for_forward
())
def
start_data_parallel_worker
(
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
,
gpu_ids
:
List
[
int
],
worker_id
:
int
,
):
model_tp_client
=
ModelTpClient
(
gpu_ids
,
server_args
,
port_args
.
model_port_args
[
worker_id
],
model_overide_args
,
)
worker_thread
=
DataParallelWorkerThread
(
worker_id
=
worker_id
,
request_queue
=
queue
.
Queue
(),
detokenizer_port
=
port_args
.
detokenizer_port
,
step_func
=
model_tp_client
.
step
,
)
worker_thread
.
start
()
return
worker_thread
python/sglang/srt/managers/controller/manager_multi.py
View file @
d774acad
...
...
@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
Each data parallel worker can manage multiple tensor parallel workers.
"""
import
asyncio
import
dataclasses
import
logging
from
concurrent.futures
import
ThreadPoolExecutor
import
multiprocessing
import
os
from
enum
import
Enum
,
auto
from
typing
import
Dict
import
numpy
as
np
import
zmq
import
zmq.asyncio
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.dp_worker
import
(
DataParallelWorkerThread
,
start_data_parallel_worker
,
from
sglang.srt.managers.controller.manager_single
import
(
start_controller_process
as
start_controller_process_single
,
)
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
...
...
@@ -23,12 +21,14 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput
,
)
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
"srt.controller"
)
class
LoadBalanceMethod
(
Enum
):
"""Load balance method."""
ROUND_ROBIN
=
auto
()
SHORTEST_QUEUE
=
auto
()
...
...
@@ -41,155 +41,155 @@ class LoadBalanceMethod(Enum):
raise
ValueError
(
f
"Invalid load balance method:
{
method
}
"
)
from
exc
class
Controller
:
@
dataclasses
.
dataclass
class
WorkerHandle
:
"""Store the handle of a data parallel worker."""
proc
:
multiprocessing
.
Process
queue
:
multiprocessing
.
Queue
class
ControllerMulti
:
"""A controller that manages multiple data parallel workers."""
def
__init__
(
self
,
load_balance_method
:
str
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
,
):
self
.
load_balance_method
=
LoadBalanceMethod
.
from_str
(
load_balance_method
)
# Parse args
self
.
server_args
=
server_args
self
.
port_args
=
port_args
self
.
model_overide_args
=
model_overide_args
self
.
load_balance_method
=
LoadBalanceMethod
.
from_str
(
server_args
.
load_balance_method
)
if
self
.
load_balance_method
==
LoadBalanceMethod
.
ROUND_ROBIN
:
self
.
round_robin_counter
=
0
# Init communication
context
=
zmq
.
Context
()
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
controller_port
}
"
)
self
.
dispatch_lookup
=
{
# Dispatch method
self
.
round_robin_counter
=
0
dispatch_lookup
=
{
LoadBalanceMethod
.
ROUND_ROBIN
:
self
.
round_robin_scheduler
,
LoadBalanceMethod
.
SHORTEST_QUEUE
:
self
.
shortest_queue_scheduler
,
}
self
.
dispatching
=
self
.
dispatch_lookup
[
self
.
load_balance_method
]
# Init communication
context
=
zmq
.
asyncio
.
Context
()
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
router_port
}
"
)
# Init status
self
.
recv_reqs
=
[]
self
.
dispatching
=
dispatch_lookup
[
self
.
load_balance_method
]
# Start data parallel workers
self
.
workers
:
Dict
[
int
,
DataParallelWorkerThread
]
=
{}
tp_size
=
server_args
.
tp_size
def
start_dp_worker
(
i
):
try
:
gpu_ids
=
list
(
range
(
i
*
tp_size
,
(
i
+
1
)
*
tp_size
))
worker_thread
=
start_data_parallel_worker
(
server_args
,
port_args
,
model_overide_args
,
gpu_ids
,
i
)
self
.
workers
[
i
]
=
worker_thread
except
Exception
:
logger
.
error
(
f
"Failed to start local worker
{
i
}
\n
{
get_exception_traceback
()
}
"
)
self
.
workers
=
[]
for
i
in
range
(
server_args
.
dp_size
):
start_dp_worker
(
i
)
# Parallel launch is slower, probably due to the disk bandwidth limitations.
# with ThreadPoolExecutor(server_args.dp_size) as executor:
# executor.map(start_dp_worker, range(server_args.dp_size))
def
have_any_live_worker
(
self
):
return
any
(
worker_thread
.
liveness
for
worker_thread
in
self
.
workers
.
values
())
self
.
start_dp_worker
(
i
)
def
start_dp_worker
(
self
,
dp_worker_id
:
int
):
tp_size
=
self
.
server_args
.
tp_size
pipe_controller_reader
,
pipe_controller_writer
=
multiprocessing
.
Pipe
(
duplex
=
False
)
gpu_ids
=
list
(
range
(
dp_worker_id
*
tp_size
,
(
dp_worker_id
+
1
)
*
tp_size
))
queue
=
multiprocessing
.
Queue
()
proc
=
multiprocessing
.
Process
(
target
=
start_controller_process_single
,
args
=
(
self
.
server_args
,
self
.
port_args
,
pipe_controller_writer
,
self
.
model_overide_args
,
True
,
gpu_ids
,
dp_worker_id
,
queue
,
)
)
proc
.
start
()
def
put_req_to_worker
(
self
,
worker_id
,
req
):
self
.
workers
[
worker_id
].
request_queue
.
put
(
req
)
controller_init_state
=
pipe_controller_reader
.
recv
()
if
controller_init_state
!=
"init ok"
:
raise
RuntimeError
(
f
"Initialization failed. controller_init_state:
{
controller_init_state
}
"
)
self
.
workers
.
append
(
WorkerHandle
(
proc
=
proc
,
queue
=
queue
,
))
async
def
round_robin_scheduler
(
self
,
input_requests
):
available_workers
=
list
(
self
.
workers
.
keys
())
def
round_robin_scheduler
(
self
,
input_requests
):
for
r
in
input_requests
:
self
.
put_req_to_worker
(
available_
workers
[
self
.
round_robin_counter
]
,
r
)
self
.
workers
[
self
.
round_robin_counter
]
.
queue
.
put
(
r
)
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
available_
workers
self
.
workers
)
return
async
def
shortest_queue_scheduler
(
self
,
input_requests
):
def
shortest_queue_scheduler
(
self
,
input_requests
):
for
r
in
input_requests
:
worker
=
min
(
self
.
workers
,
key
=
lambda
w
:
self
.
workers
[
w
].
request_queue
.
qsize
()
)
self
.
put_req_to_worker
(
worker
,
r
)
return
async
def
remove_dead_workers
(
self
):
for
i
in
list
(
self
.
workers
.
keys
()):
worker_thread
=
self
.
workers
[
i
]
if
not
worker_thread
.
liveness
:
worker_thread
.
join
()
# move unsuccessful requests back to the queue
while
not
worker_thread
.
request_queue
.
empty
():
self
.
recv_reqs
.
append
(
worker_thread
.
request_queue
.
get
())
del
self
.
workers
[
i
]
logger
.
info
(
f
"Stale worker
{
i
}
removed"
)
async
def
loop_for_forward
(
self
):
while
True
:
await
self
.
remove_dead_workers
()
queue_sizes
=
[
worker
.
queue
.
qsize
()
for
worker
in
self
.
workers
]
wid
=
np
.
argmin
(
queue_sizes
)
self
.
workers
[
wid
].
queue
.
put
(
r
)
if
self
.
have_any_live_worker
():
next_step_input
=
list
(
self
.
recv_reqs
)
self
.
recv_reqs
=
[]
if
next_step_input
:
await
self
.
dispatching
(
next_step_input
)
# else:
# logger.error("There is no live worker.")
def
loop_for_forward
(
self
):
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
dispatching
(
recv_reqs
)
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
def
recv_requests
(
self
):
recv_reqs
=
[]
async
def
loop_for_recv_requests
(
self
):
while
True
:
recv_req
=
await
self
.
recv_from_tokenizer
.
recv_pyobj
()
try
:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
except
zmq
.
ZMQError
:
break
if
isinstance
(
recv_req
,
FlushCacheReq
):
# TODO(lsyin): apply more specific flushCacheReq
for
worker_thread
in
self
.
workers
.
values
():
worker_thread
.
request_queue
.
put
(
recv_req
)
elif
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
recv_reqs
.
append
(
recv_req
)
for
worker
in
self
.
workers
:
worker
.
queue
.
put
(
recv_req
)
elif
isinstance
(
recv_req
,
AbortReq
):
in_queue
=
False
for
i
,
req
in
enumerate
(
self
.
recv_reqs
):
for
i
,
req
in
enumerate
(
recv_reqs
):
if
req
.
rid
==
recv_req
.
rid
:
self
.
recv_reqs
[
i
]
=
recv_req
recv_reqs
[
i
]
=
recv_req
in_queue
=
True
break
if
not
in_queue
:
# Send abort req to all TP groups
for
worker
in
list
(
self
.
workers
.
keys
()):
self
.
put_req_to_worker
(
worker
,
recv_req
)
for
worker
in
self
.
workers
:
worker
.
queue
.
put
(
recv_req
)
elif
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
recv_reqs
.
append
(
recv_req
)
else
:
logger
.
error
(
f
"Invalid object:
{
recv_req
}
"
)
return
recv_reqs
def
start_controller_process
(
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
,
model_overide_args
=
None
,
model_overide_args
:
dict
,
):
"""Start a controller process."""
logging
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
format
=
"%(message)s"
,
)
try
:
controller
=
Controller
(
server_args
.
load_balance_method
,
server_args
,
port_args
,
model_overide_args
)
controller
=
ControllerMulti
(
server_args
,
port_args
,
model_overide_args
)
except
Exception
:
pipe_writer
.
send
(
get_exception_traceback
())
raise
pipe_writer
.
send
(
"init ok"
)
loop
=
asyncio
.
new_event_loop
()
loop
.
set_default_executor
(
ThreadPoolExecutor
(
max_workers
=
256
))
pipe_writer
.
send
(
"init ok"
)
asyncio
.
set_event_loop
(
loop
)
loop
.
create_task
(
controller
.
loop_for_recv_requests
())
loop
.
run_until_complete
(
controller
.
loop_for_forward
())
try
:
controller
.
loop_for_forward
()
except
Exception
:
logger
.
error
(
"Exception in ControllerMulti:
\n
"
+
get_exception_traceback
())
finally
:
for
w
in
controller
.
workers
:
os
.
kill
(
w
.
proc
.
pid
,
9
)
kill_parent_process
()
python/sglang/srt/managers/controller/manager_single.py
View file @
d774acad
...
...
@@ -3,126 +3,61 @@
import
logging
import
multiprocessing
import
os
import
pickle
from
typing
import
List
import
torch
import
torch.distributed
as
dist
import
zmq
import
zmq.asyncio
from
sglang.srt.managers.controller.tp_worker
import
ModelTpServer
from
sglang.srt.server_args
import
ModelPortArgs
,
PortArgs
,
ServerArgs
from
sglang.srt.managers.controller.tp_worker
import
(
broadcast_recv_input
,
launch_tp_servers
,
ModelTpServer
)
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
"srt.controller"
)
def
run_tp_server
(
gpu_id
:
int
,
tp_rank
:
int
,
server_args
:
ServerArgs
,
model_port_args
:
ModelPortArgs
,
model_overide_args
:
dict
,
):
"""Run a tp server."""
try
:
model_server
=
ModelTpServer
(
gpu_id
,
tp_rank
,
server_args
,
model_port_args
,
model_overide_args
,
)
tp_cpu_group
=
model_server
.
model_runner
.
tp_group
.
cpu_group
while
True
:
recv_reqs
=
broadcast_recv_input
(
None
,
tp_rank
,
tp_cpu_group
)
model_server
.
exposed_step
(
recv_reqs
)
except
Exception
:
logger
.
error
(
"Exception in run_tp_server:
\n
"
+
get_exception_traceback
())
raise
def
launch_tp_servers
(
gpu_ids
,
tp_rank_range
,
server_args
,
model_port_args
,
model_overide_args
):
"""Launch multiple tp servers."""
procs
=
[]
for
i
in
tp_rank_range
:
proc
=
multiprocessing
.
Process
(
target
=
run_tp_server
,
args
=
(
gpu_ids
[
i
],
i
,
server_args
,
model_port_args
,
model_overide_args
),
)
proc
.
start
()
procs
.
append
(
proc
)
return
procs
def
broadcast_recv_input
(
data
,
rank
,
dist_group
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if
rank
==
0
:
if
len
(
data
)
==
0
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
else
:
serialized_data
=
pickle
.
dumps
(
data
)
size
=
len
(
serialized_data
)
tensor_data
=
torch
.
ByteTensor
(
list
(
serialized_data
))
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
else
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
size
=
tensor_size
.
item
()
if
size
==
0
:
return
[]
tensor_data
=
torch
.
empty
(
size
,
dtype
=
torch
.
uint8
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
serialized_data
=
bytes
(
tensor_data
.
tolist
())
data
=
pickle
.
loads
(
serialized_data
)
return
data
class
ControllerSingle
:
"""A controller that manages a group of tensor parallel workers."""
def
__init__
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
:
dict
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
:
dict
,
gpu_ids
:
List
[
int
],
is_data_parallel_worker
:
bool
,
dp_worker_id
:
int
,
mp_queue
:
multiprocessing
.
Queue
,
):
# Parse args
self
.
server_args
=
server_args
self
.
tp_procs
=
[]
self
.
tp_size
=
server_args
.
tp_size
self
.
is_dp_worker
=
is_data_parallel_worker
self
.
dp_worker_id
=
dp_worker_id
self
.
mp_queue
=
mp_queue
# Init communication
context
=
zmq
.
Context
(
2
)
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
router_port
}
"
)
if
not
self
.
is_dp_worker
:
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
controller_port
}
"
)
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
detokenizer_port
}
"
)
# Init model server
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
gpu_ids
=
[
i
for
_
in
range
(
server_args
.
nnodes
)
for
i
in
range
(
tp_size_local
)]
# Launch other tp ranks
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
self
.
tp_procs
=
[]
if
tp_size_local
>
1
:
tp_rank_range
=
range
(
1
,
tp_size_local
)
self
.
tp_procs
=
launch_tp_servers
(
gpu_ids
,
tp_rank_range
,
server_args
,
port_args
.
mode
l_port
_args
[
0
],
port_args
.
ncc
l_port
s
[
dp_worker_id
],
model_overide_args
,
)
...
...
@@ -131,16 +66,19 @@ class ControllerSingle:
gpu_ids
[
0
],
0
,
server_args
,
port_args
.
mode
l_port
_args
[
0
],
port_args
.
ncc
l_port
s
[
dp_worker_id
],
model_overide_args
,
)
self
.
tp_cpu_group
=
self
.
tp_server
.
model_runner
.
tp_group
.
cpu_group
def
loop_for_forward
(
self
):
while
True
:
recv_reqs
=
self
.
recv_requests
()
if
not
self
.
is_dp_worker
:
recv_reqs
=
self
.
recv_requests_from_zmq
()
else
:
recv_reqs
=
self
.
recv_requests_from_mp_queue
()
if
self
.
server_args
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
broadcast_recv_input
(
recv_reqs
,
0
,
self
.
tp_cpu_group
)
out_pyobjs
=
self
.
tp_server
.
exposed_step
(
recv_reqs
)
...
...
@@ -148,27 +86,51 @@ class ControllerSingle:
for
obj
in
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
def
recv_requests
(
self
):
def
recv_requests
_from_zmq
(
self
):
recv_reqs
=
[]
while
True
:
try
:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
recv_reqs
.
append
(
recv_req
)
except
zmq
.
ZMQError
:
break
recv_reqs
.
append
(
recv_req
)
return
recv_reqs
def
recv_requests_from_mp_queue
(
self
):
recv_reqs
=
[]
while
not
self
.
mp_queue
.
empty
():
recv_reqs
.
append
(
self
.
mp_queue
.
get
())
return
recv_reqs
def
start_controller_process
(
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
,
model_overide_args
:
dict
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
:
multiprocessing
.
connection
.
Connection
,
model_overide_args
:
dict
,
is_data_parallel_worker
:
bool
=
False
,
gpu_ids
:
List
[
int
]
=
None
,
dp_worker_id
:
int
=
None
,
queue
:
multiprocessing
.
connection
.
Connection
=
None
,
):
"""Start a controller process."""
logging
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
format
=
"%(message)s"
,
)
if
not
is_data_parallel_worker
:
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
gpu_ids
=
[
i
for
_
in
range
(
server_args
.
nnodes
)
for
i
in
range
(
tp_size_local
)]
dp_worker_id
=
0
queue
=
None
try
:
controller
=
ControllerSingle
(
server_args
,
port_args
,
model_overide_args
)
controller
=
ControllerSingle
(
server_args
,
port_args
,
model_overide_args
,
gpu_ids
,
is_data_parallel_worker
,
dp_worker_id
,
queue
)
except
Exception
:
pipe_writer
.
send
(
get_exception_traceback
())
raise
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
d774acad
"""A tensor parallel worker."""
import
asyncio
import
logging
import
multiprocessing
import
pickle
import
time
import
warnings
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
List
,
Optional
import
rpyc
import
torch
from
rpyc.utils.classic
import
obtain
import
torch.distributed
as
dist
from
sglang.global_config
import
global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
...
...
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput
,
)
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.server_args
import
ModelPortArgs
,
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
connect_rpyc_service
,
get_int_token_logit_bias
,
is_multimodal_model
,
set_random_seed
,
start_rpyc_service_process
,
suppress_other_loggers
,
)
from
sglang.utils
import
get_exception_traceback
...
...
@@ -52,10 +49,9 @@ class ModelTpServer:
gpu_id
:
int
,
tp_rank
:
int
,
server_args
:
ServerArgs
,
mode
l_port
_args
:
ModelPortArgs
,
ncc
l_port
:
int
,
model_overide_args
:
dict
,
):
server_args
,
model_port_args
=
obtain
(
server_args
),
obtain
(
model_port_args
)
suppress_other_loggers
()
# Copy arguments
...
...
@@ -79,7 +75,7 @@ class ModelTpServer:
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
nccl_port
=
model_port_args
.
nccl_port
,
nccl_port
=
nccl_port
,
server_args
=
server_args
,
)
...
...
@@ -178,9 +174,6 @@ class ModelTpServer:
self
.
new_token_ratio_recovery
=
global_config
.
new_token_ratio_recovery
def
exposed_step
(
self
,
recv_reqs
):
if
not
isinstance
(
recv_reqs
,
list
):
recv_reqs
=
obtain
(
recv_reqs
)
try
:
# Recv requests
for
recv_req
in
recv_reqs
:
...
...
@@ -425,12 +418,6 @@ class ModelTpServer:
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
-
len
(
can_run_list
)
}
"
)
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# Return the new batch
new_batch
=
Batch
.
init_new
(
...
...
@@ -733,87 +720,74 @@ class ModelTpServer:
break
class
ModelTpService
(
rpyc
.
Service
):
exposed_ModelTpServer
=
ModelTpServer
class
ModelTpClient
:
def
__init__
(
self
,
gpu_ids
:
List
[
int
],
server_args
:
ServerArgs
,
model_port_args
:
ModelPortArgs
,
model_overide_args
,
):
server_args
,
model_port_args
=
obtain
(
server_args
),
obtain
(
model_port_args
)
self
.
tp_size
=
server_args
.
tp_size
def
run_tp_server
(
gpu_id
:
int
,
tp_rank
:
int
,
server_args
:
ServerArgs
,
nccl_port
:
int
,
model_overide_args
:
dict
,
):
"""Run a tensor parallel server."""
try
:
model_server
=
ModelTpServer
(
gpu_id
,
tp_rank
,
server_args
,
nccl_port
,
model_overide_args
,
)
tp_cpu_group
=
model_server
.
model_runner
.
tp_group
.
cpu_group
while
True
:
recv_reqs
=
broadcast_recv_input
(
None
,
tp_rank
,
tp_cpu_group
)
model_server
.
exposed_step
(
recv_reqs
)
except
Exception
:
logger
.
error
(
"Exception in run_tp_server:
\n
"
+
get_exception_traceback
())
raise
def
launch_tp_servers
(
gpu_ids
,
tp_rank_range
,
server_args
,
nccl_port
,
model_overide_args
):
"""Launch multiple tensor parallel servers."""
procs
=
[]
for
i
in
tp_rank_range
:
proc
=
multiprocessing
.
Process
(
target
=
run_tp_server
,
args
=
(
gpu_ids
[
i
],
i
,
server_args
,
nccl_port
,
model_overide_args
),
)
proc
.
start
()
procs
.
append
(
proc
)
if
self
.
tp_size
*
server_args
.
dp_size
==
1
:
# Init model
assert
len
(
gpu_ids
)
==
1
self
.
model_server
=
ModelTpService
().
exposed_ModelTpServer
(
gpu_ids
[
0
],
0
,
server_args
,
model_port_args
,
model_overide_args
,
)
return
procs
# Wrap functions
def
async_wrap
(
f
):
async
def
_func
(
*
args
,
**
kwargs
):
return
f
(
*
args
,
**
kwargs
)
return
_func
def
broadcast_recv_input
(
data
,
rank
,
dist_group
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
self
.
step
=
async_wrap
(
self
.
model_server
.
exposed_step
)
if
rank
==
0
:
if
len
(
data
)
==
0
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
else
:
with
ThreadPoolExecutor
(
self
.
tp_size
)
as
executor
:
# Launch model processes
if
server_args
.
nnodes
==
1
:
self
.
procs
=
list
(
executor
.
map
(
lambda
args
:
start_rpyc_service_process
(
*
args
),
[
(
ModelTpService
,
p
)
for
p
in
model_port_args
.
model_tp_ports
],
)
)
addrs
=
[(
"localhost"
,
p
)
for
p
in
model_port_args
.
model_tp_ports
]
else
:
addrs
=
[
(
ip
,
port
)
for
ip
,
port
in
zip
(
model_port_args
.
model_tp_ips
,
model_port_args
.
model_tp_ports
)
]
self
.
model_services
=
list
(
executor
.
map
(
lambda
args
:
connect_rpyc_service
(
*
args
),
addrs
)
)
# Init model
def
init_model
(
i
):
return
self
.
model_services
[
i
].
ModelTpServer
(
gpu_ids
[
i
],
i
,
server_args
,
model_port_args
,
model_overide_args
,
)
self
.
model_servers
=
list
(
executor
.
map
(
init_model
,
range
(
self
.
tp_size
)))
# Wrap functions
def
async_wrap
(
func_name
):
fs
=
[
rpyc
.
async_
(
getattr
(
m
,
func_name
))
for
m
in
self
.
model_servers
]
async
def
_func
(
*
args
,
**
kwargs
):
tasks
=
[
f
(
*
args
,
**
kwargs
)
for
f
in
fs
]
await
asyncio
.
gather
(
*
[
asyncio
.
to_thread
(
t
.
wait
)
for
t
in
tasks
])
return
obtain
(
tasks
[
0
].
value
)
return
_func
self
.
step
=
async_wrap
(
"step"
)
serialized_data
=
pickle
.
dumps
(
data
)
size
=
len
(
serialized_data
)
tensor_data
=
torch
.
ByteTensor
(
list
(
serialized_data
))
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
else
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
size
=
tensor_size
.
item
()
if
size
==
0
:
return
[]
tensor_data
=
torch
.
empty
(
size
,
dtype
=
torch
.
uint8
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
serialized_data
=
bytes
(
tensor_data
.
tolist
())
data
=
pickle
.
loads
(
serialized_data
)
return
data
python/sglang/srt/managers/tokenizer_manager.py
View file @
d774acad
...
...
@@ -61,7 +61,7 @@ class TokenizerManager:
self
.
recv_from_detokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
self
.
send_to_router
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_router
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
rout
er_port
}
"
)
self
.
send_to_router
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
controll
er_port
}
"
)
self
.
model_path
=
server_args
.
model_path
self
.
hf_config
=
get_config
(
...
...
python/sglang/srt/server.py
View file @
d774acad
...
...
@@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import (
v1_chat_completions
,
v1_completions
,
)
from
sglang.srt.server_args
import
ModelPortArgs
,
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
API_KEY_HEADER_NAME
,
APIKeyValidatorMiddleware
,
allocate_init_ports
,
assert_pkg_version
,
enable_show_time_cost
,
receive_addrs
,
send_addrs_to_rank_0
,
)
from
sglang.utils
import
get_exception_traceback
...
...
@@ -98,6 +96,7 @@ async def flush_cache():
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
"""Handle a generate request."""
if
obj
.
stream
:
async
def
stream_results
():
...
...
@@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs):
}
def
launch_server
(
server_args
:
ServerArgs
,
pipe_finish_writer
,
model_overide_args
=
None
):
def
launch_server
(
server_args
:
ServerArgs
,
model_overide_args
:
Optional
[
dict
]
=
None
,
pipe_finish_writer
:
Optional
[
mp
.
connection
.
Connection
]
=
None
):
"""Launch an HTTP server."""
global
tokenizer_manager
logging
.
basicConfig
(
...
...
@@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
_set_global_server_args
(
server_args
)
# Allocate ports
assert
server_args
.
tp_size
%
server_args
.
nnodes
==
0
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
additional_ports
,
tp_size_local
,
server_args
.
dp_size
,
)
ports
=
server_args
.
additional_ports
model_port_args
=
[]
for
i
in
range
(
server_args
.
dp_size
):
model_port_args
.
append
(
ModelPortArgs
(
nccl_port
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)],
model_tp_ips
=
[
None
]
*
tp_size_local
,
model_tp_ports
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)
+
1
:
3
+
(
i
+
1
)
*
(
tp_size_local
+
1
)
],
)
)
port_args
=
PortArgs
(
tokenizer_port
=
ports
[
0
],
rout
er_port
=
ports
[
1
],
controll
er_port
=
ports
[
1
],
detokenizer_port
=
ports
[
2
],
mode
l_port
_args
=
model_port_args
,
ncc
l_port
s
=
ports
[
3
:]
,
)
# Handle multi-node t
p
# Handle multi-node t
ensor parallelism
if
server_args
.
nnodes
>
1
:
assert
server_args
.
dp_size
==
1
,
"Multi-node dp is not supported."
...
...
@@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
gpu_ids
,
tp_rank_range
,
server_args
,
port
_args
.
model_port_arg
s
[
0
],
ports
[
3
],
model_overide_args
,
)
while
True
:
...
...
@@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
pipe_
rout
er_reader
,
pipe_
rout
er_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_
controll
er_reader
,
pipe_
controll
er_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_detoken_reader
,
pipe_detoken_writer
=
mp
.
Pipe
(
duplex
=
False
)
if
server_args
.
dp_size
==
1
:
start_process
=
start_controller_process_single
else
:
start_process
=
start_controller_process_multi
proc_
rout
er
=
mp
.
Process
(
proc_
controll
er
=
mp
.
Process
(
target
=
start_process
,
args
=
(
server_args
,
port_args
,
pipe_
rout
er_writer
,
model_overide_args
),
args
=
(
server_args
,
port_args
,
pipe_
controll
er_writer
,
model_overide_args
),
)
proc_
rout
er
.
start
()
proc_
controll
er
.
start
()
proc_detoken
=
mp
.
Process
(
target
=
start_detokenizer_process
,
args
=
(
...
...
@@ -255,68 +241,27 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
proc_detoken
.
start
()
# Wait for the model to finish loading
rout
er_init_state
=
pipe_
rout
er_reader
.
recv
()
controll
er_init_state
=
pipe_
controll
er_reader
.
recv
()
detoken_init_state
=
pipe_detoken_reader
.
recv
()
if
rout
er_init_state
!=
"init ok"
or
detoken_init_state
!=
"init ok"
:
proc_
rout
er
.
kill
()
if
controll
er_init_state
!=
"init ok"
or
detoken_init_state
!=
"init ok"
:
proc_
controll
er
.
kill
()
proc_detoken
.
kill
()
print
(
f
"Initialization failed.
rout
er_init_state:
{
rout
er_init_state
}
"
,
flush
=
True
f
"Initialization failed.
controll
er_init_state:
{
controll
er_init_state
}
"
,
flush
=
True
)
print
(
f
"Initialization failed. detoken_init_state:
{
detoken_init_state
}
"
,
flush
=
True
,
)
sys
.
exit
(
1
)
assert
proc_
rout
er
.
is_alive
()
and
proc_detoken
.
is_alive
()
assert
proc_
controll
er
.
is_alive
()
and
proc_detoken
.
is_alive
()
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
# Send a warmup request
def
_wait_and_warmup
():
headers
=
{}
url
=
server_args
.
url
()
if
server_args
.
api_key
:
headers
[
API_KEY_HEADER_NAME
]
=
server_args
.
api_key
# Wait until the server is launched
for
_
in
range
(
120
):
time
.
sleep
(
0.5
)
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
break
except
requests
.
exceptions
.
RequestException
:
pass
# Send a warmup request
try
:
for
_
in
range
(
server_args
.
dp_size
):
res
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
"The capital city of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
8
,
},
},
headers
=
headers
,
timeout
=
600
,
)
assert
res
.
status_code
==
200
except
Exception
as
e
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
get_exception_traceback
())
print
(
f
"Initialization failed. warmup error:
{
e
}
"
,
flush
=
True
)
raise
e
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"init ok"
)
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
)
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
,
args
=
(
server_args
,
pipe_finish_writer
))
t
.
start
()
# Listen for requests
...
...
@@ -333,6 +278,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
t
.
join
()
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
):
headers
=
{}
url
=
server_args
.
url
()
if
server_args
.
api_key
:
headers
[
API_KEY_HEADER_NAME
]
=
server_args
.
api_key
# Wait until the server is launched
for
_
in
range
(
120
):
time
.
sleep
(
0.5
)
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
break
except
requests
.
exceptions
.
RequestException
:
pass
# Send a warmup request
try
:
for
_
in
range
(
server_args
.
dp_size
):
res
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
"The capital city of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
8
,
},
},
headers
=
headers
,
timeout
=
600
,
)
assert
res
.
status_code
==
200
except
Exception
as
e
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
get_exception_traceback
())
print
(
f
"Initialization failed. warmup error:
{
e
}
"
,
flush
=
True
)
raise
e
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"init ok"
)
class
Runtime
:
"""
A wrapper for the server.
...
...
@@ -354,7 +341,6 @@ class Runtime:
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
=
allocate_init_ports
(
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
,
self
.
server_args
.
tp_size
,
self
.
server_args
.
dp_size
,
)
...
...
@@ -367,7 +353,7 @@ class Runtime:
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc
=
mp
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
pipe_writer
,
model_overide_args
),
args
=
(
self
.
server_args
,
model_overide_args
,
pipe_writer
),
)
proc
.
start
()
pipe_writer
.
close
()
...
...
python/sglang/srt/server_args.py
View file @
d774acad
...
...
@@ -337,16 +337,9 @@ class ServerArgs:
)
@
dataclasses
.
dataclass
class
ModelPortArgs
:
nccl_port
:
int
model_tp_ips
:
List
[
str
]
model_tp_ports
:
List
[
int
]
@
dataclasses
.
dataclass
class
PortArgs
:
tokenizer_port
:
int
rout
er_port
:
int
controll
er_port
:
int
detokenizer_port
:
int
mode
l_port
_arg
s
:
List
[
ModelPortArgs
]
ncc
l_ports
:
List
[
int
]
python/sglang/srt/utils.py
View file @
d774acad
...
...
@@ -3,7 +3,6 @@
import
base64
import
fcntl
import
logging
import
multiprocessing
import
os
import
random
import
socket
...
...
@@ -16,12 +15,10 @@ from typing import List, Optional
import
numpy
as
np
import
psutil
import
requests
import
rpyc
import
torch
import
triton
from
fastapi.responses
import
JSONResponse
from
packaging
import
version
as
pkg_version
from
rpyc.utils.server
import
ThreadedServer
from
starlette.middleware.base
import
BaseHTTPMiddleware
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -148,7 +145,6 @@ def is_port_available(port):
def
allocate_init_ports
(
port
:
Optional
[
int
]
=
None
,
additional_ports
:
Optional
[
List
[
int
]]
=
None
,
tp_size
:
int
=
1
,
dp_size
:
int
=
1
,
):
"""Allocate ports for all connections."""
...
...
@@ -160,8 +156,8 @@ def allocate_init_ports(
ret_ports
=
list
(
set
(
x
for
x
in
ret_ports
if
is_port_available
(
x
)))
cur_port
=
ret_ports
[
-
1
]
+
1
if
len
(
ret_ports
)
>
0
else
10000
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl
+ tp_size
)
num_ports_needed
=
4
+
dp_size
*
(
1
+
tp_size
)
# HTTP + Tokenizer + Controller + Detokenizer + dp_size *
1
(nccl)
num_ports_needed
=
4
+
dp_size
while
len
(
ret_ports
)
<
num_ports_needed
:
if
cur_port
not
in
ret_ports
and
is_port_available
(
cur_port
):
ret_ports
.
append
(
cur_port
)
...
...
@@ -371,49 +367,6 @@ def load_image(image_file):
return
image
,
image_size
def
connect_rpyc_service
(
host
,
port
):
repeat_count
=
0
while
repeat_count
<
20
:
try
:
con
=
rpyc
.
connect
(
host
,
port
,
config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
3600
,
},
)
break
except
ConnectionRefusedError
as
e
:
time
.
sleep
(
1
)
repeat_count
+=
1
if
repeat_count
==
20
:
raise
RuntimeError
(
f
"Connect rpyc error:
{
e
}
"
)
return
con
.
root
def
start_rpyc_service
(
service
:
rpyc
.
Service
,
port
:
int
):
t
=
ThreadedServer
(
service
=
service
,
port
=
port
,
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
3600
,
},
)
t
.
logger
.
setLevel
(
logging
.
WARN
)
t
.
start
()
def
start_rpyc_service_process
(
service
:
rpyc
.
Service
,
port
:
int
):
proc
=
multiprocessing
.
Process
(
target
=
start_rpyc_service
,
args
=
(
service
,
port
))
proc
.
start
()
return
proc
def
suppress_other_loggers
():
from
vllm.logger
import
logger
as
vllm_default_logger
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment