Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
608668e1
Unverified
Commit
608668e1
authored
Jun 08, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 08, 2025
Browse files
Slightly improve the sampler to skip unnecessary steps (#6956)
parent
6c0a4828
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
109 additions
and
93 deletions
+109
-93
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+81
-75
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-1
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+1
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+13
-1
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+2
-1
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+2
-7
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+9
-7
No files found.
python/sglang/srt/layers/sampler.py
View file @
608668e1
...
...
@@ -5,7 +5,7 @@ import torch
import
torch.distributed
as
dist
from
torch
import
nn
from
sglang.srt.distributed
import
get_t
ensor_model_parallel
_group
from
sglang.srt.distributed
import
get_t
p
_group
from
sglang.srt.layers.dp_attention
import
get_attention_tp_group
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
...
@@ -30,7 +30,7 @@ class Sampler(nn.Module):
def
__init__
(
self
):
super
().
__init__
()
self
.
use_nan_detection
=
global_server_args_dict
[
"enable_nan_detection"
]
self
.
tp_sync_group
=
get_t
ensor_model_parallel
_group
().
device_group
self
.
tp_sync_group
=
get_t
p
_group
().
device_group
if
global_server_args_dict
[
"enable_dp_attention"
]:
self
.
tp_sync_group
=
get_attention_tp_group
().
device_group
...
...
@@ -59,7 +59,7 @@ class Sampler(nn.Module):
# Apply the custom logit processors if registered in the sampling info.
if
sampling_info
.
has_custom_logit_processor
:
self
.
_
apply_custom_logit_processor
(
logits
,
sampling_info
)
apply_custom_logit_processor
(
logits
,
sampling_info
)
if
self
.
use_nan_detection
and
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
...
...
@@ -81,49 +81,39 @@ class Sampler(nn.Module):
probs
=
logits
del
logits
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
if
return_logprob
:
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
# https://github.com/flashinfer-ai/flashinfer/issues/708
# so we use the torch implementation.
# NOTE: OpenAI's logprobs is independent of top-p, we use the
# same rule.
logprobs
=
torch
.
log
(
probs
).
clamp
(
min
=
torch
.
finfo
(
probs
.
dtype
).
min
)
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
if
sampling_info
.
need_min_p_sampling
:
probs
=
top_k_renorm_prob
(
probs
,
sampling_info
.
top_ks
)
probs
=
top_p_renorm_prob
(
probs
,
sampling_info
.
top_ps
)
batch_next_token_ids
=
min_p_sampling_from_probs
(
probs
,
sampling_info
.
min_ps
)
else
:
# Check Nan will throw exception, only check when crash_on_warnings is True
check_nan
=
self
.
use_nan_detection
and
crash_on_warnings
()
batch_next_token_ids
=
top_k_top_p_sampling_from_probs
(
probs
.
contiguous
(),
if
True
:
# Keep this redundant check to simplify some internal code sync
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
if
sampling_info
.
need_min_p_sampling
:
probs
=
top_k_renorm_prob
(
probs
,
sampling_info
.
top_ks
)
probs
=
top_p_renorm_prob
(
probs
,
sampling_info
.
top_ps
)
batch_next_token_ids
=
min_p_sampling_from_probs
(
probs
,
sampling_info
.
min_ps
)
else
:
batch_next_token_ids
=
top_k_top_p_sampling_from_probs
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
filter_apply_order
=
"joint"
,
check_nan
=
self
.
use_nan_detection
,
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# A slower fallback implementation with torch native operations.
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
filter_apply_order
=
"joint"
,
check_nan
=
check_nan
,
sampling_info
.
min_ps
,
sampling_info
.
need_min_p_sampling
,
)
else
:
raise
ValueError
(
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# A slower fallback implementation with torch native operations.
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
,
sampling_info
.
need_min_p_sampling
,
)
if
return_logprob
:
logprobs
=
torch
.
log
(
probs
).
clamp
(
min
=
torch
.
finfo
(
probs
.
dtype
).
min
)
else
:
raise
ValueError
(
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
)
if
return_logprob
:
# clamp to avoid -inf
logprobs
=
torch
.
log
(
probs
).
clamp
(
min
=
torch
.
finfo
(
probs
.
dtype
).
min
)
# Attach logprobs to logits_output (in-place modification)
if
return_logprob
:
...
...
@@ -160,39 +150,6 @@ class Sampler(nn.Module):
return
batch_next_token_ids
def
_apply_custom_logit_processor
(
self
,
logits
:
torch
.
Tensor
,
sampling_batch_info
:
SamplingBatchInfo
):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
assert
logits
.
shape
[
0
]
==
len
(
sampling_batch_info
),
(
f
"The batch size of logits (
{
logits
.
shape
[
0
]
}
) does not match the batch size of "
f
"sampling_batch_info (
{
len
(
sampling_batch_info
)
}
)"
)
for
_
,
(
processor
,
batch_mask
,
)
in
sampling_batch_info
.
custom_logit_processor
.
items
():
# Get the batch indices that need to be processed
batch_indices
=
batch_mask
.
nonzero
(
as_tuple
=
True
)[
0
]
assert
batch_mask
.
shape
[
0
]
==
len
(
sampling_batch_info
),
(
f
"The number of batch mask (
{
batch_mask
.
shape
[
0
]
}
) does not match the number of "
f
"sampling_batch_info (
{
len
(
sampling_batch_info
)
}
)"
)
# Apply the processor to the logits
logits
[
batch_mask
]
=
processor
(
logits
[
batch_mask
],
[
sampling_batch_info
.
custom_params
[
i
]
for
i
in
batch_indices
],
)
logger
.
debug
(
f
"Custom logit processor
{
processor
.
__class__
.
__name__
}
is applied."
)
def
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
...
...
@@ -221,6 +178,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
return
batch_next_token_ids
def
sampling_from_probs_torch
(
probs
:
torch
.
Tensor
):
"""A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering."""
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
batch_next_token_ids
=
sampled_index
.
view
(
-
1
).
to
(
torch
.
int32
)
return
batch_next_token_ids
def
top_p_normalize_probs_torch
(
probs
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
...
...
@@ -259,3 +224,44 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
output_token_ids_logprobs_idx
.
append
([])
return
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
def
apply_custom_logit_processor
(
logits
:
torch
.
Tensor
,
sampling_batch_info
:
SamplingBatchInfo
,
num_tokens_in_batch
:
int
=
1
,
):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place.
num_tokens_in_batch is needed to support spec decoding, where each batch can contain multiple
tokens. By default, we assume each batch contains only 1 token.
"""
assert
logits
.
shape
[
0
]
==
len
(
sampling_batch_info
)
*
num_tokens_in_batch
,
(
f
"The batch size of logits (
{
logits
.
shape
[
0
]
}
) does not match the batch size of "
f
"sampling_batch_info (
{
len
(
sampling_batch_info
)
}
) x num_tokens_in_batch "
f
"(
{
num_tokens_in_batch
}
)"
)
for
_
,
(
processor
,
batch_mask
,
)
in
sampling_batch_info
.
custom_logit_processor
.
items
():
# Get the batch indices that need to be processed
batch_indices
=
batch_mask
.
nonzero
(
as_tuple
=
True
)[
0
]
assert
batch_mask
.
shape
[
0
]
==
len
(
sampling_batch_info
),
(
f
"The number of batch mask (
{
batch_mask
.
shape
[
0
]
}
) does not match the number of "
f
"sampling_batch_info (
{
len
(
sampling_batch_info
)
}
)"
)
batch_mask
=
torch
.
repeat_interleave
(
batch_mask
,
num_tokens_in_batch
)
# Apply the processor to the logits
logits
[
batch_mask
]
=
processor
(
logits
[
batch_mask
],
[
sampling_batch_info
.
custom_params
[
i
]
for
i
in
batch_indices
],
)
logger
.
debug
(
f
"Custom logit processor
{
processor
.
__class__
.
__name__
}
is applied."
)
python/sglang/srt/managers/tokenizer_manager.py
View file @
608668e1
...
...
@@ -852,7 +852,7 @@ class TokenizerManager:
obj
.
load_format
=
self
.
server_args
.
load_format
logger
.
info
(
"Start update_weights. Load format=%s"
,
obj
.
load_format
)
if
True
:
if
True
:
# Keep this redundant check to simplify some internal code sync
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
...
...
python/sglang/srt/models/llama.py
View file @
608668e1
...
...
@@ -17,7 +17,7 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
608668e1
...
...
@@ -9,10 +9,12 @@ import torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.sampling_params
import
TOP_K_ALL
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -27,6 +29,12 @@ class SamplingBatchInfo:
# Whether all requests use greedy sampling
is_all_greedy
:
bool
# Whether any requests use top_p sampling
need_top_p_sampling
:
bool
# Whether any requests use top_k sampling
need_top_k_sampling
:
bool
# Whether any request needs min_p sampling
need_min_p_sampling
:
bool
...
...
@@ -133,6 +141,8 @@ class SamplingBatchInfo:
top_ks
=
top_ks
,
min_ps
=
min_ps
,
is_all_greedy
=
all
(
r
.
sampling_params
.
top_k
<=
1
for
r
in
reqs
),
need_top_p_sampling
=
any
(
r
.
sampling_params
.
top_p
!=
1.0
for
r
in
reqs
),
need_top_k_sampling
=
any
(
r
.
sampling_params
.
top_k
!=
TOP_K_ALL
for
r
in
reqs
),
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
),
vocab_size
=
vocab_size
,
penalizer_orchestrator
=
penalizer_orchestrator
,
...
...
@@ -167,7 +177,7 @@ class SamplingBatchInfo:
# Apply the mask
for
i
,
grammar
in
enumerate
(
self
.
grammars
):
if
grammar
and
not
grammar
.
finished
:
if
grammar
and
not
grammar
.
finished
and
not
grammar
.
is_terminated
()
:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
# Move the mask to the device if needed
...
...
@@ -308,4 +318,6 @@ class SamplingBatchInfo:
setattr
(
self
,
item
,
torch
.
cat
([
self_val
,
other_val
]))
self
.
is_all_greedy
&=
other
.
is_all_greedy
self
.
need_top_p_sampling
|=
other
.
need_top_p_sampling
self
.
need_top_k_sampling
|=
other
.
need_top_k_sampling
self
.
need_min_p_sampling
|=
other
.
need_min_p_sampling
python/sglang/srt/sampling/sampling_params.py
View file @
608668e1
...
...
@@ -16,6 +16,7 @@
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
_SAMPLING_EPS
=
1e-6
TOP_K_ALL
=
1
<<
30
class
SamplingParams
:
...
...
@@ -84,7 +85,7 @@ class SamplingParams:
self
.
temperature
=
1.0
self
.
top_k
=
1
if
self
.
top_k
==
-
1
:
self
.
top_k
=
1
<<
30
# whole vocabulary
self
.
top_k
=
TOP_K_ALL
# whole vocabulary
def
verify
(
self
):
if
self
.
temperature
<
0.0
:
...
...
python/sglang/srt/two_batch_overlap.py
View file @
608668e1
import
dataclasses
import
logging
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
from
typing
import
Dict
,
List
,
Optional
,
Sequence
import
torch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.communicator
import
(
CommunicateContext
,
CommunicateSimpleFn
,
CommunicateSummableTensorPairFn
,
ScatterMode
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.quantization.deep_gemm
import
configure_deep_gemm_num_sms
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
...
@@ -20,9 +18,6 @@ from sglang.srt.operations import execute_operations, execute_overlapped_operati
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.utils
import
BumpAllocator
,
DeepEPMode
,
get_bool_env_var
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
_tbo_debug
=
get_bool_env_var
(
"SGLANG_TBO_DEBUG"
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -46,7 +41,7 @@ def compute_split_seq_index(
assert
num_tokens
==
0
return
0
else
:
raise
NotImplementedError
raise
NotImplementedError
()
def
_split_array_by_half_sum
(
arr
:
Sequence
[
int
])
->
int
:
...
...
python/sglang/srt/utils.py
View file @
608668e1
...
...
@@ -1928,16 +1928,18 @@ def next_power_of_2(n: int):
setattr
(
triton
,
"next_power_of_2"
,
next_power_of_2
)
@
contextmanager
def
empty_context
(
*
args
,
**
kwargs
):
try
:
# Setup code goes here
yield
finally
:
# Cleanup code goes here
class
EmptyContextManager
:
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
pass
def
empty_context
(
*
args
,
**
kwargs
):
return
EmptyContextManager
()
def
add_prefix
(
name
:
str
,
prefix
:
str
)
->
str
:
"""Add a weight path prefix to a module name.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment