Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
583697cd
Unverified
Commit
583697cd
authored
Jan 20, 2025
by
Hongpeng Guo
Committed by
GitHub
Jan 20, 2025
Browse files
[Enhancement] Custom Logit Processor Improvement (#2998)
Signed-off-by:
Hongpeng Guo
<
hpguo@anyscale.com
>
parent
2584f6d9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
79 additions
and
28 deletions
+79
-28
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+1
-0
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+10
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+34
-19
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+26
-9
No files found.
python/sglang/bench_one_batch.py
View file @
583697cd
...
@@ -232,6 +232,7 @@ def extend(reqs, model_runner):
...
@@ -232,6 +232,7 @@ def extend(reqs, model_runner):
model_config
=
model_runner
.
model_config
,
model_config
=
model_runner
.
model_config
,
enable_overlap
=
False
,
enable_overlap
=
False
,
spec_algorithm
=
SpeculativeAlgorithm
.
NONE
,
spec_algorithm
=
SpeculativeAlgorithm
.
NONE
,
enable_custom_logit_processor
=
False
,
)
)
batch
.
prepare_for_extend
()
batch
.
prepare_for_extend
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
python/sglang/srt/layers/sampler.py
View file @
583697cd
...
@@ -132,6 +132,11 @@ class Sampler(nn.Module):
...
@@ -132,6 +132,11 @@ class Sampler(nn.Module):
"""Apply custom logit processors to the logits.
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
This function will modify the logits in-place."""
assert
logits
.
shape
[
0
]
==
len
(
sampling_batch_info
),
(
f
"The batch size of logits (
{
logits
.
shape
[
0
]
}
) does not match the batch size of "
f
"sampling_batch_info (
{
len
(
sampling_batch_info
)
}
)"
)
for
_
,
(
for
_
,
(
processor
,
processor
,
batch_mask
,
batch_mask
,
...
@@ -139,6 +144,11 @@ class Sampler(nn.Module):
...
@@ -139,6 +144,11 @@ class Sampler(nn.Module):
# Get the batch indices that need to be processed
# Get the batch indices that need to be processed
batch_indices
=
batch_mask
.
nonzero
(
as_tuple
=
True
)[
0
]
batch_indices
=
batch_mask
.
nonzero
(
as_tuple
=
True
)[
0
]
assert
batch_mask
.
shape
[
0
]
==
len
(
sampling_batch_info
),
(
f
"The number of batch mask (
{
batch_mask
.
shape
[
0
]
}
) does not match the number of "
f
"sampling_batch_info (
{
len
(
sampling_batch_info
)
}
)"
)
# Apply the processor to the logits
# Apply the processor to the logits
logits
[
batch_mask
]
=
processor
(
logits
[
batch_mask
]
=
processor
(
logits
[
batch_mask
],
logits
[
batch_mask
],
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
583697cd
...
@@ -595,6 +595,9 @@ class ScheduleBatch:
...
@@ -595,6 +595,9 @@ class ScheduleBatch:
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
SpecInfo
]
=
None
spec_info
:
Optional
[
SpecInfo
]
=
None
# Enable custom logit processor
enable_custom_logit_processor
:
bool
=
False
@
classmethod
@
classmethod
def
init_new
(
def
init_new
(
cls
,
cls
,
...
@@ -605,6 +608,7 @@ class ScheduleBatch:
...
@@ -605,6 +608,7 @@ class ScheduleBatch:
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
enable_overlap
:
bool
,
spec_algorithm
:
SpeculativeAlgorithm
,
spec_algorithm
:
SpeculativeAlgorithm
,
enable_custom_logit_processor
:
bool
,
):
):
return
cls
(
return
cls
(
reqs
=
reqs
,
reqs
=
reqs
,
...
@@ -618,6 +622,7 @@ class ScheduleBatch:
...
@@ -618,6 +622,7 @@ class ScheduleBatch:
has_grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
has_grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
device
=
req_to_token_pool
.
device
,
device
=
req_to_token_pool
.
device
,
spec_algorithm
=
spec_algorithm
,
spec_algorithm
=
spec_algorithm
,
enable_custom_logit_processor
=
enable_custom_logit_processor
,
)
)
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -1201,6 +1206,7 @@ class ScheduleBatch:
...
@@ -1201,6 +1206,7 @@ class ScheduleBatch:
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
decoding_reqs
=
self
.
decoding_reqs
,
decoding_reqs
=
self
.
decoding_reqs
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_algorithm
=
self
.
spec_algorithm
,
enable_custom_logit_processor
=
self
.
enable_custom_logit_processor
,
)
)
def
__str__
(
self
):
def
__str__
(
self
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
583697cd
...
@@ -966,6 +966,7 @@ class Scheduler:
...
@@ -966,6 +966,7 @@ class Scheduler:
self
.
model_config
,
self
.
model_config
,
self
.
enable_overlap
,
self
.
enable_overlap
,
self
.
spec_algorithm
,
self
.
spec_algorithm
,
self
.
server_args
.
enable_custom_logit_processor
,
)
)
new_batch
.
prepare_for_extend
()
new_batch
.
prepare_for_extend
()
...
@@ -1520,6 +1521,7 @@ class Scheduler:
...
@@ -1520,6 +1521,7 @@ class Scheduler:
self
.
model_config
,
self
.
model_config
,
self
.
enable_overlap
,
self
.
enable_overlap
,
self
.
spec_algorithm
,
self
.
spec_algorithm
,
self
.
server_args
.
enable_custom_logit_processor
,
)
)
idle_batch
.
prepare_for_idle
()
idle_batch
.
prepare_for_idle
()
return
idle_batch
return
idle_batch
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
583697cd
...
@@ -89,7 +89,10 @@ class SamplingBatchInfo:
...
@@ -89,7 +89,10 @@ class SamplingBatchInfo:
).
to
(
device
,
non_blocking
=
True
)
).
to
(
device
,
non_blocking
=
True
)
# Check if any request has custom logit processor
# Check if any request has custom logit processor
has_custom_logit_processor
=
any
(
r
.
custom_logit_processor
for
r
in
reqs
)
has_custom_logit_processor
=
(
batch
.
enable_custom_logit_processor
# check the flag first.
and
any
(
r
.
custom_logit_processor
for
r
in
reqs
)
# then check the requests.
)
if
has_custom_logit_processor
:
if
has_custom_logit_processor
:
# Merge the same type of custom logit processors together
# Merge the same type of custom logit processors together
...
@@ -247,8 +250,7 @@ class SamplingBatchInfo:
...
@@ -247,8 +250,7 @@ class SamplingBatchInfo:
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
):
"""Filter the custom logit processor and custom params"""
"""Filter the custom logit processor and custom params"""
if
not
self
.
custom_logit_processor
:
return
self
.
custom_logit_processor
=
{
self
.
custom_logit_processor
=
{
k
:
(
p
,
mask
[
new_indices
])
k
:
(
p
,
mask
[
new_indices
])
for
k
,
(
p
,
mask
)
in
self
.
custom_logit_processor
.
items
()
for
k
,
(
p
,
mask
)
in
self
.
custom_logit_processor
.
items
()
...
@@ -258,7 +260,9 @@ class SamplingBatchInfo:
...
@@ -258,7 +260,9 @@ class SamplingBatchInfo:
}
}
self
.
custom_params
=
[
self
.
custom_params
[
i
]
for
i
in
unfinished_indices
]
self
.
custom_params
=
[
self
.
custom_params
[
i
]
for
i
in
unfinished_indices
]
if
len
(
self
)
==
0
:
# If the custom logit processor is an empty dict, set the flag to False,
# and set the custom logit processor and custom params to None.
if
len
(
self
.
custom_logit_processor
)
==
0
:
self
.
custom_logit_processor
=
None
self
.
custom_logit_processor
=
None
self
.
custom_params
=
None
self
.
custom_params
=
None
self
.
has_custom_logit_processor
=
False
self
.
has_custom_logit_processor
=
False
...
@@ -290,8 +294,8 @@ class SamplingBatchInfo:
...
@@ -290,8 +294,8 @@ class SamplingBatchInfo:
@
staticmethod
@
staticmethod
def
merge_custom_logit_processor
(
def
merge_custom_logit_processor
(
lhs
:
Optional
[
Dict
[
st
r
,
torch
.
Tensor
]],
lhs
:
Optional
[
Dict
[
int
,
Tuple
[
CustomLogitProcesso
r
,
torch
.
Tensor
]]
]
,
rhs
:
Optional
[
Dict
[
st
r
,
torch
.
Tensor
]],
rhs
:
Optional
[
Dict
[
int
,
Tuple
[
CustomLogitProcesso
r
,
torch
.
Tensor
]]
]
,
bs1
:
int
,
bs1
:
int
,
bs2
:
int
,
bs2
:
int
,
device
:
str
,
device
:
str
,
...
@@ -319,27 +323,22 @@ class SamplingBatchInfo:
...
@@ -319,27 +323,22 @@ class SamplingBatchInfo:
)
)
merged_dict
[
k
]
=
(
processor
,
torch
.
cat
([
left_mask
,
right_mask
]))
merged_dict
[
k
]
=
(
processor
,
torch
.
cat
([
left_mask
,
right_mask
]))
assert
merged_dict
[
k
][
1
].
shape
[
0
]
==
bs1
+
bs2
,
(
f
"The batch size of merged mask (
{
merged_dict
[
k
][
1
].
shape
[
0
]
}
) does not match "
f
"the sum of the batch sizes of the two masks (
{
bs1
+
bs2
}
)"
f
"
\n
{
left_mask
=
}
\n
{
right_mask
=
}
\n
{
bs1
=
}
\n
{
bs2
=
}
"
f
"
\n
{
lhs
=
}
\n
{
rhs
=
}
"
)
return
merged_dict
return
merged_dict
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
# Merge the logit bias tensor
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
self
.
is_all_greedy
=
self
.
is_all_greedy
and
other
.
is_all_greedy
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
)
)
self
.
need_min_p_sampling
=
self
.
need_min_p_sampling
or
other
.
need_min_p_sampling
# Merge the custom logit processors and custom params lists
# Merge the custom logit processors and custom params lists
if
self
.
has_custom_logit_processor
or
other
.
has_custom_logit_processor
:
if
self
.
has_custom_logit_processor
or
other
.
has_custom_logit_processor
:
# Merge the custom logit processors
# Merge the custom logit processors
...
@@ -360,6 +359,22 @@ class SamplingBatchInfo:
...
@@ -360,6 +359,22 @@ class SamplingBatchInfo:
# Set the flag to True if any of the two has custom logit processor
# Set the flag to True if any of the two has custom logit processor
self
.
has_custom_logit_processor
=
True
self
.
has_custom_logit_processor
=
True
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
# please make sure any merge operation with len(self) or len(other) is done before
# the merge operation of the temperatures tensor below.
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
self
.
is_all_greedy
=
self
.
is_all_greedy
and
other
.
is_all_greedy
self
.
need_min_p_sampling
=
self
.
need_min_p_sampling
or
other
.
need_min_p_sampling
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
):
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
):
# Apply logit_bias
# Apply logit_bias
if
self
.
logit_bias
is
not
None
:
if
self
.
logit_bias
is
not
None
:
...
...
test/srt/test_srt_endpoint.py
View file @
583697cd
...
@@ -4,8 +4,10 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
...
@@ -4,8 +4,10 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
"""
"""
import
json
import
json
import
random
import
unittest
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Optional
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
...
@@ -253,8 +255,11 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -253,8 +255,11 @@ class TestSRTEndpoint(unittest.TestCase):
self
.
assertTrue
(
all
(
x
is
not
None
for
x
in
logprobs
))
self
.
assertTrue
(
all
(
x
is
not
None
for
x
in
logprobs
))
def
run_custom_logit_processor
(
self
,
target_token_id
:
int
):
def
run_custom_logit_processor
(
self
,
target_token_id
:
Optional
[
int
]
=
None
):
"""Test custom logit processor with custom params."""
"""Test custom logit processor with custom params.
If target_token_id is None, the custom logit processor won't be passed in.
"""
custom_params
=
{
"token_id"
:
target_token_id
}
custom_params
=
{
"token_id"
:
target_token_id
}
...
@@ -285,8 +290,12 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -285,8 +290,12 @@ class TestSRTEndpoint(unittest.TestCase):
# Custom json data with custom logit processor and params.
# Custom json data with custom logit processor and params.
custom_json
=
base_json
.
copy
()
custom_json
=
base_json
.
copy
()
custom_json
[
"custom_logit_processor"
]
=
DeterministicLogitProcessor
().
to_str
()
# Only set the custom logit processor if target_token_id is not None.
custom_json
[
"sampling_params"
][
"custom_params"
]
=
custom_params
if
target_token_id
is
not
None
:
custom_json
[
"custom_logit_processor"
]
=
(
DeterministicLogitProcessor
().
to_str
()
)
custom_json
[
"sampling_params"
][
"custom_params"
]
=
custom_params
custom_response
=
requests
.
post
(
custom_response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
...
@@ -297,22 +306,30 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -297,22 +306,30 @@ class TestSRTEndpoint(unittest.TestCase):
sampled_tokens
=
[
x
[
1
]
for
x
in
output_token_logprobs
]
sampled_tokens
=
[
x
[
1
]
for
x
in
output_token_logprobs
]
# The logit processor should always sample the given token as the logits is deterministic.
# The logit processor should always sample the given token as the logits is deterministic.
self
.
assertTrue
(
all
(
x
==
custom_params
[
"token_id"
]
for
x
in
sampled_tokens
))
if
target_token_id
is
not
None
:
self
.
assertTrue
(
all
(
x
==
custom_params
[
"token_id"
]
for
x
in
sampled_tokens
),
# Print the detailed test case info if the test fails.
f
"
{
target_token_id
=
}
\n
{
sampled_tokens
=
}
\n
{
custom_response
=
}
"
,
)
def
test_custom_logit_processor
(
self
):
def
test_custom_logit_processor
(
self
):
"""Test custom logit processor with a single request."""
"""Test custom logit processor with a single request."""
# Temporarily skipped due to buggy implementation
return
self
.
run_custom_logit_processor
(
target_token_id
=
5
)
self
.
run_custom_logit_processor
(
target_token_id
=
5
)
def
test_custom_logit_processor_batch
(
self
):
def
test_custom_logit_processor_batch
(
self
):
"""Test custom logit processor with a batch of requests."""
"""Test custom logit processor with a batch of requests."""
# Temporarily skipped due to buggy implementation
return
target_token_ids
=
list
(
range
(
32
))
target_token_ids
=
list
(
range
(
32
))
with
ThreadPoolExecutor
(
len
(
target_token_ids
))
as
executor
:
with
ThreadPoolExecutor
(
len
(
target_token_ids
))
as
executor
:
list
(
executor
.
map
(
self
.
run_custom_logit_processor
,
target_token_ids
))
list
(
executor
.
map
(
self
.
run_custom_logit_processor
,
target_token_ids
))
def
test_custom_logit_processor_batch_mixed
(
self
):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids
=
list
(
range
(
32
))
+
[
None
]
*
16
random
.
shuffle
(
target_token_ids
)
with
ThreadPoolExecutor
(
len
(
target_token_ids
))
as
executor
:
list
(
executor
.
map
(
self
.
run_custom_logit_processor
,
target_token_ids
))
def
test_get_server_info
(
self
):
def
test_get_server_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response_json
=
response
.
json
()
response_json
=
response
.
json
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment