"examples/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "30c5a79f9071a4b45692847a1a1fba4ebee3f6eb"
Unverified Commit 3c1c2ac3 authored by Ziqi Fan's avatar Ziqi Fan Committed by GitHub
Browse files

refactor: change trtllm example kv routing use python bindings | deal with...

refactor: change trtllm example kv routing use python bindings | deal with trtllm partial blocks | trtllm event change (#866)
parent 6630fa5c
...@@ -233,8 +233,6 @@ RUN . /opt/dynamo/venv/bin/activate && \ ...@@ -233,8 +233,6 @@ RUN . /opt/dynamo/venv/bin/activate && \
RUN pip install dist/ai_dynamo_runtime*cp312*.whl && \ RUN pip install dist/ai_dynamo_runtime*cp312*.whl && \
pip install dist/ai_dynamo*any.whl pip install dist/ai_dynamo*any.whl
# Tell TRTLLM worker to use the Dynamo LLM C API for KV Cache Routing
ENV DYNAMO_KV_CAPI_PATH="/opt/dynamo/bindings/lib/libdynamo_llm_capi.so"
ENV DYNAMO_HOME=/workspace ENV DYNAMO_HOME=/workspace
# Use UCX for TRTLLM KV Cache Transfer # Use UCX for TRTLLM KV Cache Transfer
......
...@@ -143,6 +143,8 @@ dynamo serve graphs.disagg_router:Frontend -f ./configs/disagg_router.yaml ...@@ -143,6 +143,8 @@ dynamo serve graphs.disagg_router:Frontend -f ./configs/disagg_router.yaml
We are defining TRTLLM_USE_UCX_KVCACHE so that TRTLLM uses UCX for transfering the KV We are defining TRTLLM_USE_UCX_KVCACHE so that TRTLLM uses UCX for transfering the KV
cache between the context and generation workers. cache between the context and generation workers.
NOTE: currently disaggregated serving with KV Routing may not work due to prefix cache hit is showing 0, though when it should not.
### Client ### Client
See [client](../llm/README.md#client) section to learn how to send request to the deployment. See [client](../llm/README.md#client) section to learn how to send request to the deployment.
......
...@@ -42,9 +42,8 @@ from tensorrt_llm.llmapi.disagg_utils import ( ...@@ -42,9 +42,8 @@ from tensorrt_llm.llmapi.disagg_utils import (
) )
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
from dynamo.llm import KvMetricsPublisher from dynamo.llm import KvEventPublisher, KvMetricsPublisher
from dynamo.sdk import dynamo_context
from .kv_cache_event_publisher import KVCacheEventPublisher
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -137,6 +136,19 @@ class BaseTensorrtLLMEngine: ...@@ -137,6 +136,19 @@ class BaseTensorrtLLMEngine:
if self._publish_stats: if self._publish_stats:
self._kv_metrics_publisher = KvMetricsPublisher() self._kv_metrics_publisher = KvMetricsPublisher()
if self._publish_events:
if self._worker_id is None:
raise ValueError("Worker ID is None!")
runtime = dynamo_context["runtime"]
kv_listener = runtime.namespace(self._namespace_str).component(
self._component_str
)
self._kv_event_publisher = KvEventPublisher(
kv_listener, int(self._worker_id), self._kv_block_size
)
logger.info("KvEventPublisher is initialized")
self._engine_config = engine_config self._engine_config = engine_config
def _init_engine(self): def _init_engine(self):
...@@ -170,11 +182,15 @@ class BaseTensorrtLLMEngine: ...@@ -170,11 +182,15 @@ class BaseTensorrtLLMEngine:
try: try:
if self._publish_stats: if self._publish_stats:
self._init_publish_metrics_thread() self._init_publish_metrics_thread()
except Exception as e:
logger.error(f"Failed to initialize publish metrics threads: {e}")
raise e
try:
if self._publish_events: if self._publish_events:
self._init_publish_kv_cache_events_thread() self._init_publish_kv_cache_events_thread()
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize publish metrics threads: {e}") logger.error(f"Failed to initialize publish events threads: {e}")
raise e raise e
def _init_publish_metrics_thread(self): def _init_publish_metrics_thread(self):
...@@ -216,20 +232,13 @@ class BaseTensorrtLLMEngine: ...@@ -216,20 +232,13 @@ class BaseTensorrtLLMEngine:
) )
def _init_publish_kv_cache_events_thread(self): def _init_publish_kv_cache_events_thread(self):
if self._worker_id is None: if self._kv_event_publisher is None:
logger.error("Worker ID not initialized!") logger.error("KV event publisher not initialized!")
return return
# TODO: Use python bindings to publish kv cache events once they # A set to store the block hash of partial block (i.e. block containing less than kv_block_size tokens) hashes.
# are available. # It is used to prevent sending remove event to kv router since partial blocks are not stored.
lib_path = "/opt/dynamo/bindings/lib/libdynamo_llm_capi.so" self._partial_block_hashes = set()
self._kv_cache_events_publisher = KVCacheEventPublisher(
self._namespace_str,
self._component_str,
int(self._worker_id),
lib_path,
self._kv_block_size,
)
# Prepare threads for publishing kv cache events but don't start them yet. # Prepare threads for publishing kv cache events but don't start them yet.
# TRTLLM needs to start generating tokens first before kv cache events # TRTLLM needs to start generating tokens first before kv cache events
...@@ -295,30 +304,56 @@ class BaseTensorrtLLMEngine: ...@@ -295,30 +304,56 @@ class BaseTensorrtLLMEngine:
return return
events = self._llm_engine.get_kv_cache_events_async(timeout=5) events = self._llm_engine.get_kv_cache_events_async(timeout=5)
async for event_list in events: async for event in events:
for event in event_list: event_id = event["event_id"]
data = event["data"] data = event["data"]
if data["type"] == "stored": if data["type"] == "stored":
parent_hash = data["parent_hash"] parent_hash = data["parent_hash"]
for block in data["blocks"]: token_ids = []
tokens = [] num_block_tokens = []
for token in block["tokens"]: block_hashes = []
tokens.append(int(token["token_id"])) for block in data["blocks"]:
token_num_in_block = len(block["tokens"])
# Note: Currently data does not have lora_id. block_hash = block["block_hash"]
# Using 0 as default value. If later data has if token_num_in_block > self._kv_block_size:
# lora_id, we need to verify if this is correct. logger.error(
lora_id = data.get("lora_id", 0) f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self._kv_block_size}"
self._kv_cache_events_publisher.stored_event( )
parent_hash, return
block["block_hash"], if token_num_in_block < self._kv_block_size:
tokens, logger.debug(
lora_id, f"Early stop when block {block_hash} containing {token_num_in_block} tokens not equal to kv_block_size {self._kv_block_size}"
)
self._partial_block_hashes.add(block_hash)
break
num_block_tokens.append(token_num_in_block)
block_hashes.append(block_hash)
for token in block["tokens"]:
token_ids.append(int(token["token_id"]))
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id = data.get("lora_id", 0)
self._kv_event_publisher.publish_stored(
event_id,
token_ids,
num_block_tokens,
block_hashes,
lora_id,
parent_hash,
)
elif data["type"] == "removed":
block_hashes = []
for block_hash in data["block_hashes"]:
if block_hash in self._partial_block_hashes:
logger.debug(
f"Skipping removing block hash {block_hash} since it is a partial block"
) )
parent_hash = block["block_hash"] self._partial_block_hashes.remove(block_hash)
elif data["type"] == "removed": continue
for block_hash in data["block_hashes"]: block_hashes.append(block_hash)
self._kv_cache_events_publisher.removed_event(block_hash) self._kv_event_publisher.publish_removed(event_id, block_hashes)
return True return True
def _start_threads(self): def _start_threads(self):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
import logging
from ctypes import c_char_p, c_int64, c_uint32
logger = logging.getLogger(__name__)
class DynamoResult:
OK = 0
ERR = 1
class KVCacheEventPublisher:
def __init__(
self,
namespace: str,
component: str,
worker_id: int,
lib_path: str,
kv_block_size: int,
):
self.lib = None
try:
self.lib = ctypes.CDLL(lib_path)
self.lib.dynamo_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
self.lib.dynamo_llm_init.restype = c_uint32
result = self.lib.dynamo_llm_init(
namespace.encode(), component.encode(), worker_id, kv_block_size
)
if result == DynamoResult.OK:
logger.info(
"KVCacheEventPublisher initialized successfully. Ready to publish KV Cache Events"
)
else:
logger.info("KVCacheEventPublisher initialization failed!")
except Exception as e:
logger.exception(f"Failed to load {lib_path}")
raise e
self.lib.dynamo_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
ctypes.POINTER(ctypes.c_uint64), # parent_hash
ctypes.c_uint64, # lora_id
]
self.lib.dynamo_kv_event_publish_stored.restype = (
ctypes.c_uint32
) # dynamo_llm_result_t
self.lib.dynamo_kv_event_publish_removed.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
]
self.lib.dynamo_kv_event_publish_removed.restype = (
ctypes.c_uint32
) # dynamo_llm_result_t
self._event_counter = 0
def stored_event(self, parent_hash, block_hash, token_ids, lora_id):
if self.lib is None:
logger.error("KVCacheEventPublisher not initialized!")
return
logger.debug(
f"Stored parent_hash: {parent_hash}, block_hash: {block_hash}, token_ids: {token_ids}"
)
parent_hash = (
(ctypes.c_uint64 * 1)(parent_hash) if parent_hash is not None else None
)
token_ids_arr = (ctypes.c_uint32 * len(token_ids))(*token_ids)
num_block_tokens = (ctypes.c_size_t * 1)(len(token_ids))
block_hash = (ctypes.c_uint64 * 1)(block_hash)
result = self.lib.dynamo_kv_event_publish_stored(
self._event_counter, # uint64_t event_id
token_ids_arr, # const uint32_t *token_ids
num_block_tokens, # const uintptr_t *num_block_tokens
block_hash, # const uint64_t *block_ids
1, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
lora_id, # uint64_t lora_id
)
self._event_counter += 1
if result == DynamoResult.OK:
logger.debug(f"Store - Published KV Event: {block_hash}")
else:
logger.error(f"Store - Failed to Publish KV Event: {block_hash}")
def removed_event(self, block_hash):
if self.lib is None:
logger.error("KVCacheEventPublisher not initialized!")
return
result = self.lib.dynamo_kv_event_publish_removed(
self._event_counter,
(ctypes.c_uint64 * 1)(block_hash),
1,
)
self._event_counter += 1
if result == DynamoResult.OK:
logger.debug(f"Remove - Published KV Event: {block_hash}")
else:
logger.error(f"Remove - Failed to Publish KV Event: {block_hash}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment