Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d4fc1a70
"vscode:/vscode.git/clone" did not exist on "55ab9f371a198a190b423763a1d48e4495fe520e"
Unverified
Commit
d4fc1a70
authored
Nov 28, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 28, 2024
Browse files
Crash the server correctly during error (#2231)
parent
db674e3d
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
92 additions
and
84 deletions
+92
-84
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+3
-6
python/sglang/bench_one_batch_server.py
python/sglang/bench_one_batch_server.py
+4
-3
python/sglang/launch_server.py
python/sglang/launch_server.py
+2
-2
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+7
-11
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+7
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+8
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+2
-2
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+11
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+16
-5
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+8
-20
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+6
-6
python/sglang/utils.py
python/sglang/utils.py
+2
-2
rust/py_test/test_launch_server.py
rust/py_test/test_launch_server.py
+2
-2
test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
.../sampling/penaltylib/test_srt_endpoint_with_penalizers.py
+2
-2
test/srt/test_cache_report.py
test/srt/test_cache_report.py
+2
-2
test/srt/test_data_parallelism.py
test/srt/test_data_parallelism.py
+2
-2
test/srt/test_double_sparsity.py
test/srt/test_double_sparsity.py
+2
-2
test/srt/test_dp_attention.py
test/srt/test_dp_attention.py
+2
-2
test/srt/test_embedding_openai_server.py
test/srt/test_embedding_openai_server.py
+2
-2
test/srt/test_eval_accuracy_large.py
test/srt/test_eval_accuracy_large.py
+2
-2
No files found.
python/sglang/bench_one_batch.py
View file @
d4fc1a70
...
...
@@ -47,6 +47,7 @@ import itertools
import
json
import
logging
import
multiprocessing
import
os
import
time
from
typing
import
Tuple
...
...
@@ -62,11 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server
import
_set_envs_and_config
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
configure_logger
,
kill_child_process
,
suppress_other_loggers
,
)
from
sglang.srt.utils
import
configure_logger
,
kill_process_tree
,
suppress_other_loggers
@
dataclasses
.
dataclass
...
...
@@ -468,4 +465,4 @@ if __name__ == "__main__":
main
(
server_args
,
bench_args
)
finally
:
if
server_args
.
tp_size
!=
1
:
kill_
child_process
(
)
kill_
process_tree
(
os
.
getpid
(),
include_parent
=
False
)
python/sglang/bench_one_batch_server.py
View file @
d4fc1a70
...
...
@@ -15,6 +15,7 @@ import dataclasses
import
itertools
import
json
import
multiprocessing
import
os
import
time
from
typing
import
Tuple
...
...
@@ -23,7 +24,7 @@ import requests
from
sglang.srt.server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
@
dataclasses
.
dataclass
...
...
@@ -69,7 +70,7 @@ def launch_server_internal(server_args):
except
Exception
as
e
:
raise
e
finally
:
kill_
child_process
(
)
kill_
process_tree
(
os
.
getpid
(),
include_parent
=
False
)
def
launch_server_process
(
server_args
:
ServerArgs
):
...
...
@@ -175,7 +176,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
)
finally
:
if
proc
:
kill_
child_
process
(
proc
.
pid
,
include_self
=
True
)
kill_process
_tree
(
proc
.
pid
)
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
...
...
python/sglang/launch_server.py
View file @
d4fc1a70
...
...
@@ -4,7 +4,7 @@ import sys
from
sglang.srt.server
import
launch_server
from
sglang.srt.server_args
import
prepare_server_args
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
if
__name__
==
"__main__"
:
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
...
...
@@ -12,4 +12,4 @@ if __name__ == "__main__":
try
:
launch_server
(
server_args
)
finally
:
kill_
child_process
(
)
kill_
process_tree
(
os
.
getpid
(),
include_parent
=
False
)
python/sglang/srt/managers/data_parallel_controller.py
View file @
d4fc1a70
...
...
@@ -15,9 +15,11 @@
import
logging
import
multiprocessing
as
mp
import
signal
import
threading
from
enum
import
Enum
,
auto
import
psutil
import
zmq
from
sglang.srt.managers.io_struct
import
(
...
...
@@ -26,13 +28,7 @@ from sglang.srt.managers.io_struct import (
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
bind_port
,
configure_logger
,
get_zmq_socket
,
kill_parent_process
,
suppress_other_loggers
,
)
from
sglang.srt.utils
import
bind_port
,
configure_logger
,
get_zmq_socket
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -235,7 +231,7 @@ def run_data_parallel_controller_process(
pipe_writer
,
):
configure_logger
(
server_args
)
suppress_other_loggers
()
parent_process
=
psutil
.
Process
().
parent
()
try
:
controller
=
DataParallelController
(
server_args
,
port_args
)
...
...
@@ -244,6 +240,6 @@ def run_data_parallel_controller_process(
)
controller
.
event_loop
()
except
Exception
:
msg
=
get_exception_traceback
()
logger
.
error
(
msg
)
kill_
parent_process
(
)
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"DataParallelController hit an exception:
{
traceback
}
"
)
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
python/sglang/srt/managers/detokenizer_manager.py
View file @
d4fc1a70
...
...
@@ -15,9 +15,11 @@
import
dataclasses
import
logging
import
signal
from
collections
import
OrderedDict
from
typing
import
List
,
Union
import
psutil
import
zmq
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
...
...
@@ -28,7 +30,7 @@ from sglang.srt.managers.io_struct import (
)
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
,
FINISH_MATCHED_TOKEN
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
configure_logger
,
get_zmq_socket
,
kill_parent_process
from
sglang.srt.utils
import
configure_logger
,
get_zmq_socket
from
sglang.utils
import
find_printable_text
,
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -193,11 +195,12 @@ def run_detokenizer_process(
port_args
:
PortArgs
,
):
configure_logger
(
server_args
)
parent_process
=
psutil
.
Process
().
parent
()
try
:
manager
=
DetokenizerManager
(
server_args
,
port_args
)
manager
.
event_loop
()
except
Exception
:
msg
=
get_exception_traceback
()
logger
.
error
(
msg
)
kill_
parent_process
(
)
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"DetokenizerManager hit an exception:
{
traceback
}
"
)
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
python/sglang/srt/managers/scheduler.py
View file @
d4fc1a70
...
...
@@ -15,6 +15,7 @@
import
logging
import
os
import
signal
import
threading
import
time
import
warnings
...
...
@@ -23,6 +24,7 @@ from concurrent import futures
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
import
psutil
import
torch
import
zmq
...
...
@@ -73,7 +75,6 @@ from sglang.srt.utils import (
crash_on_warnings
,
get_bool_env_var
,
get_zmq_socket
,
kill_parent_process
,
set_gpu_proc_affinity
,
set_random_seed
,
suppress_other_loggers
,
...
...
@@ -316,6 +317,7 @@ class Scheduler:
self
.
watchdog_timeout
=
server_args
.
watchdog_timeout
t
=
threading
.
Thread
(
target
=
self
.
watchdog_thread
,
daemon
=
True
)
t
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
# Init profiler
if
os
.
getenv
(
"SGLANG_TORCH_PROFILER_DIR"
,
""
)
==
""
:
...
...
@@ -359,7 +361,7 @@ class Scheduler:
self
.
watchdog_last_time
=
time
.
time
()
time
.
sleep
(
self
.
watchdog_timeout
/
2
)
kill_
parent_process
(
)
self
.
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
@
torch
.
no_grad
()
def
event_loop_normal
(
self
):
...
...
@@ -1423,6 +1425,7 @@ def run_scheduler_process(
configure_logger
(
server_args
,
prefix
=
f
" DP
{
dp_rank
}
TP
{
tp_rank
}
"
)
suppress_other_loggers
()
parent_process
=
psutil
.
Process
().
parent
()
try
:
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
dp_rank
)
...
...
@@ -1434,6 +1437,6 @@ def run_scheduler_process(
else
:
scheduler
.
event_loop_normal
()
except
Exception
:
msg
=
get_exception_traceback
()
logger
.
error
(
msg
)
kill_
parent_process
(
)
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"Scheduler hit an exception:
{
traceback
}
"
)
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
python/sglang/srt/managers/tokenizer_manager.py
View file @
d4fc1a70
...
...
@@ -58,7 +58,7 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
get_zmq_socket
,
kill_
child_
process
from
sglang.srt.utils
import
get_zmq_socket
,
kill_process
_tree
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -532,7 +532,7 @@ class TokenizerManager:
else
:
break
kill_
child_
process
(
include_
self
=
True
)
kill_process
_tree
(
os
.
getpid
(),
include_
parent
=
True
)
sys
.
exit
(
0
)
async
def
handle_loop
(
self
):
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
d4fc1a70
...
...
@@ -15,16 +15,19 @@
import
dataclasses
import
logging
import
signal
import
threading
from
queue
import
Queue
from
typing
import
Optional
import
psutil
import
torch
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -70,6 +73,7 @@ class TpModelWorkerClient:
target
=
self
.
forward_thread_func
,
)
self
.
forward_thread
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
...
...
@@ -87,8 +91,13 @@ class TpModelWorkerClient:
)
def
forward_thread_func
(
self
):
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
try
:
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
except
Exception
:
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"TpModelWorkerClient hit an exception:
{
traceback
}
"
)
self
.
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
@
torch
.
no_grad
()
def
forward_thread_func_
(
self
):
...
...
python/sglang/srt/server.py
View file @
d4fc1a70
...
...
@@ -23,6 +23,8 @@ import json
import
logging
import
multiprocessing
as
mp
import
os
import
signal
import
sys
import
threading
import
time
from
http
import
HTTPStatus
...
...
@@ -79,7 +81,7 @@ from sglang.srt.utils import (
configure_logger
,
delete_directory
,
is_port_available
,
kill_
child_
process
,
kill_process
_tree
,
maybe_set_triton_cache_manager
,
prepare_model_and_tokenizer
,
set_prometheus_multiproc_dir
,
...
...
@@ -572,6 +574,15 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html."
,
)
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
def
sigquit_handler
(
signum
,
frame
):
kill_process_tree
(
os
.
getpid
())
signal
.
signal
(
signal
.
SIGQUIT
,
sigquit_handler
)
# Set mp start method
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
...
...
@@ -598,7 +609,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
last_traceback
)
logger
.
error
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
)
kill_
child_
process
(
include_self
=
True
)
kill_process
_tree
(
os
.
getpid
()
)
return
model_info
=
res
.
json
()
...
...
@@ -631,7 +642,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
last_traceback
)
logger
.
error
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
)
kill_
child_
process
(
include_self
=
True
)
kill_process
_tree
(
os
.
getpid
()
)
return
# logger.info(f"{res.json()=}")
...
...
@@ -700,7 +711,7 @@ class Runtime:
def
shutdown
(
self
):
if
self
.
pid
is
not
None
:
kill_
child_
process
(
self
.
pid
,
include_self
=
True
)
kill_process
_tree
(
self
.
pid
)
self
.
pid
=
None
def
cache_prefix
(
self
,
prefix
:
str
):
...
...
@@ -924,7 +935,7 @@ class Engine:
return
ret
def
shutdown
(
self
):
kill_
child_process
(
)
kill_
process_tree
(
os
.
getpid
(),
include_parent
=
False
)
def
get_tokenizer
(
self
):
global
tokenizer_manager
...
...
python/sglang/srt/utils.py
View file @
d4fc1a70
...
...
@@ -443,26 +443,14 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
)
def
kill_parent_process
():
"""Kill the parent process and all children of the parent process."""
current_process
=
psutil
.
Process
()
parent_process
=
current_process
.
parent
()
kill_child_process
(
parent_process
.
pid
,
include_self
=
True
,
skip_pid
=
current_process
.
pid
)
try
:
current_process
.
kill
()
except
psutil
.
NoSuchProcess
:
pass
def
kill_child_process
(
pid
=
None
,
include_self
=
False
,
skip_pid
=
None
):
"""Kill the process and all its children process."""
if
pid
is
None
:
pid
=
os
.
getpid
()
def
kill_process_tree
(
parent_pid
,
include_parent
:
bool
=
True
,
skip_pid
:
int
=
None
):
"""Kill the process and all its child processes."""
if
parent_pid
is
None
:
parent_pid
=
os
.
getpid
()
include_parent
=
False
try
:
itself
=
psutil
.
Process
(
pid
)
itself
=
psutil
.
Process
(
parent_
pid
)
except
psutil
.
NoSuchProcess
:
return
...
...
@@ -475,13 +463,13 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
except
psutil
.
NoSuchProcess
:
pass
if
include_
self
:
if
include_
parent
:
try
:
itself
.
kill
()
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them.
itself
.
send_signal
(
signal
.
SIGI
N
T
)
itself
.
send_signal
(
signal
.
SIG
QU
IT
)
except
psutil
.
NoSuchProcess
:
pass
...
...
python/sglang/test/test_utils.py
View file @
d4fc1a70
...
...
@@ -22,7 +22,7 @@ from sglang.bench_serving import run_benchmark
from
sglang.global_config
import
global_config
from
sglang.lang.backend.openai
import
OpenAI
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.srt.utils
import
get_bool_env_var
,
kill_
child_
process
from
sglang.srt.utils
import
get_bool_env_var
,
kill_process
_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.utils
import
get_exception_traceback
...
...
@@ -504,7 +504,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
)
assert
ret_code
==
0
except
TimeoutError
:
kill_
child_
process
(
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
process
.
pid
)
time
.
sleep
(
5
)
print
(
f
"
\n
Timeout after
{
timeout_per_file
}
seconds when running
{
filename
}
\n
"
,
...
...
@@ -578,7 +578,7 @@ def run_bench_serving(
run_benchmark
(
warmup_args
)
res
=
run_benchmark
(
args
)
finally
:
kill_
child_
process
(
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
process
.
pid
)
assert
res
[
"completed"
]
==
num_prompts
return
res
...
...
@@ -611,7 +611,7 @@ def run_bench_one_batch(model, other_args):
lastline
=
output
.
split
(
"
\n
"
)[
-
3
]
output_throughput
=
float
(
lastline
.
split
(
" "
)[
-
2
])
finally
:
kill_
child_
process
(
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
process
.
pid
)
return
output_throughput
...
...
@@ -710,8 +710,8 @@ def run_and_check_memory_leak(
workload_func
(
base_url
,
model
)
# Clean up everything
kill_
child_
process
(
process
.
pid
,
include_self
=
True
)
kill_
child_
process
(
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
process
.
pid
)
kill_process
_tree
(
process
.
pid
)
stdout
.
close
()
stderr
.
close
()
if
os
.
path
.
exists
(
STDOUT_FILENAME
):
...
...
python/sglang/utils.py
View file @
d4fc1a70
...
...
@@ -348,9 +348,9 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
def
terminate_process
(
process
):
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
kill_
child_
process
(
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
process
.
pid
)
def
print_highlight
(
html_content
:
str
):
...
...
rust/py_test/test_launch_server.py
View file @
d4fc1a70
...
...
@@ -5,7 +5,7 @@ from types import SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
...
...
@@ -79,7 +79,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
kill_
child_
process
(
cls
.
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
...
...
test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
View file @
d4fc1a70
...
...
@@ -4,7 +4,7 @@ from multiprocessing import Process
import
requests
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
...
@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
kill_
child_
process
(
cls
.
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
cls
.
process
.
pid
)
def
run_decode
(
self
,
...
...
test/srt/test_cache_report.py
View file @
d4fc1a70
...
...
@@ -4,7 +4,7 @@ import unittest
import
openai
import
requests
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
...
...
@@ -44,7 +44,7 @@ class TestCacheReport(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
kill_
child_
process
(
cls
.
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
cls
.
process
.
pid
)
def
run_decode
(
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
):
response
=
requests
.
post
(
...
...
test/srt/test_data_parallelism.py
View file @
d4fc1a70
...
...
@@ -4,7 +4,7 @@ from types import SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
...
...
@@ -28,7 +28,7 @@ class TestDataParallelism(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
kill_
child_
process
(
cls
.
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
...
...
test/srt/test_double_sparsity.py
View file @
d4fc1a70
...
...
@@ -2,7 +2,7 @@ import os
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
...
...
@@ -45,7 +45,7 @@ class TestDoubleSparsity(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
kill_
child_
process
(
cls
.
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
...
...
test/srt/test_dp_attention.py
View file @
d4fc1a70
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
...
...
@@ -30,7 +30,7 @@ class TestDPAttention(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
kill_
child_
process
(
cls
.
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
...
...
test/srt/test_embedding_openai_server.py
View file @
d4fc1a70
...
...
@@ -3,7 +3,7 @@ import unittest
import
openai
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
...
...
@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
kill_
child_
process
(
cls
.
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
cls
.
process
.
pid
)
def
run_embedding
(
self
,
use_list_input
,
token_input
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
...
...
test/srt/test_eval_accuracy_large.py
View file @
d4fc1a70
...
...
@@ -6,7 +6,7 @@ python -m unittest test_eval_accuracy_large.TestEvalAccuracyLarge.test_mmlu
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_
child_
process
from
sglang.srt.utils
import
kill_process
_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
...
...
@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
kill_
child_
process
(
cls
.
process
.
pid
,
include_self
=
True
)
kill_process
_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment