Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e30ef368
"server/vscode:/vscode.git/clone" did not exist on "64142489b69d394cf4801d7265d4b2c3443225a0"
Unverified
Commit
e30ef368
authored
Jun 17, 2025
by
woodx
Committed by
GitHub
Jun 16, 2025
Browse files
Feat/support rerank (#6058)
parent
91a066ec
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
684 additions
and
30 deletions
+684
-30
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+5
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+14
-0
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+11
-0
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+19
-0
python/sglang/srt/layers/pooler.py
python/sglang/srt/layers/pooler.py
+56
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+21
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+21
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+15
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+5
-1
python/sglang/srt/models/bert.py
python/sglang/srt/models/bert.py
+113
-13
python/sglang/srt/models/roberta.py
python/sglang/srt/models/roberta.py
+117
-9
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+64
-1
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+7
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+38
-3
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+2
-0
python/sglang/utils.py
python/sglang/utils.py
+9
-0
test/srt/models/test_cross_encoder_models.py
test/srt/models/test_cross_encoder_models.py
+91
-0
test/srt/run_suite.py
test/srt/run_suite.py
+2
-0
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+73
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
e30ef368
...
@@ -550,6 +550,11 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
...
@@ -550,6 +550,11 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
or
"Qwen2ForRewardModel"
in
model_architectures
or
"Qwen2ForRewardModel"
in
model_architectures
or
"Qwen2ForSequenceClassification"
in
model_architectures
or
"Qwen2ForSequenceClassification"
in
model_architectures
or
"CLIPModel"
in
model_architectures
or
"CLIPModel"
in
model_architectures
or
"BertModel"
in
model_architectures
or
"Contriever"
in
model_architectures
or
"BertForSequenceClassification"
in
model_architectures
or
"XLMRobertaModel"
in
model_architectures
or
"XLMRobertaForSequenceClassification"
in
model_architectures
):
):
return
False
return
False
else
:
else
:
...
...
python/sglang/srt/entrypoints/engine.py
View file @
e30ef368
...
@@ -327,6 +327,20 @@ class Engine(EngineBase):
...
@@ -327,6 +327,20 @@ class Engine(EngineBase):
generator
=
self
.
tokenizer_manager
.
generate_request
(
obj
,
None
)
generator
=
self
.
tokenizer_manager
.
generate_request
(
obj
,
None
)
return
await
generator
.
__anext__
()
return
await
generator
.
__anext__
()
def
rerank
(
self
,
prompt
:
Union
[
List
[
List
[
str
]]],
)
->
Dict
:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
Please refer to `EmbeddingReqInput` for the documentation.
"""
obj
=
EmbeddingReqInput
(
text
=
prompt
,
is_cross_encoder_request
=
True
)
loop
=
asyncio
.
get_event_loop
()
generator
=
self
.
tokenizer_manager
.
generate_request
(
obj
,
None
)
ret
=
loop
.
run_until_complete
(
generator
.
__anext__
())
return
ret
def
shutdown
(
self
):
def
shutdown
(
self
):
"""Shutdown the engine"""
"""Shutdown the engine"""
kill_process_tree
(
os
.
getpid
(),
include_parent
=
False
)
kill_process_tree
(
os
.
getpid
(),
include_parent
=
False
)
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
e30ef368
...
@@ -67,6 +67,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -67,6 +67,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
V1RerankReqInput
,
VertexGenerateReqInput
,
VertexGenerateReqInput
,
)
)
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
...
@@ -79,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
...
@@ -79,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
v1_delete_file
,
v1_delete_file
,
v1_embeddings
,
v1_embeddings
,
v1_files_create
,
v1_files_create
,
v1_rerank
,
v1_retrieve_batch
,
v1_retrieve_batch
,
v1_retrieve_file
,
v1_retrieve_file
,
v1_retrieve_file_content
,
v1_retrieve_file_content
,
...
@@ -328,6 +330,15 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
...
@@ -328,6 +330,15 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return
_create_error_response
(
e
)
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/v1/rerank"
,
methods
=
[
"POST"
,
"PUT"
])
async
def
v1_rerank_request
(
obj
:
V1RerankReqInput
,
raw_request
:
Request
):
try
:
ret
=
await
v1_rerank
(
_global_state
.
tokenizer_manager
,
obj
,
raw_request
)
return
ret
except
ValueError
as
e
:
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/flush_cache"
,
methods
=
[
"GET"
,
"POST"
])
@
app
.
api_route
(
"/flush_cache"
,
methods
=
[
"GET"
,
"POST"
])
async
def
flush_cache
():
async
def
flush_cache
():
"""Flush the radix cache."""
"""Flush the radix cache."""
...
...
python/sglang/srt/layers/activation.py
View file @
e30ef368
...
@@ -20,6 +20,7 @@ from typing import Optional
...
@@ -20,6 +20,7 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
...
@@ -29,6 +30,7 @@ from sglang.srt.distributed import (
...
@@ -29,6 +30,7 @@ from sglang.srt.distributed import (
)
)
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.utils
import
is_cuda
,
set_weight_attrs
from
sglang.srt.utils
import
is_cuda
,
set_weight_attrs
from
sglang.utils
import
resolve_obj_by_qualname
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -165,6 +167,23 @@ def get_act_fn(
...
@@ -165,6 +167,23 @@ def get_act_fn(
return
act_fn
return
act_fn
def
get_cross_encoder_activation_function
(
config
:
PretrainedConfig
):
if
(
hasattr
(
config
,
"sbert_ce_default_activation_function"
)
and
config
.
sbert_ce_default_activation_function
is
not
None
):
function_name
=
config
.
sbert_ce_default_activation_function
assert
function_name
.
startswith
(
"torch.nn.modules."
),
(
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
return
resolve_obj_by_qualname
(
function_name
)()
else
:
# adapt bge-reranker
return
nn
.
Identity
()
if
not
_is_cuda
:
if
not
_is_cuda
:
logger
.
info
(
logger
.
info
(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
...
...
python/sglang/srt/layers/pooler.py
View file @
e30ef368
...
@@ -3,10 +3,13 @@
...
@@ -3,10 +3,13 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
enum
import
IntEnum
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
sglang.srt.layers.activation
import
get_cross_encoder_activation_function
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
...
@@ -54,3 +57,56 @@ class Pooler(nn.Module):
...
@@ -54,3 +57,56 @@ class Pooler(nn.Module):
pooled_data
=
nn
.
functional
.
normalize
(
pooled_data
,
p
=
2
,
dim
=
1
)
pooled_data
=
nn
.
functional
.
normalize
(
pooled_data
,
p
=
2
,
dim
=
1
)
return
EmbeddingPoolerOutput
(
embeddings
=
pooled_data
)
return
EmbeddingPoolerOutput
(
embeddings
=
pooled_data
)
class
CrossEncodingPooler
(
nn
.
Module
):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `EmbeddingPoolerOutput`.
"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
classifier
:
nn
.
Module
,
pooler
:
Optional
[
nn
.
Module
]
=
None
,
):
super
().
__init__
()
self
.
classifier
=
classifier
self
.
pooler
=
pooler
self
.
default_activation_function
=
get_cross_encoder_activation_function
(
config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
EmbeddingPoolerOutput
:
"""Pools sentence pair scores from the hidden_states."""
prompt_lens
=
forward_batch
.
extend_seq_lens
offset
=
0
pooled_data_lst
=
[]
for
prompt_len
in
prompt_lens
:
pooled_data_i
=
hidden_states
[
offset
:
offset
+
prompt_len
]
if
self
.
pooler
is
not
None
:
final_shape_tensor
=
self
.
pooler
(
pooled_data_i
,
forward_batch
)
else
:
final_shape_tensor
=
self
.
classifier
(
pooled_data_i
)
pooled_data_lst
.
append
(
final_shape_tensor
)
offset
+=
prompt_len
pooled_output
=
torch
.
stack
(
pooled_data_lst
)
if
self
.
pooler
is
not
None
:
# apply classifier once on the full batch if possible
pooled_output
=
self
.
classifier
(
pooled_output
)
scores
=
self
.
default_activation_function
(
pooled_output
).
squeeze
(
-
1
)
return
EmbeddingPoolerOutput
(
embeddings
=
scores
)
python/sglang/srt/managers/io_struct.py
View file @
e30ef368
...
@@ -481,7 +481,7 @@ class TokenizedGenerateReqInput:
...
@@ -481,7 +481,7 @@ class TokenizedGenerateReqInput:
@
dataclass
@
dataclass
class
EmbeddingReqInput
:
class
EmbeddingReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
text
:
Optional
[
Union
[
List
[
List
[
str
]],
List
[
str
],
str
]]
=
None
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
# Can be formatted as:
# Can be formatted as:
# - Single image for a single request
# - Single image for a single request
...
@@ -505,6 +505,8 @@ class EmbeddingReqInput:
...
@@ -505,6 +505,8 @@ class EmbeddingReqInput:
log_metrics
:
bool
=
True
log_metrics
:
bool
=
True
# The modalities of the image data [image, multi-images, video]
# The modalities of the image data [image, multi-images, video]
modalities
:
Optional
[
List
[
str
]]
=
None
modalities
:
Optional
[
List
[
str
]]
=
None
# For cross-encoder requests
is_cross_encoder_request
:
bool
=
False
def
contains_mm_input
(
self
)
->
bool
:
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
...
@@ -564,6 +566,16 @@ class EmbeddingReqInput:
...
@@ -564,6 +566,16 @@ class EmbeddingReqInput:
return
self
.
rid
return
self
.
rid
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
if
self
.
is_cross_encoder_request
:
return
EmbeddingReqInput
(
text
=
[
self
.
text
[
i
]]
if
self
.
text
is
not
None
else
None
,
input_ids
=
None
,
image_data
=
None
,
sampling_params
=
self
.
sampling_params
[
i
],
rid
=
self
.
rid
[
i
],
is_cross_encoder_request
=
True
,
)
return
EmbeddingReqInput
(
return
EmbeddingReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
...
@@ -583,6 +595,8 @@ class TokenizedEmbeddingReqInput:
...
@@ -583,6 +595,8 @@ class TokenizedEmbeddingReqInput:
input_ids
:
List
[
int
]
input_ids
:
List
[
int
]
# The image inputs
# The image inputs
image_inputs
:
dict
image_inputs
:
dict
# The token type ids
token_type_ids
:
List
[
int
]
# Dummy sampling params for compatibility
# Dummy sampling params for compatibility
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
...
@@ -847,6 +861,12 @@ class SetInternalStateReq:
...
@@ -847,6 +861,12 @@ class SetInternalStateReq:
server_args
:
Dict
[
str
,
Any
]
server_args
:
Dict
[
str
,
Any
]
@
dataclass
class
V1RerankReqInput
:
query
:
str
documents
:
List
[
str
]
@
dataclass
@
dataclass
class
SetInternalStateReqOutput
:
class
SetInternalStateReqOutput
:
updated
:
bool
updated
:
bool
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e30ef368
...
@@ -445,6 +445,7 @@ class Req:
...
@@ -445,6 +445,7 @@ class Req:
origin_input_ids_unpadded
:
Optional
[
Tuple
[
int
]]
=
None
,
origin_input_ids_unpadded
:
Optional
[
Tuple
[
int
]]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
input_embeds
:
Optional
[
List
[
List
[
float
]]]
=
None
,
input_embeds
:
Optional
[
List
[
List
[
float
]]]
=
None
,
token_type_ids
:
List
[
int
]
=
None
,
session_id
:
Optional
[
str
]
=
None
,
session_id
:
Optional
[
str
]
=
None
,
custom_logit_processor
:
Optional
[
str
]
=
None
,
custom_logit_processor
:
Optional
[
str
]
=
None
,
return_hidden_states
:
bool
=
False
,
return_hidden_states
:
bool
=
False
,
...
@@ -470,6 +471,9 @@ class Req:
...
@@ -470,6 +471,9 @@ class Req:
self
.
session_id
=
session_id
self
.
session_id
=
session_id
self
.
input_embeds
=
input_embeds
self
.
input_embeds
=
input_embeds
# for corss-endoder model
self
.
token_type_ids
=
token_type_ids
# Sampling info
# Sampling info
if
isinstance
(
sampling_params
.
custom_params
,
dict
):
if
isinstance
(
sampling_params
.
custom_params
,
dict
):
sampling_params
=
copy
.
copy
(
sampling_params
)
sampling_params
=
copy
.
copy
(
sampling_params
)
...
@@ -841,6 +845,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -841,6 +845,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Batched arguments to model runner
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
input_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
input_embeds
:
torch
.
Tensor
=
None
# shape: [b, hidden_size], float32
input_embeds
:
torch
.
Tensor
=
None
# shape: [b, hidden_size], float32
token_type_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
req_pool_indices
:
torch
.
Tensor
=
None
# shape: [b], int64
req_pool_indices
:
torch
.
Tensor
=
None
# shape: [b], int64
seq_lens
:
torch
.
Tensor
=
None
# shape: [b], int64
seq_lens
:
torch
.
Tensor
=
None
# shape: [b], int64
# The output locations of the KV cache
# The output locations of the KV cache
...
@@ -1142,6 +1147,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1142,6 +1147,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
token_type_ids
=
[
r
.
token_type_ids
for
r
in
reqs
if
r
.
token_type_ids
is
not
None
]
req_pool_indices_tensor
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
).
to
(
req_pool_indices_tensor
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
)
)
...
@@ -1154,6 +1163,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1154,6 +1163,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
)
token_type_ids_tensor
=
None
if
len
(
token_type_ids
)
>
0
:
token_type_ids_tensor
=
torch
.
tensor
(
sum
(
token_type_ids
,
[]),
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
)
extend_lens_tensor
=
seq_lens_tensor
-
prefix_lens_tensor
extend_lens_tensor
=
seq_lens_tensor
-
prefix_lens_tensor
# Copy prefix and do some basic check
# Copy prefix and do some basic check
...
@@ -1269,6 +1285,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1269,6 +1285,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
)
)
self
.
multimodal_inputs
=
multimodal_inputs
self
.
multimodal_inputs
=
multimodal_inputs
self
.
token_type_ids
=
token_type_ids_tensor
self
.
seq_lens_sum
=
sum
(
seq_lens
)
self
.
seq_lens_sum
=
sum
(
seq_lens
)
if
self
.
return_logprob
:
if
self
.
return_logprob
:
...
@@ -1714,6 +1731,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1714,6 +1731,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
sampling_info
=
self
.
sampling_info
,
sampling_info
=
self
.
sampling_info
,
input_embeds
=
self
.
input_embeds
,
input_embeds
=
self
.
input_embeds
,
token_type_ids
=
self
.
token_type_ids
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_info
=
self
.
spec_info
,
spec_info
=
self
.
spec_info
,
capture_hidden_mode
=
(
capture_hidden_mode
=
(
...
@@ -1807,6 +1825,9 @@ class ModelWorkerBatch:
...
@@ -1807,6 +1825,9 @@ class ModelWorkerBatch:
# The input Embeds
# The input Embeds
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
# For corss-encoder model
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Speculative decoding
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
e30ef368
...
@@ -1150,6 +1150,7 @@ class Scheduler(
...
@@ -1150,6 +1150,7 @@ class Scheduler(
recv_req
.
input_text
,
recv_req
.
input_text
,
recv_req
.
input_ids
,
recv_req
.
input_ids
,
recv_req
.
sampling_params
,
recv_req
.
sampling_params
,
token_type_ids
=
recv_req
.
token_type_ids
,
)
)
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
e30ef368
...
@@ -459,6 +459,10 @@ class TokenizerManager:
...
@@ -459,6 +459,10 @@ class TokenizerManager:
# Tokenize
# Tokenize
input_embeds
=
None
input_embeds
=
None
input_text
=
obj
.
text
input_text
=
obj
.
text
token_type_ids
=
None
is_cross_encoder_request
=
(
isinstance
(
obj
,
EmbeddingReqInput
)
and
obj
.
is_cross_encoder_request
)
if
obj
.
input_embeds
is
not
None
:
if
obj
.
input_embeds
is
not
None
:
if
not
self
.
server_args
.
disable_radix_cache
:
if
not
self
.
server_args
.
disable_radix_cache
:
raise
ValueError
(
raise
ValueError
(
...
@@ -477,7 +481,14 @@ class TokenizerManager:
...
@@ -477,7 +481,14 @@ class TokenizerManager:
"accept text prompts. Please provide input_ids or re-initialize "
"accept text prompts. Please provide input_ids or re-initialize "
"the engine with skip_tokenizer_init=False."
"the engine with skip_tokenizer_init=False."
)
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
encoded
=
self
.
tokenizer
(
input_text
,
return_token_type_ids
=
is_cross_encoder_request
)
input_ids
=
encoded
[
"input_ids"
]
if
is_cross_encoder_request
:
input_ids
=
encoded
[
"input_ids"
][
0
]
token_type_ids
=
encoded
.
get
(
"token_type_ids"
,
[
None
])[
0
]
if
self
.
mm_processor
and
obj
.
contains_mm_input
():
if
self
.
mm_processor
and
obj
.
contains_mm_input
():
image_inputs
=
await
self
.
mm_processor
.
process_mm_data_async
(
image_inputs
=
await
self
.
mm_processor
.
process_mm_data_async
(
...
@@ -493,7 +504,7 @@ class TokenizerManager:
...
@@ -493,7 +504,7 @@ class TokenizerManager:
self
.
_validate_token_len
(
obj
,
input_ids
)
self
.
_validate_token_len
(
obj
,
input_ids
)
return
self
.
_create_tokenized_object
(
return
self
.
_create_tokenized_object
(
obj
,
input_text
,
input_ids
,
input_embeds
,
image_inputs
obj
,
input_text
,
input_ids
,
input_embeds
,
image_inputs
,
token_type_ids
)
)
def
_validate_token_len
(
def
_validate_token_len
(
...
@@ -532,6 +543,7 @@ class TokenizerManager:
...
@@ -532,6 +543,7 @@ class TokenizerManager:
input_ids
:
List
[
int
],
input_ids
:
List
[
int
],
input_embeds
:
Optional
[
Union
[
List
[
float
],
None
]]
=
None
,
input_embeds
:
Optional
[
Union
[
List
[
float
],
None
]]
=
None
,
image_inputs
:
Optional
[
Dict
]
=
None
,
image_inputs
:
Optional
[
Dict
]
=
None
,
token_type_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
]:
)
->
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
]:
"""Create a tokenized request object from common parameters."""
"""Create a tokenized request object from common parameters."""
...
@@ -592,6 +604,7 @@ class TokenizerManager:
...
@@ -592,6 +604,7 @@ class TokenizerManager:
input_text
,
input_text
,
input_ids
,
input_ids
,
image_inputs
,
image_inputs
,
token_type_ids
,
sampling_params
,
sampling_params
,
)
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
e30ef368
...
@@ -224,6 +224,9 @@ class ForwardBatch:
...
@@ -224,6 +224,9 @@ class ForwardBatch:
# For input embeddings
# For input embeddings
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
# For cross-encoder model
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Sampling info
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
sampling_info
:
SamplingBatchInfo
=
None
...
@@ -300,6 +303,7 @@ class ForwardBatch:
...
@@ -300,6 +303,7 @@ class ForwardBatch:
spec_info
=
batch
.
spec_info
,
spec_info
=
batch
.
spec_info
,
capture_hidden_mode
=
batch
.
capture_hidden_mode
,
capture_hidden_mode
=
batch
.
capture_hidden_mode
,
input_embeds
=
batch
.
input_embeds
,
input_embeds
=
batch
.
input_embeds
,
token_type_ids
=
batch
.
token_type_ids
,
tbo_split_seq_index
=
batch
.
tbo_split_seq_index
,
tbo_split_seq_index
=
batch
.
tbo_split_seq_index
,
)
)
device
=
model_runner
.
device
device
=
model_runner
.
device
...
@@ -356,8 +360,8 @@ class ForwardBatch:
...
@@ -356,8 +360,8 @@ class ForwardBatch:
ret
.
extend_prefix_lens
=
torch
.
tensor
(
ret
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
extend_prefix_lens
,
dtype
=
torch
.
int32
batch
.
extend_prefix_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
).
to
(
device
,
non_blocking
=
True
)
if
support_triton
(
model_runner
.
server_args
.
attention_backend
):
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
if
support_triton
(
model_runner
.
server_args
.
attention_backend
):
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
ret
.
extend_prefix_lens
,
ret
.
extend_prefix_lens
,
ret
.
extend_seq_lens
,
ret
.
extend_seq_lens
,
...
...
python/sglang/srt/models/bert.py
View file @
e30ef368
...
@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import (
...
@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import (
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.pooler
import
Embed
dingPooler
Output
,
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
CrossEnco
dingPooler
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
AttentionType
,
RadixAttention
from
sglang.srt.layers.radix_attention
import
AttentionType
,
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
BertConfig
=
None
BertConfig
=
None
...
@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module):
...
@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
input_shape
=
input_ids
.
size
()
...
@@ -58,8 +60,11 @@ class BertEmbedding(nn.Module):
...
@@ -58,8 +60,11 @@ class BertEmbedding(nn.Module):
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position
_id
s
)
position_embeddings
=
self
.
position_embeddings
(
positions
)
token_type_ids
=
forward_batch
.
token_type_ids
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
)
...
@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module):
...
@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module):
return
embeddings
return
embeddings
class
BertPooler
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
# simply taking the hidden state corresponding
first_token_tensor
=
hidden_states
[
0
,
:]
pooled_output
=
self
.
dense
(
first_token_tensor
)
pooled_output
=
self
.
activation
(
pooled_output
)
return
pooled_output
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -113,6 +137,8 @@ class BertLayer(nn.Module):
...
@@ -113,6 +137,8 @@ class BertLayer(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
attention
=
BertAttention
(
self
.
attention
=
BertAttention
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
num_attention_heads
=
config
.
num_attention_heads
,
num_attention_heads
=
config
.
num_attention_heads
,
...
@@ -142,6 +168,7 @@ class BertLayer(nn.Module):
...
@@ -142,6 +168,7 @@ class BertLayer(nn.Module):
attn_output
=
self
.
attention
(
hidden_states
,
forward_batch
)
attn_output
=
self
.
attention
(
hidden_states
,
forward_batch
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
return
output
return
output
...
@@ -326,16 +353,23 @@ class BertModel(nn.Module):
...
@@ -326,16 +353,23 @@ class BertModel(nn.Module):
*
,
*
,
config
:
BertConfig
,
config
:
BertConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
use_bert_pooler
:
bool
=
False
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
use_bert_pooler
=
use_bert_pooler
self
.
config
=
config
self
.
config
=
config
self
.
embeddings
=
BertEmbedding
(
config
)
self
.
embeddings
=
BertEmbedding
(
config
)
self
.
encoder
=
BertEncoder
(
self
.
encoder
=
BertEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"encoder"
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"encoder"
,
prefix
),
)
self
.
pooler
=
(
BertPooler
(
config
)
if
self
.
use_bert_pooler
else
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
)
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# self.pooler = BertPooler(config)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -351,11 +385,16 @@ class BertModel(nn.Module):
...
@@ -351,11 +385,16 @@ class BertModel(nn.Module):
hidden_states
=
self
.
embeddings
(
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
position_ids
=
positions
,
positions
=
positions
,
forward_batch
=
forward_batch
,
)
)
hidden_states
=
self
.
encoder
(
hidden_states
,
forward_batch
=
forward_batch
)
hidden_states
=
self
.
encoder
(
hidden_states
,
forward_batch
=
forward_batch
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
if
not
self
.
use_bert_pooler
:
hidden_states
=
self
.
pooler
(
hidden_states
,
forward_batch
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -368,7 +407,7 @@ class BertModel(nn.Module):
...
@@ -368,7 +407,7 @@ class BertModel(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
name
=
name
.
replace
(
"self"
,
"self_attn"
)
name
=
name
.
replace
(
"self"
,
"self_attn"
)
if
"pooler"
in
name
:
if
not
self
.
use_bert_pooler
and
"pooler"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
@@ -395,4 +434,65 @@ class Contriever(BertModel):
...
@@ -395,4 +434,65 @@ class Contriever(BertModel):
pass
pass
EntryClass
=
[
BertModel
,
Contriever
]
class
BertForSequenceClassification
(
nn
.
Module
):
def
__init__
(
self
,
*
,
config
:
BertConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
num_labels
=
config
.
num_labels
self
.
bert
=
BertModel
(
config
=
config
,
quant_config
=
quant_config
,
use_bert_pooler
=
True
,
prefix
=
add_prefix
(
"bert"
,
prefix
),
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
pooler
=
CrossEncodingPooler
(
config
,
self
.
classifier
,
self
.
bert
.
pooler
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self_weights
=
[]
def
weight_filter
():
for
name
,
weight
in
weights
:
if
name
.
startswith
(
"bert."
):
yield
(
name
[
len
(
"bert."
)
:],
weight
)
else
:
self_weights
.
append
((
name
,
weight
))
self
.
bert
.
load_weights
(
weight_filter
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
self_weights
:
if
name
.
startswith
(
"classifier"
):
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
assert
get_embedding
==
True
hidden_states
=
self
.
bert
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
input_embeds
,
get_embedding
=
get_embedding
,
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
EntryClass
=
[
BertModel
,
Contriever
,
BertForSequenceClassification
]
python/sglang/srt/models/roberta.py
View file @
e30ef368
...
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
...
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
CrossEncodingPooler
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder
...
@@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder
RobertaConfig
=
None
RobertaConfig
=
None
# Adapted from transformers
class
RobertaClassificationHead
(
nn
.
Module
):
"""Head for sentence-level classification tasks."""
def
__init__
(
self
,
config
:
RobertaConfig
):
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
out_proj
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
def
forward
(
self
,
features
,
**
kwargs
):
x
=
features
[
0
,
:]
# take <s> token (equiv. to [CLS])
x
=
self
.
dense
(
x
)
x
=
torch
.
tanh
(
x
)
x
=
self
.
out_proj
(
x
)
return
x
class
RobertaEmbedding
(
nn
.
Module
):
class
RobertaEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
RobertaConfig
):
def
__init__
(
self
,
config
:
RobertaConfig
):
...
@@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module):
...
@@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
inputs_embeds
=
None
,
forward_batch
:
ForwardBatch
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
input_shape
=
input_ids
.
size
()
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
...
@@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module):
...
@@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module):
# Position embeddings.
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
token_type_ids
=
forward_batch
.
token_type_ids
if
token_type_ids
is
None
:
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
...
@@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module):
...
@@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module):
return
embeddings
return
embeddings
class
XLMRobertaModel
(
nn
.
Module
):
class
XLMRoberta
Base
Model
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
config
:
RobertaConfig
,
config
:
RobertaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
add_pooling_layer
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
embeddings
=
RobertaEmbedding
(
config
)
self
.
embeddings
=
RobertaEmbedding
(
config
)
self
.
encoder
=
BertEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
""
)
self
.
encoder
=
BertEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
""
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
CLS
,
normalize
=
True
)
self
.
pooler
=
(
Pooler
(
pooling_type
=
PoolingType
.
CLS
,
normalize
=
True
)
if
add_pooling_layer
else
None
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module):
...
@@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
position_ids
=
positions
,
position_ids
=
positions
,
seq_lens
=
forward_batch
.
seq_lens
,
seq_lens
=
forward_batch
.
seq_lens
,
forward_batch
=
forward_batch
,
)
)
hidden_states
=
self
.
encoder
(
hidden_states
,
forward_batch
=
forward_batch
)
hidden_states
=
self
.
encoder
(
hidden_states
,
forward_batch
=
forward_batch
)
pooler_out
=
self
.
pooler
(
hidden_states
,
forward_batch
)
return
pooler_out
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module):
...
@@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
name
=
name
.
replace
(
"self"
,
"self_attn"
)
name
=
name
.
replace
(
"self"
,
"self_attn"
)
if
"pooler"
in
name
:
if
self
.
pooler
is
None
and
"pooler"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
@@ -175,4 +199,88 @@ def create_position_ids_from_input_ids(
...
@@ -175,4 +199,88 @@ def create_position_ids_from_input_ids(
return
incremental_indices
.
long
()
+
padding_idx
return
incremental_indices
.
long
()
+
padding_idx
EntryClass
=
[
XLMRobertaModel
]
class
XLMRobertaModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
config
:
RobertaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
roberta
=
XLMRobertaBaseModel
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
CLS
,
normalize
=
True
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
roberta
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
get_embedding
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
.
roberta
.
load_weights
(
weights
)
class
XLMRobertaForSequenceClassification
(
nn
.
Module
):
def
__init__
(
self
,
*
,
config
:
RobertaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
roberta
=
XLMRobertaBaseModel
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
classifier
=
RobertaClassificationHead
(
config
)
self
.
pooler
=
CrossEncodingPooler
(
config
,
self
.
classifier
,
self
.
roberta
.
pooler
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
True
,
)
->
torch
.
Tensor
:
assert
(
get_embedding
),
"XLMRobertaForSequenceClassification is only used for rerank"
hidden_states
=
self
.
roberta
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
get_embedding
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self_weights
=
[]
def
weight_filter
():
for
name
,
weight
in
weights
:
if
name
.
startswith
(
"roberta."
):
yield
(
name
[
len
(
"roberta."
)
:],
weight
)
else
:
self_weights
.
append
((
name
,
weight
))
self
.
roberta
.
load_weights
(
weight_filter
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
self_weights
:
if
name
.
startswith
(
"classifier"
):
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
[
XLMRobertaModel
,
XLMRobertaForSequenceClassification
]
python/sglang/srt/openai_api/adapter.py
View file @
e30ef368
...
@@ -41,7 +41,11 @@ from sglang.srt.conversation import (
...
@@ -41,7 +41,11 @@ from sglang.srt.conversation import (
register_conv_template
,
register_conv_template
,
)
)
from
sglang.srt.function_call.function_call_parser
import
FunctionCallParser
from
sglang.srt.function_call.function_call_parser
import
FunctionCallParser
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.managers.io_struct
import
(
EmbeddingReqInput
,
GenerateReqInput
,
V1RerankReqInput
,
)
from
sglang.srt.openai_api.protocol
import
(
from
sglang.srt.openai_api.protocol
import
(
BatchRequest
,
BatchRequest
,
BatchResponse
,
BatchResponse
,
...
@@ -69,6 +73,7 @@ from sglang.srt.openai_api.protocol import (
...
@@ -69,6 +73,7 @@ from sglang.srt.openai_api.protocol import (
FunctionResponse
,
FunctionResponse
,
LogProbs
,
LogProbs
,
MultimodalEmbeddingInput
,
MultimodalEmbeddingInput
,
RerankResponse
,
ScoringRequest
,
ScoringRequest
,
ScoringResponse
,
ScoringResponse
,
ToolCall
,
ToolCall
,
...
@@ -2020,6 +2025,64 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
...
@@ -2020,6 +2025,64 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
return
response
return
response
def
v1_rerank_request
(
obj
:
V1RerankReqInput
):
if
obj
.
query
is
None
:
raise
ValueError
(
"query is required"
)
if
obj
.
documents
is
None
or
len
(
obj
.
documents
)
==
0
:
raise
ValueError
(
"documents is required"
)
pairs
=
[]
for
doc
in
obj
.
documents
:
pairs
.
append
([
obj
.
query
,
doc
])
adapted_request
=
EmbeddingReqInput
(
text
=
pairs
,
is_cross_encoder_request
=
True
,
)
return
adapted_request
def
v1_rerank_response
(
ret
,
obj
:
V1RerankReqInput
):
response
=
[]
for
idx
,
ret_item
in
enumerate
(
ret
):
response
.
append
(
RerankResponse
(
score
=
ret
[
idx
][
"embedding"
],
document
=
obj
.
documents
[
idx
],
index
=
idx
,
meta_info
=
ret
[
idx
][
"meta_info"
],
)
)
response
.
sort
(
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
return
response
async
def
v1_rerank
(
tokenizer_manager
,
obj
:
V1RerankReqInput
,
raw_request
:
Request
):
adapted_request
=
v1_rerank_request
(
obj
)
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
response
=
v1_rerank_response
(
ret
,
obj
,
)
return
response
def
to_openai_style_logprobs
(
def
to_openai_style_logprobs
(
input_token_logprobs
=
None
,
input_token_logprobs
=
None
,
output_token_logprobs
=
None
,
output_token_logprobs
=
None
,
...
...
python/sglang/srt/openai_api/protocol.py
View file @
e30ef368
...
@@ -539,6 +539,13 @@ class ScoringResponse(BaseModel):
...
@@ -539,6 +539,13 @@ class ScoringResponse(BaseModel):
object
:
str
=
"scoring"
object
:
str
=
"scoring"
class
RerankResponse
(
BaseModel
):
score
:
float
document
:
str
index
:
int
meta_info
:
Optional
[
dict
]
=
None
def
exclude_if_none
(
obj
,
field_names
:
List
[
str
]):
def
exclude_if_none
(
obj
,
field_names
:
List
[
str
]):
omit_if_none_fields
=
{
k
for
k
,
v
in
obj
.
model_fields
.
items
()
if
k
in
field_names
}
omit_if_none_fields
=
{
k
for
k
,
v
in
obj
.
model_fields
.
items
()
if
k
in
field_names
}
return
{
k
:
v
for
k
,
v
in
obj
if
k
not
in
omit_if_none_fields
or
v
is
not
None
}
return
{
k
:
v
for
k
,
v
in
obj
if
k
not
in
omit_if_none_fields
or
v
is
not
None
}
python/sglang/test/runners.py
View file @
e30ef368
...
@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
...
@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
# "The capital of France is",
]
]
TEST_RERANK_QUERY_DOCS
=
[
{
"query"
:
"How many people live in Berlin?"
,
"documents"
:
[
"Berlin is well known for its museums."
,
],
},
{
"query"
:
"How many people live in Berlin?"
,
"documents"
:
[
"Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."
,
"Berlin is well known for its museums."
,
],
},
]
dirpath
=
os
.
path
.
dirname
(
__file__
)
dirpath
=
os
.
path
.
dirname
(
__file__
)
with
open
(
os
.
path
.
join
(
dirpath
,
"long_prompt.txt"
),
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
dirpath
,
"long_prompt.txt"
),
"r"
)
as
f
:
...
@@ -241,7 +256,7 @@ class HFRunner:
...
@@ -241,7 +256,7 @@ class HFRunner:
self
.
model
=
_get_sentence_transformer_embedding_model
(
self
.
model
=
_get_sentence_transformer_embedding_model
(
model_path
,
torch_dtype
model_path
,
torch_dtype
)
)
elif
self
.
model_type
==
"reward"
:
elif
self
.
model_type
==
"reward"
or
self
.
model_type
==
"cross_encoder"
:
from
transformers
import
AutoModelForSequenceClassification
from
transformers
import
AutoModelForSequenceClassification
self
.
model
=
AutoModelForSequenceClassification
.
from_pretrained
(
self
.
model
=
AutoModelForSequenceClassification
.
from_pretrained
(
...
@@ -303,6 +318,15 @@ class HFRunner:
...
@@ -303,6 +318,15 @@ class HFRunner:
else
:
else
:
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
elif
self
.
model_type
==
"cross_encoder"
:
inputs
=
self
.
tokenizer
(
prompts
,
padding
=
True
,
return_tensors
=
"pt"
).
to
(
"cuda"
)
scores
=
self
.
model
(
**
inputs
).
logits
scores
=
scores
.
squeeze
().
tolist
()
if
not
isinstance
(
scores
,
list
):
scores
=
[
scores
]
out_queue
.
put
(
ModelOutput
(
scores
=
scores
))
elif
self
.
model_type
==
"reward"
:
elif
self
.
model_type
==
"reward"
:
scores
=
[]
scores
=
[]
...
@@ -322,7 +346,9 @@ class HFRunner:
...
@@ -322,7 +346,9 @@ class HFRunner:
def
forward
(
def
forward
(
self
,
self
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
prompts
:
Union
[
List
[
List
[
str
]],
List
[
str
],
List
[
torch
.
Tensor
]
]
=
DEFAULT_PROMPTS
,
image_data
:
Optional
[
List
[
str
]]
=
None
,
image_data
:
Optional
[
List
[
str
]]
=
None
,
max_new_tokens
:
int
=
8
,
max_new_tokens
:
int
=
8
,
lora_paths
:
Optional
[
List
[
str
]]
=
None
,
lora_paths
:
Optional
[
List
[
str
]]
=
None
,
...
@@ -526,7 +552,9 @@ class SRTRunner:
...
@@ -526,7 +552,9 @@ class SRTRunner:
def
forward
(
def
forward
(
self
,
self
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
prompts
:
Union
[
List
[
List
[
str
]],
List
[
str
],
List
[
torch
.
Tensor
]
]
=
DEFAULT_PROMPTS
,
image_data
:
Optional
[
List
[
str
]]
=
None
,
image_data
:
Optional
[
List
[
str
]]
=
None
,
max_new_tokens
:
int
=
8
,
max_new_tokens
:
int
=
8
,
lora_paths
:
Optional
[
List
[
str
]]
=
None
,
lora_paths
:
Optional
[
List
[
str
]]
=
None
,
...
@@ -552,6 +580,13 @@ class SRTRunner:
...
@@ -552,6 +580,13 @@ class SRTRunner:
else
:
else
:
logits
=
[
response
[
"embedding"
]]
logits
=
[
response
[
"embedding"
]]
return
ModelOutput
(
embed_logits
=
logits
)
return
ModelOutput
(
embed_logits
=
logits
)
# cross encoder model
elif
self
.
model_type
==
"cross_encoder"
:
response
=
self
.
engine
.
rerank
(
prompts
)
if
not
isinstance
(
response
,
list
):
response
=
[
response
]
scores
=
[
x
[
"embedding"
]
for
x
in
response
]
return
ModelOutput
(
scores
=
scores
)
# reward model
# reward model
else
:
else
:
response
=
self
.
engine
.
encode
(
prompts
)
response
=
self
.
engine
.
encode
(
prompts
)
...
...
python/sglang/test/test_utils.py
View file @
e30ef368
...
@@ -41,6 +41,8 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
...
@@ -41,6 +41,8 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
=
"Qwen/Qwen1.5-MoE-A2.7B"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
=
"Qwen/Qwen1.5-MoE-A2.7B"
# MLA test models
# MLA test models
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
=
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
=
"cross-encoder/ms-marco-MiniLM-L6-v2"
DEFAULT_MLA_MODEL_NAME_FOR_TEST
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_MODEL_NAME_FOR_TEST
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_MLA
=
"lmsys/sglang-ci-dsv3-test"
DEFAULT_MODEL_NAME_FOR_TEST_MLA
=
"lmsys/sglang-ci-dsv3-test"
...
...
python/sglang/utils.py
View file @
e30ef368
...
@@ -512,3 +512,12 @@ async def async_stream_and_merge(llm, prompt, sampling_params):
...
@@ -512,3 +512,12 @@ async def async_stream_and_merge(llm, prompt, sampling_params):
cleaned_chunk
=
trim_overlap
(
final_text
,
chunk_text
)
cleaned_chunk
=
trim_overlap
(
final_text
,
chunk_text
)
final_text
+=
cleaned_chunk
final_text
+=
cleaned_chunk
yield
cleaned_chunk
# yield the non-overlapping portion
yield
cleaned_chunk
# yield the non-overlapping portion
def
resolve_obj_by_qualname
(
qualname
:
str
)
->
Any
:
"""
Resolve an object by its fully qualified name.
"""
module_name
,
obj_name
=
qualname
.
rsplit
(
"."
,
1
)
module
=
importlib
.
import_module
(
module_name
)
return
getattr
(
module
,
obj_name
)
test/srt/models/test_cross_encoder_models.py
0 → 100644
View file @
e30ef368
import
multiprocessing
as
mp
import
random
import
unittest
import
torch
from
transformers
import
AutoConfig
,
AutoTokenizer
from
sglang.test.runners
import
TEST_RERANK_QUERY_DOCS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
CustomTestCase
,
is_in_ci
MODELS
=
[
(
"cross-encoder/ms-marco-MiniLM-L6-v2"
,
1
,
1e-2
),
(
"BAAI/bge-reranker-v2-m3"
,
1
,
1e-2
),
]
ATTENTION_BACKEND
=
[
"torch_native"
,
"triton"
]
TORCH_DTYPES
=
[
torch
.
float32
]
class
TestCrossEncoderModels
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
assert_close_prefill_logits
(
self
,
prompts
,
model_path
,
tp_size
,
torch_dtype
,
score_tolerance
,
attention_backend
,
)
->
None
:
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"cross_encoder"
,
)
as
hf_runner
:
hf_scores
=
hf_runner
.
forward
(
prompts
).
scores
with
SRTRunner
(
model_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
model_type
=
"cross_encoder"
,
attention_backend
=
attention_backend
,
chunked_prefill_size
=-
1
,
disable_radix_cache
=
True
,
)
as
srt_runner
:
srt_scores
=
srt_runner
.
forward
(
prompts
).
scores
for
i
in
range
(
len
(
srt_scores
)):
score_difference
=
abs
(
hf_scores
[
i
]
-
srt_scores
[
i
])
assert
(
score_difference
<
score_tolerance
),
"cross encoder scores are not all close"
def
preprocess_prompts
(
self
,
prompt
):
processed_prompts
=
[]
query
=
prompt
[
"query"
]
documents
=
prompt
[
"documents"
]
for
document
in
documents
:
processed_prompts
.
append
([
query
,
document
])
return
processed_prompts
def
test_prefill_logits
(
self
):
models_to_test
=
MODELS
if
is_in_ci
():
models_to_test
=
[
random
.
choice
(
MODELS
)]
for
model
,
tp_size
,
prefill_tolerance
in
models_to_test
:
for
attention_backend
in
ATTENTION_BACKEND
:
for
queryDocs
in
TEST_RERANK_QUERY_DOCS
:
prompts
=
self
.
preprocess_prompts
(
queryDocs
)
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_prefill_logits
(
prompts
,
model
,
tp_size
,
torch_dtype
,
prefill_tolerance
,
attention_backend
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
e30ef368
...
@@ -19,6 +19,8 @@ suites = {
...
@@ -19,6 +19,8 @@ suites = {
TestFile
(
"models/lora/test_lora_cuda_graph.py"
,
250
),
TestFile
(
"models/lora/test_lora_cuda_graph.py"
,
250
),
TestFile
(
"models/test_embedding_models.py"
,
73
),
TestFile
(
"models/test_embedding_models.py"
,
73
),
# TestFile("models/test_clip_models.py", 52),
# TestFile("models/test_clip_models.py", 52),
TestFile
(
"models/test_encoder_embedding_models.py"
,
100
),
TestFile
(
"models/test_cross_encoder_models.py"
,
100
),
TestFile
(
"models/test_compressed_tensors_models.py"
,
42
),
TestFile
(
"models/test_compressed_tensors_models.py"
,
42
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_generation_models.py"
,
103
),
# TestFile("models/test_gme_qwen_models.py", 45),
# TestFile("models/test_gme_qwen_models.py", 45),
...
...
test/srt/test_openai_server.py
View file @
e30ef368
...
@@ -17,7 +17,9 @@ import requests
...
@@ -17,7 +17,9 @@ import requests
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.runners
import
TEST_RERANK_QUERY_DOCS
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
@@ -699,6 +701,77 @@ class TestOpenAIEmbedding(CustomTestCase):
...
@@ -699,6 +701,77 @@ class TestOpenAIEmbedding(CustomTestCase):
self
.
assertEqual
(
cm
.
exception
.
status_code
,
400
)
self
.
assertEqual
(
cm
.
exception
.
status_code
,
400
)
class
TestOpenAIV1Rerank
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
score_tolerance
=
1e-2
# Configure embedding-specific args
other_args
=
[
"--is-embedding"
,
"--enable-metrics"
,
"--disable-radix-cache"
,
"--chunked-prefill-size"
,
"-1"
,
"--attention-backend"
,
"torch_native"
,
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
other_args
=
other_args
,
)
cls
.
base_url
+=
"/v1/rerank"
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
run_rerank
(
self
,
query
,
docs
):
response
=
requests
.
post
(
self
.
base_url
,
headers
=
{
"Authorization"
:
f
"Bearer
{
self
.
api_key
}
"
,
"Content-Type"
:
"application/json"
,
},
json
=
{
"query"
:
query
,
"documents"
:
docs
},
)
return
response
.
json
()
def
test_rerank_single
(
self
):
"""Test single rerank request"""
query
=
TEST_RERANK_QUERY_DOCS
[
0
][
"query"
]
docs
=
TEST_RERANK_QUERY_DOCS
[
0
][
"documents"
]
response
=
self
.
run_rerank
(
query
,
docs
)
self
.
assertEqual
(
len
(
response
),
1
)
self
.
assertTrue
(
isinstance
(
response
[
0
][
"score"
],
float
))
self
.
assertTrue
(
isinstance
(
response
[
0
][
"document"
],
str
))
self
.
assertTrue
(
isinstance
(
response
[
0
][
"index"
],
int
))
def
test_rerank_batch
(
self
):
"""Test batch rerank request"""
query
=
TEST_RERANK_QUERY_DOCS
[
1
][
"query"
]
docs
=
TEST_RERANK_QUERY_DOCS
[
1
][
"documents"
]
response
=
self
.
run_rerank
(
query
,
docs
)
self
.
assertEqual
(
len
(
response
),
2
)
self
.
assertTrue
(
isinstance
(
response
[
0
][
"score"
],
float
))
self
.
assertTrue
(
isinstance
(
response
[
1
][
"score"
],
float
))
self
.
assertTrue
(
isinstance
(
response
[
0
][
"document"
],
str
))
self
.
assertTrue
(
isinstance
(
response
[
1
][
"document"
],
str
))
self
.
assertTrue
(
isinstance
(
response
[
0
][
"index"
],
int
))
self
.
assertTrue
(
isinstance
(
response
[
1
][
"index"
],
int
))
class
TestOpenAIServerIgnoreEOS
(
CustomTestCase
):
class
TestOpenAIServerIgnoreEOS
(
CustomTestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment