Unverified Commit cfb2fb5a authored by woodx's avatar woodx Committed by GitHub
Browse files

[OAI refactor] Add rerank and score serving (#7399)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 22bfed75
......@@ -534,6 +534,22 @@ class ScoringResponse(BaseModel):
object: str = "scoring"
class V1RerankReqInput(BaseModel):
query: str
documents: List[str]
class RerankResponse(BaseModel):
score: float
document: str
index: int
meta_info: Optional[dict] = None
OpenAIServingRequest = Union[
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ScoringRequest
ChatCompletionRequest,
CompletionRequest,
EmbeddingRequest,
ScoringRequest,
V1RerankReqInput,
]
import logging
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
RerankResponse,
V1RerankReqInput,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.managers.io_struct import EmbeddingReqInput
logger = logging.getLogger(__name__)
class OpenAIServingRerank(OpenAIServingBase):
"""Handler for rerank requests"""
def _request_id_prefix(self) -> str:
return "rerank-"
def _validate_request(self, request: V1RerankReqInput) -> Optional[str]:
"""Validate rerank request format and content"""
if not request.query:
return "Query cannot be empty"
if isinstance(request.query, str):
if not request.query.strip():
return "Query cannot be empty or whitespace only"
if not request.documents:
return "Documents cannot be empty"
for doc in request.documents:
if not doc:
return "Each document must be a non-empty string"
if isinstance(doc, str) and not doc.strip():
return "Each document cannot be empty or whitespace only"
return None
def _convert_to_internal_request(
self, request: V1RerankReqInput
) -> tuple[EmbeddingReqInput, V1RerankReqInput]:
"""Convert OpenAI rerank request to internal embedding format"""
# Create pairs of [query, document] for each document
pairs = []
for doc in request.documents:
pairs.append([request.query, doc])
adapted_request = EmbeddingReqInput(
text=pairs,
is_cross_encoder_request=True,
)
return adapted_request, request
async def _handle_non_streaming_request(
self,
adapted_request: EmbeddingReqInput,
request: V1RerankReqInput,
raw_request: Request,
) -> Union[RerankResponse, ErrorResponse]:
"""Handle the rerank request"""
try:
ret = await self.tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return self.create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = self._build_rerank_response(ret, request)
return response
def _build_rerank_response(
self, ret: List[Dict[str, Any]], request: V1RerankReqInput
) -> List[RerankResponse]:
"""Build the rerank response from generation results"""
response = []
for idx, ret_item in enumerate(ret):
response.append(
RerankResponse(
score=ret_item["embedding"],
document=request.documents[idx],
index=idx,
meta_info=ret_item["meta_info"],
)
)
# Sort by score in descending order (highest relevance first)
response.sort(key=lambda x: x.score, reverse=True)
return response
import logging
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
ScoringRequest,
ScoringResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
logger = logging.getLogger(__name__)
class OpenAIServingScore(OpenAIServingBase):
"""Handler for scoring requests"""
def _request_id_prefix(self) -> str:
return "score-"
def _convert_to_internal_request(
self,
request: ScoringRequest,
) -> tuple[ScoringRequest, ScoringRequest]:
"""Convert OpenAI scoring request to internal format"""
# For scoring, we pass the request directly as the tokenizer_manager
# has a specialized score_request method that doesn't use GenerateReqInput
return request, request
async def _handle_non_streaming_request(
self,
adapted_request: ScoringRequest,
request: ScoringRequest,
raw_request: Request,
) -> Union[ScoringResponse, ErrorResponse]:
"""Handle the scoring request"""
try:
# Use tokenizer_manager's score_request method directly
scores = await self.tokenizer_manager.score_request(
query=request.query,
items=request.items,
label_token_ids=request.label_token_ids,
apply_softmax=request.apply_softmax,
item_first=request.item_first,
request=raw_request,
)
# Create response with just the scores, without usage info
response = ScoringResponse(
scores=scores,
model=request.model,
)
return response
except ValueError as e:
return self.create_error_response(str(e))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment