Unverified Commit 4c4b6f7a authored by Daniel Mescheder's avatar Daniel Mescheder Committed by GitHub
Browse files

[Frontend] Add sampling parameters to Responses API (#32609)


Signed-off-by: default avatarDaniel Mescheder <dmesch@amazon.com>
Co-authored-by: default avatarDaniel Mescheder <dmesch@amazon.com>
parent 10546f92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for ResponsesRequest.to_sampling_params() parameter mapping."""
import pytest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
class TestResponsesRequestSamplingParams:
"""Test that ResponsesRequest correctly maps parameters to SamplingParams."""
def test_basic_sampling_params(self):
"""Test basic sampling parameters are correctly mapped."""
request = ResponsesRequest(
model="test-model",
input="test input",
temperature=0.8,
top_p=0.95,
top_k=50,
max_output_tokens=100,
)
sampling_params = request.to_sampling_params(default_max_tokens=1000)
assert sampling_params.temperature == 0.8
assert sampling_params.top_p == 0.95
assert sampling_params.top_k == 50
assert sampling_params.max_tokens == 100
def test_extra_sampling_params(self):
"""Test extra sampling parameters are correctly mapped."""
request = ResponsesRequest(
model="test-model",
input="test input",
repetition_penalty=1.2,
seed=42,
stop=["END", "STOP"],
ignore_eos=True,
vllm_xargs={"custom": "value"},
)
sampling_params = request.to_sampling_params(default_max_tokens=1000)
assert sampling_params.repetition_penalty == 1.2
assert sampling_params.seed == 42
assert sampling_params.stop == ["END", "STOP"]
assert sampling_params.ignore_eos is True
assert sampling_params.extra_args == {"custom": "value"}
def test_stop_string_conversion(self):
"""Test that single stop string is converted to list."""
request = ResponsesRequest(
model="test-model",
input="test input",
stop="STOP",
)
sampling_params = request.to_sampling_params(default_max_tokens=1000)
assert sampling_params.stop == ["STOP"]
def test_default_values(self):
"""Test default values for optional parameters."""
request = ResponsesRequest(
model="test-model",
input="test input",
)
sampling_params = request.to_sampling_params(default_max_tokens=1000)
assert sampling_params.repetition_penalty == 1.0 # None → 1.0
assert sampling_params.stop == [] # Empty list
assert sampling_params.extra_args == {} # Empty dict
def test_seed_bounds_validation(self):
"""Test that seed values outside torch.long bounds are rejected."""
import torch
from pydantic import ValidationError
# Test seed below minimum
with pytest.raises(ValidationError) as exc_info:
ResponsesRequest(
model="test-model",
input="test input",
seed=torch.iinfo(torch.long).min - 1,
)
assert "greater_than_equal" in str(exc_info.value).lower()
# Test seed above maximum
with pytest.raises(ValidationError) as exc_info:
ResponsesRequest(
model="test-model",
input="test input",
seed=torch.iinfo(torch.long).max + 1,
)
assert "less_than_equal" in str(exc_info.value).lower()
# Test valid seed at boundaries
request_min = ResponsesRequest(
model="test-model",
input="test input",
seed=torch.iinfo(torch.long).min,
)
assert request_min.seed == torch.iinfo(torch.long).min
request_max = ResponsesRequest(
model="test-model",
input="test input",
seed=torch.iinfo(torch.long).max,
)
assert request_max.seed == torch.iinfo(torch.long).max
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import pytest_asyncio
from openai import OpenAI
......@@ -147,3 +146,27 @@ async def test_max_tokens(client: OpenAI, model_name: str):
assert response is not None
assert response.status == "incomplete"
assert response.incomplete_details.reason == "max_output_tokens"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_extra_sampling_params(client: OpenAI, model_name: str):
"""Test that extra sampling parameters are accepted and work."""
# Test with multiple sampling parameters - just verify they're accepted
response = await client.responses.create(
model=model_name,
input="Write a short sentence",
max_output_tokens=50,
temperature=0.7,
top_p=0.9,
extra_body={
"top_k": 40,
"repetition_penalty": 1.2,
"seed": 42,
},
)
# Verify request succeeded and parameters were accepted
assert response.status in ["completed", "incomplete"]
assert len(response.output) > 0
assert response.output[0].content[0].text # Has text output
......@@ -6,6 +6,7 @@
import time
from typing import Any, Literal, TypeAlias
import torch
from openai.types.responses import (
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
......@@ -77,6 +78,8 @@ from vllm.utils import random_uuid
logger = init_logger(__name__)
_LONG_INFO = torch.iinfo(torch.long)
class InputTokensDetails(OpenAIBaseModel):
cached_tokens: int
......@@ -230,6 +233,18 @@ class ResponsesRequest(OpenAIBaseModel):
# this cannot be used in conjunction with previous_response_id
# TODO: consider supporting non harmony messages as well
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
repetition_penalty: float | None = None
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: str | list[str] | None = []
ignore_eos: bool = False
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
default=None,
description=(
"Additional request parameters with (list of) string or "
"numeric values, used by custom extensions."
),
)
# --8<-- [end:responses-extra-params]
def build_chat_params(
......@@ -297,6 +312,10 @@ class ResponsesRequest(OpenAIBaseModel):
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get("repetition_penalty", 1.0)
stop_token_ids = default_sampling_params.get("stop_token_ids")
# Structured output
......@@ -313,7 +332,10 @@ class ResponsesRequest(OpenAIBaseModel):
elif response_format.type == "json_object":
raise NotImplementedError("json_object is not supported")
# TODO: add more parameters
stop = self.stop if self.stop else []
if isinstance(stop, str):
stop = [stop]
return SamplingParams.from_optional(
temperature=temperature,
top_p=top_p,
......@@ -321,11 +343,16 @@ class ResponsesRequest(OpenAIBaseModel):
max_tokens=max_tokens,
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
stop_token_ids=stop_token_ids,
stop=stop,
repetition_penalty=repetition_penalty,
seed=self.seed,
ignore_eos=self.ignore_eos,
output_kind=(
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
),
structured_outputs=structured_outputs,
logit_bias=self.logit_bias,
extra_args=self.vllm_xargs or {},
skip_clone=True, # Created fresh per request, safe to skip clone
skip_special_tokens=self.skip_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment