Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4ede6770
Unverified
Commit
4ede6770
authored
Mar 30, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 30, 2025
Browse files
Fix retract for page size > 1 (#4914)
parent
b26bc86b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
68 additions
and
120 deletions
+68
-120
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+2
-44
python/sglang/srt/constrained/base_grammar_backend.py
python/sglang/srt/constrained/base_grammar_backend.py
+5
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-3
python/sglang/srt/metrics/collector.py
python/sglang/srt/metrics/collector.py
+23
-53
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+12
-8
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+0
-4
test/srt/models/lora/test_lora_tp.py
test/srt/models/lora/test_lora_tp.py
+3
-3
test/srt/run_suite.py
test/srt/run_suite.py
+13
-3
test/srt/test_dp_attention.py
test/srt/test_dp_attention.py
+4
-0
test/srt/test_metrics.py
test/srt/test_metrics.py
+0
-1
No files found.
.github/workflows/pr-test.yml
View file @
4ede6770
...
...
@@ -87,53 +87,11 @@ jobs:
run
:
|
bash scripts/ci_install_dependency.sh
-
name
:
Test data parallelism (DP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_data_parallelism.py
-
name
:
Test data parallelism attention (DP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_dp_attention.py
-
name
:
Test update weights from distributed
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_update_weights_from_distributed.py
-
name
:
Test VerlEngine
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_verl_engine.py
-
name
:
Test Patch Torch
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_patch_torch.py
-
name
:
Test expert parallelism (EP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_moe_ep.py
-
name
:
Test torch compile (TP=2)
-
name
:
Run test
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_mla_tp.py
-
name
:
Test lora tensor parallelism (TP=2)
timeout-minutes
:
10
run
:
|
cd test/srt/models/lora
python3 test_lora_tp.py
python3 run_suite.py --suite per-commit-2-gpu
performance-test-1-gpu-part-1
:
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
...
...
python/sglang/srt/constrained/base_grammar_backend.py
View file @
4ede6770
...
...
@@ -169,7 +169,9 @@ class BaseGrammarBackend(ABC):
self
.
cache
.
clear
()
def
create_grammar_backend
(
server_args
:
ServerArgs
,
tokenizer
,
vocab_size
):
def
create_grammar_backend
(
server_args
:
ServerArgs
,
tokenizer
,
vocab_size
:
int
)
->
Optional
[
BaseGrammarBackend
]:
if
server_args
.
grammar_backend
==
"outlines"
:
from
sglang.srt.constrained.outlines_backend
import
OutlinesGrammarBackend
...
...
@@ -188,6 +190,8 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
tokenizer
=
tokenizer
,
whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
elif
server_args
.
grammar_backend
==
"none"
:
return
None
else
:
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
4ede6770
...
...
@@ -599,6 +599,7 @@ class Req:
self
.
extend_logprob_start_len
=
0
self
.
is_chunked
=
0
self
.
req_pool_idx
=
None
self
.
already_computed
=
0
def
__repr__
(
self
):
return
(
...
...
@@ -960,8 +961,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# If req.input_embeds is already a list, append its content directly
input_embeds
.
extend
(
req
.
input_embeds
)
# Use extend to avoid nesting
if
req
.
is_retracted
:
req
.
already_computed
=
0
req
.
cached_tokens
+=
pre_len
-
req
.
already_computed
req
.
already_computed
=
seq_len
req
.
is_retracted
=
False
...
...
@@ -1189,7 +1188,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
else
:
# TODO: apply more fine-grained retraction
last_uncached_pos
=
len
(
req
.
prefix_indices
)
last_uncached_pos
=
(
(
len
(
req
.
prefix_indices
)
+
server_args
.
page_size
-
1
)
//
server_args
.
page_size
*
server_args
.
page_size
)
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
last_uncached_pos
:
seq_lens_cpu
[
idx
]
]
...
...
python/sglang/srt/metrics/collector.py
View file @
4ede6770
...
...
@@ -33,7 +33,7 @@ class SchedulerMetricsCollector:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from
prometheus_client
import
Gauge
from
prometheus_client
import
Gauge
,
Histogram
self
.
labels
=
labels
self
.
last_log_time
=
time
.
time
()
...
...
@@ -139,10 +139,10 @@ class TokenizerMetricsCollector:
labelnames
=
labels
.
keys
(),
buckets
=
[
0.1
,
0.
3
,
0.
5
,
0.
7
,
0.
9
,
0.
2
,
0.
4
,
0.
6
,
0.
8
,
1
,
2
,
4
,
...
...
@@ -153,36 +153,9 @@ class TokenizerMetricsCollector:
40
,
60
,
80
,
120
,
160
,
],
)
self
.
histogram_time_per_output_token
=
Histogram
(
name
=
"sglang:time_per_output_token_seconds"
,
documentation
=
"Histogram of time per output token in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
0.002
,
0.005
,
0.010
,
0.020
,
0.030
,
0.040
,
0.050
,
0.060
,
0.070
,
0.080
,
0.090
,
0.100
,
0.150
,
0.200
,
0.300
,
0.400
,
0.600
,
0.800
,
1.000
,
2.000
,
100
,
200
,
400
,
],
)
...
...
@@ -202,17 +175,18 @@ class TokenizerMetricsCollector:
0.030
,
0.035
,
0.040
,
0.0
5
0
,
0.0
75
,
0.0
6
0
,
0.0
80
,
0.100
,
0.150
,
0.200
,
0.300
,
0.400
,
0.
5
00
,
0.
75
0
,
0.
6
00
,
0.
80
0
,
1.000
,
2.000
,
4.000
,
6.000
,
8.000
,
],
)
...
...
@@ -224,23 +198,22 @@ class TokenizerMetricsCollector:
0.1
,
0.2
,
0.4
,
0.6
,
0.8
,
1
,
2
,
5
,
4
,
6
,
8
,
10
,
20
,
40
,
60
,
80
,
100
,
150
,
200
,
250
,
300
,
350
,
500
,
1000
,
400
,
800
,
],
)
...
...
@@ -256,13 +229,10 @@ class TokenizerMetricsCollector:
):
self
.
prompt_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
prompt_tokens
)
self
.
generation_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
generation_tokens
)
self
.
cached_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
cached_tokens
)
if
cached_tokens
>
0
:
self
.
cached_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
cached_tokens
)
self
.
num_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
self
.
_log_histogram
(
self
.
histogram_e2e_request_latency
,
e2e_latency
)
if
generation_tokens
>=
1
:
self
.
histogram_time_per_output_token
.
labels
(
**
self
.
labels
).
observe
(
e2e_latency
/
generation_tokens
)
def
observe_time_to_first_token
(
self
,
value
:
float
):
self
.
histogram_time_to_first_token
.
labels
(
**
self
.
labels
).
observe
(
value
)
...
...
python/sglang/srt/server_args.py
View file @
4ede6770
...
...
@@ -128,7 +128,7 @@ class ServerArgs:
# Kernel backend
attention_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
grammar_backend
:
Optional
[
str
]
=
"xgrammar"
grammar_backend
:
Optional
[
str
]
=
None
# Speculative decoding
speculative_algorithm
:
Optional
[
str
]
=
None
...
...
@@ -193,6 +193,13 @@ class ServerArgs:
disaggregation_bootstrap_port
:
int
=
8998
def
__post_init__
(
self
):
# Expert parallelism
if
self
.
enable_ep_moe
:
self
.
ep_size
=
self
.
tp_size
logger
.
info
(
f
"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[
{
self
.
tp_size
}
]."
)
# Set missing default values
if
self
.
tokenizer_path
is
None
:
self
.
tokenizer_path
=
self
.
model_path
...
...
@@ -274,12 +281,9 @@ class ServerArgs:
)
self
.
disable_cuda_graph
=
True
# Expert parallelism
if
self
.
enable_ep_moe
:
self
.
ep_size
=
self
.
tp_size
logger
.
info
(
f
"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[
{
self
.
tp_size
}
]."
)
# Choose grammar backend
if
self
.
grammar_backend
is
None
:
self
.
grammar_backend
=
"xgrammar"
# Data parallelism attention
if
self
.
enable_dp_attention
:
...
...
@@ -813,7 +817,7 @@ class ServerArgs:
parser
.
add_argument
(
"--grammar-backend"
,
type
=
str
,
choices
=
[
"xgrammar"
,
"outlines"
,
"llguidance"
],
choices
=
[
"xgrammar"
,
"outlines"
,
"llguidance"
,
"none"
],
default
=
ServerArgs
.
grammar_backend
,
help
=
"Choose the backend for grammar-guided decoding."
,
)
...
...
python/sglang/test/test_utils.py
View file @
4ede6770
...
...
@@ -1012,9 +1012,6 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
class
CustomTestCase
(
unittest
.
TestCase
):
pass
"""
def
_callTestMethod
(
self
,
method
):
max_retry
=
int
(
os
.
environ
.
get
(
"SGLANG_TEST_MAX_RETRY"
,
"2"
if
is_in_ci
()
else
"0"
)
...
...
@@ -1023,4 +1020,3 @@ class CustomTestCase(unittest.TestCase):
lambda
:
super
(
CustomTestCase
,
self
).
_callTestMethod
(
method
),
max_retry
=
max_retry
,
)
"""
test/srt/models/lora/test_lora_tp.py
View file @
4ede6770
...
...
@@ -33,6 +33,9 @@ CI_LORA_MODELS = [
],
max_loras_per_batch
=
1
,
),
]
ALL_OTHER_LORA_MODELS
=
[
LoRAModelCase
(
base
=
"meta-llama/Llama-3.1-8B-Instruct"
,
adaptors
=
[
...
...
@@ -43,9 +46,6 @@ CI_LORA_MODELS = [
],
max_loras_per_batch
=
1
,
),
]
ALL_OTHER_LORA_MODELS
=
[
LoRAModelCase
(
base
=
"meta-llama/Llama-2-7b-hf"
,
adaptors
=
[
LoRAAdaptor
(
name
=
"winddude/wizardLM-LlaMA-LoRA-7B"
)],
...
...
test/srt/run_suite.py
View file @
4ede6770
...
...
@@ -16,7 +16,7 @@ suites = {
TestFile
(
"models/lora/test_lora.py"
,
76
),
TestFile
(
"models/lora/test_lora_backend.py"
,
420
),
TestFile
(
"models/lora/test_multi_lora_backend.py"
,
144
),
TestFile
(
"models/test_embedding_models.py"
,
119
),
TestFile
(
"models/test_embedding_models.py"
,
35
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_grok_models.py"
,
60
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
...
...
@@ -38,7 +38,7 @@ suites = {
TestFile
(
"test_metrics.py"
,
32
),
TestFile
(
"test_mla.py"
,
92
),
TestFile
(
"test_mla_deepseek_v3.py"
,
221
),
TestFile
(
"test_mla_int8_deepseek_v3.py"
,
421
),
TestFile
(
"test_mla_int8_deepseek_v3.py"
,
522
),
TestFile
(
"test_mla_flashinfer.py"
,
395
),
TestFile
(
"test_mla_fp8.py"
,
93
),
TestFile
(
"test_no_chunked_prefill.py"
,
126
),
...
...
@@ -59,7 +59,7 @@ suites = {
TestFile
(
"test_srt_endpoint.py"
,
94
),
TestFile
(
"test_torch_compile.py"
,
76
),
TestFile
(
"test_torch_compile_moe.py"
,
85
),
TestFile
(
"test_torch_native_attention_backend.py"
,
1
49
),
TestFile
(
"test_torch_native_attention_backend.py"
,
1
23
),
TestFile
(
"test_torchao.py"
,
70
),
TestFile
(
"test_triton_attention_kernels.py"
,
4
),
TestFile
(
"test_triton_attention_backend.py"
,
134
),
...
...
@@ -76,6 +76,16 @@ suites = {
TestFile
(
"test_hicache.py"
,
60
),
TestFile
(
"test_hicache_mla.py"
,
90
),
],
"per-commit-2-gpu"
:
[
TestFile
(
"test_data_parallelism.py"
,
90
),
TestFile
(
"test_dp_attention.py"
,
90
),
TestFile
(
"test_update_weights_from_distributed.py"
,
100
),
TestFile
(
"test_verl_engine.py"
,
100
),
TestFile
(
"test_patch_torch.py"
,
30
),
TestFile
(
"test_moe_ep.py"
,
220
),
TestFile
(
"test_mla_tp.py"
,
420
),
TestFile
(
"test_lora_tp.py"
,
300
),
],
"nightly"
:
[
TestFile
(
"test_nightly_gsm8k_eval.py"
),
],
...
...
test/srt/test_dp_attention.py
View file @
4ede6770
...
...
@@ -60,3 +60,7 @@ class TestDPAttentionDP2TP2(CustomTestCase):
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"score"
],
0.8
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_metrics.py
View file @
4ede6770
...
...
@@ -63,7 +63,6 @@ class TestEnableMetrics(CustomTestCase):
"sglang:cached_tokens_total"
,
"sglang:num_requests_total"
,
"sglang:time_to_first_token_seconds"
,
"sglang:time_per_output_token_seconds"
,
"sglang:inter_token_latency_seconds"
,
"sglang:e2e_request_latency_seconds"
,
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment