Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
86fc0d79
"docs/vscode:/vscode.git/clone" did not exist on "75d53cc83966b4046e5a329ddf7baa6aa24f52e2"
Unverified
Commit
86fc0d79
authored
Oct 27, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 27, 2024
Browse files
Add a watch dog thread (#1816)
parent
1be853ee
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
82 additions
and
39 deletions
+82
-39
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-1
python/sglang/bench_server_latency.py
python/sglang/bench_server_latency.py
+2
-3
python/sglang/launch_server.py
python/sglang/launch_server.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+33
-5
python/sglang/srt/server.py
python/sglang/srt/server.py
+6
-6
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+15
-6
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+5
-5
test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
.../sampling/penaltylib/test_srt_endpoint_with_penalizers.py
+1
-1
test/srt/test_cache_report.py
test/srt/test_cache_report.py
+1
-1
test/srt/test_data_parallelism.py
test/srt/test_data_parallelism.py
+1
-1
test/srt/test_double_sparsity.py
test/srt/test_double_sparsity.py
+1
-1
test/srt/test_embedding_openai_server.py
test/srt/test_embedding_openai_server.py
+1
-1
test/srt/test_eval_accuracy_large.py
test/srt/test_eval_accuracy_large.py
+1
-1
test/srt/test_eval_accuracy_large_chunked_prefill.py
test/srt/test_eval_accuracy_large_chunked_prefill.py
+1
-1
test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py
test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py
+1
-1
test/srt/test_eval_accuracy_mini.py
test/srt/test_eval_accuracy_mini.py
+1
-1
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+1
-1
test/srt/test_large_max_new_tokens.py
test/srt/test_large_max_new_tokens.py
+1
-1
test/srt/test_matched_stop.py
test/srt/test_matched_stop.py
+1
-1
No files found.
python/sglang/bench_latency.py
View file @
86fc0d79
...
@@ -550,4 +550,4 @@ if __name__ == "__main__":
...
@@ -550,4 +550,4 @@ if __name__ == "__main__":
except
Exception
as
e
:
except
Exception
as
e
:
raise
e
raise
e
finally
:
finally
:
kill_child_process
(
os
.
getpid
(),
including_parent
=
False
)
kill_child_process
()
python/sglang/bench_server_latency.py
View file @
86fc0d79
...
@@ -15,7 +15,6 @@ import dataclasses
...
@@ -15,7 +15,6 @@ import dataclasses
import
itertools
import
itertools
import
json
import
json
import
multiprocessing
import
multiprocessing
import
os
import
time
import
time
from
typing
import
Tuple
from
typing
import
Tuple
...
@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
...
@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
except
Exception
as
e
:
except
Exception
as
e
:
raise
e
raise
e
finally
:
finally
:
kill_child_process
(
os
.
getpid
(),
including_parent
=
False
)
kill_child_process
()
def
launch_server_process
(
server_args
:
ServerArgs
):
def
launch_server_process
(
server_args
:
ServerArgs
):
...
@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
)
)
finally
:
finally
:
if
proc
:
if
proc
:
kill_child_process
(
proc
.
pid
)
kill_child_process
(
proc
.
pid
,
include_self
=
True
)
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
...
...
python/sglang/launch_server.py
View file @
86fc0d79
...
@@ -15,4 +15,4 @@ if __name__ == "__main__":
...
@@ -15,4 +15,4 @@ if __name__ == "__main__":
except
Exception
as
e
:
except
Exception
as
e
:
raise
e
raise
e
finally
:
finally
:
kill_child_process
(
os
.
getpid
(),
including_parent
=
False
)
kill_child_process
()
python/sglang/srt/managers/scheduler.py
View file @
86fc0d79
...
@@ -18,6 +18,7 @@ limitations under the License.
...
@@ -18,6 +18,7 @@ limitations under the License.
import
json
import
json
import
logging
import
logging
import
os
import
os
import
threading
import
time
import
time
import
warnings
import
warnings
from
collections
import
deque
from
collections
import
deque
...
@@ -222,10 +223,11 @@ class Scheduler:
...
@@ -222,10 +223,11 @@ class Scheduler:
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
cur_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
cur_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
decode_
forward_ct
=
0
self
.
forward_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
self
.
forward_ct_decode
=
0
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
self
.
last_stats_tic
=
time
.
time
()
self
.
stream_interval
=
server_args
.
stream_interval
# Init chunked prefill
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
...
@@ -272,6 +274,11 @@ class Scheduler:
...
@@ -272,6 +274,11 @@ class Scheduler:
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
# Init watchdog thread
self
.
watchdog_timeout
=
server_args
.
watchdog_timeout
t
=
threading
.
Thread
(
target
=
self
.
watchdog_thread
,
daemon
=
True
)
t
.
start
()
# Init profiler
# Init profiler
if
os
.
getenv
(
"SGLANG_TORCH_PROFILER_DIR"
,
""
)
==
""
:
if
os
.
getenv
(
"SGLANG_TORCH_PROFILER_DIR"
,
""
)
==
""
:
self
.
profiler
=
None
self
.
profiler
=
None
...
@@ -289,6 +296,23 @@ class Scheduler:
...
@@ -289,6 +296,23 @@ class Scheduler:
with_stack
=
True
,
with_stack
=
True
,
)
)
def
watchdog_thread
(
self
):
self
.
watchdog_last_forward_ct
=
0
self
.
watchdog_last_time
=
time
.
time
()
while
True
:
if
self
.
cur_batch
is
not
None
:
if
self
.
watchdog_last_forward_ct
==
self
.
forward_ct
:
if
time
.
time
()
>
self
.
watchdog_last_time
+
self
.
watchdog_timeout
:
logger
.
error
(
f
"Watchdog timeout (
{
self
.
watchdog_timeout
=
}
)"
)
break
else
:
self
.
watchdog_last_forward_ct
=
self
.
forward_ct
self
.
watchdog_last_time
=
time
.
time
()
time
.
sleep
(
self
.
watchdog_timeout
/
2
)
kill_parent_process
()
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
event_loop_normal
(
self
):
def
event_loop_normal
(
self
):
"""A normal blocking scheduler loop."""
"""A normal blocking scheduler loop."""
...
@@ -299,6 +323,7 @@ class Scheduler:
...
@@ -299,6 +323,7 @@ class Scheduler:
self
.
process_input_requests
(
recv_reqs
)
self
.
process_input_requests
(
recv_reqs
)
batch
=
self
.
get_next_batch_to_run
()
batch
=
self
.
get_next_batch_to_run
()
self
.
cur_batch
=
batch
if
batch
:
if
batch
:
result
=
self
.
run_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
...
@@ -746,6 +771,8 @@ class Scheduler:
...
@@ -746,6 +771,8 @@ class Scheduler:
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
"""Run a batch."""
"""Run a batch."""
self
.
forward_ct
+=
1
if
self
.
is_generation
:
if
self
.
is_generation
:
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
@@ -778,6 +805,7 @@ class Scheduler:
...
@@ -778,6 +805,7 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
)
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
if
self
.
is_generation
:
logits_output
,
next_token_ids
,
bid
=
result
logits_output
,
next_token_ids
,
bid
=
result
...
@@ -890,8 +918,8 @@ class Scheduler:
...
@@ -890,8 +918,8 @@ class Scheduler:
self
.
token_to_kv_pool
.
free_group_end
()
self
.
token_to_kv_pool
.
free_group_end
()
self
.
decode_
forward_ct
=
(
self
.
decode_
forward_ct
+
1
)
%
(
1
<<
30
)
self
.
forward_ct
_decode
=
(
self
.
forward_ct
_decode
+
1
)
%
(
1
<<
30
)
if
self
.
tp_rank
==
0
and
self
.
decode_
forward_ct
%
40
==
0
:
if
self
.
tp_rank
==
0
and
self
.
forward_ct
_decode
%
40
==
0
:
self
.
print_decode_stats
()
self
.
print_decode_stats
()
def
add_logprob_return_values
(
def
add_logprob_return_values
(
...
@@ -984,7 +1012,7 @@ class Scheduler:
...
@@ -984,7 +1012,7 @@ class Scheduler:
else
:
# embedding or reward model
else
:
# embedding or reward model
output_embeddings
=
[]
output_embeddings
=
[]
is_stream_iter
=
self
.
decode_
forward_ct
%
self
.
stream_interval
==
0
is_stream_iter
=
self
.
forward_ct
_decode
%
self
.
stream_interval
==
0
for
req
in
reqs
:
for
req
in
reqs
:
if
req
.
finished
()
or
(
if
req
.
finished
()
or
(
...
...
python/sglang/srt/server.py
View file @
86fc0d79
...
@@ -441,7 +441,7 @@ def launch_server(
...
@@ -441,7 +441,7 @@ def launch_server(
# Send a warmup request
# Send a warmup request
t
=
threading
.
Thread
(
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
,
args
=
(
server_args
,
pipe_finish_writer
,
os
.
getpid
()
)
target
=
_wait_and_warmup
,
args
=
(
server_args
,
pipe_finish_writer
)
)
)
t
.
start
()
t
.
start
()
...
@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
,
pid
):
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
):
headers
=
{}
headers
=
{}
url
=
server_args
.
url
()
url
=
server_args
.
url
()
if
server_args
.
api_key
:
if
server_args
.
api_key
:
...
@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
...
@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
last_traceback
)
pipe_finish_writer
.
send
(
last_traceback
)
logger
.
error
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
)
logger
.
error
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
)
kill_child_process
(
pid
,
includ
ing_parent
=
Fals
e
)
kill_child_process
(
includ
e_self
=
Tru
e
)
return
return
model_info
=
res
.
json
()
model_info
=
res
.
json
()
...
@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
...
@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
last_traceback
)
pipe_finish_writer
.
send
(
last_traceback
)
logger
.
error
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
)
logger
.
error
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
)
kill_child_process
(
pid
,
includ
ing_parent
=
Fals
e
)
kill_child_process
(
includ
e_self
=
Tru
e
)
return
return
# logger.info(f"{res.json()=}")
# logger.info(f"{res.json()=}")
...
@@ -617,7 +617,7 @@ class Runtime:
...
@@ -617,7 +617,7 @@ class Runtime:
def
shutdown
(
self
):
def
shutdown
(
self
):
if
self
.
pid
is
not
None
:
if
self
.
pid
is
not
None
:
kill_child_process
(
self
.
pid
)
kill_child_process
(
self
.
pid
,
include_self
=
True
)
self
.
pid
=
None
self
.
pid
=
None
def
cache_prefix
(
self
,
prefix
:
str
):
def
cache_prefix
(
self
,
prefix
:
str
):
...
@@ -834,7 +834,7 @@ class Engine:
...
@@ -834,7 +834,7 @@ class Engine:
return
ret
return
ret
def
shutdown
(
self
):
def
shutdown
(
self
):
kill_child_process
(
os
.
getpid
(),
including_parent
=
Fals
e
)
kill_child_process
(
include_self
=
Tru
e
)
def
get_tokenizer
(
self
):
def
get_tokenizer
(
self
):
global
tokenizer_manager
global
tokenizer_manager
...
...
python/sglang/srt/server_args.py
View file @
86fc0d79
...
@@ -74,6 +74,7 @@ class ServerArgs:
...
@@ -74,6 +74,7 @@ class ServerArgs:
api_key
:
Optional
[
str
]
=
None
api_key
:
Optional
[
str
]
=
None
file_storage_pth
:
str
=
"SGLang_storage"
file_storage_pth
:
str
=
"SGLang_storage"
enable_cache_report
:
bool
=
False
enable_cache_report
:
bool
=
False
watchdog_timeout
:
float
=
600
# Data parallelism
# Data parallelism
dp_size
:
int
=
1
dp_size
:
int
=
1
...
@@ -429,6 +430,12 @@ class ServerArgs:
...
@@ -429,6 +430,12 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Return number of cached tokens in usage.prompt_tokens_details for each openai request."
,
help
=
"Return number of cached tokens in usage.prompt_tokens_details for each openai request."
,
)
)
parser
.
add_argument
(
"--watchdog-timeout"
,
type
=
float
,
default
=
ServerArgs
.
watchdog_timeout
,
help
=
"Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging."
,
)
# Data parallelism
# Data parallelism
parser
.
add_argument
(
parser
.
add_argument
(
...
...
python/sglang/srt/utils.py
View file @
86fc0d79
...
@@ -398,17 +398,26 @@ def kill_parent_process():
...
@@ -398,17 +398,26 @@ def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
"""Kill the parent process and all children of the parent process."""
current_process
=
psutil
.
Process
()
current_process
=
psutil
.
Process
()
parent_process
=
current_process
.
parent
()
parent_process
=
current_process
.
parent
()
kill_child_process
(
parent_process
.
pid
,
skip_pid
=
current_process
.
pid
)
kill_child_process
(
parent_process
.
pid
,
include_self
=
True
,
skip_pid
=
current_process
.
pid
)
try
:
current_process
.
kill
()
except
psutil
.
NoSuchProcess
:
pass
def
kill_child_process
(
pid
,
includ
ing_parent
=
Tru
e
,
skip_pid
=
None
):
def
kill_child_process
(
pid
=
None
,
includ
e_self
=
Fals
e
,
skip_pid
=
None
):
"""Kill the process and all its children process."""
"""Kill the process and all its children process."""
if
pid
is
None
:
pid
=
os
.
getpid
()
try
:
try
:
parent
=
psutil
.
Process
(
pid
)
itself
=
psutil
.
Process
(
pid
)
except
psutil
.
NoSuchProcess
:
except
psutil
.
NoSuchProcess
:
return
return
children
=
parent
.
children
(
recursive
=
True
)
children
=
itself
.
children
(
recursive
=
True
)
for
child
in
children
:
for
child
in
children
:
if
child
.
pid
==
skip_pid
:
if
child
.
pid
==
skip_pid
:
continue
continue
...
@@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
...
@@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
except
psutil
.
NoSuchProcess
:
except
psutil
.
NoSuchProcess
:
pass
pass
if
includ
ing_parent
:
if
includ
e_self
:
try
:
try
:
parent
.
kill
()
itself
.
kill
()
except
psutil
.
NoSuchProcess
:
except
psutil
.
NoSuchProcess
:
pass
pass
...
...
python/sglang/test/test_utils.py
View file @
86fc0d79
...
@@ -495,7 +495,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
...
@@ -495,7 +495,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
)
)
assert
ret_code
==
0
assert
ret_code
==
0
except
TimeoutError
:
except
TimeoutError
:
kill_child_process
(
process
.
pid
)
kill_child_process
(
process
.
pid
,
include_self
=
True
)
time
.
sleep
(
5
)
time
.
sleep
(
5
)
print
(
print
(
f
"
\n
Timeout after
{
timeout_per_file
}
seconds when running
{
filename
}
\n
"
,
f
"
\n
Timeout after
{
timeout_per_file
}
seconds when running
{
filename
}
\n
"
,
...
@@ -563,7 +563,7 @@ def run_bench_serving(
...
@@ -563,7 +563,7 @@ def run_bench_serving(
try
:
try
:
res
=
run_benchmark
(
args
)
res
=
run_benchmark
(
args
)
finally
:
finally
:
kill_child_process
(
process
.
pid
)
kill_child_process
(
process
.
pid
,
include_self
=
True
)
assert
res
[
"completed"
]
==
num_prompts
assert
res
[
"completed"
]
==
num_prompts
return
res
return
res
...
@@ -596,7 +596,7 @@ def run_bench_latency(model, other_args):
...
@@ -596,7 +596,7 @@ def run_bench_latency(model, other_args):
lastline
=
output
.
split
(
"
\n
"
)[
-
3
]
lastline
=
output
.
split
(
"
\n
"
)[
-
3
]
output_throughput
=
float
(
lastline
.
split
(
" "
)[
-
2
])
output_throughput
=
float
(
lastline
.
split
(
" "
)[
-
2
])
finally
:
finally
:
kill_child_process
(
process
.
pid
)
kill_child_process
(
process
.
pid
,
include_self
=
True
)
return
output_throughput
return
output_throughput
...
@@ -707,8 +707,8 @@ def run_mmlu_test(
...
@@ -707,8 +707,8 @@ def run_mmlu_test(
pass
pass
# Clean up everything
# Clean up everything
kill_child_process
(
process
.
pid
)
kill_child_process
(
process
.
pid
,
include_self
=
True
)
kill_child_process
(
process
.
pid
)
kill_child_process
(
process
.
pid
,
include_self
=
True
)
stdout
.
close
()
stdout
.
close
()
stderr
.
close
()
stderr
.
close
()
if
os
.
path
.
exists
(
STDOUT_FILENAME
):
if
os
.
path
.
exists
(
STDOUT_FILENAME
):
...
...
test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
View file @
86fc0d79
...
@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
...
@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
run_decode
(
def
run_decode
(
self
,
self
,
...
...
test/srt/test_cache_report.py
View file @
86fc0d79
...
@@ -45,7 +45,7 @@ class TestCacheReport(unittest.TestCase):
...
@@ -45,7 +45,7 @@ class TestCacheReport(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
run_decode
(
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
):
def
run_decode
(
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
):
response
=
requests
.
post
(
response
=
requests
.
post
(
...
...
test/srt/test_data_parallelism.py
View file @
86fc0d79
...
@@ -25,7 +25,7 @@ class TestDataParallelism(unittest.TestCase):
...
@@ -25,7 +25,7 @@ class TestDataParallelism(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
test_mmlu
(
self
):
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
...
...
test/srt/test_double_sparsity.py
View file @
86fc0d79
...
@@ -43,7 +43,7 @@ class TestDoubleSparsity(unittest.TestCase):
...
@@ -43,7 +43,7 @@ class TestDoubleSparsity(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
test_mmlu
(
self
):
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
...
...
test/srt/test_embedding_openai_server.py
View file @
86fc0d79
...
@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
run_embedding
(
self
,
use_list_input
,
token_input
):
def
run_embedding
(
self
,
use_list_input
,
token_input
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
...
...
test/srt/test_eval_accuracy_large.py
View file @
86fc0d79
...
@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
...
@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
test_mmlu
(
self
):
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
...
...
test/srt/test_eval_accuracy_large_chunked_prefill.py
View file @
86fc0d79
...
@@ -25,7 +25,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
...
@@ -25,7 +25,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
test_mmlu
(
self
):
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
...
...
test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py
View file @
86fc0d79
...
@@ -31,7 +31,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
...
@@ -31,7 +31,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
test_mmlu
(
self
):
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
...
...
test/srt/test_eval_accuracy_mini.py
View file @
86fc0d79
...
@@ -22,7 +22,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
...
@@ -22,7 +22,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
test_mmlu
(
self
):
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
...
...
test/srt/test_json_constrained.py
View file @
86fc0d79
...
@@ -41,7 +41,7 @@ class TestJSONConstrained(unittest.TestCase):
...
@@ -41,7 +41,7 @@ class TestJSONConstrained(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
run_decode
(
self
,
json_schema
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
):
def
run_decode
(
self
,
json_schema
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
):
response
=
requests
.
post
(
response
=
requests
.
post
(
...
...
test/srt/test_large_max_new_tokens.py
View file @
86fc0d79
...
@@ -42,7 +42,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
...
@@ -42,7 +42,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
cls
.
stdout
.
close
()
cls
.
stdout
.
close
()
cls
.
stderr
.
close
()
cls
.
stderr
.
close
()
os
.
remove
(
"stdout.txt"
)
os
.
remove
(
"stdout.txt"
)
...
...
test/srt/test_matched_stop.py
View file @
86fc0d79
...
@@ -32,7 +32,7 @@ class TestMatchedStop(unittest.TestCase):
...
@@ -32,7 +32,7 @@ class TestMatchedStop(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
run_completions_generation
(
def
run_completions_generation
(
self
,
self
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment