Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e30ef368
Unverified
Commit
e30ef368
authored
Jun 17, 2025
by
woodx
Committed by
GitHub
Jun 16, 2025
Browse files
Feat/support rerank (#6058)
parent
91a066ec
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
684 additions
and
30 deletions
+684
-30
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+5
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+14
-0
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+11
-0
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+19
-0
python/sglang/srt/layers/pooler.py
python/sglang/srt/layers/pooler.py
+56
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+21
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+21
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+15
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+5
-1
python/sglang/srt/models/bert.py
python/sglang/srt/models/bert.py
+113
-13
python/sglang/srt/models/roberta.py
python/sglang/srt/models/roberta.py
+117
-9
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+64
-1
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+7
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+38
-3
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+2
-0
python/sglang/utils.py
python/sglang/utils.py
+9
-0
test/srt/models/test_cross_encoder_models.py
test/srt/models/test_cross_encoder_models.py
+91
-0
test/srt/run_suite.py
test/srt/run_suite.py
+2
-0
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+73
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
e30ef368
...
...
@@ -550,6 +550,11 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
or
"Qwen2ForRewardModel"
in
model_architectures
or
"Qwen2ForSequenceClassification"
in
model_architectures
or
"CLIPModel"
in
model_architectures
or
"BertModel"
in
model_architectures
or
"Contriever"
in
model_architectures
or
"BertForSequenceClassification"
in
model_architectures
or
"XLMRobertaModel"
in
model_architectures
or
"XLMRobertaForSequenceClassification"
in
model_architectures
):
return
False
else
:
...
...
python/sglang/srt/entrypoints/engine.py
View file @
e30ef368
...
...
@@ -327,6 +327,20 @@ class Engine(EngineBase):
generator
=
self
.
tokenizer_manager
.
generate_request
(
obj
,
None
)
return
await
generator
.
__anext__
()
def
rerank
(
self
,
prompt
:
Union
[
List
[
List
[
str
]]],
)
->
Dict
:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
Please refer to `EmbeddingReqInput` for the documentation.
"""
obj
=
EmbeddingReqInput
(
text
=
prompt
,
is_cross_encoder_request
=
True
)
loop
=
asyncio
.
get_event_loop
()
generator
=
self
.
tokenizer_manager
.
generate_request
(
obj
,
None
)
ret
=
loop
.
run_until_complete
(
generator
.
__anext__
())
return
ret
def
shutdown
(
self
):
"""Shutdown the engine"""
kill_process_tree
(
os
.
getpid
(),
include_parent
=
False
)
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
e30ef368
...
...
@@ -67,6 +67,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
V1RerankReqInput
,
VertexGenerateReqInput
,
)
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
...
...
@@ -79,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
v1_delete_file
,
v1_embeddings
,
v1_files_create
,
v1_rerank
,
v1_retrieve_batch
,
v1_retrieve_file
,
v1_retrieve_file_content
,
...
...
@@ -328,6 +330,15 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/v1/rerank"
,
methods
=
[
"POST"
,
"PUT"
])
async
def
v1_rerank_request
(
obj
:
V1RerankReqInput
,
raw_request
:
Request
):
try
:
ret
=
await
v1_rerank
(
_global_state
.
tokenizer_manager
,
obj
,
raw_request
)
return
ret
except
ValueError
as
e
:
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/flush_cache"
,
methods
=
[
"GET"
,
"POST"
])
async
def
flush_cache
():
"""Flush the radix cache."""
...
...
python/sglang/srt/layers/activation.py
View file @
e30ef368
...
...
@@ -20,6 +20,7 @@ from typing import Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
(
...
...
@@ -29,6 +30,7 @@ from sglang.srt.distributed import (
)
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.utils
import
is_cuda
,
set_weight_attrs
from
sglang.utils
import
resolve_obj_by_qualname
_is_cuda
=
is_cuda
()
...
...
@@ -165,6 +167,23 @@ def get_act_fn(
return
act_fn
def
get_cross_encoder_activation_function
(
config
:
PretrainedConfig
):
if
(
hasattr
(
config
,
"sbert_ce_default_activation_function"
)
and
config
.
sbert_ce_default_activation_function
is
not
None
):
function_name
=
config
.
sbert_ce_default_activation_function
assert
function_name
.
startswith
(
"torch.nn.modules."
),
(
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
return
resolve_obj_by_qualname
(
function_name
)()
else
:
# adapt bge-reranker
return
nn
.
Identity
()
if
not
_is_cuda
:
logger
.
info
(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
...
...
python/sglang/srt/layers/pooler.py
View file @
e30ef368
...
...
@@ -3,10 +3,13 @@
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
sglang.srt.layers.activation
import
get_cross_encoder_activation_function
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
...
...
@@ -54,3 +57,56 @@ class Pooler(nn.Module):
pooled_data
=
nn
.
functional
.
normalize
(
pooled_data
,
p
=
2
,
dim
=
1
)
return
EmbeddingPoolerOutput
(
embeddings
=
pooled_data
)
class
CrossEncodingPooler
(
nn
.
Module
):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `EmbeddingPoolerOutput`.
"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
classifier
:
nn
.
Module
,
pooler
:
Optional
[
nn
.
Module
]
=
None
,
):
super
().
__init__
()
self
.
classifier
=
classifier
self
.
pooler
=
pooler
self
.
default_activation_function
=
get_cross_encoder_activation_function
(
config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
EmbeddingPoolerOutput
:
"""Pools sentence pair scores from the hidden_states."""
prompt_lens
=
forward_batch
.
extend_seq_lens
offset
=
0
pooled_data_lst
=
[]
for
prompt_len
in
prompt_lens
:
pooled_data_i
=
hidden_states
[
offset
:
offset
+
prompt_len
]
if
self
.
pooler
is
not
None
:
final_shape_tensor
=
self
.
pooler
(
pooled_data_i
,
forward_batch
)
else
:
final_shape_tensor
=
self
.
classifier
(
pooled_data_i
)
pooled_data_lst
.
append
(
final_shape_tensor
)
offset
+=
prompt_len
pooled_output
=
torch
.
stack
(
pooled_data_lst
)
if
self
.
pooler
is
not
None
:
# apply classifier once on the full batch if possible
pooled_output
=
self
.
classifier
(
pooled_output
)
scores
=
self
.
default_activation_function
(
pooled_output
).
squeeze
(
-
1
)
return
EmbeddingPoolerOutput
(
embeddings
=
scores
)
python/sglang/srt/managers/io_struct.py
View file @
e30ef368
...
...
@@ -481,7 +481,7 @@ class TokenizedGenerateReqInput:
@
dataclass
class
EmbeddingReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
text
:
Optional
[
Union
[
List
[
List
[
str
]],
List
[
str
],
str
]]
=
None
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
# Can be formatted as:
# - Single image for a single request
...
...
@@ -505,6 +505,8 @@ class EmbeddingReqInput:
log_metrics
:
bool
=
True
# The modalities of the image data [image, multi-images, video]
modalities
:
Optional
[
List
[
str
]]
=
None
# For cross-encoder requests
is_cross_encoder_request
:
bool
=
False
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
...
...
@@ -564,6 +566,16 @@ class EmbeddingReqInput:
return
self
.
rid
def
__getitem__
(
self
,
i
):
if
self
.
is_cross_encoder_request
:
return
EmbeddingReqInput
(
text
=
[
self
.
text
[
i
]]
if
self
.
text
is
not
None
else
None
,
input_ids
=
None
,
image_data
=
None
,
sampling_params
=
self
.
sampling_params
[
i
],
rid
=
self
.
rid
[
i
],
is_cross_encoder_request
=
True
,
)
return
EmbeddingReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
...
...
@@ -583,6 +595,8 @@ class TokenizedEmbeddingReqInput:
input_ids
:
List
[
int
]
# The image inputs
image_inputs
:
dict
# The token type ids
token_type_ids
:
List
[
int
]
# Dummy sampling params for compatibility
sampling_params
:
SamplingParams
...
...
@@ -847,6 +861,12 @@ class SetInternalStateReq:
server_args
:
Dict
[
str
,
Any
]
@
dataclass
class
V1RerankReqInput
:
query
:
str
documents
:
List
[
str
]
@
dataclass
class
SetInternalStateReqOutput
:
updated
:
bool
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e30ef368
...
...
@@ -445,6 +445,7 @@ class Req:
origin_input_ids_unpadded
:
Optional
[
Tuple
[
int
]]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
input_embeds
:
Optional
[
List
[
List
[
float
]]]
=
None
,
token_type_ids
:
List
[
int
]
=
None
,
session_id
:
Optional
[
str
]
=
None
,
custom_logit_processor
:
Optional
[
str
]
=
None
,
return_hidden_states
:
bool
=
False
,
...
...
@@ -470,6 +471,9 @@ class Req:
self
.
session_id
=
session_id
self
.
input_embeds
=
input_embeds
# for corss-endoder model
self
.
token_type_ids
=
token_type_ids
# Sampling info
if
isinstance
(
sampling_params
.
custom_params
,
dict
):
sampling_params
=
copy
.
copy
(
sampling_params
)
...
...
@@ -841,6 +845,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
input_embeds
:
torch
.
Tensor
=
None
# shape: [b, hidden_size], float32
token_type_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
req_pool_indices
:
torch
.
Tensor
=
None
# shape: [b], int64
seq_lens
:
torch
.
Tensor
=
None
# shape: [b], int64
# The output locations of the KV cache
...
...
@@ -1142,6 +1147,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
token_type_ids
=
[
r
.
token_type_ids
for
r
in
reqs
if
r
.
token_type_ids
is
not
None
]
req_pool_indices_tensor
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
)
...
...
@@ -1154,6 +1163,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
token_type_ids_tensor
=
None
if
len
(
token_type_ids
)
>
0
:
token_type_ids_tensor
=
torch
.
tensor
(
sum
(
token_type_ids
,
[]),
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
)
extend_lens_tensor
=
seq_lens_tensor
-
prefix_lens_tensor
# Copy prefix and do some basic check
...
...
@@ -1269,6 +1285,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
device
,
non_blocking
=
True
)
self
.
multimodal_inputs
=
multimodal_inputs
self
.
token_type_ids
=
token_type_ids_tensor
self
.
seq_lens_sum
=
sum
(
seq_lens
)
if
self
.
return_logprob
:
...
...
@@ -1714,6 +1731,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
sampling_info
=
self
.
sampling_info
,
input_embeds
=
self
.
input_embeds
,
token_type_ids
=
self
.
token_type_ids
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_info
=
self
.
spec_info
,
capture_hidden_mode
=
(
...
...
@@ -1807,6 +1825,9 @@ class ModelWorkerBatch:
# The input Embeds
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
# For corss-encoder model
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
e30ef368
...
...
@@ -1150,6 +1150,7 @@ class Scheduler(
recv_req
.
input_text
,
recv_req
.
input_ids
,
recv_req
.
sampling_params
,
token_type_ids
=
recv_req
.
token_type_ids
,
)
req
.
tokenizer
=
self
.
tokenizer
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
e30ef368
...
...
@@ -459,6 +459,10 @@ class TokenizerManager:
# Tokenize
input_embeds
=
None
input_text
=
obj
.
text
token_type_ids
=
None
is_cross_encoder_request
=
(
isinstance
(
obj
,
EmbeddingReqInput
)
and
obj
.
is_cross_encoder_request
)
if
obj
.
input_embeds
is
not
None
:
if
not
self
.
server_args
.
disable_radix_cache
:
raise
ValueError
(
...
...
@@ -477,7 +481,14 @@ class TokenizerManager:
"accept text prompts. Please provide input_ids or re-initialize "
"the engine with skip_tokenizer_init=False."
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
encoded
=
self
.
tokenizer
(
input_text
,
return_token_type_ids
=
is_cross_encoder_request
)
input_ids
=
encoded
[
"input_ids"
]
if
is_cross_encoder_request
:
input_ids
=
encoded
[
"input_ids"
][
0
]
token_type_ids
=
encoded
.
get
(
"token_type_ids"
,
[
None
])[
0
]
if
self
.
mm_processor
and
obj
.
contains_mm_input
():
image_inputs
=
await
self
.
mm_processor
.
process_mm_data_async
(
...
...
@@ -493,7 +504,7 @@ class TokenizerManager:
self
.
_validate_token_len
(
obj
,
input_ids
)
return
self
.
_create_tokenized_object
(
obj
,
input_text
,
input_ids
,
input_embeds
,
image_inputs
obj
,
input_text
,
input_ids
,
input_embeds
,
image_inputs
,
token_type_ids
)
def
_validate_token_len
(
...
...
@@ -532,6 +543,7 @@ class TokenizerManager:
input_ids
:
List
[
int
],
input_embeds
:
Optional
[
Union
[
List
[
float
],
None
]]
=
None
,
image_inputs
:
Optional
[
Dict
]
=
None
,
token_type_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
]:
"""Create a tokenized request object from common parameters."""
...
...
@@ -592,6 +604,7 @@ class TokenizerManager:
input_text
,
input_ids
,
image_inputs
,
token_type_ids
,
sampling_params
,
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
e30ef368
...
...
@@ -224,6 +224,9 @@ class ForwardBatch:
# For input embeddings
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
# For cross-encoder model
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
...
...
@@ -300,6 +303,7 @@ class ForwardBatch:
spec_info
=
batch
.
spec_info
,
capture_hidden_mode
=
batch
.
capture_hidden_mode
,
input_embeds
=
batch
.
input_embeds
,
token_type_ids
=
batch
.
token_type_ids
,
tbo_split_seq_index
=
batch
.
tbo_split_seq_index
,
)
device
=
model_runner
.
device
...
...
@@ -356,8 +360,8 @@ class ForwardBatch:
ret
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
extend_prefix_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
if
support_triton
(
model_runner
.
server_args
.
attention_backend
):
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
ret
.
extend_prefix_lens
,
ret
.
extend_seq_lens
,
...
...
python/sglang/srt/models/bert.py
View file @
e30ef368
...
...
@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import (
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.pooler
import
Embed
dingPooler
Output
,
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
CrossEnco
dingPooler
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
AttentionType
,
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
BertConfig
=
None
...
...
@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
...
...
@@ -58,11 +60,14 @@ class BertEmbedding(nn.Module):
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position
_id
s
)
position_embeddings
=
self
.
position_embeddings
(
positions
)
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
token_type_ids
=
forward_batch
.
token_type_ids
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
...
...
@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module):
return
embeddings
class
BertPooler
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
# simply taking the hidden state corresponding
first_token_tensor
=
hidden_states
[
0
,
:]
pooled_output
=
self
.
dense
(
first_token_tensor
)
pooled_output
=
self
.
activation
(
pooled_output
)
return
pooled_output
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
...
...
@@ -113,6 +137,8 @@ class BertLayer(nn.Module):
):
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
attention
=
BertAttention
(
hidden_size
=
config
.
hidden_size
,
num_attention_heads
=
config
.
num_attention_heads
,
...
...
@@ -142,6 +168,7 @@ class BertLayer(nn.Module):
attn_output
=
self
.
attention
(
hidden_states
,
forward_batch
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
return
output
...
...
@@ -326,16 +353,23 @@ class BertModel(nn.Module):
*
,
config
:
BertConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
use_bert_pooler
:
bool
=
False
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
use_bert_pooler
=
use_bert_pooler
self
.
config
=
config
self
.
embeddings
=
BertEmbedding
(
config
)
self
.
encoder
=
BertEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"encoder"
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"encoder"
,
prefix
),
)
self
.
pooler
=
(
BertPooler
(
config
)
if
self
.
use_bert_pooler
else
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# self.pooler = BertPooler(config)
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -351,11 +385,16 @@ class BertModel(nn.Module):
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
positions
,
positions
=
positions
,
forward_batch
=
forward_batch
,
)
hidden_states
=
self
.
encoder
(
hidden_states
,
forward_batch
=
forward_batch
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
if
not
self
.
use_bert_pooler
:
hidden_states
=
self
.
pooler
(
hidden_states
,
forward_batch
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
...
...
@@ -368,7 +407,7 @@ class BertModel(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
name
=
name
.
replace
(
"self"
,
"self_attn"
)
if
"pooler"
in
name
:
if
not
self
.
use_bert_pooler
and
"pooler"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
...
@@ -395,4 +434,65 @@ class Contriever(BertModel):
pass
EntryClass
=
[
BertModel
,
Contriever
]
class
BertForSequenceClassification
(
nn
.
Module
):
def
__init__
(
self
,
*
,
config
:
BertConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
num_labels
=
config
.
num_labels
self
.
bert
=
BertModel
(
config
=
config
,
quant_config
=
quant_config
,
use_bert_pooler
=
True
,
prefix
=
add_prefix
(
"bert"
,
prefix
),
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
pooler
=
CrossEncodingPooler
(
config
,
self
.
classifier
,
self
.
bert
.
pooler
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self_weights
=
[]
def
weight_filter
():
for
name
,
weight
in
weights
:
if
name
.
startswith
(
"bert."
):
yield
(
name
[
len
(
"bert."
)
:],
weight
)
else
:
self_weights
.
append
((
name
,
weight
))
self
.
bert
.
load_weights
(
weight_filter
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
self_weights
:
if
name
.
startswith
(
"classifier"
):
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
assert
get_embedding
==
True
hidden_states
=
self
.
bert
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
input_embeds
,
get_embedding
=
get_embedding
,
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
EntryClass
=
[
BertModel
,
Contriever
,
BertForSequenceClassification
]
python/sglang/srt/models/roberta.py
View file @
e30ef368
...
...
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
CrossEncodingPooler
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
...
@@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder
RobertaConfig
=
None
# Adapted from transformers
class
RobertaClassificationHead
(
nn
.
Module
):
"""Head for sentence-level classification tasks."""
def
__init__
(
self
,
config
:
RobertaConfig
):
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
out_proj
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
def
forward
(
self
,
features
,
**
kwargs
):
x
=
features
[
0
,
:]
# take <s> token (equiv. to [CLS])
x
=
self
.
dense
(
x
)
x
=
torch
.
tanh
(
x
)
x
=
self
.
out_proj
(
x
)
return
x
class
RobertaEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
RobertaConfig
):
...
...
@@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module):
input_ids
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
inputs_embeds
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
...
...
@@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module):
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
token_type_ids
=
forward_batch
.
token_type_ids
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
...
...
@@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module):
return
embeddings
class
XLMRobertaModel
(
nn
.
Module
):
class
XLMRoberta
Base
Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
config
:
RobertaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
add_pooling_layer
:
bool
=
False
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embeddings
=
RobertaEmbedding
(
config
)
self
.
encoder
=
BertEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
""
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
CLS
,
normalize
=
True
)
self
.
pooler
=
(
Pooler
(
pooling_type
=
PoolingType
.
CLS
,
normalize
=
True
)
if
add_pooling_layer
else
None
)
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module):
input_ids
=
input_ids
,
position_ids
=
positions
,
seq_lens
=
forward_batch
.
seq_lens
,
forward_batch
=
forward_batch
,
)
hidden_states
=
self
.
encoder
(
hidden_states
,
forward_batch
=
forward_batch
)
pooler_out
=
self
.
pooler
(
hidden_states
,
forward_batch
)
return
pooler_out
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
@@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
name
=
name
.
replace
(
"self"
,
"self_attn"
)
if
"pooler"
in
name
:
if
self
.
pooler
is
None
and
"pooler"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
...
@@ -175,4 +199,88 @@ def create_position_ids_from_input_ids(
return
incremental_indices
.
long
()
+
padding_idx
EntryClass
=
[
XLMRobertaModel
]
class
XLMRobertaModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
config
:
RobertaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
roberta
=
XLMRobertaBaseModel
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
CLS
,
normalize
=
True
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
roberta
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
get_embedding
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
.
roberta
.
load_weights
(
weights
)
class
XLMRobertaForSequenceClassification
(
nn
.
Module
):
def
__init__
(
self
,
*
,
config
:
RobertaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
roberta
=
XLMRobertaBaseModel
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
classifier
=
RobertaClassificationHead
(
config
)
self
.
pooler
=
CrossEncodingPooler
(
config
,
self
.
classifier
,
self
.
roberta
.
pooler
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
True
,
)
->
torch
.
Tensor
:
assert
(
get_embedding
),
"XLMRobertaForSequenceClassification is only used for rerank"
hidden_states
=
self
.
roberta
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
get_embedding
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self_weights
=
[]
def
weight_filter
():
for
name
,
weight
in
weights
:
if
name
.
startswith
(
"roberta."
):
yield
(
name
[
len
(
"roberta."
)
:],
weight
)
else
:
self_weights
.
append
((
name
,
weight
))
self
.
roberta
.
load_weights
(
weight_filter
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
self_weights
:
if
name
.
startswith
(
"classifier"
):
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
[
XLMRobertaModel
,
XLMRobertaForSequenceClassification
]
python/sglang/srt/openai_api/adapter.py
View file @
e30ef368
...
...
@@ -41,7 +41,11 @@ from sglang.srt.conversation import (
register_conv_template
,
)
from
sglang.srt.function_call.function_call_parser
import
FunctionCallParser
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.managers.io_struct
import
(
EmbeddingReqInput
,
GenerateReqInput
,
V1RerankReqInput
,
)
from
sglang.srt.openai_api.protocol
import
(
BatchRequest
,
BatchResponse
,
...
...
@@ -69,6 +73,7 @@ from sglang.srt.openai_api.protocol import (
FunctionResponse
,
LogProbs
,
MultimodalEmbeddingInput
,
RerankResponse
,
ScoringRequest
,
ScoringResponse
,
ToolCall
,
...
...
@@ -2020,6 +2025,64 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
return
response
def
v1_rerank_request
(
obj
:
V1RerankReqInput
):
if
obj
.
query
is
None
:
raise
ValueError
(
"query is required"
)
if
obj
.
documents
is
None
or
len
(
obj
.
documents
)
==
0
:
raise
ValueError
(
"documents is required"
)
pairs
=
[]
for
doc
in
obj
.
documents
:
pairs
.
append
([
obj
.
query
,
doc
])
adapted_request
=
EmbeddingReqInput
(
text
=
pairs
,
is_cross_encoder_request
=
True
,
)
return
adapted_request
def
v1_rerank_response
(
ret
,
obj
:
V1RerankReqInput
):
response
=
[]
for
idx
,
ret_item
in
enumerate
(
ret
):
response
.
append
(
RerankResponse
(
score
=
ret
[
idx
][
"embedding"
],
document
=
obj
.
documents
[
idx
],
index
=
idx
,
meta_info
=
ret
[
idx
][
"meta_info"
],
)
)
response
.
sort
(
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
return
response
async
def
v1_rerank
(
tokenizer_manager
,
obj
:
V1RerankReqInput
,
raw_request
:
Request
):
adapted_request
=
v1_rerank_request
(
obj
)
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
response
=
v1_rerank_response
(
ret
,
obj
,
)
return
response
def
to_openai_style_logprobs
(
input_token_logprobs
=
None
,
output_token_logprobs
=
None
,
...
...
python/sglang/srt/openai_api/protocol.py
View file @
e30ef368
...
...
@@ -539,6 +539,13 @@ class ScoringResponse(BaseModel):
object
:
str
=
"scoring"
class
RerankResponse
(
BaseModel
):
score
:
float
document
:
str
index
:
int
meta_info
:
Optional
[
dict
]
=
None
def
exclude_if_none
(
obj
,
field_names
:
List
[
str
]):
omit_if_none_fields
=
{
k
for
k
,
v
in
obj
.
model_fields
.
items
()
if
k
in
field_names
}
return
{
k
:
v
for
k
,
v
in
obj
if
k
not
in
omit_if_none_fields
or
v
is
not
None
}
python/sglang/test/runners.py
View file @
e30ef368
...
...
@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
]
TEST_RERANK_QUERY_DOCS
=
[
{
"query"
:
"How many people live in Berlin?"
,
"documents"
:
[
"Berlin is well known for its museums."
,
],
},
{
"query"
:
"How many people live in Berlin?"
,
"documents"
:
[
"Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."
,
"Berlin is well known for its museums."
,
],
},
]
dirpath
=
os
.
path
.
dirname
(
__file__
)
with
open
(
os
.
path
.
join
(
dirpath
,
"long_prompt.txt"
),
"r"
)
as
f
:
...
...
@@ -241,7 +256,7 @@ class HFRunner:
self
.
model
=
_get_sentence_transformer_embedding_model
(
model_path
,
torch_dtype
)
elif
self
.
model_type
==
"reward"
:
elif
self
.
model_type
==
"reward"
or
self
.
model_type
==
"cross_encoder"
:
from
transformers
import
AutoModelForSequenceClassification
self
.
model
=
AutoModelForSequenceClassification
.
from_pretrained
(
...
...
@@ -303,6 +318,15 @@ class HFRunner:
else
:
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
elif
self
.
model_type
==
"cross_encoder"
:
inputs
=
self
.
tokenizer
(
prompts
,
padding
=
True
,
return_tensors
=
"pt"
).
to
(
"cuda"
)
scores
=
self
.
model
(
**
inputs
).
logits
scores
=
scores
.
squeeze
().
tolist
()
if
not
isinstance
(
scores
,
list
):
scores
=
[
scores
]
out_queue
.
put
(
ModelOutput
(
scores
=
scores
))
elif
self
.
model_type
==
"reward"
:
scores
=
[]
...
...
@@ -322,7 +346,9 @@ class HFRunner:
def
forward
(
self
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
prompts
:
Union
[
List
[
List
[
str
]],
List
[
str
],
List
[
torch
.
Tensor
]
]
=
DEFAULT_PROMPTS
,
image_data
:
Optional
[
List
[
str
]]
=
None
,
max_new_tokens
:
int
=
8
,
lora_paths
:
Optional
[
List
[
str
]]
=
None
,
...
...
@@ -526,7 +552,9 @@ class SRTRunner:
def
forward
(
self
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
prompts
:
Union
[
List
[
List
[
str
]],
List
[
str
],
List
[
torch
.
Tensor
]
]
=
DEFAULT_PROMPTS
,
image_data
:
Optional
[
List
[
str
]]
=
None
,
max_new_tokens
:
int
=
8
,
lora_paths
:
Optional
[
List
[
str
]]
=
None
,
...
...
@@ -552,6 +580,13 @@ class SRTRunner:
else
:
logits
=
[
response
[
"embedding"
]]
return
ModelOutput
(
embed_logits
=
logits
)
# cross encoder model
elif
self
.
model_type
==
"cross_encoder"
:
response
=
self
.
engine
.
rerank
(
prompts
)
if
not
isinstance
(
response
,
list
):
response
=
[
response
]
scores
=
[
x
[
"embedding"
]
for
x
in
response
]
return
ModelOutput
(
scores
=
scores
)
# reward model
else
:
response
=
self
.
engine
.
encode
(
prompts
)
...
...
python/sglang/test/test_utils.py
View file @
e30ef368
...
...
@@ -41,6 +41,8 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
=
"Qwen/Qwen1.5-MoE-A2.7B"
# MLA test models
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
=
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
=
"cross-encoder/ms-marco-MiniLM-L6-v2"
DEFAULT_MLA_MODEL_NAME_FOR_TEST
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_MLA
=
"lmsys/sglang-ci-dsv3-test"
...
...
python/sglang/utils.py
View file @
e30ef368
...
...
@@ -512,3 +512,12 @@ async def async_stream_and_merge(llm, prompt, sampling_params):
cleaned_chunk
=
trim_overlap
(
final_text
,
chunk_text
)
final_text
+=
cleaned_chunk
yield
cleaned_chunk
# yield the non-overlapping portion
def
resolve_obj_by_qualname
(
qualname
:
str
)
->
Any
:
"""
Resolve an object by its fully qualified name.
"""
module_name
,
obj_name
=
qualname
.
rsplit
(
"."
,
1
)
module
=
importlib
.
import_module
(
module_name
)
return
getattr
(
module
,
obj_name
)
test/srt/models/test_cross_encoder_models.py
0 → 100644
View file @
e30ef368
import
multiprocessing
as
mp
import
random
import
unittest
import
torch
from
transformers
import
AutoConfig
,
AutoTokenizer
from
sglang.test.runners
import
TEST_RERANK_QUERY_DOCS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
CustomTestCase
,
is_in_ci
MODELS
=
[
(
"cross-encoder/ms-marco-MiniLM-L6-v2"
,
1
,
1e-2
),
(
"BAAI/bge-reranker-v2-m3"
,
1
,
1e-2
),
]
ATTENTION_BACKEND
=
[
"torch_native"
,
"triton"
]
TORCH_DTYPES
=
[
torch
.
float32
]
class
TestCrossEncoderModels
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
assert_close_prefill_logits
(
self
,
prompts
,
model_path
,
tp_size
,
torch_dtype
,
score_tolerance
,
attention_backend
,
)
->
None
:
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"cross_encoder"
,
)
as
hf_runner
:
hf_scores
=
hf_runner
.
forward
(
prompts
).
scores
with
SRTRunner
(
model_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
model_type
=
"cross_encoder"
,
attention_backend
=
attention_backend
,
chunked_prefill_size
=-
1
,
disable_radix_cache
=
True
,
)
as
srt_runner
:
srt_scores
=
srt_runner
.
forward
(
prompts
).
scores
for
i
in
range
(
len
(
srt_scores
)):
score_difference
=
abs
(
hf_scores
[
i
]
-
srt_scores
[
i
])
assert
(
score_difference
<
score_tolerance
),
"cross encoder scores are not all close"
def
preprocess_prompts
(
self
,
prompt
):
processed_prompts
=
[]
query
=
prompt
[
"query"
]
documents
=
prompt
[
"documents"
]
for
document
in
documents
:
processed_prompts
.
append
([
query
,
document
])
return
processed_prompts
def
test_prefill_logits
(
self
):
models_to_test
=
MODELS
if
is_in_ci
():
models_to_test
=
[
random
.
choice
(
MODELS
)]
for
model
,
tp_size
,
prefill_tolerance
in
models_to_test
:
for
attention_backend
in
ATTENTION_BACKEND
:
for
queryDocs
in
TEST_RERANK_QUERY_DOCS
:
prompts
=
self
.
preprocess_prompts
(
queryDocs
)
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_prefill_logits
(
prompts
,
model
,
tp_size
,
torch_dtype
,
prefill_tolerance
,
attention_backend
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
e30ef368
...
...
@@ -19,6 +19,8 @@ suites = {
TestFile
(
"models/lora/test_lora_cuda_graph.py"
,
250
),
TestFile
(
"models/test_embedding_models.py"
,
73
),
# TestFile("models/test_clip_models.py", 52),
TestFile
(
"models/test_encoder_embedding_models.py"
,
100
),
TestFile
(
"models/test_cross_encoder_models.py"
,
100
),
TestFile
(
"models/test_compressed_tensors_models.py"
,
42
),
TestFile
(
"models/test_generation_models.py"
,
103
),
# TestFile("models/test_gme_qwen_models.py", 45),
...
...
test/srt/test_openai_server.py
View file @
e30ef368
...
...
@@ -17,7 +17,9 @@ import requests
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.runners
import
TEST_RERANK_QUERY_DOCS
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
...
@@ -699,6 +701,77 @@ class TestOpenAIEmbedding(CustomTestCase):
self
.
assertEqual
(
cm
.
exception
.
status_code
,
400
)
class
TestOpenAIV1Rerank
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
score_tolerance
=
1e-2
# Configure embedding-specific args
other_args
=
[
"--is-embedding"
,
"--enable-metrics"
,
"--disable-radix-cache"
,
"--chunked-prefill-size"
,
"-1"
,
"--attention-backend"
,
"torch_native"
,
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
other_args
=
other_args
,
)
cls
.
base_url
+=
"/v1/rerank"
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
run_rerank
(
self
,
query
,
docs
):
response
=
requests
.
post
(
self
.
base_url
,
headers
=
{
"Authorization"
:
f
"Bearer
{
self
.
api_key
}
"
,
"Content-Type"
:
"application/json"
,
},
json
=
{
"query"
:
query
,
"documents"
:
docs
},
)
return
response
.
json
()
def
test_rerank_single
(
self
):
"""Test single rerank request"""
query
=
TEST_RERANK_QUERY_DOCS
[
0
][
"query"
]
docs
=
TEST_RERANK_QUERY_DOCS
[
0
][
"documents"
]
response
=
self
.
run_rerank
(
query
,
docs
)
self
.
assertEqual
(
len
(
response
),
1
)
self
.
assertTrue
(
isinstance
(
response
[
0
][
"score"
],
float
))
self
.
assertTrue
(
isinstance
(
response
[
0
][
"document"
],
str
))
self
.
assertTrue
(
isinstance
(
response
[
0
][
"index"
],
int
))
def
test_rerank_batch
(
self
):
"""Test batch rerank request"""
query
=
TEST_RERANK_QUERY_DOCS
[
1
][
"query"
]
docs
=
TEST_RERANK_QUERY_DOCS
[
1
][
"documents"
]
response
=
self
.
run_rerank
(
query
,
docs
)
self
.
assertEqual
(
len
(
response
),
2
)
self
.
assertTrue
(
isinstance
(
response
[
0
][
"score"
],
float
))
self
.
assertTrue
(
isinstance
(
response
[
1
][
"score"
],
float
))
self
.
assertTrue
(
isinstance
(
response
[
0
][
"document"
],
str
))
self
.
assertTrue
(
isinstance
(
response
[
1
][
"document"
],
str
))
self
.
assertTrue
(
isinstance
(
response
[
0
][
"index"
],
int
))
self
.
assertTrue
(
isinstance
(
response
[
1
][
"index"
],
int
))
class
TestOpenAIServerIgnoreEOS
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment