Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
632b7d8c
Unverified
Commit
632b7d8c
authored
Sep 23, 2025
by
Liangsheng Yin
Committed by
GitHub
Sep 23, 2025
Browse files
Use simulate acc len from `sglang.environ` (#10771)
parent
16adf3dc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
13 deletions
+31
-13
python/sglang/environ.py
python/sglang/environ.py
+2
-0
python/sglang/launch_server.py
python/sglang/launch_server.py
+14
-0
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+15
-13
No files found.
python/sglang/environ.py
View file @
632b7d8c
...
@@ -124,6 +124,8 @@ class Envs:
...
@@ -124,6 +124,8 @@ class Envs:
SGLANG_TEST_REQUEST_TIME_STATS
=
EnvBool
(
False
)
SGLANG_TEST_REQUEST_TIME_STATS
=
EnvBool
(
False
)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK
=
EnvBool
(
False
)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK
=
EnvBool
(
False
)
SGLANG_DISABLE_REQUEST_LOGGING
=
EnvBool
(
False
)
SGLANG_DISABLE_REQUEST_LOGGING
=
EnvBool
(
False
)
SGLANG_SIMULATE_ACC_LEN
=
EnvFloat
(
-
1
)
SGLANG_SIMULATE_ACC_METHOD
=
EnvStr
(
"multinomial"
)
# Model Parallel
# Model Parallel
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER
=
EnvBool
(
True
)
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER
=
EnvBool
(
True
)
...
...
python/sglang/launch_server.py
View file @
632b7d8c
...
@@ -7,9 +7,23 @@ from sglang.srt.entrypoints.http_server import launch_server
...
@@ -7,9 +7,23 @@ from sglang.srt.entrypoints.http_server import launch_server
from
sglang.srt.server_args
import
prepare_server_args
from
sglang.srt.server_args
import
prepare_server_args
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
MOVE_ENVS_WARN
=
"""
########################################################################
# For contributors and developers: #
# Please move environment variable definitions to 'sglang/environ.py' #
# using the following pattern: #
# SGLANG_XXX = EnvBool(False) #
# #
########################################################################
"""
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
from
sglang.srt.server_args
import
print_deprecated_warning
print_deprecated_warning
(
MOVE_ENVS_WARN
)
try
:
try
:
launch_server
(
server_args
)
launch_server
(
server_args
)
finally
:
finally
:
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
632b7d8c
...
@@ -12,6 +12,7 @@ import torch.nn.functional as F
...
@@ -12,6 +12,7 @@ import torch.nn.functional as F
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.environ
import
envs
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
@@ -23,7 +24,7 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -23,7 +24,7 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
if
is_cuda
():
if
is_cuda
():
...
@@ -42,8 +43,8 @@ logger = logging.getLogger(__name__)
...
@@ -42,8 +43,8 @@ logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN
=
os
.
environ
.
get
(
"SIMULATE_ACC_LEN"
)
SIMULATE_ACC_LEN
=
envs
.
SGLANG_SIMULATE_ACC_LEN
.
get
()
# turn off if < 0
SIMULATE_ACC_METHOD
=
os
.
environ
.
get
(
"
SIMULATE_ACC_METHOD
"
,
"multinomial"
)
SIMULATE_ACC_METHOD
=
envs
.
SGLANG_
SIMULATE_ACC_METHOD
.
get
(
)
TREE_TRAVERSE_TIME_THRESHOLD
=
1
# TODO: set this properly
TREE_TRAVERSE_TIME_THRESHOLD
=
1
# TODO: set this properly
...
@@ -500,13 +501,12 @@ class EagleVerifyInput:
...
@@ -500,13 +501,12 @@ class EagleVerifyInput:
deterministic
=
True
,
deterministic
=
True
,
)
)
if
SIMULATE_ACC_LEN
:
if
SIMULATE_ACC_LEN
>
0.0
:
# Do simulation
# Do simulation
accept_index
=
_generate_simulated_accept_index
(
accept_index
=
_generate_simulated_accept_index
(
accept_index
=
accept_index
,
accept_index
=
accept_index
,
predict
=
predict
,
# mutable
predict
=
predict
,
# mutable
accept_length
=
accept_length
,
# mutable
accept_length
=
accept_length
,
# mutable
simulate_acc_len
=
SIMULATE_ACC_LEN
,
bs
=
bs
,
bs
=
bs
,
spec_steps
=
self
.
spec_steps
,
spec_steps
=
self
.
spec_steps
,
)
)
...
@@ -1131,14 +1131,16 @@ def _generate_simulated_accept_index(
...
@@ -1131,14 +1131,16 @@ def _generate_simulated_accept_index(
accept_index
,
accept_index
,
predict
,
predict
,
accept_length
,
accept_length
,
simulate_acc_len
,
bs
,
bs
,
spec_steps
,
spec_steps
,
simulate_acc_len
:
float
=
SIMULATE_ACC_LEN
,
simulate_acc_method
:
str
=
SIMULATE_ACC_METHOD
,
):
):
simulate_acc_len_float
=
float
(
simulate_acc_len
)
assert
simulate_acc_len
>
0.0
if
SIMULATE_ACC_METHOD
==
"multinomial"
:
if
simulate_acc_method
==
"multinomial"
:
simulated_values
=
torch
.
normal
(
simulated_values
=
torch
.
normal
(
mean
=
simulate_acc_len
_float
,
mean
=
simulate_acc_len
,
std
=
1.0
,
std
=
1.0
,
size
=
(
1
,),
size
=
(
1
,),
device
=
"cpu"
,
device
=
"cpu"
,
...
@@ -1146,19 +1148,19 @@ def _generate_simulated_accept_index(
...
@@ -1146,19 +1148,19 @@ def _generate_simulated_accept_index(
# clamp simulated values to be between 1 and self.spec_steps
# clamp simulated values to be between 1 and self.spec_steps
simulated_values
=
torch
.
clamp
(
simulated_values
,
min
=
1.0
,
max
=
spec_steps
+
1
)
simulated_values
=
torch
.
clamp
(
simulated_values
,
min
=
1.0
,
max
=
spec_steps
+
1
)
simulate_acc_len
=
int
(
simulated_values
.
round
().
item
())
simulate_acc_len
=
int
(
simulated_values
.
round
().
item
())
elif
SIMULATE_ACC_METHOD
==
"match-expected"
:
elif
simulate_acc_method
==
"match-expected"
:
# multinomial sampling does not match the expected length
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
# either round down or round up of the expected length
simulate_acc_len
_float
=
max
(
1.0
,
min
(
spec_steps
+
1
,
simulate_acc_len
_float
))
simulate_acc_len
=
max
(
1.0
,
min
(
spec_steps
+
1
,
simulate_acc_len
))
lower
=
int
(
simulate_acc_len
_float
//
1
)
lower
=
int
(
simulate_acc_len
//
1
)
upper
=
lower
+
1
if
lower
<
spec_steps
+
1
else
lower
upper
=
lower
+
1
if
lower
<
spec_steps
+
1
else
lower
if
lower
==
upper
:
if
lower
==
upper
:
simulate_acc_len
=
lower
simulate_acc_len
=
lower
else
:
else
:
weight_upper
=
simulate_acc_len
_float
-
lower
weight_upper
=
simulate_acc_len
-
lower
weight_lower
=
1.0
-
weight_upper
weight_lower
=
1.0
-
weight_upper
probs
=
torch
.
tensor
([
weight_lower
,
weight_upper
],
device
=
"cpu"
)
probs
=
torch
.
tensor
([
weight_lower
,
weight_upper
],
device
=
"cpu"
)
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment