Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
91f93f14
Unverified
Commit
91f93f14
authored
Jun 07, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 07, 2024
Browse files
Crash the server when error or OOM happens (#514)
parent
f70f7258
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
63 additions
and
22 deletions
+63
-22
python/sglang/srt/managers/controller/dp_worker.py
python/sglang/srt/managers/controller/dp_worker.py
+5
-0
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+10
-1
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+7
-5
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+3
-2
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+14
-13
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+23
-0
No files found.
python/sglang/srt/managers/controller/dp_worker.py
View file @
91f93f14
...
...
@@ -12,6 +12,7 @@ from sglang.global_config import global_config
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
"srt.controller"
)
...
...
@@ -58,6 +59,10 @@ class DataParallelWorkerThread(threading.Thread):
f
"
{
get_exception_traceback
()
}
"
)
self
.
liveness
=
False
# Crash the whole server when there are any errors.
# TODO(lianmin): make this an option.
kill_parent_process
()
return
for
obj
in
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
91f93f14
"""A controller that manages a group of tensor parallel workers."""
import
asyncio
import
logging
import
time
import
uvloop
import
zmq
...
...
@@ -9,10 +10,13 @@ import zmq.asyncio
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.utils
import
get_exception_traceback
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
logger
=
logging
.
getLogger
(
"srt.controller"
)
class
ControllerSingle
:
def
__init__
(
self
,
model_client
:
ModelTpClient
,
port_args
:
PortArgs
):
...
...
@@ -85,4 +89,9 @@ def start_controller_process(
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
loop
.
create_task
(
controller
.
loop_for_recv_requests
())
loop
.
run_until_complete
(
controller
.
loop_for_forward
())
\ No newline at end of file
try
:
loop
.
run_until_complete
(
controller
.
loop_for_forward
())
except
Exception
:
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
finally
:
kill_parent_process
()
\ No newline at end of file
python/sglang/srt/managers/controller/model_runner.py
View file @
91f93f14
...
...
@@ -18,7 +18,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
,
monkey_patch_vllm_p2p_access_check
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
...
...
@@ -240,10 +240,12 @@ class ModelRunner:
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Set cuda device."
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Init nccl begin."
)
monkey_patch_vllm_p2p_access_check
()
init_distributed_environment
(
backend
=
"nccl"
,
world_size
=
self
.
tp_size
,
rank
=
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
...
...
@@ -265,7 +267,7 @@ class ModelRunner:
def
load_model
(
self
):
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight begin. "
f
"
A
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"
a
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
device_config
=
DeviceConfig
()
...
...
@@ -295,8 +297,8 @@ class ModelRunner:
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight end. "
f
"
T
ype=
{
type
(
self
.
model
).
__name__
}
.
"
f
"
A
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"
t
ype=
{
type
(
self
.
model
).
__name__
}
,
"
f
"
a
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
def
profile_max_num_token
(
self
,
total_gpu_memory
):
...
...
@@ -333,7 +335,7 @@ class ModelRunner:
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Memory pool end. "
f
"
A
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"
a
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
@
torch
.
inference_mode
()
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
91f93f14
...
...
@@ -34,7 +34,7 @@ from sglang.srt.utils import (
)
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
"srt.
model_tp
"
)
logger
=
logging
.
getLogger
(
"srt.
tp_worker
"
)
class
ModelTpServer
:
...
...
@@ -187,7 +187,8 @@ class ModelTpServer:
# Forward
self
.
forward_step
()
except
Exception
:
logger
.
error
(
"Exception in ModelTpClient:
\n
"
+
get_exception_traceback
())
logger
.
error
(
"Exception in ModelTpServer:
\n
"
+
get_exception_traceback
())
raise
# Return results
ret
=
self
.
out_pyobjs
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
91f93f14
...
...
@@ -87,7 +87,7 @@ def start_detokenizer_process(
try
:
manager
=
DetokenizerManager
(
server_args
,
port_args
)
except
Exception
as
e
:
except
Exception
:
pipe_writer
.
send
(
get_exception_traceback
())
raise
pipe_writer
.
send
(
"init ok"
)
...
...
python/sglang/srt/server.py
View file @
91f93f14
...
...
@@ -228,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Send a warmup request
try
:
res
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
"The capital city of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
for
_
in
range
(
server_args
.
dp_size
):
res
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
"The capital city of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
},
headers
=
headers
,
timeout
=
600
,
)
assert
res
.
status_code
==
200
except
Exception
as
e
:
headers
=
headers
,
timeout
=
600
,
)
assert
res
.
status_code
==
200
except
Exception
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
get_exception_traceback
())
print
(
f
"Initialization failed. warmup error:
{
e
}
"
)
...
...
python/sglang/srt/utils.py
View file @
91f93f14
...
...
@@ -12,6 +12,7 @@ from io import BytesIO
from
typing
import
List
,
Optional
import
numpy
as
np
import
psutil
import
requests
import
rpyc
import
torch
...
...
@@ -441,6 +442,27 @@ def assert_pkg_version(pkg: str, min_version: str):
)
def
kill_parent_process
():
"""Kill the parent process and all children of the parent process."""
current_process
=
psutil
.
Process
()
parent_process
=
current_process
.
parent
()
children
=
current_process
.
children
(
recursive
=
True
)
for
child
in
children
:
if
child
.
pid
!=
current_process
.
pid
:
os
.
kill
(
child
.
pid
,
9
)
os
.
kill
(
parent_process
.
pid
,
9
)
def
monkey_patch_vllm_p2p_access_check
():
"""
Monkey patch the slow p2p access check in vllm.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
"""
import
vllm.distributed.device_communicators.custom_all_reduce_utils
as
tgt
setattr
(
tgt
,
"gpu_p2p_access_check"
,
lambda
*
arg
,
**
kwargs
:
True
)
API_KEY_HEADER_NAME
=
"X-API-Key"
...
...
@@ -459,3 +481,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
)
response
=
await
call_next
(
request
)
return
response
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment