Unverified Commit 9975cb9f authored by Ziqi Fan's avatar Ziqi Fan Committed by GitHub
Browse files

feat: enable KVBM to support PD disagg in Dynamo vLLM (#3352)


Signed-off-by: default avatarZiqi Fan <ziqif@nvidia.com>
parent 0aa0768f
......@@ -40,7 +40,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
| [**KV-Aware Routing**](../../../docs/architecture/kv_cache_routing.md) | ✅ | |
| [**SLA-Based Planner**](../../../docs/architecture/sla_planner.md) | ✅ | |
| [**Load Based Planner**](../../../docs/architecture/load_planner.md) | 🚧 | WIP |
| [**KVBM**](../../../docs/architecture/kvbm_architecture.md) | 🚧 | WIP |
| [**KVBM**](../../../docs/architecture/kvbm_architecture.md) | | |
| [**LMCache**](./LMCache_Integration.md) | ✅ | |
### Large Scale P/D and WideEP Features
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# run ingress
python -m dynamo.frontend --http-port=8000 &
# run worker with KVBM enabled
# NOTE: remove --enforce-eager for production use
DYN_KVBM_CPU_CACHE_GB=20 \
python -m dynamo.vllm --model Qwen/Qwen3-0.6B --connector kvbm --enforce-eager
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# run ingress with KV router
python -m dynamo.frontend --router-mode kv --http-port=8000 &
# run decode worker on GPU 0, without enabling KVBM
# NOTE: remove --enforce-eager for production use
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --connector nixl --enforce-eager &
# run prefill worker on GPU 1 with KVBM enabled using 20GB of CPU cache
# NOTE: remove --enforce-eager for production use
DYN_KVBM_CPU_CACHE_GB=20 \
CUDA_VISIBLE_DEVICES=1 \
python3 -m dynamo.vllm \
--model Qwen/Qwen3-0.6B \
--is-prefill-worker \
--connector kvbm nixl \
--enforce-eager
......@@ -351,11 +351,12 @@ def create_kv_transfer_config(config: Config) -> Optional[KVTransferConfig]:
cfg = multi_connectors[0]
return KVTransferConfig(**cfg)
# For multiple connectors, use MultiConnector
# For multiple connectors, use PdConnector
return KVTransferConfig(
kv_connector="MultiConnector",
kv_connector="PdConnector",
kv_role="kv_both",
kv_connector_extra_config={"connectors": multi_connectors},
kv_connector_module_path="dynamo.llm.vllm_integration.connector",
)
......
......@@ -223,7 +223,7 @@ if [ "$ARCH" = "amd64" ]; then
# TODO: Re-enable for arm64 after verifying lmcache compatibility and resolving the build issue.
# Alec: Likely lmcache was compiled witha different version of torch and need to install it from source for arm64
uv pip install lmcache==0.3.3
uv pip install lmcache==0.3.7
echo "✓ LMCache installed"
else
echo "⚠ Skipping LMCache on ARM64 (compatibility issues)"
......
......@@ -25,6 +25,7 @@ To learn what KVBM is, please check [here](https://docs.nvidia.com/dynamo/latest
To use KVBM in vLLM, you can follow the steps below:
### Docker Setup
```bash
# start up etcd for KVBM leader/worker registration and discovery
docker compose -f deploy/docker-compose.yml up -d
......@@ -34,26 +35,32 @@ docker compose -f deploy/docker-compose.yml up -d
# launch the container
./container/run.sh --framework vllm -it --mount-workspace --use-nixl-gds
```
# enable kv offloading to CPU memory
# 4 means 4GB of CPU memory would be used
export DYN_KVBM_CPU_CACHE_GB=4
# enable kv offloading to disk
# 8 means 8GB of disk would be used
export DYN_KVBM_DISK_CACHE_GB=8
### Aggregated Serving with KVBM
```bash
cd $DYNAMO_HOME/components/backends/vllm
./launch/agg_kvbm.sh
```
# [DYNAMO] start dynamo frontend
python -m dynamo.frontend --http-port 8000 &
### Disaggregated Serving with KVBM (1P1D)
```bash
# NOTE: need at least 2 GPUs
cd $DYNAMO_HOME/components/backends/vllm
./launch/disagg_kvbm.sh
```
> [!NOTE]
> To tune the size of CPU or disk cache, set `DYN_KVBM_CPU_CACHE_GB` and `DYN_KVBM_DISK_CACHE_GB` accordingly. We only set `DYN_KVBM_CPU_CACHE_GB=20` in both scripts above.
# [DYNAMO] serve an LLM model using KVBM with dynamo
python -m dynamo.vllm \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--connector kvbm &
> [!NOTE]
> `DYN_KVBM_CPU_CACHE_GB` must be set and `DYN_KVBM_DISK_CACHE_GB` is optional.
# make a call to LLM
### Sample Request
```bash
# make a request to verify vLLM with KVBM is started up correctly
# NOTE: change the model name if served with a different one
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"model": "Qwen/Qwen3-0.6B",
"messages": [
{
"role": "user",
......@@ -61,11 +68,11 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json"
}
],
"stream":false,
"max_tokens": 30
"max_tokens": 10
}'
```
Alternatively, can use "vllm serve" with KVBM by replacing the above two [DYNAMO] cmds with below:
Alternatively, can use `vllm serve` directly to use KVBM for aggregated serving:
```bash
vllm serve --kv-transfer-config '{"kv_connector":"DynamoConnector","kv_role":"kv_both", "kv_connector_module_path": "dynamo.llm.vllm_integration.connector"}' deepseek-ai/DeepSeek-R1-Distill-Llama-8B
```
......
......@@ -2,8 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use dynamo_llm::block_manager::{
block::BlockId, connector::protocol::WorkerTransferRequest, distributed::BlockTransferRequest,
pool::BlockPoolError,
block::BlockId, connector::protocol::WorkerTransferRequest, pool::BlockPoolError,
};
pub mod leader;
......@@ -163,11 +162,3 @@ impl ConnectorMetadata {
self.operations.extend(xfer_reqs);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectorOperation {
pub req_id: String,
pub iteration: u64,
pub uuid: uuid::Uuid,
pub xfer_req: BlockTransferRequest,
}
......@@ -3,8 +3,16 @@
# Import connector classes to make them available at the expected paths for vLLM
from .connector.dynamo_connector import DynamoConnector, DynamoConnectorMetadata
from .connector.pd_connector import PdConnector, PdConnectorMetadata
# Create module-level alias for backward compatibility
dynamo_connector = DynamoConnector
pd_connector = PdConnector
__all__ = ["DynamoConnector", "DynamoConnectorMetadata", "dynamo_connector"]
__all__ = [
"DynamoConnector",
"DynamoConnectorMetadata",
"dynamo_connector",
"PdConnector",
"PdConnectorMetadata",
]
......@@ -2,5 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
from .dynamo_connector import DynamoConnector, DynamoConnectorMetadata
from .pd_connector import PdConnector, PdConnectorMetadata
__all__ = ["DynamoConnector", "DynamoConnectorMetadata"]
__all__ = [
"DynamoConnector",
"DynamoConnectorMetadata",
"PdConnector",
"PdConnectorMetadata",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import TYPE_CHECKING
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
LMCacheConnectorV1,
)
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
MultiKVConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import NixlConnector
from vllm.v1.core.sched.output import SchedulerOutput
from dynamo.llm.vllm_integration.connector.dynamo_connector import DynamoConnector
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@dataclass
class PdConnectorMetadata(MultiKVConnectorMetadata):
pass
class PdConnector(MultiConnector):
"""
A wrapper for using KV offloading Connectors (e.g. KVBM or LMCache) and NIXL Connector for PD disaggregated serving.
The current logic is:
- The first connector must be KVBM or LMCache and would be used by prefill worker to offload and onboard KV blocks.
- The second connector must be NIXL and will be used by decode worker to get KV blocks from prefill worker.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
if len(self._connectors) != 2:
raise ValueError(
f"PdConnector requires exactly two connectors (got {len(self._connectors)})"
)
if not isinstance(self._connectors[0], (DynamoConnector, LMCacheConnectorV1)):
raise TypeError(
f"Expected first connector to be DynamoConnector or LMCacheConnector, "
f"got {type(self._connectors[0]).__name__}"
)
if not isinstance(self._connectors[1], NixlConnector):
raise TypeError(
f"Expected second connector to be NixlConnector, "
f"got {type(self._connectors[1]).__name__}"
)
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(self, connector_metadata: PdConnectorMetadata) -> None:
assert isinstance(connector_metadata, PdConnectorMetadata)
if connector_metadata.extra_async_saves:
self._extra_async_saves.update(connector_metadata.extra_async_saves)
for c, cm in zip(self._connectors, connector_metadata.metadata):
c.bind_connector_metadata(cm)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Get the number of matched tokens for the request using Dynamo Connector (KVBM).
"""
return self._connectors[0].get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update the state after allocation using Dynamo Connector (KVBM) and Nixl Connector.
"""
empty_blocks = blocks.new_empty()
# allocate blocks for KV offloading connector to onboard KV blocks
self._connectors[0].update_state_after_alloc(
request, blocks, num_external_tokens
)
# no need to allocate any blocks for NIXL connector since this is in prefill worker side
# and it only needs to wait for decode worker to pull its data.
self._connectors[1].update_state_after_alloc(request, empty_blocks, 0)
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> PdConnectorMetadata:
metadata = PdConnectorMetadata(
metadata=tuple(
c.build_connector_meta(scheduler_output) for c in self._connectors
)
)
if self._extra_async_saves:
metadata.extra_async_saves = self._extra_async_saves
self._extra_async_saves = {}
return metadata
......@@ -55,14 +55,9 @@ pub fn load_and_validate_tensors(
// Check the stride, and ensure our tensor is contiguous.
// TODO: We eventually need to be able to handle this.
let stride = tensor.stride();
for i in 1..stride.len() {
if stride[i] > stride[i - 1] {
return Err(anyhow::anyhow!(
"Tensor strides must be monotonically decreasing! Got {:?}",
stride
));
}
}
tracing::debug!("stride: {:?}", stride);
tracing::debug!("stride is monotonically decreasing for NHD layout");
tracing::debug!("stride is NOT monotonically decreasing for HND layout");
// Check that all layer tensors have the same shape.
// TODO: We eventually need to support the weirder models with heterogenous layers.
......
......@@ -31,10 +31,13 @@ Run all kvbm tests:
pytest -v -m "kvbm" -s
```
Run the determinism test file directly:
Run the determinism test file directly inside dynamo repo:
```bash
pytest -v dynamo/tests/kvbm/test_determinism.py -s
pytest -v tests/kvbm/test_determinism_agg.py -s
# disagg needs 2 GPUs to run
pytest -v tests/kvbm/test_determinism_disagg.py -s
```
## Configuration
......
This diff is collapsed.
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Determinism test for KVBM in disaggregated mode.
To make sure KVBM's accuracy, this test suite checks if the model produces
deterministic outputs when same requests are served 1) without KVBM onboarded KV
blocks and 2) with KVBM onboarded KV blocks, when given the same inputs with
fixed seed and temperature=0.
The expected results should be at least 95% match between the two cases.
Compared to aggregated mode, disaggregated mode has some known randomness.
Example reference: https://github.com/vllm-project/vllm/issues/7779#issuecomment-2304967870
"""
import importlib.util
import logging
import os
import signal
import subprocess
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, TextIO
import pytest
import requests
from .common import DeterminismTester, ServerType
from .common import TestDeterminism as BaseTestDeterminism
# Test markers to align with repository conventions
# Todo: enable the rest when kvbm is built in the ci
pytestmark = [
pytest.mark.kvbm,
pytest.mark.e2e,
pytest.mark.slow,
pytest.mark.gpu_2,
]
SUCCESS_RATE_THRESHOLD = 0.95
class LLMServerManager:
"""Manages LLM server lifecycle for determinism testing."""
def __init__(
self,
base_url: Optional[str] = None,
port: Optional[int] = None,
cpu_cache_blocks: Optional[int] = None,
gpu_cache_blocks: Optional[int] = None,
log_dir: Optional[Path] = None,
server_type: Optional[str] = ServerType.vllm,
):
self.server_type = server_type
self.port = port or int(os.environ.get("KVBM_SERVER_PORT", "8000"))
self.base_url = base_url or f"http://localhost:{self.port}"
self.process_frontend: Optional[subprocess.Popen] = None
self.process_prefiller: Optional[subprocess.Popen] = None
self.process_decoder: Optional[subprocess.Popen] = None
self.cpu_cache_blocks = cpu_cache_blocks
self.gpu_cache_blocks = gpu_cache_blocks
# Prepare logging
self.log_dir = log_dir or Path(".")
self.log_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config_str = (
f"cpu{cpu_cache_blocks or 'default'}_gpu{gpu_cache_blocks or 'default'}"
)
self.prefiller_log_file = (
self.log_dir / f"{self.server_type}_prefiller_{config_str}_{timestamp}.log"
)
self.prefiller_stdout_file: Optional[TextIO] = None
self.prefiller_stderr_file: Optional[TextIO] = None
self.decoder_log_file = (
self.log_dir / f"{self.server_type}_decoder_{timestamp}.log"
)
self.decoder_stdout_file: Optional[TextIO] = None
self.decoder_stderr_file: Optional[TextIO] = None
# Environment for the process
self.env = os.environ.copy()
self.env.update(
{
"RUST_BACKTRACE": "1",
# DynamoConnector connection settings
"NATS_SERVER": "nats://localhost:4222",
"ETCD_ENDPOINTS": "http://localhost:2379",
}
)
# CPU cache blocks override via env
if cpu_cache_blocks is not None:
self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks)
self._set_up_dynamo_config()
if self.server_type == ServerType.vllm:
self._set_up_vllm_config(gpu_cache_blocks)
else:
raise ValueError(
f"{self.server_type} is not supported yet in the KVBM test suite"
)
def _set_up_dynamo_config(self, router_mode: str = "kv"):
self.dynamo_frontend_cmd = [
"python3",
"-m",
"dynamo.frontend",
"--router-mode",
router_mode,
"--http-port",
str(self.port),
]
def _set_up_vllm_config(self, gpu_cache_blocks):
self.env["VLLM_SERVER_DEV_MODE"] = "1"
# Construct decoder command
self.decoder_cmd = [
"python3",
"-m",
"dynamo.vllm",
"--model",
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--block-size",
"16",
"--max-seq-len",
"8000", # required to fit on L4 GPU when using 8b model
"--connector",
"nixl",
]
# Construct prefiller command
self.prefiller_cmd = [
"python3",
"-m",
"dynamo.vllm",
"--model",
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--is-prefill-worker",
"--block-size",
"16",
"--max-seq-len",
"8000", # required to fit on L4 GPU when using 8b model
"--connector",
"kvbm",
"nixl",
]
# GPU blocks override
if gpu_cache_blocks is not None:
self.decoder_cmd.extend(
["--num-gpu-blocks-override", str(gpu_cache_blocks)]
)
self.prefiller_cmd.extend(
["--num-gpu-blocks-override", str(gpu_cache_blocks)]
)
def start_server(self, timeout: int = 300) -> bool:
"""Start LLM server and wait for readiness."""
if self.is_server_running():
self.stop_server()
time.sleep(5)
# Open log files
self.prefiller_stdout_file = open(
self.prefiller_log_file.with_suffix(".stdout.log"), "w"
)
self.prefiller_stderr_file = open(
self.prefiller_log_file.with_suffix(".stderr.log"), "w"
)
if self.prefiller_stdout_file is not None:
self.prefiller_stdout_file.write(
f"=== {self.server_type} Prefiller Started at {datetime.now()} ===\nCommand: {' '.join(self.prefiller_cmd)}\n"
)
self.prefiller_stdout_file.flush()
self.decoder_stdout_file = open(
self.decoder_log_file.with_suffix(".stdout.log"), "w"
)
self.decoder_stderr_file = open(
self.decoder_log_file.with_suffix(".stderr.log"), "w"
)
if self.decoder_stdout_file is not None:
self.decoder_stdout_file.write(
f"=== {self.server_type} Decoder Started at {datetime.now()} ===\nCommand: {' '.join(self.decoder_cmd)}\n"
)
self.decoder_stdout_file.flush()
# Create separate environment configs for different processes
decoder_env = self.env.copy()
decoder_env["CUDA_VISIBLE_DEVICES"] = "0"
prefiller_env = self.env.copy()
prefiller_env["CUDA_VISIBLE_DEVICES"] = "1"
# Launch frontend first
self.process_frontend = subprocess.Popen(
self.dynamo_frontend_cmd,
env=self.env,
preexec_fn=os.setsid,
)
print(f"Frontend process started with PID: {self.process_frontend.pid}")
# Give frontend time to start up
time.sleep(5)
# Launch decoder
self.process_decoder = subprocess.Popen(
self.decoder_cmd,
stdout=self.decoder_stdout_file,
stderr=self.decoder_stderr_file,
env=decoder_env,
preexec_fn=os.setsid,
)
print(f"Decoder process started with PID: {self.process_decoder.pid}")
# Give decoder time to start up
time.sleep(5)
# Launch prefiller
self.process_prefiller = subprocess.Popen(
self.prefiller_cmd,
stdout=self.prefiller_stdout_file,
stderr=self.prefiller_stderr_file,
env=prefiller_env,
preexec_fn=os.setsid,
)
print(f"Prefiller process started with PID: {self.process_prefiller.pid}")
# Give prefiller time to start up
print(
"Sleeping for 30 seconds to wait for decoder and prefiller to start up..."
)
time.sleep(30)
# Wait for health
start_time = time.time()
while time.time() - start_time < timeout:
try:
if self.is_server_running():
return True
if (
self.process_frontend.poll() is not None
or self.process_prefiller.poll() is not None
or self.process_decoder.poll() is not None
):
self.stop_server()
return False
except Exception as e:
print(f"Error checking server status: {e}")
print("Waiting for server to start up:")
print(f"timeout: {timeout}, elapsed: {int(time.time() - start_time)}")
time.sleep(5)
# Timeout
self.stop_server()
return False
def stop_server(self):
"""Stop LLM server and close logs."""
if self.process_frontend:
try:
os.killpg(os.getpgid(self.process_frontend.pid), signal.SIGTERM)
try:
self.process_frontend.wait(timeout=30)
except subprocess.TimeoutExpired:
os.killpg(os.getpgid(self.process_frontend.pid), signal.SIGKILL)
self.process_frontend.wait()
except (ProcessLookupError, OSError):
pass
finally:
self.process_frontend = None
if self.process_prefiller:
try:
os.killpg(os.getpgid(self.process_prefiller.pid), signal.SIGTERM)
try:
self.process_prefiller.wait(timeout=30)
except subprocess.TimeoutExpired:
os.killpg(os.getpgid(self.process_prefiller.pid), signal.SIGKILL)
self.process_prefiller.wait()
except (ProcessLookupError, OSError):
pass
finally:
self.process_prefiller = None
if self.process_decoder:
try:
os.killpg(os.getpgid(self.process_decoder.pid), signal.SIGTERM)
try:
self.process_decoder.wait(timeout=30)
except subprocess.TimeoutExpired:
os.killpg(os.getpgid(self.process_decoder.pid), signal.SIGKILL)
self.process_decoder.wait()
except (ProcessLookupError, OSError):
pass
finally:
self.process_decoder = None
self._close_log_files()
def _close_log_files(self):
if self.prefiller_stdout_file:
self.prefiller_stdout_file.write(
f"\n=== Prefiller Stopped at {datetime.now()} ===\n"
)
self.prefiller_stdout_file.close()
self.prefiller_stdout_file = None
if self.prefiller_stderr_file:
self.prefiller_stderr_file.close()
self.prefiller_stderr_file = None
if self.decoder_stdout_file:
self.decoder_stdout_file.write(
f"\n=== Decoder Stopped at {datetime.now()} ===\n"
)
self.decoder_stdout_file.close()
self.decoder_stdout_file = None
if self.decoder_stderr_file:
self.decoder_stderr_file.close()
self.decoder_stderr_file = None
def is_server_running(self) -> bool:
try:
# First check basic health
response = requests.get(f"{self.base_url}/health", timeout=5)
if response.status_code != 200:
return False
# Then check if the model endpoint is ready with a simple test request
test_payload = {
"model": os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
),
"messages": [{"role": "user", "content": "test"}],
"max_completion_tokens": 1,
"temperature": 0,
}
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Content-Type": "application/json"},
json=test_payload,
timeout=10,
)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
class DisaggDeterminismTester(DeterminismTester):
"""Disaggregated architecture specific determinism tester."""
def __init__(
self,
base_url: Optional[str] = None,
model_id: Optional[str] = None,
server_type: Optional[str] = ServerType.vllm,
):
super().__init__(base_url, model_id, server_type)
def reset_prefix_cache(self):
"""Reset the prefix cache."""
print("Resetting prefix cache...")
# 150 shakespeare requests (each request is 200 words, and roughly 17 blocks) could evict 150 * 17 = 2550 blocks
shakespeare_count = 150
for seq_idx in range(1, shakespeare_count + 1):
start_word = (seq_idx - 1) * self.word_count
content = self.get_shakespeare_content(start_word)
if content:
print(
f"Resetting Shakespeare sequence {seq_idx} (words {start_word}-{start_word + self.word_count - 1})..."
)
try:
self.make_request(content)
except Exception as e:
print(f"Resetting request failed: {e}")
print("Cache reset done")
@pytest.fixture(scope="function")
def llm_server(request, runtime_services):
"""Start and stop a LLM server for each test with optional cache block overrides.
To parametrize, use:
@pytest.mark.parametrize("llm_server", [{"cpu_blocks": 10000, "gpu_blocks": 1000}], indirect=True)
"""
logger = logging.getLogger("pytest")
logger.setLevel(logging.INFO)
cpu_blocks = getattr(request, "param", {}).get("cpu_blocks", None)
gpu_blocks = getattr(request, "param", {}).get("gpu_blocks", None)
port = getattr(request, "param", {}).get("port", None)
# Put logs in the per-test directory set up by tests/conftest.py
log_dir = Path(request.node.name)
if importlib.util.find_spec("vllm") is not None:
server_type = ServerType.vllm
else:
raise Exception("vllm module is not available in the current environment.")
server_manager = LLMServerManager(
port=port,
cpu_cache_blocks=cpu_blocks,
gpu_cache_blocks=gpu_blocks,
log_dir=log_dir,
server_type=server_type,
)
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
if not server_manager.start_server(timeout=start_timeout):
pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
)
yield server_manager
server_manager.stop_server()
@pytest.fixture(scope="function")
def tester(llm_server):
"""Create determinism tester bound to the running server's base URL."""
t = DisaggDeterminismTester(
base_url=llm_server.base_url, server_type=llm_server.server_type
)
t.download_shakespeare_text()
return t
class TestDeterminismDisagg(BaseTestDeterminism):
"""Test class for determinism validation."""
@pytest.mark.parametrize(
"llm_server",
[
{
"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000")),
"gpu_blocks": int(os.environ.get("KVBM_GPU_BLOCKS", "1000")),
},
],
indirect=True,
)
def test_determinism_disagg_with_cache_reset(
self, tester, llm_server, runtime_services
):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
# Call the base class implementation
super().base_test_determinism_with_cache_reset(
tester,
llm_server,
runtime_services,
success_rate_threshold=SUCCESS_RATE_THRESHOLD,
)
if __name__ == "__main__":
# Allow running as script
pytest.main([__file__, "-v", "-s"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment