Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
91f93f14
Unverified
Commit
91f93f14
authored
Jun 07, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 07, 2024
Browse files
Crash the server when error or OOM happens (#514)
parent
f70f7258
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
63 additions
and
22 deletions
+63
-22
python/sglang/srt/managers/controller/dp_worker.py
python/sglang/srt/managers/controller/dp_worker.py
+5
-0
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+10
-1
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+7
-5
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+3
-2
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+14
-13
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+23
-0
No files found.
python/sglang/srt/managers/controller/dp_worker.py
View file @
91f93f14
...
@@ -12,6 +12,7 @@ from sglang.global_config import global_config
...
@@ -12,6 +12,7 @@ from sglang.global_config import global_config
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
"srt.controller"
)
logger
=
logging
.
getLogger
(
"srt.controller"
)
...
@@ -58,6 +59,10 @@ class DataParallelWorkerThread(threading.Thread):
...
@@ -58,6 +59,10 @@ class DataParallelWorkerThread(threading.Thread):
f
"
{
get_exception_traceback
()
}
"
f
"
{
get_exception_traceback
()
}
"
)
)
self
.
liveness
=
False
self
.
liveness
=
False
# Crash the whole server when there are any errors.
# TODO(lianmin): make this an option.
kill_parent_process
()
return
for
obj
in
out_pyobjs
:
for
obj
in
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
91f93f14
"""A controller that manages a group of tensor parallel workers."""
"""A controller that manages a group of tensor parallel workers."""
import
asyncio
import
asyncio
import
logging
import
logging
import
time
import
uvloop
import
uvloop
import
zmq
import
zmq
...
@@ -9,10 +10,13 @@ import zmq.asyncio
...
@@ -9,10 +10,13 @@ import zmq.asyncio
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
logger
=
logging
.
getLogger
(
"srt.controller"
)
class
ControllerSingle
:
class
ControllerSingle
:
def
__init__
(
self
,
model_client
:
ModelTpClient
,
port_args
:
PortArgs
):
def
__init__
(
self
,
model_client
:
ModelTpClient
,
port_args
:
PortArgs
):
...
@@ -85,4 +89,9 @@ def start_controller_process(
...
@@ -85,4 +89,9 @@ def start_controller_process(
loop
=
asyncio
.
new_event_loop
()
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
asyncio
.
set_event_loop
(
loop
)
loop
.
create_task
(
controller
.
loop_for_recv_requests
())
loop
.
create_task
(
controller
.
loop_for_recv_requests
())
loop
.
run_until_complete
(
controller
.
loop_for_forward
())
try
:
\ No newline at end of file
loop
.
run_until_complete
(
controller
.
loop_for_forward
())
except
Exception
:
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
finally
:
kill_parent_process
()
\ No newline at end of file
python/sglang/srt/managers/controller/model_runner.py
View file @
91f93f14
...
@@ -18,7 +18,7 @@ from vllm.model_executor.models import ModelRegistry
...
@@ -18,7 +18,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
,
monkey_patch_vllm_p2p_access_check
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
...
@@ -240,10 +240,12 @@ class ModelRunner:
...
@@ -240,10 +240,12 @@ class ModelRunner:
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Set cuda device."
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Set cuda device."
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Init nccl begin."
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Init nccl begin."
)
monkey_patch_vllm_p2p_access_check
()
init_distributed_environment
(
init_distributed_environment
(
backend
=
"nccl"
,
backend
=
"nccl"
,
world_size
=
self
.
tp_size
,
world_size
=
self
.
tp_size
,
rank
=
self
.
tp_rank
,
rank
=
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
distributed_init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
)
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
...
@@ -265,7 +267,7 @@ class ModelRunner:
...
@@ -265,7 +267,7 @@ class ModelRunner:
def
load_model
(
self
):
def
load_model
(
self
):
logger
.
info
(
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight begin. "
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight begin. "
f
"
A
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"
a
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
)
device_config
=
DeviceConfig
()
device_config
=
DeviceConfig
()
...
@@ -295,8 +297,8 @@ class ModelRunner:
...
@@ -295,8 +297,8 @@ class ModelRunner:
)
)
logger
.
info
(
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight end. "
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight end. "
f
"
T
ype=
{
type
(
self
.
model
).
__name__
}
.
"
f
"
t
ype=
{
type
(
self
.
model
).
__name__
}
,
"
f
"
A
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"
a
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
)
def
profile_max_num_token
(
self
,
total_gpu_memory
):
def
profile_max_num_token
(
self
,
total_gpu_memory
):
...
@@ -333,7 +335,7 @@ class ModelRunner:
...
@@ -333,7 +335,7 @@ class ModelRunner:
)
)
logger
.
info
(
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Memory pool end. "
f
"[gpu_id=
{
self
.
gpu_id
}
] Memory pool end. "
f
"
A
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"
a
vail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
91f93f14
...
@@ -34,7 +34,7 @@ from sglang.srt.utils import (
...
@@ -34,7 +34,7 @@ from sglang.srt.utils import (
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
"srt.
model_tp
"
)
logger
=
logging
.
getLogger
(
"srt.
tp_worker
"
)
class
ModelTpServer
:
class
ModelTpServer
:
...
@@ -187,7 +187,8 @@ class ModelTpServer:
...
@@ -187,7 +187,8 @@ class ModelTpServer:
# Forward
# Forward
self
.
forward_step
()
self
.
forward_step
()
except
Exception
:
except
Exception
:
logger
.
error
(
"Exception in ModelTpClient:
\n
"
+
get_exception_traceback
())
logger
.
error
(
"Exception in ModelTpServer:
\n
"
+
get_exception_traceback
())
raise
# Return results
# Return results
ret
=
self
.
out_pyobjs
ret
=
self
.
out_pyobjs
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
91f93f14
...
@@ -87,7 +87,7 @@ def start_detokenizer_process(
...
@@ -87,7 +87,7 @@ def start_detokenizer_process(
try
:
try
:
manager
=
DetokenizerManager
(
server_args
,
port_args
)
manager
=
DetokenizerManager
(
server_args
,
port_args
)
except
Exception
as
e
:
except
Exception
:
pipe_writer
.
send
(
get_exception_traceback
())
pipe_writer
.
send
(
get_exception_traceback
())
raise
raise
pipe_writer
.
send
(
"init ok"
)
pipe_writer
.
send
(
"init ok"
)
...
...
python/sglang/srt/server.py
View file @
91f93f14
...
@@ -228,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -228,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Send a warmup request
# Send a warmup request
try
:
try
:
res
=
requests
.
post
(
for
_
in
range
(
server_args
.
dp_size
):
url
+
"/generate"
,
res
=
requests
.
post
(
json
=
{
url
+
"/generate"
,
"text"
:
"The capital city of France is"
,
json
=
{
"sampling_params"
:
{
"text"
:
"The capital city of France is"
,
"temperature"
:
0
,
"sampling_params"
:
{
"max_new_tokens"
:
16
,
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
},
},
headers
=
headers
,
headers
=
headers
,
timeout
=
600
,
timeout
=
600
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
except
Exception
:
except
Exception
as
e
:
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
get_exception_traceback
())
pipe_finish_writer
.
send
(
get_exception_traceback
())
print
(
f
"Initialization failed. warmup error:
{
e
}
"
)
print
(
f
"Initialization failed. warmup error:
{
e
}
"
)
...
...
python/sglang/srt/utils.py
View file @
91f93f14
...
@@ -12,6 +12,7 @@ from io import BytesIO
...
@@ -12,6 +12,7 @@ from io import BytesIO
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
numpy
as
np
import
numpy
as
np
import
psutil
import
requests
import
requests
import
rpyc
import
rpyc
import
torch
import
torch
...
@@ -441,6 +442,27 @@ def assert_pkg_version(pkg: str, min_version: str):
...
@@ -441,6 +442,27 @@ def assert_pkg_version(pkg: str, min_version: str):
)
)
def
kill_parent_process
():
"""Kill the parent process and all children of the parent process."""
current_process
=
psutil
.
Process
()
parent_process
=
current_process
.
parent
()
children
=
current_process
.
children
(
recursive
=
True
)
for
child
in
children
:
if
child
.
pid
!=
current_process
.
pid
:
os
.
kill
(
child
.
pid
,
9
)
os
.
kill
(
parent_process
.
pid
,
9
)
def
monkey_patch_vllm_p2p_access_check
():
"""
Monkey patch the slow p2p access check in vllm.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
"""
import
vllm.distributed.device_communicators.custom_all_reduce_utils
as
tgt
setattr
(
tgt
,
"gpu_p2p_access_check"
,
lambda
*
arg
,
**
kwargs
:
True
)
API_KEY_HEADER_NAME
=
"X-API-Key"
API_KEY_HEADER_NAME
=
"X-API-Key"
...
@@ -459,3 +481,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
...
@@ -459,3 +481,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
)
)
response
=
await
call_next
(
request
)
response
=
await
call_next
(
request
)
return
response
return
response
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment