Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7d671e4a
Unverified
Commit
7d671e4a
authored
Nov 19, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 19, 2024
Browse files
Enable overlap by default (#2067)
parent
699384cb
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
91 additions
and
74 deletions
+91
-74
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+4
-18
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+4
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+22
-2
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+10
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-3
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+2
-7
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+20
-10
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+5
-5
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+2
-2
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+5
-4
test/srt/test_moe_eval_accuracy_large.py
test/srt/test_moe_eval_accuracy_large.py
+1
-1
test/srt/test_non_overlap_scheduler.py
test/srt/test_non_overlap_scheduler.py
+4
-4
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+2
-2
test/srt/test_torch_compile.py
test/srt/test_torch_compile.py
+4
-3
test/srt/test_torch_compile_moe.py
test/srt/test_torch_compile_moe.py
+3
-2
No files found.
python/sglang/bench_latency.py
View file @
7d671e4a
...
...
@@ -220,7 +220,8 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return
reqs
def
_extend
(
reqs
,
model_runner
):
@
torch
.
no_grad
def
extend
(
reqs
,
model_runner
):
batch
=
ScheduleBatch
.
init_new
(
reqs
=
reqs
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
...
...
@@ -236,15 +237,8 @@ def _extend(reqs, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
def
extend
(
reqs
,
model_runner
):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode
=
not
model_runner
.
torch_tp_applied
with
torch
.
inference_mode
(
use_inf_mode
):
return
_extend
(
reqs
,
model_runner
)
def
_decode
(
input_token_ids
,
batch
,
model_runner
):
@
torch
.
no_grad
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
output_ids
=
input_token_ids
batch
.
prepare_for_decode
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
@@ -254,14 +248,6 @@ def _decode(input_token_ids, batch, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
def
decode
(
input_token_ids
,
batch
,
model_runner
):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode
=
not
model_runner
.
torch_tp_applied
with
torch
.
inference_mode
(
use_inf_mode
):
return
_decode
(
input_token_ids
,
batch
,
model_runner
)
def
correctness_test
(
server_args
,
port_args
,
...
...
python/sglang/srt/constrained/outlines_backend.py
View file @
7d671e4a
...
...
@@ -87,9 +87,12 @@ class OutlinesGrammar(BaseGrammarObject):
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
).
to
(
vocab_mask
.
device
,
non_blocking
=
True
)
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
.
fill_
(
1
)
vocab_mask
[
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
]
=
0
vocab_mask
.
scatter_
(
0
,
tokens
,
torch
.
zeros_like
(
tokens
,
dtype
=
torch
.
bool
))
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7d671e4a
...
...
@@ -899,10 +899,7 @@ class ScheduleBatch:
self
.
input_ids
=
self
.
output_ids
self
.
output_ids
=
None
if
self
.
sampling_info
.
penalizer_orchestrator
:
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
self
.
input_ids
)
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
self
.
input_ids
)
# Alloc mem
bs
=
len
(
self
.
reqs
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
7d671e4a
...
...
@@ -30,7 +30,7 @@ import torch
import
zmq
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
...
...
@@ -102,7 +102,7 @@ class Scheduler:
self
.
disable_jump_forward
=
server_args
.
disable_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
enable_overlap
=
server_args
.
en
able_overlap_schedule
self
.
enable_overlap
=
not
server_args
.
dis
able_overlap_schedule
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
...
...
@@ -159,6 +159,23 @@ class Scheduler:
trust_remote_code
=
server_args
.
trust_remote_code
,
)
# Check whether overlap can be enabled
if
not
self
.
is_generation
:
self
.
enable_overlap
=
False
logger
.
info
(
"Overlap scheduler is disabled for embedding models."
)
if
(
server_args
.
attention_backend
==
"triton"
or
server_args
.
enable_double_sparsity
or
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
not
self
.
server_args
.
disable_mla
)
):
self
.
enable_overlap
=
False
logger
.
info
(
"Overlap scheduler is disabled if using triton attention backend."
)
# Launch a tensor parallel worker
if
self
.
enable_overlap
:
TpWorkerClass
=
TpModelWorkerClient
...
...
@@ -903,6 +920,7 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
...
...
@@ -958,6 +976,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
...
...
@@ -1031,6 +1050,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
7d671e4a
...
...
@@ -157,14 +157,19 @@ class TpModelWorkerClient:
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
# A cuda stream sync here to avoid the cuda illegal memory access error.
_
=
model_worker_batch
.
seq_lens
[
0
].
item
()
torch
.
cuda
.
current_stream
().
synchronize
()
# Push a new batch to the queue
model_worker_batch
.
sampling_info
=
dataclasses
.
replace
(
model_worker_batch
.
sampling_info
,
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info
=
model_worker_batch
.
sampling_info
sampling_info
.
update_penalties
()
model_worker_batch
.
sampling_info
=
self
.
cur_sampling_info
=
dataclasses
.
replace
(
sampling_info
,
sampling_info_done
=
threading
.
Event
(),
scaling_penalties
=
sampling_info
.
scaling_penalties
,
linear_penalties
=
sampling_info
.
linear_penalties
,
)
self
.
cur_sampling_info
=
model_worker_batch
.
sampling_info
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
# Allocate output future objects
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
7d671e4a
...
...
@@ -116,7 +116,7 @@ class ModelRunner:
)
if
self
.
is_multimodal
:
logger
.
warning
(
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
...
...
@@ -636,13 +636,11 @@ class ModelRunner:
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
sampling_info
=
forward_batch
.
sampling_info
if
sampling_info
.
sampling_info_done
:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
if
sampling_info
.
grammars
:
sampling_info
.
sampling_info_done
.
wait
()
sampling_info
.
update_penalties
()
else
:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info
.
update_regex_vocab_mask
()
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
7d671e4a
...
...
@@ -132,9 +132,6 @@ class SamplingBatchInfo:
return
len
(
self
.
temperatures
)
def
update_penalties
(
self
):
if
not
self
.
penalizer_orchestrator
:
return
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
...
...
@@ -176,8 +173,7 @@ class SamplingBatchInfo:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
"temperatures"
,
...
...
@@ -216,8 +212,7 @@ class SamplingBatchInfo:
return
None
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
"temperatures"
,
...
...
python/sglang/srt/server_args.py
View file @
7d671e4a
...
...
@@ -123,7 +123,7 @@ class ServerArgs:
disable_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_mla
:
bool
=
False
en
able_overlap_schedule
:
bool
=
False
dis
able_overlap_schedule
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_torch_compile
:
bool
=
False
...
...
@@ -172,9 +172,7 @@ class ServerArgs:
if
gpu_mem
<
25000
:
self
.
chunked_prefill_size
//=
4
# make it 2048
self
.
cuda_graph_max_bs
=
4
logger
.
warning
(
"Automatically adjust --chunked-prefill-size for small GPUs."
)
logger
.
info
(
"Automatically adjust --chunked-prefill-size for small GPUs."
)
if
not
is_flashinfer_available
():
self
.
attention_backend
=
"triton"
...
...
@@ -192,15 +190,22 @@ class ServerArgs:
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
cuda_graph_max_bs
=
min
(
self
.
cuda_graph_max_bs
,
96
)
self
.
schedule_conservativeness
=
self
.
schedule_conservativeness
*
0.3
self
.
en
able_overlap_schedule
=
Fals
e
logger
.
warning
(
self
.
dis
able_overlap_schedule
=
Tru
e
logger
.
info
(
f
"DP attention is enabled. The chunked prefill size is adjusted to
{
self
.
chunked_prefill_size
}
to avoid MoE kernel issues. "
f
"The CUDA graph max batch size is adjusted to
{
self
.
cuda_graph_max_bs
}
. "
f
"The schedule conservativeness is adjusted to
{
self
.
schedule_conservativeness
}
. "
"Data parallel size is adjusted to be the same as tensor parallel size."
"Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap schedule is disabled."
)
if
self
.
enable_mixed_chunk
:
logger
.
info
(
"Overlap schedule is disabled because mixed-style chunked prefill is enabled."
)
self
.
disable_overlap_schedule
=
True
if
self
.
en
able_overlap_schedule
:
if
not
self
.
dis
able_overlap_schedule
:
self
.
disable_jump_forward
=
True
@
staticmethod
...
...
@@ -624,9 +629,9 @@ class ServerArgs:
help
=
"Disable the NaN detection for better performance."
,
)
parser
.
add_argument
(
"--
en
able-overlap-schedule"
,
"--
dis
able-overlap-schedule"
,
action
=
"store_true"
,
help
=
"
Overlap
the CPU scheduler with GPU model worker.
Experimental feature.
"
,
help
=
"
Disable the overlap scheduler, which overlaps
the CPU scheduler with GPU model worker."
,
)
parser
.
add_argument
(
"--enable-mixed-chunk"
,
...
...
@@ -692,6 +697,11 @@ class ServerArgs:
)
# Deprecated arguments
parser
.
add_argument
(
"--enable-overlap-schedule"
,
action
=
DeprecatedAction
,
help
=
"'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument."
,
)
parser
.
add_argument
(
"--disable-flashinfer"
,
action
=
DeprecatedAction
,
...
...
python/sglang/test/test_utils.py
View file @
7d671e4a
...
...
@@ -670,7 +670,7 @@ def run_and_check_memory_leak(
workload_func
,
disable_radix_cache
,
enable_mixed_chunk
,
en
able_overlap
,
dis
able_overlap
,
chunked_prefill_size
,
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
...
...
@@ -678,8 +678,8 @@ def run_and_check_memory_leak(
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
other_args
+=
[
"--enable-mixed-chunk"
]
if
en
able_overlap
:
other_args
+=
[
"--
en
able-overlap-schedule"
]
if
dis
able_overlap
:
other_args
+=
[
"--
dis
able-overlap-schedule"
]
model
=
DEFAULT_MODEL_NAME_FOR_TEST
port
=
random
.
randint
(
4000
,
5000
)
...
...
@@ -731,7 +731,7 @@ def run_and_check_memory_leak(
def
run_mmlu_test
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
,
en
able_overlap
=
False
,
dis
able_overlap
=
False
,
chunked_prefill_size
=
32
,
):
def
workload_func
(
base_url
,
model
):
...
...
@@ -754,7 +754,7 @@ def run_mmlu_test(
workload_func
,
disable_radix_cache
,
enable_mixed_chunk
,
en
able_overlap
,
dis
able_overlap
,
chunked_prefill_size
,
)
...
...
test/srt/run_suite.py
View file @
7d671e4a
...
...
@@ -17,8 +17,8 @@ suites = {
"test_json_constrained.py"
,
"test_large_max_new_tokens.py"
,
"test_metrics.py"
,
"test_non_overlap_scheduler.py"
,
"test_openai_server.py"
,
"test_overlap_schedule.py"
,
"test_pytorch_sampling_backend.py"
,
"test_radix_attention.py"
,
"test_retract_decode.py"
,
...
...
test/srt/test_bench_serving.py
View file @
7d671e4a
...
...
@@ -97,8 +97,8 @@ class TestBenchServing(unittest.TestCase):
if
is_in_ci
():
self
.
assertLess
(
res
[
"median_e2e_latency_ms"
],
12000
)
self
.
assertLess
(
res
[
"median_ttft_ms"
],
8
0
)
self
.
assertLess
(
res
[
"median_itl_ms"
],
1
1
)
self
.
assertLess
(
res
[
"median_ttft_ms"
],
8
6
)
self
.
assertLess
(
res
[
"median_itl_ms"
],
1
0
)
def
test_moe_offline_throughput_default
(
self
):
res
=
run_bench_serving
(
...
...
test/srt/test_json_constrained.py
View file @
7d671e4a
...
...
@@ -78,10 +78,11 @@ class TestJSONConstrained(unittest.TestCase):
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
# Make sure jump forward is triggered
self
.
assertGreater
(
ret
[
"meta_info"
][
"completion_tokens"
],
ret
[
"meta_info"
][
"completion_tokens_wo_jump_forward"
],
)
# NOTE: This is skipped because overlap scheduler does not support jump forward
# self.assertGreater(
# ret["meta_info"]["completion_tokens"],
# ret["meta_info"]["completion_tokens_wo_jump_forward"],
# )
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
...
...
test/srt/test_moe_eval_accuracy_large.py
View file @
7d671e4a
...
...
@@ -59,7 +59,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
)
metrics
=
run_eval
(
args
)
self
.
assertGreater
(
metrics
[
"score"
],
0.4
1
)
self
.
assertGreater
(
metrics
[
"score"
],
0.4
0
)
def
test_mgsm_en
(
self
):
args
=
SimpleNamespace
(
...
...
test/srt/test_overlap_schedule.py
→
test/srt/test_
non_
overlap_schedule
r
.py
View file @
7d671e4a
...
...
@@ -12,22 +12,22 @@ from sglang.test.test_utils import run_mmlu_test
class
TestOverlapSchedule
(
unittest
.
TestCase
):
def
test_no_radix_attention_chunked_prefill
(
self
):
run_mmlu_test
(
disable_radix_cache
=
True
,
chunked_prefill_size
=
32
,
en
able_overlap
=
True
disable_radix_cache
=
True
,
chunked_prefill_size
=
32
,
dis
able_overlap
=
True
)
def
test_no_radix_attention_no_chunked_prefill
(
self
):
run_mmlu_test
(
disable_radix_cache
=
True
,
chunked_prefill_size
=-
1
,
en
able_overlap
=
True
disable_radix_cache
=
True
,
chunked_prefill_size
=-
1
,
dis
able_overlap
=
True
)
def
test_radix_attention_chunked_prefill
(
self
):
run_mmlu_test
(
disable_radix_cache
=
False
,
chunked_prefill_size
=
32
,
en
able_overlap
=
True
disable_radix_cache
=
False
,
chunked_prefill_size
=
32
,
dis
able_overlap
=
True
)
def
test_radix_attention_no_chunked_prefill
(
self
):
run_mmlu_test
(
disable_radix_cache
=
False
,
chunked_prefill_size
=-
1
,
en
able_overlap
=
True
disable_radix_cache
=
False
,
chunked_prefill_size
=-
1
,
dis
able_overlap
=
True
)
...
...
test/srt/test_radix_attention.py
View file @
7d671e4a
...
...
@@ -107,7 +107,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
)
class
TestRadixCacheOverlapLPM
(
TestRadixCacheFCFS
):
class
TestRadixCache
Non
OverlapLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...
...
@@ -117,7 +117,7 @@ class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--
en
able-overlap-schedule"
,
"--
dis
able-overlap-schedule"
,
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
...
...
test/srt/test_torch_compile.py
View file @
7d671e4a
import
time
import
unittest
from
types
import
SimpleNamespace
...
...
@@ -56,14 +57,14 @@ class TestTorchCompile(unittest.TestCase):
return
response
.
json
()
def
test_throughput
(
self
):
import
time
# Warmup
res
=
self
.
run_decode
(
16
)
max_tokens
=
256
tic
=
time
.
time
()
res
=
self
.
run_decode
(
max_tokens
)
tok
=
time
.
time
()
print
(
res
[
"text"
]
)
print
(
f
"
{
res
=
}
"
)
throughput
=
max_tokens
/
(
tok
-
tic
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
self
.
assertGreaterEqual
(
throughput
,
152
)
...
...
test/srt/test_torch_compile_moe.py
View file @
7d671e4a
import
time
import
unittest
from
types
import
SimpleNamespace
...
...
@@ -56,10 +57,10 @@ class TestTorchCompile(unittest.TestCase):
return
response
.
json
()
def
test_throughput
(
self
):
import
time
# Warmup
res
=
self
.
run_decode
(
16
)
max_tokens
=
256
tic
=
time
.
time
()
res
=
self
.
run_decode
(
max_tokens
)
tok
=
time
.
time
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment