Unverified Commit ba34f131 authored by Olga Andreeva's avatar Olga Andreeva Committed by GitHub
Browse files

feat(kvbm): add MLA support for DeepSeek-V2 and extend determinism tests (#7786)


Signed-off-by: default avatarOlga Andreeva <oandreeva@nvidia.com>
Co-authored-by: default avatarClaude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: default avatarwz1qqx <ziqi.wang@novita.ai>
parent fd5cc288
......@@ -7,6 +7,7 @@ Implementation of vLLM KV cache manager protocol.
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
......@@ -31,6 +32,39 @@ if TYPE_CHECKING:
from kvbm.vllm_integration.rust import KvConnectorWorker as RustKvConnectorWorker
@dataclass
class KvTensorLayout:
"""Semantic description of the KV cache tensor layout.
Python derives this from VllmConfig (which has the model architecture)
so that Rust does not need to guess from raw tensor shapes.
For MLA models, outer_dim and inner_dim are set explicitly because Rust's
shape-based inference cannot distinguish the fused KV latent axis from the
block/page axis. For standard attention the fields are None, which tells
Rust to fall back to its own contiguity-based inference (already correct).
"""
outer_dim: Optional[int] # None = let Rust detect; 1 for MLA
inner_dim: Optional[int] # None = let Rust detect; head_size for MLA
@classmethod
def from_vllm_config(
cls, vllm_config: "VllmConfig", shape: "torch.Size", use_mla: bool = False
) -> "KvTensorLayout":
if use_mla:
# MLA tensors are 3D: [num_blocks, page_size, head_size]
# No outer_dim axis — K and V are fused into a single latent.
return cls(outer_dim=1, inner_dim=shape[-1])
else:
# Standard attention: Rust already infers outer_dim and inner_dim
# correctly from tensor strides/contiguity. Don't guess here —
# the block dimension can be at position 0 or 1 depending on the
# attention backend, so shape[1] is not reliably outer_dim.
return cls(outer_dim=None, inner_dim=None)
DistributedRuntime = None
if is_dyn_runtime_enabled():
from dynamo.runtime import DistributedRuntime
......@@ -103,12 +137,18 @@ class KvConnectorWorker:
"Hybrid models with different KV cache shapes are not supported yet."
)
# Extract parameters
# TODO: Assume the block dimension is within the first 2. This will break if you're doing something weird like having 1 or 2 device blocks.
num_device_blocks = max(shape[0], shape[1])
page_size = cache_config.block_size
use_mla = getattr(self.vllm_config.model_config, "use_mla", False)
# MLA tensors are [num_blocks, page_size, head_size] — block dim is always axis 0.
# Standard attention tensors are [outer_dim, num_blocks, ...] or [num_blocks, outer_dim, ...]
# — block dim is whichever axis is larger, which is unambiguous as long as
# num_blocks >> outer_dim (always true in practice).
num_device_blocks = shape[0] if use_mla else max(shape[0], shape[1])
device_id = first_tensor.device.index
layout = KvTensorLayout.from_vllm_config(self.vllm_config, shape, use_mla)
# Determine cache dtype
if cache_config.cache_dtype == "auto":
kv_cache_dtype = self.vllm_config.model_config.dtype
......@@ -123,6 +163,8 @@ class KvConnectorWorker:
kv_cache_dtype.itemsize,
ordered_kv_caches,
raw_event_handles,
outer_dim=layout.outer_dim,
inner_dim=layout.inner_dim,
)
def bind_connector_metadata(self, data: bytes) -> None:
......
......@@ -37,6 +37,8 @@ pub trait Worker: Send + Sync {
device_layout_type: Option<LayoutType>,
host_layout_type: Option<LayoutType>,
disk_layout_type: Option<LayoutType>,
outer_dim: Option<usize>,
inner_dim: Option<usize>,
) -> anyhow::Result<()>;
fn bind_connector_metadata(&mut self, metadata: Vec<u8>) -> anyhow::Result<()>;
......@@ -132,6 +134,8 @@ impl Worker for KvConnectorWorker {
device_layout_type: Option<LayoutType>,
host_layout_type: Option<LayoutType>,
disk_layout_type: Option<LayoutType>,
outer_dim: Option<usize>,
inner_dim: Option<usize>,
) -> anyhow::Result<()> {
if self.kvbm_worker.get().is_some() {
tracing::warn!("kvbm worker already registered");
......@@ -194,7 +198,7 @@ impl Worker for KvConnectorWorker {
}
};
let config = KvbmWorkerConfig::builder()
let mut config_builder = KvbmWorkerConfig::builder()
.cancel_token(get_current_cancel_token())
.num_device_blocks(num_device_blocks)
.page_size(page_size)
......@@ -206,8 +210,14 @@ impl Worker for KvConnectorWorker {
.host_layout_type(host_layout_type.unwrap_or(LayoutType::FullyContiguous))
.disk_layout_type(disk_layout_type.unwrap_or(LayoutType::FullyContiguous))
.leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.build()?;
.leader_ack_url(get_leader_zmq_ack_url());
if let Some(od) = outer_dim {
config_builder = config_builder.outer_dim(Some(od));
}
if let Some(id) = inner_dim {
config_builder = config_builder.inner_dim(Some(id));
}
let config = config_builder.build()?;
let worker = get_current_tokio_handle().block_on(async move {
let worker = KvbmWorker::new(config, false).await?;
......@@ -485,7 +495,7 @@ impl PyKvConnectorWorker {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (num_device_blocks, page_size, device_id, dtype_width_bytes, kv_caches, raw_event_handles, device_layout_type=None, host_layout_type=None, disk_layout_type=None))]
#[pyo3(signature = (num_device_blocks, page_size, device_id, dtype_width_bytes, kv_caches, raw_event_handles, device_layout_type=None, host_layout_type=None, disk_layout_type=None, outer_dim=None, inner_dim=None))]
pub fn register_kv_caches(
&mut self,
num_device_blocks: usize,
......@@ -497,6 +507,8 @@ impl PyKvConnectorWorker {
device_layout_type: Option<PyLayoutType>,
host_layout_type: Option<PyLayoutType>,
disk_layout_type: Option<PyLayoutType>,
outer_dim: Option<usize>,
inner_dim: Option<usize>,
) -> PyResult<()> {
// Convert Python tensors to Rust VllmTensor objects
let mut rust_kv_caches = Vec::new();
......@@ -516,6 +528,8 @@ impl PyKvConnectorWorker {
device_layout_type.map(|py_layout| py_layout.into()),
host_layout_type.map(|py_layout| py_layout.into()),
disk_layout_type.map(|py_layout| py_layout.into()),
outer_dim,
inner_dim,
)
.map_err(to_pyerr)
}
......
......@@ -25,6 +25,7 @@ use nixl_sys::Agent as NixlAgent;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use validator::Validate;
use tokio::runtime::Handle;
use tokio_util::sync::CancellationToken;
......@@ -394,13 +395,23 @@ impl Handler for BlockTransferDispatch {
}
}
#[derive(Builder, Clone)]
fn validate_page_size(value: usize) -> Result<(), validator::ValidationError> {
if !value.is_power_of_two() {
return Err(validator::ValidationError::new(
"page_size_not_power_of_two",
));
}
Ok(())
}
#[derive(Builder, Clone, Validate)]
#[builder(pattern = "owned")]
pub struct KvbmWorkerConfig {
cancel_token: CancellationToken,
num_device_blocks: usize,
#[validate(custom(function = "validate_page_size"), range(max = 1024))]
#[builder(default = "32")]
page_size: usize,
......@@ -422,6 +433,18 @@ pub struct KvbmWorkerConfig {
#[builder(default = "LayoutType::FullyContiguous")]
disk_layout_type: LayoutType,
/// Explicit outer dimension (1 for MLA, 2 for standard K/V).
/// When set, skips shape inference. Python should always provide this.
#[validate(range(min = 1, max = 2))]
#[builder(default = "None")]
pub outer_dim: Option<usize>,
/// Explicit inner dimension (head_size for MLA, num_heads * head_dim for standard).
/// When set, skips shape inference. Python should always provide this.
#[validate(range(min = 1))]
#[builder(default = "None")]
pub inner_dim: Option<usize>,
#[builder(default = "None")]
scheduler_client: Option<TransferSchedulerClient>,
......@@ -444,6 +467,33 @@ impl KvbmWorkerConfig {
pub fn builder() -> KvbmWorkerConfigBuilder {
KvbmWorkerConfigBuilder::default()
}
/// Validate configuration contract before use.
///
/// Field-level rules (`outer_dim`, `inner_dim`, `page_size`) are enforced via
/// `#[validate]` attributes on the struct. This method additionally checks the
/// cross-field coupling invariant: `outer_dim` and `inner_dim` must both be
/// `Some` or both be `None`.
pub fn validate(&self) -> anyhow::Result<()> {
// Run derive-based field validators (#[validate] attributes).
<Self as Validate>::validate(self)
.map_err(|e| anyhow::anyhow!("KvbmWorkerConfig validation failed: {}", e))?;
// Cross-field: outer_dim and inner_dim must be coupled (both Some or both None).
match (self.outer_dim, self.inner_dim) {
(Some(_), None) | (None, Some(_)) => {
anyhow::bail!(
"outer_dim and inner_dim must be provided together (both Some or both None); \
got outer_dim={:?}, inner_dim={:?}",
self.outer_dim,
self.inner_dim
);
}
_ => {}
}
Ok(())
}
}
pub struct KvbmWorker {
......@@ -460,6 +510,8 @@ impl KvbmWorker {
config.dtype_width_bytes
);
config.validate()?;
if config.num_device_blocks == 0 {
return Err(anyhow::anyhow!("num_device_blocks must be greater than 0"));
}
......@@ -476,8 +528,10 @@ impl KvbmWorker {
let (layout_type, num_layers, outer_dim, inner_dim) = match config.device_layout_type {
LayoutType::FullyContiguous => {
let num_layers = shape[1];
let outer_dim = shape[2];
let inner_dim = shape[3..].iter().product::<usize>() / config.page_size;
let outer_dim = config.outer_dim.unwrap_or(shape[2]);
let inner_dim = config
.inner_dim
.unwrap_or_else(|| shape[3..].iter().product::<usize>() / config.page_size);
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}",
num_layers,
......@@ -494,21 +548,37 @@ impl KvbmWorker {
)
}
LayoutType::LayerSeparate { outer_contiguous } => {
// Use the already-detected layout type from config (no re-detection needed)
let layout_type = config.device_layout_type;
let num_layers = device_tensors.len();
// Extract outer_dim based on the provided outer_contiguous value
let outer_dim = if outer_contiguous {
shape[0] // Outer contiguous: [outer_dim, n_blocks, ...]
let (outer_dim, inner_dim) = match (config.outer_dim, config.inner_dim) {
// Explicit dims provided by caller (e.g. Python via KvTensorLayout) — use as-is.
(Some(od), Some(id)) => (od, id),
// No explicit dims: infer from shape.
// outer_dim valid range is [1, 2] (1 = MLA fused, 2 = standard K/V split).
// If the candidate dimension exceeds 2 the tensor has no explicit K/V axis
// (e.g. MLA models produce [n_blocks, page_size, latent_dim]) — fall back to
// outer_dim=1 and compute inner_dim from all dims after n_blocks.
_ => {
let candidate = if outer_contiguous { shape[0] } else { shape[1] };
if candidate <= 2 {
// Standard layout: outer_dim is encoded in the shape.
// outer_contiguous=true: [outer_dim, n_blocks, page_size, inner_dim]
// outer_contiguous=false: [n_blocks, outer_dim, page_size, inner_dim]
let inner_dim = shape[2..].iter().product::<usize>() / config.page_size;
(candidate, inner_dim)
} else {
shape[1] // Block contiguous: [n_blocks, outer_dim, ...]
// MLA-style: no explicit K/V split, treat as outer_dim=1.
let dims_start = if outer_contiguous { 2 } else { 1 };
let inner_dim =
shape[dims_start..].iter().product::<usize>() / config.page_size;
(1, inner_dim)
}
}
};
let num_layers = device_tensors.len();
let inner_dim = shape[2..].iter().product::<usize>() / config.page_size;
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, outer_contiguous={}, page_size={}, inner_dim={}",
"Layout: num_layers={}, outer_dim={}, outer_contiguous={}, page_size={}, inner_dim={}",
num_layers,
outer_dim,
outer_contiguous,
......@@ -799,3 +869,136 @@ impl Drop for KvbmWorker {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_util::sync::CancellationToken;
fn base_config() -> KvbmWorkerConfig {
KvbmWorkerConfig::builder()
.cancel_token(CancellationToken::new())
.num_device_blocks(1)
.build()
.expect("base config should build")
}
// --- outer_dim ---
#[test]
fn validate_outer_dim_none_is_ok() {
let mut cfg = base_config();
cfg.outer_dim = None;
cfg.inner_dim = None;
assert!(cfg.validate().is_ok());
}
#[test]
fn validate_outer_dim_1_is_ok() {
let mut cfg = base_config();
cfg.outer_dim = Some(1);
cfg.inner_dim = Some(64);
assert!(cfg.validate().is_ok());
}
#[test]
fn validate_outer_dim_2_is_ok() {
let mut cfg = base_config();
cfg.outer_dim = Some(2);
cfg.inner_dim = Some(64);
assert!(cfg.validate().is_ok());
}
#[test]
fn validate_outer_dim_3_is_err() {
let mut cfg = base_config();
cfg.outer_dim = Some(3);
cfg.inner_dim = Some(64);
let err = cfg.validate().unwrap_err().to_string();
assert!(
err.contains("outer_dim") && err.contains("range"),
"got: {err}"
);
}
#[test]
fn validate_outer_dim_0_is_err() {
let mut cfg = base_config();
cfg.outer_dim = Some(0);
cfg.inner_dim = Some(64);
let err = cfg.validate().unwrap_err().to_string();
assert!(
err.contains("outer_dim") && err.contains("range"),
"got: {err}"
);
}
// --- inner_dim ---
#[test]
fn validate_inner_dim_zero_is_err() {
let mut cfg = base_config();
cfg.outer_dim = Some(2);
cfg.inner_dim = Some(0);
let err = cfg.validate().unwrap_err().to_string();
assert!(
err.contains("inner_dim") && err.contains("range"),
"got: {err}"
);
}
// --- coupling ---
#[test]
fn validate_outer_some_inner_none_is_err() {
let mut cfg = base_config();
cfg.outer_dim = Some(2);
cfg.inner_dim = None;
let err = cfg.validate().unwrap_err().to_string();
assert!(err.contains("provided together"), "got: {err}");
}
#[test]
fn validate_outer_none_inner_some_is_err() {
let mut cfg = base_config();
cfg.outer_dim = None;
cfg.inner_dim = Some(64);
let err = cfg.validate().unwrap_err().to_string();
assert!(err.contains("provided together"), "got: {err}");
}
// --- page_size ---
#[test]
fn validate_page_size_power_of_two_is_ok() {
for size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
let mut cfg = base_config();
cfg.page_size = size;
assert!(cfg.validate().is_ok(), "expected ok for page_size={size}");
}
}
#[test]
fn validate_page_size_not_power_of_two_is_err() {
for size in [3, 5, 6, 7, 100, 300] {
let mut cfg = base_config();
cfg.page_size = size;
let err = cfg.validate().unwrap_err().to_string();
assert!(
err.contains("page_size_not_power_of_two"),
"expected power-of-two error for page_size={size}, got: {err}"
);
}
}
#[test]
fn validate_page_size_exceeds_max_is_err() {
let mut cfg = base_config();
cfg.page_size = 2048;
let err = cfg.validate().unwrap_err().to_string();
assert!(
err.contains("page_size") && err.contains("range"),
"got: {err}"
);
}
}
......@@ -2210,4 +2210,160 @@ pub mod tests {
);
}
}
// ============================================================================
// MLA (outer_dim=1) TESTS
// ============================================================================
mod mla_outer_dim_1_tests {
use super::*;
const MLA_NUM_BLOCKS: usize = 100;
const MLA_NUM_LAYERS: usize = 61;
const MLA_OUTER_DIM: usize = 1;
const MLA_PAGE_SIZE: usize = 64;
const MLA_INNER_DIM: usize = 576;
const MLA_DTYPE_WIDTH_BYTES: usize = 2;
fn setup_mla_layer_separate(
is_outer_contiguous: bool,
) -> Result<LayerSeparate<NullDeviceStorage>, LayoutError> {
let config = LayoutConfig {
num_blocks: MLA_NUM_BLOCKS,
num_layers: MLA_NUM_LAYERS,
outer_dim: MLA_OUTER_DIM,
page_size: MLA_PAGE_SIZE,
inner_dim: MLA_INNER_DIM,
alignment: 1,
dtype_width_bytes: MLA_DTYPE_WIDTH_BYTES,
};
let ls_config = LayerSeparateConfig::new(config.clone(), is_outer_contiguous)?;
let required_size = ls_config.required_allocation_size();
let storages = (0..MLA_NUM_LAYERS)
.map(|_| NullDeviceStorage::new(required_size as u64))
.collect();
LayerSeparate::new(config, storages, is_outer_contiguous)
}
#[test]
fn test_mla_ls_creation_block_contiguous() {
let layout =
setup_mla_layer_separate(false).expect("MLA LayerSeparate creation should succeed");
assert_eq!(layout.num_blocks(), MLA_NUM_BLOCKS);
assert_eq!(layout.num_layers(), MLA_NUM_LAYERS);
assert_eq!(layout.outer_dim(), 1);
assert_eq!(layout.page_size(), MLA_PAGE_SIZE);
assert_eq!(layout.inner_dim(), MLA_INNER_DIM);
assert_eq!(
layout.layout_type(),
LayoutType::LayerSeparate {
outer_contiguous: false
}
);
}
#[test]
fn test_mla_ls_creation_outer_contiguous() {
let layout =
setup_mla_layer_separate(true).expect("MLA LayerSeparate creation should succeed");
assert_eq!(layout.outer_dim(), 1);
assert_eq!(
layout.layout_type(),
LayoutType::LayerSeparate {
outer_contiguous: true
}
);
}
#[test]
fn test_mla_ls_memory_region_size() {
let layout = setup_mla_layer_separate(false).expect("Layout setup failed");
let region = layout.memory_region(0, 0, 0).unwrap();
let expected_size = MLA_PAGE_SIZE * MLA_INNER_DIM * MLA_DTYPE_WIDTH_BYTES;
assert_eq!(
region.size, expected_size,
"MLA memory region size should be page_size * inner_dim * dtype_bytes = {} * {} * {} = {}",
MLA_PAGE_SIZE, MLA_INNER_DIM, MLA_DTYPE_WIDTH_BYTES, expected_size
);
}
#[test]
fn test_mla_ls_block_stride() {
let layout = setup_mla_layer_separate(false).expect("Layout setup failed");
let region_0 = layout.memory_region(0, 0, 0).unwrap();
let region_1 = layout.memory_region(1, 0, 0).unwrap();
// With outer_dim=1 and block_contiguous, block_stride = outer_dim_stride * outer_dim
// = (page_size * inner_dim * dtype) * 1
let expected_stride = MLA_PAGE_SIZE * MLA_INNER_DIM * MLA_DTYPE_WIDTH_BYTES;
assert_eq!(
region_1.addr - region_0.addr,
expected_stride,
"MLA block stride should equal memory_region_size when outer_dim=1"
);
}
#[test]
fn test_mla_ls_outer_idx_must_be_zero() {
let layout = setup_mla_layer_separate(false).expect("Layout setup failed");
assert!(
layout.memory_region(0, 0, 0).is_ok(),
"outer_idx=0 should be valid for outer_dim=1"
);
assert!(
layout.memory_region(0, 0, 1).is_err(),
"outer_idx=1 should be invalid for outer_dim=1"
);
}
#[test]
fn test_mla_ls_verify_all_regions() {
let layout = setup_mla_layer_separate(false).expect("Layout setup failed");
assert!(
layout.verify_memory_regions().is_ok(),
"MLA LayerSeparate memory region verification should pass"
);
for block_idx in 0..MLA_NUM_BLOCKS {
for layer_idx in 0..MLA_NUM_LAYERS {
let matches = layout
.verify_memory_region(block_idx, layer_idx, 0)
.expect("Verification should not error");
assert!(
matches,
"MLA region ({}, {}, 0) should match",
block_idx, layer_idx
);
}
}
}
#[test]
fn test_mla_ls_address_calculation() {
let layout = setup_mla_layer_separate(false).expect("Layout setup failed");
for block_idx in 0..MLA_NUM_BLOCKS {
for layer_idx in 0..MLA_NUM_LAYERS {
let actual = layout.memory_region(block_idx, layer_idx, 0).unwrap();
let expected_addr = layout
.expected_memory_address(block_idx, layer_idx, 0)
.unwrap();
assert_eq!(
actual.addr, expected_addr,
"MLA address mismatch at ({}, {}, 0)",
block_idx, layer_idx
);
}
}
}
}
}
......@@ -12,6 +12,7 @@ aggregated and disaggregated determinism tests.
import importlib.util
import os
import re
import tempfile
import time
from collections import defaultdict
from difflib import SequenceMatcher
......@@ -734,9 +735,15 @@ class TestDeterminism:
"""
import subprocess
model = os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)
model = llm_server.model_config.model_id
# NOTE: with large models (e.g. DeepSeek-V2-Lite), vllm bench decodes all
# prompts through the tokenizer before sending requests. 2000 x 4000-token
# prompts take ~160s to decode, exceeding KVBM_BENCH_STARTUP_WAIT (120s).
# Reduce via KVBM_BENCH_NUM_PROMPTS / KVBM_BENCH_INPUT_LEN if needed.
num_prompts = int(os.environ.get("KVBM_BENCH_NUM_PROMPTS", "2000"))
input_len = int(os.environ.get("KVBM_BENCH_INPUT_LEN", "4000"))
output_len = int(os.environ.get("KVBM_BENCH_OUTPUT_LEN", "180"))
concurrency = int(os.environ.get("KVBM_BENCH_CONCURRENCY", "7"))
bench_cmd = [
"vllm",
"bench",
......@@ -750,18 +757,20 @@ class TestDeterminism:
"--dataset-name",
"random",
"--random-input-len",
"4000",
str(input_len),
"--random-output-len",
"180",
str(output_len),
"--max-concurrency",
"7",
str(concurrency),
"--num-prompts",
"2000",
str(num_prompts),
]
print(f"\nStarting vllm bench: {' '.join(bench_cmd)}")
bench_log = os.path.join(str(Path(".")), "vllm_bench_semantic.log")
bench_file = open(bench_log, "w")
bench_fd, bench_log = tempfile.mkstemp(
suffix=".log", prefix="vllm_bench_semantic_"
)
bench_file = os.fdopen(bench_fd, "w")
bench_process = subprocess.Popen(
bench_cmd,
stdout=bench_file,
......
......@@ -27,6 +27,7 @@ import subprocess
import sys
import threading
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, TextIO
......@@ -43,19 +44,70 @@ from .common import check_module_available
HAS_VLLM_BENCH = check_module_available("vllm")
@dataclass
class KvbmModelConfig:
"""Describes a model and the vLLM serving flags needed for KVBM testing."""
model_id: str
block_size: Optional[int] = None # None = let vllm decide
attention_backend: Optional[str] = None # None = let vllm decide
max_model_len: int = 8000
# Set False for MLA models: VLLM_BATCH_INVARIANT=1 disables prefix caching
# for TRITON_MLA in vLLM 0.17.1, defeating KV offload testing.
batch_invariant: bool = True
@property
def short_name(self) -> str:
return self.model_id.split("/")[-1]
@property
def use_mla(self) -> bool:
"""True when the model uses Multi-head Latent Attention (e.g. TRITON_MLA)."""
return self.attention_backend is not None and "MLA" in self.attention_backend
# Models exercised by this test suite.
# CI iterates over all entries; add a new entry to test an additional model.
_MODEL_CONFIGS: List[KvbmModelConfig] = [
KvbmModelConfig(
model_id=os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
),
block_size=16,
attention_backend="FLASH_ATTN",
),
KvbmModelConfig(
model_id="deepseek-ai/DeepSeek-V2-Lite",
# TRITON_MLA works on all devices; on H100 set KVBM_MLA_BACKEND=FLASH_ATTN_MLA
attention_backend=os.environ.get("KVBM_MLA_BACKEND", "TRITON_MLA"),
# VLLM_BATCH_INVARIANT=1 disables prefix caching for TRITON_MLA in vLLM 0.17.1
batch_invariant=False,
),
]
# KVBM env vars that drive test duration (used to compute timeouts below).
_KVBM_MAX_ITERATIONS = int(os.environ.get("KVBM_MAX_ITERATIONS", "100"))
_KVBM_NUM_ITERATIONS = int(os.environ.get("KVBM_NUM_ITERATIONS", "15"))
_KVBM_REQUEST_DELAY = int(os.environ.get("KVBM_REQUEST_DELAY", "30"))
# Server startup timeout (env-configurable; larger models like DeepSeek-V2-Lite
# may need 600s+).
_SERVER_START_TIMEOUT = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "600"))
# Compute timeouts from the same env vars that control test duration.
# test_determinism_agg_with_cache_reset: runs warmup + 2 phases of KVBM_MAX_ITERATIONS,
# each iteration ~4s (request + overhead), plus ~50s setup/teardown.
_CACHE_RESET_TIMEOUT = 2 * (_KVBM_MAX_ITERATIONS * 4 + 50)
# Each formula adds _SERVER_START_TIMEOUT so the pytest timeout covers both
# the server startup and the actual test body.
#
# test_determinism_agg_with_cache_reset: warmup + 2 phases of KVBM_MAX_ITERATIONS,
# each iteration ~4s (request + overhead), plus ~50s teardown.
_CACHE_RESET_TIMEOUT = _SERVER_START_TIMEOUT + 2 * (_KVBM_MAX_ITERATIONS * 4 + 50)
# test_concurrent_determinism_under_load: dominated by
# (KVBM_NUM_ITERATIONS - 1) * KVBM_REQUEST_DELAY seconds of sleep,
# plus ~150s overhead (server startup, benchmark ramp, teardown).
_CONCURRENT_TIMEOUT = 2 * ((_KVBM_NUM_ITERATIONS - 1) * _KVBM_REQUEST_DELAY + 150)
# plus ~150s overhead (benchmark ramp, teardown).
_CONCURRENT_TIMEOUT = _SERVER_START_TIMEOUT + 2 * (
(_KVBM_NUM_ITERATIONS - 1) * _KVBM_REQUEST_DELAY + 150
)
# Test markers to align with repository conventions
# Todo: enable the rest when kvbm is built in the ci
......@@ -78,8 +130,10 @@ class LLMServerManager:
gpu_cache_blocks: Optional[int] = None,
log_dir: Optional[Path] = None,
server_type: Optional[str] = ServerType.vllm,
model_config: Optional[KvbmModelConfig] = None,
):
self.server_type = server_type
self.model_config = model_config or _MODEL_CONFIGS[0]
# Use provided port, env var, or allocate a dynamic port to avoid conflicts
if port is not None:
self.port = port
......@@ -121,8 +175,6 @@ class LLMServerManager:
# Enable KVBM metrics for monitoring offload/onboard
"DYN_KVBM_METRICS": "true",
"DYN_KVBM_METRICS_PORT": str(self.metrics_port),
# Enable vLLM batch invariant for deterministic batching
"VLLM_BATCH_INVARIANT": "1",
}
)
......@@ -131,7 +183,7 @@ class LLMServerManager:
self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks)
if self.server_type == ServerType.vllm:
self._set_up_vllm_config(gpu_cache_blocks)
self._set_up_vllm_config(gpu_cache_blocks, self.model_config)
elif self.server_type == ServerType.trtllm:
self._set_up_trtllm_config(gpu_cache_blocks)
else:
......@@ -139,27 +191,33 @@ class LLMServerManager:
f"{self.server_type} is not supported yet in the KVBM test suite"
)
def _set_up_vllm_config(self, gpu_cache_blocks):
def _set_up_vllm_config(self, gpu_cache_blocks, model_config: KvbmModelConfig):
self.env["VLLM_SERVER_DEV_MODE"] = "1"
if model_config.batch_invariant:
self.env["VLLM_BATCH_INVARIANT"] = "1"
else:
self.env.pop("VLLM_BATCH_INVARIANT", None)
# Construct serve command
self.server_cmd = [
"vllm",
"serve",
"--block-size",
"16",
"--port",
str(self.port),
"--kv-transfer-config",
'{"kv_connector":"DynamoConnector","kv_role":"kv_both", "kv_connector_module_path": "kvbm.vllm_integration.connector"}',
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--attention-config.backend",
"FLASH_ATTN",
model_config.model_id,
"--max-model-len",
"8000", # required to fit on L4 GPU when using 8b model
str(model_config.max_model_len),
]
# GPU blocks override
if model_config.block_size is not None:
self.server_cmd.extend(["--block-size", str(model_config.block_size)])
if model_config.attention_backend is not None:
self.server_cmd.extend(
["--attention-config.backend", model_config.attention_backend]
)
if gpu_cache_blocks is not None:
self.server_cmd.extend(["--num-gpu-blocks-override", str(gpu_cache_blocks)])
......@@ -332,9 +390,7 @@ class LLMServerManager:
# Then check if the model endpoint is ready with a simple test request
test_payload = {
"model": os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
),
"model": self.model_config.model_id,
"messages": [{"role": "user", "content": "test"}],
"max_completion_tokens": 1,
"temperature": 0,
......@@ -404,6 +460,7 @@ def llm_server(request, runtime_services):
cpu_blocks = getattr(request, "param", {}).get("cpu_blocks", None)
gpu_blocks = getattr(request, "param", {}).get("gpu_blocks", None)
port = getattr(request, "param", {}).get("port", None)
model_config = getattr(request, "param", {}).get("model_config", None)
# Put logs in the per-test directory set up by tests/conftest.py
log_dir = Path(resolve_test_output_path(request.node.name))
......@@ -423,9 +480,10 @@ def llm_server(request, runtime_services):
gpu_cache_blocks=gpu_blocks,
log_dir=log_dir,
server_type=server_type,
model_config=model_config,
)
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
start_timeout = _SERVER_START_TIMEOUT
if not server_manager.start_server(timeout=start_timeout):
pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
......@@ -441,6 +499,7 @@ def tester(llm_server):
"""Create determinism tester bound to the running server's base URL."""
t = AggDeterminismTester(
base_url=llm_server.base_url,
model_id=llm_server.model_config.model_id,
server_type=llm_server.server_type,
)
t.download_shakespeare_text()
......@@ -456,9 +515,12 @@ class TestDeterminismAgg(BaseTestDeterminism):
{
"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000")),
"gpu_blocks": int(os.environ.get("KVBM_GPU_BLOCKS", "2048")),
},
"model_config": cfg,
}
for cfg in _MODEL_CONFIGS
],
indirect=True,
ids=[cfg.short_name for cfg in _MODEL_CONFIGS],
)
@pytest.mark.kvbm
@pytest.mark.timeout(
......@@ -479,9 +541,12 @@ class TestDeterminismAgg(BaseTestDeterminism):
{
"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "30000")),
"gpu_blocks": int(os.environ.get("KVBM_GPU_BLOCKS", "2048")),
},
"model_config": cfg,
}
for cfg in _MODEL_CONFIGS
],
indirect=True,
ids=[cfg.short_name for cfg in _MODEL_CONFIGS],
)
@pytest.mark.kvbm_concurrency
@pytest.mark.skipif(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment