Unverified Commit bff2e5f1 authored by Andrew Xia's avatar Andrew Xia Committed by GitHub
Browse files

[gpt-oss][2] fix types for streaming (#24556)


Signed-off-by: default avatarAndrew Xia <axia@meta.com>
parent 3c068c63
......@@ -27,7 +27,6 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai import BaseModel
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool
......@@ -67,7 +66,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
RerankRequest, RerankResponse,
ResponsesRequest,
ResponsesResponse, ScoreRequest,
ScoreResponse, TokenizeRequest,
ScoreResponse,
StreamingResponsesResponse,
TokenizeRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
......@@ -481,8 +482,8 @@ async def show_version():
async def _convert_stream_to_sse_events(
generator: AsyncGenerator[BaseModel,
None]) -> AsyncGenerator[str, None]:
generator: AsyncGenerator[StreamingResponsesResponse, None]
) -> AsyncGenerator[str, None]:
"""Convert the generator to a stream of events in SSE format"""
async for event in generator:
event_type = getattr(event, 'type', 'unknown')
......
......@@ -18,10 +18,19 @@ from openai.types.chat.chat_completion_audio import (
from openai.types.chat.chat_completion_message import (
Annotation as OpenAIAnnotation)
# yapf: enable
from openai.types.responses import (ResponseFunctionToolCall,
ResponseInputItemParam, ResponseOutputItem,
ResponsePrompt, ResponseReasoningItem,
ResponseStatus)
from openai.types.responses import (
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent, ResponseCompletedEvent,
ResponseContentPartAddedEvent, ResponseContentPartDoneEvent,
ResponseCreatedEvent, ResponseFunctionToolCall, ResponseInProgressEvent,
ResponseInputItemParam, ResponseOutputItem, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponsePrompt, ResponseReasoningItem,
ResponseReasoningTextDeltaEvent, ResponseReasoningTextDoneEvent,
ResponseStatus, ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent)
# Backward compatibility for OpenAI client versions
try: # For older openai versions (< 1.100.0)
......@@ -251,6 +260,26 @@ ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
ResponseReasoningItem,
ResponseFunctionToolCall]
StreamingResponsesResponse: TypeAlias = Union[
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseCompletedEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
ResponseWebSearchCallCompletedEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallInterpretingEvent,
ResponseCodeInterpreterCallCompletedEvent,
]
class ResponsesRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
......
......@@ -10,24 +10,28 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack
from copy import copy
from http import HTTPStatus
from typing import Callable, Final, Optional, TypeVar, Union
from typing import Callable, Final, Optional, Union
import jinja2
import openai.types.responses as openai_responses_types
from fastapi import Request
from openai import BaseModel
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.responses import (ResponseCreatedEvent,
ResponseFunctionToolCall,
ResponseInProgressEvent,
ResponseOutputItem,
ResponseOutputItemDoneEvent,
ResponseOutputMessage, ResponseOutputText,
ResponseReasoningItem,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseStatus, response_text_delta_event)
from openai.types.responses import (
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent,
ResponseCodeInterpreterToolCallParam, ResponseCompletedEvent,
ResponseContentPartAddedEvent, ResponseContentPartDoneEvent,
ResponseCreatedEvent, ResponseFunctionToolCall, ResponseFunctionWebSearch,
ResponseInProgressEvent, ResponseOutputItem, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponseOutputMessage, ResponseOutputText,
ResponseReasoningItem, ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent, ResponseStatus, ResponseTextDeltaEvent,
ResponseTextDoneEvent, ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent,
response_function_web_search, response_text_delta_event)
from openai.types.responses.response_output_text import (Logprob,
LogprobTopLogprob)
# yapf: enable
......@@ -55,7 +59,8 @@ from vllm.entrypoints.openai.protocol import (DeltaMessage, ErrorResponse,
OutputTokensDetails,
RequestResponseMetadata,
ResponsesRequest,
ResponsesResponse, ResponseUsage)
ResponsesResponse, ResponseUsage,
StreamingResponsesResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
......@@ -175,7 +180,7 @@ class OpenAIServingResponses(OpenAIServing):
# HACK(wuhang): This is a hack. We should use a better store.
# FIXME: If enable_store=True, this may cause a memory leak since we
# never remove events from the store.
self.event_store: dict[str, tuple[deque[BaseModel],
self.event_store: dict[str, tuple[deque[StreamingResponsesResponse],
asyncio.Event]] = {}
self.background_tasks: dict[str, asyncio.Task] = {}
......@@ -186,8 +191,8 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[BaseModel, None], ResponsesResponse,
ErrorResponse]:
) -> Union[AsyncGenerator[StreamingResponsesResponse, None],
ResponsesResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
......@@ -814,7 +819,7 @@ class OpenAIServingResponses(OpenAIServing):
*args,
**kwargs,
):
event_deque: deque[BaseModel] = deque()
event_deque: deque[StreamingResponsesResponse] = deque()
new_event_signal = asyncio.Event()
self.event_store[request.request_id] = (event_deque, new_event_signal)
response = None
......@@ -867,7 +872,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
response_id: str,
starting_after: Optional[int] = None,
) -> AsyncGenerator[BaseModel, None]:
) -> AsyncGenerator[StreamingResponsesResponse, None]:
if response_id not in self.event_store:
raise ValueError(f"Unknown response_id: {response_id}")
......@@ -893,8 +898,8 @@ class OpenAIServingResponses(OpenAIServing):
response_id: str,
starting_after: Optional[int],
stream: Optional[bool],
) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[BaseModel,
None]]:
) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[
StreamingResponsesResponse, None]]:
if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id)
......@@ -977,9 +982,9 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[[BaseModel],
BaseModel],
) -> AsyncGenerator[BaseModel, None]:
_increment_sequence_number_and_return: Callable[
[StreamingResponsesResponse], StreamingResponsesResponse],
) -> AsyncGenerator[StreamingResponsesResponse, None]:
current_content_index = 0
current_output_index = 0
current_item_id = ""
......@@ -1017,13 +1022,11 @@ class OpenAIServingResponses(OpenAIServing):
current_item_id = str(uuid.uuid4())
if delta_message.reasoning_content:
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseReasoningItem(
item=ResponseReasoningItem(
type="reasoning",
id=current_item_id,
summary=[],
......@@ -1032,13 +1035,11 @@ class OpenAIServingResponses(OpenAIServing):
))
else:
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseOutputMessage(
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
......@@ -1047,13 +1048,13 @@ class OpenAIServingResponses(OpenAIServing):
),
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartAddedEvent(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
......@@ -1104,11 +1105,11 @@ class OpenAIServingResponses(OpenAIServing):
item=reasoning_item,
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemAddedEvent(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.ResponseOutputMessage(
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
......@@ -1119,13 +1120,13 @@ class OpenAIServingResponses(OpenAIServing):
current_output_index += 1
current_item_id = str(uuid.uuid4())
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartAddedEvent(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
......@@ -1148,7 +1149,7 @@ class OpenAIServingResponses(OpenAIServing):
))
elif delta_message.content is not None:
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDeltaEvent(
ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=current_content_index,
......@@ -1204,7 +1205,7 @@ class OpenAIServingResponses(OpenAIServing):
for pm in previous_delta_messages
if pm.content is not None)
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDoneEvent(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=current_output_index,
......@@ -1220,7 +1221,7 @@ class OpenAIServingResponses(OpenAIServing):
annotations=[],
)
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartDoneEvent(
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=current_item_id,
......@@ -1257,9 +1258,9 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[[BaseModel],
BaseModel],
) -> AsyncGenerator[BaseModel, None]:
_increment_sequence_number_and_return: Callable[
[StreamingResponsesResponse], StreamingResponsesResponse],
) -> AsyncGenerator[StreamingResponsesResponse, None]:
current_content_index = -1
current_output_index = 0
current_item_id: str = ""
......@@ -1314,7 +1315,7 @@ class OpenAIServingResponses(OpenAIServing):
annotations=[],
)
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDoneEvent(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=current_output_index,
......@@ -1324,7 +1325,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
......@@ -1334,7 +1334,7 @@ class OpenAIServingResponses(OpenAIServing):
part=text_content,
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
......@@ -1355,13 +1355,11 @@ class OpenAIServingResponses(OpenAIServing):
sent_output_item_added = True
current_item_id = f"msg_{random_uuid()}"
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseOutputMessage(
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
......@@ -1371,14 +1369,13 @@ class OpenAIServingResponses(OpenAIServing):
))
current_content_index += 1
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
......@@ -1386,7 +1383,7 @@ class OpenAIServingResponses(OpenAIServing):
),
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDeltaEvent(
ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=current_content_index,
......@@ -1402,13 +1399,11 @@ class OpenAIServingResponses(OpenAIServing):
sent_output_item_added = True
current_item_id = f"msg_{random_uuid()}"
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseReasoningItem(
item=ResponseReasoningItem(
type="reasoning",
id=current_item_id,
summary=[],
......@@ -1417,14 +1412,13 @@ class OpenAIServingResponses(OpenAIServing):
))
current_content_index += 1
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
......@@ -1450,13 +1444,11 @@ class OpenAIServingResponses(OpenAIServing):
sent_output_item_added = True
current_item_id = f"tool_{random_uuid()}"
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseCodeInterpreterToolCallParam(
item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=current_item_id,
code=None,
......@@ -1466,7 +1458,6 @@ class OpenAIServingResponses(OpenAIServing):
),
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallInProgressEvent(
type=
"response.code_interpreter_call.in_progress",
......@@ -1475,7 +1466,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallCodeDeltaEvent(
type="response.code_interpreter_call_code.delta",
sequence_number=-1,
......@@ -1495,14 +1485,12 @@ class OpenAIServingResponses(OpenAIServing):
action = None
parsed_args = json.loads(previous_item.content[0].text)
if function_name == "search":
action = (openai_responses_types.
response_function_web_search.ActionSearch(
type="search",
query=parsed_args["query"],
))
action = (response_function_web_search.ActionSearch(
type="search",
query=parsed_args["query"],
))
elif function_name == "open":
action = (
openai_responses_types.
response_function_web_search.ActionOpenPage(
type="open_page",
# TODO: translate to url
......@@ -1510,7 +1498,6 @@ class OpenAIServingResponses(OpenAIServing):
))
elif function_name == "find":
action = (
openai_responses_types.
response_function_web_search.ActionFind(
type="find",
pattern=parsed_args["pattern"],
......@@ -1523,12 +1510,11 @@ class OpenAIServingResponses(OpenAIServing):
current_item_id = f"tool_{random_uuid()}"
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemAddedEvent(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
response_function_web_search.
item=response_function_web_search.
ResponseFunctionWebSearch(
# TODO: generate a unique id for web search call
type="web_search_call",
......@@ -1538,7 +1524,6 @@ class OpenAIServingResponses(OpenAIServing):
),
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseWebSearchCallInProgressEvent(
type="response.web_search_call.in_progress",
sequence_number=-1,
......@@ -1546,7 +1531,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseWebSearchCallSearchingEvent(
type="response.web_search_call.searching",
sequence_number=-1,
......@@ -1556,7 +1540,6 @@ class OpenAIServingResponses(OpenAIServing):
# enqueue
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseWebSearchCallCompletedEvent(
type="response.web_search_call.completed",
sequence_number=-1,
......@@ -1564,12 +1547,11 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseFunctionWebSearch(
item=ResponseFunctionWebSearch(
type="web_search_call",
id=current_item_id,
action=action,
......@@ -1582,7 +1564,6 @@ class OpenAIServingResponses(OpenAIServing):
and previous_item.recipient is not None
and previous_item.recipient.startswith("python")):
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallCodeDoneEvent(
type="response.code_interpreter_call_code.done",
sequence_number=-1,
......@@ -1591,7 +1572,6 @@ class OpenAIServingResponses(OpenAIServing):
code=previous_item.content[0].text,
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallInterpretingEvent(
type="response.code_interpreter_call.interpreting",
sequence_number=-1,
......@@ -1599,7 +1579,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallCompletedEvent(
type="response.code_interpreter_call.completed",
sequence_number=-1,
......@@ -1607,12 +1586,11 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseCodeInterpreterToolCallParam(
item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=current_item_id,
code=previous_item.content[0].text,
......@@ -1633,7 +1611,7 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
) -> AsyncGenerator[BaseModel, None]:
) -> AsyncGenerator[StreamingResponsesResponse, None]:
# TODO:
# 1. Handle disconnect
......@@ -1641,9 +1619,9 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number = 0
T = TypeVar("T", bound=BaseModel)
def _increment_sequence_number_and_return(event: T) -> T:
def _increment_sequence_number_and_return(
event: StreamingResponsesResponse
) -> StreamingResponsesResponse:
nonlocal sequence_number
# Set sequence_number if the event has this attribute
if hasattr(event, 'sequence_number'):
......@@ -1705,7 +1683,7 @@ class OpenAIServingResponses(OpenAIServing):
created_time=created_time,
)
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseCompletedEvent(
ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,
response=final_response.model_dump(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment