Unverified Commit b5e14b2b authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[1/2][feature] support openai like classification api (#11618)

parent d513ee93
# Classification API
This document describes the `/v1/classify` API endpoint implementation in SGLang, which is compatible with vLLM's classification API format.
## Overview
The classification API allows you to classify text inputs using classification models. This implementation follows the same format as vLLM's 0.7.0 classification API.
## API Endpoint
```
POST /v1/classify
```
## Request Format
```json
{
"model": "model_name",
"input": "text to classify"
}
```
### Parameters
- `model` (string, required): The name of the classification model to use
- `input` (string, required): The text to classify
- `user` (string, optional): User identifier for tracking
- `rid` (string, optional): Request ID for tracking
- `priority` (integer, optional): Request priority
## Response Format
```json
{
"id": "classify-9bf17f2847b046c7b2d5495f4b4f9682",
"object": "list",
"created": 1745383213,
"model": "jason9693/Qwen2.5-1.5B-apeach",
"data": [
{
"index": 0,
"label": "Default",
"probs": [0.565970778465271, 0.4340292513370514],
"num_classes": 2
}
],
"usage": {
"prompt_tokens": 10,
"total_tokens": 10,
"completion_tokens": 0,
"prompt_tokens_details": null
}
}
```
### Response Fields
- `id`: Unique identifier for the classification request
- `object`: Always "list"
- `created`: Unix timestamp when the request was created
- `model`: The model used for classification
- `data`: Array of classification results
- `index`: Index of the result
- `label`: Predicted class label
- `probs`: Array of probabilities for each class
- `num_classes`: Total number of classes
- `usage`: Token usage information
- `prompt_tokens`: Number of input tokens
- `total_tokens`: Total number of tokens
- `completion_tokens`: Number of completion tokens (always 0 for classification)
- `prompt_tokens_details`: Additional token details (optional)
## Example Usage
### Using curl
```bash
curl -v "http://127.0.0.1:8000/v1/classify" \
-H "Content-Type: application/json" \
-d '{
"model": "jason9693/Qwen2.5-1.5B-apeach",
"input": "Loved the new café—coffee was great."
}'
```
### Using Python
```python
import requests
import json
# Make classification request
response = requests.post(
"http://127.0.0.1:8000/v1/classify",
headers={"Content-Type": "application/json"},
json={
"model": "jason9693/Qwen2.5-1.5B-apeach",
"input": "Loved the new café—coffee was great."
}
)
# Parse response
result = response.json()
print(json.dumps(result, indent=2))
```
## Supported Models
The classification API works with any classification model supported by SGLang, including:
### Classification Models (Multi-class)
- `LlamaForSequenceClassification` - Multi-class classification
- `Qwen2ForSequenceClassification` - Multi-class classification
- `Qwen3ForSequenceClassification` - Multi-class classification
- `BertForSequenceClassification` - Multi-class classification
- `Gemma2ForSequenceClassification` - Multi-class classification
**Label Mapping**: The API automatically uses the `id2label` mapping from the model's `config.json` file to provide meaningful label names instead of generic class names. If `id2label` is not available, it falls back to `LABEL_0`, `LABEL_1`, etc., or `Class_0`, `Class_1` as a last resort.
### Reward Models (Single score)
- `InternLM2ForRewardModel` - Single reward score
- `Qwen2ForRewardModel` - Single reward score
- `LlamaForSequenceClassificationWithNormal_Weights` - Special reward model
**Note**: The `/classify` endpoint in SGLang was originally designed for reward models but now supports all non-generative models. Our `/v1/classify` endpoint provides a standardized vLLM-compatible interface for classification tasks.
## Error Handling
The API returns appropriate HTTP status codes and error messages:
- `400 Bad Request`: Invalid request format or missing required fields
- `500 Internal Server Error`: Server-side processing error
Error response format:
```json
{
"error": "Error message",
"type": "error_type",
"code": 400
}
```
## Implementation Details
The classification API is implemented using:
1. **Rust Router**: Handles routing and request/response models in `sgl-router/src/protocols/spec.rs`
2. **Python HTTP Server**: Implements the actual endpoint in `python/sglang/srt/entrypoints/http_server.py`
3. **Classification Service**: Handles the classification logic in `python/sglang/srt/entrypoints/openai/serving_classify.py`
## Testing
Use the provided test script to verify the implementation:
```bash
python test_classify_api.py
```
## Compatibility
This implementation is compatible with vLLM's classification API format, allowing seamless migration from vLLM to SGLang for classification tasks.
...@@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationM ...@@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationM
from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ClassifyRequest,
CompletionRequest, CompletionRequest,
DetokenizeRequest, DetokenizeRequest,
EmbeddingRequest, EmbeddingRequest,
...@@ -62,6 +63,7 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -62,6 +63,7 @@ from sglang.srt.entrypoints.openai.protocol import (
V1RerankReqInput, V1RerankReqInput,
) )
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.entrypoints.openai.serving_classify import OpenAIServingClassify
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
...@@ -228,6 +230,9 @@ async def lifespan(fast_api_app: FastAPI): ...@@ -228,6 +230,9 @@ async def lifespan(fast_api_app: FastAPI):
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding( fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
_global_state.tokenizer_manager, _global_state.template_manager _global_state.tokenizer_manager, _global_state.template_manager
) )
fast_api_app.state.openai_serving_classify = OpenAIServingClassify(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_score = OpenAIServingScore( fast_api_app.state.openai_serving_score = OpenAIServingScore(
_global_state.tokenizer_manager _global_state.tokenizer_manager
) )
...@@ -1082,6 +1087,18 @@ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request): ...@@ -1082,6 +1087,18 @@ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
) )
@app.post(
"/v1/classify",
response_class=ORJSONResponse,
dependencies=[Depends(validate_json_request)],
)
async def openai_v1_classify(request: ClassifyRequest, raw_request: Request):
"""OpenAI-compatible classification endpoint."""
return await raw_request.app.state.openai_serving_classify.handle_request(
request, raw_request
)
@app.post( @app.post(
"/v1/tokenize", "/v1/tokenize",
response_class=ORJSONResponse, response_class=ORJSONResponse,
......
...@@ -761,6 +761,37 @@ class EmbeddingObject(BaseModel): ...@@ -761,6 +761,37 @@ class EmbeddingObject(BaseModel):
object: str = "embedding" object: str = "embedding"
ClassifyInput = Union[str, List[str], List[int]]
class ClassifyRequest(BaseModel):
# OpenAI-compatible classification request
model: str = DEFAULT_MODEL_NAME
input: ClassifyInput
user: Optional[str] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
class ClassifyData(BaseModel):
index: int
label: str
probs: List[float]
num_classes: int
class ClassifyResponse(BaseModel):
id: str
object: str = "list"
created: int
model: str
data: List[ClassifyData]
usage: UsageInfo
class EmbeddingResponse(BaseModel): class EmbeddingResponse(BaseModel):
data: List[EmbeddingObject] data: List[EmbeddingObject]
model: str model: str
...@@ -844,6 +875,7 @@ OpenAIServingRequest = Union[ ...@@ -844,6 +875,7 @@ OpenAIServingRequest = Union[
ChatCompletionRequest, ChatCompletionRequest,
CompletionRequest, CompletionRequest,
EmbeddingRequest, EmbeddingRequest,
ClassifyRequest,
ScoringRequest, ScoringRequest,
V1RerankReqInput, V1RerankReqInput,
TokenizeRequest, TokenizeRequest,
......
from __future__ import annotations
import logging
import time
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from fastapi import Request
from fastapi.responses import ORJSONResponse
from sglang.srt.entrypoints.openai.protocol import (
ClassifyRequest,
ClassifyResponse,
ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.managers.io_struct import EmbeddingReqInput
if TYPE_CHECKING:
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__name__)
class OpenAIServingClassify(OpenAIServingBase):
"""Handler for v1/classify requests"""
def __init__(
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
self.id2label = self._get_id2label_mapping()
self.model_name = (
self.tokenizer_manager.served_model_name
if self.tokenizer_manager.served_model_name
else self.tokenizer_manager.server_args.model_path
)
if not self.id2label:
raise ValueError("id2label mapping is missing")
def _request_id_prefix(self) -> str:
return "classify-"
def _convert_to_internal_request(
self,
request: ClassifyRequest,
raw_request: Request = None,
) -> tuple[EmbeddingReqInput, ClassifyRequest]:
"""Convert OpenAI embedding request to internal format"""
prompt = request.input
if isinstance(prompt, str):
# Single string input
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
prompt_kwargs = {"text": prompt}
else:
# List of integers (token IDs) or empty list
prompt_kwargs = {"input_ids": prompt}
else:
# Other types (should not happen but handle gracefully)
prompt_kwargs = {"input_ids": prompt}
adapted_request = EmbeddingReqInput(
**prompt_kwargs,
rid=request.rid,
priority=request.priority,
)
return adapted_request, request
def _validate_request(self, request: ClassifyRequest) -> Optional[str]:
"""Validate that the input is not empty or whitespace only."""
if not (input := request.input):
return "Input cannot be empty"
# Handle single string
if isinstance(input, str):
if not input.strip():
return "Input cannot be empty or whitespace only"
return None
# Handle list inputs
if isinstance(input, list):
# Check first element to determine type
first_item = input[0]
if isinstance(first_item, str):
# List of strings
for i, item in enumerate(input):
if not isinstance(item, str):
return f"All items in input list must be strings"
if not item.strip():
return f"Input at index {i} cannot be empty or whitespace only"
elif isinstance(first_item, int):
# List of integers (token IDs)
for i, item in enumerate(input):
if not isinstance(item, int):
return f"All items in input list must be integers"
if item < 0:
return f"Token ID at index {i} must be non-negative"
return None
def _get_id2label_mapping(self) -> Optional[Dict[int, str]]:
"""Get id2label mapping from model config."""
try:
hf_config = self.tokenizer_manager.model_config.hf_config
# Check for id2label in hf_config
if hf_config.id2label:
return hf_config.id2label
# Check for num_labels and create default mapping if needed
if hasattr(hf_config, "num_labels") and hf_config.num_labels:
num_labels = hf_config.num_labels
# Create default mapping: {0: "LABEL_0", 1: "LABEL_1", ...}
return {i: f"LABEL_{i}" for i in range(num_labels)}
except Exception as e:
logger.warning(f"Failed to get id2label mapping: {e}")
return None
async def _handle_non_streaming_request(
self,
adapted_request: EmbeddingReqInput,
request: ClassifyRequest,
raw_request: Request,
) -> Union[ClassifyResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming classification request."""
# Generate request ID
try:
ret = await self.tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return self.create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = self._build_classify_response(ret)
return response
def _build_classify_response(self, ret: List[Dict[str, Any]]) -> ClassifyResponse:
request_id = f"{self._request_id_prefix()}{uuid.uuid4().hex}"
created_time = int(time.time())
classify_objects = []
prompt_tokens = 0
total_latency = 0.0
for i, item in enumerate(ret):
embedding = item.get("embedding", [])
meta_info = item.get("meta_info", {})
prompt_tokens += meta_info.get("prompt_tokens", 0)
total_latency += meta_info.get("e2e_latency", 0.0)
if embedding:
try:
embedding_tensor = torch.tensor(embedding, dtype=torch.float32)
probs = F.softmax(embedding_tensor, dim=0).tolist()
predicted_class = torch.argmax(embedding_tensor).item()
label = self.id2label[predicted_class]
except Exception as e:
logger.error(f"Error processing embedding for item {i}: {e}")
probs = [1.0]
label = "Default"
else:
probs = [1.0]
label = "Default"
classify_obj = {
"index": i,
"label": label,
"probs": probs,
"num_classes": len(probs),
}
classify_objects.append(classify_obj)
response = {
"id": request_id,
"object": "list",
"created": created_time,
"model": self.model_name,
"data": classify_objects,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
"completion_tokens": 0,
"prompt_tokens_details": None,
},
}
return ClassifyResponse(**response)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment