Unverified Commit 9975cb9f authored by Ziqi Fan's avatar Ziqi Fan Committed by GitHub
Browse files

feat: enable KVBM to support PD disagg in Dynamo vLLM (#3352)


Signed-off-by: default avatarZiqi Fan <ziqif@nvidia.com>
parent 0aa0768f
...@@ -40,7 +40,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1)) ...@@ -40,7 +40,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
| [**KV-Aware Routing**](../../../docs/architecture/kv_cache_routing.md) | ✅ | | | [**KV-Aware Routing**](../../../docs/architecture/kv_cache_routing.md) | ✅ | |
| [**SLA-Based Planner**](../../../docs/architecture/sla_planner.md) | ✅ | | | [**SLA-Based Planner**](../../../docs/architecture/sla_planner.md) | ✅ | |
| [**Load Based Planner**](../../../docs/architecture/load_planner.md) | 🚧 | WIP | | [**Load Based Planner**](../../../docs/architecture/load_planner.md) | 🚧 | WIP |
| [**KVBM**](../../../docs/architecture/kvbm_architecture.md) | 🚧 | WIP | | [**KVBM**](../../../docs/architecture/kvbm_architecture.md) | | |
| [**LMCache**](./LMCache_Integration.md) | ✅ | | | [**LMCache**](./LMCache_Integration.md) | ✅ | |
### Large Scale P/D and WideEP Features ### Large Scale P/D and WideEP Features
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# run ingress
python -m dynamo.frontend --http-port=8000 &
# run worker with KVBM enabled
# NOTE: remove --enforce-eager for production use
DYN_KVBM_CPU_CACHE_GB=20 \
python -m dynamo.vllm --model Qwen/Qwen3-0.6B --connector kvbm --enforce-eager
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# run ingress with KV router
python -m dynamo.frontend --router-mode kv --http-port=8000 &
# run decode worker on GPU 0, without enabling KVBM
# NOTE: remove --enforce-eager for production use
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --connector nixl --enforce-eager &
# run prefill worker on GPU 1 with KVBM enabled using 20GB of CPU cache
# NOTE: remove --enforce-eager for production use
DYN_KVBM_CPU_CACHE_GB=20 \
CUDA_VISIBLE_DEVICES=1 \
python3 -m dynamo.vllm \
--model Qwen/Qwen3-0.6B \
--is-prefill-worker \
--connector kvbm nixl \
--enforce-eager
...@@ -351,11 +351,12 @@ def create_kv_transfer_config(config: Config) -> Optional[KVTransferConfig]: ...@@ -351,11 +351,12 @@ def create_kv_transfer_config(config: Config) -> Optional[KVTransferConfig]:
cfg = multi_connectors[0] cfg = multi_connectors[0]
return KVTransferConfig(**cfg) return KVTransferConfig(**cfg)
# For multiple connectors, use MultiConnector # For multiple connectors, use PdConnector
return KVTransferConfig( return KVTransferConfig(
kv_connector="MultiConnector", kv_connector="PdConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"connectors": multi_connectors}, kv_connector_extra_config={"connectors": multi_connectors},
kv_connector_module_path="dynamo.llm.vllm_integration.connector",
) )
......
...@@ -223,7 +223,7 @@ if [ "$ARCH" = "amd64" ]; then ...@@ -223,7 +223,7 @@ if [ "$ARCH" = "amd64" ]; then
# TODO: Re-enable for arm64 after verifying lmcache compatibility and resolving the build issue. # TODO: Re-enable for arm64 after verifying lmcache compatibility and resolving the build issue.
# Alec: Likely lmcache was compiled witha different version of torch and need to install it from source for arm64 # Alec: Likely lmcache was compiled witha different version of torch and need to install it from source for arm64
uv pip install lmcache==0.3.3 uv pip install lmcache==0.3.7
echo "✓ LMCache installed" echo "✓ LMCache installed"
else else
echo "⚠ Skipping LMCache on ARM64 (compatibility issues)" echo "⚠ Skipping LMCache on ARM64 (compatibility issues)"
......
...@@ -25,6 +25,7 @@ To learn what KVBM is, please check [here](https://docs.nvidia.com/dynamo/latest ...@@ -25,6 +25,7 @@ To learn what KVBM is, please check [here](https://docs.nvidia.com/dynamo/latest
To use KVBM in vLLM, you can follow the steps below: To use KVBM in vLLM, you can follow the steps below:
### Docker Setup
```bash ```bash
# start up etcd for KVBM leader/worker registration and discovery # start up etcd for KVBM leader/worker registration and discovery
docker compose -f deploy/docker-compose.yml up -d docker compose -f deploy/docker-compose.yml up -d
...@@ -34,26 +35,32 @@ docker compose -f deploy/docker-compose.yml up -d ...@@ -34,26 +35,32 @@ docker compose -f deploy/docker-compose.yml up -d
# launch the container # launch the container
./container/run.sh --framework vllm -it --mount-workspace --use-nixl-gds ./container/run.sh --framework vllm -it --mount-workspace --use-nixl-gds
```
# enable kv offloading to CPU memory ### Aggregated Serving with KVBM
# 4 means 4GB of CPU memory would be used ```bash
export DYN_KVBM_CPU_CACHE_GB=4 cd $DYNAMO_HOME/components/backends/vllm
./launch/agg_kvbm.sh
# enable kv offloading to disk ```
# 8 means 8GB of disk would be used
export DYN_KVBM_DISK_CACHE_GB=8
# [DYNAMO] start dynamo frontend ### Disaggregated Serving with KVBM (1P1D)
python -m dynamo.frontend --http-port 8000 & ```bash
# NOTE: need at least 2 GPUs
cd $DYNAMO_HOME/components/backends/vllm
./launch/disagg_kvbm.sh
```
> [!NOTE]
> To tune the size of CPU or disk cache, set `DYN_KVBM_CPU_CACHE_GB` and `DYN_KVBM_DISK_CACHE_GB` accordingly. We only set `DYN_KVBM_CPU_CACHE_GB=20` in both scripts above.
# [DYNAMO] serve an LLM model using KVBM with dynamo > [!NOTE]
python -m dynamo.vllm \ > `DYN_KVBM_CPU_CACHE_GB` must be set and `DYN_KVBM_DISK_CACHE_GB` is optional.
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--connector kvbm &
# make a call to LLM ### Sample Request
```bash
# make a request to verify vLLM with KVBM is started up correctly
# NOTE: change the model name if served with a different one
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "model": "Qwen/Qwen3-0.6B",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
...@@ -61,11 +68,11 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" ...@@ -61,11 +68,11 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json"
} }
], ],
"stream":false, "stream":false,
"max_tokens": 30 "max_tokens": 10
}' }'
``` ```
Alternatively, can use "vllm serve" with KVBM by replacing the above two [DYNAMO] cmds with below: Alternatively, can use `vllm serve` directly to use KVBM for aggregated serving:
```bash ```bash
vllm serve --kv-transfer-config '{"kv_connector":"DynamoConnector","kv_role":"kv_both", "kv_connector_module_path": "dynamo.llm.vllm_integration.connector"}' deepseek-ai/DeepSeek-R1-Distill-Llama-8B vllm serve --kv-transfer-config '{"kv_connector":"DynamoConnector","kv_role":"kv_both", "kv_connector_module_path": "dynamo.llm.vllm_integration.connector"}' deepseek-ai/DeepSeek-R1-Distill-Llama-8B
``` ```
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_llm::block_manager::{ use dynamo_llm::block_manager::{
block::BlockId, connector::protocol::WorkerTransferRequest, distributed::BlockTransferRequest, block::BlockId, connector::protocol::WorkerTransferRequest, pool::BlockPoolError,
pool::BlockPoolError,
}; };
pub mod leader; pub mod leader;
...@@ -163,11 +162,3 @@ impl ConnectorMetadata { ...@@ -163,11 +162,3 @@ impl ConnectorMetadata {
self.operations.extend(xfer_reqs); self.operations.extend(xfer_reqs);
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectorOperation {
pub req_id: String,
pub iteration: u64,
pub uuid: uuid::Uuid,
pub xfer_req: BlockTransferRequest,
}
...@@ -3,8 +3,16 @@ ...@@ -3,8 +3,16 @@
# Import connector classes to make them available at the expected paths for vLLM # Import connector classes to make them available at the expected paths for vLLM
from .connector.dynamo_connector import DynamoConnector, DynamoConnectorMetadata from .connector.dynamo_connector import DynamoConnector, DynamoConnectorMetadata
from .connector.pd_connector import PdConnector, PdConnectorMetadata
# Create module-level alias for backward compatibility # Create module-level alias for backward compatibility
dynamo_connector = DynamoConnector dynamo_connector = DynamoConnector
pd_connector = PdConnector
__all__ = ["DynamoConnector", "DynamoConnectorMetadata", "dynamo_connector"] __all__ = [
"DynamoConnector",
"DynamoConnectorMetadata",
"dynamo_connector",
"PdConnector",
"PdConnectorMetadata",
]
...@@ -2,5 +2,11 @@ ...@@ -2,5 +2,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from .dynamo_connector import DynamoConnector, DynamoConnectorMetadata from .dynamo_connector import DynamoConnector, DynamoConnectorMetadata
from .pd_connector import PdConnector, PdConnectorMetadata
__all__ = ["DynamoConnector", "DynamoConnectorMetadata"] __all__ = [
"DynamoConnector",
"DynamoConnectorMetadata",
"PdConnector",
"PdConnectorMetadata",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import TYPE_CHECKING
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
LMCacheConnectorV1,
)
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
MultiKVConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import NixlConnector
from vllm.v1.core.sched.output import SchedulerOutput
from dynamo.llm.vllm_integration.connector.dynamo_connector import DynamoConnector
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@dataclass
class PdConnectorMetadata(MultiKVConnectorMetadata):
pass
class PdConnector(MultiConnector):
"""
A wrapper for using KV offloading Connectors (e.g. KVBM or LMCache) and NIXL Connector for PD disaggregated serving.
The current logic is:
- The first connector must be KVBM or LMCache and would be used by prefill worker to offload and onboard KV blocks.
- The second connector must be NIXL and will be used by decode worker to get KV blocks from prefill worker.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
if len(self._connectors) != 2:
raise ValueError(
f"PdConnector requires exactly two connectors (got {len(self._connectors)})"
)
if not isinstance(self._connectors[0], (DynamoConnector, LMCacheConnectorV1)):
raise TypeError(
f"Expected first connector to be DynamoConnector or LMCacheConnector, "
f"got {type(self._connectors[0]).__name__}"
)
if not isinstance(self._connectors[1], NixlConnector):
raise TypeError(
f"Expected second connector to be NixlConnector, "
f"got {type(self._connectors[1]).__name__}"
)
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(self, connector_metadata: PdConnectorMetadata) -> None:
assert isinstance(connector_metadata, PdConnectorMetadata)
if connector_metadata.extra_async_saves:
self._extra_async_saves.update(connector_metadata.extra_async_saves)
for c, cm in zip(self._connectors, connector_metadata.metadata):
c.bind_connector_metadata(cm)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Get the number of matched tokens for the request using Dynamo Connector (KVBM).
"""
return self._connectors[0].get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update the state after allocation using Dynamo Connector (KVBM) and Nixl Connector.
"""
empty_blocks = blocks.new_empty()
# allocate blocks for KV offloading connector to onboard KV blocks
self._connectors[0].update_state_after_alloc(
request, blocks, num_external_tokens
)
# no need to allocate any blocks for NIXL connector since this is in prefill worker side
# and it only needs to wait for decode worker to pull its data.
self._connectors[1].update_state_after_alloc(request, empty_blocks, 0)
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> PdConnectorMetadata:
metadata = PdConnectorMetadata(
metadata=tuple(
c.build_connector_meta(scheduler_output) for c in self._connectors
)
)
if self._extra_async_saves:
metadata.extra_async_saves = self._extra_async_saves
self._extra_async_saves = {}
return metadata
...@@ -55,14 +55,9 @@ pub fn load_and_validate_tensors( ...@@ -55,14 +55,9 @@ pub fn load_and_validate_tensors(
// Check the stride, and ensure our tensor is contiguous. // Check the stride, and ensure our tensor is contiguous.
// TODO: We eventually need to be able to handle this. // TODO: We eventually need to be able to handle this.
let stride = tensor.stride(); let stride = tensor.stride();
for i in 1..stride.len() { tracing::debug!("stride: {:?}", stride);
if stride[i] > stride[i - 1] { tracing::debug!("stride is monotonically decreasing for NHD layout");
return Err(anyhow::anyhow!( tracing::debug!("stride is NOT monotonically decreasing for HND layout");
"Tensor strides must be monotonically decreasing! Got {:?}",
stride
));
}
}
// Check that all layer tensors have the same shape. // Check that all layer tensors have the same shape.
// TODO: We eventually need to support the weirder models with heterogenous layers. // TODO: We eventually need to support the weirder models with heterogenous layers.
......
...@@ -31,10 +31,13 @@ Run all kvbm tests: ...@@ -31,10 +31,13 @@ Run all kvbm tests:
pytest -v -m "kvbm" -s pytest -v -m "kvbm" -s
``` ```
Run the determinism test file directly: Run the determinism test file directly inside dynamo repo:
```bash ```bash
pytest -v dynamo/tests/kvbm/test_determinism.py -s pytest -v tests/kvbm/test_determinism_agg.py -s
# disagg needs 2 GPUs to run
pytest -v tests/kvbm/test_determinism_disagg.py -s
``` ```
## Configuration ## Configuration
......
...@@ -3,270 +3,29 @@ ...@@ -3,270 +3,29 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Determinism test for language model API using pytest. Common functionality for KVBM determinism tests.
This test suite checks if the model produces deterministic outputs This module contains shared classes and functions used by both
when given the same inputs with fixed seed and temperature=0. aggregated and disaggregated determinism tests.
The test uses comprehensive server warmup (sending all test prompts
before validation) to avoid server initialization effects that could
impact determinism measurements.
""" """
import importlib.util
import logging
import os import os
import signal
import subprocess
import time import time
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, TextIO, Tuple from typing import Dict, List, Optional, Tuple
import pytest import pytest
import requests import requests
# Test markers to align with repository conventions
# Todo: enable the rest when kvbm is built in the ci
pytestmark = [
pytest.mark.kvbm,
pytest.mark.e2e,
pytest.mark.slow,
pytest.mark.gpu_1,
]
class ServerType(str, Enum): class ServerType(str, Enum):
vllm = "vllm" vllm = "vllm"
trtllm = "trtllm" trtllm = "trtllm"
class LLMServerManager:
"""Manages LLM server lifecycle for determinism testing."""
def __init__(
self,
base_url: Optional[str] = None,
port: Optional[int] = None,
cpu_cache_blocks: Optional[int] = None,
gpu_cache_blocks: Optional[int] = None,
log_dir: Optional[Path] = None,
server_type: Optional[str] = ServerType.vllm,
):
self.server_type = server_type
self.port = port or int(os.environ.get("KVBM_SERVER_PORT", "8000"))
self.base_url = base_url or f"http://localhost:{self.port}"
self.process: Optional[subprocess.Popen] = None
self.cpu_cache_blocks = cpu_cache_blocks
self.gpu_cache_blocks = gpu_cache_blocks
# Prepare logging
self.log_dir = log_dir or Path(".")
self.log_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config_str = (
f"cpu{cpu_cache_blocks or 'default'}_gpu{gpu_cache_blocks or 'default'}"
)
self.server_log_file = (
self.log_dir / f"{self.server_type}_server_{config_str}_{timestamp}.log"
)
self.server_stdout_file: Optional[TextIO] = None
self.server_stderr_file: Optional[TextIO] = None
# Environment for the process
self.env = os.environ.copy()
self.env.update(
{
"RUST_BACKTRACE": "1",
"DYN_LOG": os.environ.get(
"DYN_LOG", "debug,dynamo_llm::block_manager::layout=error"
),
# DynamoConnector connection settings
"NATS_SERVER": "nats://localhost:4222",
"ETCD_ENDPOINTS": "http://localhost:2379",
}
)
# CPU cache blocks override via env
if cpu_cache_blocks is not None:
self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks)
if self.server_type == ServerType.vllm:
self._set_up_vllm_config(gpu_cache_blocks)
elif self.server_type == ServerType.trtllm:
self._set_up_trtllm_config(gpu_cache_blocks)
else:
raise ValueError(
f"{self.server_type} is not supported yet in the KVBM test suite"
)
def _set_up_vllm_config(self, gpu_cache_blocks):
self.env["VLLM_SERVER_DEV_MODE"] = "1"
# Construct serve command
self.server_cmd = [
"vllm",
"serve",
"--block-size",
"16",
"--port",
str(self.port),
"--kv-transfer-config",
'{"kv_connector":"DynamoConnector","kv_role":"kv_both", "kv_connector_module_path": "dynamo.llm.vllm_integration.connector"}',
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--max-seq-len",
"8000", # required to fit on L4 GPU when using 8b model
]
# GPU blocks override
if gpu_cache_blocks is not None:
self.server_cmd.extend(["--num-gpu-blocks-override", str(gpu_cache_blocks)])
def _set_up_trtllm_config(self, gpu_cache_blocks):
config_path = os.environ.get(
"KVBM_TRTLLM_LLMAPI_CONFIG_PATH", "/tmp/kvbm_llm_api_config.yaml"
)
llm_api_config: dict[str, Any] = {}
llm_api_config[
"cuda_graph_config"
] = None # explicitly disable CUDA graph since Connector API doesn't support CUDA graph yet in TRTLLM
llm_api_config["kv_cache_config"] = {
"enable_partial_reuse": False,
"free_gpu_memory_fraction": 0.10, # Set a small GPU fraction so that we can evict/reset the on-device kv cache faster
}
llm_api_config["kv_connector_config"] = {
"connector_module": "dynamo.llm.trtllm_integration.connector",
"connector_scheduler_class": "DynamoKVBMConnectorLeader",
"connector_worker_class": "DynamoKVBMConnectorWorker",
}
# GPU blocks override
if gpu_cache_blocks is not None:
del llm_api_config["kv_cache_config"]["free_gpu_memory_fraction"]
llm_api_config["kv_cache_config"]["max_tokens"] = (
int(gpu_cache_blocks) * 32
) # TRTLLM defaults 32 tokens per block
# Construct serve command
self.server_cmd = [
"trtllm-serve",
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--host",
"localhost",
"--port",
str(self.port),
"--backend",
"pytorch",
"--extra_llm_api_options",
config_path,
]
import yaml
with open(config_path, "w") as f:
yaml.dump(llm_api_config, f, default_flow_style=False, sort_keys=False)
def start_server(self, timeout: int = 300) -> bool:
"""Start LLM server and wait for readiness."""
if self.is_server_running():
self.stop_server()
time.sleep(2)
# Open log files
self.server_stdout_file = open(
self.server_log_file.with_suffix(".stdout.log"), "w"
)
self.server_stderr_file = open(
self.server_log_file.with_suffix(".stderr.log"), "w"
)
if self.server_stdout_file is not None:
self.server_stdout_file.write(
f"=== {self.server_type} Server Started at {datetime.now()} ===\nCommand: {' '.join(self.server_cmd)}\n"
)
self.server_stdout_file.flush()
# Launch
self.process = subprocess.Popen(
self.server_cmd,
stdout=self.server_stdout_file,
stderr=self.server_stderr_file,
env=self.env,
preexec_fn=os.setsid,
)
# Wait for health
start_time = time.time()
while time.time() - start_time < timeout:
if self.is_server_running():
return True
if self.process.poll() is not None:
self._close_log_files()
return False
time.sleep(5)
# Timeout
self.stop_server()
return False
def stop_server(self):
"""Stop LLM server and close logs."""
if self.process:
try:
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
try:
self.process.wait(timeout=30)
except subprocess.TimeoutExpired:
os.killpg(os.getpgid(self.process.pid), signal.SIGKILL)
self.process.wait()
except (ProcessLookupError, OSError):
pass
finally:
self.process = None
self._close_log_files()
def _close_log_files(self):
if self.server_stdout_file:
self.server_stdout_file.write(
f"\n=== Server Stopped at {datetime.now()} ===\n"
)
self.server_stdout_file.close()
self.server_stdout_file = None
if self.server_stderr_file:
self.server_stderr_file.close()
self.server_stderr_file = None
def is_server_running(self) -> bool:
try:
# First check basic health
response = requests.get(f"{self.base_url}/health", timeout=5)
if response.status_code != 200:
return False
# Then check if the model endpoint is ready with a simple test request
test_payload = {
"model": os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
),
"messages": [{"role": "user", "content": "test"}],
"max_completion_tokens": 1,
"temperature": 0,
}
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Content-Type": "application/json"},
json=test_payload,
timeout=10,
)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
class DeterminismTester: class DeterminismTester:
"""Test class for model determinism validation.""" """Test class for model determinism validation."""
...@@ -366,33 +125,6 @@ class DeterminismTester: ...@@ -366,33 +125,6 @@ class DeterminismTester:
data = response.json() data = response.json()
return data["choices"][0]["message"]["content"] return data["choices"][0]["message"]["content"]
def reset_prefix_cache(self):
"""Reset the prefix cache."""
print("Resetting prefix cache...")
if self.server_type == ServerType.trtllm:
# TRTLLM doesn't support reset_prefix_cache endpoint API
# 300 shakespeare content could evict the 0.1 x 80G (~1700 blocks) on-device cache
shakespeare_count = 300
for seq_idx in range(1, shakespeare_count + 1):
start_word = (seq_idx - 1) * self.word_count
content = self.get_shakespeare_content(start_word)
if content:
print(
f"Resetting Shakespeare sequence {seq_idx} (words {start_word}-{start_word + self.word_count - 1})..."
)
try:
self.make_request(content)
except Exception as e:
print(f"Resetting request failed: {e}")
else:
response = requests.post(
f"{self.base_url}/reset_prefix_cache",
timeout=int(os.environ.get("KVBM_HTTP_TIMEOUT", "30")),
)
response.raise_for_status()
print("Cache reset done")
def warmup_server(self): def warmup_server(self):
"""Perform comprehensive server warmup with all test prompts.""" """Perform comprehensive server warmup with all test prompts."""
print("=" * 70) print("=" * 70)
...@@ -710,51 +442,6 @@ class DeterminismTester: ...@@ -710,51 +442,6 @@ class DeterminismTester:
return success_rate == 1.0 return success_rate == 1.0
@pytest.fixture(scope="function")
def llm_server(request, runtime_services):
"""Start and stop a LLM server for each test with optional cache block overrides.
To parametrize, use:
@pytest.mark.parametrize("llm_server", [{"cpu_blocks": 10000, "gpu_blocks": 2048}], indirect=True)
"""
logger = logging.getLogger("pytest")
logger.setLevel(logging.INFO)
cpu_blocks = getattr(request, "param", {}).get("cpu_blocks", None)
gpu_blocks = getattr(request, "param", {}).get("gpu_blocks", None)
port = getattr(request, "param", {}).get("port", None)
# Put logs in the per-test directory set up by tests/conftest.py
log_dir = Path(request.node.name)
if importlib.util.find_spec("vllm") is not None:
server_type = ServerType.vllm
elif importlib.util.find_spec("tensorrt_llm") is not None:
server_type = ServerType.trtllm
else:
raise Exception(
"Neither the vllm nor the tensorrt_llm module is available in the current environment."
)
server_manager = LLMServerManager(
port=port,
cpu_cache_blocks=cpu_blocks,
gpu_cache_blocks=gpu_blocks,
log_dir=log_dir,
server_type=server_type,
)
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
if not server_manager.start_server(timeout=start_timeout):
pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
)
yield server_manager
server_manager.stop_server()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def tester(llm_server): def tester(llm_server):
"""Create determinism tester bound to the running server's base URL.""" """Create determinism tester bound to the running server's base URL."""
...@@ -768,14 +455,9 @@ def tester(llm_server): ...@@ -768,14 +455,9 @@ def tester(llm_server):
class TestDeterminism: class TestDeterminism:
"""Test class for determinism validation.""" """Test class for determinism validation."""
@pytest.mark.parametrize( def base_test_determinism_with_cache_reset(
"llm_server", self, tester, llm_server, runtime_services, success_rate_threshold=1.0
[ ):
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000"))},
],
indirect=True,
)
def test_determinism_with_cache_reset(self, tester, llm_server, runtime_services):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup.""" """Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
print("\n" + "=" * 70) print("\n" + "=" * 70)
print("STARTING DETERMINISM TEST (WITH CACHE RESET)") print("STARTING DETERMINISM TEST (WITH CACHE RESET)")
...@@ -885,6 +567,12 @@ class TestDeterminism: ...@@ -885,6 +567,12 @@ class TestDeterminism:
print(f"Total comparisons: {total_passed + total_failed}") print(f"Total comparisons: {total_passed + total_failed}")
print(f"Passed (deterministic): {total_passed}") print(f"Passed (deterministic): {total_passed}")
print(f"Failed (non-deterministic): {total_failed}") print(f"Failed (non-deterministic): {total_failed}")
success_rate = (
total_passed / (total_passed + total_failed)
if total_passed + total_failed > 0
else 0
)
print(f"Success rate: {success_rate:.1%}")
print( print(
"Test compared responses before cache reset (with warmup) vs after cache reset (no warmup)." "Test compared responses before cache reset (with warmup) vs after cache reset (no warmup)."
) )
...@@ -893,215 +581,5 @@ class TestDeterminism: ...@@ -893,215 +581,5 @@ class TestDeterminism:
pytest.skip("No tests were completed - insufficient data") pytest.skip("No tests were completed - insufficient data")
assert ( assert (
total_failed == 0 success_rate >= success_rate_threshold
), f"Model is not deterministic across cache reset: {total_failed} comparisons failed" ), f"Model is not deterministic across cache reset: {total_failed} comparisons failed, success rate {success_rate:.1%} lower than expected {success_rate_threshold*100}%"
@pytest.mark.parametrize(
"llm_server",
[
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "20000"))},
],
indirect=True,
)
@pytest.mark.parametrize(
"num_concurrent",
[int(x) for x in os.environ.get("KVBM_CONCURRENT_REQUESTS", "3").split(",")],
)
@pytest.mark.parametrize(
"max_tokens",
[int(x) for x in os.environ.get("KVBM_MAX_TOKENS", "10").split(",")],
)
@pytest.mark.parametrize(
"num_prompts",
[int(x) for x in os.environ.get("KVBM_IFEVAL_PROMPTS", "120").split(",")],
)
@pytest.mark.skip(reason="Flaky test: DIS-665")
def test_concurrent_determinism_with_ifeval(
self,
tester,
llm_server,
runtime_services,
num_concurrent,
max_tokens,
num_prompts,
):
"""Simple concurrent determinism test: send IFEval prompts concurrently, with cache reset."""
print("\n" + "=" * 70)
print("CONCURRENT DETERMINISM TEST WITH IFEVAL")
print("=" * 70)
# Override max_tokens for this test iteration
original_max_tokens = os.environ.get("KVBM_MAX_TOKENS")
os.environ["KVBM_MAX_TOKENS"] = str(max_tokens)
print(
f"Using KVBM_MAX_TOKENS={max_tokens} (parametrized, original: {original_max_tokens or '48'})"
)
# Configuration comes from parametrize
print(
f"Configuration: {num_concurrent} concurrent requests, {max_tokens} max tokens"
)
# Load IFEval prompts
ifeval_prompts = tester.download_ifeval_dataset()
if not ifeval_prompts:
pytest.skip("IFEval dataset not available")
# Use parametrized number of IFEval prompts
test_prompts = ifeval_prompts[:num_prompts]
print(
f"Using {len(test_prompts)} IFEval prompts for concurrent testing (parametrized: {num_prompts})"
)
print(f"Concurrency level: {num_concurrent} simultaneous requests")
# Show sample prompts
print("\nSample prompts:")
for i, prompt in enumerate(test_prompts[:3]):
print(f" {i+1}. {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
if len(test_prompts) > 3:
print(f" ... and {len(test_prompts) - 3} more")
def run_concurrent_test(phase_name, do_warmup=False):
"""Run one phase of concurrent testing."""
print(f"\n=== {phase_name} ===")
if do_warmup:
# KV Cache warmup - send ALL test prompts to compute KV caches
print(
f"Warming up KV caches with all {len(test_prompts)} test prompts..."
)
warmup_failed = 0
for i, prompt in enumerate(test_prompts):
if (
i % 5 == 0 or i == len(test_prompts) - 1
): # Progress every 5 prompts
print(f" Warmup progress: {i+1}/{len(test_prompts)}")
try:
tester.make_request(prompt)
except Exception as e:
warmup_failed += 1
if warmup_failed <= 3: # Show first few failures
print(f" Warmup failed for prompt {i}: {e}")
if warmup_failed > 0:
print(
f"Warmup completed with {warmup_failed} failures out of {len(test_prompts)} prompts"
)
else:
print(
f"Warmup completed successfully - all {len(test_prompts)} KV caches computed"
)
# Wait for 10 seconds to make sure all transfers are complete
time.sleep(10)
else:
print("Skipping warmup (already done in previous phase)")
# Run concurrent requests
print(
f"Sending {len(test_prompts)} requests with {num_concurrent} max concurrent..."
)
start_time = time.time()
def make_request_wrapper(prompt_and_idx):
idx, prompt = prompt_and_idx
try:
response = tester.make_request(prompt)
return {
"idx": idx,
"prompt": prompt,
"response": response,
"success": True,
}
except Exception as e:
return {
"idx": idx,
"prompt": prompt,
"error": str(e),
"success": False,
}
# Execute all requests concurrently
with ThreadPoolExecutor(max_workers=num_concurrent) as executor:
results = list(
executor.map(make_request_wrapper, enumerate(test_prompts))
)
elapsed = time.time() - start_time
successful = [r for r in results if r["success"]]
failed = [r for r in results if not r["success"]]
print(
f"Completed in {elapsed:.2f}s - Success: {len(successful)}, Failed: {len(failed)}"
)
if failed:
for fail in failed[:3]: # Show first few failures
print(f" Failed: {fail['error']}")
return successful
# Phase 1: Before cache reset
results_before = run_concurrent_test(
"PHASE 1: BEFORE CACHE RESET", do_warmup=True
)
# Reset cache
print("\n" + "=" * 50)
print("RESETTING CACHE")
print("=" * 50)
tester.reset_prefix_cache()
# Phase 2: After cache reset
results_after = run_concurrent_test("PHASE 2: AFTER CACHE RESET")
# Compare results between phases
print("\n" + "=" * 70)
print("DETERMINISM ANALYSIS")
print("=" * 70)
# Create lookup for before results
before_responses = {r["idx"]: r["response"] for r in results_before}
after_responses = {r["idx"]: r["response"] for r in results_after}
deterministic_count = 0
total_compared = 0
for idx in before_responses:
if idx in after_responses:
total_compared += 1
before_resp = before_responses[idx]
after_resp = after_responses[idx]
if before_resp == after_resp:
deterministic_count += 1
print(f" Prompt {idx}: DETERMINISTIC")
else:
print(f" Prompt {idx}: NON-DETERMINISTIC")
print(f" Before: {before_resp}")
print(f" After: {after_resp}")
# Final assessment
success_rate = deterministic_count / total_compared if total_compared > 0 else 0
print("\n=== FINAL RESULT ===")
print(f"Prompts compared: {total_compared}")
print(f"Deterministic: {deterministic_count}")
print(f"Success rate: {success_rate:.1%}")
print(f"Concurrent requests: {num_concurrent}")
# Restore original max_tokens setting
if original_max_tokens is not None:
os.environ["KVBM_MAX_TOKENS"] = original_max_tokens
else:
os.environ.pop("KVBM_MAX_TOKENS", None)
assert (
success_rate == 1.0
), f"Determinism failed: {deterministic_count}/{total_compared} prompts deterministic"
if __name__ == "__main__":
# Allow running as script
pytest.main([__file__, "-v", "-s"])
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Determinism test for KVBM in aggregated mode.
To make sure KVBM's accuracy, this test suite checks if the model produces
deterministic outputs when same requests are served 1) without KVBM onboarded KV
blocks and 2) with KVBM onboarded KV blocks, when given the same inputs with
fixed seed and temperature=0.
The expected results should be 100% match between the two cases. Compared to
disaggregated mode, aggregated mode has less randomness chances.
"""
import importlib.util
import logging
import os
import signal
import subprocess
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, TextIO
import pytest
import requests
from .common import DeterminismTester, ServerType
from .common import TestDeterminism as BaseTestDeterminism
# Test markers to align with repository conventions
# Todo: enable the rest when kvbm is built in the ci
pytestmark = [
pytest.mark.kvbm,
pytest.mark.e2e,
pytest.mark.slow,
pytest.mark.gpu_1,
]
class LLMServerManager:
"""Manages LLM server lifecycle for determinism testing."""
def __init__(
self,
base_url: Optional[str] = None,
port: Optional[int] = None,
cpu_cache_blocks: Optional[int] = None,
gpu_cache_blocks: Optional[int] = None,
log_dir: Optional[Path] = None,
server_type: Optional[str] = ServerType.vllm,
):
self.server_type = server_type
self.port = port or int(os.environ.get("KVBM_SERVER_PORT", "8000"))
self.base_url = base_url or f"http://localhost:{self.port}"
self.process: Optional[subprocess.Popen] = None
self.cpu_cache_blocks = cpu_cache_blocks
self.gpu_cache_blocks = gpu_cache_blocks
# Prepare logging
self.log_dir = log_dir or Path(".")
self.log_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config_str = (
f"cpu{cpu_cache_blocks or 'default'}_gpu{gpu_cache_blocks or 'default'}"
)
self.server_log_file = (
self.log_dir / f"{self.server_type}_server_{config_str}_{timestamp}.log"
)
self.server_stdout_file: Optional[TextIO] = None
self.server_stderr_file: Optional[TextIO] = None
# Environment for the process
self.env = os.environ.copy()
self.env.update(
{
"RUST_BACKTRACE": "1",
# DynamoConnector connection settings
"NATS_SERVER": "nats://localhost:4222",
"ETCD_ENDPOINTS": "http://localhost:2379",
}
)
# CPU cache blocks override via env
if cpu_cache_blocks is not None:
self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks)
if self.server_type == ServerType.vllm:
self._set_up_vllm_config(gpu_cache_blocks)
elif self.server_type == ServerType.trtllm:
self._set_up_trtllm_config(gpu_cache_blocks)
else:
raise ValueError(
f"{self.server_type} is not supported yet in the KVBM test suite"
)
def _set_up_vllm_config(self, gpu_cache_blocks):
self.env["VLLM_SERVER_DEV_MODE"] = "1"
# Construct serve command
self.server_cmd = [
"vllm",
"serve",
"--block-size",
"16",
"--port",
str(self.port),
"--kv-transfer-config",
'{"kv_connector":"DynamoConnector","kv_role":"kv_both", "kv_connector_module_path": "dynamo.llm.vllm_integration.connector"}',
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--max-seq-len",
"8000", # required to fit on L4 GPU when using 8b model
]
# GPU blocks override
if gpu_cache_blocks is not None:
self.server_cmd.extend(["--num-gpu-blocks-override", str(gpu_cache_blocks)])
def _set_up_trtllm_config(self, gpu_cache_blocks):
config_path = os.environ.get(
"KVBM_TRTLLM_LLMAPI_CONFIG_PATH", "/tmp/kvbm_llm_api_config.yaml"
)
llm_api_config: Dict[str, Any] = {}
llm_api_config[
"cuda_graph_config"
] = None # explicitly disable CUDA graph since Connector API doesn't support CUDA graph yet in TRTLLM
llm_api_config["kv_cache_config"] = {
"enable_partial_reuse": False,
"free_gpu_memory_fraction": 0.10, # Set a small GPU fraction so that we can evict/reset the on-device kv cache faster
}
llm_api_config["kv_connector_config"] = {
"connector_module": "dynamo.llm.trtllm_integration.connector",
"connector_scheduler_class": "DynamoKVBMConnectorLeader",
"connector_worker_class": "DynamoKVBMConnectorWorker",
}
# GPU blocks override
if gpu_cache_blocks is not None:
del llm_api_config["kv_cache_config"]["free_gpu_memory_fraction"]
llm_api_config["kv_cache_config"]["max_tokens"] = (
int(gpu_cache_blocks) * 32
) # TRTLLM defaults 32 tokens per block
# Construct serve command
self.server_cmd = [
"trtllm-serve",
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--host",
"localhost",
"--port",
str(self.port),
"--backend",
"pytorch",
"--extra_llm_api_options",
config_path,
]
import yaml
with open(config_path, "w") as f:
yaml.dump(llm_api_config, f, default_flow_style=False, sort_keys=False)
def start_server(self, timeout: int = 300) -> bool:
"""Start LLM server and wait for readiness."""
if self.is_server_running():
self.stop_server()
time.sleep(2)
# Open log files
self.server_stdout_file = open(
self.server_log_file.with_suffix(".stdout.log"), "w"
)
self.server_stderr_file = open(
self.server_log_file.with_suffix(".stderr.log"), "w"
)
if self.server_stdout_file is not None:
self.server_stdout_file.write(
f"=== {self.server_type} Server Started at {datetime.now()} ===\nCommand: {' '.join(self.server_cmd)}\n"
)
self.server_stdout_file.flush()
# Launch
self.process = subprocess.Popen(
self.server_cmd,
stdout=self.server_stdout_file,
stderr=self.server_stderr_file,
env=self.env,
preexec_fn=os.setsid,
)
# Wait for health
start_time = time.time()
while time.time() - start_time < timeout:
if self.is_server_running():
return True
if self.process.poll() is not None:
self._close_log_files()
return False
time.sleep(5)
# Timeout
self.stop_server()
return False
def stop_server(self):
"""Stop LLM server and close logs."""
if self.process:
try:
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
try:
self.process.wait(timeout=30)
except subprocess.TimeoutExpired:
os.killpg(os.getpgid(self.process.pid), signal.SIGKILL)
self.process.wait()
except (ProcessLookupError, OSError):
pass
finally:
self.process = None
self._close_log_files()
def _close_log_files(self):
if self.server_stdout_file:
self.server_stdout_file.write(
f"\n=== Server Stopped at {datetime.now()} ===\n"
)
self.server_stdout_file.close()
self.server_stdout_file = None
if self.server_stderr_file:
self.server_stderr_file.close()
self.server_stderr_file = None
def is_server_running(self) -> bool:
try:
# First check basic health
response = requests.get(f"{self.base_url}/health", timeout=5)
if response.status_code != 200:
return False
# Then check if the model endpoint is ready with a simple test request
test_payload = {
"model": os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
),
"messages": [{"role": "user", "content": "test"}],
"max_completion_tokens": 1,
"temperature": 0,
}
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Content-Type": "application/json"},
json=test_payload,
timeout=10,
)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
class AggDeterminismTester(DeterminismTester):
"""Aggregated architecture specific determinism tester."""
def __init__(
self,
base_url: Optional[str] = None,
model_id: Optional[str] = None,
server_type: Optional[str] = ServerType.vllm,
):
super().__init__(base_url, model_id, server_type)
def reset_prefix_cache(self):
"""Reset the prefix cache."""
print("Resetting prefix cache...")
if self.server_type == ServerType.trtllm:
# TRTLLM doesn't support reset_prefix_cache endpoint API
# 300 shakespeare content could evict the 0.1 x 80G (~1700 blocks) on-device cache
shakespeare_count = 300
for seq_idx in range(1, shakespeare_count + 1):
start_word = (seq_idx - 1) * self.word_count
content = self.get_shakespeare_content(start_word)
if content:
print(
f"Resetting Shakespeare sequence {seq_idx} (words {start_word}-{start_word + self.word_count - 1})..."
)
try:
self.make_request(content)
except Exception as e:
print(f"Resetting request failed: {e}")
else:
response = requests.post(
f"{self.base_url}/reset_prefix_cache",
timeout=int(os.environ.get("KVBM_HTTP_TIMEOUT", "30")),
)
response.raise_for_status()
print("Cache reset done")
@pytest.fixture(scope="function")
def llm_server(request, runtime_services):
"""Start and stop a LLM server for each test with optional cache block overrides.
To parametrize, use:
@pytest.mark.parametrize("llm_server", [{"cpu_blocks": 10000, "gpu_blocks": 2048}], indirect=True)
"""
logger = logging.getLogger("pytest")
logger.setLevel(logging.INFO)
cpu_blocks = getattr(request, "param", {}).get("cpu_blocks", None)
gpu_blocks = getattr(request, "param", {}).get("gpu_blocks", None)
port = getattr(request, "param", {}).get("port", None)
# Put logs in the per-test directory set up by tests/conftest.py
log_dir = Path(request.node.name)
if importlib.util.find_spec("vllm") is not None:
server_type = ServerType.vllm
elif importlib.util.find_spec("tensorrt_llm") is not None:
server_type = ServerType.trtllm
else:
raise Exception(
"Neither the vllm nor the tensorrt_llm module is available in the current environment."
)
server_manager = LLMServerManager(
port=port,
cpu_cache_blocks=cpu_blocks,
gpu_cache_blocks=gpu_blocks,
log_dir=log_dir,
server_type=server_type,
)
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
if not server_manager.start_server(timeout=start_timeout):
pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
)
yield server_manager
server_manager.stop_server()
@pytest.fixture(scope="function")
def tester(llm_server):
"""Create determinism tester bound to the running server's base URL."""
t = AggDeterminismTester(
base_url=llm_server.base_url, server_type=llm_server.server_type
)
t.download_shakespeare_text()
return t
class TestDeterminismAgg(BaseTestDeterminism):
"""Test class for determinism validation."""
@pytest.mark.parametrize(
"llm_server",
[
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000"))},
],
indirect=True,
)
def test_determinism_agg_with_cache_reset(
self, tester, llm_server, runtime_services
):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
# Call the base class implementation
super().base_test_determinism_with_cache_reset(
tester, llm_server, runtime_services
)
@pytest.mark.parametrize(
"llm_server",
[
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "20000"))},
],
indirect=True,
)
@pytest.mark.parametrize(
"num_concurrent",
[int(x) for x in os.environ.get("KVBM_CONCURRENT_REQUESTS", "3").split(",")],
)
@pytest.mark.parametrize(
"max_tokens",
[int(x) for x in os.environ.get("KVBM_MAX_TOKENS", "10").split(",")],
)
@pytest.mark.parametrize(
"num_prompts",
[int(x) for x in os.environ.get("KVBM_IFEVAL_PROMPTS", "120").split(",")],
)
@pytest.mark.skip(reason="Flaky test: DIS-665")
def test_concurrent_determinism_with_ifeval(
self,
tester,
llm_server,
runtime_services,
num_concurrent,
max_tokens,
num_prompts,
):
"""Simple concurrent determinism test: send IFEval prompts concurrently, with cache reset."""
print("\n" + "=" * 70)
print("CONCURRENT DETERMINISM TEST WITH IFEVAL")
print("=" * 70)
# Override max_tokens for this test iteration
original_max_tokens = os.environ.get("KVBM_MAX_TOKENS")
os.environ["KVBM_MAX_TOKENS"] = str(max_tokens)
print(
f"Using KVBM_MAX_TOKENS={max_tokens} (parametrized, original: {original_max_tokens or '48'})"
)
# Configuration comes from parametrize
print(
f"Configuration: {num_concurrent} concurrent requests, {max_tokens} max tokens"
)
# Load IFEval prompts
ifeval_prompts = tester.download_ifeval_dataset()
if not ifeval_prompts:
pytest.skip("IFEval dataset not available")
# Use parametrized number of IFEval prompts
test_prompts = ifeval_prompts[:num_prompts]
print(
f"Using {len(test_prompts)} IFEval prompts for concurrent testing (parametrized: {num_prompts})"
)
print(f"Concurrency level: {num_concurrent} simultaneous requests")
# Show sample prompts
print("\nSample prompts:")
for i, prompt in enumerate(test_prompts[:3]):
print(f" {i+1}. {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
if len(test_prompts) > 3:
print(f" ... and {len(test_prompts) - 3} more")
def run_concurrent_test(phase_name, do_warmup=False):
"""Run one phase of concurrent testing."""
print(f"\n=== {phase_name} ===")
if do_warmup:
# KV Cache warmup - send ALL test prompts to compute KV caches
print(
f"Warming up KV caches with all {len(test_prompts)} test prompts..."
)
warmup_failed = 0
for i, prompt in enumerate(test_prompts):
if (
i % 5 == 0 or i == len(test_prompts) - 1
): # Progress every 5 prompts
print(f" Warmup progress: {i+1}/{len(test_prompts)}")
try:
tester.make_request(prompt)
except Exception as e:
warmup_failed += 1
if warmup_failed <= 3: # Show first few failures
print(f" Warmup failed for prompt {i}: {e}")
if warmup_failed > 0:
print(
f"Warmup completed with {warmup_failed} failures out of {len(test_prompts)} prompts"
)
else:
print(
f"Warmup completed successfully - all {len(test_prompts)} KV caches computed"
)
# Wait for 10 seconds to make sure all transfers are complete
time.sleep(10)
else:
print("Skipping warmup (already done in previous phase)")
# Run concurrent requests
print(
f"Sending {len(test_prompts)} requests with {num_concurrent} max concurrent..."
)
start_time = time.time()
def make_request_wrapper(prompt_and_idx):
idx, prompt = prompt_and_idx
try:
response = tester.make_request(prompt)
return {
"idx": idx,
"prompt": prompt,
"response": response,
"success": True,
}
except Exception as e:
return {
"idx": idx,
"prompt": prompt,
"error": str(e),
"success": False,
}
# Execute all requests concurrently
with ThreadPoolExecutor(max_workers=num_concurrent) as executor:
results = list(
executor.map(make_request_wrapper, enumerate(test_prompts))
)
elapsed = time.time() - start_time
successful = [r for r in results if r["success"]]
failed = [r for r in results if not r["success"]]
print(
f"Completed in {elapsed:.2f}s - Success: {len(successful)}, Failed: {len(failed)}"
)
if failed:
for fail in failed[:3]: # Show first few failures
print(f" Failed: {fail['error']}")
return successful
# Phase 1: Before cache reset
results_before = run_concurrent_test(
"PHASE 1: BEFORE CACHE RESET", do_warmup=True
)
# Reset cache
print("\n" + "=" * 50)
print("RESETTING CACHE")
print("=" * 50)
tester.reset_prefix_cache()
# Phase 2: After cache reset
results_after = run_concurrent_test("PHASE 2: AFTER CACHE RESET")
# Compare results between phases
print("\n" + "=" * 70)
print("DETERMINISM ANALYSIS")
print("=" * 70)
# Create lookup for before results
before_responses = {r["idx"]: r["response"] for r in results_before}
after_responses = {r["idx"]: r["response"] for r in results_after}
deterministic_count = 0
total_compared = 0
for idx in before_responses:
if idx in after_responses:
total_compared += 1
before_resp = before_responses[idx]
after_resp = after_responses[idx]
if before_resp == after_resp:
deterministic_count += 1
print(f" Prompt {idx}: DETERMINISTIC")
else:
print(f" Prompt {idx}: NON-DETERMINISTIC")
print(f" Before: {before_resp}")
print(f" After: {after_resp}")
# Final assessment
success_rate = deterministic_count / total_compared if total_compared > 0 else 0
print("\n=== FINAL RESULT ===")
print(f"Prompts compared: {total_compared}")
print(f"Deterministic: {deterministic_count}")
print(f"Success rate: {success_rate:.1%}")
print(f"Concurrent requests: {num_concurrent}")
# Restore original max_tokens setting
if original_max_tokens is not None:
os.environ["KVBM_MAX_TOKENS"] = original_max_tokens
else:
os.environ.pop("KVBM_MAX_TOKENS", None)
assert (
success_rate == 1.0
), f"Determinism failed: {deterministic_count}/{total_compared} prompts deterministic"
if __name__ == "__main__":
# Allow running as script
pytest.main([__file__, "-v", "-s"])
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Determinism test for KVBM in disaggregated mode.
To make sure KVBM's accuracy, this test suite checks if the model produces
deterministic outputs when same requests are served 1) without KVBM onboarded KV
blocks and 2) with KVBM onboarded KV blocks, when given the same inputs with
fixed seed and temperature=0.
The expected results should be at least 95% match between the two cases.
Compared to aggregated mode, disaggregated mode has some known randomness.
Example reference: https://github.com/vllm-project/vllm/issues/7779#issuecomment-2304967870
"""
import importlib.util
import logging
import os
import signal
import subprocess
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, TextIO
import pytest
import requests
from .common import DeterminismTester, ServerType
from .common import TestDeterminism as BaseTestDeterminism
# Test markers to align with repository conventions
# Todo: enable the rest when kvbm is built in the ci
pytestmark = [
pytest.mark.kvbm,
pytest.mark.e2e,
pytest.mark.slow,
pytest.mark.gpu_2,
]
SUCCESS_RATE_THRESHOLD = 0.95
class LLMServerManager:
"""Manages LLM server lifecycle for determinism testing."""
def __init__(
self,
base_url: Optional[str] = None,
port: Optional[int] = None,
cpu_cache_blocks: Optional[int] = None,
gpu_cache_blocks: Optional[int] = None,
log_dir: Optional[Path] = None,
server_type: Optional[str] = ServerType.vllm,
):
self.server_type = server_type
self.port = port or int(os.environ.get("KVBM_SERVER_PORT", "8000"))
self.base_url = base_url or f"http://localhost:{self.port}"
self.process_frontend: Optional[subprocess.Popen] = None
self.process_prefiller: Optional[subprocess.Popen] = None
self.process_decoder: Optional[subprocess.Popen] = None
self.cpu_cache_blocks = cpu_cache_blocks
self.gpu_cache_blocks = gpu_cache_blocks
# Prepare logging
self.log_dir = log_dir or Path(".")
self.log_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config_str = (
f"cpu{cpu_cache_blocks or 'default'}_gpu{gpu_cache_blocks or 'default'}"
)
self.prefiller_log_file = (
self.log_dir / f"{self.server_type}_prefiller_{config_str}_{timestamp}.log"
)
self.prefiller_stdout_file: Optional[TextIO] = None
self.prefiller_stderr_file: Optional[TextIO] = None
self.decoder_log_file = (
self.log_dir / f"{self.server_type}_decoder_{timestamp}.log"
)
self.decoder_stdout_file: Optional[TextIO] = None
self.decoder_stderr_file: Optional[TextIO] = None
# Environment for the process
self.env = os.environ.copy()
self.env.update(
{
"RUST_BACKTRACE": "1",
# DynamoConnector connection settings
"NATS_SERVER": "nats://localhost:4222",
"ETCD_ENDPOINTS": "http://localhost:2379",
}
)
# CPU cache blocks override via env
if cpu_cache_blocks is not None:
self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks)
self._set_up_dynamo_config()
if self.server_type == ServerType.vllm:
self._set_up_vllm_config(gpu_cache_blocks)
else:
raise ValueError(
f"{self.server_type} is not supported yet in the KVBM test suite"
)
def _set_up_dynamo_config(self, router_mode: str = "kv"):
self.dynamo_frontend_cmd = [
"python3",
"-m",
"dynamo.frontend",
"--router-mode",
router_mode,
"--http-port",
str(self.port),
]
def _set_up_vllm_config(self, gpu_cache_blocks):
self.env["VLLM_SERVER_DEV_MODE"] = "1"
# Construct decoder command
self.decoder_cmd = [
"python3",
"-m",
"dynamo.vllm",
"--model",
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--block-size",
"16",
"--max-seq-len",
"8000", # required to fit on L4 GPU when using 8b model
"--connector",
"nixl",
]
# Construct prefiller command
self.prefiller_cmd = [
"python3",
"-m",
"dynamo.vllm",
"--model",
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--is-prefill-worker",
"--block-size",
"16",
"--max-seq-len",
"8000", # required to fit on L4 GPU when using 8b model
"--connector",
"kvbm",
"nixl",
]
# GPU blocks override
if gpu_cache_blocks is not None:
self.decoder_cmd.extend(
["--num-gpu-blocks-override", str(gpu_cache_blocks)]
)
self.prefiller_cmd.extend(
["--num-gpu-blocks-override", str(gpu_cache_blocks)]
)
def start_server(self, timeout: int = 300) -> bool:
"""Start LLM server and wait for readiness."""
if self.is_server_running():
self.stop_server()
time.sleep(5)
# Open log files
self.prefiller_stdout_file = open(
self.prefiller_log_file.with_suffix(".stdout.log"), "w"
)
self.prefiller_stderr_file = open(
self.prefiller_log_file.with_suffix(".stderr.log"), "w"
)
if self.prefiller_stdout_file is not None:
self.prefiller_stdout_file.write(
f"=== {self.server_type} Prefiller Started at {datetime.now()} ===\nCommand: {' '.join(self.prefiller_cmd)}\n"
)
self.prefiller_stdout_file.flush()
self.decoder_stdout_file = open(
self.decoder_log_file.with_suffix(".stdout.log"), "w"
)
self.decoder_stderr_file = open(
self.decoder_log_file.with_suffix(".stderr.log"), "w"
)
if self.decoder_stdout_file is not None:
self.decoder_stdout_file.write(
f"=== {self.server_type} Decoder Started at {datetime.now()} ===\nCommand: {' '.join(self.decoder_cmd)}\n"
)
self.decoder_stdout_file.flush()
# Create separate environment configs for different processes
decoder_env = self.env.copy()
decoder_env["CUDA_VISIBLE_DEVICES"] = "0"
prefiller_env = self.env.copy()
prefiller_env["CUDA_VISIBLE_DEVICES"] = "1"
# Launch frontend first
self.process_frontend = subprocess.Popen(
self.dynamo_frontend_cmd,
env=self.env,
preexec_fn=os.setsid,
)
print(f"Frontend process started with PID: {self.process_frontend.pid}")
# Give frontend time to start up
time.sleep(5)
# Launch decoder
self.process_decoder = subprocess.Popen(
self.decoder_cmd,
stdout=self.decoder_stdout_file,
stderr=self.decoder_stderr_file,
env=decoder_env,
preexec_fn=os.setsid,
)
print(f"Decoder process started with PID: {self.process_decoder.pid}")
# Give decoder time to start up
time.sleep(5)
# Launch prefiller
self.process_prefiller = subprocess.Popen(
self.prefiller_cmd,
stdout=self.prefiller_stdout_file,
stderr=self.prefiller_stderr_file,
env=prefiller_env,
preexec_fn=os.setsid,
)
print(f"Prefiller process started with PID: {self.process_prefiller.pid}")
# Give prefiller time to start up
print(
"Sleeping for 30 seconds to wait for decoder and prefiller to start up..."
)
time.sleep(30)
# Wait for health
start_time = time.time()
while time.time() - start_time < timeout:
try:
if self.is_server_running():
return True
if (
self.process_frontend.poll() is not None
or self.process_prefiller.poll() is not None
or self.process_decoder.poll() is not None
):
self.stop_server()
return False
except Exception as e:
print(f"Error checking server status: {e}")
print("Waiting for server to start up:")
print(f"timeout: {timeout}, elapsed: {int(time.time() - start_time)}")
time.sleep(5)
# Timeout
self.stop_server()
return False
def stop_server(self):
"""Stop LLM server and close logs."""
if self.process_frontend:
try:
os.killpg(os.getpgid(self.process_frontend.pid), signal.SIGTERM)
try:
self.process_frontend.wait(timeout=30)
except subprocess.TimeoutExpired:
os.killpg(os.getpgid(self.process_frontend.pid), signal.SIGKILL)
self.process_frontend.wait()
except (ProcessLookupError, OSError):
pass
finally:
self.process_frontend = None
if self.process_prefiller:
try:
os.killpg(os.getpgid(self.process_prefiller.pid), signal.SIGTERM)
try:
self.process_prefiller.wait(timeout=30)
except subprocess.TimeoutExpired:
os.killpg(os.getpgid(self.process_prefiller.pid), signal.SIGKILL)
self.process_prefiller.wait()
except (ProcessLookupError, OSError):
pass
finally:
self.process_prefiller = None
if self.process_decoder:
try:
os.killpg(os.getpgid(self.process_decoder.pid), signal.SIGTERM)
try:
self.process_decoder.wait(timeout=30)
except subprocess.TimeoutExpired:
os.killpg(os.getpgid(self.process_decoder.pid), signal.SIGKILL)
self.process_decoder.wait()
except (ProcessLookupError, OSError):
pass
finally:
self.process_decoder = None
self._close_log_files()
def _close_log_files(self):
if self.prefiller_stdout_file:
self.prefiller_stdout_file.write(
f"\n=== Prefiller Stopped at {datetime.now()} ===\n"
)
self.prefiller_stdout_file.close()
self.prefiller_stdout_file = None
if self.prefiller_stderr_file:
self.prefiller_stderr_file.close()
self.prefiller_stderr_file = None
if self.decoder_stdout_file:
self.decoder_stdout_file.write(
f"\n=== Decoder Stopped at {datetime.now()} ===\n"
)
self.decoder_stdout_file.close()
self.decoder_stdout_file = None
if self.decoder_stderr_file:
self.decoder_stderr_file.close()
self.decoder_stderr_file = None
def is_server_running(self) -> bool:
try:
# First check basic health
response = requests.get(f"{self.base_url}/health", timeout=5)
if response.status_code != 200:
return False
# Then check if the model endpoint is ready with a simple test request
test_payload = {
"model": os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
),
"messages": [{"role": "user", "content": "test"}],
"max_completion_tokens": 1,
"temperature": 0,
}
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Content-Type": "application/json"},
json=test_payload,
timeout=10,
)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
class DisaggDeterminismTester(DeterminismTester):
"""Disaggregated architecture specific determinism tester."""
def __init__(
self,
base_url: Optional[str] = None,
model_id: Optional[str] = None,
server_type: Optional[str] = ServerType.vllm,
):
super().__init__(base_url, model_id, server_type)
def reset_prefix_cache(self):
"""Reset the prefix cache."""
print("Resetting prefix cache...")
# 150 shakespeare requests (each request is 200 words, and roughly 17 blocks) could evict 150 * 17 = 2550 blocks
shakespeare_count = 150
for seq_idx in range(1, shakespeare_count + 1):
start_word = (seq_idx - 1) * self.word_count
content = self.get_shakespeare_content(start_word)
if content:
print(
f"Resetting Shakespeare sequence {seq_idx} (words {start_word}-{start_word + self.word_count - 1})..."
)
try:
self.make_request(content)
except Exception as e:
print(f"Resetting request failed: {e}")
print("Cache reset done")
@pytest.fixture(scope="function")
def llm_server(request, runtime_services):
"""Start and stop a LLM server for each test with optional cache block overrides.
To parametrize, use:
@pytest.mark.parametrize("llm_server", [{"cpu_blocks": 10000, "gpu_blocks": 1000}], indirect=True)
"""
logger = logging.getLogger("pytest")
logger.setLevel(logging.INFO)
cpu_blocks = getattr(request, "param", {}).get("cpu_blocks", None)
gpu_blocks = getattr(request, "param", {}).get("gpu_blocks", None)
port = getattr(request, "param", {}).get("port", None)
# Put logs in the per-test directory set up by tests/conftest.py
log_dir = Path(request.node.name)
if importlib.util.find_spec("vllm") is not None:
server_type = ServerType.vllm
else:
raise Exception("vllm module is not available in the current environment.")
server_manager = LLMServerManager(
port=port,
cpu_cache_blocks=cpu_blocks,
gpu_cache_blocks=gpu_blocks,
log_dir=log_dir,
server_type=server_type,
)
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
if not server_manager.start_server(timeout=start_timeout):
pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
)
yield server_manager
server_manager.stop_server()
@pytest.fixture(scope="function")
def tester(llm_server):
"""Create determinism tester bound to the running server's base URL."""
t = DisaggDeterminismTester(
base_url=llm_server.base_url, server_type=llm_server.server_type
)
t.download_shakespeare_text()
return t
class TestDeterminismDisagg(BaseTestDeterminism):
"""Test class for determinism validation."""
@pytest.mark.parametrize(
"llm_server",
[
{
"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000")),
"gpu_blocks": int(os.environ.get("KVBM_GPU_BLOCKS", "1000")),
},
],
indirect=True,
)
def test_determinism_disagg_with_cache_reset(
self, tester, llm_server, runtime_services
):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
# Call the base class implementation
super().base_test_determinism_with_cache_reset(
tester,
llm_server,
runtime_services,
success_rate_threshold=SUCCESS_RATE_THRESHOLD,
)
if __name__ == "__main__":
# Allow running as script
pytest.main([__file__, "-v", "-s"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment