Commit 8435b993 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

fix: Fix TRTLLM chat to work with latest ToT (#127)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 089f8e1b
......@@ -138,6 +138,10 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
kv_active_block = 0
kv_total_blocks = 4
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return
......@@ -147,6 +151,9 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
# Prepare threads for publishing stats but don't start them yet.
......@@ -197,11 +204,20 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
logger.error("KV metrics publisher not initialized!")
return False
# TODO: Remove this once we have the actual values.
# Adding dummy values for now so it doesn't break the metrics.
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
self._kv_metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
logger.debug(
f"Published stats: request_active_slots: {request_active_slots}, request_total_slots: {request_total_slots}, kv_active_block: {kv_active_block}, kv_total_blocks: {kv_total_blocks}"
......
......@@ -128,5 +128,11 @@ def parse_tensorrt_llm_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]
action="store_true",
help="Publish stats from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
)
parser.add_argument(
"--kv-block-size",
type=int,
help="KV block size for TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
default=64,
)
args = parser.parse_args()
return (args, _init_engine_args(args.engine_args))
......@@ -15,6 +15,7 @@
import asyncio
import copy
import enum
import json
import traceback
from typing import AsyncIterator
......@@ -22,6 +23,7 @@ from typing import AsyncIterator
import uvloop
from common.base_engine import ChatProcessorMixin
from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from common.processor import parse_chat_message_content
from common.protocol import (
DisaggChatCompletionRequest,
DisaggChatCompletionStreamResponse,
......@@ -37,6 +39,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
logger.set_level("debug")
class EndpointType(enum.Enum):
chat = "chat"
completion = "completion"
class Scheduler:
def __init__(self, kv_router: KvRouter):
self.kv_router = kv_router
......@@ -77,13 +84,32 @@ class Router(ChatProcessorMixin):
logger.info("INITIALIZED ROUTER")
async def _get_ctx_resp(self, request, ctx_client):
async def _get_ctx_resp(self, request, ctx_client, endpoint_type: EndpointType):
logger.debug(f"Received request {request}")
# NOTE: this will increase TTFT since we are encoding the prompt here
# prompt is also encoded in the worker.
# TODO: we need to implement our own request processing and protocols to send only token ids to llmapi worker.
if endpoint_type == EndpointType.completion:
token_ids = self._tokenizer.encode(request.prompt)
else:
conversation = []
for message in request.messages:
conversation.extend(parse_chat_message_content(message))
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
token_ids = self._tokenizer.apply_chat_template(
conversation=conversation,
tokenize=True,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
worker_id_generator: AsyncIterator = self.scheduler.generate(
Tokens(tokens=token_ids).model_dump_json()
)
......@@ -92,7 +118,7 @@ class Router(ChatProcessorMixin):
await worker_id_generator.__anext__()
) # only one worker id is returned
request.max_tokens = 1
request.max_completion_tokens = 1
request.disaggregated_params = DisaggregatedParams(request_type="context_only")
logger.debug(f"[router] Sending request to context server: {request}")
......@@ -132,7 +158,9 @@ class Router(ChatProcessorMixin):
gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(request, self.ctx_completion_client)
ctx_resp = await self._get_ctx_resp(
request, self.ctx_completion_client, EndpointType.completion
)
ctx_resp_obj = DisaggCompletionStreamResponse.model_validate(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate(
......@@ -165,7 +193,9 @@ class Router(ChatProcessorMixin):
gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(request, self.ctx_chat_client)
ctx_resp = await self._get_ctx_resp(
request, self.ctx_chat_client, EndpointType.chat
)
ctx_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate(
......@@ -184,7 +214,7 @@ class Router(ChatProcessorMixin):
async for response in await self.gen_chat_client.round_robin(
gen_req.model_dump_json()
):
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate(
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(
response.data()
)
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
......@@ -228,7 +258,7 @@ async def worker(runtime: DistributedRuntime, args, engine_config):
kv_listener = runtime.namespace("dynamo").component("tensorrt-llm-ctx")
await kv_listener.create_service()
kv_router = KvRouter(runtime, kv_listener)
kv_router = KvRouter(runtime, kv_listener, args.kv_block_size)
completions_endpoint = component.endpoint("completions")
chat_endpoint = component.endpoint("chat/completions")
......
......@@ -48,7 +48,7 @@ class Router:
async def _get_ctx_resp(self, request, ctx_client):
logger.debug(f"Received request {request}")
request.max_tokens = 1
request.max_completion_tokens = 1
request.disaggregated_params = DisaggregatedParams(request_type="context_only")
logger.debug(f"[router] Sending request to context server: {request}")
ctx_resp = [
......@@ -97,6 +97,9 @@ class Router:
async for response in await self.gen_completion_client.round_robin(
gen_req.model_dump_json()
):
logger.debug(
f"[router] Received response from generation server: {response.data()}"
)
gen_resp_obj = DisaggCompletionStreamResponse.model_validate(
response.data()
)
......@@ -130,7 +133,10 @@ class Router:
async for response in await self.gen_chat_client.round_robin(
gen_req.model_dump_json()
):
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate(
logger.debug(
f"[router] Received response from generation server: {response.data()}"
)
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(
response.data()
)
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
......
......@@ -136,7 +136,6 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
streaming=request.stream,
disaggregated_params=disaggregated_params,
):
self.generate_event.set()
final_result = result
logger.debug(f"Generated result: {result}")
if self.server_config.type == "ctx":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment