"server/vscode:/vscode.git/clone" did not exist on "64142489b69d394cf4801d7265d4b2c3443225a0"
Unverified Commit e30ef368 authored by woodx's avatar woodx Committed by GitHub
Browse files

Feat/support rerank (#6058)

parent 91a066ec
...@@ -550,6 +550,11 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal ...@@ -550,6 +550,11 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
or "Qwen2ForRewardModel" in model_architectures or "Qwen2ForRewardModel" in model_architectures
or "Qwen2ForSequenceClassification" in model_architectures or "Qwen2ForSequenceClassification" in model_architectures
or "CLIPModel" in model_architectures or "CLIPModel" in model_architectures
or "BertModel" in model_architectures
or "Contriever" in model_architectures
or "BertForSequenceClassification" in model_architectures
or "XLMRobertaModel" in model_architectures
or "XLMRobertaForSequenceClassification" in model_architectures
): ):
return False return False
else: else:
......
...@@ -327,6 +327,20 @@ class Engine(EngineBase): ...@@ -327,6 +327,20 @@ class Engine(EngineBase):
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
return await generator.__anext__() return await generator.__anext__()
def rerank(
self,
prompt: Union[List[List[str]]],
) -> Dict:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
Please refer to `EmbeddingReqInput` for the documentation.
"""
obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None)
ret = loop.run_until_complete(generator.__anext__())
return ret
def shutdown(self): def shutdown(self):
"""Shutdown the engine""" """Shutdown the engine"""
kill_process_tree(os.getpid(), include_parent=False) kill_process_tree(os.getpid(), include_parent=False)
......
...@@ -67,6 +67,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -67,6 +67,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
V1RerankReqInput,
VertexGenerateReqInput, VertexGenerateReqInput,
) )
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -79,6 +80,7 @@ from sglang.srt.openai_api.adapter import ( ...@@ -79,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
v1_delete_file, v1_delete_file,
v1_embeddings, v1_embeddings,
v1_files_create, v1_files_create,
v1_rerank,
v1_retrieve_batch, v1_retrieve_batch,
v1_retrieve_file, v1_retrieve_file,
v1_retrieve_file_content, v1_retrieve_file_content,
...@@ -328,6 +330,15 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ...@@ -328,6 +330,15 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e) return _create_error_response(e)
@app.api_route("/v1/rerank", methods=["POST", "PUT"])
async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
try:
ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route("/flush_cache", methods=["GET", "POST"]) @app.api_route("/flush_cache", methods=["GET", "POST"])
async def flush_cache(): async def flush_cache():
"""Flush the radix cache.""" """Flush the radix cache."""
......
...@@ -20,6 +20,7 @@ from typing import Optional ...@@ -20,6 +20,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
...@@ -29,6 +30,7 @@ from sglang.srt.distributed import ( ...@@ -29,6 +30,7 @@ from sglang.srt.distributed import (
) )
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda, set_weight_attrs from sglang.srt.utils import is_cuda, set_weight_attrs
from sglang.utils import resolve_obj_by_qualname
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -165,6 +167,23 @@ def get_act_fn( ...@@ -165,6 +167,23 @@ def get_act_fn(
return act_fn return act_fn
def get_cross_encoder_activation_function(config: PretrainedConfig):
if (
hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None
):
function_name = config.sbert_ce_default_activation_function
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
return resolve_obj_by_qualname(function_name)()
else:
# adapt bge-reranker
return nn.Identity()
if not _is_cuda: if not _is_cuda:
logger.info( logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
......
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig
from sglang.srt.layers.activation import get_cross_encoder_activation_function
from sglang.srt.model_executor.model_runner import ForwardBatch from sglang.srt.model_executor.model_runner import ForwardBatch
...@@ -54,3 +57,56 @@ class Pooler(nn.Module): ...@@ -54,3 +57,56 @@ class Pooler(nn.Module):
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
return EmbeddingPoolerOutput(embeddings=pooled_data) return EmbeddingPoolerOutput(embeddings=pooled_data)
class CrossEncodingPooler(nn.Module):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `EmbeddingPoolerOutput`.
"""
def __init__(
self,
config: PretrainedConfig,
classifier: nn.Module,
pooler: Optional[nn.Module] = None,
):
super().__init__()
self.classifier = classifier
self.pooler = pooler
self.default_activation_function = get_cross_encoder_activation_function(config)
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> EmbeddingPoolerOutput:
"""Pools sentence pair scores from the hidden_states."""
prompt_lens = forward_batch.extend_seq_lens
offset = 0
pooled_data_lst = []
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset : offset + prompt_len]
if self.pooler is not None:
final_shape_tensor = self.pooler(pooled_data_i, forward_batch)
else:
final_shape_tensor = self.classifier(pooled_data_i)
pooled_data_lst.append(final_shape_tensor)
offset += prompt_len
pooled_output = torch.stack(pooled_data_lst)
if self.pooler is not None:
# apply classifier once on the full batch if possible
pooled_output = self.classifier(pooled_output)
scores = self.default_activation_function(pooled_output).squeeze(-1)
return EmbeddingPoolerOutput(embeddings=scores)
...@@ -481,7 +481,7 @@ class TokenizedGenerateReqInput: ...@@ -481,7 +481,7 @@ class TokenizedGenerateReqInput:
@dataclass @dataclass
class EmbeddingReqInput: class EmbeddingReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None text: Optional[Union[List[List[str]], List[str], str]] = None
# The image input. It can be an image instance, file name, URL, or base64 encoded string. # The image input. It can be an image instance, file name, URL, or base64 encoded string.
# Can be formatted as: # Can be formatted as:
# - Single image for a single request # - Single image for a single request
...@@ -505,6 +505,8 @@ class EmbeddingReqInput: ...@@ -505,6 +505,8 @@ class EmbeddingReqInput:
log_metrics: bool = True log_metrics: bool = True
# The modalities of the image data [image, multi-images, video] # The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
# For cross-encoder requests
is_cross_encoder_request: bool = False
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data) return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
...@@ -564,6 +566,16 @@ class EmbeddingReqInput: ...@@ -564,6 +566,16 @@ class EmbeddingReqInput:
return self.rid return self.rid
def __getitem__(self, i): def __getitem__(self, i):
if self.is_cross_encoder_request:
return EmbeddingReqInput(
text=[self.text[i]] if self.text is not None else None,
input_ids=None,
image_data=None,
sampling_params=self.sampling_params[i],
rid=self.rid[i],
is_cross_encoder_request=True,
)
return EmbeddingReqInput( return EmbeddingReqInput(
text=self.text[i] if self.text is not None else None, text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None,
...@@ -583,6 +595,8 @@ class TokenizedEmbeddingReqInput: ...@@ -583,6 +595,8 @@ class TokenizedEmbeddingReqInput:
input_ids: List[int] input_ids: List[int]
# The image inputs # The image inputs
image_inputs: dict image_inputs: dict
# The token type ids
token_type_ids: List[int]
# Dummy sampling params for compatibility # Dummy sampling params for compatibility
sampling_params: SamplingParams sampling_params: SamplingParams
...@@ -847,6 +861,12 @@ class SetInternalStateReq: ...@@ -847,6 +861,12 @@ class SetInternalStateReq:
server_args: Dict[str, Any] server_args: Dict[str, Any]
@dataclass
class V1RerankReqInput:
query: str
documents: List[str]
@dataclass @dataclass
class SetInternalStateReqOutput: class SetInternalStateReqOutput:
updated: bool updated: bool
......
...@@ -445,6 +445,7 @@ class Req: ...@@ -445,6 +445,7 @@ class Req:
origin_input_ids_unpadded: Optional[Tuple[int]] = None, origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None, input_embeds: Optional[List[List[float]]] = None,
token_type_ids: List[int] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
custom_logit_processor: Optional[str] = None, custom_logit_processor: Optional[str] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
...@@ -470,6 +471,9 @@ class Req: ...@@ -470,6 +471,9 @@ class Req:
self.session_id = session_id self.session_id = session_id
self.input_embeds = input_embeds self.input_embeds = input_embeds
# for corss-endoder model
self.token_type_ids = token_type_ids
# Sampling info # Sampling info
if isinstance(sampling_params.custom_params, dict): if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params) sampling_params = copy.copy(sampling_params)
...@@ -841,6 +845,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -841,6 +845,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Batched arguments to model runner # Batched arguments to model runner
input_ids: torch.Tensor = None # shape: [b], int64 input_ids: torch.Tensor = None # shape: [b], int64
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32 input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
token_type_ids: torch.Tensor = None # shape: [b], int64
req_pool_indices: torch.Tensor = None # shape: [b], int64 req_pool_indices: torch.Tensor = None # shape: [b], int64
seq_lens: torch.Tensor = None # shape: [b], int64 seq_lens: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache # The output locations of the KV cache
...@@ -1142,6 +1147,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1142,6 +1147,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens = [len(r.prefix_indices) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs]
extend_lens = [r.extend_input_len for r in reqs] extend_lens = [r.extend_input_len for r in reqs]
token_type_ids = [
r.token_type_ids for r in reqs if r.token_type_ids is not None
]
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
...@@ -1154,6 +1163,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1154,6 +1163,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor = torch.tensor( prefix_lens_tensor = torch.tensor(
prefix_lens, dtype=torch.int64, device=self.device prefix_lens, dtype=torch.int64, device=self.device
) )
token_type_ids_tensor = None
if len(token_type_ids) > 0:
token_type_ids_tensor = torch.tensor(
sum(token_type_ids, []), dtype=torch.int64
).to(self.device, non_blocking=True)
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
# Copy prefix and do some basic check # Copy prefix and do some basic check
...@@ -1269,6 +1285,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1269,6 +1285,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.device, non_blocking=True self.device, non_blocking=True
) )
self.multimodal_inputs = multimodal_inputs self.multimodal_inputs = multimodal_inputs
self.token_type_ids = token_type_ids_tensor
self.seq_lens_sum = sum(seq_lens) self.seq_lens_sum = sum(seq_lens)
if self.return_logprob: if self.return_logprob:
...@@ -1714,6 +1731,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1714,6 +1731,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
lora_paths=[req.lora_path for req in self.reqs], lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info, sampling_info=self.sampling_info,
input_embeds=self.input_embeds, input_embeds=self.input_embeds,
token_type_ids=self.token_type_ids,
spec_algorithm=self.spec_algorithm, spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info, spec_info=self.spec_info,
capture_hidden_mode=( capture_hidden_mode=(
...@@ -1807,6 +1825,9 @@ class ModelWorkerBatch: ...@@ -1807,6 +1825,9 @@ class ModelWorkerBatch:
# The input Embeds # The input Embeds
input_embeds: Optional[torch.tensor] = None input_embeds: Optional[torch.tensor] = None
# For corss-encoder model
token_type_ids: Optional[torch.Tensor] = None
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
......
...@@ -1150,6 +1150,7 @@ class Scheduler( ...@@ -1150,6 +1150,7 @@ class Scheduler(
recv_req.input_text, recv_req.input_text,
recv_req.input_ids, recv_req.input_ids,
recv_req.sampling_params, recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids,
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
......
...@@ -459,6 +459,10 @@ class TokenizerManager: ...@@ -459,6 +459,10 @@ class TokenizerManager:
# Tokenize # Tokenize
input_embeds = None input_embeds = None
input_text = obj.text input_text = obj.text
token_type_ids = None
is_cross_encoder_request = (
isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
)
if obj.input_embeds is not None: if obj.input_embeds is not None:
if not self.server_args.disable_radix_cache: if not self.server_args.disable_radix_cache:
raise ValueError( raise ValueError(
...@@ -477,7 +481,14 @@ class TokenizerManager: ...@@ -477,7 +481,14 @@ class TokenizerManager:
"accept text prompts. Please provide input_ids or re-initialize " "accept text prompts. Please provide input_ids or re-initialize "
"the engine with skip_tokenizer_init=False." "the engine with skip_tokenizer_init=False."
) )
input_ids = self.tokenizer.encode(input_text) encoded = self.tokenizer(
input_text, return_token_type_ids=is_cross_encoder_request
)
input_ids = encoded["input_ids"]
if is_cross_encoder_request:
input_ids = encoded["input_ids"][0]
token_type_ids = encoded.get("token_type_ids", [None])[0]
if self.mm_processor and obj.contains_mm_input(): if self.mm_processor and obj.contains_mm_input():
image_inputs = await self.mm_processor.process_mm_data_async( image_inputs = await self.mm_processor.process_mm_data_async(
...@@ -493,7 +504,7 @@ class TokenizerManager: ...@@ -493,7 +504,7 @@ class TokenizerManager:
self._validate_token_len(obj, input_ids) self._validate_token_len(obj, input_ids)
return self._create_tokenized_object( return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, image_inputs obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
) )
def _validate_token_len( def _validate_token_len(
...@@ -532,6 +543,7 @@ class TokenizerManager: ...@@ -532,6 +543,7 @@ class TokenizerManager:
input_ids: List[int], input_ids: List[int],
input_embeds: Optional[Union[List[float], None]] = None, input_embeds: Optional[Union[List[float], None]] = None,
image_inputs: Optional[Dict] = None, image_inputs: Optional[Dict] = None,
token_type_ids: Optional[List[int]] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]: ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters.""" """Create a tokenized request object from common parameters."""
...@@ -592,6 +604,7 @@ class TokenizerManager: ...@@ -592,6 +604,7 @@ class TokenizerManager:
input_text, input_text,
input_ids, input_ids,
image_inputs, image_inputs,
token_type_ids,
sampling_params, sampling_params,
) )
......
...@@ -224,6 +224,9 @@ class ForwardBatch: ...@@ -224,6 +224,9 @@ class ForwardBatch:
# For input embeddings # For input embeddings
input_embeds: Optional[torch.tensor] = None input_embeds: Optional[torch.tensor] = None
# For cross-encoder model
token_type_ids: Optional[torch.Tensor] = None
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo = None sampling_info: SamplingBatchInfo = None
...@@ -300,6 +303,7 @@ class ForwardBatch: ...@@ -300,6 +303,7 @@ class ForwardBatch:
spec_info=batch.spec_info, spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode, capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds, input_embeds=batch.input_embeds,
token_type_ids=batch.token_type_ids,
tbo_split_seq_index=batch.tbo_split_seq_index, tbo_split_seq_index=batch.tbo_split_seq_index,
) )
device = model_runner.device device = model_runner.device
...@@ -356,8 +360,8 @@ class ForwardBatch: ...@@ -356,8 +360,8 @@ class ForwardBatch:
ret.extend_prefix_lens = torch.tensor( ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32 batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
if support_triton(model_runner.server_args.attention_backend):
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
if support_triton(model_runner.server_args.attention_backend):
positions, ret.extend_start_loc = compute_position_triton( positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_prefix_lens,
ret.extend_seq_lens, ret.extend_seq_lens,
......
...@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import ( ...@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import AttentionType, RadixAttention from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
BertConfig = None BertConfig = None
...@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module): ...@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
input_shape = input_ids.size() input_shape = input_ids.size()
...@@ -58,8 +60,11 @@ class BertEmbedding(nn.Module): ...@@ -58,8 +60,11 @@ class BertEmbedding(nn.Module):
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings. # Position embeddings.
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(positions)
token_type_ids = forward_batch.token_type_ids
if token_type_ids is None:
token_type_ids = torch.zeros( token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=inputs_embeds.device input_shape, dtype=torch.long, device=inputs_embeds.device
) )
...@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module): ...@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module):
return embeddings return embeddings
class BertPooler(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
# simply taking the hidden state corresponding
first_token_tensor = hidden_states[0, :]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__( def __init__(
...@@ -113,6 +137,8 @@ class BertLayer(nn.Module): ...@@ -113,6 +137,8 @@ class BertLayer(nn.Module):
): ):
super().__init__() super().__init__()
self.layer_id = layer_id
self.attention = BertAttention( self.attention = BertAttention(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads, num_attention_heads=config.num_attention_heads,
...@@ -142,6 +168,7 @@ class BertLayer(nn.Module): ...@@ -142,6 +168,7 @@ class BertLayer(nn.Module):
attn_output = self.attention(hidden_states, forward_batch) attn_output = self.attention(hidden_states, forward_batch)
intermediate_output = self.intermediate(attn_output) intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output) output = self.output(intermediate_output, attn_output)
return output return output
...@@ -326,16 +353,23 @@ class BertModel(nn.Module): ...@@ -326,16 +353,23 @@ class BertModel(nn.Module):
*, *,
config: BertConfig, config: BertConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
use_bert_pooler: bool = False,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.use_bert_pooler = use_bert_pooler
self.config = config self.config = config
self.embeddings = BertEmbedding(config) self.embeddings = BertEmbedding(config)
self.encoder = BertEncoder( self.encoder = BertEncoder(
config=config, quant_config=quant_config, prefix=f"encoder" config=config,
quant_config=quant_config,
prefix=add_prefix("encoder", prefix),
)
self.pooler = (
BertPooler(config)
if self.use_bert_pooler
else Pooler(pooling_type=PoolingType.LAST, normalize=True)
) )
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# self.pooler = BertPooler(config)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -351,11 +385,16 @@ class BertModel(nn.Module): ...@@ -351,11 +385,16 @@ class BertModel(nn.Module):
hidden_states = self.embeddings( hidden_states = self.embeddings(
input_ids=input_ids, input_ids=input_ids,
position_ids=positions, positions=positions,
forward_batch=forward_batch,
) )
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch) hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
return self.pooler(hidden_states, forward_batch)
if not self.use_bert_pooler:
hidden_states = self.pooler(hidden_states, forward_batch)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -368,7 +407,7 @@ class BertModel(nn.Module): ...@@ -368,7 +407,7 @@ class BertModel(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
name = name.replace("self", "self_attn") name = name.replace("self", "self_attn")
if "pooler" in name: if not self.use_bert_pooler and "pooler" in name:
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
...@@ -395,4 +434,65 @@ class Contriever(BertModel): ...@@ -395,4 +434,65 @@ class Contriever(BertModel):
pass pass
EntryClass = [BertModel, Contriever] class BertForSequenceClassification(nn.Module):
def __init__(
self,
*,
config: BertConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.num_labels = config.num_labels
self.bert = BertModel(
config=config,
quant_config=quant_config,
use_bert_pooler=True,
prefix=add_prefix("bert", prefix),
)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.pooler = CrossEncodingPooler(config, self.classifier, self.bert.pooler)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self_weights = []
def weight_filter():
for name, weight in weights:
if name.startswith("bert."):
yield (name[len("bert.") :], weight)
else:
self_weights.append((name, weight))
self.bert.load_weights(weight_filter())
params_dict = dict(self.named_parameters())
for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor:
assert get_embedding == True
hidden_states = self.bert(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=input_embeds,
get_embedding=get_embedding,
)
return self.pooler(hidden_states, forward_batch)
EntryClass = [BertModel, Contriever, BertForSequenceClassification]
...@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple ...@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder ...@@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder
RobertaConfig = None RobertaConfig = None
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config: RobertaConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[0, :] # take <s> token (equiv. to [CLS])
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
return x
class RobertaEmbedding(nn.Module): class RobertaEmbedding(nn.Module):
def __init__(self, config: RobertaConfig): def __init__(self, config: RobertaConfig):
...@@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module): ...@@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
inputs_embeds=None, forward_batch: ForwardBatch,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
input_shape = input_ids.size() input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
...@@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module): ...@@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module):
# Position embeddings. # Position embeddings.
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
token_type_ids = forward_batch.token_type_ids
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros( token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=inputs_embeds.device input_shape, dtype=torch.long, device=inputs_embeds.device
...@@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module): ...@@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module):
return embeddings return embeddings
class XLMRobertaModel(nn.Module): class XLMRobertaBaseModel(nn.Module):
def __init__( def __init__(
self, self,
*, *,
config: RobertaConfig, config: RobertaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
add_pooling_layer: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.embeddings = RobertaEmbedding(config) self.embeddings = RobertaEmbedding(config)
self.encoder = BertEncoder(config=config, quant_config=quant_config, prefix="") self.encoder = BertEncoder(config=config, quant_config=quant_config, prefix="")
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) self.pooler = (
Pooler(pooling_type=PoolingType.CLS, normalize=True)
if add_pooling_layer
else None
)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module): ...@@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module):
input_ids=input_ids, input_ids=input_ids,
position_ids=positions, position_ids=positions,
seq_lens=forward_batch.seq_lens, seq_lens=forward_batch.seq_lens,
forward_batch=forward_batch,
) )
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch) hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
pooler_out = self.pooler(hidden_states, forward_batch)
return pooler_out return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module): ...@@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
name = name.replace("self", "self_attn") name = name.replace("self", "self_attn")
if "pooler" in name: if self.pooler is None and "pooler" in name:
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
...@@ -175,4 +199,88 @@ def create_position_ids_from_input_ids( ...@@ -175,4 +199,88 @@ def create_position_ids_from_input_ids(
return incremental_indices.long() + padding_idx return incremental_indices.long() + padding_idx
EntryClass = [XLMRobertaModel] class XLMRobertaModel(nn.Module):
def __init__(
self,
*,
config: RobertaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.roberta = XLMRobertaBaseModel(
config=config, quant_config=quant_config, prefix=prefix
)
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor:
hidden_states = self.roberta(
input_ids, positions, forward_batch, input_embeds, get_embedding
)
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.roberta.load_weights(weights)
class XLMRobertaForSequenceClassification(nn.Module):
def __init__(
self,
*,
config: RobertaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.roberta = XLMRobertaBaseModel(
config=config, quant_config=quant_config, prefix=prefix
)
self.classifier = RobertaClassificationHead(config)
self.pooler = CrossEncodingPooler(config, self.classifier, self.roberta.pooler)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> torch.Tensor:
assert (
get_embedding
), "XLMRobertaForSequenceClassification is only used for rerank"
hidden_states = self.roberta(
input_ids, positions, forward_batch, input_embeds, get_embedding
)
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self_weights = []
def weight_filter():
for name, weight in weights:
if name.startswith("roberta."):
yield (name[len("roberta.") :], weight)
else:
self_weights.append((name, weight))
self.roberta.load_weights(weight_filter())
params_dict = dict(self.named_parameters())
for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = [XLMRobertaModel, XLMRobertaForSequenceClassification]
...@@ -41,7 +41,11 @@ from sglang.srt.conversation import ( ...@@ -41,7 +41,11 @@ from sglang.srt.conversation import (
register_conv_template, register_conv_template,
) )
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
V1RerankReqInput,
)
from sglang.srt.openai_api.protocol import ( from sglang.srt.openai_api.protocol import (
BatchRequest, BatchRequest,
BatchResponse, BatchResponse,
...@@ -69,6 +73,7 @@ from sglang.srt.openai_api.protocol import ( ...@@ -69,6 +73,7 @@ from sglang.srt.openai_api.protocol import (
FunctionResponse, FunctionResponse,
LogProbs, LogProbs,
MultimodalEmbeddingInput, MultimodalEmbeddingInput,
RerankResponse,
ScoringRequest, ScoringRequest,
ScoringResponse, ScoringResponse,
ToolCall, ToolCall,
...@@ -2020,6 +2025,64 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request): ...@@ -2020,6 +2025,64 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
return response return response
def v1_rerank_request(obj: V1RerankReqInput):
if obj.query is None:
raise ValueError("query is required")
if obj.documents is None or len(obj.documents) == 0:
raise ValueError("documents is required")
pairs = []
for doc in obj.documents:
pairs.append([obj.query, doc])
adapted_request = EmbeddingReqInput(
text=pairs,
is_cross_encoder_request=True,
)
return adapted_request
def v1_rerank_response(ret, obj: V1RerankReqInput):
response = []
for idx, ret_item in enumerate(ret):
response.append(
RerankResponse(
score=ret[idx]["embedding"],
document=obj.documents[idx],
index=idx,
meta_info=ret[idx]["meta_info"],
)
)
response.sort(key=lambda x: x.score, reverse=True)
return response
async def v1_rerank(tokenizer_manager, obj: V1RerankReqInput, raw_request: Request):
adapted_request = v1_rerank_request(obj)
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = v1_rerank_response(
ret,
obj,
)
return response
def to_openai_style_logprobs( def to_openai_style_logprobs(
input_token_logprobs=None, input_token_logprobs=None,
output_token_logprobs=None, output_token_logprobs=None,
......
...@@ -539,6 +539,13 @@ class ScoringResponse(BaseModel): ...@@ -539,6 +539,13 @@ class ScoringResponse(BaseModel):
object: str = "scoring" object: str = "scoring"
class RerankResponse(BaseModel):
score: float
document: str
index: int
meta_info: Optional[dict] = None
def exclude_if_none(obj, field_names: List[str]): def exclude_if_none(obj, field_names: List[str]):
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names} omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None} return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
...@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [ ...@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt # the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is", # "The capital of France is",
] ]
TEST_RERANK_QUERY_DOCS = [
{
"query": "How many people live in Berlin?",
"documents": [
"Berlin is well known for its museums.",
],
},
{
"query": "How many people live in Berlin?",
"documents": [
"Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
"Berlin is well known for its museums.",
],
},
]
dirpath = os.path.dirname(__file__) dirpath = os.path.dirname(__file__)
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f: with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
...@@ -241,7 +256,7 @@ class HFRunner: ...@@ -241,7 +256,7 @@ class HFRunner:
self.model = _get_sentence_transformer_embedding_model( self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype model_path, torch_dtype
) )
elif self.model_type == "reward": elif self.model_type == "reward" or self.model_type == "cross_encoder":
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
self.model = AutoModelForSequenceClassification.from_pretrained( self.model = AutoModelForSequenceClassification.from_pretrained(
...@@ -303,6 +318,15 @@ class HFRunner: ...@@ -303,6 +318,15 @@ class HFRunner:
else: else:
logits = self.model.encode(prompts).tolist() logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits)) out_queue.put(ModelOutput(embed_logits=logits))
elif self.model_type == "cross_encoder":
inputs = self.tokenizer(
prompts, padding=True, return_tensors="pt"
).to("cuda")
scores = self.model(**inputs).logits
scores = scores.squeeze().tolist()
if not isinstance(scores, list):
scores = [scores]
out_queue.put(ModelOutput(scores=scores))
elif self.model_type == "reward": elif self.model_type == "reward":
scores = [] scores = []
...@@ -322,7 +346,9 @@ class HFRunner: ...@@ -322,7 +346,9 @@ class HFRunner:
def forward( def forward(
self, self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[
List[List[str]], List[str], List[torch.Tensor]
] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None, image_data: Optional[List[str]] = None,
max_new_tokens: int = 8, max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None, lora_paths: Optional[List[str]] = None,
...@@ -526,7 +552,9 @@ class SRTRunner: ...@@ -526,7 +552,9 @@ class SRTRunner:
def forward( def forward(
self, self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[
List[List[str]], List[str], List[torch.Tensor]
] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None, image_data: Optional[List[str]] = None,
max_new_tokens: int = 8, max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None, lora_paths: Optional[List[str]] = None,
...@@ -552,6 +580,13 @@ class SRTRunner: ...@@ -552,6 +580,13 @@ class SRTRunner:
else: else:
logits = [response["embedding"]] logits = [response["embedding"]]
return ModelOutput(embed_logits=logits) return ModelOutput(embed_logits=logits)
# cross encoder model
elif self.model_type == "cross_encoder":
response = self.engine.rerank(prompts)
if not isinstance(response, list):
response = [response]
scores = [x["embedding"] for x in response]
return ModelOutput(scores=scores)
# reward model # reward model
else: else:
response = self.engine.encode(prompts) response = self.engine.encode(prompts)
......
...@@ -41,6 +41,8 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" ...@@ -41,6 +41,8 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B" DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
# MLA test models # MLA test models
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST = "cross-encoder/ms-marco-MiniLM-L6-v2"
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test" DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
......
...@@ -512,3 +512,12 @@ async def async_stream_and_merge(llm, prompt, sampling_params): ...@@ -512,3 +512,12 @@ async def async_stream_and_merge(llm, prompt, sampling_params):
cleaned_chunk = trim_overlap(final_text, chunk_text) cleaned_chunk = trim_overlap(final_text, chunk_text)
final_text += cleaned_chunk final_text += cleaned_chunk
yield cleaned_chunk # yield the non-overlapping portion yield cleaned_chunk # yield the non-overlapping portion
def resolve_obj_by_qualname(qualname: str) -> Any:
"""
Resolve an object by its fully qualified name.
"""
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)
import multiprocessing as mp
import random
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import TEST_RERANK_QUERY_DOCS, HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, is_in_ci
MODELS = [
("cross-encoder/ms-marco-MiniLM-L6-v2", 1, 1e-2),
("BAAI/bge-reranker-v2-m3", 1, 1e-2),
]
ATTENTION_BACKEND = ["torch_native", "triton"]
TORCH_DTYPES = [torch.float32]
class TestCrossEncoderModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_prefill_logits(
self,
prompts,
model_path,
tp_size,
torch_dtype,
score_tolerance,
attention_backend,
) -> None:
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="cross_encoder",
) as hf_runner:
hf_scores = hf_runner.forward(prompts).scores
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="cross_encoder",
attention_backend=attention_backend,
chunked_prefill_size=-1,
disable_radix_cache=True,
) as srt_runner:
srt_scores = srt_runner.forward(prompts).scores
for i in range(len(srt_scores)):
score_difference = abs(hf_scores[i] - srt_scores[i])
assert (
score_difference < score_tolerance
), "cross encoder scores are not all close"
def preprocess_prompts(self, prompt):
processed_prompts = []
query = prompt["query"]
documents = prompt["documents"]
for document in documents:
processed_prompts.append([query, document])
return processed_prompts
def test_prefill_logits(self):
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model, tp_size, prefill_tolerance in models_to_test:
for attention_backend in ATTENTION_BACKEND:
for queryDocs in TEST_RERANK_QUERY_DOCS:
prompts = self.preprocess_prompts(queryDocs)
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
prompts,
model,
tp_size,
torch_dtype,
prefill_tolerance,
attention_backend,
)
if __name__ == "__main__":
unittest.main()
...@@ -19,6 +19,8 @@ suites = { ...@@ -19,6 +19,8 @@ suites = {
TestFile("models/lora/test_lora_cuda_graph.py", 250), TestFile("models/lora/test_lora_cuda_graph.py", 250),
TestFile("models/test_embedding_models.py", 73), TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52), # TestFile("models/test_clip_models.py", 52),
TestFile("models/test_encoder_embedding_models.py", 100),
TestFile("models/test_cross_encoder_models.py", 100),
TestFile("models/test_compressed_tensors_models.py", 42), TestFile("models/test_compressed_tensors_models.py", 42),
TestFile("models/test_generation_models.py", 103), TestFile("models/test_generation_models.py", 103),
# TestFile("models/test_gme_qwen_models.py", 45), # TestFile("models/test_gme_qwen_models.py", 45),
......
...@@ -17,7 +17,9 @@ import requests ...@@ -17,7 +17,9 @@ import requests
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
...@@ -699,6 +701,77 @@ class TestOpenAIEmbedding(CustomTestCase): ...@@ -699,6 +701,77 @@ class TestOpenAIEmbedding(CustomTestCase):
self.assertEqual(cm.exception.status_code, 400) self.assertEqual(cm.exception.status_code, 400)
class TestOpenAIV1Rerank(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.score_tolerance = 1e-2
# Configure embedding-specific args
other_args = [
"--is-embedding",
"--enable-metrics",
"--disable-radix-cache",
"--chunked-prefill-size",
"-1",
"--attention-backend",
"torch_native",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1/rerank"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_rerank(self, query, docs):
response = requests.post(
self.base_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={"query": query, "documents": docs},
)
return response.json()
def test_rerank_single(self):
"""Test single rerank request"""
query = TEST_RERANK_QUERY_DOCS[0]["query"]
docs = TEST_RERANK_QUERY_DOCS[0]["documents"]
response = self.run_rerank(query, docs)
self.assertEqual(len(response), 1)
self.assertTrue(isinstance(response[0]["score"], float))
self.assertTrue(isinstance(response[0]["document"], str))
self.assertTrue(isinstance(response[0]["index"], int))
def test_rerank_batch(self):
"""Test batch rerank request"""
query = TEST_RERANK_QUERY_DOCS[1]["query"]
docs = TEST_RERANK_QUERY_DOCS[1]["documents"]
response = self.run_rerank(query, docs)
self.assertEqual(len(response), 2)
self.assertTrue(isinstance(response[0]["score"], float))
self.assertTrue(isinstance(response[1]["score"], float))
self.assertTrue(isinstance(response[0]["document"], str))
self.assertTrue(isinstance(response[1]["document"], str))
self.assertTrue(isinstance(response[0]["index"], int))
self.assertTrue(isinstance(response[1]["index"], int))
class TestOpenAIServerIgnoreEOS(CustomTestCase): class TestOpenAIServerIgnoreEOS(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment