Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ca929118
Unverified
Commit
ca929118
authored
Jun 10, 2025
by
Brayden Zhong
Committed by
GitHub
Jun 10, 2025
Browse files
[Feature] Add Logit Bias (#6579)
Co-authored-by:
Cinjon Resnick
<
cinjon.resnick@gmail.com
>
parent
344adb00
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
183 additions
and
0 deletions
+183
-0
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+2
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+24
-0
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+2
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+39
-0
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+116
-0
No files found.
python/sglang/srt/openai_api/adapter.py
View file @
ca929118
...
@@ -582,6 +582,7 @@ def v1_generate_request(
...
@@ -582,6 +582,7 @@ def v1_generate_request(
"no_stop_trim"
:
request
.
no_stop_trim
,
"no_stop_trim"
:
request
.
no_stop_trim
,
"ignore_eos"
:
request
.
ignore_eos
,
"ignore_eos"
:
request
.
ignore_eos
,
"skip_special_tokens"
:
request
.
skip_special_tokens
,
"skip_special_tokens"
:
request
.
skip_special_tokens
,
"logit_bias"
:
request
.
logit_bias
,
}
}
)
)
return_logprobs
.
append
(
request
.
logprobs
is
not
None
)
return_logprobs
.
append
(
request
.
logprobs
is
not
None
)
...
@@ -1219,6 +1220,7 @@ def v1_chat_generate_request(
...
@@ -1219,6 +1220,7 @@ def v1_chat_generate_request(
"no_stop_trim"
:
request
.
no_stop_trim
,
"no_stop_trim"
:
request
.
no_stop_trim
,
"ignore_eos"
:
request
.
ignore_eos
,
"ignore_eos"
:
request
.
ignore_eos
,
"skip_special_tokens"
:
request
.
skip_special_tokens
,
"skip_special_tokens"
:
request
.
skip_special_tokens
,
"logit_bias"
:
request
.
logit_bias
,
}
}
if
request
.
response_format
and
request
.
response_format
.
type
==
"json_schema"
:
if
request
.
response_format
and
request
.
response_format
.
type
==
"json_schema"
:
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
ca929118
...
@@ -10,6 +10,7 @@ import torch
...
@@ -10,6 +10,7 @@ import torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.sampling_params
import
TOP_K_ALL
from
sglang.srt.sampling.sampling_params
import
TOP_K_ALL
from
sglang.srt.utils
import
merge_bias_tensor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
@@ -63,6 +64,9 @@ class SamplingBatchInfo:
...
@@ -63,6 +64,9 @@ class SamplingBatchInfo:
# Device
# Device
device
:
str
=
"cuda"
device
:
str
=
"cuda"
# Handle logit bias
logit_bias
:
Optional
[
torch
.
Tensor
]
=
None
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
...
@@ -85,6 +89,14 @@ class SamplingBatchInfo:
...
@@ -85,6 +89,14 @@ class SamplingBatchInfo:
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
).
to
(
device
,
non_blocking
=
True
)
).
to
(
device
,
non_blocking
=
True
)
logit_bias
=
None
if
any
(
r
.
sampling_params
.
logit_bias
is
not
None
for
r
in
reqs
):
logit_bias
=
torch
.
zeros
(
len
(
reqs
),
vocab_size
,
device
=
device
)
for
i
,
r
in
enumerate
(
reqs
):
if
r
.
sampling_params
.
logit_bias
is
not
None
:
for
key
,
value
in
r
.
sampling_params
.
logit_bias
.
items
():
logit_bias
[
i
,
int
(
key
)]
=
value
# Check if any request has custom logit processor
# Check if any request has custom logit processor
has_custom_logit_processor
=
(
has_custom_logit_processor
=
(
batch
.
enable_custom_logit_processor
# check the flag first.
batch
.
enable_custom_logit_processor
# check the flag first.
...
@@ -150,6 +162,7 @@ class SamplingBatchInfo:
...
@@ -150,6 +162,7 @@ class SamplingBatchInfo:
custom_params
=
custom_params
,
custom_params
=
custom_params
,
custom_logit_processor
=
merged_custom_logit_processor
,
custom_logit_processor
=
merged_custom_logit_processor
,
device
=
device
,
device
=
device
,
logit_bias
=
logit_bias
,
)
)
return
ret
return
ret
...
@@ -206,6 +219,9 @@ class SamplingBatchInfo:
...
@@ -206,6 +219,9 @@ class SamplingBatchInfo:
if
self
.
vocab_mask
is
not
None
:
if
self
.
vocab_mask
is
not
None
:
self
.
apply_mask_func
(
logits
=
logits
,
vocab_mask
=
self
.
vocab_mask
)
self
.
apply_mask_func
(
logits
=
logits
,
vocab_mask
=
self
.
vocab_mask
)
if
self
.
logit_bias
is
not
None
:
logits
.
add_
(
self
.
logit_bias
)
def
filter_batch
(
self
,
keep_indices
:
List
[
int
],
keep_indices_device
:
torch
.
Tensor
):
def
filter_batch
(
self
,
keep_indices
:
List
[
int
],
keep_indices_device
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
keep_indices_device
)
self
.
penalizer_orchestrator
.
filter
(
keep_indices_device
)
...
@@ -221,6 +237,9 @@ class SamplingBatchInfo:
...
@@ -221,6 +237,9 @@ class SamplingBatchInfo:
value
=
getattr
(
self
,
item
,
None
)
value
=
getattr
(
self
,
item
,
None
)
setattr
(
self
,
item
,
value
[
keep_indices_device
])
setattr
(
self
,
item
,
value
[
keep_indices_device
])
if
self
.
logit_bias
is
not
None
:
self
.
logit_bias
=
self
.
logit_bias
[
keep_indices_device
]
def
_filter_batch_custom_logit_processor
(
def
_filter_batch_custom_logit_processor
(
self
,
keep_indices
:
List
[
int
],
keep_indices_device
:
torch
.
Tensor
self
,
keep_indices
:
List
[
int
],
keep_indices_device
:
torch
.
Tensor
):
):
...
@@ -321,3 +340,8 @@ class SamplingBatchInfo:
...
@@ -321,3 +340,8 @@ class SamplingBatchInfo:
self
.
need_top_p_sampling
|=
other
.
need_top_p_sampling
self
.
need_top_p_sampling
|=
other
.
need_top_p_sampling
self
.
need_top_k_sampling
|=
other
.
need_top_k_sampling
self
.
need_top_k_sampling
|=
other
.
need_top_k_sampling
self
.
need_min_p_sampling
|=
other
.
need_min_p_sampling
self
.
need_min_p_sampling
|=
other
.
need_min_p_sampling
# Merge logit bias
self
.
logit_bias
=
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
,
0.0
)
python/sglang/srt/sampling/sampling_params.py
View file @
ca929118
...
@@ -52,6 +52,7 @@ class SamplingParams:
...
@@ -52,6 +52,7 @@ class SamplingParams:
no_stop_trim
:
bool
=
False
,
no_stop_trim
:
bool
=
False
,
custom_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
custom_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
stream_interval
:
Optional
[
int
]
=
None
,
stream_interval
:
Optional
[
int
]
=
None
,
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
,
)
->
None
:
)
->
None
:
self
.
max_new_tokens
=
max_new_tokens
self
.
max_new_tokens
=
max_new_tokens
self
.
stop_strs
=
stop
self
.
stop_strs
=
stop
...
@@ -78,6 +79,7 @@ class SamplingParams:
...
@@ -78,6 +79,7 @@ class SamplingParams:
self
.
no_stop_trim
=
no_stop_trim
self
.
no_stop_trim
=
no_stop_trim
self
.
custom_params
=
custom_params
self
.
custom_params
=
custom_params
self
.
stream_interval
=
stream_interval
self
.
stream_interval
=
stream_interval
self
.
logit_bias
=
logit_bias
# Process some special cases
# Process some special cases
if
0
<=
self
.
temperature
<
_SAMPLING_EPS
:
if
0
<=
self
.
temperature
<
_SAMPLING_EPS
:
...
...
python/sglang/srt/utils.py
View file @
ca929118
...
@@ -2210,6 +2210,45 @@ class Withable(Generic[T]):
...
@@ -2210,6 +2210,45 @@ class Withable(Generic[T]):
self
.
_value
=
None
self
.
_value
=
None
def
merge_bias_tensor
(
lhs
:
Optional
[
torch
.
Tensor
],
rhs
:
Optional
[
torch
.
Tensor
],
bs1
:
int
,
bs2
:
int
,
device
:
str
,
default
:
float
,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if
lhs
is
None
and
rhs
is
None
:
return
None
if
lhs
is
not
None
and
rhs
is
not
None
:
return
torch
.
cat
([
lhs
,
rhs
])
else
:
if
lhs
is
not
None
:
shape
,
dtype
=
lhs
.
shape
[
1
:],
lhs
.
dtype
else
:
shape
,
dtype
=
rhs
.
shape
[
1
:],
rhs
.
dtype
if
lhs
is
None
:
lhs
=
torch
.
empty
((
bs1
,
*
shape
),
device
=
device
,
dtype
=
dtype
).
fill_
(
default
)
if
rhs
is
None
:
rhs
=
torch
.
empty
((
bs2
,
*
shape
),
device
=
device
,
dtype
=
dtype
).
fill_
(
default
)
return
torch
.
cat
([
lhs
,
rhs
])
def
find_local_repo_dir
(
repo_id
:
str
,
revision
:
Optional
[
str
]
=
None
)
->
Optional
[
str
]:
def
find_local_repo_dir
(
repo_id
:
str
,
revision
:
Optional
[
str
]
=
None
)
->
Optional
[
str
]:
import
huggingface_hub
as
hf
import
huggingface_hub
as
hf
...
...
test/srt/test_srt_endpoint.py
View file @
ca929118
...
@@ -504,6 +504,122 @@ class TestSRTEndpoint(CustomTestCase):
...
@@ -504,6 +504,122 @@ class TestSRTEndpoint(CustomTestCase):
version
=
response_json
[
"version"
]
version
=
response_json
[
"version"
]
self
.
assertIsInstance
(
version
,
str
)
self
.
assertIsInstance
(
version
,
str
)
def
test_logit_bias
(
self
):
"""Test that a very high logit bias forces sampling of a specific token."""
# Choose a token ID to bias (using 5 as an example)
target_token_id
=
60704
# Paris for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
logit_bias
=
{
str
(
target_token_id
):
100.0
}
# Very high positive bias
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
1.0
,
# Use high temperature to encourage exploration
"max_new_tokens"
:
4
,
"logit_bias"
:
logit_bias
,
},
"return_logprob"
:
True
,
},
)
response_json
=
response
.
json
()
# Extract the sampled token IDs from the output
output_token_logprobs
=
response_json
[
"meta_info"
][
"output_token_logprobs"
]
sampled_tokens
=
[
x
[
1
]
for
x
in
output_token_logprobs
]
# Verify that all sampled tokens are the target token
self
.
assertTrue
(
all
(
x
==
target_token_id
for
x
in
sampled_tokens
),
f
"Expected all tokens to be
{
target_token_id
}
, but got
{
sampled_tokens
}
"
,
)
def
test_forbidden_token
(
self
):
"""Test that a forbidden token (very negative logit bias) doesn't appear in the output."""
# Choose a token ID to forbid (using 10 as an example)
forbidden_token_id
=
23994
# rice for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
logit_bias
=
{
str
(
forbidden_token_id
):
-
100.0
}
# Very negative bias to forbid the token
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"Only output 'rice' exactly like this, in lowercase ONLY: rice"
,
"sampling_params"
:
{
"temperature"
:
1.0
,
# Use high temperature to encourage diverse output
"max_new_tokens"
:
50
,
# Generate enough tokens to likely include numbers
"logit_bias"
:
logit_bias
,
},
"return_logprob"
:
True
,
},
)
response_json
=
response
.
json
()
# Extract the sampled token IDs from the output
output_token_logprobs
=
response_json
[
"meta_info"
][
"output_token_logprobs"
]
sampled_tokens
=
[
x
[
1
]
for
x
in
output_token_logprobs
]
# Verify that the forbidden token doesn't appear in the output
self
.
assertNotIn
(
forbidden_token_id
,
sampled_tokens
,
f
"Expected forbidden token
{
forbidden_token_id
}
not to be present, but it was found"
,
)
def
test_logit_bias_isolation
(
self
):
"""Test that logit_bias applied to one request doesn't affect other requests in batch."""
# Choose a token ID to bias in first request only
biased_token_id
=
60704
# Paris for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# Prepare batch requests - one with logit_bias and one without
requests_data
=
[
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
1.0
,
"max_new_tokens"
:
4
,
"logit_bias"
:
{
str
(
biased_token_id
):
100.0
},
# Strong bias
},
"return_logprob"
:
True
,
},
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
1.0
,
"max_new_tokens"
:
4
,
},
"return_logprob"
:
True
,
},
]
# Send both requests
responses
=
[]
for
req
in
requests_data
:
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
req
)
responses
.
append
(
response
.
json
())
# Extract token IDs from each response
biased_tokens
=
[
x
[
1
]
for
x
in
responses
[
0
][
"meta_info"
][
"output_token_logprobs"
]
]
unbiased_tokens
=
[
x
[
1
]
for
x
in
responses
[
1
][
"meta_info"
][
"output_token_logprobs"
]
]
# Verify first response contains only biased tokens
self
.
assertTrue
(
all
(
x
==
biased_token_id
for
x
in
biased_tokens
),
f
"Expected all tokens to be
{
biased_token_id
}
in first response, but got
{
biased_tokens
}
"
,
)
# Verify second response contains at least some different tokens
# (We can't guarantee exactly what tokens will be generated, but they shouldn't all be the biased token)
self
.
assertTrue
(
any
(
x
!=
biased_token_id
for
x
in
unbiased_tokens
),
f
"Expected some tokens to be different from
{
biased_token_id
}
in second response, but got
{
unbiased_tokens
}
"
,
)
def
test_get_server_info_concurrent
(
self
):
def
test_get_server_info_concurrent
(
self
):
"""Make sure the concurrent get_server_info doesn't crash the server."""
"""Make sure the concurrent get_server_info doesn't crash the server."""
tp
=
ThreadPoolExecutor
(
max_workers
=
30
)
tp
=
ThreadPoolExecutor
(
max_workers
=
30
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment