Unverified Commit 09988080 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

Refine OpenAI serving entrypoint to remove batch requests (#7372)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarChang Su <csu272@usc.edu>
parent 794be55a
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
import os import os
from enum import auto from enum import auto
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.protocol import CompletionRequest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
completion_template_name = None completion_template_name = None
...@@ -116,7 +116,7 @@ def is_completion_template_defined() -> bool: ...@@ -116,7 +116,7 @@ def is_completion_template_defined() -> bool:
return completion_template_name is not None return completion_template_name is not None
def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str: def generate_completion_prompt_from_request(request: CompletionRequest) -> str:
global completion_template_name global completion_template_name
if request.suffix == "": if request.suffix == "":
return request.prompt return request.prompt
......
...@@ -2,7 +2,7 @@ import json ...@@ -2,7 +2,7 @@ import json
import logging import logging
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, Optional, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
...@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC): ...@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format # Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request( adapted_request, processed_request = self._convert_to_internal_request(
request, self._generate_request_id_base(request) request
) )
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client # Note(Xinyuan): raw_request below is only used for detecting the connection of the client
...@@ -74,10 +74,7 @@ class OpenAIServingBase(ABC): ...@@ -74,10 +74,7 @@ class OpenAIServingBase(ABC):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: OpenAIServingRequest, request: OpenAIServingRequest,
request_id: str, ) -> tuple[GenerateReqInput, OpenAIServingRequest]:
) -> tuple[
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
]:
"""Convert OpenAI request to internal format""" """Convert OpenAI request to internal format"""
pass pass
......
...@@ -3,7 +3,7 @@ import json ...@@ -3,7 +3,7 @@ import json
import logging import logging
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
...@@ -52,60 +52,13 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -52,60 +52,13 @@ class OpenAIServingChat(OpenAIServingBase):
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "chatcmpl-" return "chatcmpl-"
def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]:
"""Validate chat messages format and content"""
if not (messages := request.messages):
return "Messages cannot be empty"
# Check for alternating user/assistant pattern (optional validation)
roles = [msg.role for msg in messages]
# First message should typically be from user or system
if roles[0] not in ["user", "system"]:
return "First message should be from 'user' or 'system'"
# Check for consecutive assistant messages (which might indicate an error)
for i in range(1, len(roles)):
if roles[i] == "assistant" and roles[i - 1] == "assistant":
# This is actually allowed in some cases, so just warn
pass
# Validate message content
for i, msg in enumerate(messages):
if msg.role == "user":
if not msg.content:
return f"User message at index {i} has no content"
elif msg.role == "assistant":
# Assistant messages can have no content if they have tool_calls
if not msg.content and not getattr(msg, "tool_calls", None):
return (
f"Assistant message at index {i} has no content or tool calls"
)
return None
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
all_requests: List[ChatCompletionRequest], request: ChatCompletionRequest,
request_ids: List[str], ) -> tuple[GenerateReqInput, ChatCompletionRequest]:
) -> tuple[
GenerateReqInput, Union[ChatCompletionRequest, List[ChatCompletionRequest]]
]:
"""Convert OpenAI chat completion request to internal format""" """Convert OpenAI chat completion request to internal format"""
input_ids = []
prompts = []
sampling_params_list = []
image_data_list = []
audio_data_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
modalities_list = []
lora_paths = []
is_multimodal = self.tokenizer_manager.model_config.is_multimodal is_multimodal = self.tokenizer_manager.model_config.is_multimodal
for request in all_requests:
# Process messages and apply chat template # Process messages and apply chat template
( (
prompt, prompt,
...@@ -117,72 +70,38 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -117,72 +70,38 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_constraint, tool_call_constraint,
) = self._process_messages(request, is_multimodal) ) = self._process_messages(request, is_multimodal)
input_ids.append(prompt_ids)
prompts.append(prompt)
return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs or 0)
lora_paths.append(request.lora_path)
# Build sampling parameters # Build sampling parameters
sampling_params = self._build_sampling_params( sampling_params = self._build_sampling_params(
request, stop, tool_call_constraint request, stop, tool_call_constraint
) )
sampling_params_list.append(sampling_params)
image_data_list.append(image_data)
audio_data_list.append(audio_data)
modalities_list.append(modalities)
# Handle single vs multiple requests # Handle single vs multiple requests
if len(all_requests) == 1:
if is_multimodal: if is_multimodal:
prompt_kwargs = {"text": prompts[0]} prompt_kwargs = {"text": prompt}
else: else:
if isinstance(input_ids[0], str): if isinstance(prompt_ids, str):
prompt_kwargs = {"text": input_ids[0]} prompt_kwargs = {"text": prompt_ids}
else: else:
prompt_kwargs = {"input_ids": input_ids[0]} prompt_kwargs = {"input_ids": prompt_ids}
sampling_params_list = sampling_params_list[0]
image_data_list = image_data_list[0]
audio_data_list = audio_data_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
modalities_list = modalities_list[0]
lora_paths = lora_paths[0]
request_ids = request_ids[0]
else:
if is_multimodal:
prompt_kwargs = {"text": prompts}
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
image_data=image_data_list, image_data=image_data,
audio_data=audio_data_list, audio_data=audio_data,
sampling_params=sampling_params_list, sampling_params=sampling_params,
return_logprob=return_logprobs, return_logprob=request.logprobs,
logprob_start_len=logprob_start_lens, logprob_start_len=-1,
top_logprobs_num=top_logprobs_nums, top_logprobs_num=request.top_logprobs or 0,
stream=all_requests[0].stream, stream=request.stream,
return_text_in_logprobs=True, return_text_in_logprobs=True,
rid=request_ids, modalities=modalities,
modalities=modalities_list, lora_path=request.lora_path,
lora_path=lora_paths, bootstrap_host=request.bootstrap_host,
bootstrap_host=all_requests[0].bootstrap_host, bootstrap_port=request.bootstrap_port,
bootstrap_port=all_requests[0].bootstrap_port, bootstrap_room=request.bootstrap_room,
bootstrap_room=all_requests[0].bootstrap_room,
) )
return adapted_request, ( return adapted_request, request
all_requests if len(all_requests) > 1 else all_requests[0]
)
def _process_messages( def _process_messages(
self, request: ChatCompletionRequest, is_multimodal: bool self, request: ChatCompletionRequest, is_multimodal: bool
...@@ -457,14 +376,29 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -457,14 +376,29 @@ class OpenAIServingChat(OpenAIServingBase):
raw_request: Request, raw_request: Request,
) -> StreamingResponse: ) -> StreamingResponse:
"""Handle streaming chat completion request""" """Handle streaming chat completion request"""
return StreamingResponse(
self._generate_chat_stream(adapted_request, request, raw_request),
media_type="text/event-stream",
background=self.tokenizer_manager.create_abort_task(adapted_request),
)
async def generate_stream_resp(): async def _generate_chat_stream(
self,
adapted_request: GenerateReqInput,
request: ChatCompletionRequest,
raw_request: Request,
) -> AsyncGenerator[str, None]:
"""Generate streaming chat completion response"""
# Parsers for tool calls and reasoning
parser_dict = {} parser_dict = {}
reasoning_parser_dict = {} reasoning_parser_dict = {}
tool_call_first = True
# State tracking for streaming
is_firsts = {} is_firsts = {}
stream_buffers = {} stream_buffers = {}
n_prev_tokens = {} n_prev_tokens = {}
# Usage tracking
prompt_tokens = {} prompt_tokens = {}
completion_tokens = {} completion_tokens = {}
cached_tokens = {} cached_tokens = {}
...@@ -475,10 +409,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -475,10 +409,6 @@ class OpenAIServingChat(OpenAIServingBase):
): ):
index = content.get("index", 0) index = content.get("index", 0)
is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
prompt_tokens[index] = content["meta_info"]["prompt_tokens"] prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
...@@ -487,21 +417,19 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -487,21 +417,19 @@ class OpenAIServingChat(OpenAIServingBase):
choice_logprobs = None choice_logprobs = None
if request.logprobs: if request.logprobs:
choice_logprobs = self._process_streaming_logprobs( choice_logprobs = self._process_streaming_logprobs(
content, n_prev_token content, n_prev_tokens.get(index, 0)
) )
n_prev_token = len( n_prev_tokens[index] = len(
content["meta_info"]["output_token_logprobs"] content["meta_info"]["output_token_logprobs"]
) )
finish_reason = content["meta_info"]["finish_reason"] finish_reason = content["meta_info"]["finish_reason"]
finish_reason_type = ( finish_reason_type = finish_reason["type"] if finish_reason else None
finish_reason["type"] if finish_reason else None
)
# First chunk with role # First chunk with role
if is_first: if is_firsts.get(index, True):
is_first = False is_firsts[index] = False
delta = DeltaMessage(role="assistant") delta = DeltaMessage(role="assistant", content="")
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
delta=delta, delta=delta,
...@@ -522,8 +450,9 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -522,8 +450,9 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
# Process content delta # Process content delta
stream_buffer = stream_buffers.get(index, "")
delta = content["text"][len(stream_buffer) :] delta = content["text"][len(stream_buffer) :]
new_stream_buffer = stream_buffer + delta stream_buffers[index] = stream_buffer + delta
# Handle reasoning content # Handle reasoning content
enable_thinking = getattr(request, "chat_template_kwargs", {}).get( enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
...@@ -552,9 +481,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -552,9 +481,6 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
if not delta: if not delta:
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
continue continue
# Handle tool calls # Handle tool calls
...@@ -571,8 +497,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -571,8 +497,7 @@ class OpenAIServingChat(OpenAIServingBase):
else: else:
# Regular content # Regular content
if delta or not ( if delta or not (
request.stream_options request.stream_options and request.stream_options.include_usage
and request.stream_options.include_usage
): ):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
...@@ -598,19 +523,8 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -598,19 +523,8 @@ class OpenAIServingChat(OpenAIServingBase):
) )
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer # Final chunk with finish_reason
is_firsts[index] = is_first finish_reason_chunk = ChatCompletionStreamResponse(
n_prev_tokens[index] = n_prev_token
# Final chunk with usage
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
prompt_tokens, completion_tokens, cached_tokens, request.n
)
else:
usage = None
final_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"], id=content["meta_info"]["id"],
created=int(time.time()), created=int(time.time()),
choices=[ choices=[
...@@ -618,12 +532,34 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -618,12 +532,34 @@ class OpenAIServingChat(OpenAIServingBase):
index=index, index=index,
delta=DeltaMessage(), delta=DeltaMessage(),
finish_reason=finish_reason_type, finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
) )
], ],
model=request.model, model=request.model,
usage=None,
)
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
# Additional usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
prompt_tokens,
completion_tokens,
cached_tokens,
request.n,
)
usage_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[], # Empty choices array as per OpenAI spec
model=request.model,
usage=usage, usage=usage,
) )
yield f"data: {final_chunk.model_dump_json()}\n\n" yield f"data: {usage_chunk.model_dump_json()}\n\n"
except Exception as e: except Exception as e:
error = self.create_streaming_error_response(str(e)) error = self.create_streaming_error_response(str(e))
...@@ -631,12 +567,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -631,12 +567,6 @@ class OpenAIServingChat(OpenAIServingBase):
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=self.tokenizer_manager.create_abort_task(adapted_request),
)
async def _handle_non_streaming_request( async def _handle_non_streaming_request(
self, self,
adapted_request: GenerateReqInput, adapted_request: GenerateReqInput,
...@@ -658,9 +588,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -658,9 +588,6 @@ class OpenAIServingChat(OpenAIServingBase):
request, request,
ret, ret,
int(time.time()), int(time.time()),
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
reasoning_parser=self.tokenizer_manager.server_args.reasoning_parser,
) )
return response return response
...@@ -670,9 +597,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -670,9 +597,6 @@ class OpenAIServingChat(OpenAIServingBase):
request: ChatCompletionRequest, request: ChatCompletionRequest,
ret: List[Dict[str, Any]], ret: List[Dict[str, Any]],
created: int, created: int,
cache_report: bool = False,
tool_call_parser: Optional[str] = None,
reasoning_parser: Optional[str] = None,
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
"""Build chat completion response from generation results""" """Build chat completion response from generation results"""
choices = [] choices = []
...@@ -691,6 +615,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -691,6 +615,7 @@ class OpenAIServingChat(OpenAIServingBase):
enable_thinking = getattr(request, "chat_template_kwargs", {}).get( enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True "enable_thinking", True
) )
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
if reasoning_parser and request.separate_reasoning and enable_thinking: if reasoning_parser and request.separate_reasoning and enable_thinking:
try: try:
parser = ReasoningParser( parser = ReasoningParser(
...@@ -708,6 +633,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -708,6 +633,7 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle tool calls # Handle tool calls
tool_calls = None tool_calls = None
if request.tool_choice != "none" and request.tools: if request.tool_choice != "none" and request.tools:
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
tool_calls, text, finish_reason = self._process_tool_calls( tool_calls, text, finish_reason = self._process_tool_calls(
text, request.tools, tool_call_parser, finish_reason text, request.tools, tool_call_parser, finish_reason
) )
...@@ -731,6 +657,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -731,6 +657,7 @@ class OpenAIServingChat(OpenAIServingBase):
choices.append(choice_data) choices.append(choice_data)
# Calculate usage # Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = aggregate_token_usage(ret, request.n, cache_report) usage = aggregate_token_usage(ret, request.n, cache_report)
return ChatCompletionResponse( return ChatCompletionResponse(
...@@ -810,7 +737,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -810,7 +737,7 @@ class OpenAIServingChat(OpenAIServingBase):
text, call_info_list = parser.parse_non_stream(text) text, call_info_list = parser.parse_non_stream(text)
tool_calls = [ tool_calls = [
ToolCall( ToolCall(
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}", id=f"call_{uuid.uuid4().hex[:24]}",
function=FunctionResponse( function=FunctionResponse(
name=call_info.name, arguments=call_info.parameters name=call_info.name, arguments=call_info.parameters
), ),
...@@ -894,6 +821,16 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -894,6 +821,16 @@ class OpenAIServingChat(OpenAIServingBase):
# Yield tool calls # Yield tool calls
for call_item in calls: for call_item in calls:
# Tool call ID should be generated only once per tool call
if call_item.name:
# First chunk: include ID and function name
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
function_name = call_item.name
else:
# Subsequent chunks: null ID and name for argument deltas
tool_call_id = None
function_name = None
if finish_reason_type == "stop": if finish_reason_type == "stop":
# Handle remaining arguments # Handle remaining arguments
latest_delta_len = 0 latest_delta_len = 0
...@@ -912,10 +849,10 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -912,10 +849,10 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason_type = "tool_calls" finish_reason_type = "tool_calls"
tool_call = ToolCall( tool_call = ToolCall(
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}", id=tool_call_id,
index=call_item.tool_index, index=call_item.tool_index,
function=FunctionResponse( function=FunctionResponse(
name=call_item.name, name=function_name,
arguments=call_item.parameters, arguments=call_item.parameters,
), ),
) )
......
import logging
import time import time
from typing import Any, Dict, List, Optional, Union from typing import Any, AsyncGenerator, Dict, List, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
...@@ -23,6 +24,8 @@ from sglang.srt.entrypoints.openai.utils import ( ...@@ -23,6 +24,8 @@ from sglang.srt.entrypoints.openai.utils import (
) )
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
logger = logging.getLogger(__name__)
class OpenAIServingCompletion(OpenAIServingBase): class OpenAIServingCompletion(OpenAIServingBase):
"""Handler for completion requests""" """Handler for completion requests"""
...@@ -30,134 +33,54 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -30,134 +33,54 @@ class OpenAIServingCompletion(OpenAIServingBase):
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "cmpl-" return "cmpl-"
def _validate_request(self, request: CompletionRequest) -> Optional[str]:
"""Validate completion prompt format and content"""
if not (prompt := request.prompt):
return "Prompt cannot be None"
if isinstance(prompt, str):
if not prompt.strip():
return "Prompt cannot be empty or whitespace only"
elif isinstance(prompt, list):
if not prompt:
return "Prompt list cannot be empty"
# Check if it's a list of strings
if all(isinstance(item, str) for item in prompt):
for i, item in enumerate(prompt):
if not item.strip():
return f"Prompt at index {i} cannot be empty or whitespace only"
# Check if it's a list of token IDs (integers)
elif all(isinstance(item, int) for item in prompt):
if any(item < 0 for item in prompt):
return "Token IDs must be non-negative"
# Check if it's a list of lists (multiple token sequences)
elif all(isinstance(item, list) for item in prompt):
for i, item in enumerate(prompt):
if not item:
return f"Token sequence at index {i} cannot be empty"
if not all(isinstance(token, int) for token in item):
return f"Token sequence at index {i} must contain only integers"
if any(token < 0 for token in item):
return (
f"Token sequence at index {i} contains negative token IDs"
)
else:
return "Prompt must be string, list of strings, list of integers, or list of integer lists"
else:
return "Prompt must be string or list"
return None
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
all_requests: List[CompletionRequest], request: CompletionRequest,
request_ids: List[str], ) -> tuple[GenerateReqInput, CompletionRequest]:
) -> tuple[GenerateReqInput, Union[CompletionRequest, List[CompletionRequest]]]:
"""Convert OpenAI completion request to internal format""" """Convert OpenAI completion request to internal format"""
# Validate batch requests # NOTE: with openai API, the prompt's logprobs are always not computed
if len(all_requests) > 1: if request.echo and request.logprobs:
first_prompt_type = type(all_requests[0].prompt) logger.warning(
for request in all_requests: "Echo is not compatible with logprobs. "
assert ( "To compute logprobs of input prompt, please use the native /generate API."
type(request.prompt) is first_prompt_type
), "All prompts must be of the same type in file input settings"
if request.n > 1:
raise ValueError(
"Parallel sampling is not supported for completions from files"
) )
prompts = []
sampling_params_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
lora_paths = []
for request in all_requests:
# Process prompt # Process prompt
prompt = request.prompt prompt = request.prompt
if is_completion_template_defined(): if is_completion_template_defined():
prompt = generate_completion_prompt_from_request(request) prompt = generate_completion_prompt_from_request(request)
prompts.append(prompt)
lora_paths.append(request.lora_path)
# Set logprob start length based on echo and logprobs # Set logprob start length based on echo and logprobs
if request.echo and request.logprobs: if request.echo and request.logprobs:
current_logprob_start_len = 0 logprob_start_len = 0
else: else:
current_logprob_start_len = -1 logprob_start_len = -1
# Build sampling parameters # Build sampling parameters
sampling_params = self._build_sampling_params(request) sampling_params = self._build_sampling_params(request)
sampling_params_list.append(sampling_params)
return_logprobs.append(request.logprobs is not None) # Determine prompt format
logprob_start_lens.append(current_logprob_start_len) if isinstance(prompt, str) or (
top_logprobs_nums.append( isinstance(prompt, list) and isinstance(prompt[0], str)
request.logprobs if request.logprobs is not None else 0 ):
) prompt_kwargs = {"text": prompt}
# Handle single vs multiple requests
if len(all_requests) == 1:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts[0]}
else:
prompt_kwargs = {"input_ids": prompts[0]}
sampling_params_list = sampling_params_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
lora_paths = lora_paths[0]
request_ids = request_ids[0]
else:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts}
else: else:
prompt_kwargs = {"input_ids": prompts} prompt_kwargs = {"input_ids": prompt}
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
sampling_params=sampling_params_list, sampling_params=sampling_params,
return_logprob=return_logprobs, return_logprob=request.logprobs is not None,
top_logprobs_num=top_logprobs_nums, top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
logprob_start_len=logprob_start_lens, logprob_start_len=logprob_start_len,
return_text_in_logprobs=True, return_text_in_logprobs=True,
stream=all_requests[0].stream, stream=request.stream,
rid=request_ids, lora_path=request.lora_path,
lora_path=lora_paths, bootstrap_host=request.bootstrap_host,
bootstrap_host=all_requests[0].bootstrap_host, bootstrap_port=request.bootstrap_port,
bootstrap_port=all_requests[0].bootstrap_port, bootstrap_room=request.bootstrap_room,
bootstrap_room=all_requests[0].bootstrap_room,
) )
return adapted_request, ( return adapted_request, request
all_requests if len(all_requests) > 1 else all_requests[0]
)
def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]: def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
"""Build sampling parameters for the request""" """Build sampling parameters for the request"""
...@@ -184,9 +107,6 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -184,9 +107,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
"logit_bias": request.logit_bias, "logit_bias": request.logit_bias,
} }
# No additional completion-specific parameters needed currently
# (json_schema is already handled in base method)
return sampling_params return sampling_params
async def _handle_streaming_request( async def _handle_streaming_request(
...@@ -196,11 +116,26 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -196,11 +116,26 @@ class OpenAIServingCompletion(OpenAIServingBase):
raw_request: Request, raw_request: Request,
) -> StreamingResponse: ) -> StreamingResponse:
"""Handle streaming completion request""" """Handle streaming completion request"""
return StreamingResponse(
self._generate_completion_stream(adapted_request, request, raw_request),
media_type="text/event-stream",
background=self.tokenizer_manager.create_abort_task(adapted_request),
)
async def _generate_completion_stream(
self,
adapted_request: GenerateReqInput,
request: CompletionRequest,
raw_request: Request,
) -> AsyncGenerator[str, None]:
"""Generate streaming completion response"""
created = int(time.time()) created = int(time.time())
async def generate_stream_resp(): # State tracking for streaming
stream_buffers = {} stream_buffers = {}
n_prev_tokens = {} n_prev_tokens = {}
# Usage tracking
prompt_tokens = {} prompt_tokens = {}
completion_tokens = {} completion_tokens = {}
cached_tokens = {} cached_tokens = {}
...@@ -211,14 +146,12 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -211,14 +146,12 @@ class OpenAIServingCompletion(OpenAIServingBase):
): ):
index = content.get("index", 0) index = content.get("index", 0)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
text = content["text"] text = content["text"]
prompt_tokens[index] = content["meta_info"]["prompt_tokens"] prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
stream_buffer = stream_buffers.get(index, "")
# Handle echo for first chunk # Handle echo for first chunk
if not stream_buffer: # The first chunk if not stream_buffer: # The first chunk
if request.echo: if request.echo:
...@@ -233,30 +166,29 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -233,30 +166,29 @@ class OpenAIServingCompletion(OpenAIServingBase):
input_token_logprobs = content["meta_info"][ input_token_logprobs = content["meta_info"][
"input_token_logprobs" "input_token_logprobs"
] ]
input_top_logprobs = content["meta_info"][ input_top_logprobs = content["meta_info"]["input_top_logprobs"]
"input_top_logprobs"
]
else: else:
input_token_logprobs = None input_token_logprobs = None
input_top_logprobs = None input_top_logprobs = None
n_prev_token = n_prev_tokens.get(index, 0)
logprobs = to_openai_style_logprobs( logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs, input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs, input_top_logprobs=input_top_logprobs,
output_token_logprobs=content["meta_info"][ output_token_logprobs=content["meta_info"][
"output_token_logprobs" "output_token_logprobs"
][n_prev_token:], ][n_prev_token:],
output_top_logprobs=content["meta_info"][ output_top_logprobs=content["meta_info"]["output_top_logprobs"][
"output_top_logprobs" n_prev_token:
][n_prev_token:], ],
) )
n_prev_token = len( n_prev_tokens[index] = len(
content["meta_info"]["output_token_logprobs"] content["meta_info"]["output_token_logprobs"]
) )
# Generate delta # Generate delta
delta = text[len(stream_buffer) :] delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"] finish_reason = content["meta_info"]["finish_reason"]
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
...@@ -278,15 +210,15 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -278,15 +210,15 @@ class OpenAIServingCompletion(OpenAIServingBase):
model=request.model, model=request.model,
) )
stream_buffers[index] = stream_buffer
n_prev_tokens[index] = n_prev_token
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
# Handle final usage chunk # Handle final usage chunk
if request.stream_options and request.stream_options.include_usage: if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base( usage = self._calculate_streaming_usage_base(
prompt_tokens, completion_tokens, cached_tokens, request.n prompt_tokens,
completion_tokens,
cached_tokens,
request.n,
) )
final_usage_chunk = CompletionStreamResponse( final_usage_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"], id=content["meta_info"]["id"],
...@@ -295,9 +227,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -295,9 +227,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
model=request.model, model=request.model,
usage=usage, usage=usage,
) )
final_usage_data = final_usage_chunk.model_dump_json( final_usage_data = final_usage_chunk.model_dump_json(exclude_none=True)
exclude_none=True
)
yield f"data: {final_usage_data}\n\n" yield f"data: {final_usage_data}\n\n"
except Exception as e: except Exception as e:
...@@ -306,12 +236,6 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -306,12 +236,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=self.tokenizer_manager.create_abort_task(adapted_request),
)
async def _handle_non_streaming_request( async def _handle_non_streaming_request(
self, self,
adapted_request: GenerateReqInput, adapted_request: GenerateReqInput,
...@@ -334,7 +258,6 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -334,7 +258,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
request, request,
ret, ret,
int(time.time()), int(time.time()),
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
) )
return response return response
...@@ -344,7 +267,6 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -344,7 +267,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
request: CompletionRequest, request: CompletionRequest,
ret: List[Dict[str, Any]], ret: List[Dict[str, Any]],
created: int, created: int,
cache_report: bool = False,
) -> CompletionResponse: ) -> CompletionResponse:
"""Build completion response from generation results""" """Build completion response from generation results"""
choices = [] choices = []
...@@ -352,7 +274,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -352,7 +274,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
# Prepare echo prompts if needed # Prepare echo prompts if needed
echo_prompts = [] echo_prompts = []
if (not isinstance(request, list)) and request.echo: if request.echo:
echo_prompts = self._prepare_echo_prompts(request) echo_prompts = self._prepare_echo_prompts(request)
echo = True echo = True
...@@ -360,21 +282,13 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -360,21 +282,13 @@ class OpenAIServingCompletion(OpenAIServingBase):
text = ret_item["text"] text = ret_item["text"]
# Handle echo # Handle echo
if isinstance(request, list) and request[idx].echo: if echo:
echo = True
text = request[idx].prompt + text
elif echo and not isinstance(request, list):
prompt_index = idx // request.n prompt_index = idx // request.n
text = echo_prompts[prompt_index] + text text = echo_prompts[prompt_index] + text
# Handle logprobs # Handle logprobs
logprobs = None logprobs = None
if isinstance(request, list) and request[idx].logprobs is not None: if request.logprobs is not None:
logprobs = True
elif (not isinstance(request, list)) and request.logprobs is not None:
logprobs = True
if logprobs:
if echo: if echo:
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"] input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
...@@ -407,6 +321,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -407,6 +321,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
choices.append(choice_data) choices.append(choice_data)
# Calculate usage # Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = aggregate_token_usage(ret, request.n, cache_report) usage = aggregate_token_usage(ret, request.n, cache_report)
return CompletionResponse( return CompletionResponse(
......
...@@ -54,34 +54,24 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -54,34 +54,24 @@ class OpenAIServingEmbedding(OpenAIServingBase):
return f"All items in input list must be integers" return f"All items in input list must be integers"
if item < 0: if item < 0:
return f"Token ID at index {i} must be non-negative" return f"Token ID at index {i} must be non-negative"
elif isinstance(first_item, list):
# List of lists (multiple token sequences)
for i, item in enumerate(input):
if not isinstance(item, list):
return f"Input at index {i} must be a list"
if not item:
return f"Input at index {i} cannot be empty"
if not all(isinstance(token, int) for token in item):
return f"Input at index {i} must contain only integers"
if any(token < 0 for token in item):
return f"Input at index {i} contains negative token IDs"
# Note: MultimodalEmbeddingInput validation would be handled by Pydantic
return None return None
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
request_id: str, ) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
"""Convert OpenAI embedding request to internal format""" """Convert OpenAI embedding request to internal format"""
prompt = request.input prompt = request.input
if isinstance(prompt, str): if isinstance(prompt, str):
# Single string input # Single string input
prompt_kwargs = {"text": prompt} prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list): elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str): if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings # List of strings - if it's a single string in a list, treat as single string
if len(prompt) == 1:
prompt_kwargs = {"text": prompt[0]}
else:
prompt_kwargs = {"text": prompt} prompt_kwargs = {"text": prompt}
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput): elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
# Handle multimodal embedding inputs # Handle multimodal embedding inputs
...@@ -94,7 +84,6 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -94,7 +84,6 @@ class OpenAIServingEmbedding(OpenAIServingBase):
generate_prompts = [] generate_prompts = []
# Check if we have a chat template for multimodal embeddings # Check if we have a chat template for multimodal embeddings
# This would need to be passed in from the server configuration
chat_template_name = getattr( chat_template_name = getattr(
self.tokenizer_manager, "chat_template_name", None self.tokenizer_manager, "chat_template_name", None
) )
...@@ -121,6 +110,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -121,6 +110,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
else: else:
# Other types (should not happen but handle gracefully) # Other types (should not happen but handle gracefully)
prompt_kwargs = {"input_ids": prompt} prompt_kwargs = {"input_ids": prompt}
adapted_request = EmbeddingReqInput( adapted_request = EmbeddingReqInput(
**prompt_kwargs, **prompt_kwargs,
) )
......
...@@ -104,52 +104,50 @@ class ServingChatTestCase(unittest.TestCase): ...@@ -104,52 +104,50 @@ class ServingChatTestCase(unittest.TestCase):
None, None,
) )
adapted, processed = self.chat._convert_to_internal_request( adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
[self.basic_req], ["rid"]
)
self.assertIsInstance(adapted, GenerateReqInput) self.assertIsInstance(adapted, GenerateReqInput)
self.assertFalse(adapted.stream) self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req) self.assertEqual(processed, self.basic_req)
# ------------- tool-call branch ------------- # # ------------- tool-call branch -------------
def test_tool_call_request_conversion(self): # def test_tool_call_request_conversion(self):
req = ChatCompletionRequest( # req = ChatCompletionRequest(
model="x", # model="x",
messages=[{"role": "user", "content": "Weather?"}], # messages=[{"role": "user", "content": "Weather?"}],
tools=[ # tools=[
{ # {
"type": "function", # "type": "function",
"function": { # "function": {
"name": "get_weather", # "name": "get_weather",
"parameters": {"type": "object", "properties": {}}, # "parameters": {"type": "object", "properties": {}},
}, # },
} # }
], # ],
tool_choice="auto", # tool_choice="auto",
) # )
with patch.object( # with patch.object(
self.chat, # self.chat,
"_process_messages", # "_process_messages",
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None), # return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
): # ):
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"]) # adapted, _ = self.chat._convert_to_internal_request(req, "rid")
self.assertEqual(adapted.rid, "rid") # self.assertEqual(adapted.rid, "rid")
def test_tool_choice_none(self): # def test_tool_choice_none(self):
req = ChatCompletionRequest( # req = ChatCompletionRequest(
model="x", # model="x",
messages=[{"role": "user", "content": "Hi"}], # messages=[{"role": "user", "content": "Hi"}],
tools=[{"type": "function", "function": {"name": "noop"}}], # tools=[{"type": "function", "function": {"name": "noop"}}],
tool_choice="none", # tool_choice="none",
) # )
with patch.object( # with patch.object(
self.chat, # self.chat,
"_process_messages", # "_process_messages",
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None), # return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
): # ):
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"]) # adapted, _ = self.chat._convert_to_internal_request(req, "rid")
self.assertEqual(adapted.rid, "rid") # self.assertEqual(adapted.rid, "rid")
# ------------- multimodal branch ------------- # ------------- multimodal branch -------------
def test_multimodal_request_with_images(self): def test_multimodal_request_with_images(self):
......
...@@ -36,12 +36,12 @@ class ServingCompletionTestCase(unittest.TestCase): ...@@ -36,12 +36,12 @@ class ServingCompletionTestCase(unittest.TestCase):
# ---------- prompt-handling ---------- # ---------- prompt-handling ----------
def test_single_string_prompt(self): def test_single_string_prompt(self):
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100) req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
internal, _ = self.sc._convert_to_internal_request([req], ["id"]) internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.text, "Hello world") self.assertEqual(internal.text, "Hello world")
def test_single_token_ids_prompt(self): def test_single_token_ids_prompt(self):
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100) req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
internal, _ = self.sc._convert_to_internal_request([req], ["id"]) internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.input_ids, [1, 2, 3, 4]) self.assertEqual(internal.input_ids, [1, 2, 3, 4])
def test_completion_template_handling(self): def test_completion_template_handling(self):
...@@ -55,7 +55,7 @@ class ServingCompletionTestCase(unittest.TestCase): ...@@ -55,7 +55,7 @@ class ServingCompletionTestCase(unittest.TestCase):
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request", "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
return_value="processed_prompt", return_value="processed_prompt",
): ):
internal, _ = self.sc._convert_to_internal_request([req], ["id"]) internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.text, "processed_prompt") self.assertEqual(internal.text, "processed_prompt")
# ---------- echo-handling ---------- # ---------- echo-handling ----------
......
...@@ -94,50 +94,42 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -94,50 +94,42 @@ class ServingEmbeddingTestCase(unittest.TestCase):
def test_convert_single_string_request(self): def test_convert_single_string_request(self):
"""Test converting single string request to internal format.""" """Test converting single string request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(self.basic_req)
self.basic_req, "test-id"
)
) )
self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.text, "Hello, how are you?") self.assertEqual(adapted_request.text, "Hello, how are you?")
self.assertEqual(adapted_request.rid, None) # self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.basic_req) self.assertEqual(processed_request, self.basic_req)
def test_convert_list_string_request(self): def test_convert_list_string_request(self):
"""Test converting list of strings request to internal format.""" """Test converting list of strings request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(self.list_req)
self.list_req, "test-id"
)
) )
self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual( self.assertEqual(
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"] adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
) )
self.assertEqual(adapted_request.rid, None) # self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.list_req) self.assertEqual(processed_request, self.list_req)
def test_convert_token_ids_request(self): def test_convert_token_ids_request(self):
"""Test converting token IDs request to internal format.""" """Test converting token IDs request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(self.token_ids_req)
self.token_ids_req, "test-id"
)
) )
self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5]) self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
self.assertEqual(adapted_request.rid, None) # self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.token_ids_req) self.assertEqual(processed_request, self.token_ids_req)
def test_convert_multimodal_request(self): def test_convert_multimodal_request(self):
"""Test converting multimodal request to internal format.""" """Test converting multimodal request to internal format."""
adapted_request, processed_request = ( adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request( self.serving_embedding._convert_to_internal_request(self.multimodal_req)
self.multimodal_req, "test-id"
)
) )
self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertIsInstance(adapted_request, EmbeddingReqInput)
...@@ -147,7 +139,7 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -147,7 +139,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIn("World", adapted_request.text) self.assertIn("World", adapted_request.text)
self.assertEqual(adapted_request.image_data[0], "base64_image_data") self.assertEqual(adapted_request.image_data[0], "base64_image_data")
self.assertIsNone(adapted_request.image_data[1]) self.assertIsNone(adapted_request.image_data[1])
self.assertEqual(adapted_request.rid, None) # self.assertEqual(adapted_request.rid, "test-id")
def test_build_single_embedding_response(self): def test_build_single_embedding_response(self):
"""Test building response for single embedding.""" """Test building response for single embedding."""
...@@ -194,9 +186,10 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -194,9 +186,10 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4 self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
self.assertEqual(response.usage.total_tokens, 7) self.assertEqual(response.usage.total_tokens, 7)
async def test_handle_request_success(self): def test_handle_request_success(self):
"""Test successful embedding request handling.""" """Test successful embedding request handling."""
async def run_test():
# Mock the generate_request to return expected data # Mock the generate_request to return expected data
async def mock_generate(): async def mock_generate():
yield { yield {
...@@ -216,8 +209,12 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -216,8 +209,12 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertEqual(len(response.data), 1) self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5]) self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
async def test_handle_request_validation_error(self): asyncio.run(run_test())
def test_handle_request_validation_error(self):
"""Test handling request with validation error.""" """Test handling request with validation error."""
async def run_test():
invalid_request = EmbeddingRequest(model="test-model", input="") invalid_request = EmbeddingRequest(model="test-model", input="")
response = await self.serving_embedding.handle_request( response = await self.serving_embedding.handle_request(
...@@ -227,9 +224,12 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -227,9 +224,12 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIsInstance(response, ORJSONResponse) self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
async def test_handle_request_generation_error(self): asyncio.run(run_test())
def test_handle_request_generation_error(self):
"""Test handling request with generation error.""" """Test handling request with generation error."""
async def run_test():
# Mock generate_request to raise an error # Mock generate_request to raise an error
async def mock_generate_error(): async def mock_generate_error():
raise ValueError("Generation failed") raise ValueError("Generation failed")
...@@ -246,8 +246,12 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -246,8 +246,12 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIsInstance(response, ORJSONResponse) self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
async def test_handle_request_internal_error(self): asyncio.run(run_test())
def test_handle_request_internal_error(self):
"""Test handling request with internal server error.""" """Test handling request with internal server error."""
async def run_test():
# Mock _convert_to_internal_request to raise an exception # Mock _convert_to_internal_request to raise an exception
with patch.object( with patch.object(
self.serving_embedding, self.serving_embedding,
...@@ -261,6 +265,8 @@ class ServingEmbeddingTestCase(unittest.TestCase): ...@@ -261,6 +265,8 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIsInstance(response, ORJSONResponse) self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 500) self.assertEqual(response.status_code, 500)
asyncio.run(run_test())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(verbosity=2) unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment