Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
583697cd
Unverified
Commit
583697cd
authored
Jan 20, 2025
by
Hongpeng Guo
Committed by
GitHub
Jan 20, 2025
Browse files
[Enhancement] Custom Logit Processor Improvement (#2998)
Signed-off-by:
Hongpeng Guo
<
hpguo@anyscale.com
>
parent
2584f6d9
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
79 additions
and
28 deletions
+79
-28
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+1
-0
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+10
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+34
-19
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+26
-9
No files found.
python/sglang/bench_one_batch.py
View file @
583697cd
...
...
@@ -232,6 +232,7 @@ def extend(reqs, model_runner):
model_config
=
model_runner
.
model_config
,
enable_overlap
=
False
,
spec_algorithm
=
SpeculativeAlgorithm
.
NONE
,
enable_custom_logit_processor
=
False
,
)
batch
.
prepare_for_extend
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
python/sglang/srt/layers/sampler.py
View file @
583697cd
...
...
@@ -132,6 +132,11 @@ class Sampler(nn.Module):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
assert
logits
.
shape
[
0
]
==
len
(
sampling_batch_info
),
(
f
"The batch size of logits (
{
logits
.
shape
[
0
]
}
) does not match the batch size of "
f
"sampling_batch_info (
{
len
(
sampling_batch_info
)
}
)"
)
for
_
,
(
processor
,
batch_mask
,
...
...
@@ -139,6 +144,11 @@ class Sampler(nn.Module):
# Get the batch indices that need to be processed
batch_indices
=
batch_mask
.
nonzero
(
as_tuple
=
True
)[
0
]
assert
batch_mask
.
shape
[
0
]
==
len
(
sampling_batch_info
),
(
f
"The number of batch mask (
{
batch_mask
.
shape
[
0
]
}
) does not match the number of "
f
"sampling_batch_info (
{
len
(
sampling_batch_info
)
}
)"
)
# Apply the processor to the logits
logits
[
batch_mask
]
=
processor
(
logits
[
batch_mask
],
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
583697cd
...
...
@@ -595,6 +595,9 @@ class ScheduleBatch:
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
SpecInfo
]
=
None
# Enable custom logit processor
enable_custom_logit_processor
:
bool
=
False
@
classmethod
def
init_new
(
cls
,
...
...
@@ -605,6 +608,7 @@ class ScheduleBatch:
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
spec_algorithm
:
SpeculativeAlgorithm
,
enable_custom_logit_processor
:
bool
,
):
return
cls
(
reqs
=
reqs
,
...
...
@@ -618,6 +622,7 @@ class ScheduleBatch:
has_grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
device
=
req_to_token_pool
.
device
,
spec_algorithm
=
spec_algorithm
,
enable_custom_logit_processor
=
enable_custom_logit_processor
,
)
def
batch_size
(
self
):
...
...
@@ -1201,6 +1206,7 @@ class ScheduleBatch:
return_logprob
=
self
.
return_logprob
,
decoding_reqs
=
self
.
decoding_reqs
,
spec_algorithm
=
self
.
spec_algorithm
,
enable_custom_logit_processor
=
self
.
enable_custom_logit_processor
,
)
def
__str__
(
self
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
583697cd
...
...
@@ -966,6 +966,7 @@ class Scheduler:
self
.
model_config
,
self
.
enable_overlap
,
self
.
spec_algorithm
,
self
.
server_args
.
enable_custom_logit_processor
,
)
new_batch
.
prepare_for_extend
()
...
...
@@ -1520,6 +1521,7 @@ class Scheduler:
self
.
model_config
,
self
.
enable_overlap
,
self
.
spec_algorithm
,
self
.
server_args
.
enable_custom_logit_processor
,
)
idle_batch
.
prepare_for_idle
()
return
idle_batch
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
583697cd
...
...
@@ -89,7 +89,10 @@ class SamplingBatchInfo:
).
to
(
device
,
non_blocking
=
True
)
# Check if any request has custom logit processor
has_custom_logit_processor
=
any
(
r
.
custom_logit_processor
for
r
in
reqs
)
has_custom_logit_processor
=
(
batch
.
enable_custom_logit_processor
# check the flag first.
and
any
(
r
.
custom_logit_processor
for
r
in
reqs
)
# then check the requests.
)
if
has_custom_logit_processor
:
# Merge the same type of custom logit processors together
...
...
@@ -247,8 +250,7 @@ class SamplingBatchInfo:
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
"""Filter the custom logit processor and custom params"""
if
not
self
.
custom_logit_processor
:
return
self
.
custom_logit_processor
=
{
k
:
(
p
,
mask
[
new_indices
])
for
k
,
(
p
,
mask
)
in
self
.
custom_logit_processor
.
items
()
...
...
@@ -258,7 +260,9 @@ class SamplingBatchInfo:
}
self
.
custom_params
=
[
self
.
custom_params
[
i
]
for
i
in
unfinished_indices
]
if
len
(
self
)
==
0
:
# If the custom logit processor is an empty dict, set the flag to False,
# and set the custom logit processor and custom params to None.
if
len
(
self
.
custom_logit_processor
)
==
0
:
self
.
custom_logit_processor
=
None
self
.
custom_params
=
None
self
.
has_custom_logit_processor
=
False
...
...
@@ -290,8 +294,8 @@ class SamplingBatchInfo:
@
staticmethod
def
merge_custom_logit_processor
(
lhs
:
Optional
[
Dict
[
st
r
,
torch
.
Tensor
]],
rhs
:
Optional
[
Dict
[
st
r
,
torch
.
Tensor
]],
lhs
:
Optional
[
Dict
[
int
,
Tuple
[
CustomLogitProcesso
r
,
torch
.
Tensor
]]
]
,
rhs
:
Optional
[
Dict
[
int
,
Tuple
[
CustomLogitProcesso
r
,
torch
.
Tensor
]]
]
,
bs1
:
int
,
bs2
:
int
,
device
:
str
,
...
...
@@ -319,27 +323,22 @@ class SamplingBatchInfo:
)
merged_dict
[
k
]
=
(
processor
,
torch
.
cat
([
left_mask
,
right_mask
]))
assert
merged_dict
[
k
][
1
].
shape
[
0
]
==
bs1
+
bs2
,
(
f
"The batch size of merged mask (
{
merged_dict
[
k
][
1
].
shape
[
0
]
}
) does not match "
f
"the sum of the batch sizes of the two masks (
{
bs1
+
bs2
}
)"
f
"
\n
{
left_mask
=
}
\n
{
right_mask
=
}
\n
{
bs1
=
}
\n
{
bs2
=
}
"
f
"
\n
{
lhs
=
}
\n
{
rhs
=
}
"
)
return
merged_dict
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
self
.
is_all_greedy
=
self
.
is_all_greedy
and
other
.
is_all_greedy
# Merge the logit bias tensor
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
)
self
.
need_min_p_sampling
=
self
.
need_min_p_sampling
or
other
.
need_min_p_sampling
# Merge the custom logit processors and custom params lists
if
self
.
has_custom_logit_processor
or
other
.
has_custom_logit_processor
:
# Merge the custom logit processors
...
...
@@ -360,6 +359,22 @@ class SamplingBatchInfo:
# Set the flag to True if any of the two has custom logit processor
self
.
has_custom_logit_processor
=
True
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
# please make sure any merge operation with len(self) or len(other) is done before
# the merge operation of the temperatures tensor below.
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
self
.
is_all_greedy
=
self
.
is_all_greedy
and
other
.
is_all_greedy
self
.
need_min_p_sampling
=
self
.
need_min_p_sampling
or
other
.
need_min_p_sampling
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
):
# Apply logit_bias
if
self
.
logit_bias
is
not
None
:
...
...
test/srt/test_srt_endpoint.py
View file @
583697cd
...
...
@@ -4,8 +4,10 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
"""
import
json
import
random
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Optional
import
numpy
as
np
import
requests
...
...
@@ -253,8 +255,11 @@ class TestSRTEndpoint(unittest.TestCase):
self
.
assertTrue
(
all
(
x
is
not
None
for
x
in
logprobs
))
def
run_custom_logit_processor
(
self
,
target_token_id
:
int
):
"""Test custom logit processor with custom params."""
def
run_custom_logit_processor
(
self
,
target_token_id
:
Optional
[
int
]
=
None
):
"""Test custom logit processor with custom params.
If target_token_id is None, the custom logit processor won't be passed in.
"""
custom_params
=
{
"token_id"
:
target_token_id
}
...
...
@@ -285,7 +290,11 @@ class TestSRTEndpoint(unittest.TestCase):
# Custom json data with custom logit processor and params.
custom_json
=
base_json
.
copy
()
custom_json
[
"custom_logit_processor"
]
=
DeterministicLogitProcessor
().
to_str
()
# Only set the custom logit processor if target_token_id is not None.
if
target_token_id
is
not
None
:
custom_json
[
"custom_logit_processor"
]
=
(
DeterministicLogitProcessor
().
to_str
()
)
custom_json
[
"sampling_params"
][
"custom_params"
]
=
custom_params
custom_response
=
requests
.
post
(
...
...
@@ -297,22 +306,30 @@ class TestSRTEndpoint(unittest.TestCase):
sampled_tokens
=
[
x
[
1
]
for
x
in
output_token_logprobs
]
# The logit processor should always sample the given token as the logits is deterministic.
self
.
assertTrue
(
all
(
x
==
custom_params
[
"token_id"
]
for
x
in
sampled_tokens
))
if
target_token_id
is
not
None
:
self
.
assertTrue
(
all
(
x
==
custom_params
[
"token_id"
]
for
x
in
sampled_tokens
),
# Print the detailed test case info if the test fails.
f
"
{
target_token_id
=
}
\n
{
sampled_tokens
=
}
\n
{
custom_response
=
}
"
,
)
def
test_custom_logit_processor
(
self
):
"""Test custom logit processor with a single request."""
# Temporarily skipped due to buggy implementation
return
self
.
run_custom_logit_processor
(
target_token_id
=
5
)
def
test_custom_logit_processor_batch
(
self
):
"""Test custom logit processor with a batch of requests."""
# Temporarily skipped due to buggy implementation
return
target_token_ids
=
list
(
range
(
32
))
with
ThreadPoolExecutor
(
len
(
target_token_ids
))
as
executor
:
list
(
executor
.
map
(
self
.
run_custom_logit_processor
,
target_token_ids
))
def
test_custom_logit_processor_batch_mixed
(
self
):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids
=
list
(
range
(
32
))
+
[
None
]
*
16
random
.
shuffle
(
target_token_ids
)
with
ThreadPoolExecutor
(
len
(
target_token_ids
))
as
executor
:
list
(
executor
.
map
(
self
.
run_custom_logit_processor
,
target_token_ids
))
def
test_get_server_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response_json
=
response
.
json
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment