Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3694f8f9
Unverified
Commit
3694f8f9
authored
Aug 16, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 16, 2024
Browse files
Mixed style of chunked prefill (#1013)
parent
5a261bd0
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
195 additions
and
59 deletions
+195
-59
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+5
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+27
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+36
-10
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+19
-19
python/sglang/srt/server.py
python/sglang/srt/server.py
+0
-9
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
python/sglang/test/simple_eval_common.py
python/sglang/test/simple_eval_common.py
+9
-10
python/sglang/test/simple_eval_gpqa.py
python/sglang/test/simple_eval_gpqa.py
+2
-1
python/sglang/test/simple_eval_humaneval.py
python/sglang/test/simple_eval_humaneval.py
+2
-2
python/sglang/test/simple_eval_math.py
python/sglang/test/simple_eval_math.py
+2
-1
python/sglang/test/simple_eval_mmlu.py
python/sglang/test/simple_eval_mmlu.py
+2
-1
test/srt/test_chunked_prefill.py
test/srt/test_chunked_prefill.py
+12
-3
test/srt/test_eval_accuracy_large_chunked_prefill.py
test/srt/test_eval_accuracy_large_chunked_prefill.py
+0
-1
test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py
test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py
+73
-0
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
3694f8f9
...
@@ -111,11 +111,14 @@ class PrefillAdder:
...
@@ -111,11 +111,14 @@ class PrefillAdder:
rem_total_tokens
:
int
,
rem_total_tokens
:
int
,
rem_input_tokens
:
int
,
rem_input_tokens
:
int
,
rem_chunk_tokens
:
Optional
[
int
],
rem_chunk_tokens
:
Optional
[
int
],
mixed_with_decode_tokens
:
int
=
0
,
):
):
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
self
.
rem_total_tokens
=
rem_total_tokens
self
.
rem_total_tokens
=
rem_total_tokens
-
mixed_with_decode_tokens
self
.
rem_input_tokens
=
rem_input_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
mixed_with_decode_tokens
self
.
can_run_list
=
[]
self
.
can_run_list
=
[]
self
.
new_inflight_req
=
None
self
.
new_inflight_req
=
None
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
3694f8f9
...
@@ -329,6 +329,9 @@ class ScheduleBatch:
...
@@ -329,6 +329,9 @@ class ScheduleBatch:
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
extend_num_tokens
:
int
=
None
extend_num_tokens
:
int
=
None
# For mixed chunekd prefill
prefix_lens_cpu
:
List
[
int
]
=
None
# For processing logprobs
# For processing logprobs
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
top_logprobs_nums
:
List
[
int
]
=
None
...
@@ -462,9 +465,33 @@ class ScheduleBatch:
...
@@ -462,9 +465,33 @@ class ScheduleBatch:
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
prefix_lens_cpu
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
batch_sampling_params
(
vocab_size
)
self
.
batch_sampling_params
(
vocab_size
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
prefix_lens_cpu
=
[
len
(
r
.
prefix_indices
)
for
r
in
self
.
reqs
]
prefix_lens_cpu
.
extend
(
[
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
-
1
for
r
in
running_batch
.
reqs
]
)
for
req
in
running_batch
.
reqs
:
req
.
fill_ids
=
req
.
origin_input_ids
+
req
.
output_ids
req
.
extend_input_len
=
1
input_ids
=
torch
.
cat
([
self
.
input_ids
,
running_batch
.
input_ids
])
out_cache_loc
=
torch
.
cat
([
self
.
out_cache_loc
,
running_batch
.
out_cache_loc
])
extend_num_tokens
=
self
.
extend_num_tokens
+
running_batch
.
batch_size
()
self
.
merge
(
running_batch
)
self
.
input_ids
=
input_ids
self
.
out_cache_loc
=
out_cache_loc
self
.
extend_num_tokens
=
extend_num_tokens
self
.
prefix_lens_cpu
=
prefix_lens_cpu
def
check_decode_mem
(
self
):
def
check_decode_mem
(
self
):
bs
=
self
.
batch_size
()
bs
=
self
.
batch_size
()
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
...
...
python/sglang/srt/managers/tp_worker.py
View file @
3694f8f9
...
@@ -174,6 +174,9 @@ class ModelTpServer:
...
@@ -174,6 +174,9 @@ class ModelTpServer:
# Chunked prefill
# Chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
self
.
current_inflight_req
=
None
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
# Init the FSM cache for constrained generation
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
if
not
server_args
.
skip_tokenizer_init
:
...
@@ -366,11 +369,14 @@ class ModelTpServer:
...
@@ -366,11 +369,14 @@ class ModelTpServer:
# Get priority queue
# Get priority queue
prefix_computed
=
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
prefix_computed
=
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
num_mixed_running
=
running_bs
if
self
.
is_mixed_chunk
else
0
adder
=
PrefillAdder
(
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
tree_cache
,
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
self
.
max_prefill_tokens
,
self
.
max_prefill_tokens
,
self
.
chunked_prefill_size
,
self
.
chunked_prefill_size
,
num_mixed_running
,
)
)
if
self
.
running_batch
is
not
None
:
if
self
.
running_batch
is
not
None
:
...
@@ -416,15 +422,27 @@ class ModelTpServer:
...
@@ -416,15 +422,27 @@ class ModelTpServer:
)
)
else
:
else
:
tree_cache_hit_rate
=
0.0
tree_cache_hit_rate
=
0.0
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Prefill batch. "
if
num_mixed_running
>
0
:
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
logger
.
info
(
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"[gpu=
{
self
.
gpu_id
}
] Prefill batch"
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"(mixed #running-req:
{
num_mixed_running
}
). "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
)
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
else
:
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Prefill batch. "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
# Return the new batch
# Return the new batch
new_batch
=
ScheduleBatch
.
init_new
(
new_batch
=
ScheduleBatch
.
init_new
(
...
@@ -440,6 +458,13 @@ class ModelTpServer:
...
@@ -440,6 +458,13 @@ class ModelTpServer:
# Build batch tensors
# Build batch tensors
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
decoding_reqs
=
[]
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
batch
.
mix_with_running
(
self
.
running_batch
)
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
if
self
.
model_runner
.
is_generation
:
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
...
@@ -481,7 +506,8 @@ class ModelTpServer:
...
@@ -481,7 +506,8 @@ class ModelTpServer:
if
req
.
finished
():
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
elif
req
not
in
decoding_reqs
:
# To reduce overhead, only cache prefill reqs
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
if
req
is
self
.
current_inflight_req
:
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
3694f8f9
...
@@ -88,11 +88,11 @@ class InputMetadata:
...
@@ -88,11 +88,11 @@ class InputMetadata:
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[
self
.
image_offsets
=
[
(
(
(
r
.
image_offset
-
len
(
r
.
prefix_
indices
)
)
(
r
.
image_offset
-
batch
.
prefix_
lens_cpu
[
i
]
)
if
r
.
image_offset
is
not
None
if
r
.
image_offset
is
not
None
else
0
else
0
)
)
for
r
in
reqs
for
i
,
r
in
enumerate
(
reqs
)
]
]
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
...
@@ -109,8 +109,8 @@ class InputMetadata:
...
@@ -109,8 +109,8 @@ class InputMetadata:
self
.
positions
=
torch
.
tensor
(
self
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
np
.
concatenate
(
[
[
np
.
arange
(
len
(
req
.
prefix_
indices
)
,
len
(
req
.
fill_ids
))
np
.
arange
(
batch
.
prefix_
lens_cpu
[
i
]
,
len
(
req
.
fill_ids
))
for
req
in
batch
.
reqs
for
i
,
req
in
enumerate
(
batch
.
reqs
)
],
],
axis
=
0
,
axis
=
0
,
),
),
...
@@ -123,7 +123,7 @@ class InputMetadata:
...
@@ -123,7 +123,7 @@ class InputMetadata:
np
.
concatenate
(
np
.
concatenate
(
[
[
np
.
arange
(
np
.
arange
(
len
(
req
.
prefix_
indices
)
+
position_ids_offsets_cpu
[
i
],
batch
.
prefix_
lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
len
(
req
.
fill_ids
)
+
position_ids_offsets_cpu
[
i
],
len
(
req
.
fill_ids
)
+
position_ids_offsets_cpu
[
i
],
)
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
...
@@ -141,12 +141,13 @@ class InputMetadata:
...
@@ -141,12 +141,13 @@ class InputMetadata:
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
else
:
else
:
extend_lens_cpu
=
[
extend_lens_cpu
=
[
len
(
r
.
fill_ids
)
-
len
(
r
.
prefix_indices
)
for
r
in
batch
.
reqs
len
(
r
.
fill_ids
)
-
batch
.
prefix_lens_cpu
[
i
]
for
i
,
r
in
enumerate
(
batch
.
reqs
)
]
]
self
.
extend_seq_lens
=
torch
.
tensor
(
extend_lens_cpu
,
device
=
"cuda"
)
self
.
extend_seq_lens
=
torch
.
tensor
(
extend_lens_cpu
,
device
=
"cuda"
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
extend_no_prefix
=
all
(
l
en
(
r
.
prefix_indices
)
==
0
for
r
in
batch
.
re
qs
)
self
.
extend_no_prefix
=
all
(
l
==
0
for
l
in
batch
.
p
re
fix_lens_cpu
)
@
classmethod
@
classmethod
def
from_schedule_batch
(
def
from_schedule_batch
(
...
@@ -180,14 +181,8 @@ class InputMetadata:
...
@@ -180,14 +181,8 @@ class InputMetadata:
if
forward_mode
!=
ForwardMode
.
DECODE
:
if
forward_mode
!=
ForwardMode
.
DECODE
:
ret
.
init_multimuldal_info
(
batch
)
ret
.
init_multimuldal_info
(
batch
)
prefix_lens
=
None
if
forward_mode
!=
ForwardMode
.
DECODE
:
prefix_lens
=
torch
.
tensor
(
[
len
(
r
.
prefix_indices
)
for
r
in
batch
.
reqs
],
device
=
"cuda"
)
if
model_runner
.
server_args
.
disable_flashinfer
:
if
model_runner
.
server_args
.
disable_flashinfer
:
ret
.
init_triton_args
(
batch
,
prefix_lens
)
ret
.
init_triton_args
(
batch
)
flashinfer_use_ragged
=
False
flashinfer_use_ragged
=
False
if
not
model_runner
.
server_args
.
disable_flashinfer
:
if
not
model_runner
.
server_args
.
disable_flashinfer
:
...
@@ -198,30 +193,35 @@ class InputMetadata:
...
@@ -198,30 +193,35 @@ class InputMetadata:
):
):
flashinfer_use_ragged
=
True
flashinfer_use_ragged
=
True
ret
.
init_flashinfer_handlers
(
ret
.
init_flashinfer_handlers
(
model_runner
,
prefix_lens
,
flashinfer_use_ragged
model_runner
,
batch
.
prefix_lens
_cpu
,
flashinfer_use_ragged
)
)
return
ret
return
ret
def
init_triton_args
(
self
,
batch
:
ScheduleBatch
,
prefix_lens
):
def
init_triton_args
(
self
,
batch
:
ScheduleBatch
):
"""Init auxiliary variables for triton attention backend."""
"""Init auxiliary variables for triton attention backend."""
self
.
triton_max_seq_len
=
int
(
torch
.
max
(
self
.
seq_lens
))
self
.
triton_max_seq_len
=
int
(
torch
.
max
(
self
.
seq_lens
))
self
.
triton_prefix_lens
=
prefix_lens
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
triton_max_extend_len
=
None
self
.
triton_max_extend_len
=
None
else
:
else
:
extend_seq_lens
=
self
.
seq_lens
-
prefix_lens
self
.
triton_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
extend_seq_lens
=
self
.
seq_lens
-
self
.
triton_prefix_lens
self
.
triton_max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
self
.
triton_max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
def
init_flashinfer_handlers
(
def
init_flashinfer_handlers
(
self
,
self
,
model_runner
,
model_runner
,
prefix_lens
,
prefix_lens
_cpu
,
flashinfer_use_ragged
,
flashinfer_use_ragged
,
):
):
if
self
.
forward_mode
!=
ForwardMode
.
DECODE
:
prefix_lens
=
torch
.
tensor
(
prefix_lens_cpu
,
device
=
"cuda"
)
else
:
prefix_lens
=
None
update_flashinfer_indices
(
update_flashinfer_indices
(
self
.
forward_mode
,
self
.
forward_mode
,
model_runner
,
model_runner
,
...
...
python/sglang/srt/server.py
View file @
3694f8f9
...
@@ -445,15 +445,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
...
@@ -445,15 +445,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
print
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
,
flush
=
True
)
print
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
,
flush
=
True
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
# Print warnings here
if
server_args
.
disable_radix_cache
and
server_args
.
chunked_prefill_size
is
not
None
:
logger
.
warning
(
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
"This combination is an experimental feature and we noticed it can lead to "
"wrong generation results. If you want to use chunked prefill, it is recommended "
"not using `--disable-radix-cache`."
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"init ok"
)
pipe_finish_writer
.
send
(
"init ok"
)
...
...
python/sglang/srt/server_args.py
View file @
3694f8f9
...
@@ -80,6 +80,7 @@ class ServerArgs:
...
@@ -80,6 +80,7 @@ class ServerArgs:
disable_regex_jump_forward
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_mla
:
bool
=
False
enable_mla
:
bool
=
False
...
@@ -396,6 +397,11 @@ class ServerArgs:
...
@@ -396,6 +397,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
)
)
parser
.
add_argument
(
"--enable-mixed-chunk"
,
action
=
"store_true"
,
help
=
"Enabling mixing prefill and decode in a chunked batch."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-torch-compile"
,
"--enable-torch-compile"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
python/sglang/test/simple_eval_common.py
View file @
3694f8f9
# Adapted from https://github.com/openai/simple-evals/
# Adapted from https://github.com/openai/simple-evals/
import
base64
import
os
import
os
import
resource
import
resource
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
multiprocessing.pool
import
ThreadPool
from
multiprocessing.pool
import
ThreadPool
from
typing
import
Any
,
Dict
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
httpx
import
httpx
import
jinja2
import
jinja2
...
@@ -44,8 +43,8 @@ class EvalResult:
...
@@ -44,8 +43,8 @@ class EvalResult:
Result of running an evaluation (usually consisting of many samples)
Result of running an evaluation (usually consisting of many samples)
"""
"""
score
:
float
|
None
# top-line metric
score
:
Optional
[
float
]
# top-line metric
metrics
:
Dict
[
str
,
float
]
|
None
# other metrics
metrics
:
Optional
[
Dict
[
str
,
float
]
]
# other metrics
htmls
:
List
[
str
]
# strings of valid HTML
htmls
:
List
[
str
]
# strings of valid HTML
convos
:
List
[
MessageList
]
# sampled conversations
convos
:
List
[
MessageList
]
# sampled conversations
...
@@ -56,10 +55,10 @@ class SingleEvalResult:
...
@@ -56,10 +55,10 @@ class SingleEvalResult:
Result of evaluating a single sample
Result of evaluating a single sample
"""
"""
score
:
float
|
None
score
:
Optional
[
float
]
metrics
:
Dict
[
str
,
float
]
=
field
(
default_factory
=
dict
)
metrics
:
Dict
[
str
,
float
]
=
field
(
default_factory
=
dict
)
html
:
str
|
None
=
None
html
:
Optional
[
str
]
=
None
convo
:
MessageList
|
None
=
None
# sampled conversation
convo
:
Optional
[
MessageList
]
=
None
# sampled conversation
class
Eval
:
class
Eval
:
...
@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase):
...
@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase):
def
__init__
(
def
__init__
(
self
,
self
,
base_url
:
str
=
None
,
base_url
:
str
=
None
,
model
:
str
|
None
=
None
,
model
:
Optional
[
str
]
=
None
,
system_message
:
str
|
None
=
None
,
system_message
:
Optional
[
str
]
=
None
,
temperature
:
float
=
0.0
,
temperature
:
float
=
0.0
,
max_tokens
:
int
=
2048
,
max_tokens
:
int
=
2048
,
):
):
...
@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str):
...
@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str):
def
aggregate_results
(
def
aggregate_results
(
single_eval_results
:
List
[
SingleEvalResult
],
single_eval_results
:
List
[
SingleEvalResult
],
default_stats
:
Tuple
[
str
]
=
(
"mean"
,
"std"
),
default_stats
:
Tuple
[
str
]
=
(
"mean"
,
"std"
),
name2stats
:
Dict
[
str
,
Tuple
[
str
]]
|
None
=
None
,
name2stats
:
Optional
[
Dict
[
str
,
Tuple
[
str
]]
]
=
None
,
)
->
EvalResult
:
)
->
EvalResult
:
"""
"""
Aggregate results from multiple evaluations into a single EvalResult.
Aggregate results from multiple evaluations into a single EvalResult.
...
...
python/sglang/test/simple_eval_gpqa.py
View file @
3694f8f9
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022
import
random
import
random
import
re
import
re
from
typing
import
Optional
import
pandas
import
pandas
...
@@ -28,7 +29,7 @@ class GPQAEval(Eval):
...
@@ -28,7 +29,7 @@ class GPQAEval(Eval):
def
__init__
(
def
__init__
(
self
,
self
,
filename
:
str
,
filename
:
str
,
num_examples
:
int
|
None
,
num_examples
:
Optional
[
int
]
,
num_threads
:
int
,
num_threads
:
int
,
n_repeats
:
int
=
1
,
n_repeats
:
int
=
1
,
):
):
...
...
python/sglang/test/simple_eval_humaneval.py
View file @
3694f8f9
...
@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
...
@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
import
random
import
random
import
re
import
re
from
concurrent.futures
import
ThreadPoolExecutor
,
as_completed
from
concurrent.futures
import
ThreadPoolExecutor
,
as_completed
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Optional
import
tqdm
import
tqdm
...
@@ -61,7 +61,7 @@ def evaluate_functional_correctness(
...
@@ -61,7 +61,7 @@ def evaluate_functional_correctness(
class
HumanEval
(
Eval
):
class
HumanEval
(
Eval
):
def
__init__
(
def
__init__
(
self
,
self
,
num_examples
:
int
|
None
,
num_examples
:
Optional
[
int
]
,
num_threads
:
int
,
num_threads
:
int
,
num_samples_per_task
:
int
=
5
,
num_samples_per_task
:
int
=
5
,
ks_passes
:
List
[
int
]
=
[
1
,
2
,
5
],
ks_passes
:
List
[
int
]
=
[
1
,
2
,
5
],
...
...
python/sglang/test/simple_eval_math.py
View file @
3694f8f9
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874
import
random
import
random
import
re
import
re
from
typing
import
Optional
import
pandas
import
pandas
...
@@ -36,7 +37,7 @@ class MathEval(Eval):
...
@@ -36,7 +37,7 @@ class MathEval(Eval):
self
,
self
,
filename
:
str
,
filename
:
str
,
equality_checker
:
SamplerBase
,
equality_checker
:
SamplerBase
,
num_examples
:
int
|
None
,
num_examples
:
Optional
[
int
]
,
num_threads
:
int
,
num_threads
:
int
,
):
):
df
=
pandas
.
read_csv
(
filename
)
df
=
pandas
.
read_csv
(
filename
)
...
...
python/sglang/test/simple_eval_mmlu.py
View file @
3694f8f9
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300
...
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300
import
random
import
random
import
re
import
re
from
typing
import
Optional
import
pandas
import
pandas
...
@@ -84,7 +85,7 @@ subject2category = {
...
@@ -84,7 +85,7 @@ subject2category = {
class
MMLUEval
(
Eval
):
class
MMLUEval
(
Eval
):
def
__init__
(
self
,
filename
:
str
,
num_examples
:
int
|
None
,
num_threads
:
int
):
def
__init__
(
self
,
filename
:
str
,
num_examples
:
Optional
[
int
]
,
num_threads
:
int
):
df
=
pandas
.
read_csv
(
filename
)
df
=
pandas
.
read_csv
(
filename
)
examples
=
[
row
.
to_dict
()
for
_
,
row
in
df
.
iterrows
()]
examples
=
[
row
.
to_dict
()
for
_
,
row
in
df
.
iterrows
()]
if
num_examples
:
if
num_examples
:
...
...
test/srt/test_chunked_prefill.py
View file @
3694f8f9
...
@@ -11,11 +11,14 @@ from sglang.test.test_utils import (
...
@@ -11,11 +11,14 @@ from sglang.test.test_utils import (
class
TestChunkedPrefill
(
unittest
.
TestCase
):
class
TestChunkedPrefill
(
unittest
.
TestCase
):
def
run_mmlu
(
self
,
disable_radix_cache
):
def
run_mmlu
(
self
,
disable_radix_cache
,
enable_mixed_chunk
):
other_args
=
[
"--chunked-prefill-size"
,
"32"
]
other_args
=
[
"--chunked-prefill-size"
,
"32"
]
if
disable_radix_cache
:
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
other_args
+=
[
"--enable-mixed-chunk"
]
model
=
DEFAULT_MODEL_NAME_FOR_TEST
model
=
DEFAULT_MODEL_NAME_FOR_TEST
base_url
=
DEFAULT_URL_FOR_UNIT_TEST
base_url
=
DEFAULT_URL_FOR_UNIT_TEST
process
=
popen_launch_server
(
process
=
popen_launch_server
(
...
@@ -40,10 +43,16 @@ class TestChunkedPrefill(unittest.TestCase):
...
@@ -40,10 +43,16 @@ class TestChunkedPrefill(unittest.TestCase):
kill_child_process
(
process
.
pid
)
kill_child_process
(
process
.
pid
)
def
test_chunked_prefill
(
self
):
def
test_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
)
self
.
run_mmlu
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
)
def
test_mixed_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
True
)
def
test_chunked_prefill_without_radix_cache
(
self
):
def
test_chunked_prefill_without_radix_cache
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
)
self
.
run_mmlu
(
disable_radix_cache
=
True
,
enable_mixed_chunk
=
False
)
def
test_mixed_chunked_prefill_without_radix_cache
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
,
enable_mixed_chunk
=
True
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_eval_accuracy_large_chunked_prefill.py
View file @
3694f8f9
...
@@ -6,7 +6,6 @@ from sglang.test.run_eval import run_eval
...
@@ -6,7 +6,6 @@ from sglang.test.run_eval import run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_ACCURACY_TEST
,
DEFAULT_URL_FOR_ACCURACY_TEST
,
DEFAULT_URL_FOR_UNIT_TEST
,
popen_launch_server
,
popen_launch_server
,
)
)
...
...
test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py
0 → 100644
View file @
3694f8f9
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_ACCURACY_TEST
,
popen_launch_server
,
)
class
TestEvalAccuracyLargeChunkedPrefill
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_ACCURACY_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--log-level-http"
,
"warning"
,
"--chunked-prefill-size"
,
"256"
,
"--enable-mixed-chunk"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
3000
,
num_threads
=
1024
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.71
,
f
"
{
metrics
}
"
def
test_human_eval
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"humaneval"
,
num_examples
=
None
,
num_threads
=
1024
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.64
,
f
"
{
metrics
}
"
def
test_mgsm_en
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mgsm_en"
,
num_examples
=
None
,
num_threads
=
1024
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.84
,
f
"
{
metrics
}
"
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment