Unverified Commit 9acaa8d1 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: Add metrics and event publishers (#1192)

parent b8272a98
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from tensorrt_llm import LLM
logging.basicConfig(level=logging.DEBUG)
class TensorRTLLMEngine:
def __init__(self, engine_args):
self.engine_args = engine_args
self._llm: Optional[LLM] = None
async def initialize(self):
if not self._llm:
model = self.engine_args.pop("model")
self._llm = LLM(
model=model,
**self.engine_args,
)
async def cleanup(self):
if self._llm:
try:
self._llm.shutdown()
except Exception as e:
logging.error(f"Error during cleanup: {e}")
finally:
self._llm = None
@property
def llm(self):
if not self._llm:
raise RuntimeError("Engine not initialized")
return self._llm
@asynccontextmanager
async def get_llm_engine(engine_args) -> AsyncGenerator[TensorRTLLMEngine, None]:
engine = TensorRTLLMEngine(engine_args)
try:
await engine.initialize()
yield engine
except Exception as e:
logging.error(f"Error in engine context: {e}")
raise
finally:
await engine.cleanup()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import concurrent.futures
import logging
import threading
import traceback
import weakref
from queue import Queue
from typing import Callable, Optional, Union
from dynamo.llm import KvEventPublisher, KvMetricsPublisher
logging.basicConfig(level=logging.DEBUG)
class ManagedThread(threading.Thread):
"""
A thread that runs a task and handles errors.
"""
def __init__(
self,
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]],
error_queue: Optional[Queue] = None,
name: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs,
):
super().__init__(name=name)
self.task = task
self.error_queue = error_queue
self.kwargs = kwargs
self.loop = loop
self.daemon = True
self._current_future: Optional[concurrent.futures.Future] = None
self._stop_event = threading.Event()
def set_loop(self, loop: asyncio.AbstractEventLoop):
self.loop = loop
def run(self):
while not self._stop_event.is_set():
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]] = self.task
if isinstance(task, weakref.WeakMethod):
task = task()
if task is None:
# Normally, this should not happen.
logging.warning("WeakMethod is expired.")
break
if task is None:
break
try:
if self.loop is None:
logging.error("[ManagedThread] Loop not initialized!")
break
self._current_future = asyncio.run_coroutine_threadsafe(
task(**self.kwargs), self.loop
)
_ = self._current_future.result()
except (asyncio.CancelledError, concurrent.futures.CancelledError):
logging.debug(f"Thread {self.name} was cancelled")
break
except Exception as e:
logging.error(
f"Error in thread {self.name}: {e}\n{traceback.format_exc()}"
)
if self.error_queue is not None:
self.error_queue.put(e)
logging.info(f"Thread {self.name} stopped.")
def stop(self):
self._stop_event.set()
if self._current_future and not self._current_future.done():
self._current_future.cancel()
class Publishers:
"""
A class to retrieve stats and kv cache events from TRTLLM engine and publish them to the metrics and events publishers.
"""
def __init__(self, component, engine, kv_listener, worker_id, kv_block_size):
self.component = component
self.engine = engine
self.kv_listener = kv_listener
self.worker_id = worker_id
self.kv_block_size = kv_block_size
# Needed by the events and metrics publishers
self.metrics_publisher = None
self.kv_event_publisher = None
self.publish_kv_cache_events_thread = None
self.publish_stats_thread = None
# A set to store the block hash of partial block (i.e. block containing less than kv_block_size tokens) hashes.
# It is used to prevent sending remove event to kv router since partial blocks are not stored.
self.partial_block_hashes = set()
self.error_queue: Queue = Queue()
self._stop_event = threading.Event()
self._setup()
async def _create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint")
if self.metrics_publisher is None:
logging.error("KV metrics publisher not initialized!")
return
await self.metrics_publisher.create_endpoint(self.component)
def _setup(self):
# Setup the metrics publisher
self.metrics_publisher = KvMetricsPublisher()
self._init_publish_metrics_thread()
task = asyncio.create_task(self._create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
)
# Setup the kv cache events publisher
self.kv_event_publisher = KvEventPublisher(
self.kv_listener, self.worker_id, self.kv_block_size
)
self._init_publish_kv_cache_events_thread()
def _init_publish_metrics_thread(self):
# Need to publish stats once so that worker can be selected.
# Publishing some dummy values...
request_active_slots = 0
request_total_slots = 4
kv_active_block = 0
kv_total_blocks = 4
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
if self.metrics_publisher is None:
logging.error("KV metrics publisher not initialized!")
return
self.metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
# Prepare threads for publishing stats but don't start them yet.
# TRTLLM needs to start generating tokens first before stats
# can be retrieved.
self.publish_stats_thread = ManagedThread(
self._publish_stats_task,
error_queue=self.error_queue,
name="publish_stats_thread",
)
def _init_publish_kv_cache_events_thread(self):
if self.kv_event_publisher is None:
logging.error("KV event publisher not initialized!")
return
# Prepare threads for publishing kv cache events but don't start them yet.
# TRTLLM needs to start generating tokens first before kv cache events
# can be retrieved.
self.publish_kv_cache_events_thread = ManagedThread(
self._publish_kv_cache_events_task,
error_queue=self.error_queue,
name="publish_kv_cache_events_thread",
)
async def _publish_stats_task(self):
"""
Publish stats to the metrics publisher.
"""
if self.engine is None:
logging.error("LLM engine not initialized!")
return
if self.metrics_publisher is None:
logging.error("KV metrics publisher not initialized!")
return False
stats = self.engine.llm.get_stats_async(timeout=5)
async for stat in stats:
request_active_slots = stat["numActiveRequests"]
request_total_slots = stat["maxNumActiveRequests"]
kv_active_block = stat["kvCacheStats"]["usedNumBlocks"]
kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"]
reused_blocks = stat["kvCacheStats"]["reusedBlocks"]
freeNumBlocks = stat["kvCacheStats"]["freeNumBlocks"]
allocTotalBlocks = stat["kvCacheStats"]["allocTotalBlocks"]
allocNewBlocks = stat["kvCacheStats"]["allocNewBlocks"]
# NOTE: num paused requests is always 0 when using guarantee no evict scheduler (default).
num_requests_waiting = (
stat["numQueuedRequests"]
+ stat["inflightBatchingStats"]["numPausedRequests"]
)
gpu_cache_usage_perc = allocTotalBlocks / kv_total_blocks
gpu_prefix_cache_hit_rate = stat["kvCacheStats"]["cacheHitRate"]
logging.debug(
f"Publishing stats: request_active_slots: {request_active_slots}, request_total_slots: {request_total_slots}, kv_active_block: {kv_active_block}, kv_total_blocks: {kv_total_blocks}, num_requests_waiting: {num_requests_waiting}, reused_blocks: {reused_blocks}, freeNumBlocks: {freeNumBlocks}, allocTotalBlocks: {allocTotalBlocks}, allocNewBlocks: {allocNewBlocks}, gpu_cache_usage_perc: {gpu_cache_usage_perc}, gpu_prefix_cache_hit_rate: {gpu_prefix_cache_hit_rate}"
)
self.metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
return True
async def _publish_kv_cache_events_task(self):
"""
Publish kv cache events to the events publisher.
"""
if self.engine is None:
logging.error("LLM engine not initialized!")
return
if self.kv_event_publisher is None:
logging.error("KV event publisher not initialized!")
return
events = self.engine.llm.get_kv_cache_events_async(timeout=5)
async for event in events:
logging.debug(f"KV cache event received: {event}")
event_id = event["event_id"]
data = event["data"]
if data["type"] == "stored":
parent_hash = data["parent_hash"]
token_ids = []
num_block_tokens = []
block_hashes = []
for block in data["blocks"]:
token_num_in_block = len(block["tokens"])
block_hash = block["block_hash"]
if token_num_in_block > self.kv_block_size:
logging.error(
f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self.kv_block_size}"
)
return
if token_num_in_block < self.kv_block_size:
logging.debug(
f"Early stop when block {block_hash} containing {token_num_in_block} tokens not equal to kv_block_size {self.kv_block_size}"
)
self.partial_block_hashes.add(block_hash)
break
num_block_tokens.append(token_num_in_block)
block_hashes.append(block_hash)
for token in block["tokens"]:
token_ids.append(int(token["token_id"]))
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id = data.get("lora_id", 0)
logging.debug(
f"publish stored event: event_id: {event_id}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
)
self.kv_event_publisher.publish_stored(
event_id,
token_ids,
num_block_tokens,
block_hashes,
lora_id,
parent_hash,
)
elif data["type"] == "removed":
block_hashes = []
for block_hash in data["block_hashes"]:
if block_hash in self.partial_block_hashes:
logging.debug(
f"Skipping removing block hash {block_hash} since it is a partial block"
)
self.partial_block_hashes.remove(block_hash)
continue
block_hashes.append(block_hash)
logging.debug(
f"publish removed event: event_id: {event_id}, block_hashes: {block_hashes}"
)
self.kv_event_publisher.publish_removed(event_id, block_hashes)
return True
def start_publish_threads(self):
if (
self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive()
):
# REVISIT
# [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
self._stats_loop = asyncio.get_running_loop()
self.publish_kv_cache_events_thread.set_loop(self._stats_loop)
self.publish_kv_cache_events_thread.start()
logging.debug("Started kv cache events thread")
if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
self._stats_loop = asyncio.get_running_loop()
self.publish_stats_thread.set_loop(self._stats_loop)
self.publish_stats_thread.start()
logging.debug("Started stats thread")
def check_error_queue(self):
if not self.error_queue.empty():
logging.error("Error in publishers error queue")
return self.error_queue.get()
return None
async def cleanup(self):
"""Cleanup threads and resources"""
self._stop_event.set()
# Add timeout to prevent hanging
cleanup_timeout = 5.0 # seconds
if self.publish_stats_thread and self.publish_stats_thread.is_alive():
self.publish_stats_thread.stop()
self.publish_stats_thread.join(timeout=cleanup_timeout)
if self.publish_stats_thread.is_alive():
logging.warning("Stats thread did not stop within timeout")
if (
self.publish_kv_cache_events_thread
and self.publish_kv_cache_events_thread.is_alive()
):
self.publish_kv_cache_events_thread.stop()
self.publish_kv_cache_events_thread.join(timeout=cleanup_timeout)
if self.publish_kv_cache_events_thread.is_alive():
logging.warning("KV cache events thread did not stop within timeout")
...@@ -12,17 +12,18 @@ import argparse ...@@ -12,17 +12,18 @@ import argparse
import asyncio import asyncio
import logging import logging
import sys import sys
from contextlib import asynccontextmanager from typing import Optional
from typing import AsyncGenerator, Optional
import uvloop import uvloop
# Import TRTLLM and related modules # Import TRTLLM and related modules
from tensorrt_llm import LLM, LlmArgs, SamplingParams from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from trtllm.engine import get_llm_engine
from trtllm.publishers import Publishers
from dynamo.llm import KvMetricsPublisher, ModelType, register_llm from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
# Only used if you run it manually from the command line # Only used if you run it manually from the command line
...@@ -30,6 +31,8 @@ DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" ...@@ -30,6 +31,8 @@ DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
# Qwen/Qwen3-0.6B is not supported by TRTLLM yet. # Qwen/Qwen3-0.6B is not supported by TRTLLM yet.
DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -45,6 +48,7 @@ class Config: ...@@ -45,6 +48,7 @@ class Config:
tensor_parallel_size: int tensor_parallel_size: int
kv_block_size: int kv_block_size: int
extra_engine_args: str extra_engine_args: str
publish_events_and_metrics: bool
class RequestHandler: class RequestHandler:
...@@ -52,34 +56,19 @@ class RequestHandler: ...@@ -52,34 +56,19 @@ class RequestHandler:
Request handler for the generate endpoint Request handler for the generate endpoint
""" """
def __init__(self, component, engine, default_sampling_params): def __init__(self, component, engine, default_sampling_params, publishers):
self.engine = engine self.engine = engine
self.component = component self.component = component
self.default_sampling_params = default_sampling_params self.default_sampling_params = default_sampling_params
self.metrics_publisher = KvMetricsPublisher() self.publishers = publishers
self.first_generation = True
def setup_kv_metrics(self):
# Initially send dummy metrics to kick start,
# TRTLLM will not update stat until forward pass is triggered
self.metrics_publisher.publish(
0, # request_active_slots
1024, # request_total_slots
0, # kv_active_blocks
1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
)
async def create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint")
await self.metrics_publisher.create_endpoint(self.component)
async def generate(self, request): async def generate(self, request):
# Check if there is an error in the publishers error queue
publishers_error = self.publishers.check_error_queue()
if publishers_error:
raise publishers_error
inputs = request["token_ids"] inputs = request["token_ids"]
sampling_params = self.default_sampling_params sampling_params = self.default_sampling_params
...@@ -98,6 +87,12 @@ class RequestHandler: ...@@ -98,6 +87,12 @@ class RequestHandler:
async for res in self.engine.llm.generate_async( async for res in self.engine.llm.generate_async(
inputs=inputs, sampling_params=sampling_params, streaming=True inputs=inputs, sampling_params=sampling_params, streaming=True
): ):
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publishers:
self.publishers.start_publish_threads()
self.first_generation = False
if res.finished: if res.finished:
yield {"finish_reason": "stop", "token_ids": []} yield {"finish_reason": "stop", "token_ids": []}
break break
...@@ -122,50 +117,6 @@ async def worker(runtime: DistributedRuntime): ...@@ -122,50 +117,6 @@ async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args()) await init(runtime, cmd_line_args())
class AsyncLLMEngine:
def __init__(self, engine_args):
self.engine_args = engine_args
self._llm: Optional[LLM] = None
self._initialized = False
async def initialize(self):
if not self._initialized:
model = self.engine_args.pop("model")
self._llm = LLM(
model=model,
**self.engine_args,
)
self._initialized = True
async def cleanup(self):
if self._initialized:
try:
self._llm.shutdown()
except Exception as e:
logging.error(f"Error during cleanup: {e}")
finally:
self._initialized = False
@property
def llm(self):
if not self._initialized:
raise RuntimeError("Engine not initialized")
return self._llm
@asynccontextmanager
async def get_llm_engine(engine_args: LlmArgs) -> AsyncGenerator[AsyncLLMEngine, None]:
engine = AsyncLLMEngine(engine_args)
try:
await engine.initialize()
yield engine
except Exception as e:
logging.error(f"Error in engine context: {e}")
raise
finally:
await engine.cleanup()
async def init(runtime: DistributedRuntime, config: Config): async def init(runtime: DistributedRuntime, config: Config):
""" """
Instantiate and serve Instantiate and serve
...@@ -187,8 +138,28 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -187,8 +138,28 @@ async def init(runtime: DistributedRuntime, config: Config):
} }
if config.extra_engine_args != "": if config.extra_engine_args != "":
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args) arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
if config.publish_events_and_metrics:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config = None
if "kv_cache_config" not in arg_map:
kv_cache_config = {}
kv_cache_config["event_buffer_max_size"] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
else:
kv_cache_config = arg_map["kv_cache_config"]
if not kv_cache_config.event_buffer_max_size:
kv_cache_config.event_buffer_max_size = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
arg_map["kv_cache_config"] = kv_cache_config
# Only pytorch backend is supported for now to publish events and metrics.
if "backend" not in arg_map:
arg_map["backend"] = "pytorch"
elif arg_map["backend"] != "pytorch":
logging.error(
"Only pytorch backend is supported for now to publish events and metrics."
)
sys.exit(1)
logging.debug(f"TRTLLM engine args: {arg_map}") logging.info(f"TRTLLM engine args: {arg_map}")
engine_args = arg_map engine_args = arg_map
# Populate default sampling params from the model # Populate default sampling params from the model
...@@ -202,12 +173,29 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -202,12 +173,29 @@ async def init(runtime: DistributedRuntime, config: Config):
await register_llm( await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name ModelType.Backend, endpoint, config.model_path, config.model_name
) )
handler = RequestHandler(component, engine, default_sampling_params)
handler.setup_kv_metrics()
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes) publishers = None
# after the lease is revoked if config.publish_events_and_metrics:
await endpoint.serve_endpoint(handler.generate) kv_listener = runtime.namespace(config.namespace).component(
config.component
)
publishers = Publishers(
component,
engine,
kv_listener,
int(endpoint.lease_id()),
config.kv_block_size,
)
handler = RequestHandler(component, engine, default_sampling_params, publishers)
try:
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(handler.generate)
finally:
if publishers:
await publishers.cleanup()
def cmd_line_args(): def cmd_line_args():
...@@ -235,6 +223,8 @@ def cmd_line_args(): ...@@ -235,6 +223,8 @@ def cmd_line_args():
parser.add_argument( parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use." "--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
) )
# IMPORTANT: We should ideally not expose this to users. We should be able to
# query the block size from the TRTLLM engine.
parser.add_argument( parser.add_argument(
"--kv-block-size", type=int, default=32, help="Size of a KV cache block." "--kv-block-size", type=int, default=32, help="Size of a KV cache block."
) )
...@@ -244,6 +234,11 @@ def cmd_line_args(): ...@@ -244,6 +234,11 @@ def cmd_line_args():
default="", default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.", help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
) )
parser.add_argument(
"--publish-events-and-metrics",
action="store_true",
help="Publish events and metrics to the dynamo components.",
)
args = parser.parse_args() args = parser.parse_args()
config = Config() config = Config()
...@@ -270,10 +265,14 @@ def cmd_line_args(): ...@@ -270,10 +265,14 @@ def cmd_line_args():
config.tensor_parallel_size = args.tensor_parallel_size config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size config.kv_block_size = args.kv_block_size
config.extra_engine_args = args.extra_engine_args config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics
return config return config
if __name__ == "__main__": if __name__ == "__main__":
uvloop.install() uvloop.install()
asyncio.run(worker()) try:
asyncio.run(worker())
except KeyboardInterrupt:
logging.info("Received SIGINT, shutting down...")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment