Unverified Commit 19c86306 authored by Marko Rosenmueller's avatar Marko Rosenmueller Committed by GitHub
Browse files

[Frontend] Support cache_salt in /v1/completions and /v1/responses (#20981)


Signed-off-by: default avatarMarko Rosenmueller <5467316+dr75@users.noreply.github.com>
parent f29fd8a7
......@@ -1540,6 +1540,7 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None
state.openai_serving_pooling = OpenAIServingPooling(
......
......@@ -290,6 +290,15 @@ class ResponsesRequest(OpenAIBaseModel):
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."),
)
cache_salt: Optional[str] = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0."))
# --8<-- [end:responses-extra-params]
_DEFAULT_SAMPLING_PARAMS = {
......@@ -351,6 +360,19 @@ class ResponsesRequest(OpenAIBaseModel):
raise ValueError("prompt template is not supported")
return data
@model_validator(mode="before")
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0.")
if not isinstance(data["cache_salt"],
str) or not data["cache_salt"]:
raise ValueError("Parameter 'cache_salt' must be a "
"non-empty string if provided.")
return data
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
......@@ -1004,6 +1026,16 @@ class CompletionRequest(OpenAIBaseModel):
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))
cache_salt: Optional[str] = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0."))
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.")
......@@ -1180,6 +1212,20 @@ class CompletionRequest(OpenAIBaseModel):
"At least one of `prompt` or `prompt_embeds` must be set.")
return data
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0.")
if not isinstance(data["cache_salt"],
str) or not data["cache_salt"]:
raise ValueError("Parameter 'cache_salt' must be a "
"non-empty string if provided.")
return data
class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
......
......@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import (
......@@ -56,6 +57,7 @@ class OpenAIServingCompletion(OpenAIServing):
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(engine_client=engine_client,
......@@ -64,6 +66,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
......@@ -313,6 +316,8 @@ class OpenAIServingCompletion(OpenAIServing):
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
num_cached_tokens = None
first_iteration = True
stream_options = request.stream_options
if stream_options:
......@@ -328,6 +333,10 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
if first_iteration:
num_cached_tokens = res.num_cached_tokens
first_iteration = False
if res.prompt is not None:
prompt_text = res.prompt
else:
......@@ -431,6 +440,10 @@ class OpenAIServingCompletion(OpenAIServing):
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)
if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
......@@ -535,6 +548,10 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
request_metadata.final_usage_info = usage
return CompletionResponse(
......
......@@ -811,6 +811,12 @@ class OpenAIServing:
prompt_token_ids=request_prompt_text["prompt_token_ids"])
for request_prompt_text in request_prompts_text
]
cache_salt = request.cache_salt if (
hasattr(request, "cache_salt")
and request.cache_salt is not None) else None
if cache_salt:
for prompt_text in engine_prompts_text:
prompt_text["cache_salt"] = cache_salt
# This check is equivalent to simply checking if
# `request_prompts_embeds` is empty, but it's difficult to propagate
......@@ -828,6 +834,9 @@ class OpenAIServing:
prompt_embeds=request_prompt_embeds["prompt_embeds"])
for request_prompt_embeds in request_prompts_embeds
]
if cache_salt:
for prompt_embed in engine_prompts_embeds:
prompt_embed["cache_salt"] = cache_salt
request_prompts = request_prompts_embeds + request_prompts_text
engine_prompts = engine_prompts_embeds + engine_prompts_text
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment