Unverified Commit 0d3ff440 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: enable toggling kv events pub/sub (currently nats based) with --no-kv-events flag (#5237)

parent bb8eaa23
......@@ -246,6 +246,89 @@ class ToolCallingChatPayload(ChatPayload):
logger.info(f"Expected tool '{self.expected_tool_name}' was called")
@dataclass
class CachedTokensChatPayload(ChatPayload):
"""
Chat payload that validates cached tokens are populated in repeated requests.
Used for testing KV router cache-aware routing where repeated identical prompts
should result in cached tokens being reported in the usage field.
Validates that usage.prompt_tokens_details.cached_tokens > 0 for requests
after the first one (since identical prompts should hit the prefix cache).
"""
def __init__(
self,
body: dict,
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
timeout: int = 60,
min_cached_tokens: int = 1,
):
super().__init__(
body=body,
repeat_count=repeat_count,
expected_response=expected_response or [],
expected_log=expected_log or [],
timeout=timeout,
)
self.min_cached_tokens = min_cached_tokens
self._request_count = 0
self._cached_tokens_found = False
def validate(self, response: Any, content: str) -> None:
"""Validate response and check for cached tokens on repeated requests."""
# First run the standard content validation
super().validate(response, content)
self._request_count += 1
result = response.json()
# Check usage field for cached tokens
# Expected structure: usage.prompt_tokens_details.cached_tokens
usage = result.get("usage", {})
prompt_tokens_details = usage.get("prompt_tokens_details") or {}
cached_tokens = prompt_tokens_details.get("cached_tokens", 0) or 0
logger.info(
f"Request {self._request_count}: prompt_tokens={usage.get('prompt_tokens')}, "
f"cached_tokens={cached_tokens}, prompt_tokens_details={prompt_tokens_details}"
)
# For requests after the first one, we expect cached tokens > 0
# (since identical prompts should hit the prefix cache)
if self._request_count > 1:
if cached_tokens >= self.min_cached_tokens:
self._cached_tokens_found = True
logger.info(
f"✓ Request {self._request_count}: Cached tokens validation PASSED - "
f"found {cached_tokens} cached tokens (min required: {self.min_cached_tokens})"
)
else:
logger.warning(
f"Request {self._request_count}: cached_tokens={cached_tokens} "
f"(expected >= {self.min_cached_tokens})"
)
def final_validation(self) -> None:
"""Called after all requests are processed to ensure we saw cached tokens.
Raises AssertionError if cached tokens were not found on any repeated request.
"""
if self.repeat_count > 1 and not self._cached_tokens_found:
raise AssertionError(
f"Expected cached_tokens >= {self.min_cached_tokens} in "
f"prompt_tokens_details for at least one repeated request, "
f"but none found after {self._request_count} requests. "
f"Verify that prefix caching is enabled and working correctly."
)
logger.info(
"✓ Final validation PASSED: cached_tokens found in repeated requests"
)
@dataclass
class LoraTestChatPayload(ChatPayload):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment