Unverified Commit 71f24ef8 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

feat: add cache_salt support to request (#10718)


Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent b1f0fc1c
...@@ -228,6 +228,10 @@ class CompletionRequest(BaseModel): ...@@ -228,6 +228,10 @@ class CompletionRequest(BaseModel):
# For request id # For request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request # Priority for the request
priority: Optional[int] = None priority: Optional[int] = None
...@@ -545,6 +549,10 @@ class ChatCompletionRequest(BaseModel): ...@@ -545,6 +549,10 @@ class ChatCompletionRequest(BaseModel):
# For request id # For request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request # Priority for the request
priority: Optional[int] = None priority: Optional[int] = None
...@@ -778,6 +786,13 @@ class ResponsesRequest(BaseModel): ...@@ -778,6 +786,13 @@ class ResponsesRequest(BaseModel):
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.", description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
) )
priority: int = Field(default=0, description="Request priority") priority: int = Field(default=0, description="Request priority")
extra_key: Optional[str] = Field(
default=None,
description="Extra key for classifying the request (e.g. cache_salt)",
)
cache_salt: Optional[str] = Field(
default=None, description="Cache salt for request caching"
)
# SGLang-specific sampling parameters # SGLang-specific sampling parameters
frequency_penalty: float = 0.0 frequency_penalty: float = 0.0
......
...@@ -86,6 +86,19 @@ class OpenAIServingBase(ABC): ...@@ -86,6 +86,19 @@ class OpenAIServingBase(ABC):
return f"{self._request_id_prefix()}{uuid.uuid4().hex}" return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]:
"""Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
parts = []
for key in ["cache_salt", "extra_key"]:
value = getattr(request, key, None)
if value:
if not isinstance(value, str):
raise TypeError(
f"Value of {key} must be a string, but got {type(value).__name__}"
)
parts.append(value)
return "".join(parts) if parts else None
@abstractmethod @abstractmethod
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
......
...@@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid, rid=request.rid,
extra_key=self._compute_extra_key(request),
priority=request.priority, priority=request.priority,
customer_labels=customer_labels, customer_labels=customer_labels,
) )
......
...@@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid, rid=request.rid,
extra_key=self._compute_extra_key(request),
priority=request.priority, priority=request.priority,
customer_labels=customer_labels, customer_labels=customer_labels,
) )
......
...@@ -245,6 +245,7 @@ class OpenAIServingResponses(OpenAIServingChat): ...@@ -245,6 +245,7 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params=sampling_params, sampling_params=sampling_params,
stream=request.stream, stream=request.stream,
rid=request.request_id, rid=request.request_id,
extra_key=self._compute_extra_key(request),
background=request.background, background=request.background,
) )
...@@ -1250,6 +1251,7 @@ class OpenAIServingResponses(OpenAIServingChat): ...@@ -1250,6 +1251,7 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params=sampling_params, sampling_params=sampling_params,
stream=adapted_request.stream, stream=adapted_request.stream,
rid=request_id, rid=request_id,
extra_key=adapted_request.extra_key,
return_logprob=adapted_request.return_logprob, return_logprob=adapted_request.return_logprob,
logprob_start_len=adapted_request.logprob_start_len, logprob_start_len=adapted_request.logprob_start_len,
top_logprobs_num=adapted_request.top_logprobs_num, top_logprobs_num=adapted_request.top_logprobs_num,
......
...@@ -84,6 +84,8 @@ class GenerateReqInput: ...@@ -84,6 +84,8 @@ class GenerateReqInput:
sampling_params: Optional[Union[List[Dict], Dict]] = None sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id. # The request id.
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Whether to return logprobs. # Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
# If return logprobs, the start location in the prompt for returning logprobs. # If return logprobs, the start location in the prompt for returning logprobs.
...@@ -606,6 +608,9 @@ class TokenizedGenerateReqInput: ...@@ -606,6 +608,9 @@ class TokenizedGenerateReqInput:
# Priority for the request # Priority for the request
priority: Optional[int] = None priority: Optional[int] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[str] = None
# Image gen grpc migration # Image gen grpc migration
return_bytes: bool = False return_bytes: bool = False
......
...@@ -491,7 +491,7 @@ class Req: ...@@ -491,7 +491,7 @@ class Req:
self.custom_logit_processor = custom_logit_processor self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
# extra key for classifying the request (e.g. lora_id, cache_salt) # extra key for classifying the request (e.g. cache_salt)
if lora_id is not None: if lora_id is not None:
extra_key = ( extra_key = (
extra_key or "" extra_key or ""
......
...@@ -750,6 +750,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -750,6 +750,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return_hidden_states=obj.return_hidden_states, return_hidden_states=obj.return_hidden_states,
data_parallel_rank=obj.data_parallel_rank, data_parallel_rank=obj.data_parallel_rank,
priority=obj.priority, priority=obj.priority,
extra_key=obj.extra_key,
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
......
...@@ -207,6 +207,84 @@ class TestCacheReport(CustomTestCase): ...@@ -207,6 +207,84 @@ class TestCacheReport(CustomTestCase):
# asyncio.run(run_test()) # asyncio.run(run_test())
def test_cache_salt_effectiveness(self):
print("=" * 100)
print("Testing cache_salt effectiveness")
# Use a unique message to avoid interference with other tests
test_message = "What is the capital of Japan?"
# First request with cache_salt "salt1"
response1 = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": test_message}],
temperature=0,
max_tokens=10,
extra_body={"cache_salt": "salt1"},
)
cached_tokens_1_first = int(response1.usage.prompt_tokens_details.cached_tokens)
prompt_tokens_1 = int(response1.usage.prompt_tokens)
print(
f"First request with salt1 - cached_tokens: {cached_tokens_1_first}, prompt_tokens: {prompt_tokens_1}"
)
# Second request with same cache_salt "salt1" - should get cache hit
response2 = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": test_message}],
temperature=0,
max_tokens=10,
extra_body={"cache_salt": "salt1"},
)
cached_tokens_1_second = int(
response2.usage.prompt_tokens_details.cached_tokens
)
print(
f"Second request with salt1 - cached_tokens: {cached_tokens_1_second}, prompt_tokens: {prompt_tokens_1}"
)
# Verify cache hit for same salt
assert (
cached_tokens_1_second > cached_tokens_1_first
), "Should have cache hit with same cache_salt"
assert (
cached_tokens_1_second == prompt_tokens_1 - 1
), "Should cache all prompt tokens except the last one"
# Third request with different cache_salt "salt2" - should not get cache hit
response3 = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": test_message}],
temperature=0,
max_tokens=10,
extra_body={"cache_salt": "salt2"},
)
cached_tokens_2_first = int(response3.usage.prompt_tokens_details.cached_tokens)
print(f"First request with salt2 - cached_tokens: {cached_tokens_2_first}")
# Verify no cache hit for different salt (should be similar to first request with salt1)
assert (
cached_tokens_2_first <= cached_tokens_1_first + self.min_cached
), "Different cache_salt should not share cache"
# Fourth request with same cache_salt "salt2" - should now get cache hit
response4 = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": test_message}],
temperature=0,
max_tokens=10,
extra_body={"cache_salt": "salt2"},
)
cached_tokens_2_second = int(
response4.usage.prompt_tokens_details.cached_tokens
)
print(f"Second request with salt2 - cached_tokens: {cached_tokens_2_second}")
# Verify cache hit for salt2
assert (
cached_tokens_2_second == cached_tokens_2_first
), "Should have cache hit with same cache_salt for salt2"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment