Commit 11e3e188 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: LLM API integration with smart routing bits (#55)


Co-authored-by: default avatarShreyas Misra <shreyasm@nvidia.com>
parent ec46ed52
...@@ -242,7 +242,7 @@ For example, 2 TP2 generation servers are 2 servers but 4 workers/mpi executor. ...@@ -242,7 +242,7 @@ For example, 2 TP2 generation servers are 2 servers but 4 workers/mpi executor.
cd /workspace/examples/python_rs/llm/tensorrt_llm/ cd /workspace/examples/python_rs/llm/tensorrt_llm/
mpirun --allow-run-as-root --oversubscribe -n WORLD_SIZE python3 -m disaggregated.worker --engine_args llm_api_config.yaml -c disaggregated/llmapi_disaggregated_configs/single_node_config.yaml 1>disagg_workers.log 2>&1 & mpirun --allow-run-as-root --oversubscribe -n WORLD_SIZE python3 -m disaggregated.worker --engine_args llm_api_config.yaml -c disaggregated/llmapi_disaggregated_configs/single_node_config.yaml 1>disagg_workers.log 2>&1 &
``` ```
If using the provided [single_node_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_config.yaml), WORLD_SIZE should be 3 as it has 2 context servers(TP=1) and 1 generation server(TP=1). If using the provided [single_node_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_config.yaml), WORLD_SIZE should be 2 as it has 1 context servers(TP=1) and 1 generation server(TP=1).
2. **Launch the router** 2. **Launch the router**
...@@ -251,6 +251,8 @@ cd /workspace/examples/python_rs/llm/tensorrt_llm/ ...@@ -251,6 +251,8 @@ cd /workspace/examples/python_rs/llm/tensorrt_llm/
python3 -m disaggregated.router 1>router.log 2>&1 & python3 -m disaggregated.router 1>router.log 2>&1 &
``` ```
Note: For KV cache aware routing, please refer to the [KV Aware Routing](./docs/kv_aware_routing.md) section.
3. **Send Requests** 3. **Send Requests**
Follow the instructions in the [Monolithic Deployment](#3-client) section to send requests to the router. Follow the instructions in the [Monolithic Deployment](#3-client) section to send requests to the router.
......
...@@ -17,20 +17,26 @@ ...@@ -17,20 +17,26 @@
import asyncio import asyncio
import threading import threading
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass
from queue import Queue
from typing import Any, Optional from typing import Any, Optional
from common.parser import LLMAPIConfig from common.parser import LLMAPIConfig
from common.processor import ChatProcessor, CompletionsProcessor from common.processor import ChatProcessor, CompletionsProcessor
from common.utils import ManagedThread
from tensorrt_llm._torch import LLM from tensorrt_llm._torch import LLM
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from transformers import AutoTokenizer from transformers import AutoTokenizer
from dynamo.llm import KvMetricsPublisher
class BaseTensorrtLLMEngine: from .kv_cache_event_publisher import KVCacheEventPublisher
class ChatProcessorMixin:
def __init__(self, engine_config: LLMAPIConfig): def __init__(self, engine_config: LLMAPIConfig):
self._engine_config = engine_config self._engine_config = engine_config
logger.info(f"Using LLM API config: {self._engine_config}") logger.info(f"Using LLM API config: {self._engine_config.to_dict()}")
# model name for chat processor # model name for chat processor
self._model_name = self._engine_config.model_name self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}") logger.info(f"Set model name: {self._model_name}")
...@@ -49,8 +55,6 @@ class BaseTensorrtLLMEngine: ...@@ -49,8 +55,6 @@ class BaseTensorrtLLMEngine:
self._engine_config.model_name self._engine_config.model_name
) )
self._init_engine()
if self._engine_config.extra_args.get("tokenizer", None): if self._engine_config.extra_args.get("tokenizer", None):
self._tokenizer = AutoTokenizer.from_pretrained( self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.extra_args.get("tokenizer", None) self._engine_config.extra_args.get("tokenizer", None)
...@@ -59,6 +63,34 @@ class BaseTensorrtLLMEngine: ...@@ -59,6 +63,34 @@ class BaseTensorrtLLMEngine:
self.chat_processor = ChatProcessor(self._model_name, self._tokenizer) self.chat_processor = ChatProcessor(self._model_name, self._tokenizer)
self.completions_processor = CompletionsProcessor(self._model_name) self.completions_processor = CompletionsProcessor(self._model_name)
@dataclass
class TensorrtLLMEngineConfig:
namespace_str: str = "dynamo"
component_str: str = "tensorrt-llm"
engine_config: LLMAPIConfig = None
worker_id: Optional[str] = None
kv_metrics_publisher: Optional[KvMetricsPublisher] = None
publish_stats: bool = False
publish_kv_cache_events: bool = False
class BaseTensorrtLLMEngine(ChatProcessorMixin):
def __init__(
self,
trt_llm_engine_config: TensorrtLLMEngineConfig,
):
super().__init__(trt_llm_engine_config.engine_config)
self._namespace_str = trt_llm_engine_config.namespace_str
self._component_str = trt_llm_engine_config.component_str
self._worker_id = trt_llm_engine_config.worker_id
self._kv_metrics_publisher = trt_llm_engine_config.kv_metrics_publisher
self._publish_stats = trt_llm_engine_config.publish_stats
self._publish_kv_cache_events = trt_llm_engine_config.publish_kv_cache_events
self._error_queue: Optional[Queue] = None
self._init_engine()
def _init_engine(self): def _init_engine(self):
logger.info("Initializing engine") logger.info("Initializing engine")
# Run the engine in a separate thread running the AsyncIO event loop. # Run the engine in a separate thread running the AsyncIO event loop.
...@@ -68,6 +100,10 @@ class BaseTensorrtLLMEngine: ...@@ -68,6 +100,10 @@ class BaseTensorrtLLMEngine:
self._event_thread = threading.Thread( self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),) target=asyncio.run, args=(self._run_llm_engine(),)
) )
self.publish_kv_cache_events_thread = None
self.publish_stats_thread = None
self._event_thread.start() self._event_thread.start()
with self._llm_engine_start_cv: with self._llm_engine_start_cv:
while self._llm_engine is None: while self._llm_engine is None:
...@@ -83,6 +119,142 @@ class BaseTensorrtLLMEngine: ...@@ -83,6 +119,142 @@ class BaseTensorrtLLMEngine:
self._event_thread = None self._event_thread = None
raise e raise e
self._error_queue = Queue()
try:
if self._publish_stats:
self._init_publish_metrics_thread()
if self._publish_kv_cache_events:
self._init_publish_kv_cache_events_thread()
except Exception as e:
logger.error(f"Failed to initialize publish metrics threads: {e}")
raise e
def _init_publish_metrics_thread(self):
# Need to publish stats once so that worker can be selected.
# Publishing some dummy values...
request_active_slots = 0
request_total_slots = 4
kv_active_block = 0
kv_total_blocks = 4
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return
self._kv_metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
)
# Prepare threads for publishing stats but don't start them yet.
# TRTLLM needs to start generating tokens first before stats
# can be retrieved.
self.publish_stats_thread = ManagedThread(
self.publish_stats_task,
error_queue=self._error_queue,
name="publish_stats_thread",
)
def _init_publish_kv_cache_events_thread(self):
if self._worker_id is None:
logger.error("Worker ID not initialized!")
return
# TODO: Use python bindings to publish kv cache events once they
# are available.
lib_path = "/opt/dynamo/bindings/lib/libdynamo_llm_capi.so"
self._kv_cache_events_publisher = KVCacheEventPublisher(
self._namespace_str, self._component_str, int(self._worker_id), lib_path
)
# Prepare threads for publishing kv cache events but don't start them yet.
# TRTLLM needs to start generating tokens first before kv cache events
# can be retrieved.
self.publish_kv_cache_events_thread = ManagedThread(
self.publish_kv_cache_events_task,
error_queue=self._error_queue,
name="publish_kv_cache_events_thread",
)
async def publish_stats_task(self):
"""
Publish stats to the metrics publisher.
"""
if self._llm_engine is None:
logger.error("LLM engine not initialized!")
return
stats = self._llm_engine.get_stats_async(timeout=5)
async for stat in stats:
request_active_slots = stat["numActiveRequests"]
request_total_slots = stat["maxNumActiveRequests"]
kv_active_block = stat["kvCacheStats"]["usedNumBlocks"]
kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"]
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return False
self._kv_metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
)
logger.debug(
f"Published stats: request_active_slots: {request_active_slots}, request_total_slots: {request_total_slots}, kv_active_block: {kv_active_block}, kv_total_blocks: {kv_total_blocks}"
)
return True
async def publish_kv_cache_events_task(self):
"""
Publish kv cache events to the events publisher.
"""
if self._llm_engine is None:
logger.error("LLM engine not initialized!")
return
events = self._llm_engine.get_kv_cache_events_async(timeout=5)
async for event_list in events:
for event in event_list:
logger.debug(f"Received event from llmapi: {event}")
id = event["event_id"]
data = event["data"]
if data["type"] == "stored":
parent_hash = data["parent_hash"]
token_ids = []
block_hashes = []
for block in data["blocks"]:
block_hash = block["block_hash"]
block_hashes.append(block_hash)
for token in block["tokens"]:
# TODO: How to handle token_extra_id?
token_ids.append(token["token_id"])
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id = data.get("lora_id", 0)
# Publish the stored event
self._kv_cache_events_publisher.stored_event(
id, parent_hash, block_hashes, token_ids, lora_id
)
logger.debug(
f"Published stored event: {id}, parent_hash: {parent_hash}, block_hashes: {block_hashes}, token_ids: {token_ids}"
)
elif data["type"] == "removed":
# Publish the removed event
block_hashes = []
for block_hash in data["block_hashes"]:
block_hashes.append(block_hash)
self._kv_cache_events_publisher.removed_event(id, block_hashes)
logger.debug(
f"Published removed event: {id}, block_hashes: {block_hashes}"
)
return True
async def _run_llm_engine(self): async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts. # Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0 self._ongoing_request_count = 0
...@@ -117,6 +289,17 @@ class BaseTensorrtLLMEngine: ...@@ -117,6 +289,17 @@ class BaseTensorrtLLMEngine:
# Wait for the engine shutdown signal. # Wait for the engine shutdown signal.
await self._llm_engine_shutdown_event.wait() await self._llm_engine_shutdown_event.wait()
# Stop the publishing threads
if self.publish_stats_thread and self.publish_stats_thread.is_alive():
self.publish_stats_thread.stop()
self.publish_stats_thread.join()
if (
self.publish_kv_cache_events_thread
and self.publish_kv_cache_events_thread.is_alive()
):
self.publish_kv_cache_events_thread.stop()
self.publish_kv_cache_events_thread.join()
# Wait for the ongoing requests to complete. # Wait for the ongoing requests to complete.
while self._ongoing_request_count > 0: while self._ongoing_request_count > 0:
logger.info( logger.info(
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
from ctypes import c_char_p, c_int64, c_uint32
from tensorrt_llm.logger import logger
logger.set_level("debug")
class DynamoResult:
OK = 0
ERR = 1
class KVCacheEventPublisher:
def __init__(self, namespace: str, component: str, worker_id: int, lib_path: str):
self.lib = None
try:
self.lib = ctypes.CDLL(lib_path)
self.lib.dynamo_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
self.lib.dynamo_llm_init.restype = c_uint32
result = self.lib.dynamo_llm_init(
namespace.encode(), component.encode(), worker_id
)
if result == DynamoResult.OK:
logger.info(
"KVCacheEventPublisher initialized successfully. Ready to publish KV Cache Events"
)
else:
logger.info("KVCacheEventPublisher initialization failed!")
except Exception as e:
print(f"Failed to load {lib_path}")
raise e
self.lib.dynamo_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
ctypes.POINTER(ctypes.c_uint64), # parent_hash
ctypes.c_uint64, # lora_id
]
self.lib.dynamo_kv_event_publish_stored.restype = (
ctypes.c_uint32
) # dynamo_llm_result_t
self.lib.dynamo_kv_event_publish_removed.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
]
self.lib.dynamo_kv_event_publish_removed.restype = (
ctypes.c_uint32
) # dynamo_llm_result_t
def stored_event(self, event_id, parent_hash, block_hashes, token_ids, lora_id):
if self.lib is None:
logger.error("KVCacheEventPublisher not initialized!")
return
logger.debug(
f"Stored event: {event_id}, parent_hash: {parent_hash}, block_hashes: {block_hashes}, token_ids: {token_ids}"
)
parent_hash = (
(ctypes.c_uint64 * 1)(parent_hash) if parent_hash is not None else None
)
block_hash_arr = (ctypes.c_uint64 * len(block_hashes))(*block_hashes)
block_hash_len = len(block_hashes)
token_ids_arr = (ctypes.c_uint32 * len(token_ids))(*token_ids)
num_block_tokens = (ctypes.c_size_t * 1)(len(token_ids))
# Publish the event
# TODO: Currently, lora_id is not available in the stored events.
result = self.lib.dynamo_kv_event_publish_stored(
event_id, # uint64_t event_id
token_ids_arr, # const uint32_t *token_ids
num_block_tokens, # const uintptr_t *num_block_tokens
block_hash_arr, # const uint64_t *block_ids
block_hash_len, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
lora_id, # uint64_t lora_id
)
if result == DynamoResult.OK:
logger.debug(f"Store - Published KV Event: {block_hashes}")
else:
logger.error(f"Store - Failed to Publish KV Event: {block_hashes}")
def removed_event(self, event_id, block_hashes):
if self.lib is None:
logger.error("KVCacheEventPublisher not initialized!")
return
result = self.lib.dynamo_kv_event_publish_removed(
event_id,
(ctypes.c_uint64 * len(block_hashes))(*block_hashes),
(ctypes.c_size_t * 1)(len(block_hashes)),
)
if result == DynamoResult.OK:
logger.debug(f"Remove - Published KV Event: {block_hashes}")
else:
logger.error(f"Remove - Failed to Publish KV Event: {block_hashes}")
...@@ -118,5 +118,15 @@ def parse_tensorrt_llm_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any] ...@@ -118,5 +118,15 @@ def parse_tensorrt_llm_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]
help="Path to the llmapi disaggregated config file", help="Path to the llmapi disaggregated config file",
default=None, default=None,
) )
parser.add_argument(
"--publish-kv-cache-events",
action="store_true",
help="Publish KV cache events from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
)
parser.add_argument(
"--publish-stats",
action="store_true",
help="Publish stats from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
)
args = parser.parse_args() args = parser.parse_args()
return (args, _init_engine_args(args.engine_args)) return (args, _init_engine_args(args.engine_args))
...@@ -29,6 +29,10 @@ from tensorrt_llm.serve.openai_protocol import ( ...@@ -29,6 +29,10 @@ from tensorrt_llm.serve.openai_protocol import (
) )
class Tokens(BaseModel):
tokens: list[int]
class Request(BaseModel): class Request(BaseModel):
prompt: str prompt: str
sampling_params: dict sampling_params: dict
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import threading
import traceback
import weakref
from queue import Queue
from typing import Callable, Optional, Union
from tensorrt_llm.logger import logger
class ManagedThread(threading.Thread):
def __init__(
self,
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]],
error_queue: Optional[Queue] = None,
name: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs,
):
super().__init__(name=name)
self.task = task
self.error_queue = error_queue
self.kwargs = kwargs
self.loop = loop
self.daemon = True
self.stop_event = threading.Event()
def set_loop(self, loop: asyncio.AbstractEventLoop):
self.loop = loop
def run(self):
while not self.stop_event.is_set():
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]] = self.task
if isinstance(task, weakref.WeakMethod):
task = task()
if task is None:
# Normally, this should not happen.
logger.warning("WeakMethod is expired.")
break
if task is None:
break
try:
if self.loop is None:
logger.error("[ManagedThread] Loop not initialized!")
break
future = asyncio.run_coroutine_threadsafe(
task(**self.kwargs), self.loop
)
_ = future.result()
except Exception as e:
logger.error(
f"Error in thread {self.name}: {e}\n{traceback.format_exc()}"
)
if self.error_queue is not None:
self.error_queue.put(e)
logger.info(f"Thread {self.name} stopped.")
def stop(self):
self.stop_event.set()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import json
import traceback
from typing import AsyncIterator
import uvloop
from common.base_engine import ChatProcessorMixin
from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from common.protocol import (
DisaggChatCompletionRequest,
DisaggChatCompletionStreamResponse,
DisaggCompletionStreamResponse,
Tokens,
)
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import CompletionRequest, DisaggregatedParams
from dynamo.llm import KvRouter
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
logger.set_level("debug")
class Scheduler:
def __init__(self, kv_router: KvRouter):
self.kv_router = kv_router
@dynamo_endpoint(Tokens, str)
async def generate(self, request) -> AsyncIterator[str]:
lora_id = 0
worker_id = None
try:
worker_id = await self.kv_router.schedule(request.tokens, lora_id)
except Exception:
logger.warning(f"Error during worker selection: {traceback.format_exc()}")
worker_id = ""
logger.debug(f"Scheduling to worker_id: {worker_id}")
yield str(worker_id)
class Router(ChatProcessorMixin):
def __init__(
self,
ctx_chat_client,
gen_chat_client,
ctx_completion_client,
gen_completion_client,
scheduler: Scheduler,
engine_config: LLMAPIConfig,
):
self.ctx_chat_client = ctx_chat_client
self.gen_chat_client = gen_chat_client
self.ctx_completion_client = ctx_completion_client
self.gen_completion_client = gen_completion_client
self.scheduler = scheduler
# allows to use tokenizer
super().__init__(engine_config)
logger.info("INITIALIZED ROUTER")
async def _get_ctx_resp(self, request, ctx_client):
logger.debug(f"Received request {request}")
# NOTE: this will increase TTFT since we are encoding the prompt here
# prompt is also encoded in the worker.
# TODO: we need to implement our own request processing and protocols to send only token ids to llmapi worker.
token_ids = self._tokenizer.encode(request.prompt)
worker_id_generator: AsyncIterator = self.scheduler.generate(
Tokens(tokens=token_ids).model_dump_json()
)
worker_id = (
await worker_id_generator.__anext__()
) # only one worker id is returned
request.max_tokens = 1
request.disaggregated_params = DisaggregatedParams(request_type="context_only")
logger.debug(f"[router] Sending request to context server: {request}")
if worker_id == "":
ctx_resp = [
resp
async for resp in await ctx_client.random(request.model_dump_json())
]
else:
ctx_resp = [
resp
async for resp in await ctx_client.direct(
request.model_dump_json(), int(worker_id)
)
]
if len(ctx_resp) > 1:
raise ValueError(
"Context server returned more than one response. This is currently not supported in disaggregated server."
)
logger.debug(
f"[router] received response from context server: {ctx_resp[0].data()}"
)
return ctx_resp[0].data()
# TODO (shreyasm): The only reason we cant further combine the two methods below is
# because the disagg params are in different locations.
# Disagg params should be in under the choices field in the response object.
# This is the case for completions but not for chat.
@dynamo_endpoint(CompletionRequest, DisaggCompletionStreamResponse)
async def generate_completion(self, request):
# These settings are needed to satisfy request checks.
request.skip_special_tokens = False
request.add_special_tokens = False
request.spaces_between_special_tokens = False
gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(request, self.ctx_completion_client)
ctx_resp_obj = DisaggCompletionStreamResponse.model_validate(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate(
ctx_resp_obj.choices[0].disaggregated_params
)
gen_req.disaggregated_params.request_type = "generation_only"
if request.stream:
yield json.loads(
ctx_resp_obj.model_dump_json(
exclude_unset=True, exclude={"disaggregated_params"}
)
)
logger.debug(f"[router] Sending request to generation server: {gen_req}")
async for response in await self.gen_completion_client.round_robin(
gen_req.model_dump_json()
):
gen_resp_obj = DisaggCompletionStreamResponse.model_validate(
response.data()
)
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
@dynamo_endpoint(DisaggChatCompletionRequest, DisaggChatCompletionStreamResponse)
async def generate_chat(self, request):
# These settings are needed to satisfy request checks.
request.skip_special_tokens = False
request.add_special_tokens = False
request.spaces_between_special_tokens = False
gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(request, self.ctx_chat_client)
ctx_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate(
ctx_resp_obj.disaggregated_params
)
gen_req.disaggregated_params.request_type = "generation_only"
if request.stream:
yield json.loads(
ctx_resp_obj.model_dump_json(
exclude_unset=True, exclude={"disaggregated_params"}
)
)
logger.debug(f"[router] Sending request to generation server: {gen_req}")
async for response in await self.gen_chat_client.round_robin(
gen_req.model_dump_json()
):
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate(
response.data()
)
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
@dynamo_worker()
async def worker(runtime: DistributedRuntime, args, engine_config):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("dynamo").component("router")
await component.create_service()
ctx_completion_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-ctx")
.endpoint("completions")
.client()
)
gen_completion_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-gen")
.endpoint("completions")
.client()
)
ctx_chat_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-ctx")
.endpoint("chat/completions")
.client()
)
gen_chat_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-gen")
.endpoint("chat/completions")
.client()
)
# Only listen to context server for now
kv_listener = runtime.namespace("dynamo").component("tensorrt-llm-ctx")
await kv_listener.create_service()
kv_router = KvRouter(runtime, kv_listener)
completions_endpoint = component.endpoint("completions")
chat_endpoint = component.endpoint("chat/completions")
scheduler = Scheduler(kv_router)
router = Router(
ctx_chat_client,
gen_chat_client,
ctx_completion_client,
gen_completion_client,
scheduler,
engine_config,
)
await asyncio.gather(
completions_endpoint.serve_endpoint(router.generate_completion),
chat_endpoint.serve_endpoint(router.generate_chat),
)
if __name__ == "__main__":
uvloop.install()
args, engine_config = parse_tensorrt_llm_args()
asyncio.run(worker(args, engine_config))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This will overwrite the llm_api_config.yaml
hostname: localhost
port: 8000
context_servers:
num_instances: 4
tensor_parallel_size: 1
moe_expert_parallel_size: 1
kv_cache_config:
free_gpu_memory_fraction: 0.45
event_buffer_max_size: 1024
enable_block_reuse: true
pytorch_backend_config:
enable_overlap_scheduler: false
use_cuda_graph: false
enable_iter_perf_stats: true
urls:
- "localhost:8001"
- "localhost:8002"
- "localhost:8003"
- "localhost:8004"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
moe_expert_parallel_size: 1
kv_cache_config:
free_gpu_memory_fraction: 0.95
pytorch_backend_config:
enable_overlap_scheduler: true
use_cuda_graph: true
urls:
- "localhost:8005"
...@@ -13,14 +13,13 @@ ...@@ -13,14 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import json import json
import os import os
import signal import signal
import uvloop import uvloop
from common.base_engine import BaseTensorrtLLMEngine from common.base_engine import BaseTensorrtLLMEngine, TensorrtLLMEngineConfig
from common.disagg_processor import ChatProcessor, parse_chat_message_content from common.disagg_processor import ChatProcessor, parse_chat_message_content
from common.parser import LLMAPIConfig, parse_tensorrt_llm_args from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from common.processor import merge_promises from common.processor import merge_promises
...@@ -44,6 +43,7 @@ from tensorrt_llm.llmapi.disagg_utils import ( ...@@ -44,6 +43,7 @@ from tensorrt_llm.llmapi.disagg_utils import (
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import CompletionRequest from tensorrt_llm.serve.openai_protocol import CompletionRequest
from dynamo.llm import KvMetricsPublisher
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
logger.set_level("debug") logger.set_level("debug")
...@@ -66,7 +66,7 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine): ...@@ -66,7 +66,7 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
def __init__( def __init__(
self, self,
engine_config: LLMAPIConfig, trt_llm_engine_config: TensorrtLLMEngineConfig,
disagg_config: DisaggServerConfig, disagg_config: DisaggServerConfig,
instance_idx: int, instance_idx: int,
sub_comm, sub_comm,
...@@ -77,19 +77,28 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine): ...@@ -77,19 +77,28 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
instance_idx instance_idx
] ]
engine_config = update_args_from_disagg_config( engine_config = update_args_from_disagg_config(
engine_config, self.server_config trt_llm_engine_config.engine_config, self.server_config
) )
trt_llm_engine_config.engine_config = engine_config
# needed for disagg # needed for disagg
self._mpi_session = MpiCommSession(sub_comm, n_workers=sub_comm.Get_size()) self._mpi_session = MpiCommSession(sub_comm, n_workers=sub_comm.Get_size())
engine_config.extra_args["_mpi_session"] = self._mpi_session trt_llm_engine_config.engine_config.extra_args[
super().__init__(engine_config) "_mpi_session"
] = self._mpi_session
super().__init__(trt_llm_engine_config)
@dynamo_endpoint(DisaggChatCompletionRequest, DisaggChatCompletionStreamResponse) @dynamo_endpoint(DisaggChatCompletionRequest, DisaggChatCompletionStreamResponse)
async def generate_chat(self, request): async def generate_chat(self, request):
if self._llm_engine is None: if self._llm_engine is None:
raise RuntimeError("Engine not initialized") raise RuntimeError("Engine not initialized")
# Check if there are any errors in the error queue.
if self._error_queue.qsize() > 0:
error = self._error_queue.get()
raise error
logger.debug(f"Received request: {request}") logger.debug(f"Received request: {request}")
chat_processor = ChatProcessor(self._model, self._tokenizer, request) chat_processor = ChatProcessor(self._model, self._tokenizer, request)
...@@ -127,6 +136,7 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine): ...@@ -127,6 +136,7 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
streaming=request.stream, streaming=request.stream,
disaggregated_params=disaggregated_params, disaggregated_params=disaggregated_params,
): ):
self.generate_event.set()
final_result = result final_result = result
logger.debug(f"Generated result: {result}") logger.debug(f"Generated result: {result}")
if self.server_config.type == "ctx": if self.server_config.type == "ctx":
...@@ -162,13 +172,30 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine): ...@@ -162,13 +172,30 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
except Exception as e: except Exception as e:
raise RuntimeError("Failed to generate: " + str(e)) raise RuntimeError("Failed to generate: " + str(e))
# Start the publishing threads with first request submission
self._stats_loop = asyncio.get_running_loop()
if (
self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive()
):
self.publish_kv_cache_events_thread.start()
if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
self.publish_stats_thread.start()
self._ongoing_request_count -= 1 self._ongoing_request_count -= 1
@dynamo_endpoint(CompletionRequest, DisaggCompletionStreamResponse) @dynamo_endpoint(CompletionRequest, DisaggCompletionStreamResponse)
async def generate_completions(self, request): async def generate_completions(self, request):
logger.debug(f"[worker] worker_id: {self._worker_id} received request")
if self._llm_engine is None: if self._llm_engine is None:
raise RuntimeError("Engine not initialized") raise RuntimeError("Engine not initialized")
# Check if there are any errors in the error queue.
if self._error_queue.qsize() > 0:
error = self._error_queue.get()
raise error
self._ongoing_request_count += 1 self._ongoing_request_count += 1
logger.debug(f"[worker] Received completions request: {request}") logger.debug(f"[worker] Received completions request: {request}")
...@@ -208,6 +235,21 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine): ...@@ -208,6 +235,21 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
else: else:
raise RuntimeError("Non-streaming is not supported") raise RuntimeError("Non-streaming is not supported")
# Start the publishing threads with first request submission
if (
self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive()
):
# [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
self._stats_loop = asyncio.get_running_loop()
self.publish_kv_cache_events_thread.set_loop(self._stats_loop)
self.publish_kv_cache_events_thread.start()
if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
self._stats_loop = asyncio.get_running_loop()
self.publish_stats_thread.set_loop(self._stats_loop)
self.publish_stats_thread.start()
self._ongoing_request_count -= 1 self._ongoing_request_count -= 1
...@@ -218,6 +260,8 @@ async def worker( ...@@ -218,6 +260,8 @@ async def worker(
disagg_config: DisaggServerConfig, disagg_config: DisaggServerConfig,
instance_idx: int, instance_idx: int,
sub_comm, sub_comm,
publish_stats: bool,
publish_kv_cache_events: bool,
): ):
""" """
Instantiate a `backend` component and serve the `generate` endpoint Instantiate a `backend` component and serve the `generate` endpoint
...@@ -226,17 +270,58 @@ async def worker( ...@@ -226,17 +270,58 @@ async def worker(
server_type = disagg_config.server_configs[instance_idx].type server_type = disagg_config.server_configs[instance_idx].type
logger.info(f"Starting {server_type} server") logger.info(f"Starting {server_type} server")
component = runtime.namespace("dynamo").component(f"tensorrt-llm-{server_type}") namespace_str = "dynamo"
component_str = f"tensorrt-llm-{server_type}"
component = runtime.namespace(namespace_str).component(component_str)
await component.create_service() await component.create_service()
completions_endpoint = component.endpoint("completions") completions_endpoint = component.endpoint("completions")
chat_endpoint = component.endpoint("chat/completions") chat_endpoint = component.endpoint("chat/completions")
engine = TensorrtLLMEngine(engine_config, disagg_config, instance_idx, sub_comm)
await asyncio.gather( if server_type == "gen":
if publish_stats:
logger.warning("Stats can only be published for ctx server")
publish_stats = False
if publish_kv_cache_events:
logger.warning("KV cache events can only be published for ctx server")
publish_kv_cache_events = False
trt_llm_engine_config = TensorrtLLMEngineConfig(
namespace_str=namespace_str,
component_str=component_str,
engine_config=engine_config,
publish_stats=publish_stats,
publish_kv_cache_events=publish_kv_cache_events,
)
# NOTE: Current implementation adds two endpoints. We can refactor this code to expose only one endpoint.
# and handle both completions and chat in the same endpoint.
# Currently, we are using completions endpoint lease id as worker id.
# I believe this might cause some issues using smart routing with chat completions endpoint.
trt_llm_engine_config.worker_id = completions_endpoint.lease_id()
if publish_stats:
trt_llm_engine_config.kv_metrics_publisher = KvMetricsPublisher()
engine = TensorrtLLMEngine(
trt_llm_engine_config,
disagg_config,
instance_idx,
sub_comm,
)
coros = [
completions_endpoint.serve_endpoint(engine.generate_completions), completions_endpoint.serve_endpoint(engine.generate_completions),
chat_endpoint.serve_endpoint(engine.generate_chat), chat_endpoint.serve_endpoint(engine.generate_chat),
]
if publish_stats:
coros.append(
trt_llm_engine_config.kv_metrics_publisher.create_endpoint(component)
) )
await asyncio.gather(*coros)
if __name__ == "__main__": if __name__ == "__main__":
uvloop.install() uvloop.install()
...@@ -262,7 +347,16 @@ if __name__ == "__main__": ...@@ -262,7 +347,16 @@ if __name__ == "__main__":
logger.info(f"is_leader: {is_leader}, instance_idx: {instance_idx}") logger.info(f"is_leader: {is_leader}, instance_idx: {instance_idx}")
if is_leader: if is_leader:
asyncio.run(worker(engine_config, disagg_config, instance_idx, sub_comm)) asyncio.run(
worker(
engine_config,
disagg_config,
instance_idx,
sub_comm,
args.publish_stats,
args.publish_kv_cache_events,
)
)
else: else:
with MPICommExecutor(sub_comm) as executor: with MPICommExecutor(sub_comm) as executor:
if not is_leader and executor is not None: if not is_leader and executor is not None:
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# KV Aware Routing
This document describes how to use the KV aware routing feature in Dynamo with TensorRT LLM disaggregated serving.
The KV Router is a component that aggregates KV Events from all the workers and maintains a prefix tree of the cached tokens. It makes decisions on which worker to route requests to based on the length of the prefix match and the load on the workers.
## KV Aware Routing with Disaggregated Serving
Follow the instructions in the [README](../README.md) to setup the environment for [disaggregated serving](../README.md#disaggregated-deployment).
All of the steps remain the same except launching the [workers and the router](../README.md#workers).
### 1. Workers
To launch the workers and the router, run the following command:
```bash
cd /workspace/examples/python_rs/llm/tensorrt_llm/
mpirun --allow-run-as-root --oversubscribe -n 5 python3 -m disaggregated.worker --publish-stats --publish-kv-cache-events --engine_args llm_api_config.yaml -c disaggregated/llmapi_disaggregated_configs/single_node_kv_aware_config.yaml 1>disagg_workers.log 2>&1 &
```
Note the extra arguments `--publish-stats` and `--publish-kv-cache-events` to publish the stats and kv cache events from the workers for effective routing.
The config file [single_node_kv_aware_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_kv_aware_config.yaml) specifies extra configuration for the LLM execution engine to support stats and kv cache events collection. These configurations are:
1. `enable_iter_perf_stats` in `pytorch_backend_config` to enable the iteration performance stats collection.
2. `event_buffer_max_size` in `kv_cache_config` to specify the maximum number of events that can be stored in the buffer.
3. `enable_block_reuse` in `kv_cache_config` to enable the block reuse feature for improved performance.
Note: The configuration also specifies 4 context servers and 1 generation server.
### 2. Router
To launch the router, run the following command:
```bash
cd /workspace/examples/python_rs/llm/tensorrt_llm/
python3 -m disaggregated.kv_router --engine_args llm_api_config.yaml 1>kv_router.log 2>&1 &
```
The router will route the incoming requests to the appropriate context server based on the stats and kv cache events.
### 3. Send Requests
Follow the instructions in the [README](../README.md#send-requests) to send requests to the [HTTP server](../README.md#http-server).
...@@ -29,7 +29,12 @@ backend: pytorch ...@@ -29,7 +29,12 @@ backend: pytorch
kv_cache_config: kv_cache_config:
free_gpu_memory_fraction: 0.95 free_gpu_memory_fraction: 0.95
# Uncomment to enable kv cache event collection
#event_buffer_max_size: 1024
#enable_block_reuse: true
pytorch_backend_config: pytorch_backend_config:
enable_overlap_scheduler: false enable_overlap_scheduler: false
use_cuda_graph: false use_cuda_graph: false
# Uncomment to enable iter perf stats
#enable_iter_perf_stats: true
\ No newline at end of file
...@@ -20,7 +20,7 @@ import signal ...@@ -20,7 +20,7 @@ import signal
import uuid import uuid
import uvloop import uvloop
from common.base_engine import BaseTensorrtLLMEngine from common.base_engine import BaseTensorrtLLMEngine, TensorrtLLMEngineConfig
from common.parser import LLMAPIConfig, parse_tensorrt_llm_args from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from common.processor import merge_promises, parse_chat_message_content from common.processor import merge_promises, parse_chat_message_content
from tensorrt_llm.executor import CppExecutorError from tensorrt_llm.executor import CppExecutorError
...@@ -42,8 +42,8 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine): ...@@ -42,8 +42,8 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
Request handler for the generate endpoint Request handler for the generate endpoint
""" """
def __init__(self, engine_config: LLMAPIConfig): def __init__(self, trt_llm_engine_config: TensorrtLLMEngineConfig):
super().__init__(engine_config) super().__init__(trt_llm_engine_config)
@dynamo_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse) @dynamo_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate_chat(self, request): async def generate_chat(self, request):
...@@ -146,13 +146,21 @@ async def worker(runtime: DistributedRuntime, engine_config: LLMAPIConfig): ...@@ -146,13 +146,21 @@ async def worker(runtime: DistributedRuntime, engine_config: LLMAPIConfig):
Instantiate a `backend` component and serve the `generate` endpoint Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints A `Component` can serve multiple endpoints
""" """
component = runtime.namespace("dynamo").component("tensorrt-llm") namespace_str = "dynamo"
component_str = "tensorrt-llm"
component = runtime.namespace(namespace_str).component(component_str)
await component.create_service() await component.create_service()
completions_endpoint = component.endpoint("completions") completions_endpoint = component.endpoint("completions")
chat_completions_endpoint = component.endpoint("chat/completions") chat_completions_endpoint = component.endpoint("chat/completions")
engine = TensorrtLLMEngine(engine_config) trt_llm_engine_config = TensorrtLLMEngineConfig(
namespace_str=namespace_str,
component_str=component_str,
engine_config=engine_config,
)
engine = TensorrtLLMEngine(trt_llm_engine_config)
await asyncio.gather( await asyncio.gather(
completions_endpoint.serve_endpoint(engine.generate_completion), completions_endpoint.serve_endpoint(engine.generate_completion),
......
...@@ -96,6 +96,7 @@ xfail_strict = true ...@@ -96,6 +96,7 @@ xfail_strict = true
log_cli_level = "INFO" log_cli_level = "INFO"
filterwarnings = [ filterwarnings = [
"error", "error",
"ignore:.*cuda*:DeprecationWarning", # Need this to avoid deprecation warnings from CUDA in tensorrt_llm.
"ignore:.*pkg_resources.*:DeprecationWarning", "ignore:.*pkg_resources.*:DeprecationWarning",
"ignore:.*multipart.*:PendingDeprecationWarning" "ignore:.*multipart.*:PendingDeprecationWarning"
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment