Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
7d671e4a
"tests/test_data/vscode:/vscode.git/clone" did not exist on "d0558abb43844102ba4e7d7b56c7953531b33d67"
Unverified
Commit
7d671e4a
authored
Nov 19, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 19, 2024
Browse files
Enable overlap by default (#2067)
parent
699384cb
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
91 additions
and
74 deletions
+91
-74
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+4
-18
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+4
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+22
-2
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+10
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-3
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+2
-7
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+20
-10
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+5
-5
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+2
-2
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+5
-4
test/srt/test_moe_eval_accuracy_large.py
test/srt/test_moe_eval_accuracy_large.py
+1
-1
test/srt/test_non_overlap_scheduler.py
test/srt/test_non_overlap_scheduler.py
+4
-4
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+2
-2
test/srt/test_torch_compile.py
test/srt/test_torch_compile.py
+4
-3
test/srt/test_torch_compile_moe.py
test/srt/test_torch_compile_moe.py
+3
-2
No files found.
python/sglang/bench_latency.py
View file @
7d671e4a
...
@@ -220,7 +220,8 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
...
@@ -220,7 +220,8 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return
reqs
return
reqs
def
_extend
(
reqs
,
model_runner
):
@
torch
.
no_grad
def
extend
(
reqs
,
model_runner
):
batch
=
ScheduleBatch
.
init_new
(
batch
=
ScheduleBatch
.
init_new
(
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
...
@@ -236,15 +237,8 @@ def _extend(reqs, model_runner):
...
@@ -236,15 +237,8 @@ def _extend(reqs, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
def
extend
(
reqs
,
model_runner
):
@
torch
.
no_grad
# Disable inference mode for now when torch TP is applied. We can remove
def
decode
(
input_token_ids
,
batch
,
model_runner
):
# this workaround once DTensor adds support for inference mode.
use_inf_mode
=
not
model_runner
.
torch_tp_applied
with
torch
.
inference_mode
(
use_inf_mode
):
return
_extend
(
reqs
,
model_runner
)
def
_decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
output_ids
=
input_token_ids
batch
.
output_ids
=
input_token_ids
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
@@ -254,14 +248,6 @@ def _decode(input_token_ids, batch, model_runner):
...
@@ -254,14 +248,6 @@ def _decode(input_token_ids, batch, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
return
next_token_ids
,
logits_output
.
next_token_logits
def
decode
(
input_token_ids
,
batch
,
model_runner
):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode
=
not
model_runner
.
torch_tp_applied
with
torch
.
inference_mode
(
use_inf_mode
):
return
_decode
(
input_token_ids
,
batch
,
model_runner
)
def
correctness_test
(
def
correctness_test
(
server_args
,
server_args
,
port_args
,
port_args
,
...
...
python/sglang/srt/constrained/outlines_backend.py
View file @
7d671e4a
...
@@ -87,9 +87,12 @@ class OutlinesGrammar(BaseGrammarObject):
...
@@ -87,9 +87,12 @@ class OutlinesGrammar(BaseGrammarObject):
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
).
to
(
vocab_mask
.
device
,
non_blocking
=
True
)
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
.
fill_
(
1
)
vocab_mask
.
fill_
(
1
)
vocab_mask
[
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
]
=
0
vocab_mask
.
scatter_
(
0
,
tokens
,
torch
.
zeros_like
(
tokens
,
dtype
=
torch
.
bool
))
@
staticmethod
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7d671e4a
...
@@ -899,10 +899,7 @@ class ScheduleBatch:
...
@@ -899,10 +899,7 @@ class ScheduleBatch:
self
.
input_ids
=
self
.
output_ids
self
.
input_ids
=
self
.
output_ids
self
.
output_ids
=
None
self
.
output_ids
=
None
if
self
.
sampling_info
.
penalizer_orchestrator
:
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
self
.
input_ids
)
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
self
.
input_ids
)
# Alloc mem
# Alloc mem
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
7d671e4a
...
@@ -30,7 +30,7 @@ import torch
...
@@ -30,7 +30,7 @@ import torch
import
zmq
import
zmq
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
...
@@ -102,7 +102,7 @@ class Scheduler:
...
@@ -102,7 +102,7 @@ class Scheduler:
self
.
disable_jump_forward
=
server_args
.
disable_jump_forward
self
.
disable_jump_forward
=
server_args
.
disable_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
enable_overlap
=
server_args
.
en
able_overlap_schedule
self
.
enable_overlap
=
not
server_args
.
dis
able_overlap_schedule
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
self
.
enable_metrics
=
server_args
.
enable_metrics
...
@@ -159,6 +159,23 @@ class Scheduler:
...
@@ -159,6 +159,23 @@ class Scheduler:
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
)
# Check whether overlap can be enabled
if
not
self
.
is_generation
:
self
.
enable_overlap
=
False
logger
.
info
(
"Overlap scheduler is disabled for embedding models."
)
if
(
server_args
.
attention_backend
==
"triton"
or
server_args
.
enable_double_sparsity
or
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
not
self
.
server_args
.
disable_mla
)
):
self
.
enable_overlap
=
False
logger
.
info
(
"Overlap scheduler is disabled if using triton attention backend."
)
# Launch a tensor parallel worker
# Launch a tensor parallel worker
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
TpWorkerClass
=
TpModelWorkerClient
TpWorkerClass
=
TpModelWorkerClient
...
@@ -903,6 +920,7 @@ class Scheduler:
...
@@ -903,6 +920,7 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_dummy_first
():
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
...
@@ -958,6 +976,7 @@ class Scheduler:
...
@@ -958,6 +976,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
else
:
# embedding or reward model
...
@@ -1031,6 +1050,7 @@ class Scheduler:
...
@@ -1031,6 +1050,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
self
.
stream_output
(
batch
.
reqs
)
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
7d671e4a
...
@@ -157,14 +157,19 @@ class TpModelWorkerClient:
...
@@ -157,14 +157,19 @@ class TpModelWorkerClient:
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
# A cuda stream sync here to avoid the cuda illegal memory access error.
# A cuda stream sync here to avoid the cuda illegal memory access error.
_
=
model_worker_batch
.
seq_lens
[
0
].
item
()
torch
.
cuda
.
current_stream
().
synchronize
()
# Push a new batch to the queue
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
model_worker_batch
.
sampling_info
=
dataclasses
.
replace
(
sampling_info
=
model_worker_batch
.
sampling_info
model_worker_batch
.
sampling_info
,
sampling_info
.
update_penalties
()
model_worker_batch
.
sampling_info
=
self
.
cur_sampling_info
=
dataclasses
.
replace
(
sampling_info
,
sampling_info_done
=
threading
.
Event
(),
sampling_info_done
=
threading
.
Event
(),
scaling_penalties
=
sampling_info
.
scaling_penalties
,
linear_penalties
=
sampling_info
.
linear_penalties
,
)
)
self
.
cur_sampling_info
=
model_worker_batch
.
sampling_info
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
# Allocate output future objects
# Allocate output future objects
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
7d671e4a
...
@@ -116,7 +116,7 @@ class ModelRunner:
...
@@ -116,7 +116,7 @@ class ModelRunner:
)
)
if
self
.
is_multimodal
:
if
self
.
is_multimodal
:
logger
.
warning
(
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
)
server_args
.
chunked_prefill_size
=
None
server_args
.
chunked_prefill_size
=
None
...
@@ -636,13 +636,11 @@ class ModelRunner:
...
@@ -636,13 +636,11 @@ class ModelRunner:
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
sampling_info
=
forward_batch
.
sampling_info
sampling_info
=
forward_batch
.
sampling_info
if
sampling_info
.
sampling_info_done
:
if
sampling_info
.
sampling_info_done
:
# Overlap mode: the function update_regex_vocab_mask was executed
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
# in process_batch_result of the last batch.
if
sampling_info
.
grammars
:
if
sampling_info
.
grammars
:
sampling_info
.
sampling_info_done
.
wait
()
sampling_info
.
sampling_info_done
.
wait
()
sampling_info
.
update_penalties
()
else
:
else
:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info
.
update_regex_vocab_mask
()
sampling_info
.
update_regex_vocab_mask
()
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
7d671e4a
...
@@ -132,9 +132,6 @@ class SamplingBatchInfo:
...
@@ -132,9 +132,6 @@ class SamplingBatchInfo:
return
len
(
self
.
temperatures
)
return
len
(
self
.
temperatures
)
def
update_penalties
(
self
):
def
update_penalties
(
self
):
if
not
self
.
penalizer_orchestrator
:
return
self
.
scaling_penalties
=
None
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
self
.
linear_penalties
=
None
...
@@ -176,7 +173,6 @@ class SamplingBatchInfo:
...
@@ -176,7 +173,6 @@ class SamplingBatchInfo:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
for
item
in
[
...
@@ -216,7 +212,6 @@ class SamplingBatchInfo:
...
@@ -216,7 +212,6 @@ class SamplingBatchInfo:
return
None
return
None
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
for
item
in
[
...
...
python/sglang/srt/server_args.py
View file @
7d671e4a
...
@@ -123,7 +123,7 @@ class ServerArgs:
...
@@ -123,7 +123,7 @@ class ServerArgs:
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_mla
:
bool
=
False
disable_mla
:
bool
=
False
en
able_overlap_schedule
:
bool
=
False
dis
able_overlap_schedule
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
...
@@ -172,9 +172,7 @@ class ServerArgs:
...
@@ -172,9 +172,7 @@ class ServerArgs:
if
gpu_mem
<
25000
:
if
gpu_mem
<
25000
:
self
.
chunked_prefill_size
//=
4
# make it 2048
self
.
chunked_prefill_size
//=
4
# make it 2048
self
.
cuda_graph_max_bs
=
4
self
.
cuda_graph_max_bs
=
4
logger
.
warning
(
logger
.
info
(
"Automatically adjust --chunked-prefill-size for small GPUs."
)
"Automatically adjust --chunked-prefill-size for small GPUs."
)
if
not
is_flashinfer_available
():
if
not
is_flashinfer_available
():
self
.
attention_backend
=
"triton"
self
.
attention_backend
=
"triton"
...
@@ -192,15 +190,22 @@ class ServerArgs:
...
@@ -192,15 +190,22 @@ class ServerArgs:
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
cuda_graph_max_bs
=
min
(
self
.
cuda_graph_max_bs
,
96
)
self
.
cuda_graph_max_bs
=
min
(
self
.
cuda_graph_max_bs
,
96
)
self
.
schedule_conservativeness
=
self
.
schedule_conservativeness
*
0.3
self
.
schedule_conservativeness
=
self
.
schedule_conservativeness
*
0.3
self
.
en
able_overlap_schedule
=
Fals
e
self
.
dis
able_overlap_schedule
=
Tru
e
logger
.
warning
(
logger
.
info
(
f
"DP attention is enabled. The chunked prefill size is adjusted to
{
self
.
chunked_prefill_size
}
to avoid MoE kernel issues. "
f
"DP attention is enabled. The chunked prefill size is adjusted to
{
self
.
chunked_prefill_size
}
to avoid MoE kernel issues. "
f
"The CUDA graph max batch size is adjusted to
{
self
.
cuda_graph_max_bs
}
. "
f
"The CUDA graph max batch size is adjusted to
{
self
.
cuda_graph_max_bs
}
. "
f
"The schedule conservativeness is adjusted to
{
self
.
schedule_conservativeness
}
. "
f
"The schedule conservativeness is adjusted to
{
self
.
schedule_conservativeness
}
. "
"Data parallel size is adjusted to be the same as tensor parallel size."
"Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap schedule is disabled."
)
if
self
.
enable_mixed_chunk
:
logger
.
info
(
"Overlap schedule is disabled because mixed-style chunked prefill is enabled."
)
)
self
.
disable_overlap_schedule
=
True
if
self
.
en
able_overlap_schedule
:
if
not
self
.
dis
able_overlap_schedule
:
self
.
disable_jump_forward
=
True
self
.
disable_jump_forward
=
True
@
staticmethod
@
staticmethod
...
@@ -624,9 +629,9 @@ class ServerArgs:
...
@@ -624,9 +629,9 @@ class ServerArgs:
help
=
"Disable the NaN detection for better performance."
,
help
=
"Disable the NaN detection for better performance."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
en
able-overlap-schedule"
,
"--
dis
able-overlap-schedule"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"
Overlap
the CPU scheduler with GPU model worker.
Experimental feature.
"
,
help
=
"
Disable the overlap scheduler, which overlaps
the CPU scheduler with GPU model worker."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-mixed-chunk"
,
"--enable-mixed-chunk"
,
...
@@ -692,6 +697,11 @@ class ServerArgs:
...
@@ -692,6 +697,11 @@ class ServerArgs:
)
)
# Deprecated arguments
# Deprecated arguments
parser
.
add_argument
(
"--enable-overlap-schedule"
,
action
=
DeprecatedAction
,
help
=
"'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-flashinfer"
,
"--disable-flashinfer"
,
action
=
DeprecatedAction
,
action
=
DeprecatedAction
,
...
...
python/sglang/test/test_utils.py
View file @
7d671e4a
...
@@ -670,7 +670,7 @@ def run_and_check_memory_leak(
...
@@ -670,7 +670,7 @@ def run_and_check_memory_leak(
workload_func
,
workload_func
,
disable_radix_cache
,
disable_radix_cache
,
enable_mixed_chunk
,
enable_mixed_chunk
,
en
able_overlap
,
dis
able_overlap
,
chunked_prefill_size
,
chunked_prefill_size
,
):
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
...
@@ -678,8 +678,8 @@ def run_and_check_memory_leak(
...
@@ -678,8 +678,8 @@ def run_and_check_memory_leak(
other_args
+=
[
"--disable-radix-cache"
]
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
if
enable_mixed_chunk
:
other_args
+=
[
"--enable-mixed-chunk"
]
other_args
+=
[
"--enable-mixed-chunk"
]
if
en
able_overlap
:
if
dis
able_overlap
:
other_args
+=
[
"--
en
able-overlap-schedule"
]
other_args
+=
[
"--
dis
able-overlap-schedule"
]
model
=
DEFAULT_MODEL_NAME_FOR_TEST
model
=
DEFAULT_MODEL_NAME_FOR_TEST
port
=
random
.
randint
(
4000
,
5000
)
port
=
random
.
randint
(
4000
,
5000
)
...
@@ -731,7 +731,7 @@ def run_and_check_memory_leak(
...
@@ -731,7 +731,7 @@ def run_and_check_memory_leak(
def
run_mmlu_test
(
def
run_mmlu_test
(
disable_radix_cache
=
False
,
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
,
enable_mixed_chunk
=
False
,
en
able_overlap
=
False
,
dis
able_overlap
=
False
,
chunked_prefill_size
=
32
,
chunked_prefill_size
=
32
,
):
):
def
workload_func
(
base_url
,
model
):
def
workload_func
(
base_url
,
model
):
...
@@ -754,7 +754,7 @@ def run_mmlu_test(
...
@@ -754,7 +754,7 @@ def run_mmlu_test(
workload_func
,
workload_func
,
disable_radix_cache
,
disable_radix_cache
,
enable_mixed_chunk
,
enable_mixed_chunk
,
en
able_overlap
,
dis
able_overlap
,
chunked_prefill_size
,
chunked_prefill_size
,
)
)
...
...
test/srt/run_suite.py
View file @
7d671e4a
...
@@ -17,8 +17,8 @@ suites = {
...
@@ -17,8 +17,8 @@ suites = {
"test_json_constrained.py"
,
"test_json_constrained.py"
,
"test_large_max_new_tokens.py"
,
"test_large_max_new_tokens.py"
,
"test_metrics.py"
,
"test_metrics.py"
,
"test_non_overlap_scheduler.py"
,
"test_openai_server.py"
,
"test_openai_server.py"
,
"test_overlap_schedule.py"
,
"test_pytorch_sampling_backend.py"
,
"test_pytorch_sampling_backend.py"
,
"test_radix_attention.py"
,
"test_radix_attention.py"
,
"test_retract_decode.py"
,
"test_retract_decode.py"
,
...
...
test/srt/test_bench_serving.py
View file @
7d671e4a
...
@@ -97,8 +97,8 @@ class TestBenchServing(unittest.TestCase):
...
@@ -97,8 +97,8 @@ class TestBenchServing(unittest.TestCase):
if
is_in_ci
():
if
is_in_ci
():
self
.
assertLess
(
res
[
"median_e2e_latency_ms"
],
12000
)
self
.
assertLess
(
res
[
"median_e2e_latency_ms"
],
12000
)
self
.
assertLess
(
res
[
"median_ttft_ms"
],
8
0
)
self
.
assertLess
(
res
[
"median_ttft_ms"
],
8
6
)
self
.
assertLess
(
res
[
"median_itl_ms"
],
1
1
)
self
.
assertLess
(
res
[
"median_itl_ms"
],
1
0
)
def
test_moe_offline_throughput_default
(
self
):
def
test_moe_offline_throughput_default
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
...
test/srt/test_json_constrained.py
View file @
7d671e4a
...
@@ -78,10 +78,11 @@ class TestJSONConstrained(unittest.TestCase):
...
@@ -78,10 +78,11 @@ class TestJSONConstrained(unittest.TestCase):
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
# Make sure jump forward is triggered
# Make sure jump forward is triggered
self
.
assertGreater
(
# NOTE: This is skipped because overlap scheduler does not support jump forward
ret
[
"meta_info"
][
"completion_tokens"
],
# self.assertGreater(
ret
[
"meta_info"
][
"completion_tokens_wo_jump_forward"
],
# ret["meta_info"]["completion_tokens"],
)
# ret["meta_info"]["completion_tokens_wo_jump_forward"],
# )
def
test_json_generate
(
self
):
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
...
...
test/srt/test_moe_eval_accuracy_large.py
View file @
7d671e4a
...
@@ -59,7 +59,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
...
@@ -59,7 +59,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
self
.
assertGreater
(
metrics
[
"score"
],
0.4
1
)
self
.
assertGreater
(
metrics
[
"score"
],
0.4
0
)
def
test_mgsm_en
(
self
):
def
test_mgsm_en
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
...
...
test/srt/test_overlap_schedule.py
→
test/srt/test_
non_
overlap_schedule
r
.py
View file @
7d671e4a
...
@@ -12,22 +12,22 @@ from sglang.test.test_utils import run_mmlu_test
...
@@ -12,22 +12,22 @@ from sglang.test.test_utils import run_mmlu_test
class
TestOverlapSchedule
(
unittest
.
TestCase
):
class
TestOverlapSchedule
(
unittest
.
TestCase
):
def
test_no_radix_attention_chunked_prefill
(
self
):
def
test_no_radix_attention_chunked_prefill
(
self
):
run_mmlu_test
(
run_mmlu_test
(
disable_radix_cache
=
True
,
chunked_prefill_size
=
32
,
en
able_overlap
=
True
disable_radix_cache
=
True
,
chunked_prefill_size
=
32
,
dis
able_overlap
=
True
)
)
def
test_no_radix_attention_no_chunked_prefill
(
self
):
def
test_no_radix_attention_no_chunked_prefill
(
self
):
run_mmlu_test
(
run_mmlu_test
(
disable_radix_cache
=
True
,
chunked_prefill_size
=-
1
,
en
able_overlap
=
True
disable_radix_cache
=
True
,
chunked_prefill_size
=-
1
,
dis
able_overlap
=
True
)
)
def
test_radix_attention_chunked_prefill
(
self
):
def
test_radix_attention_chunked_prefill
(
self
):
run_mmlu_test
(
run_mmlu_test
(
disable_radix_cache
=
False
,
chunked_prefill_size
=
32
,
en
able_overlap
=
True
disable_radix_cache
=
False
,
chunked_prefill_size
=
32
,
dis
able_overlap
=
True
)
)
def
test_radix_attention_no_chunked_prefill
(
self
):
def
test_radix_attention_no_chunked_prefill
(
self
):
run_mmlu_test
(
run_mmlu_test
(
disable_radix_cache
=
False
,
chunked_prefill_size
=-
1
,
en
able_overlap
=
True
disable_radix_cache
=
False
,
chunked_prefill_size
=-
1
,
dis
able_overlap
=
True
)
)
...
...
test/srt/test_radix_attention.py
View file @
7d671e4a
...
@@ -107,7 +107,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
...
@@ -107,7 +107,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
)
)
class
TestRadixCacheOverlapLPM
(
TestRadixCacheFCFS
):
class
TestRadixCache
Non
OverlapLPM
(
TestRadixCacheFCFS
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...
@@ -117,7 +117,7 @@ class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
...
@@ -117,7 +117,7 @@ class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
other_args
=
[
"--
en
able-overlap-schedule"
,
"--
dis
able-overlap-schedule"
,
"--chunked-prefill-size"
,
"--chunked-prefill-size"
,
"128"
,
"128"
,
"--max-total-tokens"
,
"--max-total-tokens"
,
...
...
test/srt/test_torch_compile.py
View file @
7d671e4a
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
...
@@ -56,14 +57,14 @@ class TestTorchCompile(unittest.TestCase):
...
@@ -56,14 +57,14 @@ class TestTorchCompile(unittest.TestCase):
return
response
.
json
()
return
response
.
json
()
def
test_throughput
(
self
):
def
test_throughput
(
self
):
import
time
# Warmup
res
=
self
.
run_decode
(
16
)
max_tokens
=
256
max_tokens
=
256
tic
=
time
.
time
()
tic
=
time
.
time
()
res
=
self
.
run_decode
(
max_tokens
)
res
=
self
.
run_decode
(
max_tokens
)
tok
=
time
.
time
()
tok
=
time
.
time
()
print
(
res
[
"text"
]
)
print
(
f
"
{
res
=
}
"
)
throughput
=
max_tokens
/
(
tok
-
tic
)
throughput
=
max_tokens
/
(
tok
-
tic
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
self
.
assertGreaterEqual
(
throughput
,
152
)
self
.
assertGreaterEqual
(
throughput
,
152
)
...
...
test/srt/test_torch_compile_moe.py
View file @
7d671e4a
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
...
@@ -56,10 +57,10 @@ class TestTorchCompile(unittest.TestCase):
...
@@ -56,10 +57,10 @@ class TestTorchCompile(unittest.TestCase):
return
response
.
json
()
return
response
.
json
()
def
test_throughput
(
self
):
def
test_throughput
(
self
):
import
time
# Warmup
res
=
self
.
run_decode
(
16
)
max_tokens
=
256
max_tokens
=
256
tic
=
time
.
time
()
tic
=
time
.
time
()
res
=
self
.
run_decode
(
max_tokens
)
res
=
self
.
run_decode
(
max_tokens
)
tok
=
time
.
time
()
tok
=
time
.
time
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment