Unverified Commit 7adf245b authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[Metrics] Add KV events publishing (#6098)

parent 299fd22f
......@@ -25,6 +25,7 @@ runtime_common = [
"interegular",
"llguidance>=0.7.11,<0.8.0",
"modelscope",
"msgspec",
"ninja",
"orjson",
"packaging",
......
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
KV caching events
"""
import atexit
import logging
import queue
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from itertools import count
from queue import Queue
from typing import Any, Callable, Optional, Union
import msgspec
import zmq
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class EventBatch(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False, # type: ignore[call-arg]
):
ts: float
events: list[Any]
class KVCacheEvent(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False, # type: ignore[call-arg]
tag=True,
):
"""Base class for all KV cache-related events"""
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
class AllBlocksCleared(KVCacheEvent):
pass
class KVEventBatch(EventBatch):
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches."""
@abstractmethod
def publish(self, events: EventBatch) -> None:
"""Emit events in order.
Implementations should guarantee at-least-once delivery and
monotonic ordering (e.g., via sequence numbers).
"""
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the publisher."""
class NullEventPublisher(EventPublisher):
"""No-op implementation (default when disabled)."""
def publish(self, events) -> None:
return
def shutdown(self) -> None:
return
class ZmqEventPublisher(EventPublisher):
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
Spawns a separate thread to handle publishing from a queue.
Parameters
----------
endpoint:
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
connect.
replay_endpoint:
Optional ROUTER address for replay requests. When given, subscribers can
request missed batches by sending the starting sequence number as an
8-byte big-endian integer.
buffer_steps:
Number of past batches to keep for replay.
hwm:
ZeroMQ high-water-mark for PUB socket.
max_queue_size:
Maximum number of events to buffer in memory.
topic:
Topic to publish events to.
"""
SHUTDOWN_TIMEOUT: float = 1.0
END_SEQ = (-1).to_bytes(8, "big", signed=True)
def __init__(
self,
endpoint: str = "tcp://*:5557",
replay_endpoint: Optional[str] = None,
buffer_steps: int = 10_000,
hwm: int = 100_000,
max_queue_size: int = 100_000,
topic: str = "",
) -> None:
# Storage
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
# ZMQ sockets
self._ctx = zmq.Context.instance()
self._pub: Optional[zmq.Socket] = None
self._replay: Optional[zmq.Socket] = None
self._endpoint = endpoint
self._replay_endpoint = replay_endpoint
self._hwm = hwm
self._socket_setup()
# Payload
self._seq_gen = count()
self._topic_bytes = topic.encode("utf-8")
# Thread
self._running = True
logger.info("Starting ZMQ publisher thread")
self._thread = threading.Thread(
target=self._publisher_thread, daemon=True, name="zmq-publisher"
)
self._thread.start()
atexit.register(self.shutdown)
def publish(self, events: EventBatch) -> None:
if not self._running:
raise RuntimeError("Publisher is closed")
self._event_queue.put(events)
def shutdown(self) -> None:
"""Stop the publisher thread and clean up resources."""
self._running = False
self._event_queue.put_nowait(None)
start = time.time()
pending_items = True
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
pending_items = not self._event_queue.empty()
if pending_items:
time.sleep(0.1)
if pending_items:
logger.warning(
"Warning: Queue still has %s items after %s seconds timeout",
self._event_queue.qsize(),
self.SHUTDOWN_TIMEOUT,
)
if self._thread.is_alive():
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
# Clean up ZMQ resources
try:
if self._pub is not None:
self._pub.close(linger=0)
if self._replay is not None:
self._replay.close(linger=0)
finally:
pass # Do not terminate context; other sockets may use it
def _socket_setup(self) -> None:
"""Initialize sockets
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
"""
if self._pub is None:
self._pub = self._ctx.socket(zmq.PUB)
self._pub.set_hwm(self._hwm)
# Heuristic: bind if wildcard / * present, else connect.
# bind stable, connect volatile convention
if (
"*" in self._endpoint
or "::" in self._endpoint
or self._endpoint.startswith("ipc://")
or self._endpoint.startswith("inproc://")
):
self._pub.bind(self._endpoint)
else:
self._pub.connect(self._endpoint)
# Set up replay socket: use ROUTER
# 1) handles multiple REQ clients (identities)
# 2) lets us send back one request → many replies (streamed events)
# 3) works in our non‑blocking poll loop alongside PUB
if self._replay_endpoint is not None:
self._replay = self._ctx.socket(zmq.ROUTER)
self._replay.bind(self._replay_endpoint)
def _publisher_thread(self) -> None:
"""Background thread that processes the event queue."""
self._pack = msgspec.msgpack.Encoder()
assert self._pub is not None # narrows type for mypy
while self._running or self._event_queue.qsize() > 0:
# --- replay (non-critical) ---------------------------------
if self._replay is not None and self._replay.poll(0):
try:
self._service_replay()
except Exception as e:
logger.exception("Error in replay: %s", e)
# --- main queue (critical) ---------------------------------
try:
event = self._event_queue.get(timeout=0.1)
if event is None:
break # Sentinel received, exit thread
except queue.Empty:
continue
try:
seq = next(self._seq_gen)
payload = self._pack.encode(event)
seq_bytes = seq.to_bytes(8, "big")
self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
self._buffer.append((seq, payload))
self._event_queue.task_done()
except Exception as e:
# Publishing failed; back-off a bit to avoid a tight error loop
logger.exception("Error in publisher thread: %s", e)
time.sleep(0.1)
def _service_replay(self) -> None:
"""If a replay request is waiting, send buffered batches."""
assert self._replay is not None # narrows type for mypy
frame = self._replay.recv_multipart()
if len(frame) != 3:
logger.warning("Invalid replay request: %s", frame)
return
client_id, _, start_seq_bytes = frame
start_seq = int.from_bytes(start_seq_bytes, "big")
for seq, buf in self._buffer:
if seq >= start_seq:
# [identity, empty_delim, seq_bytes, payload]
# (identity, empty_delim) are stripped off by the router
# receiving payload is (seq_bytes, payload)
self._replay.send_multipart(
(client_id, b"", seq.to_bytes(8, "big"), buf)
)
# Send end of sequence marker
# receiving payload is (-1, b""")
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
class KVEventsConfig(BaseModel):
"""Configuration for KV event publishing."""
publisher: str = "null"
"""The publisher to use for publishing kv events. Can be "null", "zmq".
"""
endpoint: str = "tcp://*:5557"
"""The zmq endpoint to use for publishing kv events.
"""
replay_endpoint: Optional[str] = None
"""The zmq endpoint to use for replaying kv events.
"""
buffer_steps: int = 10_000
"""The number of steps to cache for replay endpoint. Will only save
events from the last N steps for the replay endpoint.
"""
hwm: int = 100_000
"""The zmq high water mark for the event publisher. After queueing N events,
events will start dropping if the consumer is not keeping up.
"""
max_queue_size: int = 100_000
"""The maximum number of events to queue while waiting for publishing.
"""
topic: str = ""
"""The topic to use for the event publisher. Consumers can subscribe to
this topic to receive events.
"""
@classmethod
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
"""Parse the CLI value for the event publisher config."""
return KVEventsConfig.model_validate_json(cli_value)
class EventPublisherFactory:
_registry: dict[str, Callable[..., EventPublisher]] = {
"null": NullEventPublisher,
"zmq": ZmqEventPublisher,
}
@classmethod
def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
if name in cls._registry:
raise KeyError(f"publisher '{name}' already registered")
cls._registry[name] = ctor
@classmethod
def create(cls, config: Optional[str]) -> EventPublisher:
"""Create publisher from a config mapping."""
if not config:
return NullEventPublisher()
config = KVEventsConfig.from_cli(config)
config_dict = config.model_dump()
kind = config_dict.pop("publisher", "null")
try:
constructor = cls._registry[kind]
except KeyError as exc:
raise ValueError(f"Unknown event publisher '{kind}'") from exc
return constructor(**config_dict)
......@@ -41,6 +41,7 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin,
......@@ -197,6 +198,7 @@ class Scheduler(
self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
self.enable_kv_cache_events = server_args.kv_events_config is not None
self.stream_interval = server_args.stream_interval
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
......@@ -204,7 +206,6 @@ class Scheduler(
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.page_size = server_args.page_size
# Distributed rank info
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
......@@ -422,6 +423,7 @@ class Scheduler(
# Init metrics stats
self.init_metrics()
self.init_kv_events(server_args.kv_events_config)
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
......@@ -515,6 +517,7 @@ class Scheduler(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
enable_kv_cache_events=self.enable_kv_cache_events,
)
self.decode_mem_cache_buf_multiplier = (
......@@ -547,6 +550,10 @@ class Scheduler(
},
)
def init_kv_events(self, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
def init_disaggregation(self):
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
......@@ -1154,6 +1161,7 @@ class Scheduler(
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def log_decode_stats(
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
......@@ -1213,6 +1221,7 @@ class Scheduler(
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.spec_accept_length = spec_accept_length
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def check_memory(self):
available_size = (
......@@ -1260,6 +1269,7 @@ class Scheduler(
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
......@@ -2194,6 +2204,13 @@ class Scheduler(
prefix += f" PP{self.pp_rank}"
return prefix
def _publish_kv_events(self):
if self.enable_kv_cache_events:
events = self.tree_cache.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
......
......@@ -48,3 +48,6 @@ class BasePrefixCache(ABC):
def pretty_print(self):
raise NotImplementedError()
def take_events(self):
return []
......@@ -27,6 +27,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from sglang.srt.disaggregation.kv_events import (
AllBlocksCleared,
BlockRemoved,
BlockStored,
KVCacheEvent,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
......@@ -96,11 +102,14 @@ class RadixCache(BasePrefixCache):
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
disable: bool = False,
enable_kv_cache_events: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue = []
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
......@@ -124,6 +133,7 @@ class RadixCache(BasePrefixCache):
self.root_node.lock_ref = 1
self.evictable_size_ = 0
self.protected_size_ = 0
self._record_all_cleared_event()
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree.
......@@ -273,6 +283,8 @@ class RadixCache(BasePrefixCache):
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
self._record_remove_event(x)
def inc_lock_ref(self, node: TreeNode):
if self.disable:
return 0
......@@ -348,6 +360,7 @@ class RadixCache(BasePrefixCache):
def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child
self._record_remove_event(child)
new_node = TreeNode()
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
......@@ -358,6 +371,10 @@ class RadixCache(BasePrefixCache):
child.key = child.key[split_len:]
child.value = child.value[split_len:]
new_node.parent.children[self.get_child_key_fn(key)] = new_node
self._record_store_event(new_node)
self._record_store_event(child)
return new_node
def _insert_helper(self, node: TreeNode, key: List, value):
......@@ -390,6 +407,7 @@ class RadixCache(BasePrefixCache):
new_node.value = value
node.children[child_key] = new_node
self.evictable_size_ += len(value)
self._record_store_event(new_node)
return total_prefix_length
def _print_helper(self, node: TreeNode, indent: int):
......@@ -442,6 +460,41 @@ class RadixCache(BasePrefixCache):
return ret_list
def _record_store_event(self, node: TreeNode):
if self.enable_kv_cache_events:
block_hash = hash(tuple(node.key))
parent_block_hash = hash(tuple(node.parent.key))
self.kv_event_queue.append(
BlockStored(
block_hashes=[block_hash],
parent_block_hash=parent_block_hash,
token_ids=node.key,
block_size=len(node.key),
lora_id=None,
)
)
def _record_remove_event(self, node: TreeNode):
if self.enable_kv_cache_events:
block_hash = hash(tuple(node.key))
self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
def _record_all_cleared_event(self):
if self.enable_kv_cache_events:
self.kv_event_queue.append(AllBlocksCleared())
def take_events(self):
"""Atomically takes all events and clears the queue.
Returns:
A list of KV cache events.
"""
if not self.enable_kv_cache_events:
return []
events = self.kv_event_queue
self.kv_event_queue = []
return events
if __name__ == "__main__":
tree = RadixCache(None, None, page_size=1, disable=False)
......
......@@ -103,6 +103,7 @@ class ServerArgs:
collect_tokens_histogram: bool = False
decode_log_interval: int = 40
enable_request_time_stats_logging: bool = False
kv_events_config: Optional[str] = None
# API related
api_key: Optional[str] = None
......@@ -814,6 +815,12 @@ class ServerArgs:
default=ServerArgs.collect_tokens_histogram,
help="Collect prompt/generation tokens histogram.",
)
parser.add_argument(
"--kv-events-config",
type=str,
default=None,
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
)
parser.add_argument(
"--decode-log-interval",
type=int,
......
import time
import unittest
import msgspec
import requests
import zmq
from msgspec.msgpack import Decoder
from sglang.srt.disaggregation.kv_events import (
AllBlocksCleared,
BlockRemoved,
BlockStored,
EventBatch,
KVCacheEvent,
KVEventBatch,
)
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestKvEvents(CustomTestCase):
def test_kv_events_enabled(self):
"""Test that kv events are sent and received by subscriber data when enabled"""
# Launch kv events subscriber
decoder = Decoder(type=KVEventBatch)
context = zmq.Context()
sub = context.socket(zmq.SUB)
sub.connect("tcp://localhost:5557")
topic = "kv-events"
sub.setsockopt_string(zmq.SUBSCRIBE, topic)
# Launch sglang server
process = popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--kv-events-config",
'{"publisher": "zmq", "topic": "kv-events"}',
"--max-total-tokens",
32,
"--cuda-graph-max-bs",
2,
],
)
try:
# Make some requests to generate some metrics
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
self.assertEqual(response.status_code, 200)
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": "The capital of Spain is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
# Expected events. These may be dependent on model used (meta-llama/Llama-3.2-1B-Instruct)
expected_events = [
# <begin> The capital city of France is
BlockStored(
block_hashes=[-6650323075460941099],
parent_block_hash=5740354900026072187,
token_ids=[128000, 791, 6864, 3363, 315, 9822, 374],
block_size=7,
lora_id=None,
),
# Paris. The Eiffel Tower
BlockStored(
block_hashes=[-7584018293207282755],
parent_block_hash=-6650323075460941099,
token_ids=[12366, 13, 578, 469, 3168, 301, 22703],
block_size=7,
lora_id=None,
),
BlockStored(
block_hashes=[-8753497827991233192],
parent_block_hash=5740354900026072187,
token_ids=[0],
block_size=1,
lora_id=None,
),
BlockRemoved(block_hashes=[-6650323075460941099]),
# <begin> The capital
BlockStored(
block_hashes=[-2697055055087824455],
parent_block_hash=5740354900026072187,
token_ids=[128000, 791, 6864],
block_size=3,
lora_id=None,
),
# city of France is
BlockStored(
block_hashes=[-7505627135785778022],
parent_block_hash=-2697055055087824455,
token_ids=[3363, 315, 9822, 374],
block_size=4,
lora_id=None,
),
# of France is
BlockStored(
block_hashes=[-3861108700662737012],
parent_block_hash=-2697055055087824455,
token_ids=[315, 9822, 374],
block_size=3,
lora_id=None,
),
BlockRemoved(block_hashes=[-7584018293207282755]),
BlockRemoved(block_hashes=[-8753497827991233192]),
BlockRemoved(block_hashes=[-7505627135785778022]),
# Paris. The Eiffel Tower is located in Paris. The Eiffel Tower is a famous landmark in Paris
BlockStored(
block_hashes=[-3064341286825792715],
parent_block_hash=-3861108700662737012,
token_ids=[
12366,
13,
578,
469,
3168,
301,
22703,
374,
7559,
304,
12366,
13,
578,
469,
3168,
301,
22703,
374,
264,
11495,
38350,
304,
12366,
],
block_size=23,
lora_id=None,
),
BlockRemoved(block_hashes=[-3861108700662737012]),
# of
BlockStored(
block_hashes=[6115672085296369592],
parent_block_hash=-2697055055087824455,
token_ids=[315],
block_size=1,
lora_id=None,
),
# France is
BlockStored(
block_hashes=[4208810872343132234],
parent_block_hash=6115672085296369592,
token_ids=[9822, 374],
block_size=2,
lora_id=None,
),
# Spain is
BlockStored(
block_hashes=[1675819893649989955],
parent_block_hash=6115672085296369592,
token_ids=[18157, 374],
block_size=2,
lora_id=None,
),
BlockRemoved(block_hashes=[-3064341286825792715]),
# Madrid. The capital of France is Paris. The capital of Italy is Rome. The capital of Spain is Madrid.
BlockStored(
block_hashes=[-8505834929190027295],
parent_block_hash=1675819893649989955,
token_ids=[
25048,
13,
578,
6864,
315,
9822,
374,
12366,
13,
578,
6864,
315,
15704,
374,
22463,
13,
578,
6864,
315,
18157,
374,
25048,
13,
],
block_size=23,
lora_id=None,
),
]
# Get events
events = []
start = time.time()
max_wait_s = 5
while (
len(events) < len(expected_events)
and (time.time() - start) < max_wait_s
):
_, seq_bytes, payload = sub.recv_multipart()
event_batch = decoder.decode(payload)
for event in event_batch.events:
print(f" - {event}")
events.append(event)
for expected in expected_events:
self.assertIn(expected, events)
finally:
kill_process_tree(process.pid)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment