Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4ede6770
Unverified
Commit
4ede6770
authored
Mar 30, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 30, 2025
Browse files
Fix retract for page size > 1 (#4914)
parent
b26bc86b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
68 additions
and
120 deletions
+68
-120
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+2
-44
python/sglang/srt/constrained/base_grammar_backend.py
python/sglang/srt/constrained/base_grammar_backend.py
+5
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-3
python/sglang/srt/metrics/collector.py
python/sglang/srt/metrics/collector.py
+23
-53
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+12
-8
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+0
-4
test/srt/models/lora/test_lora_tp.py
test/srt/models/lora/test_lora_tp.py
+3
-3
test/srt/run_suite.py
test/srt/run_suite.py
+13
-3
test/srt/test_dp_attention.py
test/srt/test_dp_attention.py
+4
-0
test/srt/test_metrics.py
test/srt/test_metrics.py
+0
-1
No files found.
.github/workflows/pr-test.yml
View file @
4ede6770
...
@@ -87,53 +87,11 @@ jobs:
...
@@ -87,53 +87,11 @@ jobs:
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
-
name
:
Test data parallelism (DP=2)
-
name
:
Run test
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_data_parallelism.py
-
name
:
Test data parallelism attention (DP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_dp_attention.py
-
name
:
Test update weights from distributed
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_update_weights_from_distributed.py
-
name
:
Test VerlEngine
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_verl_engine.py
-
name
:
Test Patch Torch
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_patch_torch.py
-
name
:
Test expert parallelism (EP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_moe_ep.py
-
name
:
Test torch compile (TP=2)
timeout-minutes
:
10
timeout-minutes
:
10
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 test_mla_tp.py
python3 run_suite.py --suite per-commit-2-gpu
-
name
:
Test lora tensor parallelism (TP=2)
timeout-minutes
:
10
run
:
|
cd test/srt/models/lora
python3 test_lora_tp.py
performance-test-1-gpu-part-1
:
performance-test-1-gpu-part-1
:
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
...
...
python/sglang/srt/constrained/base_grammar_backend.py
View file @
4ede6770
...
@@ -169,7 +169,9 @@ class BaseGrammarBackend(ABC):
...
@@ -169,7 +169,9 @@ class BaseGrammarBackend(ABC):
self
.
cache
.
clear
()
self
.
cache
.
clear
()
def
create_grammar_backend
(
server_args
:
ServerArgs
,
tokenizer
,
vocab_size
):
def
create_grammar_backend
(
server_args
:
ServerArgs
,
tokenizer
,
vocab_size
:
int
)
->
Optional
[
BaseGrammarBackend
]:
if
server_args
.
grammar_backend
==
"outlines"
:
if
server_args
.
grammar_backend
==
"outlines"
:
from
sglang.srt.constrained.outlines_backend
import
OutlinesGrammarBackend
from
sglang.srt.constrained.outlines_backend
import
OutlinesGrammarBackend
...
@@ -188,6 +190,8 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
...
@@ -188,6 +190,8 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
)
elif
server_args
.
grammar_backend
==
"none"
:
return
None
else
:
else
:
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
4ede6770
...
@@ -599,6 +599,7 @@ class Req:
...
@@ -599,6 +599,7 @@ class Req:
self
.
extend_logprob_start_len
=
0
self
.
extend_logprob_start_len
=
0
self
.
is_chunked
=
0
self
.
is_chunked
=
0
self
.
req_pool_idx
=
None
self
.
req_pool_idx
=
None
self
.
already_computed
=
0
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
return
(
...
@@ -960,8 +961,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -960,8 +961,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# If req.input_embeds is already a list, append its content directly
# If req.input_embeds is already a list, append its content directly
input_embeds
.
extend
(
req
.
input_embeds
)
# Use extend to avoid nesting
input_embeds
.
extend
(
req
.
input_embeds
)
# Use extend to avoid nesting
if
req
.
is_retracted
:
req
.
already_computed
=
0
req
.
cached_tokens
+=
pre_len
-
req
.
already_computed
req
.
cached_tokens
+=
pre_len
-
req
.
already_computed
req
.
already_computed
=
seq_len
req
.
already_computed
=
seq_len
req
.
is_retracted
=
False
req
.
is_retracted
=
False
...
@@ -1189,7 +1188,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1189,7 +1188,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
else
:
else
:
# TODO: apply more fine-grained retraction
# TODO: apply more fine-grained retraction
last_uncached_pos
=
len
(
req
.
prefix_indices
)
last_uncached_pos
=
(
(
len
(
req
.
prefix_indices
)
+
server_args
.
page_size
-
1
)
//
server_args
.
page_size
*
server_args
.
page_size
)
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
last_uncached_pos
:
seq_lens_cpu
[
idx
]
req
.
req_pool_idx
,
last_uncached_pos
:
seq_lens_cpu
[
idx
]
]
]
...
...
python/sglang/srt/metrics/collector.py
View file @
4ede6770
...
@@ -33,7 +33,7 @@ class SchedulerMetricsCollector:
...
@@ -33,7 +33,7 @@ class SchedulerMetricsCollector:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from
prometheus_client
import
Gauge
from
prometheus_client
import
Gauge
,
Histogram
self
.
labels
=
labels
self
.
labels
=
labels
self
.
last_log_time
=
time
.
time
()
self
.
last_log_time
=
time
.
time
()
...
@@ -139,10 +139,10 @@ class TokenizerMetricsCollector:
...
@@ -139,10 +139,10 @@ class TokenizerMetricsCollector:
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
buckets
=
[
buckets
=
[
0.1
,
0.1
,
0.
3
,
0.
2
,
0.
5
,
0.
4
,
0.
7
,
0.
6
,
0.
9
,
0.
8
,
1
,
1
,
2
,
2
,
4
,
4
,
...
@@ -153,36 +153,9 @@ class TokenizerMetricsCollector:
...
@@ -153,36 +153,9 @@ class TokenizerMetricsCollector:
40
,
40
,
60
,
60
,
80
,
80
,
120
,
100
,
160
,
200
,
],
400
,
)
self
.
histogram_time_per_output_token
=
Histogram
(
name
=
"sglang:time_per_output_token_seconds"
,
documentation
=
"Histogram of time per output token in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
0.002
,
0.005
,
0.010
,
0.020
,
0.030
,
0.040
,
0.050
,
0.060
,
0.070
,
0.080
,
0.090
,
0.100
,
0.150
,
0.200
,
0.300
,
0.400
,
0.600
,
0.800
,
1.000
,
2.000
,
],
],
)
)
...
@@ -202,17 +175,18 @@ class TokenizerMetricsCollector:
...
@@ -202,17 +175,18 @@ class TokenizerMetricsCollector:
0.030
,
0.030
,
0.035
,
0.035
,
0.040
,
0.040
,
0.0
5
0
,
0.0
6
0
,
0.0
75
,
0.0
80
,
0.100
,
0.100
,
0.150
,
0.200
,
0.200
,
0.300
,
0.400
,
0.400
,
0.
5
00
,
0.
6
00
,
0.
75
0
,
0.
80
0
,
1.000
,
1.000
,
2.000
,
2.000
,
4.000
,
6.000
,
8.000
,
],
],
)
)
...
@@ -224,23 +198,22 @@ class TokenizerMetricsCollector:
...
@@ -224,23 +198,22 @@ class TokenizerMetricsCollector:
0.1
,
0.1
,
0.2
,
0.2
,
0.4
,
0.4
,
0.6
,
0.8
,
0.8
,
1
,
1
,
2
,
2
,
5
,
4
,
6
,
8
,
10
,
10
,
20
,
20
,
40
,
40
,
60
,
60
,
80
,
80
,
100
,
100
,
150
,
200
,
200
,
250
,
400
,
300
,
800
,
350
,
500
,
1000
,
],
],
)
)
...
@@ -256,13 +229,10 @@ class TokenizerMetricsCollector:
...
@@ -256,13 +229,10 @@ class TokenizerMetricsCollector:
):
):
self
.
prompt_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
prompt_tokens
)
self
.
prompt_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
prompt_tokens
)
self
.
generation_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
generation_tokens
)
self
.
generation_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
generation_tokens
)
self
.
cached_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
cached_tokens
)
if
cached_tokens
>
0
:
self
.
cached_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
cached_tokens
)
self
.
num_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
self
.
num_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
self
.
_log_histogram
(
self
.
histogram_e2e_request_latency
,
e2e_latency
)
self
.
_log_histogram
(
self
.
histogram_e2e_request_latency
,
e2e_latency
)
if
generation_tokens
>=
1
:
self
.
histogram_time_per_output_token
.
labels
(
**
self
.
labels
).
observe
(
e2e_latency
/
generation_tokens
)
def
observe_time_to_first_token
(
self
,
value
:
float
):
def
observe_time_to_first_token
(
self
,
value
:
float
):
self
.
histogram_time_to_first_token
.
labels
(
**
self
.
labels
).
observe
(
value
)
self
.
histogram_time_to_first_token
.
labels
(
**
self
.
labels
).
observe
(
value
)
...
...
python/sglang/srt/server_args.py
View file @
4ede6770
...
@@ -128,7 +128,7 @@ class ServerArgs:
...
@@ -128,7 +128,7 @@ class ServerArgs:
# Kernel backend
# Kernel backend
attention_backend
:
Optional
[
str
]
=
None
attention_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
grammar_backend
:
Optional
[
str
]
=
"xgrammar"
grammar_backend
:
Optional
[
str
]
=
None
# Speculative decoding
# Speculative decoding
speculative_algorithm
:
Optional
[
str
]
=
None
speculative_algorithm
:
Optional
[
str
]
=
None
...
@@ -193,6 +193,13 @@ class ServerArgs:
...
@@ -193,6 +193,13 @@ class ServerArgs:
disaggregation_bootstrap_port
:
int
=
8998
disaggregation_bootstrap_port
:
int
=
8998
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Expert parallelism
if
self
.
enable_ep_moe
:
self
.
ep_size
=
self
.
tp_size
logger
.
info
(
f
"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[
{
self
.
tp_size
}
]."
)
# Set missing default values
# Set missing default values
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
self
.
tokenizer_path
=
self
.
model_path
self
.
tokenizer_path
=
self
.
model_path
...
@@ -274,12 +281,9 @@ class ServerArgs:
...
@@ -274,12 +281,9 @@ class ServerArgs:
)
)
self
.
disable_cuda_graph
=
True
self
.
disable_cuda_graph
=
True
# Expert parallelism
# Choose grammar backend
if
self
.
enable_ep_moe
:
if
self
.
grammar_backend
is
None
:
self
.
ep_size
=
self
.
tp_size
self
.
grammar_backend
=
"xgrammar"
logger
.
info
(
f
"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[
{
self
.
tp_size
}
]."
)
# Data parallelism attention
# Data parallelism attention
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
...
@@ -813,7 +817,7 @@ class ServerArgs:
...
@@ -813,7 +817,7 @@ class ServerArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--grammar-backend"
,
"--grammar-backend"
,
type
=
str
,
type
=
str
,
choices
=
[
"xgrammar"
,
"outlines"
,
"llguidance"
],
choices
=
[
"xgrammar"
,
"outlines"
,
"llguidance"
,
"none"
],
default
=
ServerArgs
.
grammar_backend
,
default
=
ServerArgs
.
grammar_backend
,
help
=
"Choose the backend for grammar-guided decoding."
,
help
=
"Choose the backend for grammar-guided decoding."
,
)
)
...
...
python/sglang/test/test_utils.py
View file @
4ede6770
...
@@ -1012,9 +1012,6 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
...
@@ -1012,9 +1012,6 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
class
CustomTestCase
(
unittest
.
TestCase
):
class
CustomTestCase
(
unittest
.
TestCase
):
pass
"""
def
_callTestMethod
(
self
,
method
):
def
_callTestMethod
(
self
,
method
):
max_retry
=
int
(
max_retry
=
int
(
os
.
environ
.
get
(
"SGLANG_TEST_MAX_RETRY"
,
"2"
if
is_in_ci
()
else
"0"
)
os
.
environ
.
get
(
"SGLANG_TEST_MAX_RETRY"
,
"2"
if
is_in_ci
()
else
"0"
)
...
@@ -1023,4 +1020,3 @@ class CustomTestCase(unittest.TestCase):
...
@@ -1023,4 +1020,3 @@ class CustomTestCase(unittest.TestCase):
lambda
:
super
(
CustomTestCase
,
self
).
_callTestMethod
(
method
),
lambda
:
super
(
CustomTestCase
,
self
).
_callTestMethod
(
method
),
max_retry
=
max_retry
,
max_retry
=
max_retry
,
)
)
"""
test/srt/models/lora/test_lora_tp.py
View file @
4ede6770
...
@@ -33,6 +33,9 @@ CI_LORA_MODELS = [
...
@@ -33,6 +33,9 @@ CI_LORA_MODELS = [
],
],
max_loras_per_batch
=
1
,
max_loras_per_batch
=
1
,
),
),
]
ALL_OTHER_LORA_MODELS
=
[
LoRAModelCase
(
LoRAModelCase
(
base
=
"meta-llama/Llama-3.1-8B-Instruct"
,
base
=
"meta-llama/Llama-3.1-8B-Instruct"
,
adaptors
=
[
adaptors
=
[
...
@@ -43,9 +46,6 @@ CI_LORA_MODELS = [
...
@@ -43,9 +46,6 @@ CI_LORA_MODELS = [
],
],
max_loras_per_batch
=
1
,
max_loras_per_batch
=
1
,
),
),
]
ALL_OTHER_LORA_MODELS
=
[
LoRAModelCase
(
LoRAModelCase
(
base
=
"meta-llama/Llama-2-7b-hf"
,
base
=
"meta-llama/Llama-2-7b-hf"
,
adaptors
=
[
LoRAAdaptor
(
name
=
"winddude/wizardLM-LlaMA-LoRA-7B"
)],
adaptors
=
[
LoRAAdaptor
(
name
=
"winddude/wizardLM-LlaMA-LoRA-7B"
)],
...
...
test/srt/run_suite.py
View file @
4ede6770
...
@@ -16,7 +16,7 @@ suites = {
...
@@ -16,7 +16,7 @@ suites = {
TestFile
(
"models/lora/test_lora.py"
,
76
),
TestFile
(
"models/lora/test_lora.py"
,
76
),
TestFile
(
"models/lora/test_lora_backend.py"
,
420
),
TestFile
(
"models/lora/test_lora_backend.py"
,
420
),
TestFile
(
"models/lora/test_multi_lora_backend.py"
,
144
),
TestFile
(
"models/lora/test_multi_lora_backend.py"
,
144
),
TestFile
(
"models/test_embedding_models.py"
,
119
),
TestFile
(
"models/test_embedding_models.py"
,
35
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_grok_models.py"
,
60
),
TestFile
(
"models/test_grok_models.py"
,
60
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
...
@@ -38,7 +38,7 @@ suites = {
...
@@ -38,7 +38,7 @@ suites = {
TestFile
(
"test_metrics.py"
,
32
),
TestFile
(
"test_metrics.py"
,
32
),
TestFile
(
"test_mla.py"
,
92
),
TestFile
(
"test_mla.py"
,
92
),
TestFile
(
"test_mla_deepseek_v3.py"
,
221
),
TestFile
(
"test_mla_deepseek_v3.py"
,
221
),
TestFile
(
"test_mla_int8_deepseek_v3.py"
,
421
),
TestFile
(
"test_mla_int8_deepseek_v3.py"
,
522
),
TestFile
(
"test_mla_flashinfer.py"
,
395
),
TestFile
(
"test_mla_flashinfer.py"
,
395
),
TestFile
(
"test_mla_fp8.py"
,
93
),
TestFile
(
"test_mla_fp8.py"
,
93
),
TestFile
(
"test_no_chunked_prefill.py"
,
126
),
TestFile
(
"test_no_chunked_prefill.py"
,
126
),
...
@@ -59,7 +59,7 @@ suites = {
...
@@ -59,7 +59,7 @@ suites = {
TestFile
(
"test_srt_endpoint.py"
,
94
),
TestFile
(
"test_srt_endpoint.py"
,
94
),
TestFile
(
"test_torch_compile.py"
,
76
),
TestFile
(
"test_torch_compile.py"
,
76
),
TestFile
(
"test_torch_compile_moe.py"
,
85
),
TestFile
(
"test_torch_compile_moe.py"
,
85
),
TestFile
(
"test_torch_native_attention_backend.py"
,
1
49
),
TestFile
(
"test_torch_native_attention_backend.py"
,
1
23
),
TestFile
(
"test_torchao.py"
,
70
),
TestFile
(
"test_torchao.py"
,
70
),
TestFile
(
"test_triton_attention_kernels.py"
,
4
),
TestFile
(
"test_triton_attention_kernels.py"
,
4
),
TestFile
(
"test_triton_attention_backend.py"
,
134
),
TestFile
(
"test_triton_attention_backend.py"
,
134
),
...
@@ -76,6 +76,16 @@ suites = {
...
@@ -76,6 +76,16 @@ suites = {
TestFile
(
"test_hicache.py"
,
60
),
TestFile
(
"test_hicache.py"
,
60
),
TestFile
(
"test_hicache_mla.py"
,
90
),
TestFile
(
"test_hicache_mla.py"
,
90
),
],
],
"per-commit-2-gpu"
:
[
TestFile
(
"test_data_parallelism.py"
,
90
),
TestFile
(
"test_dp_attention.py"
,
90
),
TestFile
(
"test_update_weights_from_distributed.py"
,
100
),
TestFile
(
"test_verl_engine.py"
,
100
),
TestFile
(
"test_patch_torch.py"
,
30
),
TestFile
(
"test_moe_ep.py"
,
220
),
TestFile
(
"test_mla_tp.py"
,
420
),
TestFile
(
"test_lora_tp.py"
,
300
),
],
"nightly"
:
[
"nightly"
:
[
TestFile
(
"test_nightly_gsm8k_eval.py"
),
TestFile
(
"test_nightly_gsm8k_eval.py"
),
],
],
...
...
test/srt/test_dp_attention.py
View file @
4ede6770
...
@@ -60,3 +60,7 @@ class TestDPAttentionDP2TP2(CustomTestCase):
...
@@ -60,3 +60,7 @@ class TestDPAttentionDP2TP2(CustomTestCase):
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"score"
],
0.8
)
self
.
assertGreater
(
metrics
[
"score"
],
0.8
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_metrics.py
View file @
4ede6770
...
@@ -63,7 +63,6 @@ class TestEnableMetrics(CustomTestCase):
...
@@ -63,7 +63,6 @@ class TestEnableMetrics(CustomTestCase):
"sglang:cached_tokens_total"
,
"sglang:cached_tokens_total"
,
"sglang:num_requests_total"
,
"sglang:num_requests_total"
,
"sglang:time_to_first_token_seconds"
,
"sglang:time_to_first_token_seconds"
,
"sglang:time_per_output_token_seconds"
,
"sglang:inter_token_latency_seconds"
,
"sglang:inter_token_latency_seconds"
,
"sglang:e2e_request_latency_seconds"
,
"sglang:e2e_request_latency_seconds"
,
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment