Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7d671e4a
Unverified
Commit
7d671e4a
authored
Nov 19, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 19, 2024
Browse files
Enable overlap by default (#2067)
parent
699384cb
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
91 additions
and
74 deletions
+91
-74
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+4
-18
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+4
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+22
-2
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+10
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-3
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+2
-7
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+20
-10
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+5
-5
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+2
-2
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+5
-4
test/srt/test_moe_eval_accuracy_large.py
test/srt/test_moe_eval_accuracy_large.py
+1
-1
test/srt/test_non_overlap_scheduler.py
test/srt/test_non_overlap_scheduler.py
+4
-4
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+2
-2
test/srt/test_torch_compile.py
test/srt/test_torch_compile.py
+4
-3
test/srt/test_torch_compile_moe.py
test/srt/test_torch_compile_moe.py
+3
-2
No files found.
python/sglang/bench_latency.py
View file @
7d671e4a
...
@@ -220,7 +220,8 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
...
@@ -220,7 +220,8 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return
reqs
return
reqs
def
_extend
(
reqs
,
model_runner
):
@
torch
.
no_grad
def
extend
(
reqs
,
model_runner
):
batch
=
ScheduleBatch
.
init_new
(
batch
=
ScheduleBatch
.
init_new
(
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
...
@@ -236,15 +237,8 @@ def _extend(reqs, model_runner):
...
@@ -236,15 +237,8 @@ def _extend(reqs, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
def
extend
(
reqs
,
model_runner
):
@
torch
.
no_grad
# Disable inference mode for now when torch TP is applied. We can remove
def
decode
(
input_token_ids
,
batch
,
model_runner
):
# this workaround once DTensor adds support for inference mode.
use_inf_mode
=
not
model_runner
.
torch_tp_applied
with
torch
.
inference_mode
(
use_inf_mode
):
return
_extend
(
reqs
,
model_runner
)
def
_decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
output_ids
=
input_token_ids
batch
.
output_ids
=
input_token_ids
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
@@ -254,14 +248,6 @@ def _decode(input_token_ids, batch, model_runner):
...
@@ -254,14 +248,6 @@ def _decode(input_token_ids, batch, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
return
next_token_ids
,
logits_output
.
next_token_logits
def
decode
(
input_token_ids
,
batch
,
model_runner
):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode
=
not
model_runner
.
torch_tp_applied
with
torch
.
inference_mode
(
use_inf_mode
):
return
_decode
(
input_token_ids
,
batch
,
model_runner
)
def
correctness_test
(
def
correctness_test
(
server_args
,
server_args
,
port_args
,
port_args
,
...
...
python/sglang/srt/constrained/outlines_backend.py
View file @
7d671e4a
...
@@ -87,9 +87,12 @@ class OutlinesGrammar(BaseGrammarObject):
...
@@ -87,9 +87,12 @@ class OutlinesGrammar(BaseGrammarObject):
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
).
to
(
vocab_mask
.
device
,
non_blocking
=
True
)
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
.
fill_
(
1
)
vocab_mask
.
fill_
(
1
)
vocab_mask
[
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
]
=
0
vocab_mask
.
scatter_
(
0
,
tokens
,
torch
.
zeros_like
(
tokens
,
dtype
=
torch
.
bool
))
@
staticmethod
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7d671e4a
...
@@ -899,10 +899,7 @@ class ScheduleBatch:
...
@@ -899,10 +899,7 @@ class ScheduleBatch:
self
.
input_ids
=
self
.
output_ids
self
.
input_ids
=
self
.
output_ids
self
.
output_ids
=
None
self
.
output_ids
=
None
if
self
.
sampling_info
.
penalizer_orchestrator
:
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
self
.
input_ids
)
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
self
.
input_ids
)
# Alloc mem
# Alloc mem
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
7d671e4a
...
@@ -30,7 +30,7 @@ import torch
...
@@ -30,7 +30,7 @@ import torch
import
zmq
import
zmq
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
...
@@ -102,7 +102,7 @@ class Scheduler:
...
@@ -102,7 +102,7 @@ class Scheduler:
self
.
disable_jump_forward
=
server_args
.
disable_jump_forward
self
.
disable_jump_forward
=
server_args
.
disable_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
enable_overlap
=
server_args
.
en
able_overlap_schedule
self
.
enable_overlap
=
not
server_args
.
dis
able_overlap_schedule
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
self
.
enable_metrics
=
server_args
.
enable_metrics
...
@@ -159,6 +159,23 @@ class Scheduler:
...
@@ -159,6 +159,23 @@ class Scheduler:
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
)
# Check whether overlap can be enabled
if
not
self
.
is_generation
:
self
.
enable_overlap
=
False
logger
.
info
(
"Overlap scheduler is disabled for embedding models."
)
if
(
server_args
.
attention_backend
==
"triton"
or
server_args
.
enable_double_sparsity
or
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
not
self
.
server_args
.
disable_mla
)
):
self
.
enable_overlap
=
False
logger
.
info
(
"Overlap scheduler is disabled if using triton attention backend."
)
# Launch a tensor parallel worker
# Launch a tensor parallel worker
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
TpWorkerClass
=
TpModelWorkerClient
TpWorkerClass
=
TpModelWorkerClient
...
@@ -903,6 +920,7 @@ class Scheduler:
...
@@ -903,6 +920,7 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_dummy_first
():
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
...
@@ -958,6 +976,7 @@ class Scheduler:
...
@@ -958,6 +976,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
else
:
# embedding or reward model
...
@@ -1031,6 +1050,7 @@ class Scheduler:
...
@@ -1031,6 +1050,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
self
.
stream_output
(
batch
.
reqs
)
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
7d671e4a
...
@@ -157,14 +157,19 @@ class TpModelWorkerClient:
...
@@ -157,14 +157,19 @@ class TpModelWorkerClient:
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
# A cuda stream sync here to avoid the cuda illegal memory access error.
# A cuda stream sync here to avoid the cuda illegal memory access error.
_
=
model_worker_batch
.
seq_lens
[
0
].
item
()
torch
.
cuda
.
current_stream
().
synchronize
()
# Push a new batch to the queue
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
model_worker_batch
.
sampling_info
=
dataclasses
.
replace
(
sampling_info
=
model_worker_batch
.
sampling_info
model_worker_batch
.
sampling_info
,
sampling_info
.
update_penalties
()
model_worker_batch
.
sampling_info
=
self
.
cur_sampling_info
=
dataclasses
.
replace
(
sampling_info
,
sampling_info_done
=
threading
.
Event
(),
sampling_info_done
=
threading
.
Event
(),
scaling_penalties
=
sampling_info
.
scaling_penalties
,
linear_penalties
=
sampling_info
.
linear_penalties
,
)
)
self
.
cur_sampling_info
=
model_worker_batch
.
sampling_info
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
# Allocate output future objects
# Allocate output future objects
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
7d671e4a
...
@@ -116,7 +116,7 @@ class ModelRunner:
...
@@ -116,7 +116,7 @@ class ModelRunner:
)
)
if
self
.
is_multimodal
:
if
self
.
is_multimodal
:
logger
.
warning
(
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
)
server_args
.
chunked_prefill_size
=
None
server_args
.
chunked_prefill_size
=
None
...
@@ -636,13 +636,11 @@ class ModelRunner:
...
@@ -636,13 +636,11 @@ class ModelRunner:
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
sampling_info
=
forward_batch
.
sampling_info
sampling_info
=
forward_batch
.
sampling_info
if
sampling_info
.
sampling_info_done
:
if
sampling_info
.
sampling_info_done
:
# Overlap mode: the function update_regex_vocab_mask was executed
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
# in process_batch_result of the last batch.
if
sampling_info
.
grammars
:
if
sampling_info
.
grammars
:
sampling_info
.
sampling_info_done
.
wait
()
sampling_info
.
sampling_info_done
.
wait
()
sampling_info
.
update_penalties
()
else
:
else
:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info
.
update_regex_vocab_mask
()
sampling_info
.
update_regex_vocab_mask
()
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
7d671e4a
...
@@ -132,9 +132,6 @@ class SamplingBatchInfo:
...
@@ -132,9 +132,6 @@ class SamplingBatchInfo:
return
len
(
self
.
temperatures
)
return
len
(
self
.
temperatures
)
def
update_penalties
(
self
):
def
update_penalties
(
self
):
if
not
self
.
penalizer_orchestrator
:
return
self
.
scaling_penalties
=
None
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
self
.
linear_penalties
=
None
...
@@ -176,8 +173,7 @@ class SamplingBatchInfo:
...
@@ -176,8 +173,7 @@ class SamplingBatchInfo:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
for
item
in
[
"temperatures"
,
"temperatures"
,
...
@@ -216,8 +212,7 @@ class SamplingBatchInfo:
...
@@ -216,8 +212,7 @@ class SamplingBatchInfo:
return
None
return
None
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
for
item
in
[
"temperatures"
,
"temperatures"
,
...
...
python/sglang/srt/server_args.py
View file @
7d671e4a
...
@@ -123,7 +123,7 @@ class ServerArgs:
...
@@ -123,7 +123,7 @@ class ServerArgs:
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_mla
:
bool
=
False
disable_mla
:
bool
=
False
en
able_overlap_schedule
:
bool
=
False
dis
able_overlap_schedule
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
...
@@ -172,9 +172,7 @@ class ServerArgs:
...
@@ -172,9 +172,7 @@ class ServerArgs:
if
gpu_mem
<
25000
:
if
gpu_mem
<
25000
:
self
.
chunked_prefill_size
//=
4
# make it 2048
self
.
chunked_prefill_size
//=
4
# make it 2048
self
.
cuda_graph_max_bs
=
4
self
.
cuda_graph_max_bs
=
4
logger
.
warning
(
logger
.
info
(
"Automatically adjust --chunked-prefill-size for small GPUs."
)
"Automatically adjust --chunked-prefill-size for small GPUs."
)
if
not
is_flashinfer_available
():
if
not
is_flashinfer_available
():
self
.
attention_backend
=
"triton"
self
.
attention_backend
=
"triton"
...
@@ -192,15 +190,22 @@ class ServerArgs:
...
@@ -192,15 +190,22 @@ class ServerArgs:
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
cuda_graph_max_bs
=
min
(
self
.
cuda_graph_max_bs
,
96
)
self
.
cuda_graph_max_bs
=
min
(
self
.
cuda_graph_max_bs
,
96
)
self
.
schedule_conservativeness
=
self
.
schedule_conservativeness
*
0.3
self
.
schedule_conservativeness
=
self
.
schedule_conservativeness
*
0.3
self
.
en
able_overlap_schedule
=
Fals
e
self
.
dis
able_overlap_schedule
=
Tru
e
logger
.
warning
(
logger
.
info
(
f
"DP attention is enabled. The chunked prefill size is adjusted to
{
self
.
chunked_prefill_size
}
to avoid MoE kernel issues. "
f
"DP attention is enabled. The chunked prefill size is adjusted to
{
self
.
chunked_prefill_size
}
to avoid MoE kernel issues. "
f
"The CUDA graph max batch size is adjusted to
{
self
.
cuda_graph_max_bs
}
. "
f
"The CUDA graph max batch size is adjusted to
{
self
.
cuda_graph_max_bs
}
. "
f
"The schedule conservativeness is adjusted to
{
self
.
schedule_conservativeness
}
. "
f
"The schedule conservativeness is adjusted to
{
self
.
schedule_conservativeness
}
. "
"Data parallel size is adjusted to be the same as tensor parallel size."
"Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap schedule is disabled."
)
if
self
.
enable_mixed_chunk
:
logger
.
info
(
"Overlap schedule is disabled because mixed-style chunked prefill is enabled."
)
)
self
.
disable_overlap_schedule
=
True
if
self
.
en
able_overlap_schedule
:
if
not
self
.
dis
able_overlap_schedule
:
self
.
disable_jump_forward
=
True
self
.
disable_jump_forward
=
True
@
staticmethod
@
staticmethod
...
@@ -624,9 +629,9 @@ class ServerArgs:
...
@@ -624,9 +629,9 @@ class ServerArgs:
help
=
"Disable the NaN detection for better performance."
,
help
=
"Disable the NaN detection for better performance."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
en
able-overlap-schedule"
,
"--
dis
able-overlap-schedule"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"
Overlap
the CPU scheduler with GPU model worker.
Experimental feature.
"
,
help
=
"
Disable the overlap scheduler, which overlaps
the CPU scheduler with GPU model worker."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-mixed-chunk"
,
"--enable-mixed-chunk"
,
...
@@ -692,6 +697,11 @@ class ServerArgs:
...
@@ -692,6 +697,11 @@ class ServerArgs:
)
)
# Deprecated arguments
# Deprecated arguments
parser
.
add_argument
(
"--enable-overlap-schedule"
,
action
=
DeprecatedAction
,
help
=
"'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-flashinfer"
,
"--disable-flashinfer"
,
action
=
DeprecatedAction
,
action
=
DeprecatedAction
,
...
...
python/sglang/test/test_utils.py
View file @
7d671e4a
...
@@ -670,7 +670,7 @@ def run_and_check_memory_leak(
...
@@ -670,7 +670,7 @@ def run_and_check_memory_leak(
workload_func
,
workload_func
,
disable_radix_cache
,
disable_radix_cache
,
enable_mixed_chunk
,
enable_mixed_chunk
,
en
able_overlap
,
dis
able_overlap
,
chunked_prefill_size
,
chunked_prefill_size
,
):
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
...
@@ -678,8 +678,8 @@ def run_and_check_memory_leak(
...
@@ -678,8 +678,8 @@ def run_and_check_memory_leak(
other_args
+=
[
"--disable-radix-cache"
]
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
if
enable_mixed_chunk
:
other_args
+=
[
"--enable-mixed-chunk"
]
other_args
+=
[
"--enable-mixed-chunk"
]
if
en
able_overlap
:
if
dis
able_overlap
:
other_args
+=
[
"--
en
able-overlap-schedule"
]
other_args
+=
[
"--
dis
able-overlap-schedule"
]
model
=
DEFAULT_MODEL_NAME_FOR_TEST
model
=
DEFAULT_MODEL_NAME_FOR_TEST
port
=
random
.
randint
(
4000
,
5000
)
port
=
random
.
randint
(
4000
,
5000
)
...
@@ -731,7 +731,7 @@ def run_and_check_memory_leak(
...
@@ -731,7 +731,7 @@ def run_and_check_memory_leak(
def
run_mmlu_test
(
def
run_mmlu_test
(
disable_radix_cache
=
False
,
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
,
enable_mixed_chunk
=
False
,
en
able_overlap
=
False
,
dis
able_overlap
=
False
,
chunked_prefill_size
=
32
,
chunked_prefill_size
=
32
,
):
):
def
workload_func
(
base_url
,
model
):
def
workload_func
(
base_url
,
model
):
...
@@ -754,7 +754,7 @@ def run_mmlu_test(
...
@@ -754,7 +754,7 @@ def run_mmlu_test(
workload_func
,
workload_func
,
disable_radix_cache
,
disable_radix_cache
,
enable_mixed_chunk
,
enable_mixed_chunk
,
en
able_overlap
,
dis
able_overlap
,
chunked_prefill_size
,
chunked_prefill_size
,
)
)
...
...
test/srt/run_suite.py
View file @
7d671e4a
...
@@ -17,8 +17,8 @@ suites = {
...
@@ -17,8 +17,8 @@ suites = {
"test_json_constrained.py"
,
"test_json_constrained.py"
,
"test_large_max_new_tokens.py"
,
"test_large_max_new_tokens.py"
,
"test_metrics.py"
,
"test_metrics.py"
,
"test_non_overlap_scheduler.py"
,
"test_openai_server.py"
,
"test_openai_server.py"
,
"test_overlap_schedule.py"
,
"test_pytorch_sampling_backend.py"
,
"test_pytorch_sampling_backend.py"
,
"test_radix_attention.py"
,
"test_radix_attention.py"
,
"test_retract_decode.py"
,
"test_retract_decode.py"
,
...
...
test/srt/test_bench_serving.py
View file @
7d671e4a
...
@@ -97,8 +97,8 @@ class TestBenchServing(unittest.TestCase):
...
@@ -97,8 +97,8 @@ class TestBenchServing(unittest.TestCase):
if
is_in_ci
():
if
is_in_ci
():
self
.
assertLess
(
res
[
"median_e2e_latency_ms"
],
12000
)
self
.
assertLess
(
res
[
"median_e2e_latency_ms"
],
12000
)
self
.
assertLess
(
res
[
"median_ttft_ms"
],
8
0
)
self
.
assertLess
(
res
[
"median_ttft_ms"
],
8
6
)
self
.
assertLess
(
res
[
"median_itl_ms"
],
1
1
)
self
.
assertLess
(
res
[
"median_itl_ms"
],
1
0
)
def
test_moe_offline_throughput_default
(
self
):
def
test_moe_offline_throughput_default
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
...
test/srt/test_json_constrained.py
View file @
7d671e4a
...
@@ -78,10 +78,11 @@ class TestJSONConstrained(unittest.TestCase):
...
@@ -78,10 +78,11 @@ class TestJSONConstrained(unittest.TestCase):
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
# Make sure jump forward is triggered
# Make sure jump forward is triggered
self
.
assertGreater
(
# NOTE: This is skipped because overlap scheduler does not support jump forward
ret
[
"meta_info"
][
"completion_tokens"
],
# self.assertGreater(
ret
[
"meta_info"
][
"completion_tokens_wo_jump_forward"
],
# ret["meta_info"]["completion_tokens"],
)
# ret["meta_info"]["completion_tokens_wo_jump_forward"],
# )
def
test_json_generate
(
self
):
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
...
...
test/srt/test_moe_eval_accuracy_large.py
View file @
7d671e4a
...
@@ -59,7 +59,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
...
@@ -59,7 +59,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
self
.
assertGreater
(
metrics
[
"score"
],
0.4
1
)
self
.
assertGreater
(
metrics
[
"score"
],
0.4
0
)
def
test_mgsm_en
(
self
):
def
test_mgsm_en
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
...
...
test/srt/test_overlap_schedule.py
→
test/srt/test_
non_
overlap_schedule
r
.py
View file @
7d671e4a
...
@@ -12,22 +12,22 @@ from sglang.test.test_utils import run_mmlu_test
...
@@ -12,22 +12,22 @@ from sglang.test.test_utils import run_mmlu_test
class
TestOverlapSchedule
(
unittest
.
TestCase
):
class
TestOverlapSchedule
(
unittest
.
TestCase
):
def
test_no_radix_attention_chunked_prefill
(
self
):
def
test_no_radix_attention_chunked_prefill
(
self
):
run_mmlu_test
(
run_mmlu_test
(
disable_radix_cache
=
True
,
chunked_prefill_size
=
32
,
en
able_overlap
=
True
disable_radix_cache
=
True
,
chunked_prefill_size
=
32
,
dis
able_overlap
=
True
)
)
def
test_no_radix_attention_no_chunked_prefill
(
self
):
def
test_no_radix_attention_no_chunked_prefill
(
self
):
run_mmlu_test
(
run_mmlu_test
(
disable_radix_cache
=
True
,
chunked_prefill_size
=-
1
,
en
able_overlap
=
True
disable_radix_cache
=
True
,
chunked_prefill_size
=-
1
,
dis
able_overlap
=
True
)
)
def
test_radix_attention_chunked_prefill
(
self
):
def
test_radix_attention_chunked_prefill
(
self
):
run_mmlu_test
(
run_mmlu_test
(
disable_radix_cache
=
False
,
chunked_prefill_size
=
32
,
en
able_overlap
=
True
disable_radix_cache
=
False
,
chunked_prefill_size
=
32
,
dis
able_overlap
=
True
)
)
def
test_radix_attention_no_chunked_prefill
(
self
):
def
test_radix_attention_no_chunked_prefill
(
self
):
run_mmlu_test
(
run_mmlu_test
(
disable_radix_cache
=
False
,
chunked_prefill_size
=-
1
,
en
able_overlap
=
True
disable_radix_cache
=
False
,
chunked_prefill_size
=-
1
,
dis
able_overlap
=
True
)
)
...
...
test/srt/test_radix_attention.py
View file @
7d671e4a
...
@@ -107,7 +107,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
...
@@ -107,7 +107,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
)
)
class
TestRadixCacheOverlapLPM
(
TestRadixCacheFCFS
):
class
TestRadixCache
Non
OverlapLPM
(
TestRadixCacheFCFS
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...
@@ -117,7 +117,7 @@ class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
...
@@ -117,7 +117,7 @@ class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
other_args
=
[
"--
en
able-overlap-schedule"
,
"--
dis
able-overlap-schedule"
,
"--chunked-prefill-size"
,
"--chunked-prefill-size"
,
"128"
,
"128"
,
"--max-total-tokens"
,
"--max-total-tokens"
,
...
...
test/srt/test_torch_compile.py
View file @
7d671e4a
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
...
@@ -56,14 +57,14 @@ class TestTorchCompile(unittest.TestCase):
...
@@ -56,14 +57,14 @@ class TestTorchCompile(unittest.TestCase):
return
response
.
json
()
return
response
.
json
()
def
test_throughput
(
self
):
def
test_throughput
(
self
):
import
time
# Warmup
res
=
self
.
run_decode
(
16
)
max_tokens
=
256
max_tokens
=
256
tic
=
time
.
time
()
tic
=
time
.
time
()
res
=
self
.
run_decode
(
max_tokens
)
res
=
self
.
run_decode
(
max_tokens
)
tok
=
time
.
time
()
tok
=
time
.
time
()
print
(
res
[
"text"
]
)
print
(
f
"
{
res
=
}
"
)
throughput
=
max_tokens
/
(
tok
-
tic
)
throughput
=
max_tokens
/
(
tok
-
tic
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
self
.
assertGreaterEqual
(
throughput
,
152
)
self
.
assertGreaterEqual
(
throughput
,
152
)
...
...
test/srt/test_torch_compile_moe.py
View file @
7d671e4a
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
...
@@ -56,10 +57,10 @@ class TestTorchCompile(unittest.TestCase):
...
@@ -56,10 +57,10 @@ class TestTorchCompile(unittest.TestCase):
return
response
.
json
()
return
response
.
json
()
def
test_throughput
(
self
):
def
test_throughput
(
self
):
import
time
# Warmup
res
=
self
.
run_decode
(
16
)
max_tokens
=
256
max_tokens
=
256
tic
=
time
.
time
()
tic
=
time
.
time
()
res
=
self
.
run_decode
(
max_tokens
)
res
=
self
.
run_decode
(
max_tokens
)
tok
=
time
.
time
()
tok
=
time
.
time
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment