Unverified Commit e30ef368 authored by woodx's avatar woodx Committed by GitHub
Browse files

Feat/support rerank (#6058)

parent 91a066ec
......@@ -550,6 +550,11 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
or "Qwen2ForRewardModel" in model_architectures
or "Qwen2ForSequenceClassification" in model_architectures
or "CLIPModel" in model_architectures
or "BertModel" in model_architectures
or "Contriever" in model_architectures
or "BertForSequenceClassification" in model_architectures
or "XLMRobertaModel" in model_architectures
or "XLMRobertaForSequenceClassification" in model_architectures
):
return False
else:
......
......@@ -327,6 +327,20 @@ class Engine(EngineBase):
generator = self.tokenizer_manager.generate_request(obj, None)
return await generator.__anext__()
def rerank(
self,
prompt: Union[List[List[str]]],
) -> Dict:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
Please refer to `EmbeddingReqInput` for the documentation.
"""
obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None)
ret = loop.run_until_complete(generator.__anext__())
return ret
def shutdown(self):
"""Shutdown the engine"""
kill_process_tree(os.getpid(), include_parent=False)
......
......@@ -67,6 +67,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
V1RerankReqInput,
VertexGenerateReqInput,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager
......@@ -79,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
v1_delete_file,
v1_embeddings,
v1_files_create,
v1_rerank,
v1_retrieve_batch,
v1_retrieve_file,
v1_retrieve_file_content,
......@@ -328,6 +330,15 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e)
@app.api_route("/v1/rerank", methods=["POST", "PUT"])
async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
try:
ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route("/flush_cache", methods=["GET", "POST"])
async def flush_cache():
"""Flush the radix cache."""
......
......@@ -20,6 +20,7 @@ from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
......@@ -29,6 +30,7 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda, set_weight_attrs
from sglang.utils import resolve_obj_by_qualname
_is_cuda = is_cuda()
......@@ -165,6 +167,23 @@ def get_act_fn(
return act_fn
def get_cross_encoder_activation_function(config: PretrainedConfig):
if (
hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None
):
function_name = config.sbert_ce_default_activation_function
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
return resolve_obj_by_qualname(function_name)()
else:
# adapt bge-reranker
return nn.Identity()
if not _is_cuda:
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
......
......@@ -3,10 +3,13 @@
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from sglang.srt.layers.activation import get_cross_encoder_activation_function
from sglang.srt.model_executor.model_runner import ForwardBatch
......@@ -54,3 +57,56 @@ class Pooler(nn.Module):
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
return EmbeddingPoolerOutput(embeddings=pooled_data)
class CrossEncodingPooler(nn.Module):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `EmbeddingPoolerOutput`.
"""
def __init__(
self,
config: PretrainedConfig,
classifier: nn.Module,
pooler: Optional[nn.Module] = None,
):
super().__init__()
self.classifier = classifier
self.pooler = pooler
self.default_activation_function = get_cross_encoder_activation_function(config)
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> EmbeddingPoolerOutput:
"""Pools sentence pair scores from the hidden_states."""
prompt_lens = forward_batch.extend_seq_lens
offset = 0
pooled_data_lst = []
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset : offset + prompt_len]
if self.pooler is not None:
final_shape_tensor = self.pooler(pooled_data_i, forward_batch)
else:
final_shape_tensor = self.classifier(pooled_data_i)
pooled_data_lst.append(final_shape_tensor)
offset += prompt_len
pooled_output = torch.stack(pooled_data_lst)
if self.pooler is not None:
# apply classifier once on the full batch if possible
pooled_output = self.classifier(pooled_output)
scores = self.default_activation_function(pooled_output).squeeze(-1)
return EmbeddingPoolerOutput(embeddings=scores)
......@@ -481,7 +481,7 @@ class TokenizedGenerateReqInput:
@dataclass
class EmbeddingReqInput:
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
text: Optional[Union[List[List[str]], List[str], str]] = None
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
# Can be formatted as:
# - Single image for a single request
......@@ -505,6 +505,8 @@ class EmbeddingReqInput:
log_metrics: bool = True
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None
# For cross-encoder requests
is_cross_encoder_request: bool = False
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
......@@ -564,6 +566,16 @@ class EmbeddingReqInput:
return self.rid
def __getitem__(self, i):
if self.is_cross_encoder_request:
return EmbeddingReqInput(
text=[self.text[i]] if self.text is not None else None,
input_ids=None,
image_data=None,
sampling_params=self.sampling_params[i],
rid=self.rid[i],
is_cross_encoder_request=True,
)
return EmbeddingReqInput(
text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None,
......@@ -583,6 +595,8 @@ class TokenizedEmbeddingReqInput:
input_ids: List[int]
# The image inputs
image_inputs: dict
# The token type ids
token_type_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams
......@@ -847,6 +861,12 @@ class SetInternalStateReq:
server_args: Dict[str, Any]
@dataclass
class V1RerankReqInput:
query: str
documents: List[str]
@dataclass
class SetInternalStateReqOutput:
updated: bool
......
......@@ -445,6 +445,7 @@ class Req:
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
token_type_ids: List[int] = None,
session_id: Optional[str] = None,
custom_logit_processor: Optional[str] = None,
return_hidden_states: bool = False,
......@@ -470,6 +471,9 @@ class Req:
self.session_id = session_id
self.input_embeds = input_embeds
# for corss-endoder model
self.token_type_ids = token_type_ids
# Sampling info
if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params)
......@@ -841,6 +845,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Batched arguments to model runner
input_ids: torch.Tensor = None # shape: [b], int64
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
token_type_ids: torch.Tensor = None # shape: [b], int64
req_pool_indices: torch.Tensor = None # shape: [b], int64
seq_lens: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache
......@@ -1142,6 +1147,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens = [len(r.prefix_indices) for r in reqs]
extend_lens = [r.extend_input_len for r in reqs]
token_type_ids = [
r.token_type_ids for r in reqs if r.token_type_ids is not None
]
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
......@@ -1154,6 +1163,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor = torch.tensor(
prefix_lens, dtype=torch.int64, device=self.device
)
token_type_ids_tensor = None
if len(token_type_ids) > 0:
token_type_ids_tensor = torch.tensor(
sum(token_type_ids, []), dtype=torch.int64
).to(self.device, non_blocking=True)
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
# Copy prefix and do some basic check
......@@ -1269,6 +1285,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.device, non_blocking=True
)
self.multimodal_inputs = multimodal_inputs
self.token_type_ids = token_type_ids_tensor
self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
......@@ -1714,6 +1731,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
input_embeds=self.input_embeds,
token_type_ids=self.token_type_ids,
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
capture_hidden_mode=(
......@@ -1807,6 +1825,9 @@ class ModelWorkerBatch:
# The input Embeds
input_embeds: Optional[torch.tensor] = None
# For corss-encoder model
token_type_ids: Optional[torch.Tensor] = None
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
......
......@@ -1150,6 +1150,7 @@ class Scheduler(
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids,
)
req.tokenizer = self.tokenizer
......
......@@ -459,6 +459,10 @@ class TokenizerManager:
# Tokenize
input_embeds = None
input_text = obj.text
token_type_ids = None
is_cross_encoder_request = (
isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
)
if obj.input_embeds is not None:
if not self.server_args.disable_radix_cache:
raise ValueError(
......@@ -477,7 +481,14 @@ class TokenizerManager:
"accept text prompts. Please provide input_ids or re-initialize "
"the engine with skip_tokenizer_init=False."
)
input_ids = self.tokenizer.encode(input_text)
encoded = self.tokenizer(
input_text, return_token_type_ids=is_cross_encoder_request
)
input_ids = encoded["input_ids"]
if is_cross_encoder_request:
input_ids = encoded["input_ids"][0]
token_type_ids = encoded.get("token_type_ids", [None])[0]
if self.mm_processor and obj.contains_mm_input():
image_inputs = await self.mm_processor.process_mm_data_async(
......@@ -493,7 +504,7 @@ class TokenizerManager:
self._validate_token_len(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, image_inputs
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
)
def _validate_token_len(
......@@ -532,6 +543,7 @@ class TokenizerManager:
input_ids: List[int],
input_embeds: Optional[Union[List[float], None]] = None,
image_inputs: Optional[Dict] = None,
token_type_ids: Optional[List[int]] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters."""
......@@ -592,6 +604,7 @@ class TokenizerManager:
input_text,
input_ids,
image_inputs,
token_type_ids,
sampling_params,
)
......
......@@ -224,6 +224,9 @@ class ForwardBatch:
# For input embeddings
input_embeds: Optional[torch.tensor] = None
# For cross-encoder model
token_type_ids: Optional[torch.Tensor] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
......@@ -300,6 +303,7 @@ class ForwardBatch:
spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds,
token_type_ids=batch.token_type_ids,
tbo_split_seq_index=batch.tbo_split_seq_index,
)
device = model_runner.device
......@@ -356,8 +360,8 @@ class ForwardBatch:
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens
if support_triton(model_runner.server_args.attention_backend):
ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens,
ret.extend_seq_lens,
......
......@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
BertConfig = None
......@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
input_shape = input_ids.size()
......@@ -58,11 +60,14 @@ class BertEmbedding(nn.Module):
inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
position_embeddings = self.position_embeddings(positions)
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=inputs_embeds.device
)
token_type_ids = forward_batch.token_type_ids
if token_type_ids is None:
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=inputs_embeds.device
)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
......@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module):
return embeddings
class BertPooler(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
# simply taking the hidden state corresponding
first_token_tensor = hidden_states[0, :]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertEncoder(nn.Module):
def __init__(
......@@ -113,6 +137,8 @@ class BertLayer(nn.Module):
):
super().__init__()
self.layer_id = layer_id
self.attention = BertAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
......@@ -142,6 +168,7 @@ class BertLayer(nn.Module):
attn_output = self.attention(hidden_states, forward_batch)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output
......@@ -326,16 +353,23 @@ class BertModel(nn.Module):
*,
config: BertConfig,
quant_config: Optional[QuantizationConfig] = None,
use_bert_pooler: bool = False,
prefix: str = "",
):
super().__init__()
self.use_bert_pooler = use_bert_pooler
self.config = config
self.embeddings = BertEmbedding(config)
self.encoder = BertEncoder(
config=config, quant_config=quant_config, prefix=f"encoder"
config=config,
quant_config=quant_config,
prefix=add_prefix("encoder", prefix),
)
self.pooler = (
BertPooler(config)
if self.use_bert_pooler
else Pooler(pooling_type=PoolingType.LAST, normalize=True)
)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# self.pooler = BertPooler(config)
@torch.no_grad()
def forward(
......@@ -351,11 +385,16 @@ class BertModel(nn.Module):
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=positions,
positions=positions,
forward_batch=forward_batch,
)
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
return self.pooler(hidden_states, forward_batch)
if not self.use_bert_pooler:
hidden_states = self.pooler(hidden_states, forward_batch)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......@@ -368,7 +407,7 @@ class BertModel(nn.Module):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
name = name.replace("self", "self_attn")
if "pooler" in name:
if not self.use_bert_pooler and "pooler" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
......@@ -395,4 +434,65 @@ class Contriever(BertModel):
pass
EntryClass = [BertModel, Contriever]
class BertForSequenceClassification(nn.Module):
def __init__(
self,
*,
config: BertConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.num_labels = config.num_labels
self.bert = BertModel(
config=config,
quant_config=quant_config,
use_bert_pooler=True,
prefix=add_prefix("bert", prefix),
)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.pooler = CrossEncodingPooler(config, self.classifier, self.bert.pooler)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self_weights = []
def weight_filter():
for name, weight in weights:
if name.startswith("bert."):
yield (name[len("bert.") :], weight)
else:
self_weights.append((name, weight))
self.bert.load_weights(weight_filter())
params_dict = dict(self.named_parameters())
for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor:
assert get_embedding == True
hidden_states = self.bert(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=input_embeds,
get_embedding=get_embedding,
)
return self.pooler(hidden_states, forward_batch)
EntryClass = [BertModel, Contriever, BertForSequenceClassification]
......@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder
RobertaConfig = None
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config: RobertaConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[0, :] # take <s> token (equiv. to [CLS])
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
return x
class RobertaEmbedding(nn.Module):
def __init__(self, config: RobertaConfig):
......@@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module):
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds=None,
token_type_ids: Optional[torch.Tensor] = None,
forward_batch: ForwardBatch,
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)
......@@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module):
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
token_type_ids = forward_batch.token_type_ids
if token_type_ids is None:
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=inputs_embeds.device
......@@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module):
return embeddings
class XLMRobertaModel(nn.Module):
class XLMRobertaBaseModel(nn.Module):
def __init__(
self,
*,
config: RobertaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
add_pooling_layer: bool = False,
):
super().__init__()
self.config = config
self.embeddings = RobertaEmbedding(config)
self.encoder = BertEncoder(config=config, quant_config=quant_config, prefix="")
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
self.pooler = (
Pooler(pooling_type=PoolingType.CLS, normalize=True)
if add_pooling_layer
else None
)
@torch.no_grad()
def forward(
......@@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module):
input_ids=input_ids,
position_ids=positions,
seq_lens=forward_batch.seq_lens,
forward_batch=forward_batch,
)
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
pooler_out = self.pooler(hidden_states, forward_batch)
return pooler_out
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......@@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
name = name.replace("self", "self_attn")
if "pooler" in name:
if self.pooler is None and "pooler" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
......@@ -175,4 +199,88 @@ def create_position_ids_from_input_ids(
return incremental_indices.long() + padding_idx
EntryClass = [XLMRobertaModel]
class XLMRobertaModel(nn.Module):
def __init__(
self,
*,
config: RobertaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.roberta = XLMRobertaBaseModel(
config=config, quant_config=quant_config, prefix=prefix
)
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor:
hidden_states = self.roberta(
input_ids, positions, forward_batch, input_embeds, get_embedding
)
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.roberta.load_weights(weights)
class XLMRobertaForSequenceClassification(nn.Module):
def __init__(
self,
*,
config: RobertaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.roberta = XLMRobertaBaseModel(
config=config, quant_config=quant_config, prefix=prefix
)
self.classifier = RobertaClassificationHead(config)
self.pooler = CrossEncodingPooler(config, self.classifier, self.roberta.pooler)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> torch.Tensor:
assert (
get_embedding
), "XLMRobertaForSequenceClassification is only used for rerank"
hidden_states = self.roberta(
input_ids, positions, forward_batch, input_embeds, get_embedding
)
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self_weights = []
def weight_filter():
for name, weight in weights:
if name.startswith("roberta."):
yield (name[len("roberta.") :], weight)
else:
self_weights.append((name, weight))
self.roberta.load_weights(weight_filter())
params_dict = dict(self.named_parameters())
for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = [XLMRobertaModel, XLMRobertaForSequenceClassification]
......@@ -41,7 +41,11 @@ from sglang.srt.conversation import (
register_conv_template,
)
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
V1RerankReqInput,
)
from sglang.srt.openai_api.protocol import (
BatchRequest,
BatchResponse,
......@@ -69,6 +73,7 @@ from sglang.srt.openai_api.protocol import (
FunctionResponse,
LogProbs,
MultimodalEmbeddingInput,
RerankResponse,
ScoringRequest,
ScoringResponse,
ToolCall,
......@@ -2020,6 +2025,64 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
return response
def v1_rerank_request(obj: V1RerankReqInput):
if obj.query is None:
raise ValueError("query is required")
if obj.documents is None or len(obj.documents) == 0:
raise ValueError("documents is required")
pairs = []
for doc in obj.documents:
pairs.append([obj.query, doc])
adapted_request = EmbeddingReqInput(
text=pairs,
is_cross_encoder_request=True,
)
return adapted_request
def v1_rerank_response(ret, obj: V1RerankReqInput):
response = []
for idx, ret_item in enumerate(ret):
response.append(
RerankResponse(
score=ret[idx]["embedding"],
document=obj.documents[idx],
index=idx,
meta_info=ret[idx]["meta_info"],
)
)
response.sort(key=lambda x: x.score, reverse=True)
return response
async def v1_rerank(tokenizer_manager, obj: V1RerankReqInput, raw_request: Request):
adapted_request = v1_rerank_request(obj)
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = v1_rerank_response(
ret,
obj,
)
return response
def to_openai_style_logprobs(
input_token_logprobs=None,
output_token_logprobs=None,
......
......@@ -539,6 +539,13 @@ class ScoringResponse(BaseModel):
object: str = "scoring"
class RerankResponse(BaseModel):
score: float
document: str
index: int
meta_info: Optional[dict] = None
def exclude_if_none(obj, field_names: List[str]):
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
......@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
]
TEST_RERANK_QUERY_DOCS = [
{
"query": "How many people live in Berlin?",
"documents": [
"Berlin is well known for its museums.",
],
},
{
"query": "How many people live in Berlin?",
"documents": [
"Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
"Berlin is well known for its museums.",
],
},
]
dirpath = os.path.dirname(__file__)
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
......@@ -241,7 +256,7 @@ class HFRunner:
self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype
)
elif self.model_type == "reward":
elif self.model_type == "reward" or self.model_type == "cross_encoder":
from transformers import AutoModelForSequenceClassification
self.model = AutoModelForSequenceClassification.from_pretrained(
......@@ -303,6 +318,15 @@ class HFRunner:
else:
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
elif self.model_type == "cross_encoder":
inputs = self.tokenizer(
prompts, padding=True, return_tensors="pt"
).to("cuda")
scores = self.model(**inputs).logits
scores = scores.squeeze().tolist()
if not isinstance(scores, list):
scores = [scores]
out_queue.put(ModelOutput(scores=scores))
elif self.model_type == "reward":
scores = []
......@@ -322,7 +346,9 @@ class HFRunner:
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
prompts: Union[
List[List[str]], List[str], List[torch.Tensor]
] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
......@@ -526,7 +552,9 @@ class SRTRunner:
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
prompts: Union[
List[List[str]], List[str], List[torch.Tensor]
] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
......@@ -552,6 +580,13 @@ class SRTRunner:
else:
logits = [response["embedding"]]
return ModelOutput(embed_logits=logits)
# cross encoder model
elif self.model_type == "cross_encoder":
response = self.engine.rerank(prompts)
if not isinstance(response, list):
response = [response]
scores = [x["embedding"] for x in response]
return ModelOutput(scores=scores)
# reward model
else:
response = self.engine.encode(prompts)
......
......@@ -41,6 +41,8 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
# MLA test models
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST = "cross-encoder/ms-marco-MiniLM-L6-v2"
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
......
......@@ -512,3 +512,12 @@ async def async_stream_and_merge(llm, prompt, sampling_params):
cleaned_chunk = trim_overlap(final_text, chunk_text)
final_text += cleaned_chunk
yield cleaned_chunk # yield the non-overlapping portion
def resolve_obj_by_qualname(qualname: str) -> Any:
"""
Resolve an object by its fully qualified name.
"""
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)
import multiprocessing as mp
import random
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import TEST_RERANK_QUERY_DOCS, HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, is_in_ci
MODELS = [
("cross-encoder/ms-marco-MiniLM-L6-v2", 1, 1e-2),
("BAAI/bge-reranker-v2-m3", 1, 1e-2),
]
ATTENTION_BACKEND = ["torch_native", "triton"]
TORCH_DTYPES = [torch.float32]
class TestCrossEncoderModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_prefill_logits(
self,
prompts,
model_path,
tp_size,
torch_dtype,
score_tolerance,
attention_backend,
) -> None:
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="cross_encoder",
) as hf_runner:
hf_scores = hf_runner.forward(prompts).scores
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="cross_encoder",
attention_backend=attention_backend,
chunked_prefill_size=-1,
disable_radix_cache=True,
) as srt_runner:
srt_scores = srt_runner.forward(prompts).scores
for i in range(len(srt_scores)):
score_difference = abs(hf_scores[i] - srt_scores[i])
assert (
score_difference < score_tolerance
), "cross encoder scores are not all close"
def preprocess_prompts(self, prompt):
processed_prompts = []
query = prompt["query"]
documents = prompt["documents"]
for document in documents:
processed_prompts.append([query, document])
return processed_prompts
def test_prefill_logits(self):
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model, tp_size, prefill_tolerance in models_to_test:
for attention_backend in ATTENTION_BACKEND:
for queryDocs in TEST_RERANK_QUERY_DOCS:
prompts = self.preprocess_prompts(queryDocs)
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
prompts,
model,
tp_size,
torch_dtype,
prefill_tolerance,
attention_backend,
)
if __name__ == "__main__":
unittest.main()
......@@ -19,6 +19,8 @@ suites = {
TestFile("models/lora/test_lora_cuda_graph.py", 250),
TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52),
TestFile("models/test_encoder_embedding_models.py", 100),
TestFile("models/test_cross_encoder_models.py", 100),
TestFile("models/test_compressed_tensors_models.py", 42),
TestFile("models/test_generation_models.py", 103),
# TestFile("models/test_gme_qwen_models.py", 45),
......
......@@ -17,7 +17,9 @@ import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
from sglang.test.test_utils import (
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
......@@ -699,6 +701,77 @@ class TestOpenAIEmbedding(CustomTestCase):
self.assertEqual(cm.exception.status_code, 400)
class TestOpenAIV1Rerank(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.score_tolerance = 1e-2
# Configure embedding-specific args
other_args = [
"--is-embedding",
"--enable-metrics",
"--disable-radix-cache",
"--chunked-prefill-size",
"-1",
"--attention-backend",
"torch_native",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1/rerank"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_rerank(self, query, docs):
response = requests.post(
self.base_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={"query": query, "documents": docs},
)
return response.json()
def test_rerank_single(self):
"""Test single rerank request"""
query = TEST_RERANK_QUERY_DOCS[0]["query"]
docs = TEST_RERANK_QUERY_DOCS[0]["documents"]
response = self.run_rerank(query, docs)
self.assertEqual(len(response), 1)
self.assertTrue(isinstance(response[0]["score"], float))
self.assertTrue(isinstance(response[0]["document"], str))
self.assertTrue(isinstance(response[0]["index"], int))
def test_rerank_batch(self):
"""Test batch rerank request"""
query = TEST_RERANK_QUERY_DOCS[1]["query"]
docs = TEST_RERANK_QUERY_DOCS[1]["documents"]
response = self.run_rerank(query, docs)
self.assertEqual(len(response), 2)
self.assertTrue(isinstance(response[0]["score"], float))
self.assertTrue(isinstance(response[1]["score"], float))
self.assertTrue(isinstance(response[0]["document"], str))
self.assertTrue(isinstance(response[1]["document"], str))
self.assertTrue(isinstance(response[0]["index"], int))
self.assertTrue(isinstance(response[1]["index"], int))
class TestOpenAIServerIgnoreEOS(CustomTestCase):
@classmethod
def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment