Unverified Commit dfada85e authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[Frontend] Expose custom args in OpenAI APIs (#16862)


Signed-off-by: default avatarAndrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: default avatarAndrew Feldman <afeldman@redhat.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent ed333497
...@@ -4,12 +4,12 @@ import argparse ...@@ -4,12 +4,12 @@ import argparse
import itertools import itertools
import torch import torch
import triton
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size_triton, moe_align_block_size_triton,
) )
from vllm.triton_utils import triton
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
......
...@@ -326,7 +326,8 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -326,7 +326,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
) )
chat_template_kwargs: Optional[dict[str, Any]] = Field( chat_template_kwargs: Optional[dict[str, Any]] = Field(
default=None, default=None,
description=("Additional kwargs to pass to the template renderer. " description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."), "Will be accessible by the chat template."),
) )
mm_processor_kwargs: Optional[dict[str, Any]] = Field( mm_processor_kwargs: Optional[dict[str, Any]] = Field(
...@@ -414,6 +415,12 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -414,6 +415,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None, default=None,
description="KVTransfer parameters used for disaggregated serving.") description="KVTransfer parameters used for disaggregated serving.")
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
default=None,
description=("Additional request parameters with string or "
"numeric values, used by custom extensions."),
)
# --8<-- [end:chat-completion-extra-params] # --8<-- [end:chat-completion-extra-params]
# Default sampling parameters for chat completion requests # Default sampling parameters for chat completion requests
...@@ -523,6 +530,10 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -523,6 +530,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
structural_tag=self.structural_tag, structural_tag=self.structural_tag,
) )
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional( return SamplingParams.from_optional(
n=self.n, n=self.n,
best_of=self.best_of, best_of=self.best_of,
...@@ -553,8 +564,8 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -553,8 +564,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
bad_words= self.bad_words, bad_words= self.bad_words,
allowed_token_ids=self.allowed_token_ids, allowed_token_ids=self.allowed_token_ids,
extra_args=({"kv_transfer_params": self.kv_transfer_params} extra_args=extra_args or None,
if self.kv_transfer_params else None)) )
def _get_guided_json_from_tool( def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]: self) -> Optional[Union[str, dict, BaseModel]]:
...@@ -871,6 +882,12 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -871,6 +882,12 @@ class CompletionRequest(OpenAIBaseModel):
default=None, default=None,
description="KVTransfer parameters used for disaggregated serving.") description="KVTransfer parameters used for disaggregated serving.")
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
default=None,
description=("Additional request parameters with string or "
"numeric values, used by custom extensions."),
)
# --8<-- [end:completion-extra-params] # --8<-- [end:completion-extra-params]
# Default sampling parameters for completion requests # Default sampling parameters for completion requests
...@@ -968,6 +985,10 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -968,6 +985,10 @@ class CompletionRequest(OpenAIBaseModel):
whitespace_pattern=self.guided_whitespace_pattern, whitespace_pattern=self.guided_whitespace_pattern,
) )
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional( return SamplingParams.from_optional(
n=self.n, n=self.n,
best_of=self.best_of, best_of=self.best_of,
...@@ -997,8 +1018,8 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -997,8 +1018,8 @@ class CompletionRequest(OpenAIBaseModel):
guided_decoding=guided_decoding, guided_decoding=guided_decoding,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids, allowed_token_ids=self.allowed_token_ids,
extra_args=({"kv_transfer_params": self.kv_transfer_params} extra_args=extra_args or None,
if self.kv_transfer_params else None)) )
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
...@@ -1117,7 +1138,8 @@ class EmbeddingChatRequest(OpenAIBaseModel): ...@@ -1117,7 +1138,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
) )
chat_template_kwargs: Optional[dict[str, Any]] = Field( chat_template_kwargs: Optional[dict[str, Any]] = Field(
default=None, default=None,
description=("Additional kwargs to pass to the template renderer. " description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."), "Will be accessible by the chat template."),
) )
mm_processor_kwargs: Optional[dict[str, Any]] = Field( mm_processor_kwargs: Optional[dict[str, Any]] = Field(
...@@ -1623,7 +1645,8 @@ class TokenizeChatRequest(OpenAIBaseModel): ...@@ -1623,7 +1645,8 @@ class TokenizeChatRequest(OpenAIBaseModel):
) )
chat_template_kwargs: Optional[dict[str, Any]] = Field( chat_template_kwargs: Optional[dict[str, Any]] = Field(
default=None, default=None,
description=("Additional kwargs to pass to the template renderer. " description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."), "Will be accessible by the chat template."),
) )
mm_processor_kwargs: Optional[dict[str, Any]] = Field( mm_processor_kwargs: Optional[dict[str, Any]] = Field(
...@@ -1736,6 +1759,12 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -1736,6 +1759,12 @@ class TranscriptionRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data. # Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False stream_include_usage: Optional[bool] = False
stream_continuous_usage_stats: Optional[bool] = False stream_continuous_usage_stats: Optional[bool] = False
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
default=None,
description=("Additional request parameters with string or "
"numeric values, used by custom extensions."),
)
# --8<-- [end:transcription-extra-params] # --8<-- [end:transcription-extra-params]
# --8<-- [start:transcription-sampling-params] # --8<-- [start:transcription-sampling-params]
...@@ -1823,7 +1852,8 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -1823,7 +1852,8 @@ class TranscriptionRequest(OpenAIBaseModel):
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
output_kind=RequestOutputKind.DELTA output_kind=RequestOutputKind.DELTA
if self.stream \ if self.stream \
else RequestOutputKind.FINAL_ONLY) else RequestOutputKind.FINAL_ONLY,
extra_args=self.vllm_xargs)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
......
...@@ -198,8 +198,8 @@ class SamplingParams( ...@@ -198,8 +198,8 @@ class SamplingParams(
processor which only retains scores for the given token ids. processor which only retains scores for the given token ids.
Defaults to None. Defaults to None.
extra_args: Arbitrary additional args, that can be used by custom extra_args: Arbitrary additional args, that can be used by custom
sampling implementations. Not used by any in-tree sampling sampling implementations, plugins, etc. Not used by any in-tree
implementations. sampling implementations.
""" """
n: int = 1 n: int = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment