Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2ac46e94
"docs/vscode:/vscode.git/clone" did not exist on "4bfebd857d9b94dff98e88c8cc59880f8fa54ec7"
Unverified
Commit
2ac46e94
authored
Oct 12, 2025
by
Lianmin Zheng
Committed by
GitHub
Oct 12, 2025
Browse files
Sync changes on io_struct.py and deterministic ops (#11498)
parent
0aa65f94
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
73 additions
and
25 deletions
+73
-25
docs/advanced_features/server_arguments.md
docs/advanced_features/server_arguments.md
+0
-1
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+4
-1
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+14
-7
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+2
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+35
-2
python/sglang/srt/managers/multi_tokenizer_mixin.py
python/sglang/srt/managers/multi_tokenizer_mixin.py
+12
-0
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+2
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+0
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-6
python/sglang/test/test_cutlass_moe.py
python/sglang/test/test_cutlass_moe.py
+1
-1
python/sglang/test/test_deterministic_utils.py
python/sglang/test/test_deterministic_utils.py
+3
-3
No files found.
docs/advanced_features/server_arguments.md
View file @
2ac46e94
...
...
@@ -321,7 +321,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
`--debug-tensor-dump-output-folder`
| The output folder for debug tensor dumps. | None |
|
`--debug-tensor-dump-input-file`
| The input file for debug tensor dumps. | None |
|
`--debug-tensor-dump-inject`
| Enable injection of debug tensor dumps. | False |
|
`--debug-tensor-dump-prefill-only`
| Enable prefill-only mode for debug tensor dumps. | False |
## PD disaggregation
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
2ac46e94
...
...
@@ -240,6 +240,7 @@ class GroupCoordinator:
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
torch_compile
:
Optional
[
bool
]
=
None
,
gloo_timeout
:
timedelta
=
timedelta
(
seconds
=
120
*
60
),
):
# Set group info
group_name
=
group_name
or
"anonymous"
...
...
@@ -259,7 +260,9 @@ class GroupCoordinator:
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
,
timeout
=
gloo_timeout
)
if
self
.
rank
in
ranks
:
self
.
ranks
=
ranks
self
.
world_size
=
len
(
ranks
)
...
...
python/sglang/srt/layers/sampler.py
View file @
2ac46e94
...
...
@@ -91,7 +91,6 @@ class Sampler(nn.Module):
batch_next_token_ids
=
torch
.
argmax
(
logits
,
-
1
)
if
return_logprob
:
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
else
:
# If requested, cache probabilities from original logits before temperature scaling.
if
return_logprob
and
RETURN_ORIGINAL_LOGPROB
:
...
...
@@ -288,21 +287,29 @@ def multinomial_with_seed(
"""
n
,
m
=
inputs
.
shape
col_indices
=
torch
.
arange
(
m
,
device
=
inputs
.
device
).
unsqueeze
(
0
)
step_seed
=
seed
*
19349663
^
positions
*
73856093
step_seed
=
(
seed
*
19349663
)
^
(
positions
*
73856093
)
seed_expanded
=
step_seed
.
unsqueeze
(
-
1
)
hashed
=
seed_expanded
*
8589934591
^
col_indices
*
479001599
hashed
=
(
seed_expanded
*
8589934591
)
^
(
col_indices
*
479001599
)
uniform_samples
=
(
hashed
%
(
2
**
24
)).
float
()
/
(
2
**
24
)
epsilon
=
1e-9
gumbel_noise
=
-
torch
.
log
(
-
torch
.
log
(
uniform_samples
+
epsilon
)
+
epsilon
)
epsilon
=
1e-10
uniform_samples
=
uniform_samples
.
clamp
(
epsilon
,
1.0
-
epsilon
)
gumbel_noise
=
-
torch
.
log
(
-
torch
.
log
(
uniform_samples
))
log_probs
=
torch
.
log
(
inputs
+
epsilon
)
perturbed_log_probs
=
log_probs
+
gumbel_noise
return
torch
.
argmax
(
perturbed_log_probs
,
dim
=
1
,
keepdim
=
True
)
def
sampling_from_probs_torch
(
probs
:
torch
.
Tensor
):
def
sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
sampling_seed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering."""
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
if
sampling_seed
is
not
None
:
sampled_index
=
multinomial_with_seed
(
probs
,
sampling_seed
,
positions
)
else
:
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
batch_next_token_ids
=
sampled_index
.
view
(
-
1
).
to
(
torch
.
int32
)
return
batch_next_token_ids
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
2ac46e94
...
...
@@ -245,9 +245,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
input_token_ids_logprobs_idx
=
recv_obj
.
input_token_ids_logprobs_idx
,
output_token_ids_logprobs_val
=
recv_obj
.
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
=
recv_obj
.
output_token_ids_logprobs_idx
,
output_token_entropy_val
=
recv_obj
.
output_token_entropy_val
,
output_hidden_states
=
recv_obj
.
output_hidden_states
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
token_steps
=
recv_obj
.
token_steps
,
)
def
handle_multimodal_decode_req
(
self
,
recv_obj
:
BatchMultimodalDecodeReq
):
...
...
python/sglang/srt/managers/io_struct.py
View file @
2ac46e94
...
...
@@ -170,6 +170,9 @@ class GenerateReqInput(BaseReq):
# (Internal) Whether to return bytes for image generation
return_bytes
:
bool
=
False
# Whether to return entropy
return_entropy
:
bool
=
False
def
contains_mm_input
(
self
)
->
bool
:
return
(
has_valid_data
(
self
.
image_data
)
...
...
@@ -568,6 +571,7 @@ class GenerateReqInput(BaseReq):
no_logs
=
self
.
no_logs
,
custom_labels
=
self
.
custom_labels
,
return_bytes
=
self
.
return_bytes
,
return_entropy
=
self
.
return_entropy
,
)
...
...
@@ -633,6 +637,9 @@ class TokenizedGenerateReqInput(BaseReq):
# (Internal) Whether to return bytes for image generation
return_bytes
:
bool
=
False
# Whether to return entropy
return_entropy
:
bool
=
False
@
dataclass
class
BatchTokenizedGenerateReqInput
(
BaseBatchReq
):
...
...
@@ -830,6 +837,7 @@ class BatchTokenIDOutput(BaseBatchReq):
input_token_ids_logprobs_idx
:
List
[
List
]
output_token_ids_logprobs_val
:
List
[
List
]
output_token_ids_logprobs_idx
:
List
[
List
]
output_token_entropy_val
:
List
[
float
]
# Hidden states
output_hidden_states
:
List
[
List
[
float
]]
...
...
@@ -840,6 +848,9 @@ class BatchTokenIDOutput(BaseBatchReq):
placeholder_tokens_idx
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_val
:
List
[
Optional
[
List
[
int
]]]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps
:
List
[
List
[
int
]]
=
None
@
dataclass
class
BatchMultimodalDecodeReq
(
BaseBatchReq
):
...
...
@@ -861,11 +872,14 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
completion_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
# Placeholder token info
# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_val
:
List
[
Optional
[
List
[
int
]]]
return_bytes
:
bool
=
False
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps
:
List
[
List
[
int
]]
=
None
@
dataclass
...
...
@@ -896,13 +910,20 @@ class BatchStrOutput(BaseBatchReq):
input_token_ids_logprobs_idx
:
List
[
List
]
output_token_ids_logprobs_val
:
List
[
List
]
output_token_ids_logprobs_idx
:
List
[
List
]
output_token_entropy_val
:
List
[
float
]
# Hidden states
output_hidden_states
:
List
[
List
[
float
]]
# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_val
:
List
[
Optional
[
List
[
int
]]]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps
:
List
[
List
[
int
]]
=
None
@
dataclass
class
BatchMultimodalOutput
(
BaseBatchReq
):
...
...
@@ -979,6 +1000,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
torch_empty_cache
:
bool
=
False
# Whether to keep the scheduler paused after weight update
keep_pause
:
bool
=
False
# The trainer step id. Used to know which step's weights are used for sampling.
token_step
:
int
=
0
@
dataclass
...
...
@@ -1416,6 +1439,16 @@ class WatchLoadUpdateReq(BaseReq):
loads
:
List
[
GetLoadReqOutput
]
@
dataclass
class
LazyDumpTensorsReqInput
(
BaseReq
):
pass
@
dataclass
class
LazyDumpTensorsReqOutput
(
BaseReq
):
success
:
bool
def
_check_all_req_types
():
"""A helper function to check all request types are defined in this file."""
import
inspect
...
...
python/sglang/srt/managers/multi_tokenizer_mixin.py
View file @
2ac46e94
...
...
@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i):
if
output
.
output_token_ids_logprobs_idx
else
None
),
output_token_entropy_val
=
(
[
output
.
output_token_entropy_val
[
i
]]
if
output
.
output_token_entropy_val
else
None
),
output_hidden_states
=
(
[
output
.
output_hidden_states
[
i
]]
if
output
.
output_hidden_states
...
...
@@ -197,6 +202,7 @@ def _handle_output_by_index(output, i):
),
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
token_steps
=
([
output
.
token_steps
[
i
]]
if
output
.
token_steps
else
None
),
)
elif
isinstance
(
output
,
BatchEmbeddingOutput
):
new_output
=
BatchEmbeddingOutput
(
...
...
@@ -306,6 +312,11 @@ def _handle_output_by_index(output, i):
if
output
.
output_token_ids_logprobs_idx
else
None
),
output_token_entropy_val
=
(
[
output
.
output_token_entropy_val
[
i
]]
if
output
.
output_token_entropy_val
else
None
),
output_hidden_states
=
(
[
output
.
output_hidden_states
[
i
]]
if
output
.
output_hidden_states
...
...
@@ -313,6 +324,7 @@ def _handle_output_by_index(output, i):
),
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
token_steps
=
([
output
.
token_steps
[
i
]]
if
output
.
token_steps
else
None
),
)
elif
isinstance
(
output
,
BatchMultimodalOutput
):
new_output
=
BatchMultimodalOutput
(
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
2ac46e94
...
...
@@ -920,7 +920,8 @@ class SchedulerOutputProcessorMixin:
input_token_ids_logprobs_idx
,
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
,
output_hidden_states
,
output_token_entropy_val
=
None
,
output_hidden_states
=
output_hidden_states
,
rids
=
rids
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
...
...
python/sglang/srt/models/grok.py
View file @
2ac46e94
...
...
@@ -73,9 +73,6 @@ logger = logging.getLogger(__name__)
# Dump tensors for debugging
debug_tensor_dump_output_folder
=
None
debug_tensor_dump_prefill_only
=
False
# Skip all the other tensor dumps, only dump the target logits
debug_tensor_dump_only_target_logprobs
=
False
debug_tensor_dump_inject
=
False
debug_tensor_dump_layers
=
None
debug_tensor_dump_test
=
False
...
...
python/sglang/srt/server_args.py
View file @
2ac46e94
...
...
@@ -455,7 +455,6 @@ class ServerArgs:
debug_tensor_dump_output_folder
:
Optional
[
str
]
=
None
debug_tensor_dump_input_file
:
Optional
[
str
]
=
None
debug_tensor_dump_inject
:
bool
=
False
debug_tensor_dump_prefill_only
:
bool
=
False
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode
:
Literal
[
"null"
,
"prefill"
,
"decode"
]
=
"null"
...
...
@@ -2831,11 +2830,6 @@ class ServerArgs:
default
=
ServerArgs
.
debug_tensor_dump_inject
,
help
=
"Inject the outputs from jax as the input of every layer."
,
)
parser
.
add_argument
(
"--debug-tensor-dump-prefill-only"
,
action
=
"store_true"
,
help
=
"Only dump the tensors for prefill requests (i.e. batch size > 1)."
,
)
parser
.
add_argument
(
"--enable-dynamic-batch-tokenizer"
,
action
=
"store_true"
,
...
...
python/sglang/test/test_cutlass_moe.py
View file @
2ac46e94
...
...
@@ -34,7 +34,7 @@ def get_model_config(tp_size: int):
"topk"
:
topk
,
"hidden_size"
:
config
.
hidden_size
,
"shard_intermediate_size"
:
shard_intermediate_size
,
"dtype"
:
config
.
torch_
dtype
,
"dtype"
:
config
.
dtype
,
"block_shape"
:
config
.
quantization_config
[
"weight_block_size"
],
}
...
...
python/sglang/test/test_deterministic_utils.py
View file @
2ac46e94
import
time
import
unittest
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_deterministic
import
BenchArgs
,
test_deterministic
from
sglang.test.test_utils
import
(
...
...
@@ -55,6 +52,7 @@ class TestDeterministicBase(CustomTestCase):
args
.
n_start
=
10
args
.
n_trials
=
20
results
=
test_deterministic
(
args
)
args
.
temperature
=
0.5
# test for deterministic sampling
for
result
in
results
:
assert
result
==
1
...
...
@@ -65,6 +63,7 @@ class TestDeterministicBase(CustomTestCase):
args
.
test_mode
=
"mixed"
args
.
n_start
=
10
args
.
n_trials
=
20
args
.
temperature
=
0.5
# test for deterministic sampling
results
=
test_deterministic
(
args
)
for
result
in
results
:
assert
result
==
1
...
...
@@ -76,6 +75,7 @@ class TestDeterministicBase(CustomTestCase):
args
.
test_mode
=
"prefix"
args
.
n_start
=
10
args
.
n_trials
=
10
args
.
temperature
=
0.5
# test for deterministic sampling
results
=
test_deterministic
(
args
)
for
result
in
results
:
assert
result
==
1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment