Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2ac46e94
Unverified
Commit
2ac46e94
authored
Oct 12, 2025
by
Lianmin Zheng
Committed by
GitHub
Oct 12, 2025
Browse files
Sync changes on io_struct.py and deterministic ops (#11498)
parent
0aa65f94
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
73 additions
and
25 deletions
+73
-25
docs/advanced_features/server_arguments.md
docs/advanced_features/server_arguments.md
+0
-1
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+4
-1
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+14
-7
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+2
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+35
-2
python/sglang/srt/managers/multi_tokenizer_mixin.py
python/sglang/srt/managers/multi_tokenizer_mixin.py
+12
-0
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+2
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+0
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-6
python/sglang/test/test_cutlass_moe.py
python/sglang/test/test_cutlass_moe.py
+1
-1
python/sglang/test/test_deterministic_utils.py
python/sglang/test/test_deterministic_utils.py
+3
-3
No files found.
docs/advanced_features/server_arguments.md
View file @
2ac46e94
...
@@ -321,7 +321,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
...
@@ -321,7 +321,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
`--debug-tensor-dump-output-folder`
| The output folder for debug tensor dumps. | None |
|
`--debug-tensor-dump-output-folder`
| The output folder for debug tensor dumps. | None |
|
`--debug-tensor-dump-input-file`
| The input file for debug tensor dumps. | None |
|
`--debug-tensor-dump-input-file`
| The input file for debug tensor dumps. | None |
|
`--debug-tensor-dump-inject`
| Enable injection of debug tensor dumps. | False |
|
`--debug-tensor-dump-inject`
| Enable injection of debug tensor dumps. | False |
|
`--debug-tensor-dump-prefill-only`
| Enable prefill-only mode for debug tensor dumps. | False |
## PD disaggregation
## PD disaggregation
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
2ac46e94
...
@@ -240,6 +240,7 @@ class GroupCoordinator:
...
@@ -240,6 +240,7 @@ class GroupCoordinator:
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
group_name
:
Optional
[
str
]
=
None
,
torch_compile
:
Optional
[
bool
]
=
None
,
torch_compile
:
Optional
[
bool
]
=
None
,
gloo_timeout
:
timedelta
=
timedelta
(
seconds
=
120
*
60
),
):
):
# Set group info
# Set group info
group_name
=
group_name
or
"anonymous"
group_name
=
group_name
or
"anonymous"
...
@@ -259,7 +260,9 @@ class GroupCoordinator:
...
@@ -259,7 +260,9 @@ class GroupCoordinator:
)
)
# a group with `gloo` backend, to allow direct coordination between
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
# processes through the CPU.
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
,
timeout
=
gloo_timeout
)
if
self
.
rank
in
ranks
:
if
self
.
rank
in
ranks
:
self
.
ranks
=
ranks
self
.
ranks
=
ranks
self
.
world_size
=
len
(
ranks
)
self
.
world_size
=
len
(
ranks
)
...
...
python/sglang/srt/layers/sampler.py
View file @
2ac46e94
...
@@ -91,7 +91,6 @@ class Sampler(nn.Module):
...
@@ -91,7 +91,6 @@ class Sampler(nn.Module):
batch_next_token_ids
=
torch
.
argmax
(
logits
,
-
1
)
batch_next_token_ids
=
torch
.
argmax
(
logits
,
-
1
)
if
return_logprob
:
if
return_logprob
:
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
else
:
else
:
# If requested, cache probabilities from original logits before temperature scaling.
# If requested, cache probabilities from original logits before temperature scaling.
if
return_logprob
and
RETURN_ORIGINAL_LOGPROB
:
if
return_logprob
and
RETURN_ORIGINAL_LOGPROB
:
...
@@ -288,21 +287,29 @@ def multinomial_with_seed(
...
@@ -288,21 +287,29 @@ def multinomial_with_seed(
"""
"""
n
,
m
=
inputs
.
shape
n
,
m
=
inputs
.
shape
col_indices
=
torch
.
arange
(
m
,
device
=
inputs
.
device
).
unsqueeze
(
0
)
col_indices
=
torch
.
arange
(
m
,
device
=
inputs
.
device
).
unsqueeze
(
0
)
step_seed
=
seed
*
19349663
^
positions
*
73856093
step_seed
=
(
seed
*
19349663
)
^
(
positions
*
73856093
)
seed_expanded
=
step_seed
.
unsqueeze
(
-
1
)
seed_expanded
=
step_seed
.
unsqueeze
(
-
1
)
hashed
=
seed_expanded
*
8589934591
^
col_indices
*
479001599
hashed
=
(
seed_expanded
*
8589934591
)
^
(
col_indices
*
479001599
)
uniform_samples
=
(
hashed
%
(
2
**
24
)).
float
()
/
(
2
**
24
)
uniform_samples
=
(
hashed
%
(
2
**
24
)).
float
()
/
(
2
**
24
)
epsilon
=
1e-9
epsilon
=
1e-10
gumbel_noise
=
-
torch
.
log
(
-
torch
.
log
(
uniform_samples
+
epsilon
)
+
epsilon
)
uniform_samples
=
uniform_samples
.
clamp
(
epsilon
,
1.0
-
epsilon
)
gumbel_noise
=
-
torch
.
log
(
-
torch
.
log
(
uniform_samples
))
log_probs
=
torch
.
log
(
inputs
+
epsilon
)
log_probs
=
torch
.
log
(
inputs
+
epsilon
)
perturbed_log_probs
=
log_probs
+
gumbel_noise
perturbed_log_probs
=
log_probs
+
gumbel_noise
return
torch
.
argmax
(
perturbed_log_probs
,
dim
=
1
,
keepdim
=
True
)
return
torch
.
argmax
(
perturbed_log_probs
,
dim
=
1
,
keepdim
=
True
)
def
sampling_from_probs_torch
(
probs
:
torch
.
Tensor
):
def
sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
sampling_seed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""A sampling implementation with native pytorch operations, without
"""A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering."""
top-k, top-p, or min-p filtering."""
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
if
sampling_seed
is
not
None
:
sampled_index
=
multinomial_with_seed
(
probs
,
sampling_seed
,
positions
)
else
:
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
batch_next_token_ids
=
sampled_index
.
view
(
-
1
).
to
(
torch
.
int32
)
batch_next_token_ids
=
sampled_index
.
view
(
-
1
).
to
(
torch
.
int32
)
return
batch_next_token_ids
return
batch_next_token_ids
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
2ac46e94
...
@@ -245,9 +245,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
...
@@ -245,9 +245,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
input_token_ids_logprobs_idx
=
recv_obj
.
input_token_ids_logprobs_idx
,
input_token_ids_logprobs_idx
=
recv_obj
.
input_token_ids_logprobs_idx
,
output_token_ids_logprobs_val
=
recv_obj
.
output_token_ids_logprobs_val
,
output_token_ids_logprobs_val
=
recv_obj
.
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
=
recv_obj
.
output_token_ids_logprobs_idx
,
output_token_ids_logprobs_idx
=
recv_obj
.
output_token_ids_logprobs_idx
,
output_token_entropy_val
=
recv_obj
.
output_token_entropy_val
,
output_hidden_states
=
recv_obj
.
output_hidden_states
,
output_hidden_states
=
recv_obj
.
output_hidden_states
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
placeholder_tokens_val
=
None
,
token_steps
=
recv_obj
.
token_steps
,
)
)
def
handle_multimodal_decode_req
(
self
,
recv_obj
:
BatchMultimodalDecodeReq
):
def
handle_multimodal_decode_req
(
self
,
recv_obj
:
BatchMultimodalDecodeReq
):
...
...
python/sglang/srt/managers/io_struct.py
View file @
2ac46e94
...
@@ -170,6 +170,9 @@ class GenerateReqInput(BaseReq):
...
@@ -170,6 +170,9 @@ class GenerateReqInput(BaseReq):
# (Internal) Whether to return bytes for image generation
# (Internal) Whether to return bytes for image generation
return_bytes
:
bool
=
False
return_bytes
:
bool
=
False
# Whether to return entropy
return_entropy
:
bool
=
False
def
contains_mm_input
(
self
)
->
bool
:
def
contains_mm_input
(
self
)
->
bool
:
return
(
return
(
has_valid_data
(
self
.
image_data
)
has_valid_data
(
self
.
image_data
)
...
@@ -568,6 +571,7 @@ class GenerateReqInput(BaseReq):
...
@@ -568,6 +571,7 @@ class GenerateReqInput(BaseReq):
no_logs
=
self
.
no_logs
,
no_logs
=
self
.
no_logs
,
custom_labels
=
self
.
custom_labels
,
custom_labels
=
self
.
custom_labels
,
return_bytes
=
self
.
return_bytes
,
return_bytes
=
self
.
return_bytes
,
return_entropy
=
self
.
return_entropy
,
)
)
...
@@ -633,6 +637,9 @@ class TokenizedGenerateReqInput(BaseReq):
...
@@ -633,6 +637,9 @@ class TokenizedGenerateReqInput(BaseReq):
# (Internal) Whether to return bytes for image generation
# (Internal) Whether to return bytes for image generation
return_bytes
:
bool
=
False
return_bytes
:
bool
=
False
# Whether to return entropy
return_entropy
:
bool
=
False
@
dataclass
@
dataclass
class
BatchTokenizedGenerateReqInput
(
BaseBatchReq
):
class
BatchTokenizedGenerateReqInput
(
BaseBatchReq
):
...
@@ -830,6 +837,7 @@ class BatchTokenIDOutput(BaseBatchReq):
...
@@ -830,6 +837,7 @@ class BatchTokenIDOutput(BaseBatchReq):
input_token_ids_logprobs_idx
:
List
[
List
]
input_token_ids_logprobs_idx
:
List
[
List
]
output_token_ids_logprobs_val
:
List
[
List
]
output_token_ids_logprobs_val
:
List
[
List
]
output_token_ids_logprobs_idx
:
List
[
List
]
output_token_ids_logprobs_idx
:
List
[
List
]
output_token_entropy_val
:
List
[
float
]
# Hidden states
# Hidden states
output_hidden_states
:
List
[
List
[
float
]]
output_hidden_states
:
List
[
List
[
float
]]
...
@@ -840,6 +848,9 @@ class BatchTokenIDOutput(BaseBatchReq):
...
@@ -840,6 +848,9 @@ class BatchTokenIDOutput(BaseBatchReq):
placeholder_tokens_idx
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_idx
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_val
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_val
:
List
[
Optional
[
List
[
int
]]]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps
:
List
[
List
[
int
]]
=
None
@
dataclass
@
dataclass
class
BatchMultimodalDecodeReq
(
BaseBatchReq
):
class
BatchMultimodalDecodeReq
(
BaseBatchReq
):
...
@@ -861,11 +872,14 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
...
@@ -861,11 +872,14 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
completion_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
# Placeholder token info
# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_idx
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_val
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_val
:
List
[
Optional
[
List
[
int
]]]
return_bytes
:
bool
=
False
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps
:
List
[
List
[
int
]]
=
None
@
dataclass
@
dataclass
...
@@ -896,13 +910,20 @@ class BatchStrOutput(BaseBatchReq):
...
@@ -896,13 +910,20 @@ class BatchStrOutput(BaseBatchReq):
input_token_ids_logprobs_idx
:
List
[
List
]
input_token_ids_logprobs_idx
:
List
[
List
]
output_token_ids_logprobs_val
:
List
[
List
]
output_token_ids_logprobs_val
:
List
[
List
]
output_token_ids_logprobs_idx
:
List
[
List
]
output_token_ids_logprobs_idx
:
List
[
List
]
output_token_entropy_val
:
List
[
float
]
# Hidden states
# Hidden states
output_hidden_states
:
List
[
List
[
float
]]
output_hidden_states
:
List
[
List
[
float
]]
# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_idx
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_val
:
List
[
Optional
[
List
[
int
]]]
placeholder_tokens_val
:
List
[
Optional
[
List
[
int
]]]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps
:
List
[
List
[
int
]]
=
None
@
dataclass
@
dataclass
class
BatchMultimodalOutput
(
BaseBatchReq
):
class
BatchMultimodalOutput
(
BaseBatchReq
):
...
@@ -979,6 +1000,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
...
@@ -979,6 +1000,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
torch_empty_cache
:
bool
=
False
torch_empty_cache
:
bool
=
False
# Whether to keep the scheduler paused after weight update
# Whether to keep the scheduler paused after weight update
keep_pause
:
bool
=
False
keep_pause
:
bool
=
False
# The trainer step id. Used to know which step's weights are used for sampling.
token_step
:
int
=
0
@
dataclass
@
dataclass
...
@@ -1416,6 +1439,16 @@ class WatchLoadUpdateReq(BaseReq):
...
@@ -1416,6 +1439,16 @@ class WatchLoadUpdateReq(BaseReq):
loads
:
List
[
GetLoadReqOutput
]
loads
:
List
[
GetLoadReqOutput
]
@
dataclass
class
LazyDumpTensorsReqInput
(
BaseReq
):
pass
@
dataclass
class
LazyDumpTensorsReqOutput
(
BaseReq
):
success
:
bool
def
_check_all_req_types
():
def
_check_all_req_types
():
"""A helper function to check all request types are defined in this file."""
"""A helper function to check all request types are defined in this file."""
import
inspect
import
inspect
...
...
python/sglang/srt/managers/multi_tokenizer_mixin.py
View file @
2ac46e94
...
@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i):
...
@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i):
if
output
.
output_token_ids_logprobs_idx
if
output
.
output_token_ids_logprobs_idx
else
None
else
None
),
),
output_token_entropy_val
=
(
[
output
.
output_token_entropy_val
[
i
]]
if
output
.
output_token_entropy_val
else
None
),
output_hidden_states
=
(
output_hidden_states
=
(
[
output
.
output_hidden_states
[
i
]]
[
output
.
output_hidden_states
[
i
]]
if
output
.
output_hidden_states
if
output
.
output_hidden_states
...
@@ -197,6 +202,7 @@ def _handle_output_by_index(output, i):
...
@@ -197,6 +202,7 @@ def _handle_output_by_index(output, i):
),
),
placeholder_tokens_idx
=
None
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
placeholder_tokens_val
=
None
,
token_steps
=
([
output
.
token_steps
[
i
]]
if
output
.
token_steps
else
None
),
)
)
elif
isinstance
(
output
,
BatchEmbeddingOutput
):
elif
isinstance
(
output
,
BatchEmbeddingOutput
):
new_output
=
BatchEmbeddingOutput
(
new_output
=
BatchEmbeddingOutput
(
...
@@ -306,6 +312,11 @@ def _handle_output_by_index(output, i):
...
@@ -306,6 +312,11 @@ def _handle_output_by_index(output, i):
if
output
.
output_token_ids_logprobs_idx
if
output
.
output_token_ids_logprobs_idx
else
None
else
None
),
),
output_token_entropy_val
=
(
[
output
.
output_token_entropy_val
[
i
]]
if
output
.
output_token_entropy_val
else
None
),
output_hidden_states
=
(
output_hidden_states
=
(
[
output
.
output_hidden_states
[
i
]]
[
output
.
output_hidden_states
[
i
]]
if
output
.
output_hidden_states
if
output
.
output_hidden_states
...
@@ -313,6 +324,7 @@ def _handle_output_by_index(output, i):
...
@@ -313,6 +324,7 @@ def _handle_output_by_index(output, i):
),
),
placeholder_tokens_idx
=
None
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
placeholder_tokens_val
=
None
,
token_steps
=
([
output
.
token_steps
[
i
]]
if
output
.
token_steps
else
None
),
)
)
elif
isinstance
(
output
,
BatchMultimodalOutput
):
elif
isinstance
(
output
,
BatchMultimodalOutput
):
new_output
=
BatchMultimodalOutput
(
new_output
=
BatchMultimodalOutput
(
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
2ac46e94
...
@@ -920,7 +920,8 @@ class SchedulerOutputProcessorMixin:
...
@@ -920,7 +920,8 @@ class SchedulerOutputProcessorMixin:
input_token_ids_logprobs_idx
,
input_token_ids_logprobs_idx
,
output_token_ids_logprobs_val
,
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
,
output_token_ids_logprobs_idx
,
output_hidden_states
,
output_token_entropy_val
=
None
,
output_hidden_states
=
output_hidden_states
,
rids
=
rids
,
rids
=
rids
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
placeholder_tokens_val
=
None
,
...
...
python/sglang/srt/models/grok.py
View file @
2ac46e94
...
@@ -73,9 +73,6 @@ logger = logging.getLogger(__name__)
...
@@ -73,9 +73,6 @@ logger = logging.getLogger(__name__)
# Dump tensors for debugging
# Dump tensors for debugging
debug_tensor_dump_output_folder
=
None
debug_tensor_dump_output_folder
=
None
debug_tensor_dump_prefill_only
=
False
# Skip all the other tensor dumps, only dump the target logits
debug_tensor_dump_only_target_logprobs
=
False
debug_tensor_dump_inject
=
False
debug_tensor_dump_inject
=
False
debug_tensor_dump_layers
=
None
debug_tensor_dump_layers
=
None
debug_tensor_dump_test
=
False
debug_tensor_dump_test
=
False
...
...
python/sglang/srt/server_args.py
View file @
2ac46e94
...
@@ -455,7 +455,6 @@ class ServerArgs:
...
@@ -455,7 +455,6 @@ class ServerArgs:
debug_tensor_dump_output_folder
:
Optional
[
str
]
=
None
debug_tensor_dump_output_folder
:
Optional
[
str
]
=
None
debug_tensor_dump_input_file
:
Optional
[
str
]
=
None
debug_tensor_dump_input_file
:
Optional
[
str
]
=
None
debug_tensor_dump_inject
:
bool
=
False
debug_tensor_dump_inject
:
bool
=
False
debug_tensor_dump_prefill_only
:
bool
=
False
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode
:
Literal
[
"null"
,
"prefill"
,
"decode"
]
=
"null"
disaggregation_mode
:
Literal
[
"null"
,
"prefill"
,
"decode"
]
=
"null"
...
@@ -2831,11 +2830,6 @@ class ServerArgs:
...
@@ -2831,11 +2830,6 @@ class ServerArgs:
default
=
ServerArgs
.
debug_tensor_dump_inject
,
default
=
ServerArgs
.
debug_tensor_dump_inject
,
help
=
"Inject the outputs from jax as the input of every layer."
,
help
=
"Inject the outputs from jax as the input of every layer."
,
)
)
parser
.
add_argument
(
"--debug-tensor-dump-prefill-only"
,
action
=
"store_true"
,
help
=
"Only dump the tensors for prefill requests (i.e. batch size > 1)."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-dynamic-batch-tokenizer"
,
"--enable-dynamic-batch-tokenizer"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
python/sglang/test/test_cutlass_moe.py
View file @
2ac46e94
...
@@ -34,7 +34,7 @@ def get_model_config(tp_size: int):
...
@@ -34,7 +34,7 @@ def get_model_config(tp_size: int):
"topk"
:
topk
,
"topk"
:
topk
,
"hidden_size"
:
config
.
hidden_size
,
"hidden_size"
:
config
.
hidden_size
,
"shard_intermediate_size"
:
shard_intermediate_size
,
"shard_intermediate_size"
:
shard_intermediate_size
,
"dtype"
:
config
.
torch_
dtype
,
"dtype"
:
config
.
dtype
,
"block_shape"
:
config
.
quantization_config
[
"weight_block_size"
],
"block_shape"
:
config
.
quantization_config
[
"weight_block_size"
],
}
}
...
...
python/sglang/test/test_deterministic_utils.py
View file @
2ac46e94
import
time
import
unittest
import
unittest
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_deterministic
import
BenchArgs
,
test_deterministic
from
sglang.test.test_deterministic
import
BenchArgs
,
test_deterministic
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
...
@@ -55,6 +52,7 @@ class TestDeterministicBase(CustomTestCase):
...
@@ -55,6 +52,7 @@ class TestDeterministicBase(CustomTestCase):
args
.
n_start
=
10
args
.
n_start
=
10
args
.
n_trials
=
20
args
.
n_trials
=
20
results
=
test_deterministic
(
args
)
results
=
test_deterministic
(
args
)
args
.
temperature
=
0.5
# test for deterministic sampling
for
result
in
results
:
for
result
in
results
:
assert
result
==
1
assert
result
==
1
...
@@ -65,6 +63,7 @@ class TestDeterministicBase(CustomTestCase):
...
@@ -65,6 +63,7 @@ class TestDeterministicBase(CustomTestCase):
args
.
test_mode
=
"mixed"
args
.
test_mode
=
"mixed"
args
.
n_start
=
10
args
.
n_start
=
10
args
.
n_trials
=
20
args
.
n_trials
=
20
args
.
temperature
=
0.5
# test for deterministic sampling
results
=
test_deterministic
(
args
)
results
=
test_deterministic
(
args
)
for
result
in
results
:
for
result
in
results
:
assert
result
==
1
assert
result
==
1
...
@@ -76,6 +75,7 @@ class TestDeterministicBase(CustomTestCase):
...
@@ -76,6 +75,7 @@ class TestDeterministicBase(CustomTestCase):
args
.
test_mode
=
"prefix"
args
.
test_mode
=
"prefix"
args
.
n_start
=
10
args
.
n_start
=
10
args
.
n_trials
=
10
args
.
n_trials
=
10
args
.
temperature
=
0.5
# test for deterministic sampling
results
=
test_deterministic
(
args
)
results
=
test_deterministic
(
args
)
for
result
in
results
:
for
result
in
results
:
assert
result
==
1
assert
result
==
1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment