publisher.py 32.6 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
# SPDX-License-Identifier: Apache-2.0

4
5
6
7
8
9
10
11
"""
TensorRT-LLM KV Event Publisher Module

This module contains the Publisher class that retrieves KV cache events from TensorRT-LLM
and publishes them either to ZMQ (for consolidator) or NATS (direct to router).

Key Components:
- ZmqKvEventPublisher: Pure Python ZMQ PUBLISHER that publishes TensorRT-LLM KV events
12
13
14
  to ZMQ (so the consolidator can subscribe). This is different from KvEventPublisher
  in dynamo.llm, which is a Rust-based class that can optionally subscribe from a ZMQ
  source and publishes to NATS.
15
16
17
- Publisher: Main class that coordinates event publishing (ZMQ or NATS) and metrics publishing.

Event Flow:
18
- With Consolidator: Engine → ZmqKvEventPublisher (ZMQ PUB) → Consolidator → KvEventPublisher (dynamo.llm, ZMQ SUB) → NATS → Router
19
20
21
- Without Consolidator: Engine → KvEventPublisher (NATS PUB) → Router
"""

22
23
24
25
import asyncio
import concurrent.futures
import logging
import threading
26
import time
27
28
import traceback
import weakref
29
from contextlib import asynccontextmanager
30
from queue import Queue
31
from typing import Awaitable, Callable, Dict, Optional, Union
32

33
34
import msgpack
import zmq
35
from prometheus_client import CollectorRegistry
36

37
from dynamo.common.utils.prometheus import LLMBackendMetrics
38
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
39
40
41

logging.basicConfig(level=logging.DEBUG)

42
43
44
45
46
# Create a dedicated registry for dynamo_component metrics
# This ensures these metrics are isolated and can be exposed via their own callback
DYNAMO_COMPONENT_REGISTRY = CollectorRegistry()


47
48
49
50
51
52
53
54
55
56
# Use non-blocking RPC calls; control overhead with backoff sleeps.
_STATS_TIMEOUT_SEC = 0.01
_KV_EVENTS_TIMEOUT_SEC = 0.0
_PUBLISH_MIN_SLEEP_SEC = 0.01
_PUBLISH_MAX_SLEEP_SEC = 0.1
_PUBLISH_BACKOFF_FACTOR = 2.0
_KV_EVENTS_MIN_SLEEP_SEC = 0.005
_KV_EVENTS_MAX_SLEEP_SEC = 0.02
_KV_EVENTS_BACKOFF_FACTOR = 1.5

57

58
59
60
61
62
63
64
65
66
67
68
69
def _to_signed_i64(value: int | None) -> int | None:
    """Convert a Python int to signed 64-bit range by two's complement."""
    if value is None:
        return None

    if value >= 2**63:
        return value - 2**64
    if value < -(2**63):
        return ((value + 2**63) % 2**64) - 2**63
    return value


70
71
72
73
74
class ZmqKvEventPublisher:
    """
    Pure Python ZMQ PUBLISHER for TensorRT-LLM KV events.

    This class publishes TensorRT-LLM's KV cache events to ZMQ so that the consolidator
75
76
77
    can subscribe to them. This is different from KvEventPublisher in dynamo.llm,
    which is a Rust-based class that can optionally subscribe from a ZMQ source
    and publishes to NATS.
78
79
80
81

    Event Format: [timestamp, [events], data_parallel_rank]
    Message Format: multipart ZMQ message [topic, sequence, payload] where payload is
    msgpack-serialized batch.
82
83
    When attention DP is enabled for DeepSeek-style models, `data_parallel_rank` is set to the attention DP rank.
    Otherwise, it defaults to 0.
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    Usage:
        Used by Publisher class when consolidator is enabled (zmq_endpoint provided).
        Publishes events from TensorRT-LLM engine to ZMQ for consolidator to consume.
    """

    def __init__(self, zmq_endpoint: str, kv_block_size: int, topic: str = ""):
        """
        Initialize ZMQ publisher.

        Args:
            zmq_endpoint: ZMQ endpoint to bind to (e.g., "tcp://*:20081")
            kv_block_size: Size of KV cache blocks in tokens
            topic: ZMQ topic to publish on (empty string for all topics)
        """
        self.zmq_endpoint = zmq_endpoint
        self.kv_block_size = kv_block_size
        self.topic = topic
        self.ctx = zmq.Context()
        self.socket = self.ctx.socket(zmq.PUB)
        self.socket.bind(zmq_endpoint)
        self.sequence = 0
106
107
108
        self.data_parallel_rank = (
            0  # TensorRT-LLM doesn't use DP for now (but does support attention DP)
        )
109
110
111
112
113
114
115
116
117
118
119
120
        logging.info(
            f"TensorRT-LLM: ZMQ KV event publisher initialized - bound to {zmq_endpoint} "
            f"with topic '{topic}', kv_block_size={kv_block_size}"
        )

    def publish_stored(
        self,
        token_ids: list[int],
        num_block_tokens: list[int],
        block_hashes: list[int],
        lora_id: int = 0,
        parent_hash: Optional[int] = None,
121
        block_mm_infos: Optional[list[dict | None]] = None,
122
        attention_dp_rank: int = 0,
123
    ):
124
125
126
127
        """Publish a BlockStored event.

        Note: event_id is managed internally via self.sequence counter.
        """
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        # Convert block hashes to signed i64 format
        block_hashes_signed = [_to_signed_i64(h) for h in block_hashes]
        parent_hash_signed = (
            _to_signed_i64(parent_hash) if parent_hash is not None else None
        )

        # Create event in the same format as vLLM's ZmqEventPublisher:
        # All blocks should have the same size (kv_block_size)
        event = {
            "type": "BlockStored",
            "block_hashes": block_hashes_signed,
            "parent_block_hash": parent_hash_signed,
            "token_ids": token_ids,
            "block_size": self.kv_block_size,
            "lora_id": lora_id if lora_id != 0 else None,
        }

145
146
147
148
        # Add multimodal info if present
        if block_mm_infos is not None:
            event["block_mm_infos"] = block_mm_infos

149
        self._publish_event(event, attention_dp_rank)
150

151
    def publish_removed(self, block_hashes: list[int], attention_dp_rank: int = 0):
152
153
154
155
        """Publish a BlockRemoved event.

        Note: event_id is managed internally via self.sequence counter.
        """
156
157
158
159
160
161
162
163
        # Convert block hashes to signed i64 format (vLLM compatibility)
        block_hashes_signed = [_to_signed_i64(h) for h in block_hashes]

        event = {
            "type": "BlockRemoved",
            "block_hashes": block_hashes_signed,
        }

164
        self._publish_event(event, attention_dp_rank)
165
166
167
168
169
170

    def publish_all_cleared(self):
        """Publish an AllBlocksCleared event."""
        event = {"type": "AllBlocksCleared"}
        self._publish_event(event)

171
    def _publish_event(self, event: dict, attention_dp_rank: int = 0):
172
173
174
        """Publish a single event to ZMQ in vLLM batch format."""
        try:
            # Create batch in vLLM format: [timestamp, [events], data_parallel_rank]
175
            # The third element (data_parallel_rank) is used by the router for dp_rank routing
176
            timestamp = time.time()
177
            batch = [timestamp, [event], attention_dp_rank]
178
179
            event_type = event.get("type", "Unknown")
            logging.debug(
180
                f"TensorRT-LLM: ZMQ publisher sending {event_type} event (dp_rank={attention_dp_rank}) to {self.zmq_endpoint}"
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            )

            # Serialize with msgpack (vLLM uses msgpack/rmp_serde compatible format)
            payload = msgpack.packb(batch, use_bin_type=True)

            # Create multipart message: [topic, sequence, payload]
            # Format matches what consolidator expects: 3 frames [topic, sequence, payload]
            sequence_bytes = self.sequence.to_bytes(8, byteorder="big")
            self.sequence += 1

            # Send multipart message (blocking send to ensure delivery)
            # Topic is empty string for "all topics" (vLLM compatibility)
            self.socket.send_multipart(
                [self.topic.encode(), sequence_bytes, payload], flags=0
            )
        except Exception as e:
            logging.error(f"Failed to publish ZMQ event: {e}", exc_info=True)

    def shutdown(self):
        """Shutdown the ZMQ publisher."""
        if self.socket:
            self.socket.close()
        if self.ctx:
            self.ctx.term()
        logging.info("ZMQ KV event publisher shut down")


208
209
210
211
212
213
214
class ManagedThread(threading.Thread):
    """
    A thread that runs a task and handles errors.
    """

    def __init__(
        self,
215
        task: Optional[Union[Callable[..., Awaitable[bool]], weakref.WeakMethod]],
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        error_queue: Optional[Queue] = None,
        name: Optional[str] = None,
        loop: Optional[asyncio.AbstractEventLoop] = None,
        **kwargs,
    ):
        super().__init__(name=name)
        self.task = task
        self.error_queue = error_queue
        self.kwargs = kwargs
        self.loop = loop
        self.daemon = True
        self._current_future: Optional[concurrent.futures.Future] = None

        self._stop_event = threading.Event()

    def set_loop(self, loop: asyncio.AbstractEventLoop):
        self.loop = loop

    def run(self):
        while not self._stop_event.is_set():
236
237
238
            task: Optional[
                Union[Callable[..., Awaitable[bool]], weakref.WeakMethod]
            ] = self.task
239
240
241
242
243
244
245
246
247
248
249
250
251
252
            if isinstance(task, weakref.WeakMethod):
                task = task()
                if task is None:
                    # Normally, this should not happen.
                    logging.warning("WeakMethod is expired.")
                    break

            if task is None:
                break

            try:
                if self.loop is None:
                    logging.error("[ManagedThread] Loop not initialized!")
                    break
253
254
255
256
257
258
259
260

                # Call the task function to get the coroutine
                coro = task(**self.kwargs)
                if not asyncio.iscoroutine(coro):
                    logging.error(f"Task {task} did not return a coroutine")
                    break

                self._current_future = asyncio.run_coroutine_threadsafe(coro, self.loop)
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
                _ = self._current_future.result()
            except (asyncio.CancelledError, concurrent.futures.CancelledError):
                logging.debug(f"Thread {self.name} was cancelled")
                break
            except Exception as e:
                logging.error(
                    f"Error in thread {self.name}: {e}\n{traceback.format_exc()}"
                )
                if self.error_queue is not None:
                    self.error_queue.put(e)

        logging.info(f"Thread {self.name} stopped.")

    def stop(self):
        self._stop_event.set()
        if self._current_future and not self._current_future.done():
            self._current_future.cancel()


280
class Publisher:
281
    """
282
283
284
285
286
287
288
289
290
291
292
    Main publisher class for TensorRT-LLM KV events and metrics.

    Retrieves KV cache events and stats from TensorRT-LLM engine and publishes them:
    - KV Events: Routes to either ZMQ (if consolidator enabled) or NATS (if no consolidator)
    - Metrics: Always publishes to NATS via WorkerMetricsPublisher

    Publisher Selection Logic:
    - If zmq_endpoint provided: Uses ZmqKvEventPublisher (ZMQ PUB) → Consolidator → NATS
    - If zmq_endpoint None: Uses KvEventPublisher (NATS PUB) → Router directly

    Note: The ZmqKvEventPublisher used here is the pure Python ZMQ publisher defined
293
    in this module, not the Rust-based KvEventPublisher from dynamo.llm (which is
294
    used in main.py as the worker-side subscriber from consolidator to NATS).
295
296
    """

297
    def __init__(
298
        self,
299
        endpoint,
300
301
302
303
        engine,
        worker_id,
        kv_block_size,
        metrics_labels,
304
        component_gauges: LLMBackendMetrics,
305
        zmq_endpoint: Optional[str] = None,
306
        enable_local_indexer: bool = False,
307
        metrics_collector=None,
308
    ):
309
        self.endpoint = endpoint
310
311
312
        self.engine = engine
        self.worker_id = worker_id
        self.kv_block_size = kv_block_size
313
        self.max_window_size = None
314
        self.metrics_labels = metrics_labels
315
        self.component_gauges = component_gauges
316
        self.enable_local_indexer = enable_local_indexer
317
        self.metrics_collector = metrics_collector
318
        self.attention_dp_size = engine.get_attention_dp_size()
319
320
321
322
323

        # The first few kv events from the model engine are always "created" type events.
        # Use these events to capture the max_window_size of the model.
        # When the first event that is not a "created" type is received, the publisher will set this to False to stop processing "created" type events.
        self.processing_initial_created_events = True
324
325
326

        # Needed by the events and metrics publishers
        self.metrics_publisher = None
327
328
329
        self.kv_event_publishers: Optional[
            Dict[int, KvEventPublisher]
        ] = None  # One per attention_dp_rank
330
331
332
        self.zmq_kv_event_publisher = None  # ZMQ publisher for consolidator
        self.publish_kv_cache_events_thread: Optional[ManagedThread] = None
        self.publish_stats_thread: Optional[ManagedThread] = None
333
334
        # A set to store the block hash of partial block (i.e. block containing less than kv_block_size tokens) hashes.
        # It is used to prevent sending remove event to kv router since partial blocks are not stored.
335
        self.partial_block_hashes: set[int] = set()
336
337
        self.error_queue: Queue = Queue()
        self._stop_event = threading.Event()
338
339
        # Track the last engine event_id to assert consecutive event IDs from the engine
        self._last_engine_event_id: Optional[int] = None
340

341
342
343
344
345
346
347
348
349
350
351
352
353
        # Initialize ZMQ publisher if endpoint is provided (consolidator enabled)
        if zmq_endpoint:
            logging.info(
                f"TensorRT-LLM: Initializing ZMQ KV event publisher with endpoint={zmq_endpoint}"
            )
            self.zmq_kv_event_publisher = ZmqKvEventPublisher(
                zmq_endpoint, self.kv_block_size
            )
        else:
            logging.info(
                "TensorRT-LLM: ZMQ endpoint not provided, ZMQ publisher will not be initialized"
            )

354
355
356
357
358
    async def _create_metrics_publisher_endpoint(self):
        logging.debug("Creating metrics publisher endpoint")
        if self.metrics_publisher is None:
            logging.error("KV metrics publisher not initialized!")
            return
359
        await self.metrics_publisher.create_endpoint(self.endpoint)
360

361
    def initialize(self):
362
        # Setup the metrics publisher
363
        self.metrics_publisher = WorkerMetricsPublisher()
364
365
366
367
368
369
370
        self._init_publish_metrics_thread()
        task = asyncio.create_task(self._create_metrics_publisher_endpoint())
        task.add_done_callback(
            lambda _: logging.debug("metrics publisher endpoint created")
        )

        # Setup the kv cache events publisher
371
372
373
        # Publisher selection based on consolidator configuration:
        # - With consolidator: Use ZmqKvEventPublisher (this module) → ZMQ → Consolidator → NATS → Router
        # - Without consolidator: Use KvEventPublisher → NATS → Router (direct)
374
        # Note: The worker-side KvEventPublisher (from dynamo.llm) that subscribes from
375
376
377
378
379
380
        # consolidator and publishes to NATS is created separately in main.py, not here.
        if self.zmq_kv_event_publisher:
            logging.info(
                "KV Event Consolidator enabled - using ZMQ publisher only. "
                "Consolidator will publish consolidated events to NATS."
            )
381
            self.kv_event_publishers = None
382
383
        else:
            # No consolidator: use NATS publisher (router subscribes directly)
384
385
386
387
            # Create one KvEventPublisher per attention_dp_rank (similar to vLLM's DP pattern)
            self.kv_event_publishers = {}
            for rank in range(self.attention_dp_size):
                self.kv_event_publishers[rank] = KvEventPublisher(
388
389
390
                    endpoint=self.endpoint,
                    worker_id=self.worker_id,
                    kv_block_size=self.kv_block_size,
391
392
393
394
395
                    dp_rank=rank,
                    enable_local_indexer=self.enable_local_indexer,
                )
            logging.info(
                f"Created {self.attention_dp_size} KV event publisher(s) for attention DP ranks"
396
397
398
            )

        # Always initialize the thread - it routes to either ZMQ or NATS publisher
399
400
401
402
403
404
405
406
        self._init_publish_kv_cache_events_thread()

    def _init_publish_metrics_thread(self):
        # Need to publish stats once so that worker can be selected.
        if self.metrics_publisher is None:
            logging.error("KV metrics publisher not initialized!")
            return

407
        # Publish initial metrics with 0 active blocks
408
        # TRT-LLM doesn't use data parallelism currently (dp_rank="0")
409
        self.metrics_publisher.publish(None, 0)
410
411
        self.component_gauges.set_total_blocks("0", 0)
        self.component_gauges.set_gpu_cache_usage("0", 0.0)
412

413
414
415
416
417
418
419
420
421
422
        # Prepare threads for publishing stats but don't start them yet.
        # TRTLLM needs to start generating tokens first before stats
        # can be retrieved.
        self.publish_stats_thread = ManagedThread(
            self._publish_stats_task,
            error_queue=self.error_queue,
            name="publish_stats_thread",
        )

    def _init_publish_kv_cache_events_thread(self):
423
        # The _publish_kv_cache_events_task will route to the appropriate publisher
424
425
426
427
428
429
430
431
432
        # Prepare threads for publishing kv cache events but don't start them yet.
        # TRTLLM needs to start generating tokens first before kv cache events
        # can be retrieved.
        self.publish_kv_cache_events_thread = ManagedThread(
            self._publish_kv_cache_events_task,
            error_queue=self.error_queue,
            name="publish_kv_cache_events_thread",
        )

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    async def _polling_loop(
        self,
        fetch_fn,
        handler_fn,
        min_sleep: float,
        max_sleep: float,
        backoff_factor: float,
    ):
        sleep_s = min_sleep
        while not self._stop_event.is_set():
            had_data = False
            try:
                async for item in fetch_fn():
                    had_data = True
                    handler_fn(item)
            except (asyncio.TimeoutError, TimeoutError, asyncio.QueueEmpty):
                pass
            except Exception as e:
                logging.warning(f"Publisher polling loop error: {e}", exc_info=True)

            if not had_data:
                await asyncio.sleep(sleep_s)
                sleep_s = min(max_sleep, sleep_s * backoff_factor)
            else:
                sleep_s = min_sleep

459
460
461
462
463
464
465
466
467
468
469
470
    async def _publish_stats_task(self):
        """
        Publish stats to the metrics publisher.
        """
        if self.engine is None:
            logging.error("LLM engine not initialized!")
            return

        if self.metrics_publisher is None:
            logging.error("KV metrics publisher not initialized!")
            return False

471
        def handle_stat(stat):
472
            kv_active_blocks = stat["kvCacheStats"]["usedNumBlocks"]
473
            kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"]
474
            logging.debug(f"Publishing stats: kv_active_blocks: {kv_active_blocks}")
475
            # TRT-LLM doesn't use data parallelism currently (dp_rank=None for NATS, "0" for Prometheus)
476
            self.metrics_publisher.publish(None, kv_active_blocks)
477

478
479
480
481
482
483
484
485
486
            # Publish Prometheus metrics
            self.component_gauges.set_total_blocks("0", kv_total_blocks)

            # Calculate and publish GPU cache usage percentage
            gpu_cache_usage = (
                kv_active_blocks / kv_total_blocks if kv_total_blocks > 0 else 0.0
            )
            self.component_gauges.set_gpu_cache_usage("0", gpu_cache_usage)

487
488
489
490
491
492
493
494
495
496
            # Log iteration stats to TRT-LLM MetricsCollector (PR #11243)
            # This populates trtllm_kv_cache_hit_rate and trtllm_kv_cache_utilization gauges
            if self.metrics_collector and hasattr(
                self.metrics_collector, "log_iteration_stats"
            ):
                try:
                    self.metrics_collector.log_iteration_stats(stat)
                except Exception as e:
                    logging.warning(f"Failed to log iteration stats: {e}")

497
498
499
500
501
502
503
        await self._polling_loop(
            lambda: self.engine.llm.get_stats_async(timeout=_STATS_TIMEOUT_SEC),
            handle_stat,
            _PUBLISH_MIN_SLEEP_SEC,
            _PUBLISH_MAX_SLEEP_SEC,
            _PUBLISH_BACKOFF_FACTOR,
        )
504
505
506
507
508
        return True

    async def _publish_kv_cache_events_task(self):
        """
        Publish kv cache events to the events publisher.
509
        Routes to ZMQ (if kv event consolidation is enabled) or NATS (if no kv event consolidation).
510
511
512
513
514
        """
        if self.engine is None:
            logging.error("LLM engine not initialized!")
            return

515
        # Check that at least one publisher is available
516
        if not self.kv_event_publishers and self.zmq_kv_event_publisher is None:
517
            logging.error("No KV event publisher initialized (neither NATS nor ZMQ)!")
518
519
            return

520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
        await self._polling_loop(
            lambda: self.engine.llm.get_kv_cache_events_async(
                timeout=_KV_EVENTS_TIMEOUT_SEC
            ),
            self._handle_kv_event,
            _KV_EVENTS_MIN_SLEEP_SEC,
            _KV_EVENTS_MAX_SLEEP_SEC,
            _KV_EVENTS_BACKOFF_FACTOR,
        )
        return True

    def _handle_kv_event(self, event):
        logging.debug(f"KV cache event received: {event}")
        # drop the events that is not emitted from the global attention layer.
        if self.should_drop_event(event):
            return

        event_id = event["event_id"]
538
539
540
541
542
543
544
545
546
547

        # Check for consecutive event IDs from the engine
        if self._last_engine_event_id is not None:
            expected_id = self._last_engine_event_id + 1
            if event_id != expected_id:
                logging.warning(
                    f"Non-consecutive engine event_id: expected {expected_id}, got {event_id}"
                )
        self._last_engine_event_id = event_id

548
549
550
551
552
553
554
        data = event["data"]
        if data["type"] == "stored":
            self.processing_initial_created_events = False
            parent_hash = _to_signed_i64(data["parent_hash"])
            token_ids: list[int] = []
            num_block_tokens: list[int] = []
            block_hashes: list[int] = []
555
            block_mm_infos: list[dict | None] = []
556
557
558
559
560
561
            for block in data["blocks"]:
                token_num_in_block = len(block["tokens"])
                block_hash = _to_signed_i64(block["block_hash"])
                if token_num_in_block > self.kv_block_size:
                    logging.error(
                        f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self.kv_block_size}"
562
                    )
563
564
565
566
                    return
                if block_hash is None:
                    logging.warning(
                        f"Skipping block with None hash containing {token_num_in_block} tokens"
567
                    )
568
569
570
571
                    continue
                if token_num_in_block < self.kv_block_size:
                    logging.debug(
                        f"Early stop when block {block_hash} containing {token_num_in_block} tokens not equal to kv_block_size {self.kv_block_size}"
572
                    )
573
574
575
576
577
578
579
                    self.partial_block_hashes.add(block_hash)
                    break
                num_block_tokens.append(token_num_in_block)
                block_hashes.append(block_hash)
                for token in block["tokens"]:
                    token_ids.append(int(token["token_id"]))

580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
                # Extract multimodal hash info for this block
                # {"mm_keys": [{"type":"mm_key","hash":"<hex>","start_offset":N}]}
                mm_keys = block.get("mm_keys", [])
                mm_hashes = [
                    int(mm_key["hash"][:16], 16)
                    for mm_key in mm_keys
                    if mm_key.get("type") == "mm_key" and mm_key.get("hash")
                ]
                if mm_hashes:
                    block_mm_infos.append(
                        {
                            "mm_objects": [
                                {"mm_hash": mm_hash, "offsets": []}
                                for mm_hash in mm_hashes
                            ]
                        }
                    )
                else:
                    block_mm_infos.append(None)

600
601
602
603
604
            # Note: Currently data does not have lora_id.
            # Using 0 as default value. If later data has
            # lora_id, we need to verify if this is correct.
            lora_id = data.get("lora_id", 0)

605
606
607
608
            # Get attention_dp_rank from event (TRT-LLM includes this in KVCacheEvent)
            # Default to 0 for backwards compatibility with older TRT-LLM versions
            attention_dp_rank = event.get("attention_dp_rank", 0)

609
            logging.debug(
610
                f"publish stored event: engine_event_id: {event_id}, attention_dp_rank: {attention_dp_rank}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
611
612
            )
            # Publish to ZMQ if consolidator is enabled, otherwise publish to NATS
613
            # Note: event_id is managed internally by the publisher (monotonic counter per dp_rank)
614
615
616
617
618
619
620
621
            if self.zmq_kv_event_publisher:
                # Consolidator enabled: publish to ZMQ only
                self.zmq_kv_event_publisher.publish_stored(
                    token_ids,
                    num_block_tokens,
                    block_hashes,
                    lora_id,
                    parent_hash,
622
                    block_mm_infos,
623
                    attention_dp_rank,
624
                )
625
            elif self.kv_event_publishers:
626
                # No consolidator: publish to NATS (router subscribes directly)
627
628
629
630
631
632
633
634
635
                # Route to correct publisher based on attention_dp_rank
                publisher = self.kv_event_publishers.get(attention_dp_rank)
                if publisher:
                    publisher.publish_stored(
                        token_ids,
                        num_block_tokens,
                        block_hashes,
                        lora_id,
                        parent_hash,
636
                        block_mm_infos,
637
638
639
640
641
642
                    )
                else:
                    logging.warning(
                        f"No publisher for attention_dp_rank={attention_dp_rank}, "
                        f"available ranks: {list(self.kv_event_publishers.keys())}"
                    )
643
644
645
646
647
648
649
650
651
652
        elif data["type"] == "removed":
            self.processing_initial_created_events = False
            removed_block_hashes: list[int] = []
            for block_hash in data["block_hashes"]:
                block_hash = _to_signed_i64(block_hash)
                if block_hash is None:
                    continue
                if block_hash in self.partial_block_hashes:
                    logging.debug(
                        f"Skipping removing block hash {block_hash} since it is a partial block"
653
                    )
654
655
656
                    self.partial_block_hashes.remove(block_hash)
                    continue
                removed_block_hashes.append(block_hash)
657

658
659
660
            # Get attention_dp_rank from event (TRT-LLM includes this in KVCacheEvent)
            attention_dp_rank = event.get("attention_dp_rank", 0)

661
            logging.debug(
662
                f"publish removed event: engine_event_id: {event_id}, attention_dp_rank: {attention_dp_rank}, block_hashes: {removed_block_hashes}"
663
664
            )
            # Publish to ZMQ if consolidator is enabled, otherwise publish to NATS
665
            # Note: event_id is managed internally by the publisher (monotonic counter per dp_rank)
666
667
            if self.zmq_kv_event_publisher:
                # Consolidator enabled: publish to ZMQ only
668
669
670
671
                self.zmq_kv_event_publisher.publish_removed(
                    removed_block_hashes, attention_dp_rank
                )
            elif self.kv_event_publishers:
672
                # No consolidator: publish to NATS (router subscribes directly)
673
674
675
676
677
678
679
680
681
                # Route to correct publisher based on attention_dp_rank
                publisher = self.kv_event_publishers.get(attention_dp_rank)
                if publisher:
                    publisher.publish_removed(removed_block_hashes)
                else:
                    logging.warning(
                        f"No publisher for attention_dp_rank={attention_dp_rank}, "
                        f"available ranks: {list(self.kv_event_publishers.keys())}"
                    )
682
683
        elif data["type"] == "created" and self.processing_initial_created_events:
            self.update_max_window_size(event)
684

685
    def start(self):
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
        if (
            self.publish_kv_cache_events_thread
            and not self.publish_kv_cache_events_thread.is_alive()
        ):
            # REVISIT
            # [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
            self._stats_loop = asyncio.get_running_loop()
            self.publish_kv_cache_events_thread.set_loop(self._stats_loop)
            self.publish_kv_cache_events_thread.start()
            logging.debug("Started kv cache events thread")

        if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
            self._stats_loop = asyncio.get_running_loop()
            self.publish_stats_thread.set_loop(self._stats_loop)
            self.publish_stats_thread.start()
            logging.debug("Started stats thread")

    def check_error_queue(self):
        if not self.error_queue.empty():
            logging.error("Error in publishers error queue")
            return self.error_queue.get()
        return None

    async def cleanup(self):
        """Cleanup threads and resources"""
        self._stop_event.set()
        # Add timeout to prevent hanging
        cleanup_timeout = 5.0  # seconds

        if self.publish_stats_thread and self.publish_stats_thread.is_alive():
            self.publish_stats_thread.stop()
            self.publish_stats_thread.join(timeout=cleanup_timeout)
            if self.publish_stats_thread.is_alive():
                logging.warning("Stats thread did not stop within timeout")

        if (
            self.publish_kv_cache_events_thread
            and self.publish_kv_cache_events_thread.is_alive()
        ):
            self.publish_kv_cache_events_thread.stop()
            self.publish_kv_cache_events_thread.join(timeout=cleanup_timeout)
            if self.publish_kv_cache_events_thread.is_alive():
                logging.warning("KV cache events thread did not stop within timeout")
729

730
731
732
733
        # Shutdown ZMQ publisher if it exists
        if self.zmq_kv_event_publisher:
            self.zmq_kv_event_publisher.shutdown()

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
    def update_max_window_size(self, event):
        if "window_size" in event:
            window_size = event["window_size"]
            if self.max_window_size is None or window_size > self.max_window_size:
                self.max_window_size = window_size
                logging.debug(
                    f"kv events max_window_size has been updated to {self.max_window_size}"
                )

    # The global attention layer will emit the KV event with the max_window_size.
    # We only want to keep the KV event that has the max_window_size to ensure
    # the accuracy of KV routing.
    # TRTLLM emits a "created" event at the very beginning when it creates the KV cache,
    # so we can use the "created" event to identify the max_window_size of the global
    # attention layer in the model engine.
    def should_drop_event(self, event):
        # There are two cases for KV event filtering:
        #
        # 1. If "window_size" is NOT in the KV event:
        #    "window_size" was added to KV events only recently, so some older versions of TRTLLM
        #    might not include it. In this case, the publisher will assume that all events are
        #    from the global attention layer.
        #
        # 2. If "window_size" is present in the KV event:
        #    The publisher will not drop any KV events until all initial "created" KV events
        #    have been processed in order to capture the max_window_size.
        #    After processing all "created" events, the publisher will only accept KV events
        #    whose window_size is equal to the max_window_size to ensure accurate routing.
        if "window_size" not in event or self.processing_initial_created_events:
            return False

        if event["window_size"] != self.max_window_size:
            return True

        return False

770
771

@asynccontextmanager
772
async def get_publisher(
773
    endpoint,
774
775
776
777
    engine,
    worker_id,
    kv_block_size,
    metrics_labels,
778
    component_gauges: LLMBackendMetrics,
779
    zmq_endpoint: Optional[str] = None,
780
    enable_local_indexer: bool = False,
781
    metrics_collector=None,
782
783
):
    publisher = Publisher(
784
        endpoint,
785
786
787
788
        engine,
        worker_id,
        kv_block_size,
        metrics_labels,
789
        component_gauges=component_gauges,
790
        zmq_endpoint=zmq_endpoint,
791
        enable_local_indexer=enable_local_indexer,
792
        metrics_collector=metrics_collector,
793
    )
794
795
796
797
798
799
800
801
    try:
        publisher.initialize()
        yield publisher
    except Exception as e:
        logging.error(f"Error in engine context: {e}")
        raise
    finally:
        await publisher.cleanup()