Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3694f8f9
"tests/python/vscode:/vscode.git/clone" did not exist on "ae4a5b739412d817da36b86c858f00e9605022a9"
Unverified
Commit
3694f8f9
authored
Aug 16, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 16, 2024
Browse files
Mixed style of chunked prefill (#1013)
parent
5a261bd0
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
195 additions
and
59 deletions
+195
-59
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+5
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+27
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+36
-10
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+19
-19
python/sglang/srt/server.py
python/sglang/srt/server.py
+0
-9
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
python/sglang/test/simple_eval_common.py
python/sglang/test/simple_eval_common.py
+9
-10
python/sglang/test/simple_eval_gpqa.py
python/sglang/test/simple_eval_gpqa.py
+2
-1
python/sglang/test/simple_eval_humaneval.py
python/sglang/test/simple_eval_humaneval.py
+2
-2
python/sglang/test/simple_eval_math.py
python/sglang/test/simple_eval_math.py
+2
-1
python/sglang/test/simple_eval_mmlu.py
python/sglang/test/simple_eval_mmlu.py
+2
-1
test/srt/test_chunked_prefill.py
test/srt/test_chunked_prefill.py
+12
-3
test/srt/test_eval_accuracy_large_chunked_prefill.py
test/srt/test_eval_accuracy_large_chunked_prefill.py
+0
-1
test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py
test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py
+73
-0
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
3694f8f9
...
@@ -111,11 +111,14 @@ class PrefillAdder:
...
@@ -111,11 +111,14 @@ class PrefillAdder:
rem_total_tokens
:
int
,
rem_total_tokens
:
int
,
rem_input_tokens
:
int
,
rem_input_tokens
:
int
,
rem_chunk_tokens
:
Optional
[
int
],
rem_chunk_tokens
:
Optional
[
int
],
mixed_with_decode_tokens
:
int
=
0
,
):
):
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
self
.
rem_total_tokens
=
rem_total_tokens
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_input_tokens
=
rem_input_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
mixed_with_decode_tokens
self
.
can_run_list
=
[]
self
.
can_run_list
=
[]
self
.
new_inflight_req
=
None
self
.
new_inflight_req
=
None
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
3694f8f9
...
@@ -329,6 +329,9 @@ class ScheduleBatch:
...
@@ -329,6 +329,9 @@ class ScheduleBatch:
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
extend_num_tokens
:
int
=
None
extend_num_tokens
:
int
=
None
# For mixed chunekd prefill
prefix_lens_cpu
:
List
[
int
]
=
None
# For processing logprobs
# For processing logprobs
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
top_logprobs_nums
:
List
[
int
]
=
None
...
@@ -462,9 +465,33 @@ class ScheduleBatch:
...
@@ -462,9 +465,33 @@ class ScheduleBatch:
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
prefix_lens_cpu
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
batch_sampling_params
(
vocab_size
)
self
.
batch_sampling_params
(
vocab_size
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
prefix_lens_cpu
=
[
len
(
r
.
prefix_indices
)
for
r
in
self
.
reqs
]
prefix_lens_cpu
.
extend
(
[
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
-
1
for
r
in
running_batch
.
reqs
]
)
for
req
in
running_batch
.
reqs
:
req
.
fill_ids
=
req
.
origin_input_ids
+
req
.
output_ids
req
.
extend_input_len
=
1
input_ids
=
torch
.
cat
([
self
.
input_ids
,
running_batch
.
input_ids
])
out_cache_loc
=
torch
.
cat
([
self
.
out_cache_loc
,
running_batch
.
out_cache_loc
])
extend_num_tokens
=
self
.
extend_num_tokens
+
running_batch
.
batch_size
()
self
.
merge
(
running_batch
)
self
.
input_ids
=
input_ids
self
.
out_cache_loc
=
out_cache_loc
self
.
extend_num_tokens
=
extend_num_tokens
self
.
prefix_lens_cpu
=
prefix_lens_cpu
def
check_decode_mem
(
self
):
def
check_decode_mem
(
self
):
bs
=
self
.
batch_size
()
bs
=
self
.
batch_size
()
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
...
...
python/sglang/srt/managers/tp_worker.py
View file @
3694f8f9
...
@@ -174,6 +174,9 @@ class ModelTpServer:
...
@@ -174,6 +174,9 @@ class ModelTpServer:
# Chunked prefill
# Chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
self
.
current_inflight_req
=
None
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
# Init the FSM cache for constrained generation
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
if
not
server_args
.
skip_tokenizer_init
:
...
@@ -366,11 +369,14 @@ class ModelTpServer:
...
@@ -366,11 +369,14 @@ class ModelTpServer:
# Get priority queue
# Get priority queue
prefix_computed
=
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
prefix_computed
=
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
num_mixed_running
=
running_bs
if
self
.
is_mixed_chunk
else
0
adder
=
PrefillAdder
(
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
tree_cache
,
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
self
.
max_prefill_tokens
,
self
.
max_prefill_tokens
,
self
.
chunked_prefill_size
,
self
.
chunked_prefill_size
,
num_mixed_running
,
)
)
if
self
.
running_batch
is
not
None
:
if
self
.
running_batch
is
not
None
:
...
@@ -416,15 +422,27 @@ class ModelTpServer:
...
@@ -416,15 +422,27 @@ class ModelTpServer:
)
)
else
:
else
:
tree_cache_hit_rate
=
0.0
tree_cache_hit_rate
=
0.0
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Prefill batch. "
if
num_mixed_running
>
0
:
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
logger
.
info
(
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"[gpu=
{
self
.
gpu_id
}
] Prefill batch"
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"(mixed #running-req:
{
num_mixed_running
}
). "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
)
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
else
:
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Prefill batch. "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
# Return the new batch
# Return the new batch
new_batch
=
ScheduleBatch
.
init_new
(
new_batch
=
ScheduleBatch
.
init_new
(
...
@@ -440,6 +458,13 @@ class ModelTpServer:
...
@@ -440,6 +458,13 @@ class ModelTpServer:
# Build batch tensors
# Build batch tensors
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
decoding_reqs
=
[]
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
batch
.
mix_with_running
(
self
.
running_batch
)
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
if
self
.
model_runner
.
is_generation
:
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
...
@@ -481,7 +506,8 @@ class ModelTpServer:
...
@@ -481,7 +506,8 @@ class ModelTpServer:
if
req
.
finished
():
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
elif
req
not
in
decoding_reqs
:
# To reduce overhead, only cache prefill reqs
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
if
req
is
self
.
current_inflight_req
:
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
3694f8f9
...
@@ -88,11 +88,11 @@ class InputMetadata:
...
@@ -88,11 +88,11 @@ class InputMetadata:
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[
self
.
image_offsets
=
[
(
(
(
r
.
image_offset
-
len
(
r
.
prefix_
indices
)
)
(
r
.
image_offset
-
batch
.
prefix_
lens_cpu
[
i
]
)
if
r
.
image_offset
is
not
None
if
r
.
image_offset
is
not
None
else
0
else
0
)
)
for
r
in
reqs
for
i
,
r
in
enumerate
(
reqs
)
]
]
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
...
@@ -109,8 +109,8 @@ class InputMetadata:
...
@@ -109,8 +109,8 @@ class InputMetadata:
self
.
positions
=
torch
.
tensor
(
self
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
np
.
concatenate
(
[
[
np
.
arange
(
len
(
req
.
prefix_
indices
)
,
len
(
req
.
fill_ids
))
np
.
arange
(
batch
.
prefix_
lens_cpu
[
i
]
,
len
(
req
.
fill_ids
))
for
req
in
batch
.
reqs
for
i
,
req
in
enumerate
(
batch
.
reqs
)
],
],
axis
=
0
,
axis
=
0
,
),
),
...
@@ -123,7 +123,7 @@ class InputMetadata:
...
@@ -123,7 +123,7 @@ class InputMetadata:
np
.
concatenate
(
np
.
concatenate
(
[
[
np
.
arange
(
np
.
arange
(
len
(
req
.
prefix_
indices
)
+
position_ids_offsets_cpu
[
i
],
batch
.
prefix_
lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
len
(
req
.
fill_ids
)
+
position_ids_offsets_cpu
[
i
],
len
(
req
.
fill_ids
)
+
position_ids_offsets_cpu
[
i
],
)
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
...
@@ -141,12 +141,13 @@ class InputMetadata:
...
@@ -141,12 +141,13 @@ class InputMetadata:
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
else
:
else
:
extend_lens_cpu
=
[
extend_lens_cpu
=
[
len
(
r
.
fill_ids
)
-
len
(
r
.
prefix_indices
)
for
r
in
batch
.
reqs
len
(
r
.
fill_ids
)
-
batch
.
prefix_lens_cpu
[
i
]
for
i
,
r
in
enumerate
(
batch
.
reqs
)
]
]
self
.
extend_seq_lens
=
torch
.
tensor
(
extend_lens_cpu
,
device
=
"cuda"
)
self
.
extend_seq_lens
=
torch
.
tensor
(
extend_lens_cpu
,
device
=
"cuda"
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
extend_no_prefix
=
all
(
l
en
(
r
.
prefix_indices
)
==
0
for
r
in
batch
.
re
qs
)
self
.
extend_no_prefix
=
all
(
l
==
0
for
l
in
batch
.
p
re
fix_lens_cpu
)
@
classmethod
@
classmethod
def
from_schedule_batch
(
def
from_schedule_batch
(
...
@@ -180,14 +181,8 @@ class InputMetadata:
...
@@ -180,14 +181,8 @@ class InputMetadata:
if
forward_mode
!=
ForwardMode
.
DECODE
:
if
forward_mode
!=
ForwardMode
.
DECODE
:
ret
.
init_multimuldal_info
(
batch
)
ret
.
init_multimuldal_info
(
batch
)
prefix_lens
=
None
if
forward_mode
!=
ForwardMode
.
DECODE
:
prefix_lens
=
torch
.
tensor
(
[
len
(
r
.
prefix_indices
)
for
r
in
batch
.
reqs
],
device
=
"cuda"
)
if
model_runner
.
server_args
.
disable_flashinfer
:
if
model_runner
.
server_args
.
disable_flashinfer
:
ret
.
init_triton_args
(
batch
,
prefix_lens
)
ret
.
init_triton_args
(
batch
)
flashinfer_use_ragged
=
False
flashinfer_use_ragged
=
False
if
not
model_runner
.
server_args
.
disable_flashinfer
:
if
not
model_runner
.
server_args
.
disable_flashinfer
:
...
@@ -198,30 +193,35 @@ class InputMetadata:
...
@@ -198,30 +193,35 @@ class InputMetadata:
):
):
flashinfer_use_ragged
=
True
flashinfer_use_ragged
=
True
ret
.
init_flashinfer_handlers
(
ret
.
init_flashinfer_handlers
(
model_runner
,
prefix_lens
,
flashinfer_use_ragged
model_runner
,
batch
.
prefix_lens
_cpu
,
flashinfer_use_ragged
)
)
return
ret
return
ret
def
init_triton_args
(
self
,
batch
:
ScheduleBatch
,
prefix_lens
):
def
init_triton_args
(
self
,
batch
:
ScheduleBatch
):
"""Init auxiliary variables for triton attention backend."""
"""Init auxiliary variables for triton attention backend."""
self
.
triton_max_seq_len
=
int
(
torch
.
max
(
self
.
seq_lens
))
self
.
triton_max_seq_len
=
int
(
torch
.
max
(
self
.
seq_lens
))
self
.
triton_prefix_lens
=
prefix_lens
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
triton_max_extend_len
=
None
self
.
triton_max_extend_len
=
None
else
:
else
:
extend_seq_lens
=
self
.
seq_lens
-
prefix_lens
self
.
triton_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
extend_seq_lens
=
self
.
seq_lens
-
self
.
triton_prefix_lens
self
.
triton_max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
self
.
triton_max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
def
init_flashinfer_handlers
(
def
init_flashinfer_handlers
(
self
,
self
,
model_runner
,
model_runner
,
prefix_lens
,
prefix_lens
_cpu
,
flashinfer_use_ragged
,
flashinfer_use_ragged
,
):
):
if
self
.
forward_mode
!=
ForwardMode
.
DECODE
:
prefix_lens
=
torch
.
tensor
(
prefix_lens_cpu
,
device
=
"cuda"
)
else
:
prefix_lens
=
None
update_flashinfer_indices
(
update_flashinfer_indices
(
self
.
forward_mode
,
self
.
forward_mode
,
model_runner
,
model_runner
,
...
...
python/sglang/srt/server.py
View file @
3694f8f9
...
@@ -445,15 +445,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
...
@@ -445,15 +445,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
print
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
,
flush
=
True
)
print
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
,
flush
=
True
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
# Print warnings here
if
server_args
.
disable_radix_cache
and
server_args
.
chunked_prefill_size
is
not
None
:
logger
.
warning
(
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
"This combination is an experimental feature and we noticed it can lead to "
"wrong generation results. If you want to use chunked prefill, it is recommended "
"not using `--disable-radix-cache`."
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"init ok"
)
pipe_finish_writer
.
send
(
"init ok"
)
...
...
python/sglang/srt/server_args.py
View file @
3694f8f9
...
@@ -80,6 +80,7 @@ class ServerArgs:
...
@@ -80,6 +80,7 @@ class ServerArgs:
disable_regex_jump_forward
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_mla
:
bool
=
False
enable_mla
:
bool
=
False
...
@@ -396,6 +397,11 @@ class ServerArgs:
...
@@ -396,6 +397,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
)
)
parser
.
add_argument
(
"--enable-mixed-chunk"
,
action
=
"store_true"
,
help
=
"Enabling mixing prefill and decode in a chunked batch."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-torch-compile"
,
"--enable-torch-compile"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
python/sglang/test/simple_eval_common.py
View file @
3694f8f9
# Adapted from https://github.com/openai/simple-evals/
# Adapted from https://github.com/openai/simple-evals/
import
base64
import
os
import
os
import
resource
import
resource
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
multiprocessing.pool
import
ThreadPool
from
multiprocessing.pool
import
ThreadPool
from
typing
import
Any
,
Dict
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
httpx
import
httpx
import
jinja2
import
jinja2
...
@@ -44,8 +43,8 @@ class EvalResult:
...
@@ -44,8 +43,8 @@ class EvalResult:
Result of running an evaluation (usually consisting of many samples)
Result of running an evaluation (usually consisting of many samples)
"""
"""
score
:
float
|
None
# top-line metric
score
:
Optional
[
float
]
# top-line metric
metrics
:
Dict
[
str
,
float
]
|
None
# other metrics
metrics
:
Optional
[
Dict
[
str
,
float
]
]
# other metrics
htmls
:
List
[
str
]
# strings of valid HTML
htmls
:
List
[
str
]
# strings of valid HTML
convos
:
List
[
MessageList
]
# sampled conversations
convos
:
List
[
MessageList
]
# sampled conversations
...
@@ -56,10 +55,10 @@ class SingleEvalResult:
...
@@ -56,10 +55,10 @@ class SingleEvalResult:
Result of evaluating a single sample
Result of evaluating a single sample
"""
"""
score
:
float
|
None
score
:
Optional
[
float
]
metrics
:
Dict
[
str
,
float
]
=
field
(
default_factory
=
dict
)
metrics
:
Dict
[
str
,
float
]
=
field
(
default_factory
=
dict
)
html
:
str
|
None
=
None
html
:
Optional
[
str
]
=
None
convo
:
MessageList
|
None
=
None
# sampled conversation
convo
:
Optional
[
MessageList
]
=
None
# sampled conversation
class
Eval
:
class
Eval
:
...
@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase):
...
@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase):
def
__init__
(
def
__init__
(
self
,
self
,
base_url
:
str
=
None
,
base_url
:
str
=
None
,
model
:
str
|
None
=
None
,
model
:
Optional
[
str
]
=
None
,
system_message
:
str
|
None
=
None
,
system_message
:
Optional
[
str
]
=
None
,
temperature
:
float
=
0.0
,
temperature
:
float
=
0.0
,
max_tokens
:
int
=
2048
,
max_tokens
:
int
=
2048
,
):
):
...
@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str):
...
@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str):
def
aggregate_results
(
def
aggregate_results
(
single_eval_results
:
List
[
SingleEvalResult
],
single_eval_results
:
List
[
SingleEvalResult
],
default_stats
:
Tuple
[
str
]
=
(
"mean"
,
"std"
),
default_stats
:
Tuple
[
str
]
=
(
"mean"
,
"std"
),
name2stats
:
Dict
[
str
,
Tuple
[
str
]]
|
None
=
None
,
name2stats
:
Optional
[
Dict
[
str
,
Tuple
[
str
]]
]
=
None
,
)
->
EvalResult
:
)
->
EvalResult
:
"""
"""
Aggregate results from multiple evaluations into a single EvalResult.
Aggregate results from multiple evaluations into a single EvalResult.
...
...
python/sglang/test/simple_eval_gpqa.py
View file @
3694f8f9
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022
import
random
import
random
import
re
import
re
from
typing
import
Optional
import
pandas
import
pandas
...
@@ -28,7 +29,7 @@ class GPQAEval(Eval):
...
@@ -28,7 +29,7 @@ class GPQAEval(Eval):
def
__init__
(
def
__init__
(
self
,
self
,
filename
:
str
,
filename
:
str
,
num_examples
:
int
|
None
,
num_examples
:
Optional
[
int
]
,
num_threads
:
int
,
num_threads
:
int
,
n_repeats
:
int
=
1
,
n_repeats
:
int
=
1
,
):
):
...
...
python/sglang/test/simple_eval_humaneval.py
View file @
3694f8f9
...
@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
...
@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
import
random
import
random
import
re
import
re
from
concurrent.futures
import
ThreadPoolExecutor
,
as_completed
from
concurrent.futures
import
ThreadPoolExecutor
,
as_completed
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Optional
import
tqdm
import
tqdm
...
@@ -61,7 +61,7 @@ def evaluate_functional_correctness(
...
@@ -61,7 +61,7 @@ def evaluate_functional_correctness(
class
HumanEval
(
Eval
):
class
HumanEval
(
Eval
):
def
__init__
(
def
__init__
(
self
,
self
,
num_examples
:
int
|
None
,
num_examples
:
Optional
[
int
]
,
num_threads
:
int
,
num_threads
:
int
,
num_samples_per_task
:
int
=
5
,
num_samples_per_task
:
int
=
5
,
ks_passes
:
List
[
int
]
=
[
1
,
2
,
5
],
ks_passes
:
List
[
int
]
=
[
1
,
2
,
5
],
...
...
python/sglang/test/simple_eval_math.py
View file @
3694f8f9
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874
import
random
import
random
import
re
import
re
from
typing
import
Optional
import
pandas
import
pandas
...
@@ -36,7 +37,7 @@ class MathEval(Eval):
...
@@ -36,7 +37,7 @@ class MathEval(Eval):
self
,
self
,
filename
:
str
,
filename
:
str
,
equality_checker
:
SamplerBase
,
equality_checker
:
SamplerBase
,
num_examples
:
int
|
None
,
num_examples
:
Optional
[
int
]
,
num_threads
:
int
,
num_threads
:
int
,
):
):
df
=
pandas
.
read_csv
(
filename
)
df
=
pandas
.
read_csv
(
filename
)
...
...
python/sglang/test/simple_eval_mmlu.py
View file @
3694f8f9
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300
import
random
import
random
import
re
import
re
from
typing
import
Optional
import
pandas
import
pandas
...
@@ -84,7 +85,7 @@ subject2category = {
...
@@ -84,7 +85,7 @@ subject2category = {
class
MMLUEval
(
Eval
):
class
MMLUEval
(
Eval
):
def
__init__
(
self
,
filename
:
str
,
num_examples
:
int
|
None
,
num_threads
:
int
):
def
__init__
(
self
,
filename
:
str
,
num_examples
:
Optional
[
int
]
,
num_threads
:
int
):
df
=
pandas
.
read_csv
(
filename
)
df
=
pandas
.
read_csv
(
filename
)
examples
=
[
row
.
to_dict
()
for
_
,
row
in
df
.
iterrows
()]
examples
=
[
row
.
to_dict
()
for
_
,
row
in
df
.
iterrows
()]
if
num_examples
:
if
num_examples
:
...
...
test/srt/test_chunked_prefill.py
View file @
3694f8f9
...
@@ -11,11 +11,14 @@ from sglang.test.test_utils import (
...
@@ -11,11 +11,14 @@ from sglang.test.test_utils import (
class
TestChunkedPrefill
(
unittest
.
TestCase
):
class
TestChunkedPrefill
(
unittest
.
TestCase
):
def
run_mmlu
(
self
,
disable_radix_cache
):
def
run_mmlu
(
self
,
disable_radix_cache
,
enable_mixed_chunk
):
other_args
=
[
"--chunked-prefill-size"
,
"32"
]
other_args
=
[
"--chunked-prefill-size"
,
"32"
]
if
disable_radix_cache
:
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
other_args
+=
[
"--enable-mixed-chunk"
]
model
=
DEFAULT_MODEL_NAME_FOR_TEST
model
=
DEFAULT_MODEL_NAME_FOR_TEST
base_url
=
DEFAULT_URL_FOR_UNIT_TEST
base_url
=
DEFAULT_URL_FOR_UNIT_TEST
process
=
popen_launch_server
(
process
=
popen_launch_server
(
...
@@ -40,10 +43,16 @@ class TestChunkedPrefill(unittest.TestCase):
...
@@ -40,10 +43,16 @@ class TestChunkedPrefill(unittest.TestCase):
kill_child_process
(
process
.
pid
)
kill_child_process
(
process
.
pid
)
def
test_chunked_prefill
(
self
):
def
test_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
)
self
.
run_mmlu
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
)
def
test_mixed_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
True
)
def
test_chunked_prefill_without_radix_cache
(
self
):
def
test_chunked_prefill_without_radix_cache
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
)
self
.
run_mmlu
(
disable_radix_cache
=
True
,
enable_mixed_chunk
=
False
)
def
test_mixed_chunked_prefill_without_radix_cache
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
,
enable_mixed_chunk
=
True
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_eval_accuracy_large_chunked_prefill.py
View file @
3694f8f9
...
@@ -6,7 +6,6 @@ from sglang.test.run_eval import run_eval
...
@@ -6,7 +6,6 @@ from sglang.test.run_eval import run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_ACCURACY_TEST
,
DEFAULT_URL_FOR_ACCURACY_TEST
,
DEFAULT_URL_FOR_UNIT_TEST
,
popen_launch_server
,
popen_launch_server
,
)
)
...
...
test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py
0 → 100644
View file @
3694f8f9
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_ACCURACY_TEST
,
popen_launch_server
,
)
class
TestEvalAccuracyLargeChunkedPrefill
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_ACCURACY_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--log-level-http"
,
"warning"
,
"--chunked-prefill-size"
,
"256"
,
"--enable-mixed-chunk"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
3000
,
num_threads
=
1024
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.71
,
f
"
{
metrics
}
"
def
test_human_eval
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"humaneval"
,
num_examples
=
None
,
num_threads
=
1024
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.64
,
f
"
{
metrics
}
"
def
test_mgsm_en
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mgsm_en"
,
num_examples
=
None
,
num_threads
=
1024
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.84
,
f
"
{
metrics
}
"
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment