Commit 8435b993 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

fix: Fix TRTLLM chat to work with latest ToT (#127)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 089f8e1b
...@@ -138,6 +138,10 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin): ...@@ -138,6 +138,10 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
kv_active_block = 0 kv_active_block = 0
kv_total_blocks = 4 kv_total_blocks = 4
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
if self._kv_metrics_publisher is None: if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!") logger.error("KV metrics publisher not initialized!")
return return
...@@ -147,6 +151,9 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin): ...@@ -147,6 +151,9 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
request_total_slots, request_total_slots,
kv_active_block, kv_active_block,
kv_total_blocks, kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
) )
# Prepare threads for publishing stats but don't start them yet. # Prepare threads for publishing stats but don't start them yet.
...@@ -197,11 +204,20 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin): ...@@ -197,11 +204,20 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
logger.error("KV metrics publisher not initialized!") logger.error("KV metrics publisher not initialized!")
return False return False
# TODO: Remove this once we have the actual values.
# Adding dummy values for now so it doesn't break the metrics.
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
self._kv_metrics_publisher.publish( self._kv_metrics_publisher.publish(
request_active_slots, request_active_slots,
request_total_slots, request_total_slots,
kv_active_block, kv_active_block,
kv_total_blocks, kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
) )
logger.debug( logger.debug(
f"Published stats: request_active_slots: {request_active_slots}, request_total_slots: {request_total_slots}, kv_active_block: {kv_active_block}, kv_total_blocks: {kv_total_blocks}" f"Published stats: request_active_slots: {request_active_slots}, request_total_slots: {request_total_slots}, kv_active_block: {kv_active_block}, kv_total_blocks: {kv_total_blocks}"
......
...@@ -128,5 +128,11 @@ def parse_tensorrt_llm_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any] ...@@ -128,5 +128,11 @@ def parse_tensorrt_llm_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]
action="store_true", action="store_true",
help="Publish stats from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.", help="Publish stats from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
) )
parser.add_argument(
"--kv-block-size",
type=int,
help="KV block size for TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
default=64,
)
args = parser.parse_args() args = parser.parse_args()
return (args, _init_engine_args(args.engine_args)) return (args, _init_engine_args(args.engine_args))
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import asyncio import asyncio
import copy import copy
import enum
import json import json
import traceback import traceback
from typing import AsyncIterator from typing import AsyncIterator
...@@ -22,6 +23,7 @@ from typing import AsyncIterator ...@@ -22,6 +23,7 @@ from typing import AsyncIterator
import uvloop import uvloop
from common.base_engine import ChatProcessorMixin from common.base_engine import ChatProcessorMixin
from common.parser import LLMAPIConfig, parse_tensorrt_llm_args from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from common.processor import parse_chat_message_content
from common.protocol import ( from common.protocol import (
DisaggChatCompletionRequest, DisaggChatCompletionRequest,
DisaggChatCompletionStreamResponse, DisaggChatCompletionStreamResponse,
...@@ -37,6 +39,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker ...@@ -37,6 +39,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
logger.set_level("debug") logger.set_level("debug")
class EndpointType(enum.Enum):
chat = "chat"
completion = "completion"
class Scheduler: class Scheduler:
def __init__(self, kv_router: KvRouter): def __init__(self, kv_router: KvRouter):
self.kv_router = kv_router self.kv_router = kv_router
...@@ -77,13 +84,32 @@ class Router(ChatProcessorMixin): ...@@ -77,13 +84,32 @@ class Router(ChatProcessorMixin):
logger.info("INITIALIZED ROUTER") logger.info("INITIALIZED ROUTER")
async def _get_ctx_resp(self, request, ctx_client): async def _get_ctx_resp(self, request, ctx_client, endpoint_type: EndpointType):
logger.debug(f"Received request {request}") logger.debug(f"Received request {request}")
# NOTE: this will increase TTFT since we are encoding the prompt here # NOTE: this will increase TTFT since we are encoding the prompt here
# prompt is also encoded in the worker. # prompt is also encoded in the worker.
# TODO: we need to implement our own request processing and protocols to send only token ids to llmapi worker. # TODO: we need to implement our own request processing and protocols to send only token ids to llmapi worker.
token_ids = self._tokenizer.encode(request.prompt) if endpoint_type == EndpointType.completion:
token_ids = self._tokenizer.encode(request.prompt)
else:
conversation = []
for message in request.messages:
conversation.extend(parse_chat_message_content(message))
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
token_ids = self._tokenizer.apply_chat_template(
conversation=conversation,
tokenize=True,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
worker_id_generator: AsyncIterator = self.scheduler.generate( worker_id_generator: AsyncIterator = self.scheduler.generate(
Tokens(tokens=token_ids).model_dump_json() Tokens(tokens=token_ids).model_dump_json()
) )
...@@ -92,7 +118,7 @@ class Router(ChatProcessorMixin): ...@@ -92,7 +118,7 @@ class Router(ChatProcessorMixin):
await worker_id_generator.__anext__() await worker_id_generator.__anext__()
) # only one worker id is returned ) # only one worker id is returned
request.max_tokens = 1 request.max_completion_tokens = 1
request.disaggregated_params = DisaggregatedParams(request_type="context_only") request.disaggregated_params = DisaggregatedParams(request_type="context_only")
logger.debug(f"[router] Sending request to context server: {request}") logger.debug(f"[router] Sending request to context server: {request}")
...@@ -132,7 +158,9 @@ class Router(ChatProcessorMixin): ...@@ -132,7 +158,9 @@ class Router(ChatProcessorMixin):
gen_req = copy.deepcopy(request) gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(request, self.ctx_completion_client) ctx_resp = await self._get_ctx_resp(
request, self.ctx_completion_client, EndpointType.completion
)
ctx_resp_obj = DisaggCompletionStreamResponse.model_validate(ctx_resp) ctx_resp_obj = DisaggCompletionStreamResponse.model_validate(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate( gen_req.disaggregated_params = DisaggregatedParams.model_validate(
...@@ -165,7 +193,9 @@ class Router(ChatProcessorMixin): ...@@ -165,7 +193,9 @@ class Router(ChatProcessorMixin):
gen_req = copy.deepcopy(request) gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(request, self.ctx_chat_client) ctx_resp = await self._get_ctx_resp(
request, self.ctx_chat_client, EndpointType.chat
)
ctx_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(ctx_resp) ctx_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate( gen_req.disaggregated_params = DisaggregatedParams.model_validate(
...@@ -184,7 +214,7 @@ class Router(ChatProcessorMixin): ...@@ -184,7 +214,7 @@ class Router(ChatProcessorMixin):
async for response in await self.gen_chat_client.round_robin( async for response in await self.gen_chat_client.round_robin(
gen_req.model_dump_json() gen_req.model_dump_json()
): ):
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate( gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(
response.data() response.data()
) )
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True)) yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
...@@ -228,7 +258,7 @@ async def worker(runtime: DistributedRuntime, args, engine_config): ...@@ -228,7 +258,7 @@ async def worker(runtime: DistributedRuntime, args, engine_config):
kv_listener = runtime.namespace("dynamo").component("tensorrt-llm-ctx") kv_listener = runtime.namespace("dynamo").component("tensorrt-llm-ctx")
await kv_listener.create_service() await kv_listener.create_service()
kv_router = KvRouter(runtime, kv_listener) kv_router = KvRouter(runtime, kv_listener, args.kv_block_size)
completions_endpoint = component.endpoint("completions") completions_endpoint = component.endpoint("completions")
chat_endpoint = component.endpoint("chat/completions") chat_endpoint = component.endpoint("chat/completions")
......
...@@ -48,7 +48,7 @@ class Router: ...@@ -48,7 +48,7 @@ class Router:
async def _get_ctx_resp(self, request, ctx_client): async def _get_ctx_resp(self, request, ctx_client):
logger.debug(f"Received request {request}") logger.debug(f"Received request {request}")
request.max_tokens = 1 request.max_completion_tokens = 1
request.disaggregated_params = DisaggregatedParams(request_type="context_only") request.disaggregated_params = DisaggregatedParams(request_type="context_only")
logger.debug(f"[router] Sending request to context server: {request}") logger.debug(f"[router] Sending request to context server: {request}")
ctx_resp = [ ctx_resp = [
...@@ -97,6 +97,9 @@ class Router: ...@@ -97,6 +97,9 @@ class Router:
async for response in await self.gen_completion_client.round_robin( async for response in await self.gen_completion_client.round_robin(
gen_req.model_dump_json() gen_req.model_dump_json()
): ):
logger.debug(
f"[router] Received response from generation server: {response.data()}"
)
gen_resp_obj = DisaggCompletionStreamResponse.model_validate( gen_resp_obj = DisaggCompletionStreamResponse.model_validate(
response.data() response.data()
) )
...@@ -130,7 +133,10 @@ class Router: ...@@ -130,7 +133,10 @@ class Router:
async for response in await self.gen_chat_client.round_robin( async for response in await self.gen_chat_client.round_robin(
gen_req.model_dump_json() gen_req.model_dump_json()
): ):
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate( logger.debug(
f"[router] Received response from generation server: {response.data()}"
)
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(
response.data() response.data()
) )
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True)) yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
......
...@@ -136,7 +136,6 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine): ...@@ -136,7 +136,6 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
streaming=request.stream, streaming=request.stream,
disaggregated_params=disaggregated_params, disaggregated_params=disaggregated_params,
): ):
self.generate_event.set()
final_result = result final_result = result
logger.debug(f"Generated result: {result}") logger.debug(f"Generated result: {result}")
if self.server_config.type == "ctx": if self.server_config.type == "ctx":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment