Unverified Commit 9ccbf6b6 authored by RioS's avatar RioS Committed by GitHub
Browse files

[responsesAPI]add extra body parameters (#30532)


Signed-off-by: default avatarRi0S <aa248424@gmail.com>
parent ae2e503d
...@@ -320,6 +320,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -320,6 +320,7 @@ class ResponsesRequest(OpenAIBaseModel):
max_tool_calls: int | None = None max_tool_calls: int | None = None
metadata: Metadata | None = None metadata: Metadata | None = None
model: str | None = None model: str | None = None
logit_bias: dict[str, float] | None = None
parallel_tool_calls: bool | None = True parallel_tool_calls: bool | None = True
previous_response_id: str | None = None previous_response_id: str | None = None
prompt: ResponsePrompt | None = None prompt: ResponsePrompt | None = None
...@@ -333,6 +334,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -333,6 +334,7 @@ class ResponsesRequest(OpenAIBaseModel):
tools: list[Tool] = Field(default_factory=list) tools: list[Tool] = Field(default_factory=list)
top_logprobs: int | None = 0 top_logprobs: int | None = 0
top_p: float | None = None top_p: float | None = None
top_k: int | None = None
truncation: Literal["auto", "disabled"] | None = "disabled" truncation: Literal["auto", "disabled"] | None = "disabled"
user: str | None = None user: str | None = None
...@@ -387,6 +389,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -387,6 +389,7 @@ class ResponsesRequest(OpenAIBaseModel):
_DEFAULT_SAMPLING_PARAMS = { _DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0, "temperature": 1.0,
"top_p": 1.0, "top_p": 1.0,
"top_k": 0,
} }
def to_sampling_params( def to_sampling_params(
...@@ -408,6 +411,10 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -408,6 +411,10 @@ class ResponsesRequest(OpenAIBaseModel):
top_p = default_sampling_params.get( top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
) )
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
stop_token_ids = default_sampling_params.get("stop_token_ids") stop_token_ids = default_sampling_params.get("stop_token_ids")
# Structured output # Structured output
...@@ -428,6 +435,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -428,6 +435,7 @@ class ResponsesRequest(OpenAIBaseModel):
return SamplingParams.from_optional( return SamplingParams.from_optional(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None, logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
...@@ -435,6 +443,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -435,6 +443,7 @@ class ResponsesRequest(OpenAIBaseModel):
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
), ),
structured_outputs=structured_outputs, structured_outputs=structured_outputs,
logit_bias=self.logit_bias,
) )
def is_include_output_logprobs(self) -> bool: def is_include_output_logprobs(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment