Unverified Commit 568eb100 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

fix: Use Rust Ingress (dynamo-run) for the Frontend (#1391)

parent 2ae40af5
...@@ -21,18 +21,12 @@ import os ...@@ -21,18 +21,12 @@ import os
import signal import signal
import threading import threading
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import asdict
from enum import Enum from enum import Enum
from queue import Queue from queue import Queue
from typing import Any, Optional from typing import Any, Optional
from common.parser import LLMAPIConfig from common.parser import LLMAPIConfig
from common.protocol import ( from common.protocol import DisaggregatedTypeConverter
DisaggregatedTypeConverter,
TRTLLMWorkerRequest,
TRTLLMWorkerResponse,
TRTLLMWorkerResponseOutput,
)
from common.utils import ManagedThread, ServerType from common.utils import ManagedThread, ServerType
from tensorrt_llm.executor import CppExecutorError from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import LLM, SamplingParams from tensorrt_llm.llmapi import LLM, SamplingParams
...@@ -40,6 +34,7 @@ from tensorrt_llm.llmapi.disagg_utils import ( ...@@ -40,6 +34,7 @@ from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig, CtxGenServerConfig,
parse_disagg_config_file, parse_disagg_config_file,
) )
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
...@@ -65,14 +60,26 @@ def update_args_from_disagg_config( ...@@ -65,14 +60,26 @@ def update_args_from_disagg_config(
return engine_config return engine_config
def get_sampling_params(sampling_params): def _to_signed_i64(value: int | None) -> int | None:
# Removes keys starting with '_' from the sampling params which gets """Convert a Python int to signed 64-bit range by two's complement."""
# added by the LLM API. TRTLLM does not support creating SamplingParams if value is None:
# from a dictionary with keys starting with '_'. return None
cleaned_dict = {
key: value for key, value in sampling_params.items() if not key.startswith("_") if value >= 2**63:
} return value - 2**64
return SamplingParams(**cleaned_dict) if value < -(2**63):
return ((value + 2**63) % 2**64) - 2**63
return value
def get_sampling_params(sampling_params_dict, default_sampling_params):
sampling_params = copy.deepcopy(default_sampling_params)
for key, value in sampling_params_dict.items():
if value is None:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
return sampling_params
class BaseTensorrtLLMEngine: class BaseTensorrtLLMEngine:
...@@ -161,6 +168,12 @@ class BaseTensorrtLLMEngine: ...@@ -161,6 +168,12 @@ class BaseTensorrtLLMEngine:
target=asyncio.run, args=(self._run_llm_engine(),) target=asyncio.run, args=(self._run_llm_engine(),)
) )
# Populate default sampling params from the model
tokenizer = tokenizer_factory(self._engine_config.model_name)
self._default_sampling_params = SamplingParams()
self._default_sampling_params._setup(tokenizer)
self._default_sampling_params.stop = None
self.publish_kv_cache_events_thread = None self.publish_kv_cache_events_thread = None
self.publish_stats_thread = None self.publish_stats_thread = None
...@@ -308,13 +321,13 @@ class BaseTensorrtLLMEngine: ...@@ -308,13 +321,13 @@ class BaseTensorrtLLMEngine:
event_id = event["event_id"] event_id = event["event_id"]
data = event["data"] data = event["data"]
if data["type"] == "stored": if data["type"] == "stored":
parent_hash = data["parent_hash"] parent_hash = _to_signed_i64(data["parent_hash"])
token_ids = [] token_ids = []
num_block_tokens = [] num_block_tokens = []
block_hashes = [] block_hashes = []
for block in data["blocks"]: for block in data["blocks"]:
token_num_in_block = len(block["tokens"]) token_num_in_block = len(block["tokens"])
block_hash = block["block_hash"] block_hash = _to_signed_i64(block["block_hash"])
if token_num_in_block > self._kv_block_size: if token_num_in_block > self._kv_block_size:
logger.error( logger.error(
f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self._kv_block_size}" f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self._kv_block_size}"
...@@ -350,6 +363,7 @@ class BaseTensorrtLLMEngine: ...@@ -350,6 +363,7 @@ class BaseTensorrtLLMEngine:
elif data["type"] == "removed": elif data["type"] == "removed":
block_hashes = [] block_hashes = []
for block_hash in data["block_hashes"]: for block_hash in data["block_hashes"]:
block_hash = _to_signed_i64(block_hash)
if block_hash in self._partial_block_hashes: if block_hash in self._partial_block_hashes:
logger.debug( logger.debug(
f"Skipping removing block hash {block_hash} since it is a partial block" f"Skipping removing block hash {block_hash} since it is a partial block"
...@@ -458,7 +472,8 @@ class BaseTensorrtLLMEngine: ...@@ -458,7 +472,8 @@ class BaseTensorrtLLMEngine:
async def _get_remote_prefill_response(self, request): async def _get_remote_prefill_response(self, request):
prefill_request = copy.deepcopy(request) prefill_request = copy.deepcopy(request)
prefill_request.sampling_params["max_tokens"] = 1 # TRTLLM requires max_tokens to be set for prefill requests.
prefill_request.stop_conditions.max_tokens = 1
prefill_request.disaggregated_params = DisaggregatedParams( prefill_request.disaggregated_params = DisaggregatedParams(
request_type=DisaggRequestType.CONTEXT_ONLY.value request_type=DisaggRequestType.CONTEXT_ONLY.value
) )
...@@ -466,7 +481,7 @@ class BaseTensorrtLLMEngine: ...@@ -466,7 +481,7 @@ class BaseTensorrtLLMEngine:
if self._prefill_client is None: if self._prefill_client is None:
raise ValueError("Prefill client not initialized") raise ValueError("Prefill client not initialized")
# TODO: Use smart KV router to determine which prefill worker to use. # TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
ctx_responses = [ ctx_responses = [
ctx_response ctx_response
async for ctx_response in await self._prefill_client.round_robin( async for ctx_response in await self._prefill_client.round_robin(
...@@ -480,17 +495,10 @@ class BaseTensorrtLLMEngine: ...@@ -480,17 +495,10 @@ class BaseTensorrtLLMEngine:
logger.debug( logger.debug(
f"Received response from prefill worker: {ctx_responses[0].data()}" f"Received response from prefill worker: {ctx_responses[0].data()}"
) )
ctx_response_obj = TRTLLMWorkerResponse.model_validate_json( remote_prefill_response = ctx_responses[0]
ctx_responses[0].data() return remote_prefill_response
)
ctx_response_obj.outputs = [
TRTLLMWorkerResponseOutput(**ctx_response_obj.outputs[0])
]
assert ctx_response_obj.outputs[0].disaggregated_params is not None
return ctx_response_obj
async def generate(self, request: TRTLLMWorkerRequest): async def generate(self, request):
if self._llm_engine is None: if self._llm_engine is None:
raise RuntimeError("Engine not initialized") raise RuntimeError("Engine not initialized")
...@@ -500,7 +508,7 @@ class BaseTensorrtLLMEngine: ...@@ -500,7 +508,7 @@ class BaseTensorrtLLMEngine:
self._ongoing_request_count += 1 self._ongoing_request_count += 1
try: try:
worker_inputs = request.tokens.tokens worker_inputs = request.token_ids
disaggregated_params = ( disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params( DisaggregatedTypeConverter.to_llm_disaggregated_params(
...@@ -508,24 +516,33 @@ class BaseTensorrtLLMEngine: ...@@ -508,24 +516,33 @@ class BaseTensorrtLLMEngine:
) )
) )
if self._remote_prefill and self._server_type == ServerType.GEN: num_output_tokens_so_far = 0
ctx_response_obj = await self._get_remote_prefill_response(request)
yield TRTLLMWorkerResponse( if self._remote_prefill and self._server_type == ServerType.GEN:
request_id=request.id, ctx_response = await self._get_remote_prefill_response(request)
prompt_token_ids=ctx_response_obj.prompt_token_ids, remote_prefill_response = ctx_response.data()
outputs=[asdict(ctx_response_obj.outputs[0])], if (
finished=ctx_response_obj.finished, remote_prefill_response["finish_reason"] == "stop"
).model_dump_json(exclude_unset=True) or remote_prefill_response["finish_reason"] == "error"
):
yield remote_prefill_response
return
num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
worker_inputs = ctx_response_obj.prompt_token_ids # Decode the disaggregated params from the remote prefill response
disaggregated_params = ( disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params( DisaggregatedTypeConverter.to_llm_disaggregated_params(
DisaggregatedParams( DisaggregatedParams(
**ctx_response_obj.outputs[0].disaggregated_params **remote_prefill_response["disaggregated_params"]
) )
) )
) )
# Send the first token response to the client
first_token_response = remote_prefill_response
first_token_response.pop("disaggregated_params")
yield first_token_response
disaggregated_params.request_type = ( disaggregated_params.request_type = (
DisaggRequestType.GENERATION_ONLY.value DisaggRequestType.GENERATION_ONLY.value
) )
...@@ -534,29 +551,44 @@ class BaseTensorrtLLMEngine: ...@@ -534,29 +551,44 @@ class BaseTensorrtLLMEngine:
f"Worker inputs: {worker_inputs}, disaggregated params: {disaggregated_params}" f"Worker inputs: {worker_inputs}, disaggregated params: {disaggregated_params}"
) )
sampling_params = get_sampling_params(request.sampling_params) sampling_params = get_sampling_params(
request.sampling_options.dict(), self._default_sampling_params
)
max_tokens = request.stop_conditions.max_tokens
if max_tokens:
sampling_params.max_tokens = max_tokens
async for response in self._llm_engine.generate_async( async for response in self._llm_engine.generate_async(
inputs=worker_inputs, inputs=worker_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
disaggregated_params=disaggregated_params, disaggregated_params=disaggregated_params,
streaming=False streaming=self._server_type != ServerType.CTX,
if self._server_type == ServerType.CTX
else request.streaming,
): ):
# Convert the disaggregated params to OAI format so if response.finished and self._server_type != ServerType.CTX:
# it can be sent over the network. yield {"finish_reason": "stop", "token_ids": []}
response.outputs[ break
0
].disaggregated_params = DisaggregatedTypeConverter.to_oai_disaggregated_params( if not response.outputs:
response.outputs[0].disaggregated_params yield {"finish_reason": "error", "token_ids": []}
) break
yield TRTLLMWorkerResponse( output = response.outputs[0]
request_id=request.id, next_total_toks = len(output.token_ids)
prompt_token_ids=response.prompt_token_ids, out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
outputs=[asdict(response.outputs[0])], if output.finish_reason:
finished=response.finished, out["finish_reason"] = output.finish_reason
).model_dump_json(exclude_unset=True) if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self._server_type == ServerType.CTX:
# Return the disaggregated params only when operating in prefill mode.
out[
"disaggregated_params"
] = DisaggregatedTypeConverter.to_oai_disaggregated_params(
output.disaggregated_params
).dict()
yield out
num_output_tokens_so_far = next_total_toks
except CppExecutorError: except CppExecutorError:
signal.raise_signal(signal.SIGINT) signal.raise_signal(signal.SIGINT)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import asdict
from typing import Any, Dict, List, Union
from common.parser import LLMAPIConfig
from common.protocol import (
DynamoTRTLLMChatCompletionResponseStreamChoice,
DynamoTRTLLMChatCompletionStreamResponse,
DynamoTRTLLMCompletionResponseStreamChoice,
DynamoTRTLLMCompletionStreamResponse,
Tokens,
TRTLLMWorkerRequest,
TRTLLMWorkerResponse,
TRTLLMWorkerResponseOutput,
)
from common.utils import ConversationMessage
from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, tokenizer_factory
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
DeltaMessage,
FunctionCall,
ToolCall,
UsageInfo,
)
logger = logging.getLogger(__name__)
class ChatProcessorMixin:
def __init__(
self, engine_config: LLMAPIConfig, using_engine_generator: bool = False
):
self._engine_config = engine_config
logger.info(f"Using LLM API config: {self._engine_config.to_dict()}")
# model name for chat processor
self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}")
self._tokenizer = tokenizer_factory(self._model_name)
self.chat_processor = ChatProcessor(
self._model_name, self._tokenizer, using_engine_generator
)
self.completions_processor = CompletionsProcessor(
self._model_name, self._tokenizer
)
def parse_chat_message_content(
message: ChatCompletionMessageParam,
) -> Union[ConversationMessage, List[ConversationMessage], List[None]]:
role = message["role"]
content = message.get("content")
if content is None:
return []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)]
texts: List[str] = []
for part in content:
part_type = part["type"]
if part_type == "text":
text = part["text"] # type: ignore
texts.append(text)
else:
raise NotImplementedError(f"{part_type} is not supported")
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
class BaseChatProcessor:
def __init__(
self,
model: str,
tokenizer: TokenizerBase,
):
self.model = model
self.tokenizer = tokenizer
def _get_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
role = "assistant"
else:
role = request.messages[-1]["role"]
return role
def _stream_usage_info(
self, request: ChatCompletionRequest, prompt_tokens: int, completion_tokens: int
):
if (
request.stream_options
and request.stream_options.include_usage
and request.stream_options.continuous_usage_stats
):
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
usage = None
return usage
def _create_logprobs(
self, token_ids: List[int], logprobs: List[float]
) -> ChatCompletionLogProbs:
assert len(token_ids) == len(
logprobs
), "token_ids and logprobs have different lengths"
content: List[ChatCompletionLogProbsContent] = []
for token_id, logprob in zip(token_ids, logprobs):
token = self.tokenizer.decode(token_id)
# returning multiple logprobs is not supported
first_logprob = ChatCompletionLogProbsContent(
token=token,
# NOTE: min logprob -9999.0 for probabilities extremely close to 0
logprob=max(logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")),
)
content.append(first_logprob)
chat_logprobs = ChatCompletionLogProbs(content=content)
return chat_logprobs
class ChatProcessor(BaseChatProcessor):
def __init__(
self,
model: str,
tokenizer: TokenizerBase,
using_engine_generator: bool = False,
):
super().__init__(model, tokenizer)
self.using_engine_generator = using_engine_generator
def yield_first_chat(
self,
request: ChatCompletionRequest,
request_id: str,
response: RequestOutput,
content: str | None = None,
):
role = self._get_role(request)
num_choices = 1 if request.n is None else request.n
num_tokens = len(response.prompt_token_ids)
content = response.outputs[0].text_diff
for i in range(num_choices):
choice = DynamoTRTLLMChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role, content=content),
finish_reason=None,
)
if response.outputs[0].disaggregated_params is not None:
# Do not include the disaggregated params in response
# from processor.
pass
chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id,
choices=[choice],
model=self.model,
)
chunk.usage = self._stream_usage_info(request, num_tokens, 0)
return chunk.model_dump_json()
def create_chat_stream_response(
self,
request: ChatCompletionRequest,
request_id: str,
response: RequestOutput,
conversation: List[Dict[str, Any]],
first_iteration: bool = True,
) -> str:
num_choices = 1 if request.n is None else request.n
finish_reason_sent = [False] * num_choices
role = self._get_role(request)
prompt_tokens = len(response.prompt_token_ids)
if first_iteration:
return self.yield_first_chat(request, request_id, response)
# TODO: Fix this
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
if last_msg_content:
return self.yield_first_chat(
request, request_id, response, content=last_msg_content
)
first_iteration = False
for output in response.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_text = output.text_diff
if (
request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
delta_message = DeltaMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text,
)
)
]
)
else:
delta_message = DeltaMessage(content=delta_text, role=role)
choice = DynamoTRTLLMChatCompletionResponseStreamChoice(
index=i, delta=delta_message, finish_reason=None
)
if request.logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = self._create_logprobs(token_ids, logprobs)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True
chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id,
choices=[choice],
model=self.model,
)
logger.debug(f"[processor] Chunk: {chunk}")
chunk.usage = self._stream_usage_info(request, prompt_tokens, output.length)
return chunk.model_dump_json()
# TODO: make request.stream_options.include_usage = True when stream=False in rust
if request.stream_options and request.stream_options.include_usage:
completion_tokens = sum(output.length for output in response.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id,
choices=[],
model=self.model,
usage=final_usage,
)
return final_usage_chunk.model_dump_json()
return "data: [DONE]\n\n"
async def preprocess(self, request):
conversation: List[Any] = []
for message in request.messages:
conversation.extend(parse_chat_message_content(message))
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
prompt = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=True,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
sampling_params._setup(self.tokenizer)
sampling_params.stop = None
return TRTLLMWorkerRequest(
id=request.id,
model=request.model,
sampling_params=asdict(sampling_params),
streaming=request.stream,
conversation=conversation,
disaggregated_params=request.disaggregated_params,
tokens=Tokens(tokens=prompt),
)
async def postprocess(
self,
engine_generator,
request,
conversation,
):
first_iteration = True
async for raw_response in engine_generator:
if self.using_engine_generator:
response = TRTLLMWorkerResponse(
request_id=request.id,
prompt=raw_response.prompt,
prompt_token_ids=raw_response.prompt_token_ids,
outputs=[asdict(raw_response.outputs[0])],
finished=raw_response.finished,
)
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
else:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
last_token_ids_len = response.outputs[0]["_last_token_ids_len"]
response.outputs[0]["text"] = self.tokenizer.decode(
response.outputs[0]["token_ids"][last_token_ids_len:]
)
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
response_data = self.create_chat_stream_response(
request,
request.id,
response,
conversation,
first_iteration=first_iteration,
)
first_iteration = False
logger.debug(f"[postprocessor] Response: {response_data}")
yield response_data
class CompletionsProcessor:
def __init__(
self,
model: str,
tokenizer: TokenizerBase,
):
self.model = model
self.tokenizer = tokenizer
def create_completion_stream_response(self, request, response):
num_choices = 1 if request.n is None else request.n
echoed = [False] * num_choices
# len(response.outputs) is always 1
for gen_idx, output in enumerate(response.outputs):
text = output.text
if request.echo and not echoed[gen_idx]:
text = request.prompt + text
choice = DynamoTRTLLMCompletionResponseStreamChoice(
text=text,
index=output.index,
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
)
chunk = DynamoTRTLLMCompletionStreamResponse(
model=self.model,
choices=[choice],
)
return chunk.model_dump_json()
async def preprocess(self, request):
if isinstance(request.prompt, str) or (
isinstance(request.prompt, list)
and all(isinstance(x, int) for x in request.prompt)
):
prompt = request.prompt
else:
raise ValueError(
"Invalid prompt type. Only string or list of integers are supported."
)
sampling_params = request.to_sampling_params()
sampling_params._setup(self.tokenizer)
sampling_params.stop = None
return TRTLLMWorkerRequest(
id=request.id,
model=request.model,
streaming=request.stream,
sampling_params=asdict(sampling_params),
disaggregated_params=request.disaggregated_params,
tokens=Tokens(tokens=self.tokenizer.encode(prompt)),
)
async def postprocess(
self,
engine_generator,
request,
):
async for raw_response in engine_generator:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
last_token_ids_len = response.outputs[0]["_last_token_ids_len"]
response.outputs[0]["text"] = self.tokenizer.decode(
response.outputs[0]["token_ids"][last_token_ids_len:]
)
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
response_data = self.create_completion_stream_response(
request,
response,
)
logger.debug(f"[postprocessor] Response: {response_data}")
yield response_data
...@@ -131,6 +131,12 @@ def parse_tensorrt_llm_args( ...@@ -131,6 +131,12 @@ def parse_tensorrt_llm_args(
parser.add_argument( parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file" "--engine_args", type=str, required=True, help="Path to the engine args file"
) )
parser.add_argument(
"--served_model_name",
type=str,
help="Name of the model to serve",
default=None,
)
parser.add_argument( parser.add_argument(
"--llmapi-disaggregated-config", "--llmapi-disaggregated-config",
"-c", "-c",
......
...@@ -12,126 +12,19 @@ ...@@ -12,126 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64 import base64
import time from typing import List, Optional
import uuid
from dataclasses import dataclass, field
from typing import Any, List, Literal, Optional, TypeAlias, Union
import torch from pydantic import BaseModel, Field
from common.utils import ConversationMessage
from pydantic import BaseModel, ConfigDict, Field
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.serve.openai_protocol import ( from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
ChatCompletionRequest,
ChatCompletionResponseStreamChoice,
CompletionRequest,
CompletionResponseStreamChoice,
DisaggregatedParams,
UsageInfo,
)
# The max_tokens is being deprecated in favor of max_completion_tokens.
# However, TRTLLM protocol might still refer it as max_tokens.
class DynamoTRTLLMCompletionRequest(CompletionRequest):
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid4().hex)}")
max_completion_tokens: Optional[int] = None
nvext: Optional[dict] = Field(default=None)
class DynamoTRTLLMChatCompletionRequest(ChatCompletionRequest):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid4().hex)}")
max_completion_tokens: Optional[int] = None
max_tokens: Optional[int] = None
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
nvext: Optional[dict] = Field(default=None)
class Tokens(BaseModel): class Tokens(BaseModel):
tokens: list[int] tokens: list[int]
class Request(BaseModel): TokenIdType = int
prompt: str
sampling_params: dict
streaming: bool
class TRTLLMWorkerRequest(BaseModel):
model: str
id: str
prompt: str | None = None
sampling_params: dict
streaming: bool = True
conversation: Optional[List[ConversationMessage]] = Field(default=None)
tokens: Optional[Tokens] = Field(default=None)
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
@dataclass(slots=True)
class Logprob:
"""Holds logprob and vocab rank for a token."""
logprob: float
rank: Optional[int] = None
# List of token_id_to_Logprob dict for prompt or generation texts
TokenLogprobs: TypeAlias = list[dict[int, Logprob]]
@dataclass
class TRTLLMWorkerResponseOutput:
index: int
text: str = ""
token_ids: Optional[List[int]] = field(default_factory=list)
cumulative_logprob: Optional[float] = None
logprobs: Optional[TokenLogprobs] = field(default_factory=list)
prompt_logprobs: Optional[TokenLogprobs] = field(default_factory=list)
finish_reason: Optional[Literal["stop", "length", "timeout", "cancelled"]] = None
stop_reason: Optional[Union[int, str]] = None
generation_logits: Optional[torch.Tensor] = None
disaggregated_params: Optional[DisaggregatedParams] = None
# hidden fields for tracking the diffs
_last_text_len: int = field(default=0, init=True, repr=False)
_last_token_ids_len: int = field(default=0, init=True, repr=False)
_last_logprobs_len: int = field(default=0, init=True, repr=False)
_incremental_states: Optional[dict] = field(default=None, init=True, repr=False)
# the result of result_handler passed to postprocess workers
_postprocess_result: Any = None
@property
def length(self) -> int:
return 0 if self.token_ids is None else len(self.token_ids)
@property
def text_diff(self) -> str:
return self.text[self._last_text_len :]
@property
def token_ids_diff(self) -> List[int]:
return (
[] if self.token_ids is None else self.token_ids[self._last_token_ids_len :]
)
# Ignoring the mypy error here as this is copied from TensorRT-LLM project.
# https://github.com/NVIDIA/TensorRT-LLM/blob/19c6e68bec891b66146a09647ee7b70230ef5f67/tensorrt_llm/executor/result.py#L68
# TODO: Work with the TensorRT-LLM team to get this fixed.
@property
def logprobs_diff(self) -> List[float]: # type: ignore
return [] if self.logprobs is None else self.logprobs[self._last_logprobs_len :] # type: ignore
class TRTLLMWorkerResponse(BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
request_id: str
prompt: str | None = None
prompt_token_ids: list[int]
outputs: list[dict]
finished: bool
class DisaggregatedTypeConverter: class DisaggregatedTypeConverter:
...@@ -175,36 +68,37 @@ class DisaggregatedTypeConverter: ...@@ -175,36 +68,37 @@ class DisaggregatedTypeConverter:
) )
# Chat Completions # TODO: move these to common for all LLMs once we adopt dynamo-run
# derived from lib/llm/src/protocols/common/preprocessor.rs
class StopConditions(BaseModel):
class DynamoTRTLLMChatCompletionResponseStreamChoice( max_tokens: Optional[int] = None
ChatCompletionResponseStreamChoice stop: Optional[List[str]] = None
): stop_token_ids_hidden: Optional[List[TokenIdType]] = None
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None) min_tokens: Optional[int] = None
ignore_eos: Optional[bool] = None
class DynamoTRTLLMChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid4().hex)}") class SamplingOptions(BaseModel):
object: Literal["chat.completion.chunk"] = "chat.completion.chunk" n: Optional[int] = None
created: int = Field(default_factory=lambda: int(time.time())) best_of: Optional[int] = None
model: str presence_penalty: Optional[float] = None
choices: List[DynamoTRTLLMChatCompletionResponseStreamChoice] frequency_penalty: Optional[float] = None
usage: Optional[UsageInfo] = Field(default=None) repetition_penalty: Optional[float] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
## Completions top_k: Optional[int] = None
min_p: Optional[float] = None
use_beam_search: Optional[bool] = None
length_penalty: Optional[float] = None
seed: Optional[int] = None
class DynamoTRTLLMCompletionResponseStreamChoice(CompletionResponseStreamChoice): class TRTLLMWorkerRequest(BaseModel):
token_ids: List[TokenIdType]
stop_conditions: StopConditions
sampling_options: SamplingOptions
eos_token_ids: List[TokenIdType] = Field(default_factory=list)
mdc_sum: Optional[str] = None
annotations: List[str] = Field(default_factory=list)
estimated_prefix_hit_num_blocks: Optional[int] = None
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None) disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class DynamoTRTLLMCompletionStreamResponse(BaseModel):
model_config = ConfigDict(extra="forbid")
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid4().hex)}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[DynamoTRTLLMCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
...@@ -17,7 +17,6 @@ import logging ...@@ -17,7 +17,6 @@ import logging
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from components.processor import Processor
from components.worker import TensorRTLLMWorker from components.worker import TensorRTLLMWorker
from fastapi import FastAPI from fastapi import FastAPI
from pydantic import BaseModel from pydantic import BaseModel
...@@ -30,88 +29,91 @@ from dynamo.sdk.lib.image import DYNAMO_IMAGE ...@@ -30,88 +29,91 @@ from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_http_binary_path(): def get_dynamo_run_binary():
"""Find the dynamo-run binary path in SDK or fallback to 'dynamo-run' command."""
sdk_path = Path(sdk.__file__) sdk_path = Path(sdk.__file__)
binary_path = sdk_path.parent / "cli/bin/http" binary_path = sdk_path.parent / "cli/bin/dynamo-run"
if not binary_path.exists(): if not binary_path.exists():
return "http" return "dynamo-run"
else: else:
return str(binary_path) return str(binary_path)
class FrontendConfig(BaseModel): class FrontendConfig(BaseModel):
"""Configuration for the Frontend service including model and HTTP server settings."""
served_model_name: str served_model_name: str
endpoint_chat: str endpoint: str
endpoint_completions: str port: int = 8000
port: int = 8080 router: str = "round-robin"
block_size: int = 32
# todo this should be called ApiServer
@service( @service(
dynamo={ dynamo={
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
image=DYNAMO_IMAGE, image=DYNAMO_IMAGE,
app=FastAPI(title="TensorRT LLM Example"), app=FastAPI(title="TensorRT-LLM Example"),
) )
# todo this should be called ApiServer
class Frontend: class Frontend:
worker = depends(TensorRTLLMWorker) worker = depends(TensorRTLLMWorker)
processor = depends(Processor)
def __init__(self): def __init__(self):
frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend")) """Initialize Frontend service with HTTP server and model configuration."""
self.frontend_config = FrontendConfig(
# Chat/completions Endpoint **ServiceConfig.get_parsed_config("Frontend")
subprocess.run(
[
"llmctl",
"http",
"remove",
"chat-models",
frontend_config.served_model_name,
]
)
subprocess.run(
[
"llmctl",
"http",
"add",
"chat-models",
frontend_config.served_model_name,
frontend_config.endpoint_chat,
]
) )
self.process = None
# Completions Endpoint logger.warning(f"Frontend config: {self.frontend_config}")
subprocess.run(
[ self.start_ingress_and_processor()
"llmctl",
"http", def start_ingress_and_processor(self):
"remove", """Starting dynamo-run based ingress and processor"""
"completions", logger.info(
frontend_config.served_model_name, f"Starting HTTP server and processor on port {self.frontend_config.port}"
]
)
subprocess.run(
[
"llmctl",
"http",
"add",
"completions",
frontend_config.served_model_name,
frontend_config.endpoint_completions,
]
) )
dynamo_run_binary = get_dynamo_run_binary()
cmd = [
dynamo_run_binary,
"in=http",
"out=dyn",
"--http-port",
str(self.frontend_config.port),
"--router-mode",
self.frontend_config.router,
]
logger.info(f"Frontend cmd: {cmd}")
logger.info("Starting HTTP server") self.process = subprocess.Popen(
http_binary = get_http_binary_path() cmd,
process = subprocess.Popen( stdout=None,
[http_binary, "-p", str(frontend_config.port)], stdout=None, stderr=None stderr=None,
) )
try:
process.wait() def close(self):
except KeyboardInterrupt: """Clean up resources by terminating the subprocess."""
process.terminate() if self.process is not None:
process.wait() try:
logger.info("Terminating subprocess...")
self.process.terminate()
# Wait for process to terminate with a timeout
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.warning("Subprocess did not terminate gracefully, forcing kill")
self.process.kill()
self.process.wait()
except Exception as e:
logger.error(f"Error while terminating subprocess: {e}")
finally:
self.process = None
def __del__(self):
"""Destructor to ensure subprocess is cleaned up."""
self.close()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import logging
import random
import traceback
from argparse import Namespace
from typing import AsyncIterator
from common.protocol import Tokens
from components.worker import TensorRTLLMWorker
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__)
WorkerId = str
def parse_args(service_name, prefix) -> Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
# TODO: Read block size
parser.add_argument(
"--block-size",
type=int,
default=32,
help="KV block size",
)
parser.add_argument(
"--custom-router",
type=bool,
default=False,
help="Whether to use custom router or not",
)
config = ServiceConfig.get_instance()
config_args = config.as_args(service_name, prefix=prefix)
args = parser.parse_args(config_args)
return args
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Router:
worker = depends(TensorRTLLMWorker)
def __init__(self):
logger.info("Initializing KV router.")
class_name = self.__class__.__name__
self.args = parse_args(class_name, "")
@async_on_start
async def async_init(self):
self.runtime = dynamo_context["runtime"]
self.workers_client = (
await self.runtime.namespace("dynamo")
.component("TensorRTLLMWorker")
.endpoint("generate")
.client()
)
while len(self.workers_client.instance_ids()) < self.args.min_workers:
logger.info(
f"Waiting for more workers to be ready.\n"
f" Current: {len(self.workers_client.instance_ids())},"
f" Required: {self.args.min_workers}"
)
await asyncio.sleep(30)
kv_listener = self.runtime.namespace("dynamo").component("TensorRTLLMWorker")
await kv_listener.create_service()
self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
logger.info("KV Router initialized")
def _cost_function(
self,
scores: OverlapScores | None,
metrics: AggregatedMetrics | None,
token_length: int,
):
worker_scores = {}
if scores:
for worker_id, score in scores.scores.items():
# score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better
worker_scores[worker_id] = (
score * self.indexer.block_size() / token_length
)
logger.debug(f"Worker scores: {worker_scores}")
worker_metrics = {}
# pull metrics for each worker
max_waiting = 0.0
if metrics:
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
worker_metrics[worker_id] = {
"gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
if hasattr(endpoint, "gpu_cache_usage_perc")
else 0.0,
"num_requests_waiting": endpoint.num_requests_waiting
if hasattr(endpoint, "num_requests_waiting")
else 0.0,
"gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
else 0.0,
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
)
logger.debug(f"Worker metrics: {worker_metrics}")
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.instance_ids()
worker_logits = {}
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(
worker_id,
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
if max_waiting > 0
else 0.0
)
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = (
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
)
logger.debug(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
)
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return ""
# Select the worker with the highest logit
if worker_logits:
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
else:
best_worker_id = ""
# Log the metrics for the selected worker
if best_worker_id:
logger.debug(
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
)
logger.debug(
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
metrics_dict = worker_metrics.get(best_worker_id, {})
logger.debug(
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
)
logger.debug(
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
)
logger.debug(
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
)
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
@endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
if self.indexer is None or self.metrics_aggregator is None:
yield "_0.0"
lora_id = 0
worker_id = ""
try:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics()
schedule_result = self._cost_function(scores, metrics, token_length)
except Exception:
schedule_result = ""
logger.warning(f"Error during worker selection: {traceback.format_exc()}")
if schedule_result == "":
worker_id = ""
prefix_hit_rate = 0.0
else:
worker_id, prefix_hit_rate = schedule_result
yield f"{worker_id}_{prefix_hit_rate}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
from common.chat_processor import ChatProcessorMixin
from common.parser import parse_tensorrt_llm_args
from common.protocol import (
DynamoTRTLLMChatCompletionRequest,
DynamoTRTLLMCompletionRequest,
)
from common.utils import RequestType
from components.kv_router import Router
from components.worker import TensorRTLLMWorker
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__)
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ChatProcessorMixin):
worker = depends(TensorRTLLMWorker)
router = depends(Router)
def __init__(
self,
):
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
args, engine_config = parse_tensorrt_llm_args(config_args)
self.remote_prefill = args.remote_prefill
self.router_mode = args.router
self.min_workers = 1
self.args = args
super().__init__(engine_config)
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = TensorRTLLMWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
if self.args.router == "kv":
router_ns, router_name = Router.dynamo_address() # type: ignore
self.router_client = (
await runtime.namespace(router_ns)
.component(router_name)
.endpoint("generate")
.client()
)
while len(self.worker_client.instance_ids()) < self.min_workers:
logger.info(
f"Waiting for workers to be ready.\n"
f" Current: {len(self.worker_client.instance_ids())},"
f" Required: {self.min_workers}"
)
await asyncio.sleep(30)
async def _generate(self, raw_request, request_type: RequestType):
raw_request.skip_special_tokens = False
raw_request.add_special_tokens = False
raw_request.spaces_between_special_tokens = False
logger.debug(f"[preprocessor] Received request: {raw_request}")
if request_type == RequestType.CHAT:
preprocessed_request = await self.chat_processor.preprocess(raw_request)
else:
preprocessed_request = await self.completions_processor.preprocess(
raw_request
)
worker_id = ""
if self.router_mode == "kv":
router_generator = await self.router_client.generate(
preprocessed_request.tokens.model_dump_json()
)
decision = await router_generator.__anext__()
decision = decision.data()
worker_id, prefix_hit_rate = decision.split("_")
prefix_hit_rate = float(prefix_hit_rate)
logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
if worker_id == "":
if self.router_mode == "round-robin":
self._send_request = self.worker_client.round_robin
else:
# fallback to random
self._send_request = self.worker_client.random
engine_generator = await self._send_request(
preprocessed_request.model_dump_json()
)
else:
engine_generator = await self.worker_client.direct(
preprocessed_request.model_dump_json(), int(worker_id)
)
if request_type == RequestType.CHAT:
async for response in self.chat_processor.postprocess(
engine_generator,
raw_request,
preprocessed_request.conversation,
):
logger.debug(f"[preprocessor] Response: {response}")
yield json.loads(response)
else:
async for response in self.completions_processor.postprocess(
engine_generator, raw_request
):
logger.debug(f"[preprocessor] Response: {response}")
yield json.loads(response)
@endpoint(name="chat/completions")
async def generate_chat(self, raw_request: DynamoTRTLLMChatCompletionRequest):
# max_tokens is deprecated, however if the max_tokens is provided instead
# of max_completion_tokens, we will use the value as max_completion_tokens.
if raw_request.max_tokens is not None:
if raw_request.max_completion_tokens is None:
raw_request.max_completion_tokens = raw_request.max_tokens
else:
if raw_request.max_tokens != raw_request.max_completion_tokens:
raise ValueError(
"max_tokens and max_completion_tokens must be the same"
)
# min_tokens isn't currently propagated through the Rust OpenAI HTTP frontend,
# and ignore_eos is passed through the 'nvext' field, so set both when found.
if raw_request.nvext:
ignore_eos = raw_request.nvext.get("ignore_eos")
raw_request.ignore_eos = ignore_eos
# If ignore_eos is True, set min_tokens to max_tokens to guarantee
# the full expected OSL for consistent benchmarking purposes.
if ignore_eos:
logger.debug(
f"[preprocessor] `ignore_eos` detected, setting `min_tokens` to `max_completion_tokens`: {raw_request.max_completion_tokens}"
)
raw_request.min_tokens = raw_request.max_completion_tokens
async for response in self._generate(raw_request, RequestType.CHAT):
yield response
@endpoint(name="completions")
async def completions(self, raw_request: DynamoTRTLLMCompletionRequest):
# min_tokens isn't currently propagated through the Rust OpenAI HTTP frontend,
# and ignore_eos is passed through the 'nvext' field, so set both when found.
if raw_request.nvext:
ignore_eos = raw_request.nvext.get("ignore_eos")
raw_request.ignore_eos = ignore_eos
# If ignore_eos is True, set min_tokens to max_tokens to guarantee
# the full expected OSL for consistent benchmarking purposes.
if ignore_eos:
logger.debug(
f"[preprocessor] `ignore_eos` detected, setting `min_tokens` to `max_tokens`: {raw_request.max_tokens}"
)
raw_request.min_tokens = raw_request.max_tokens
async for response in self._generate(raw_request, RequestType.COMPLETION):
yield response
...@@ -21,6 +21,7 @@ from common.protocol import TRTLLMWorkerRequest ...@@ -21,6 +21,7 @@ from common.protocol import TRTLLMWorkerRequest
from common.utils import ServerType from common.utils import ServerType
from components.prefill_worker import TensorRTLLMPrefillWorker from components.prefill_worker import TensorRTLLMPrefillWorker
from dynamo.llm import ModelType, register_llm
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
...@@ -43,10 +44,12 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine): ...@@ -43,10 +44,12 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine):
config = ServiceConfig.get_instance() config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="") config_args = config.as_args(class_name, prefix="")
args, engine_config = parse_tensorrt_llm_args(config_args) args, engine_config = parse_tensorrt_llm_args(config_args)
self.served_model_name = args.served_model_name
worker_id = dynamo_context["endpoints"][0].lease_id() worker_id = dynamo_context["endpoints"][0].lease_id()
namespace, _ = TensorRTLLMWorker.dynamo_address() # type: ignore
self._min_prefill_workers = args.min_prefill_workers self._min_prefill_workers = args.min_prefill_workers
super().__init__( super().__init__(
namespace_str="dynamo", namespace_str=namespace,
component_str=class_name, component_str=class_name,
worker_id=worker_id, worker_id=worker_id,
engine_config=engine_config, engine_config=engine_config,
...@@ -62,6 +65,24 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine): ...@@ -62,6 +65,24 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine):
async def async_init(self): async def async_init(self):
self._init_engine() self._init_engine()
runtime = dynamo_context["runtime"]
logger.info("Registering LLM for discovery")
comp_ns, comp_name = TensorRTLLMWorker.dynamo_address() # type: ignore
endpoint = runtime.namespace(comp_ns).component(comp_name).endpoint("generate")
try:
await register_llm(
ModelType.Backend,
endpoint,
self._engine_config.model_name,
self.served_model_name,
kv_cache_block_size=self._kv_block_size,
)
logger.info("Successfully registered LLM for discovery")
except Exception as e:
logger.error(f"Failed to register LLM for discovery: {e}")
raise
if self._remote_prefill: if self._remote_prefill:
runtime = dynamo_context["runtime"] runtime = dynamo_context["runtime"]
comp_ns, comp_name = TensorRTLLMPrefillWorker.dynamo_address() # type: ignore comp_ns, comp_name = TensorRTLLMPrefillWorker.dynamo_address() # type: ignore
......
...@@ -15,20 +15,15 @@ ...@@ -15,20 +15,15 @@
Frontend: Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint_completions: dynamo.Processor.completions endpoint: dynamo.TensorRTLLMWorker.generate
endpoint_chat: dynamo.Processor.chat/completions
port: 8000 port: 8000
Processor:
engine_args: "configs/llm_api_config.yaml"
router: round-robin router: round-robin
ServiceArgs:
workers: 5 # to reduce the tokenization bottleneck at a high concurrency
TensorRTLLMWorker: TensorRTLLMWorker:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
engine_args: "configs/llm_api_config.yaml" engine_args: "configs/llm_api_config.yaml"
router: round-robin router: round-robin
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
resources: resources:
gpu: 1 gpu: 1
\ No newline at end of file
...@@ -15,19 +15,9 @@ ...@@ -15,19 +15,9 @@
Frontend: Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint_completions: dynamo.Processor.completions endpoint: dynamo.TensorRTLLMWorker.generate
endpoint_chat: dynamo.Processor.chat/completions
port: 8000 port: 8000
Processor:
engine_args: "configs/llm_api_config_router.yaml"
router: kv router: kv
ServiceArgs:
workers: 5 # to reduce the tokenization bottleneck at a high concurrency
Router:
model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
min-workers: 1
TensorRTLLMWorker: TensorRTLLMWorker:
engine_args: "configs/llm_api_config_router.yaml" engine_args: "configs/llm_api_config_router.yaml"
......
...@@ -16,18 +16,12 @@ ...@@ -16,18 +16,12 @@
Frontend: Frontend:
# This is the client-facing model name, you can set this to anything you'd like. # This is the client-facing model name, you can set this to anything you'd like.
served_model_name: "nvidia/DeepSeek-R1-FP4" served_model_name: "nvidia/DeepSeek-R1-FP4"
endpoint_chat: dynamo.Processor.chat/completions endpoint: dynamo.TensorRTLLMWorker.generate
endpoint_completions: dynamo.Processor.completions
port: 8000 port: 8000
Processor:
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
router: round-robin router: round-robin
# Parallelize preprocessing/tokenization to avoid bottlenecks
ServiceArgs:
workers: 5
TensorRTLLMWorker: TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml" engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
router: round-robin router: round-robin
ServiceArgs: ServiceArgs:
......
...@@ -16,19 +16,12 @@ ...@@ -16,19 +16,12 @@
Frontend: Frontend:
# This is the client-facing model name, you can set this to anything you'd like. # This is the client-facing model name, you can set this to anything you'd like.
served_model_name: "nvidia/DeepSeek-R1-FP4" served_model_name: "nvidia/DeepSeek-R1-FP4"
endpoint_chat: dynamo.Processor.chat/completions endpoint: dynamo.TensorRTLLMWorker.generate
endpoint_completions: dynamo.Processor.completions
port: 8000 port: 8000
Processor:
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
router: round-robin router: round-robin
remote-prefill: true
# Parallelize preprocessing/tokenization to avoid bottlenecks
ServiceArgs:
workers: 5
TensorRTLLMWorker: TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml" engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
llmapi-disaggregated-config: "configs/deepseek_r1/disagg_llm_api_config.yaml" llmapi-disaggregated-config: "configs/deepseek_r1/disagg_llm_api_config.yaml"
remote-prefill: true remote-prefill: true
......
...@@ -15,23 +15,17 @@ ...@@ -15,23 +15,17 @@
Frontend: Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint_completions: dynamo.Processor.completions endpoint: dynamo.TensorRTLLMWorker.generate
endpoint_chat: dynamo.Processor.chat/completions
port: 8000 port: 8000
Processor:
engine_args: "configs/llm_api_config.yaml"
router: round-robin router: round-robin
remote-prefill: true
ServiceArgs:
workers: 5 # to reduce the tokenization bottleneck at a high concurrency
TensorRTLLMWorker: TensorRTLLMWorker:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
engine_args: "configs/llm_api_config.yaml" engine_args: "configs/llm_api_config.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml" llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml"
router: round-robin
remote-prefill: true remote-prefill: true
min-prefill-workers: 1 min-prefill-workers: 1
router: round-robin
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
resources: resources:
......
...@@ -15,37 +15,27 @@ ...@@ -15,37 +15,27 @@
Frontend: Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint_completions: dynamo.Processor.completions endpoint: dynamo.TensorRTLLMWorker.generate
endpoint_chat: dynamo.Processor.chat/completions
port: 8000 port: 8000
router: kv
Processor:
engine_args: "configs/llm_api_config_disagg_router.yaml"
router: "kv"
remote-prefill: true
ServiceArgs:
workers: 5 # to reduce the tokenization bottleneck at a high concurrency
Router:
model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
min-workers: 1
TensorRTLLMWorker: TensorRTLLMWorker:
engine_args: "configs/llm_api_config_disagg_router.yaml" served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
engine_args: "configs/llm_api_config_router.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_router_configs/single_node_config.yaml" llmapi-disaggregated-config: "configs/llmapi_disagg_router_configs/single_node_config.yaml"
router: kv
remote-prefill: true remote-prefill: true
min-prefill-workers: 1 min-prefill-workers: 1
router: kv
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
resources: resources:
gpu: 1 gpu: 1
TensorRTLLMPrefillWorker: TensorRTLLMPrefillWorker:
engine_args: "configs/llm_api_config_disagg_router.yaml" engine_args: "configs/llm_api_config_router.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_router_configs/single_node_config.yaml" llmapi-disaggregated-config: "configs/llmapi_disagg_router_configs/single_node_config.yaml"
router: round-robin router: round-robin
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
resources: resources:
gpu: 1 gpu: 1
\ No newline at end of file
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
from components.frontend import Frontend from components.frontend import Frontend
from components.processor import Processor
from components.worker import TensorRTLLMWorker from components.worker import TensorRTLLMWorker
Frontend.link(Processor).link(TensorRTLLMWorker) Frontend.link(TensorRTLLMWorker)
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from components.frontend import Frontend from components.frontend import Frontend
from components.kv_router import Router
from components.processor import Processor
from components.worker import TensorRTLLMWorker from components.worker import TensorRTLLMWorker
Frontend.link(Processor).link(Router).link(TensorRTLLMWorker) Frontend.link(TensorRTLLMWorker)
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from components.frontend import Frontend from components.frontend import Frontend
from components.prefill_worker import TensorRTLLMPrefillWorker from components.prefill_worker import TensorRTLLMPrefillWorker
from components.processor import Processor
from components.worker import TensorRTLLMWorker from components.worker import TensorRTLLMWorker
Frontend.link(Processor).link(TensorRTLLMWorker).link(TensorRTLLMPrefillWorker) Frontend.link(TensorRTLLMWorker).link(TensorRTLLMPrefillWorker)
...@@ -14,11 +14,7 @@ ...@@ -14,11 +14,7 @@
# limitations under the License. # limitations under the License.
from components.frontend import Frontend from components.frontend import Frontend
from components.kv_router import Router
from components.prefill_worker import TensorRTLLMPrefillWorker from components.prefill_worker import TensorRTLLMPrefillWorker
from components.processor import Processor
from components.worker import TensorRTLLMWorker from components.worker import TensorRTLLMWorker
Frontend.link(Processor).link(Router).link(TensorRTLLMWorker).link( Frontend.link(TensorRTLLMWorker).link(TensorRTLLMPrefillWorker)
TensorRTLLMPrefillWorker
)
...@@ -16,6 +16,18 @@ from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher ...@@ -16,6 +16,18 @@ from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
def _to_signed_i64(value: int | None) -> int | None:
"""Convert a Python int to signed 64-bit range by two's complement."""
if value is None:
return None
if value >= 2**63:
return value - 2**64
if value < -(2**63):
return ((value + 2**63) % 2**64) - 2**63
return value
class ManagedThread(threading.Thread): class ManagedThread(threading.Thread):
""" """
A thread that runs a task and handles errors. A thread that runs a task and handles errors.
...@@ -242,13 +254,13 @@ class Publisher: ...@@ -242,13 +254,13 @@ class Publisher:
event_id = event["event_id"] event_id = event["event_id"]
data = event["data"] data = event["data"]
if data["type"] == "stored": if data["type"] == "stored":
parent_hash = data["parent_hash"] parent_hash = _to_signed_i64(data["parent_hash"])
token_ids = [] token_ids = []
num_block_tokens = [] num_block_tokens = []
block_hashes = [] block_hashes = []
for block in data["blocks"]: for block in data["blocks"]:
token_num_in_block = len(block["tokens"]) token_num_in_block = len(block["tokens"])
block_hash = block["block_hash"] block_hash = _to_signed_i64(block["block_hash"])
if token_num_in_block > self.kv_block_size: if token_num_in_block > self.kv_block_size:
logging.error( logging.error(
f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self.kv_block_size}" f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self.kv_block_size}"
...@@ -284,6 +296,7 @@ class Publisher: ...@@ -284,6 +296,7 @@ class Publisher:
elif data["type"] == "removed": elif data["type"] == "removed":
block_hashes = [] block_hashes = []
for block_hash in data["block_hashes"]: for block_hash in data["block_hashes"]:
block_hash = _to_signed_i64(block_hash)
if block_hash in self.partial_block_hashes: if block_hash in self.partial_block_hashes:
logging.debug( logging.debug(
f"Skipping removing block hash {block_hash} since it is a partial block" f"Skipping removing block hash {block_hash} since it is a partial block"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment