Unverified Commit 0df6d462 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: add KV Event Publishing to vLLM v1 (#1181)

parent 93ca9df1
......@@ -269,6 +269,19 @@ dependencies = [
"zmq",
]
[[package]]
name = "asynchronous-codec"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a860072022177f903e59730004fb5dc13db9275b79bb2aef7ba8ce831956c233"
dependencies = [
"bytes",
"futures-sink",
"futures-util",
"memchr",
"pin-project-lite",
]
[[package]]
name = "atomic"
version = "0.6.0"
......@@ -1285,6 +1298,19 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "dashmap"
version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [
"cfg-if 1.0.0",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.9.0"
......@@ -1606,6 +1632,7 @@ dependencies = [
"rayon",
"regex",
"reqwest",
"rmp-serde",
"rstest 0.18.2",
"rstest_reuse",
"sentencepiece",
......@@ -1626,6 +1653,7 @@ dependencies = [
"uuid 1.16.0",
"validator",
"xxhash-rust",
"zeromq",
]
[[package]]
......@@ -2607,6 +2635,12 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "hashbrown"
version = "0.15.2"
......@@ -5275,6 +5309,28 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rmp"
version = "0.8.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4"
dependencies = [
"byteorder",
"num-traits",
"paste",
]
[[package]]
name = "rmp-serde"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db"
dependencies = [
"byteorder",
"rmp",
"serde",
]
[[package]]
name = "router"
version = "0.2.1"
......@@ -6514,6 +6570,7 @@ checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
dependencies = [
"bytes",
"futures-core",
"futures-io",
"futures-sink",
"pin-project-lite",
"tokio",
......@@ -7783,6 +7840,33 @@ version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
[[package]]
name = "zeromq"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a4528179201f6eecf211961a7d3276faa61554c82651ecc66387f68fc3004bd"
dependencies = [
"async-trait",
"asynchronous-codec",
"bytes",
"crossbeam-queue",
"dashmap",
"futures-channel",
"futures-io",
"futures-task",
"futures-util",
"log",
"num-traits",
"once_cell",
"parking_lot",
"rand 0.8.5",
"regex",
"thiserror 1.0.69",
"tokio",
"tokio-util",
"uuid 1.16.0",
]
[[package]]
name = "zeromq-src"
version = "0.2.6+4.3.4"
......
......@@ -115,6 +115,7 @@ fn mock_stats_handler(_stats: EndpointStats) -> serde_json::Value {
let gpu_cache_usage_perc = rand::rng().random_range(0.0..=1.0);
let gpu_prefix_cache_hit_rate = rand::rng().random_range(0.0..=1.0);
let stats = ForwardPassMetrics {
data_parallel_rank: None, // Default for backwards compatibility
request_active_slots,
request_total_slots,
kv_active_blocks,
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# `dynamo-run out=vllm` runs this script
# Can also be used standalone: `python3 vllm_inc.py` - lots of optional cmd line params
# Setup checklist:
# - We are in a virtualenv with vllm installed. Must be newer than v0.9.0 (currently pre-release)
# 1f079540db5f1080a2f61a730da50d3009934c5a - this commit is working for me
# Steps:
# git clone https://github.com/vllm-project/vllm.git
# cd vllm && git checkout 1f079540db5f1080a2f61a730da50d3009934c5a
# uv pip uninstall ai-dynamo-vllm
# VLLM_USE_PRECOMPILED=1 uv pip install --editable .
import argparse
import asyncio
import json
import logging
import os
import sys
import uuid
from typing import Optional
import uvloop
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVEventsConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from dynamo.llm import (
KvEventPublisherFromZmq,
KvEventPublisherFromZmqConfig,
KvMetricsPublisher,
ModelType,
register_llm,
)
from dynamo.runtime import Component, DistributedRuntime, dynamo_worker
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
tensor_parallel_size: int
kv_block_size: int
context_length: int
extra_engine_args: str
class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the KvMetricsPublisher to match the StatLoggerBase interface."""
def __init__(self, component: Component, dp_rank: int) -> None:
self.inner = KvMetricsPublisher()
self.inner.create_endpoint(component)
self.dp_rank = dp_rank
def record(
self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats]
):
# request_total_slots and kv_total_blocks are properties of model + gpu
# we should only publish them once, not every metric update
# they should be part of some runtime metadata tied to MDC or put in etcd ?
hit_rate = 0
if scheduler_stats.prefix_cache_stats.queries > 0:
hit_rate = (
scheduler_stats.prefix_cache_stats.hits
/ scheduler_stats.prefix_cache_stats.queries
)
# TODO Manage DP Ranks in metrics aggregation.
self.inner.publish(
request_active_slots=scheduler_stats.num_running_reqs,
request_total_slots=0, # TODO - remove from metrics
kv_active_blocks=0, # TODO - need to calculate this
kv_total_blocks=0, # TODO - remove from metrics
num_requests_waiting=scheduler_stats.num_waiting_reqs, # used in current cost function
gpu_cache_usage_perc=scheduler_stats.gpu_cache_usage, # used in current cost function
gpu_prefix_cache_hit_rate=hit_rate,
data_parallel_rank=self.dp_rank,
)
def log_engine_initialized(self) -> None:
pass
class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""
def __init__(self, component: Component) -> None:
self.component = component
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
return DynamoStatLoggerPublisher(self.component, dp_rank)
def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase:
return self.create_stat_logger(dp_rank=dp_rank)
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, component, engine, default_sampling_params):
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
async def generate(self, request):
request_id = str(uuid.uuid4().hex)
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = SamplingParams(**self.default_sampling_params)
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
num_output_tokens_so_far = 0
gen = self.engine_client.generate(prompt, sampling_params, request_id)
async for res in gen:
# res is vllm's RequestOutput
# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it.
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend,
endpoint,
config.model_path,
config.model_name,
kv_cache_block_size=config.kv_block_size,
)
arg_map = {
"model": config.model_path,
"task": "generate",
"tensor_parallel_size": config.tensor_parallel_size,
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True, publisher="zmq"
),
}
if config.context_length:
# Usually we want it to default to the max (from tokenizer_config.json)
arg_map["max_model_len"] = config.context_length
if config.kv_block_size > 0:
arg_map["block_size"] = config.kv_block_size
if config.extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(config.extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {config.extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
logger.info(f"VLLM config: {arg_map}")
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
os.environ[
"VLLM_WORKER_MULTIPROC_METHOD"
] = "spawn" # Ensure our publisher makes it to the new process
engine_args = AsyncEngineArgs(**arg_map)
model_config = engine_args.create_model_config()
# Load default sampling params from `generation_config.json`
default_sampling_params = model_config.get_diff_sampling_param()
# Taken from build_async_engine_client_from_engine_args()
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# Explicitly pass our custom stat logger for metrics
engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
stat_loggers=[StatLoggerFactory(component)],
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
)
logger.info("VllmWorker has been initialized")
zmq_config = KvEventPublisherFromZmqConfig(
worker_id=endpoint.lease_id(), kv_block_size=engine_args.block_size
)
_ = KvEventPublisherFromZmq(component=component, config=zmq_config)
handler = RequestHandler(component, engine_client, default_sampling_params)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="vLLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a JSON file containing additional keyword arguments to pass to the vLLM AsyncLLMEngine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -250,6 +250,19 @@ dependencies = [
"zmq",
]
[[package]]
name = "asynchronous-codec"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a860072022177f903e59730004fb5dc13db9275b79bb2aef7ba8ce831956c233"
dependencies = [
"bytes",
"futures-sink",
"futures-util",
"memchr",
"pin-project-lite",
]
[[package]]
name = "atomic"
version = "0.6.0"
......@@ -887,6 +900,19 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "dashmap"
version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [
"cfg-if 1.0.0",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.9.0"
......@@ -1106,6 +1132,7 @@ dependencies = [
"rand 0.9.1",
"rayon",
"regex",
"rmp-serde",
"serde",
"serde_json",
"strum",
......@@ -1123,6 +1150,7 @@ dependencies = [
"uuid",
"validator",
"xxhash-rust",
"zeromq",
]
[[package]]
......@@ -1909,6 +1937,12 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "hashbrown"
version = "0.15.2"
......@@ -3818,6 +3852,28 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rmp"
version = "0.8.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4"
dependencies = [
"byteorder",
"num-traits",
"paste",
]
[[package]]
name = "rmp-serde"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db"
dependencies = [
"byteorder",
"rmp",
"serde",
]
[[package]]
name = "rustc-demangle"
version = "0.1.24"
......@@ -4601,6 +4657,7 @@ checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
dependencies = [
"bytes",
"futures-core",
"futures-io",
"futures-sink",
"pin-project-lite",
"tokio",
......@@ -5692,6 +5749,33 @@ version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
[[package]]
name = "zeromq"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a4528179201f6eecf211961a7d3276faa61554c82651ecc66387f68fc3004bd"
dependencies = [
"async-trait",
"asynchronous-codec",
"bytes",
"crossbeam-queue",
"dashmap",
"futures-channel",
"futures-io",
"futures-task",
"futures-util",
"log",
"num-traits",
"once_cell",
"parking_lot",
"rand 0.8.5",
"regex",
"thiserror 1.0.69",
"tokio",
"tokio-util",
"uuid",
]
[[package]]
name = "zeromq-src"
version = "0.2.6+4.3.4"
......
......@@ -61,6 +61,8 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::AggregatedMetrics>()?;
m.add_class::<llm::kv::KvMetricsAggregator>()?;
m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<llm::kv::KvEventPublisherFromZmq>()?;
m.add_class::<llm::kv::KvEventPublisherFromZmqConfig>()?;
m.add_class::<llm::kv::KvRecorder>()?;
m.add_class::<llm::nats::NatsQueue>()?;
m.add_class::<http::HttpService>()?;
......
......@@ -14,13 +14,15 @@
// limitations under the License.
use std::collections::HashMap;
use std::sync::atomic::AtomicU32;
use super::*;
use llm_rs::kv_router::indexer::KvIndexerInterface;
use rs::traits::events::EventSubscriber;
use tracing;
use llm_rs::kv_router::{indexer::compute_block_hash_for_seq, protocols::*};
use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::create_stored_blocks;
#[pyclass]
pub(crate) struct KvRouter {
......@@ -93,6 +95,7 @@ impl KvMetricsPublisher {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (request_active_slots, request_total_slots, kv_active_blocks, kv_total_blocks, num_requests_waiting, gpu_cache_usage_perc, gpu_prefix_cache_hit_rate, data_parallel_rank = 0))]
fn publish(
&self,
_py: Python,
......@@ -103,10 +106,12 @@ impl KvMetricsPublisher {
num_requests_waiting: u64,
gpu_cache_usage_perc: f32,
gpu_prefix_cache_hit_rate: f32,
data_parallel_rank: u32,
) -> PyResult<()> {
self.inner
.publish(
llm_rs::kv_router::protocols::ForwardPassMetrics {
data_parallel_rank: Some(data_parallel_rank),
request_active_slots,
request_total_slots,
kv_active_blocks,
......@@ -121,10 +126,73 @@ impl KvMetricsPublisher {
}
}
#[pyclass]
#[derive(Clone)]
pub struct KvEventPublisherFromZmqConfig {
#[pyo3(get, set)]
pub worker_id: i64,
#[pyo3(get, set)]
pub kv_block_size: usize,
#[pyo3(get, set)]
pub zmq_endpoint: String,
#[pyo3(get, set)]
pub zmq_topic: String,
}
#[pymethods]
impl KvEventPublisherFromZmqConfig {
#[new]
#[pyo3(signature = (
worker_id,
kv_block_size,
zmq_endpoint = "tcp://127.0.0.1:5557".to_string(),
zmq_topic = "".to_string()
))]
pub fn new(
worker_id: i64,
kv_block_size: usize,
zmq_endpoint: String,
zmq_topic: String,
) -> Self {
Self {
worker_id,
kv_block_size,
zmq_endpoint,
zmq_topic,
}
}
}
#[pyclass]
pub(crate) struct KvEventPublisherFromZmq {
inner: llm_rs::kv_router::publisher::KvEventPublisherFromZmq,
}
#[pymethods]
impl KvEventPublisherFromZmq {
#[new]
fn new(component: Component, config: KvEventPublisherFromZmqConfig) -> PyResult<Self> {
let mut inner =
llm_rs::kv_router::publisher::KvEventPublisherFromZmq::new(config.kv_block_size);
inner.start_background_task(
component.inner,
config.worker_id,
config.zmq_endpoint,
config.zmq_topic,
);
Ok(Self { inner })
}
fn shutdown(&mut self) {
self.inner.shutdown()
}
}
#[pyclass]
pub(crate) struct KvEventPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
warning_count: u32,
kv_block_size: usize,
warning_count: Arc<AtomicU32>,
}
#[pymethods]
......@@ -132,14 +200,15 @@ impl KvEventPublisher {
#[new]
fn new(component: Component, worker_id: i64, kv_block_size: usize) -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner.clone(),
component.inner,
worker_id,
kv_block_size,
)
.map_err(to_pyerr)?;
Ok(Self {
inner: inner.into(),
warning_count: 0,
kv_block_size,
warning_count: Arc::new(AtomicU32::new(0)),
})
}
......@@ -151,19 +220,21 @@ impl KvEventPublisher {
event_id: u64,
token_ids: Vec<u32>,
num_block_tokens: Vec<u64>,
block_hashes: Vec<u64>,
block_hashes: Vec<i64>,
lora_id: u64,
parent_hash: Option<u64>,
parent_hash: Option<i64>,
) -> PyResult<()> {
let event = KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: self.create_stored_blocks(
parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
blocks: create_stored_blocks(
self.kv_block_size,
&token_ids,
&num_block_tokens,
&block_hashes,
lora_id,
&self.warning_count,
),
}),
};
......@@ -171,10 +242,10 @@ impl KvEventPublisher {
self.inner.publish(event).map_err(to_pyerr)
}
fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec<u64>) -> PyResult<()> {
fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec<i64>) -> PyResult<()> {
let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
.iter()
.map(|&v| ExternalSequenceBlockHash(v))
.map(|&h| ExternalSequenceBlockHash::from(h))
.collect();
let event = KvCacheEvent {
event_id,
......@@ -185,50 +256,6 @@ impl KvEventPublisher {
}
}
impl KvEventPublisher {
fn create_stored_block_from_parts(
&self,
block_hash: u64,
token_ids: &[u32],
_lora_id: u64,
) -> KvCacheStoredBlockData {
let tokens_hash = compute_block_hash_for_seq(token_ids, self.inner.kv_block_size())[0];
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash),
tokens_hash,
}
}
fn create_stored_blocks(
&mut self,
token_ids: &[u32],
num_block_tokens: &[u64],
block_hashes: &[u64],
lora_id: u64,
) -> Vec<KvCacheStoredBlockData> {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
let mut token_offset: usize = 0;
for (num_tokens_it, block_hash_it) in num_block_tokens.iter().zip(block_hashes.iter()) {
if (self.warning_count < 3) && (*num_tokens_it != self.inner.kv_block_size() as u64) {
tracing::warn!(
"Block not published. Block size must be {} tokens to be published. Block size is: {}",
self.inner.kv_block_size(),
*num_tokens_it
);
self.warning_count += 1;
break;
}
let tokens = &token_ids[token_offset..(token_offset + *num_tokens_it as usize)];
blocks.push(self.create_stored_block_from_parts(*block_hash_it, tokens, lora_id));
token_offset += *num_tokens_it as usize;
}
blocks
}
}
#[pyclass]
#[derive(Clone)]
pub(crate) struct OverlapScores {
......
......@@ -368,6 +368,10 @@ class KvMetricsPublisher:
request_total_slots: int,
kv_active_blocks: int,
kv_total_blocks: int,
num_requests_waiting: int,
gpu_cache_usage_perc: float,
gpu_prefix_cache_hit_rate: float,
data_parallel_rank: int = 0,
) -> None:
"""
Update the KV metrics being reported.
......@@ -575,6 +579,40 @@ class KvEventPublisher:
"""
...
class KvEventPublisherFromZmqConfig:
def __init__(
self,
worker_id: int,
kv_block_size: int,
zmq_endpoint: str = "tcp://127.0.0.1:5557",
zmq_topic: str = ""
) -> None:
"""
Configuration for the KvEventPublisherFromZmq.
:param worker_id: The worker ID.
:param kv_block_size: The block size for the key-value store.
:param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557".
:param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string.
"""
...
class KvEventPublisherFromZmq:
def __init__(self, component: Component, config: KvEventPublisherFromZmqConfig) -> None:
"""
Initializes a new KvEventPublisherFromZmq instance.
:param component: The component to be used.
:param config: Configuration for the event publisher.
"""
...
def shutdown(self) -> None:
"""
Shuts down the event publisher, stopping any background tasks.
"""
...
class HttpService:
"""
A HTTP service for dynamo applications.
......
......@@ -25,6 +25,8 @@ from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpError as HttpError
from dynamo._core import HttpService as HttpService
from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvEventPublisherFromZmq as KvEventPublisherFromZmq
from dynamo._core import KvEventPublisherFromZmqConfig as KvEventPublisherFromZmqConfig
from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvMetricsPublisher as KvMetricsPublisher
......
......@@ -116,6 +116,10 @@ minijinja-contrib = { version = "2.10.2", features = ["pycompat"] }
ggus = "0.4.0"
memmap2 = "0.9.5"
# Publishers
zeromq = "0.4.1"
rmp-serde = "1.3"
[dev-dependencies]
assert_matches = "1.5"
hf-hub = { workspace = true }
......
......@@ -73,7 +73,7 @@ impl KvRouter {
.primary_lease()
.expect("Cannot KV route static workers")
.primary_token();
tracing::info!("KV Routing initialized");
let metrics_aggregator =
KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
......
......@@ -41,6 +41,7 @@ pub struct WorkerSelectionResult {
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ForwardPassMetrics {
pub data_parallel_rank: Option<u32>, // backwards compatible
pub request_active_slots: u64,
pub request_total_slots: u64,
pub kv_active_blocks: u64,
......@@ -65,6 +66,21 @@ pub struct LocalBlockHash(pub u64);
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct ExternalSequenceBlockHash(pub u64);
// Implement From trait for convenient conversion
impl From<u64> for ExternalSequenceBlockHash {
fn from(value: u64) -> Self {
Self(value)
}
}
impl From<i64> for ExternalSequenceBlockHash {
/// Bitwise reinterpretation: preserves all bits, including negatives.
/// This is lossless, but negative i64 values will appear as large u64 values.
fn from(value: i64) -> Self {
Self(value as u64)
}
}
/// Represents a collection of cache events and a shutdown flag.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheEvents {
......
This diff is collapsed.
......@@ -277,8 +277,7 @@ impl WorkerSelector for DefaultWorkerSelector {
let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0);
// Calculate normalized metrics
assert!(ep.data.kv_total_blocks > 0);
let gpu_cache_usage = ep.data.kv_active_blocks as f64 / ep.data.kv_total_blocks as f64;
let gpu_cache_usage = ep.data.gpu_cache_usage_perc as f64;
let normalized_waiting = if max_waiting > 0.0 {
ep.data.num_requests_waiting as f64 / max_waiting
} else {
......
......@@ -393,6 +393,7 @@ impl Scheduler {
};
ForwardPassMetrics {
data_parallel_rank: None, // Default for backwards compatibility
request_active_slots: state.running.len() as u64,
request_total_slots: 420, // Dummy value as specified
kv_active_blocks: active_blocks_count,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment