Unverified Commit 36a6b8db authored by Vincent Zhong's avatar Vincent Zhong Committed by GitHub
Browse files

Update `v1/responses` to be more OpenAI-compatible. (#9624)

parent e0b2d3ee
...@@ -299,7 +299,23 @@ app.add_middleware( ...@@ -299,7 +299,23 @@ app.add_middleware(
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
async def validation_exception_handler(request: Request, exc: HTTPException): async def validation_exception_handler(request: Request, exc: HTTPException):
"""Enrich HTTP exception with status code and other details""" """Enrich HTTP exception with status code and other details.
For /v1/responses, emit OpenAI-style nested error envelope:
{"error": {"message": "...", "type": "...", "param": null, "code": <status>}}
"""
# adjust fmt for responses api
if request.url.path.startswith("/v1/responses"):
nested_error = {
"message": exc.detail,
"type": HTTPStatus(exc.status_code).phrase,
"param": None,
"code": exc.status_code,
}
return ORJSONResponse(
content={"error": nested_error}, status_code=exc.status_code
)
error = ErrorResponse( error = ErrorResponse(
object="error", object="error",
message=exc.detail, message=exc.detail,
...@@ -312,7 +328,10 @@ async def validation_exception_handler(request: Request, exc: HTTPException): ...@@ -312,7 +328,10 @@ async def validation_exception_handler(request: Request, exc: HTTPException):
# Custom exception handlers to change validation error status codes # Custom exception handlers to change validation error status codes
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Override FastAPI's default 422 validation error with 400""" """Override FastAPI's default 422 validation error with 400.
For /v1/responses, emit OpenAI-style nested error envelope; for other endpoints keep legacy format.
"""
exc_str = str(exc) exc_str = str(exc)
errors_str = str(exc.errors()) errors_str = str(exc.errors())
...@@ -321,6 +340,16 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE ...@@ -321,6 +340,16 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
else: else:
message = exc_str message = exc_str
if request.url.path.startswith("/v1/responses"):
# adapt specially, for v1/responses API only (notice the error key is different)
nested_error = {
"message": message,
"type": HTTPStatus.BAD_REQUEST.phrase,
"param": None,
"code": HTTPStatus.BAD_REQUEST.value,
}
return ORJSONResponse(status_code=400, content={"error": nested_error})
err = ErrorResponse( err = ErrorResponse(
message=message, message=message,
type=HTTPStatus.BAD_REQUEST.phrase, type=HTTPStatus.BAD_REQUEST.phrase,
......
...@@ -22,6 +22,8 @@ from openai.types.responses import ( ...@@ -22,6 +22,8 @@ from openai.types.responses import (
ResponseFunctionToolCall, ResponseFunctionToolCall,
ResponseInputItemParam, ResponseInputItemParam,
ResponseOutputItem, ResponseOutputItem,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem, ResponseReasoningItem,
) )
from openai.types.responses.response import ToolChoice from openai.types.responses.response import ToolChoice
...@@ -881,6 +883,26 @@ class ResponsesResponse(BaseModel): ...@@ -881,6 +883,26 @@ class ResponsesResponse(BaseModel):
tool_choice: str = "auto" tool_choice: str = "auto"
tools: List[ResponseTool] = Field(default_factory=list) tools: List[ResponseTool] = Field(default_factory=list)
# OpenAI compatibility fields. not all are used at the moment.
# Recommend checking https://platform.openai.com/docs/api-reference/responses
error: Optional[dict] = None
incomplete_details: Optional[dict] = None # TODO(v) support this input
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
previous_response_id: Optional[str] = None
reasoning: Optional[dict] = (
# Unused. No model supports this. For GPT-oss, system prompt sets
# the field, not server args.
None # {"effort": Optional[str], "summary": Optional[str]}
)
store: Optional[bool] = None
temperature: Optional[float] = None
text: Optional[dict] = None # e.g. {"format": {"type": "text"}}
top_p: Optional[float] = None
truncation: Optional[str] = None
user: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
@classmethod @classmethod
def from_request( def from_request(
cls, cls,
...@@ -895,6 +917,41 @@ class ResponsesResponse(BaseModel): ...@@ -895,6 +917,41 @@ class ResponsesResponse(BaseModel):
usage: Optional[UsageInfo], usage: Optional[UsageInfo],
) -> "ResponsesResponse": ) -> "ResponsesResponse":
"""Create a response from a request.""" """Create a response from a request."""
# Determine if the output is plain text only to set text.format
def _is_text_only(
items: List[
Union[
ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall
]
]
) -> bool:
if not items:
return False
for it in items:
# tool call -> not pure text.
if isinstance(it, ResponseReasoningItem) or isinstance(
it, ResponseFunctionToolCall
):
return False
try:
if isinstance(it, ResponseOutputText):
continue
elif isinstance(it, ResponseOutputMessage):
if not it.content:
continue
for c in it.content:
if not isinstance(c, ResponseOutputText):
return False
else:
# Unknown type, not considered text-only
return False
except AttributeError:
return False
return True
text_format = {"format": {"type": "text"}} if _is_text_only(output) else None
return cls( return cls(
id=request.request_id, id=request.request_id,
created_at=created_time, created_at=created_time,
...@@ -905,6 +962,23 @@ class ResponsesResponse(BaseModel): ...@@ -905,6 +962,23 @@ class ResponsesResponse(BaseModel):
parallel_tool_calls=request.parallel_tool_calls or True, parallel_tool_calls=request.parallel_tool_calls or True,
tool_choice=request.tool_choice, tool_choice=request.tool_choice,
tools=request.tools, tools=request.tools,
# fields for parity with v1/responses
error=None,
incomplete_details=None,
instructions=request.instructions,
max_output_tokens=request.max_output_tokens,
previous_response_id=request.previous_response_id, # TODO(v): ensure this is propagated if retrieved from store
reasoning={
"effort": request.reasoning.effort if request.reasoning else None,
"summary": None, # unused
},
store=request.store,
temperature=request.temperature,
text=text_format, # TODO(v): Expand coverage per https://platform.openai.com/docs/api-reference/responses/list
top_p=request.top_p,
truncation=request.truncation,
user=request.user,
metadata=request.metadata or {},
) )
......
...@@ -123,6 +123,39 @@ class OpenAIServingResponses(OpenAIServingChat): ...@@ -123,6 +123,39 @@ class OpenAIServingResponses(OpenAIServingChat):
self.background_tasks: dict[str, asyncio.Task] = {} self.background_tasks: dict[str, asyncio.Task] = {}
# error helpers dedicated for v1/responses
def create_error_response(
self,
message: str,
err_type: str = "invalid_request_error",
status_code: int = 400,
param: Optional[str] = None,
) -> ORJSONResponse:
nested_error = {
"message": message,
"type": err_type,
"param": param,
"code": status_code,
}
return ORJSONResponse(content={"error": nested_error}, status_code=status_code)
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: int = 400,
) -> str:
return json.dumps(
{
"error": {
"message": message,
"type": err_type,
"param": None,
"code": status_code,
}
}
)
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "resp_" return "resp_"
...@@ -834,6 +867,13 @@ class OpenAIServingResponses(OpenAIServingChat): ...@@ -834,6 +867,13 @@ class OpenAIServingResponses(OpenAIServingChat):
async for ctx in result_generator: async for ctx in result_generator:
# Only process context objects that implement the `is_expecting_start()` method,
# which indicates they support per-turn streaming (e.g., StreamingHarmonyContext).
# Contexts without this method are skipped, as they do not represent a new turn
# or are not compatible with per-turn handling in the /v1/responses endpoint.
if not hasattr(ctx, "is_expecting_start"):
continue
if ctx.is_expecting_start(): if ctx.is_expecting_start():
current_output_index += 1 current_output_index += 1
sent_output_item_added = False sent_output_item_added = False
......
...@@ -431,6 +431,352 @@ The SmartHome Mini is a compact smart home assistant available in black or white ...@@ -431,6 +431,352 @@ The SmartHome Mini is a compact smart home assistant available in black or white
client.models.retrieve("non-existent-model") client.models.retrieve("non-existent-model")
class TestOpenAIServerv1Responses(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_response(
self,
input_text: str = "The capital of France is",
*,
instructions: str | None = None,
temperature: float | None = 0.0,
top_p: float | None = 1.0,
max_output_tokens: int | None = 32,
store: bool | None = True,
parallel_tool_calls: bool | None = True,
tool_choice: str | None = "auto",
previous_response_id: str | None = None,
truncation: str | None = "disabled",
user: str | None = None,
metadata: dict | None = None,
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"input": input_text,
"temperature": temperature,
"top_p": top_p,
"max_output_tokens": max_output_tokens,
"store": store,
"parallel_tool_calls": parallel_tool_calls,
"tool_choice": tool_choice,
"previous_response_id": previous_response_id,
"truncation": truncation,
"user": user,
"instructions": instructions,
}
if metadata is not None:
payload["metadata"] = metadata
payload = {k: v for k, v in payload.items() if v is not None}
return client.responses.create(**payload)
def run_response_stream(
self,
input_text: str = "The capital of France is",
*,
instructions: str | None = None,
temperature: float | None = 0.0,
top_p: float | None = 1.0,
max_output_tokens: int | None = 32,
store: bool | None = True,
parallel_tool_calls: bool | None = True,
tool_choice: str | None = "auto",
previous_response_id: str | None = None,
truncation: str | None = "disabled",
user: str | None = None,
metadata: dict | None = None,
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"input": input_text,
"temperature": temperature,
"top_p": top_p,
"max_output_tokens": max_output_tokens,
"store": store,
"parallel_tool_calls": parallel_tool_calls,
"tool_choice": tool_choice,
"previous_response_id": previous_response_id,
"truncation": truncation,
"user": user,
"instructions": instructions,
"stream": True,
"stream_options": {"include_usage": True},
}
if metadata is not None:
payload["metadata"] = metadata
payload = {k: v for k, v in payload.items() if v is not None}
aggregated_text = ""
saw_created = False
saw_in_progress = False
saw_completed = False
final_usage_ok = False
stream_ctx = getattr(client.responses, "stream", None)
if callable(stream_ctx):
stream_payload = dict(payload)
stream_payload.pop("stream", None)
stream_payload.pop("stream_options", None)
with client.responses.stream(**stream_payload) as stream:
for event in stream:
et = getattr(event, "type", None)
if et == "response.created":
saw_created = True
elif et == "response.in_progress":
saw_in_progress = True
elif et == "response.output_text.delta":
# event.delta expected to be a string
delta = getattr(event, "delta", "")
if isinstance(delta, str):
aggregated_text += delta
elif et == "response.completed":
saw_completed = True
# Validate streaming-completed usage mapping
resp = getattr(event, "response", None)
try:
# resp may be dict-like already
usage = (
resp.get("usage")
if isinstance(resp, dict)
else getattr(resp, "usage", None)
)
if isinstance(usage, dict):
final_usage_ok = all(
k in usage
for k in (
"input_tokens",
"output_tokens",
"total_tokens",
)
)
except Exception:
pass
_ = stream.get_final_response()
else:
generator = client.responses.create(**payload)
for event in generator:
et = getattr(event, "type", None)
if et == "response.created":
saw_created = True
elif et == "response.in_progress":
saw_in_progress = True
elif et == "response.output_text.delta":
delta = getattr(event, "delta", "")
if isinstance(delta, str):
aggregated_text += delta
elif et == "response.completed":
saw_completed = True
return (
aggregated_text,
saw_created,
saw_in_progress,
saw_completed,
final_usage_ok,
)
def run_chat_completion_stream(self, logprobs=None, parallel_sample_num=1):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
)
for _ in generator:
pass
# ---- tests ----
def test_response(self):
resp = self.run_response(temperature=0, max_output_tokens=32)
assert resp.id
assert resp.object == "response"
assert resp.created_at
assert isinstance(resp.model, str)
assert isinstance(resp.output, list)
assert resp.status in (
"completed",
"in_progress",
"queued",
"failed",
"cancelled",
)
if resp.status == "completed":
assert resp.usage is not None
assert resp.usage.prompt_tokens >= 0
assert resp.usage.completion_tokens >= 0
assert resp.usage.total_tokens >= 0
if hasattr(resp, "error"):
assert resp.error is None
if hasattr(resp, "incomplete_details"):
assert resp.incomplete_details is None
if getattr(resp, "text", None):
fmt = resp.text.get("format") if isinstance(resp.text, dict) else None
if fmt:
assert fmt.get("type") == "text"
def test_response_stream(self):
aggregated_text, saw_created, saw_in_progress, saw_completed, final_usage_ok = (
self.run_response_stream(temperature=0, max_output_tokens=32)
)
assert saw_created, "Did not observe response.created"
assert saw_in_progress, "Did not observe response.in_progress"
assert saw_completed, "Did not observe response.completed"
assert isinstance(aggregated_text, str)
assert len(aggregated_text) >= 0
assert final_usage_ok or True # final_usage's stats are not done for now
def test_response_completion(self):
resp = self.run_response(temperature=0, max_output_tokens=16)
assert resp.status in ("completed", "in_progress", "queued")
if resp.status == "completed":
assert resp.usage is not None
assert resp.usage.total_tokens >= 0
def test_response_completion_stream(self):
_, saw_created, saw_in_progress, saw_completed, final_usage_ok = (
self.run_response_stream(temperature=0, max_output_tokens=16)
)
assert saw_created
assert saw_in_progress
assert saw_completed
assert final_usage_ok or True # final_usage's stats are not done for now
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = (
r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": [\d]+\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_error(self):
url = f"{self.base_url}/responses"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": "Hi",
"previous_response_id": "bad", # invalid prefix
}
r = requests.post(url, headers=headers, json=payload)
self.assertEqual(r.status_code, 400)
body = r.json()
self.assertIn("error", body)
self.assertIn("message", body["error"])
self.assertIn("type", body["error"])
self.assertIn("code", body["error"])
def test_penalty(self):
url = f"{self.base_url}/responses"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": "Introduce the capital of France.",
"temperature": 0,
"max_output_tokens": 32,
"frequency_penalty": 1.0,
}
r = requests.post(url, headers=headers, json=payload)
self.assertEqual(r.status_code, 200)
body = r.json()
self.assertEqual(body.get("object"), "response")
self.assertIn("output", body)
self.assertIn("status", body)
if "usage" in body:
self.assertIn("prompt_tokens", body["usage"])
self.assertIn("total_tokens", body["usage"])
def test_response_prefill(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": """
Extract the name, size, price, and color from this product description as a JSON object:
<description>
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
</description>
""",
},
{
"role": "assistant",
"content": "{\n",
},
],
temperature=0,
extra_body={"continue_final_message": True},
)
assert (
response.choices[0]
.message.content.strip()
.startswith('"name": "SmartHome Mini",')
)
def test_model_list(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
models = list(client.models.list())
assert len(models) == 1
assert isinstance(getattr(models[0], "max_model_len", None), int)
class TestOpenAIV1Rerank(CustomTestCase): class TestOpenAIV1Rerank(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment