Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a1e697b2
Unverified
Commit
a1e697b2
authored
Dec 08, 2024
by
Byron Hsu
Committed by
GitHub
Dec 08, 2024
Browse files
[router] Improve cleanup logic (#2411)
parent
a6ca736c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
99 deletions
+85
-99
rust/py_src/sglang_router/launch_server.py
rust/py_src/sglang_router/launch_server.py
+37
-80
rust/py_test/test_launch_server.py
rust/py_test/test_launch_server.py
+48
-19
No files found.
rust/py_src/sglang_router/launch_server.py
View file @
a1e697b2
...
@@ -10,12 +10,12 @@ import time
...
@@ -10,12 +10,12 @@ import time
from
typing
import
List
from
typing
import
List
import
requests
import
requests
from
setproctitle
import
setproctitle
from
sglang_router.launch_router
import
RouterArgs
,
launch_router
from
sglang_router.launch_router
import
RouterArgs
,
launch_router
from
sglang.srt.server
import
launch_server
from
sglang.srt.server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
is_port_available
from
sglang.srt.utils
import
is_port_available
from
sglang.utils
import
get_exception_traceback
def
setup_logger
():
def
setup_logger
():
...
@@ -34,10 +34,12 @@ def setup_logger():
...
@@ -34,10 +34,12 @@ def setup_logger():
return
logger
return
logger
logger
=
setup_logger
()
# Create new process group
# Create new process group
def
run_server
(
server_args
,
dp_rank
):
def
run_server
(
server_args
,
dp_rank
):
os
.
setpgrp
()
# Create new process group
setproctitle
(
f
"sglang::server"
)
# Set SGLANG_DP_RANK environment variable
# Set SGLANG_DP_RANK environment variable
os
.
environ
[
"SGLANG_DP_RANK"
]
=
str
(
dp_rank
)
os
.
environ
[
"SGLANG_DP_RANK"
]
=
str
(
dp_rank
)
...
@@ -58,36 +60,6 @@ def launch_server_process(
...
@@ -58,36 +60,6 @@ def launch_server_process(
return
proc
return
proc
def
cleanup_processes
(
processes
:
List
[
mp
.
Process
]):
logger
=
logging
.
getLogger
(
"router"
)
logger
.
info
(
"Cleaning up processes..."
)
for
proc
in
processes
:
if
proc
.
is_alive
():
try
:
os
.
killpg
(
os
.
getpgid
(
proc
.
pid
),
signal
.
SIGTERM
)
proc
.
join
(
timeout
=
3
)
if
proc
.
is_alive
():
logger
.
warning
(
f
"Process
{
proc
.
pid
}
did not terminate gracefully, force killing..."
)
os
.
killpg
(
os
.
getpgid
(
proc
.
pid
),
signal
.
SIGKILL
)
except
ProcessLookupError
:
pass
def
setup_signal_handlers
(
cleanup_func
):
"""Setup handlers for various termination signals."""
def
signal_handler
(
signum
,
frame
):
cleanup_func
()
sys
.
exit
(
1
)
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
if
hasattr
(
signal
,
"SIGQUIT"
):
signal
.
signal
(
signal
.
SIGQUIT
,
signal_handler
)
def
wait_for_server_health
(
host
:
str
,
port
:
int
,
timeout
:
int
=
300
)
->
bool
:
def
wait_for_server_health
(
host
:
str
,
port
:
int
,
timeout
:
int
=
300
)
->
bool
:
"""Wait for server to be healthy by checking /health endpoint."""
"""Wait for server to be healthy by checking /health endpoint."""
start_time
=
time
.
time
()
start_time
=
time
.
time
()
...
@@ -117,8 +89,12 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
...
@@ -117,8 +89,12 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
return
available_ports
return
available_ports
def
cleanup_processes
(
processes
:
List
[
mp
.
Process
]):
for
process
in
processes
:
process
.
terminate
()
def
main
():
def
main
():
logger
=
setup_logger
()
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp
.
set_start_method
(
"spawn"
)
mp
.
set_start_method
(
"spawn"
)
...
@@ -148,52 +124,33 @@ def main():
...
@@ -148,52 +124,33 @@ def main():
# Start server processes
# Start server processes
server_processes
=
[]
server_processes
=
[]
try
:
for
i
,
worker_port
in
enumerate
(
worker_ports
):
for
i
,
worker_port
in
enumerate
(
worker_ports
):
logger
.
info
(
f
"Launching DP server process
{
i
}
on port
{
worker_port
}
"
)
logger
.
info
(
f
"Launching DP server process
{
i
}
on port
{
worker_port
}
"
)
proc
=
launch_server_process
(
server_args
,
worker_port
,
i
)
proc
=
launch_server_process
(
server_args
,
worker_port
,
i
)
server_processes
.
append
(
proc
)
server_processes
.
append
(
proc
)
signal
.
signal
(
signal
.
SIGINT
,
lambda
sig
,
frame
:
cleanup_processes
(
server_processes
))
# Setup cleanup handler
signal
.
signal
(
setup_signal_handlers
(
lambda
:
cleanup_processes
(
server_processes
))
signal
.
SIGTERM
,
lambda
sig
,
frame
:
cleanup_processes
(
server_processes
)
)
# Wait for all servers to be healthy
signal
.
signal
(
all_healthy
=
True
signal
.
SIGQUIT
,
lambda
sig
,
frame
:
cleanup_processes
(
server_processes
)
)
for
port
in
worker_ports
:
if
not
wait_for_server_health
(
server_args
.
host
,
port
):
for
port
in
worker_ports
:
logger
.
error
(
f
"Server on port
{
port
}
failed to become healthy"
)
if
not
wait_for_server_health
(
server_args
.
host
,
port
):
all_healthy
=
False
logger
.
error
(
f
"Server on port
{
port
}
failed to become healthy"
)
break
break
if
not
all_healthy
:
logger
.
info
(
"All servers are healthy. Starting router..."
)
logger
.
error
(
"Not all servers are healthy. Shutting down..."
)
cleanup_processes
(
server_processes
)
# Update router args with worker URLs
sys
.
exit
(
1
)
router_args
.
worker_urls
=
[
f
"http://
{
server_args
.
host
}
:
{
port
}
"
for
port
in
worker_ports
logger
.
info
(
"All servers are healthy. Starting router..."
)
]
# Update router args with worker URLs
# Start the router
router_args
.
worker_urls
=
[
router
=
launch_router
(
router_args
)
f
"http://
{
server_args
.
host
}
:
{
port
}
"
for
port
in
worker_ports
]
# Start the router
router
=
launch_router
(
router_args
)
if
router
is
None
:
logger
.
error
(
"Failed to start router. Shutting down..."
)
cleanup_processes
(
server_processes
)
sys
.
exit
(
1
)
except
KeyboardInterrupt
:
logger
.
info
(
"Received shutdown signal..."
)
except
Exception
as
e
:
logger
.
error
(
f
"Error occurred:
{
e
}
"
)
logger
.
error
(
get_exception_traceback
())
finally
:
logger
.
info
(
"Cleaning up processes..."
)
cleanup_processes
(
server_processes
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
rust/py_test/test_launch_server.py
View file @
a1e697b2
...
@@ -6,7 +6,6 @@ from types import SimpleNamespace
...
@@ -6,7 +6,6 @@ from types import SimpleNamespace
import
requests
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
...
@@ -104,23 +103,52 @@ def popen_launch_server(
...
@@ -104,23 +103,52 @@ def popen_launch_server(
return
process
return
process
def
terminate_and_wait
(
process
,
timeout
=
300
):
"""Terminate a process and wait until it is terminated.
Args:
process: subprocess.Popen object
timeout: maximum time to wait in seconds
Raises:
TimeoutError: if process does not terminate within timeout
"""
if
process
is
None
:
return
process
.
terminate
()
start_time
=
time
.
time
()
while
process
.
poll
()
is
None
:
print
(
f
"Terminating process
{
process
.
pid
}
"
)
if
time
.
time
()
-
start_time
>
timeout
:
raise
TimeoutError
(
f
"Process
{
process
.
pid
}
failed to terminate within
{
timeout
}
s"
)
time
.
sleep
(
1
)
print
(
f
"Process
{
process
.
pid
}
is successfully terminated"
)
class
TestLaunchServer
(
unittest
.
TestCase
):
class
TestLaunchServer
(
unittest
.
TestCase
):
@
classmethod
def
setUp
(
self
):
def
setUpClass
(
cls
):
self
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
self
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
self
.
process
=
None
cls
.
process
=
None
self
.
other_process
=
[]
cls
.
other_process
=
[]
def
tearDown
(
self
):
@
classmethod
print
(
"Running tearDown..."
)
def
tearDownClass
(
cls
):
if
self
.
process
:
kill_process_tree
(
cls
.
process
.
pid
)
terminate_and_wait
(
self
.
process
)
for
process
in
cls
.
other_process
:
for
process
in
self
.
other_process
:
kill_process_tree
(
process
.
pid
)
terminate_and_wait
(
process
)
print
(
"tearDown done"
)
def
test_mmlu
(
self
):
def
test_1_mmlu
(
self
):
print
(
"Running test_1_mmlu..."
)
# DP size = 2
# DP size = 2
TestLaunchServer
.
process
=
popen_launch_router
(
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
model
,
self
.
base_url
,
self
.
base_url
,
dp_size
=
2
,
dp_size
=
2
,
...
@@ -144,9 +172,10 @@ class TestLaunchServer(unittest.TestCase):
...
@@ -144,9 +172,10 @@ class TestLaunchServer(unittest.TestCase):
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
def
test_add_and_remove_worker
(
self
):
def
test_2_add_and_remove_worker
(
self
):
print
(
"Running test_2_add_and_remove_worker..."
)
# DP size = 1
# DP size = 1
TestLaunchServer
.
process
=
popen_launch_router
(
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
model
,
self
.
base_url
,
self
.
base_url
,
dp_size
=
1
,
dp_size
=
1
,
...
@@ -159,7 +188,7 @@ class TestLaunchServer(unittest.TestCase):
...
@@ -159,7 +188,7 @@ class TestLaunchServer(unittest.TestCase):
worker_process
=
popen_launch_server
(
worker_process
=
popen_launch_server
(
self
.
model
,
worker_url
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
self
.
model
,
worker_url
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
)
TestLaunchServer
.
other_process
.
append
(
worker_process
)
self
.
other_process
.
append
(
worker_process
)
# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
with
requests
.
Session
()
as
session
:
with
requests
.
Session
()
as
session
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment