Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e403d237
Unverified
Commit
e403d237
authored
Jan 19, 2025
by
Hongpeng Guo
Committed by
GitHub
Jan 19, 2025
Browse files
[Feature] Add sampler custom logits processor (#2396)
Signed-off-by:
Hongpeng Guo
<
hpguo@anyscale.com
>
parent
3bcf5ece
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
302 additions
and
4 deletions
+302
-4
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+29
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+19
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+14
-0
python/sglang/srt/managers/session_controller.py
python/sglang/srt/managers/session_controller.py
+1
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-0
python/sglang/srt/sampling/custom_logit_processor.py
python/sglang/srt/sampling/custom_logit_processor.py
+38
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+121
-1
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+3
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+4
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+62
-1
No files found.
python/sglang/srt/layers/sampler.py
View file @
e403d237
import
logging
from
typing
import
List
from
typing
import
Dict
,
List
import
torch
from
torch
import
nn
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
crash_on_warnings
,
is_flashinfer_available
...
...
@@ -35,6 +36,10 @@ class Sampler(nn.Module):
):
logits
=
logits_output
.
next_token_logits
# Apply the custom logit processors if registered in the sampling info.
if
sampling_info
.
has_custom_logit_processor
:
self
.
_apply_custom_logit_processor
(
logits
,
sampling_info
)
if
self
.
use_nan_detectioin
and
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
logits
=
torch
.
where
(
...
...
@@ -121,6 +126,29 @@ class Sampler(nn.Module):
return
batch_next_token_ids
def
_apply_custom_logit_processor
(
self
,
logits
:
torch
.
Tensor
,
sampling_batch_info
:
SamplingBatchInfo
):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
for
_
,
(
processor
,
batch_mask
,
)
in
sampling_batch_info
.
custom_logit_processor
.
items
():
# Get the batch indices that need to be processed
batch_indices
=
batch_mask
.
nonzero
(
as_tuple
=
True
)[
0
]
# Apply the processor to the logits
logits
[
batch_mask
]
=
processor
(
logits
[
batch_mask
],
[
sampling_batch_info
.
custom_params
[
i
]
for
i
in
batch_indices
],
)
logger
.
debug
(
f
"Custom logit processor
{
processor
.
__class__
.
__name__
}
is applied."
)
def
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
e403d237
...
...
@@ -22,6 +22,7 @@ from enum import Enum
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
...
@@ -69,6 +70,8 @@ class GenerateReqInput:
# Session info for continual prompting
session_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
# Custom logit processor (serialized function)
custom_logit_processor
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
def
normalize_batch_and_arguments
(
self
):
if
(
...
...
@@ -183,6 +186,13 @@ class GenerateReqInput:
else
:
assert
self
.
parallel_sample_num
==
1
if
self
.
custom_logit_processor
is
None
:
self
.
custom_logit_processor
=
[
None
]
*
num
elif
not
isinstance
(
self
.
custom_logit_processor
,
list
):
self
.
custom_logit_processor
=
[
self
.
custom_logit_processor
]
*
num
else
:
assert
self
.
parallel_sample_num
==
1
def
regenerate_rid
(
self
):
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
...
...
@@ -202,6 +212,11 @@ class GenerateReqInput:
log_metrics
=
self
.
log_metrics
,
modalities
=
self
.
modalities
[
i
]
if
self
.
modalities
else
None
,
lora_path
=
self
.
lora_path
[
i
]
if
self
.
lora_path
is
not
None
else
None
,
custom_logit_processor
=
(
self
.
custom_logit_processor
[
i
]
if
self
.
custom_logit_processor
is
not
None
else
None
),
)
...
...
@@ -234,6 +249,10 @@ class TokenizedGenerateReqInput:
# Session info for continual prompting
session_params
:
Optional
[
SessionParams
]
=
None
# Custom logit processor (serialized function)
# TODO (hpguo): Add an example and update doc string here
custom_logit_processor
:
Optional
[
str
]
=
None
@
dataclass
class
EmbeddingReqInput
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e403d237
...
...
@@ -232,6 +232,7 @@ class Req:
lora_path
:
Optional
[
str
]
=
None
,
input_embeds
:
Optional
[
List
[
List
[
float
]]]
=
None
,
session_id
:
Optional
[
str
]
=
None
,
custom_logit_processor
:
Optional
[
str
]
=
None
,
eos_token_ids
:
Optional
[
Set
[
int
]]
=
None
,
):
# Input and output info
...
...
@@ -252,6 +253,7 @@ class Req:
# Sampling info
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
self
.
custom_logit_processor
=
custom_logit_processor
# Memory pool info
self
.
req_pool_idx
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
e403d237
...
...
@@ -614,6 +614,19 @@ class Scheduler:
fake_input_ids
=
[
1
]
*
seq_length
recv_req
.
input_ids
=
fake_input_ids
# Handle custom logit processor passed to the request
custom_logit_processor
=
recv_req
.
custom_logit_processor
if
(
not
self
.
server_args
.
enable_custom_logit_processor
and
custom_logit_processor
is
not
None
):
logger
.
warning
(
"The SGLang server is not configured to enable custom logit processor."
"The custom logit processor passed in will be ignored."
"Please set --enable-custom-logits-processor to enable this feature."
)
custom_logit_processor
=
None
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
...
...
@@ -624,6 +637,7 @@ class Scheduler:
stream
=
recv_req
.
stream
,
lora_path
=
recv_req
.
lora_path
,
input_embeds
=
recv_req
.
input_embeds
,
custom_logit_processor
=
custom_logit_processor
,
eos_token_ids
=
self
.
model_config
.
hf_eos_token_id
,
)
req
.
tokenizer
=
self
.
tokenizer
...
...
python/sglang/srt/managers/session_controller.py
View file @
e403d237
...
...
@@ -131,6 +131,7 @@ class Session:
sampling_params
=
req
.
sampling_params
,
lora_path
=
req
.
lora_path
,
session_id
=
self
.
session_id
,
custom_logit_processor
=
req
.
custom_logit_processor
,
)
if
last_req
is
not
None
:
new_req
.
image_inputs
=
last_req
.
image_inputs
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
e403d237
...
...
@@ -381,6 +381,7 @@ class TokenizerManager:
lora_path
=
obj
.
lora_path
,
input_embeds
=
input_embeds
,
session_params
=
session_params
,
custom_logit_processor
=
obj
.
custom_logit_processor
,
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
...
...
python/sglang/srt/sampling/custom_logit_processor.py
0 → 100644
View file @
e403d237
import
json
from
abc
import
ABC
,
abstractmethod
from
functools
import
lru_cache
from
typing
import
Any
,
Dict
,
List
,
Optional
import
dill
import
torch
@
lru_cache
(
maxsize
=
None
)
def
_cache_from_str
(
json_str
:
str
):
"""Deserialize a json string to a Callable object.
This function is cached to avoid redundant deserialization.
"""
data
=
json
.
loads
(
json_str
)
return
dill
.
loads
(
bytes
.
fromhex
(
data
[
"callable"
]))
class
CustomLogitProcessor
(
ABC
):
"""Abstract base class for callable functions."""
@
abstractmethod
def
__call__
(
self
,
logits
:
torch
.
Tensor
,
custom_param_list
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
)
->
torch
.
Tensor
:
"""Define the callable behavior."""
raise
NotImplementedError
def
to_str
(
self
)
->
str
:
"""Serialize the callable function to a JSON-compatible string."""
return
json
.
dumps
({
"callable"
:
dill
.
dumps
(
self
).
hex
()})
@
classmethod
def
from_str
(
cls
,
json_str
:
str
):
"""Deserialize a callable function from a JSON string."""
return
_cache_from_str
(
json_str
)
python/sglang/srt/sampling/sampling_batch_info.py
View file @
e403d237
...
...
@@ -3,7 +3,7 @@ from __future__ import annotations
import
dataclasses
import
logging
import
threading
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -14,6 +14,7 @@ if is_cuda:
from
sgl_kernel
import
sampling_scaling_penalties
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -36,6 +37,9 @@ class SamplingBatchInfo:
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
# Whether any request has custom logit processor
has_custom_logit_processor
:
bool
# Bias Tensors
vocab_size
:
int
grammars
:
Optional
[
List
]
=
None
...
...
@@ -52,6 +56,14 @@ class SamplingBatchInfo:
# Device
device
:
str
=
"cuda"
# Custom Parameters
custom_params
:
Optional
[
List
[
Optional
[
Dict
[
str
,
Any
]]]]
=
None
# Custom Logit Processor
custom_logit_processor
:
Optional
[
Dict
[
int
,
Tuple
[
CustomLogitProcessor
,
torch
.
Tensor
]]
]
=
None
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
,
enable_overlap_schedule
:
bool
...
...
@@ -76,6 +88,36 @@ class SamplingBatchInfo:
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
).
to
(
device
,
non_blocking
=
True
)
# Check if any request has custom logit processor
has_custom_logit_processor
=
any
(
r
.
custom_logit_processor
for
r
in
reqs
)
if
has_custom_logit_processor
:
# Merge the same type of custom logit processors together
processor_dict
=
{}
for
i
,
r
in
enumerate
(
reqs
):
if
r
.
custom_logit_processor
is
None
:
continue
processor_str
=
r
.
custom_logit_processor
if
processor_str
not
in
processor_dict
:
processor_dict
[
processor_str
]
=
[]
processor_dict
[
processor_str
].
append
(
i
)
merged_custom_logit_processor
=
{
hash
(
processor_str
):
(
# The deserialized custom logit processor object
CustomLogitProcessor
.
from_str
(
processor_str
),
# The mask tensor for the requests that use this custom logit processor
torch
.
zeros
(
len
(
reqs
),
dtype
=
torch
.
bool
)
.
scatter_
(
0
,
torch
.
tensor
(
true_indices
),
True
)
.
to
(
device
,
non_blocking
=
True
),
)
for
processor_str
,
true_indices
in
processor_dict
.
items
()
}
custom_params
=
[
r
.
sampling_params
.
custom_params
for
r
in
reqs
]
else
:
merged_custom_logit_processor
=
None
custom_params
=
None
ret
=
cls
(
temperatures
=
temperatures
,
top_ps
=
top_ps
,
...
...
@@ -83,8 +125,11 @@ class SamplingBatchInfo:
min_ps
=
min_ps
,
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
),
is_all_greedy
=
all
(
r
.
sampling_params
.
top_k
<=
1
for
r
in
reqs
),
has_custom_logit_processor
=
has_custom_logit_processor
,
vocab_size
=
vocab_size
,
device
=
device
,
custom_params
=
custom_params
,
custom_logit_processor
=
merged_custom_logit_processor
,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
...
...
@@ -184,6 +229,8 @@ class SamplingBatchInfo:
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
if
self
.
has_custom_logit_processor
:
self
.
_filter_batch_custom_logit_processor
(
unfinished_indices
,
new_indices
)
for
item
in
[
"temperatures"
,
...
...
@@ -196,6 +243,26 @@ class SamplingBatchInfo:
if
value
is
not
None
:
# logit_bias can be None
setattr
(
self
,
item
,
value
[
new_indices
])
def
_filter_batch_custom_logit_processor
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
"""Filter the custom logit processor and custom params"""
if
not
self
.
custom_logit_processor
:
return
self
.
custom_logit_processor
=
{
k
:
(
p
,
mask
[
new_indices
])
for
k
,
(
p
,
mask
)
in
self
.
custom_logit_processor
.
items
()
if
any
(
mask
[
new_indices
]
)
# ignore the custom logit processor whose mask is all False
}
self
.
custom_params
=
[
self
.
custom_params
[
i
]
for
i
in
unfinished_indices
]
if
len
(
self
)
==
0
:
self
.
custom_logit_processor
=
None
self
.
custom_params
=
None
self
.
has_custom_logit_processor
=
False
@
staticmethod
def
merge_bias_tensor
(
lhs
:
torch
.
Tensor
,
...
...
@@ -221,6 +288,39 @@ class SamplingBatchInfo:
return
None
@
staticmethod
def
merge_custom_logit_processor
(
lhs
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]],
rhs
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]],
bs1
:
int
,
bs2
:
int
,
device
:
str
,
):
if
lhs
is
None
and
rhs
is
None
:
return
None
lhs
,
rhs
=
lhs
or
{},
rhs
or
{}
keys
=
set
(
lhs
.
keys
()).
union
(
set
(
rhs
.
keys
()))
merged_dict
=
{}
for
k
in
keys
:
# Get the logit processor object
processor
=
lhs
[
k
][
0
]
if
k
in
lhs
else
rhs
[
k
][
0
]
# Get and merge the mask tensors from the two dicts
left_mask
=
(
lhs
[
k
][
1
]
if
k
in
lhs
else
torch
.
zeros
(
bs1
,
dtype
=
torch
.
bool
,
device
=
device
)
)
right_mask
=
(
rhs
[
k
][
1
]
if
k
in
rhs
else
torch
.
zeros
(
bs2
,
dtype
=
torch
.
bool
,
device
=
device
)
)
merged_dict
[
k
]
=
(
processor
,
torch
.
cat
([
left_mask
,
right_mask
]))
return
merged_dict
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
...
...
@@ -240,6 +340,26 @@ class SamplingBatchInfo:
)
self
.
need_min_p_sampling
=
self
.
need_min_p_sampling
or
other
.
need_min_p_sampling
# Merge the custom logit processors and custom params lists
if
self
.
has_custom_logit_processor
or
other
.
has_custom_logit_processor
:
# Merge the custom logit processors
self
.
custom_logit_processor
=
(
SamplingBatchInfo
.
merge_custom_logit_processor
(
self
.
custom_logit_processor
,
other
.
custom_logit_processor
,
len
(
self
),
len
(
other
),
self
.
device
,
)
)
# Merge the custom params lists
self
.
custom_params
=
self
.
custom_params
or
[
None
]
*
len
(
self
)
other
.
custom_params
=
other
.
custom_params
or
[
None
]
*
len
(
other
)
self
.
custom_params
.
extend
(
other
.
custom_params
)
# Set the flag to True if any of the two has custom logit processor
self
.
has_custom_logit_processor
=
True
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
):
# Apply logit_bias
if
self
.
logit_bias
is
not
None
:
...
...
python/sglang/srt/sampling/sampling_params.py
View file @
e403d237
...
...
@@ -13,7 +13,7 @@
# ==============================================================================
"""Sampling parameters for text generation."""
from
typing
import
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
_SAMPLING_EPS
=
1e-6
...
...
@@ -48,6 +48,7 @@ class SamplingParams:
no_stop_trim
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
skip_special_tokens
:
bool
=
True
,
custom_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
self
.
temperature
=
temperature
self
.
top_p
=
top_p
...
...
@@ -71,6 +72,7 @@ class SamplingParams:
self
.
json_schema
=
json_schema
self
.
ebnf
=
ebnf
self
.
no_stop_trim
=
no_stop_trim
self
.
custom_params
=
custom_params
# Process some special cases
if
self
.
temperature
<
_SAMPLING_EPS
:
...
...
python/sglang/srt/server.py
View file @
e403d237
...
...
@@ -773,6 +773,7 @@ class Engine:
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
lora_path
:
Optional
[
List
[
Optional
[
str
]]]
=
None
,
custom_logit_processor
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
stream
:
bool
=
False
,
):
obj
=
GenerateReqInput
(
...
...
@@ -784,6 +785,7 @@ class Engine:
top_logprobs_num
=
top_logprobs_num
,
lora_path
=
lora_path
,
stream
=
stream
,
custom_logit_processor
=
custom_logit_processor
,
)
# get the current event loop
...
...
@@ -824,6 +826,7 @@ class Engine:
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
lora_path
:
Optional
[
List
[
Optional
[
str
]]]
=
None
,
custom_logit_processor
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stream
:
bool
=
False
,
):
obj
=
GenerateReqInput
(
...
...
@@ -835,6 +838,7 @@ class Engine:
top_logprobs_num
=
top_logprobs_num
,
lora_path
=
lora_path
,
stream
=
stream
,
custom_logit_processor
=
custom_logit_processor
,
)
ret
=
await
generate_request
(
obj
,
None
)
...
...
python/sglang/srt/server_args.py
View file @
e403d237
...
...
@@ -159,6 +159,9 @@ class ServerArgs:
enable_memory_saver
:
bool
=
False
allow_auto_truncate
:
bool
=
False
# Custom logit processor
enable_custom_logit_processor
:
bool
=
False
def
__post_init__
(
self
):
# Set missing default values
if
self
.
tokenizer_path
is
None
:
...
...
@@ -865,6 +868,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Allow automatically truncating requests that exceed the maximum input length instead of returning an error."
,
)
parser
.
add_argument
(
"--enable-custom-logit-processor"
,
action
=
"store_true"
,
help
=
"Enable users to pass custom logit processors to the server (disabled by default for security)"
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
test/srt/test_srt_endpoint.py
View file @
e403d237
...
...
@@ -5,10 +5,12 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
import
json
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
import
numpy
as
np
import
requests
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
...
...
@@ -24,7 +26,10 @@ class TestSRTEndpoint(unittest.TestCase):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
(
"--enable-custom-logit-processor"
,),
)
@
classmethod
...
...
@@ -248,6 +253,62 @@ class TestSRTEndpoint(unittest.TestCase):
self
.
assertTrue
(
all
(
x
is
not
None
for
x
in
logprobs
))
def
run_custom_logit_processor
(
self
,
target_token_id
:
int
):
"""Test custom logit processor with custom params."""
custom_params
=
{
"token_id"
:
target_token_id
}
class
DeterministicLogitProcessor
(
CustomLogitProcessor
):
"""A dummy logit processor that changes the logits to always
sample the given token id.
"""
def
__call__
(
self
,
logits
,
custom_param_list
):
assert
logits
.
shape
[
0
]
==
len
(
custom_param_list
)
key
=
"token_id"
for
i
,
param_dict
in
enumerate
(
custom_param_list
):
# Mask all other tokens
logits
[
i
,
:]
=
-
float
(
"inf"
)
# Assign highest probability to the specified token
logits
[
i
,
param_dict
[
key
]]
=
0.0
return
logits
prompts
=
"Question: Is Paris the Capital of France? Answer:"
# Base case json data to be posted to the server.
base_json
=
{
"text"
:
prompts
,
"sampling_params"
:
{
"temperature"
:
0.0
},
"return_logprob"
:
True
,
}
# Custom json data with custom logit processor and params.
custom_json
=
base_json
.
copy
()
custom_json
[
"custom_logit_processor"
]
=
DeterministicLogitProcessor
().
to_str
()
custom_json
[
"sampling_params"
][
"custom_params"
]
=
custom_params
custom_response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
custom_json
,
).
json
()
output_token_logprobs
=
custom_response
[
"meta_info"
][
"output_token_logprobs"
]
sampled_tokens
=
[
x
[
1
]
for
x
in
output_token_logprobs
]
# The logit processor should always sample the given token as the logits is deterministic.
self
.
assertTrue
(
all
(
x
==
custom_params
[
"token_id"
]
for
x
in
sampled_tokens
))
def
test_custom_logit_processor
(
self
):
"""Test custom logit processor with a single request."""
self
.
run_custom_logit_processor
(
target_token_id
=
5
)
def
test_custom_logit_processor_batch
(
self
):
"""Test custom logit processor with a batch of requests."""
target_token_ids
=
list
(
range
(
32
))
with
ThreadPoolExecutor
(
len
(
target_token_ids
))
as
executor
:
list
(
executor
.
map
(
self
.
run_custom_logit_processor
,
target_token_ids
))
def
test_get_server_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response_json
=
response
.
json
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment