Unverified Commit 568eb100 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

fix: Use Rust Ingress (dynamo-run) for the Frontend (#1391)

parent 2ae40af5
......@@ -21,18 +21,12 @@ import os
import signal
import threading
from contextlib import asynccontextmanager
from dataclasses import asdict
from enum import Enum
from queue import Queue
from typing import Any, Optional
from common.parser import LLMAPIConfig
from common.protocol import (
DisaggregatedTypeConverter,
TRTLLMWorkerRequest,
TRTLLMWorkerResponse,
TRTLLMWorkerResponseOutput,
)
from common.protocol import DisaggregatedTypeConverter
from common.utils import ManagedThread, ServerType
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import LLM, SamplingParams
......@@ -40,6 +34,7 @@ from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig,
parse_disagg_config_file,
)
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
......@@ -65,14 +60,26 @@ def update_args_from_disagg_config(
return engine_config
def get_sampling_params(sampling_params):
# Removes keys starting with '_' from the sampling params which gets
# added by the LLM API. TRTLLM does not support creating SamplingParams
# from a dictionary with keys starting with '_'.
cleaned_dict = {
key: value for key, value in sampling_params.items() if not key.startswith("_")
}
return SamplingParams(**cleaned_dict)
def _to_signed_i64(value: int | None) -> int | None:
"""Convert a Python int to signed 64-bit range by two's complement."""
if value is None:
return None
if value >= 2**63:
return value - 2**64
if value < -(2**63):
return ((value + 2**63) % 2**64) - 2**63
return value
def get_sampling_params(sampling_params_dict, default_sampling_params):
sampling_params = copy.deepcopy(default_sampling_params)
for key, value in sampling_params_dict.items():
if value is None:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
return sampling_params
class BaseTensorrtLLMEngine:
......@@ -161,6 +168,12 @@ class BaseTensorrtLLMEngine:
target=asyncio.run, args=(self._run_llm_engine(),)
)
# Populate default sampling params from the model
tokenizer = tokenizer_factory(self._engine_config.model_name)
self._default_sampling_params = SamplingParams()
self._default_sampling_params._setup(tokenizer)
self._default_sampling_params.stop = None
self.publish_kv_cache_events_thread = None
self.publish_stats_thread = None
......@@ -308,13 +321,13 @@ class BaseTensorrtLLMEngine:
event_id = event["event_id"]
data = event["data"]
if data["type"] == "stored":
parent_hash = data["parent_hash"]
parent_hash = _to_signed_i64(data["parent_hash"])
token_ids = []
num_block_tokens = []
block_hashes = []
for block in data["blocks"]:
token_num_in_block = len(block["tokens"])
block_hash = block["block_hash"]
block_hash = _to_signed_i64(block["block_hash"])
if token_num_in_block > self._kv_block_size:
logger.error(
f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self._kv_block_size}"
......@@ -350,6 +363,7 @@ class BaseTensorrtLLMEngine:
elif data["type"] == "removed":
block_hashes = []
for block_hash in data["block_hashes"]:
block_hash = _to_signed_i64(block_hash)
if block_hash in self._partial_block_hashes:
logger.debug(
f"Skipping removing block hash {block_hash} since it is a partial block"
......@@ -458,7 +472,8 @@ class BaseTensorrtLLMEngine:
async def _get_remote_prefill_response(self, request):
prefill_request = copy.deepcopy(request)
prefill_request.sampling_params["max_tokens"] = 1
# TRTLLM requires max_tokens to be set for prefill requests.
prefill_request.stop_conditions.max_tokens = 1
prefill_request.disaggregated_params = DisaggregatedParams(
request_type=DisaggRequestType.CONTEXT_ONLY.value
)
......@@ -466,7 +481,7 @@ class BaseTensorrtLLMEngine:
if self._prefill_client is None:
raise ValueError("Prefill client not initialized")
# TODO: Use smart KV router to determine which prefill worker to use.
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
ctx_responses = [
ctx_response
async for ctx_response in await self._prefill_client.round_robin(
......@@ -480,17 +495,10 @@ class BaseTensorrtLLMEngine:
logger.debug(
f"Received response from prefill worker: {ctx_responses[0].data()}"
)
ctx_response_obj = TRTLLMWorkerResponse.model_validate_json(
ctx_responses[0].data()
)
ctx_response_obj.outputs = [
TRTLLMWorkerResponseOutput(**ctx_response_obj.outputs[0])
]
assert ctx_response_obj.outputs[0].disaggregated_params is not None
return ctx_response_obj
remote_prefill_response = ctx_responses[0]
return remote_prefill_response
async def generate(self, request: TRTLLMWorkerRequest):
async def generate(self, request):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
......@@ -500,7 +508,7 @@ class BaseTensorrtLLMEngine:
self._ongoing_request_count += 1
try:
worker_inputs = request.tokens.tokens
worker_inputs = request.token_ids
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
......@@ -508,24 +516,33 @@ class BaseTensorrtLLMEngine:
)
)
if self._remote_prefill and self._server_type == ServerType.GEN:
ctx_response_obj = await self._get_remote_prefill_response(request)
num_output_tokens_so_far = 0
yield TRTLLMWorkerResponse(
request_id=request.id,
prompt_token_ids=ctx_response_obj.prompt_token_ids,
outputs=[asdict(ctx_response_obj.outputs[0])],
finished=ctx_response_obj.finished,
).model_dump_json(exclude_unset=True)
if self._remote_prefill and self._server_type == ServerType.GEN:
ctx_response = await self._get_remote_prefill_response(request)
remote_prefill_response = ctx_response.data()
if (
remote_prefill_response["finish_reason"] == "stop"
or remote_prefill_response["finish_reason"] == "error"
):
yield remote_prefill_response
return
num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
worker_inputs = ctx_response_obj.prompt_token_ids
# Decode the disaggregated params from the remote prefill response
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
DisaggregatedParams(
**ctx_response_obj.outputs[0].disaggregated_params
**remote_prefill_response["disaggregated_params"]
)
)
)
# Send the first token response to the client
first_token_response = remote_prefill_response
first_token_response.pop("disaggregated_params")
yield first_token_response
disaggregated_params.request_type = (
DisaggRequestType.GENERATION_ONLY.value
)
......@@ -534,29 +551,44 @@ class BaseTensorrtLLMEngine:
f"Worker inputs: {worker_inputs}, disaggregated params: {disaggregated_params}"
)
sampling_params = get_sampling_params(request.sampling_params)
sampling_params = get_sampling_params(
request.sampling_options.dict(), self._default_sampling_params
)
max_tokens = request.stop_conditions.max_tokens
if max_tokens:
sampling_params.max_tokens = max_tokens
async for response in self._llm_engine.generate_async(
inputs=worker_inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=False
if self._server_type == ServerType.CTX
else request.streaming,
streaming=self._server_type != ServerType.CTX,
):
# Convert the disaggregated params to OAI format so
# it can be sent over the network.
response.outputs[
0
].disaggregated_params = DisaggregatedTypeConverter.to_oai_disaggregated_params(
response.outputs[0].disaggregated_params
)
if response.finished and self._server_type != ServerType.CTX:
yield {"finish_reason": "stop", "token_ids": []}
break
if not response.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
yield TRTLLMWorkerResponse(
request_id=request.id,
prompt_token_ids=response.prompt_token_ids,
outputs=[asdict(response.outputs[0])],
finished=response.finished,
).model_dump_json(exclude_unset=True)
output = response.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self._server_type == ServerType.CTX:
# Return the disaggregated params only when operating in prefill mode.
out[
"disaggregated_params"
] = DisaggregatedTypeConverter.to_oai_disaggregated_params(
output.disaggregated_params
).dict()
yield out
num_output_tokens_so_far = next_total_toks
except CppExecutorError:
signal.raise_signal(signal.SIGINT)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import asdict
from typing import Any, Dict, List, Union
from common.parser import LLMAPIConfig
from common.protocol import (
DynamoTRTLLMChatCompletionResponseStreamChoice,
DynamoTRTLLMChatCompletionStreamResponse,
DynamoTRTLLMCompletionResponseStreamChoice,
DynamoTRTLLMCompletionStreamResponse,
Tokens,
TRTLLMWorkerRequest,
TRTLLMWorkerResponse,
TRTLLMWorkerResponseOutput,
)
from common.utils import ConversationMessage
from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, tokenizer_factory
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
DeltaMessage,
FunctionCall,
ToolCall,
UsageInfo,
)
logger = logging.getLogger(__name__)
class ChatProcessorMixin:
def __init__(
self, engine_config: LLMAPIConfig, using_engine_generator: bool = False
):
self._engine_config = engine_config
logger.info(f"Using LLM API config: {self._engine_config.to_dict()}")
# model name for chat processor
self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}")
self._tokenizer = tokenizer_factory(self._model_name)
self.chat_processor = ChatProcessor(
self._model_name, self._tokenizer, using_engine_generator
)
self.completions_processor = CompletionsProcessor(
self._model_name, self._tokenizer
)
def parse_chat_message_content(
message: ChatCompletionMessageParam,
) -> Union[ConversationMessage, List[ConversationMessage], List[None]]:
role = message["role"]
content = message.get("content")
if content is None:
return []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)]
texts: List[str] = []
for part in content:
part_type = part["type"]
if part_type == "text":
text = part["text"] # type: ignore
texts.append(text)
else:
raise NotImplementedError(f"{part_type} is not supported")
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
class BaseChatProcessor:
def __init__(
self,
model: str,
tokenizer: TokenizerBase,
):
self.model = model
self.tokenizer = tokenizer
def _get_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
role = "assistant"
else:
role = request.messages[-1]["role"]
return role
def _stream_usage_info(
self, request: ChatCompletionRequest, prompt_tokens: int, completion_tokens: int
):
if (
request.stream_options
and request.stream_options.include_usage
and request.stream_options.continuous_usage_stats
):
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
usage = None
return usage
def _create_logprobs(
self, token_ids: List[int], logprobs: List[float]
) -> ChatCompletionLogProbs:
assert len(token_ids) == len(
logprobs
), "token_ids and logprobs have different lengths"
content: List[ChatCompletionLogProbsContent] = []
for token_id, logprob in zip(token_ids, logprobs):
token = self.tokenizer.decode(token_id)
# returning multiple logprobs is not supported
first_logprob = ChatCompletionLogProbsContent(
token=token,
# NOTE: min logprob -9999.0 for probabilities extremely close to 0
logprob=max(logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")),
)
content.append(first_logprob)
chat_logprobs = ChatCompletionLogProbs(content=content)
return chat_logprobs
class ChatProcessor(BaseChatProcessor):
def __init__(
self,
model: str,
tokenizer: TokenizerBase,
using_engine_generator: bool = False,
):
super().__init__(model, tokenizer)
self.using_engine_generator = using_engine_generator
def yield_first_chat(
self,
request: ChatCompletionRequest,
request_id: str,
response: RequestOutput,
content: str | None = None,
):
role = self._get_role(request)
num_choices = 1 if request.n is None else request.n
num_tokens = len(response.prompt_token_ids)
content = response.outputs[0].text_diff
for i in range(num_choices):
choice = DynamoTRTLLMChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role, content=content),
finish_reason=None,
)
if response.outputs[0].disaggregated_params is not None:
# Do not include the disaggregated params in response
# from processor.
pass
chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id,
choices=[choice],
model=self.model,
)
chunk.usage = self._stream_usage_info(request, num_tokens, 0)
return chunk.model_dump_json()
def create_chat_stream_response(
self,
request: ChatCompletionRequest,
request_id: str,
response: RequestOutput,
conversation: List[Dict[str, Any]],
first_iteration: bool = True,
) -> str:
num_choices = 1 if request.n is None else request.n
finish_reason_sent = [False] * num_choices
role = self._get_role(request)
prompt_tokens = len(response.prompt_token_ids)
if first_iteration:
return self.yield_first_chat(request, request_id, response)
# TODO: Fix this
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
if last_msg_content:
return self.yield_first_chat(
request, request_id, response, content=last_msg_content
)
first_iteration = False
for output in response.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_text = output.text_diff
if (
request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
delta_message = DeltaMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text,
)
)
]
)
else:
delta_message = DeltaMessage(content=delta_text, role=role)
choice = DynamoTRTLLMChatCompletionResponseStreamChoice(
index=i, delta=delta_message, finish_reason=None
)
if request.logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = self._create_logprobs(token_ids, logprobs)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True
chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id,
choices=[choice],
model=self.model,
)
logger.debug(f"[processor] Chunk: {chunk}")
chunk.usage = self._stream_usage_info(request, prompt_tokens, output.length)
return chunk.model_dump_json()
# TODO: make request.stream_options.include_usage = True when stream=False in rust
if request.stream_options and request.stream_options.include_usage:
completion_tokens = sum(output.length for output in response.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id,
choices=[],
model=self.model,
usage=final_usage,
)
return final_usage_chunk.model_dump_json()
return "data: [DONE]\n\n"
async def preprocess(self, request):
conversation: List[Any] = []
for message in request.messages:
conversation.extend(parse_chat_message_content(message))
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
prompt = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=True,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
sampling_params._setup(self.tokenizer)
sampling_params.stop = None
return TRTLLMWorkerRequest(
id=request.id,
model=request.model,
sampling_params=asdict(sampling_params),
streaming=request.stream,
conversation=conversation,
disaggregated_params=request.disaggregated_params,
tokens=Tokens(tokens=prompt),
)
async def postprocess(
self,
engine_generator,
request,
conversation,
):
first_iteration = True
async for raw_response in engine_generator:
if self.using_engine_generator:
response = TRTLLMWorkerResponse(
request_id=request.id,
prompt=raw_response.prompt,
prompt_token_ids=raw_response.prompt_token_ids,
outputs=[asdict(raw_response.outputs[0])],
finished=raw_response.finished,
)
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
else:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
last_token_ids_len = response.outputs[0]["_last_token_ids_len"]
response.outputs[0]["text"] = self.tokenizer.decode(
response.outputs[0]["token_ids"][last_token_ids_len:]
)
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
response_data = self.create_chat_stream_response(
request,
request.id,
response,
conversation,
first_iteration=first_iteration,
)
first_iteration = False
logger.debug(f"[postprocessor] Response: {response_data}")
yield response_data
class CompletionsProcessor:
def __init__(
self,
model: str,
tokenizer: TokenizerBase,
):
self.model = model
self.tokenizer = tokenizer
def create_completion_stream_response(self, request, response):
num_choices = 1 if request.n is None else request.n
echoed = [False] * num_choices
# len(response.outputs) is always 1
for gen_idx, output in enumerate(response.outputs):
text = output.text
if request.echo and not echoed[gen_idx]:
text = request.prompt + text
choice = DynamoTRTLLMCompletionResponseStreamChoice(
text=text,
index=output.index,
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
)
chunk = DynamoTRTLLMCompletionStreamResponse(
model=self.model,
choices=[choice],
)
return chunk.model_dump_json()
async def preprocess(self, request):
if isinstance(request.prompt, str) or (
isinstance(request.prompt, list)
and all(isinstance(x, int) for x in request.prompt)
):
prompt = request.prompt
else:
raise ValueError(
"Invalid prompt type. Only string or list of integers are supported."
)
sampling_params = request.to_sampling_params()
sampling_params._setup(self.tokenizer)
sampling_params.stop = None
return TRTLLMWorkerRequest(
id=request.id,
model=request.model,
streaming=request.stream,
sampling_params=asdict(sampling_params),
disaggregated_params=request.disaggregated_params,
tokens=Tokens(tokens=self.tokenizer.encode(prompt)),
)
async def postprocess(
self,
engine_generator,
request,
):
async for raw_response in engine_generator:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
last_token_ids_len = response.outputs[0]["_last_token_ids_len"]
response.outputs[0]["text"] = self.tokenizer.decode(
response.outputs[0]["token_ids"][last_token_ids_len:]
)
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
response_data = self.create_completion_stream_response(
request,
response,
)
logger.debug(f"[postprocessor] Response: {response_data}")
yield response_data
......@@ -131,6 +131,12 @@ def parse_tensorrt_llm_args(
parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file"
)
parser.add_argument(
"--served_model_name",
type=str,
help="Name of the model to serve",
default=None,
)
parser.add_argument(
"--llmapi-disaggregated-config",
"-c",
......
......@@ -12,126 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, List, Literal, Optional, TypeAlias, Union
from typing import List, Optional
import torch
from common.utils import ConversationMessage
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponseStreamChoice,
CompletionRequest,
CompletionResponseStreamChoice,
DisaggregatedParams,
UsageInfo,
)
# The max_tokens is being deprecated in favor of max_completion_tokens.
# However, TRTLLM protocol might still refer it as max_tokens.
class DynamoTRTLLMCompletionRequest(CompletionRequest):
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid4().hex)}")
max_completion_tokens: Optional[int] = None
nvext: Optional[dict] = Field(default=None)
class DynamoTRTLLMChatCompletionRequest(ChatCompletionRequest):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid4().hex)}")
max_completion_tokens: Optional[int] = None
max_tokens: Optional[int] = None
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
nvext: Optional[dict] = Field(default=None)
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
class Tokens(BaseModel):
tokens: list[int]
class Request(BaseModel):
prompt: str
sampling_params: dict
streaming: bool
class TRTLLMWorkerRequest(BaseModel):
model: str
id: str
prompt: str | None = None
sampling_params: dict
streaming: bool = True
conversation: Optional[List[ConversationMessage]] = Field(default=None)
tokens: Optional[Tokens] = Field(default=None)
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
@dataclass(slots=True)
class Logprob:
"""Holds logprob and vocab rank for a token."""
logprob: float
rank: Optional[int] = None
# List of token_id_to_Logprob dict for prompt or generation texts
TokenLogprobs: TypeAlias = list[dict[int, Logprob]]
@dataclass
class TRTLLMWorkerResponseOutput:
index: int
text: str = ""
token_ids: Optional[List[int]] = field(default_factory=list)
cumulative_logprob: Optional[float] = None
logprobs: Optional[TokenLogprobs] = field(default_factory=list)
prompt_logprobs: Optional[TokenLogprobs] = field(default_factory=list)
finish_reason: Optional[Literal["stop", "length", "timeout", "cancelled"]] = None
stop_reason: Optional[Union[int, str]] = None
generation_logits: Optional[torch.Tensor] = None
disaggregated_params: Optional[DisaggregatedParams] = None
# hidden fields for tracking the diffs
_last_text_len: int = field(default=0, init=True, repr=False)
_last_token_ids_len: int = field(default=0, init=True, repr=False)
_last_logprobs_len: int = field(default=0, init=True, repr=False)
_incremental_states: Optional[dict] = field(default=None, init=True, repr=False)
# the result of result_handler passed to postprocess workers
_postprocess_result: Any = None
@property
def length(self) -> int:
return 0 if self.token_ids is None else len(self.token_ids)
@property
def text_diff(self) -> str:
return self.text[self._last_text_len :]
@property
def token_ids_diff(self) -> List[int]:
return (
[] if self.token_ids is None else self.token_ids[self._last_token_ids_len :]
)
# Ignoring the mypy error here as this is copied from TensorRT-LLM project.
# https://github.com/NVIDIA/TensorRT-LLM/blob/19c6e68bec891b66146a09647ee7b70230ef5f67/tensorrt_llm/executor/result.py#L68
# TODO: Work with the TensorRT-LLM team to get this fixed.
@property
def logprobs_diff(self) -> List[float]: # type: ignore
return [] if self.logprobs is None else self.logprobs[self._last_logprobs_len :] # type: ignore
class TRTLLMWorkerResponse(BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
request_id: str
prompt: str | None = None
prompt_token_ids: list[int]
outputs: list[dict]
finished: bool
TokenIdType = int
class DisaggregatedTypeConverter:
......@@ -175,36 +68,37 @@ class DisaggregatedTypeConverter:
)
# Chat Completions
class DynamoTRTLLMChatCompletionResponseStreamChoice(
ChatCompletionResponseStreamChoice
):
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class DynamoTRTLLMChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid4().hex)}")
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[DynamoTRTLLMChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
## Completions
# TODO: move these to common for all LLMs once we adopt dynamo-run
# derived from lib/llm/src/protocols/common/preprocessor.rs
class StopConditions(BaseModel):
max_tokens: Optional[int] = None
stop: Optional[List[str]] = None
stop_token_ids_hidden: Optional[List[TokenIdType]] = None
min_tokens: Optional[int] = None
ignore_eos: Optional[bool] = None
class SamplingOptions(BaseModel):
n: Optional[int] = None
best_of: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
repetition_penalty: Optional[float] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
min_p: Optional[float] = None
use_beam_search: Optional[bool] = None
length_penalty: Optional[float] = None
seed: Optional[int] = None
class DynamoTRTLLMCompletionResponseStreamChoice(CompletionResponseStreamChoice):
class TRTLLMWorkerRequest(BaseModel):
token_ids: List[TokenIdType]
stop_conditions: StopConditions
sampling_options: SamplingOptions
eos_token_ids: List[TokenIdType] = Field(default_factory=list)
mdc_sum: Optional[str] = None
annotations: List[str] = Field(default_factory=list)
estimated_prefix_hit_num_blocks: Optional[int] = None
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class DynamoTRTLLMCompletionStreamResponse(BaseModel):
model_config = ConfigDict(extra="forbid")
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid4().hex)}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[DynamoTRTLLMCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
......@@ -17,7 +17,6 @@ import logging
import subprocess
from pathlib import Path
from components.processor import Processor
from components.worker import TensorRTLLMWorker
from fastapi import FastAPI
from pydantic import BaseModel
......@@ -30,88 +29,91 @@ from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__)
def get_http_binary_path():
def get_dynamo_run_binary():
"""Find the dynamo-run binary path in SDK or fallback to 'dynamo-run' command."""
sdk_path = Path(sdk.__file__)
binary_path = sdk_path.parent / "cli/bin/http"
binary_path = sdk_path.parent / "cli/bin/dynamo-run"
if not binary_path.exists():
return "http"
return "dynamo-run"
else:
return str(binary_path)
class FrontendConfig(BaseModel):
"""Configuration for the Frontend service including model and HTTP server settings."""
served_model_name: str
endpoint_chat: str
endpoint_completions: str
port: int = 8080
endpoint: str
port: int = 8000
router: str = "round-robin"
block_size: int = 32
# todo this should be called ApiServer
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
app=FastAPI(title="TensorRT LLM Example"),
app=FastAPI(title="TensorRT-LLM Example"),
)
# todo this should be called ApiServer
class Frontend:
worker = depends(TensorRTLLMWorker)
processor = depends(Processor)
def __init__(self):
frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
# Chat/completions Endpoint
subprocess.run(
[
"llmctl",
"http",
"remove",
"chat-models",
frontend_config.served_model_name,
]
)
subprocess.run(
[
"llmctl",
"http",
"add",
"chat-models",
frontend_config.served_model_name,
frontend_config.endpoint_chat,
]
"""Initialize Frontend service with HTTP server and model configuration."""
self.frontend_config = FrontendConfig(
**ServiceConfig.get_parsed_config("Frontend")
)
self.process = None
# Completions Endpoint
subprocess.run(
[
"llmctl",
"http",
"remove",
"completions",
frontend_config.served_model_name,
]
)
subprocess.run(
[
"llmctl",
"http",
"add",
"completions",
frontend_config.served_model_name,
frontend_config.endpoint_completions,
]
logger.warning(f"Frontend config: {self.frontend_config}")
self.start_ingress_and_processor()
def start_ingress_and_processor(self):
"""Starting dynamo-run based ingress and processor"""
logger.info(
f"Starting HTTP server and processor on port {self.frontend_config.port}"
)
dynamo_run_binary = get_dynamo_run_binary()
cmd = [
dynamo_run_binary,
"in=http",
"out=dyn",
"--http-port",
str(self.frontend_config.port),
"--router-mode",
self.frontend_config.router,
]
logger.info(f"Frontend cmd: {cmd}")
logger.info("Starting HTTP server")
http_binary = get_http_binary_path()
process = subprocess.Popen(
[http_binary, "-p", str(frontend_config.port)], stdout=None, stderr=None
self.process = subprocess.Popen(
cmd,
stdout=None,
stderr=None,
)
try:
process.wait()
except KeyboardInterrupt:
process.terminate()
process.wait()
def close(self):
"""Clean up resources by terminating the subprocess."""
if self.process is not None:
try:
logger.info("Terminating subprocess...")
self.process.terminate()
# Wait for process to terminate with a timeout
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.warning("Subprocess did not terminate gracefully, forcing kill")
self.process.kill()
self.process.wait()
except Exception as e:
logger.error(f"Error while terminating subprocess: {e}")
finally:
self.process = None
def __del__(self):
"""Destructor to ensure subprocess is cleaned up."""
self.close()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import logging
import random
import traceback
from argparse import Namespace
from typing import AsyncIterator
from common.protocol import Tokens
from components.worker import TensorRTLLMWorker
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__)
WorkerId = str
def parse_args(service_name, prefix) -> Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
# TODO: Read block size
parser.add_argument(
"--block-size",
type=int,
default=32,
help="KV block size",
)
parser.add_argument(
"--custom-router",
type=bool,
default=False,
help="Whether to use custom router or not",
)
config = ServiceConfig.get_instance()
config_args = config.as_args(service_name, prefix=prefix)
args = parser.parse_args(config_args)
return args
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Router:
worker = depends(TensorRTLLMWorker)
def __init__(self):
logger.info("Initializing KV router.")
class_name = self.__class__.__name__
self.args = parse_args(class_name, "")
@async_on_start
async def async_init(self):
self.runtime = dynamo_context["runtime"]
self.workers_client = (
await self.runtime.namespace("dynamo")
.component("TensorRTLLMWorker")
.endpoint("generate")
.client()
)
while len(self.workers_client.instance_ids()) < self.args.min_workers:
logger.info(
f"Waiting for more workers to be ready.\n"
f" Current: {len(self.workers_client.instance_ids())},"
f" Required: {self.args.min_workers}"
)
await asyncio.sleep(30)
kv_listener = self.runtime.namespace("dynamo").component("TensorRTLLMWorker")
await kv_listener.create_service()
self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
logger.info("KV Router initialized")
def _cost_function(
self,
scores: OverlapScores | None,
metrics: AggregatedMetrics | None,
token_length: int,
):
worker_scores = {}
if scores:
for worker_id, score in scores.scores.items():
# score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better
worker_scores[worker_id] = (
score * self.indexer.block_size() / token_length
)
logger.debug(f"Worker scores: {worker_scores}")
worker_metrics = {}
# pull metrics for each worker
max_waiting = 0.0
if metrics:
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
worker_metrics[worker_id] = {
"gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
if hasattr(endpoint, "gpu_cache_usage_perc")
else 0.0,
"num_requests_waiting": endpoint.num_requests_waiting
if hasattr(endpoint, "num_requests_waiting")
else 0.0,
"gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
else 0.0,
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
)
logger.debug(f"Worker metrics: {worker_metrics}")
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.instance_ids()
worker_logits = {}
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(
worker_id,
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
if max_waiting > 0
else 0.0
)
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = (
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
)
logger.debug(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
)
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return ""
# Select the worker with the highest logit
if worker_logits:
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
else:
best_worker_id = ""
# Log the metrics for the selected worker
if best_worker_id:
logger.debug(
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
)
logger.debug(
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
metrics_dict = worker_metrics.get(best_worker_id, {})
logger.debug(
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
)
logger.debug(
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
)
logger.debug(
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
)
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
@endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
if self.indexer is None or self.metrics_aggregator is None:
yield "_0.0"
lora_id = 0
worker_id = ""
try:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics()
schedule_result = self._cost_function(scores, metrics, token_length)
except Exception:
schedule_result = ""
logger.warning(f"Error during worker selection: {traceback.format_exc()}")
if schedule_result == "":
worker_id = ""
prefix_hit_rate = 0.0
else:
worker_id, prefix_hit_rate = schedule_result
yield f"{worker_id}_{prefix_hit_rate}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
from common.chat_processor import ChatProcessorMixin
from common.parser import parse_tensorrt_llm_args
from common.protocol import (
DynamoTRTLLMChatCompletionRequest,
DynamoTRTLLMCompletionRequest,
)
from common.utils import RequestType
from components.kv_router import Router
from components.worker import TensorRTLLMWorker
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__)
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ChatProcessorMixin):
worker = depends(TensorRTLLMWorker)
router = depends(Router)
def __init__(
self,
):
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
args, engine_config = parse_tensorrt_llm_args(config_args)
self.remote_prefill = args.remote_prefill
self.router_mode = args.router
self.min_workers = 1
self.args = args
super().__init__(engine_config)
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = TensorRTLLMWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
if self.args.router == "kv":
router_ns, router_name = Router.dynamo_address() # type: ignore
self.router_client = (
await runtime.namespace(router_ns)
.component(router_name)
.endpoint("generate")
.client()
)
while len(self.worker_client.instance_ids()) < self.min_workers:
logger.info(
f"Waiting for workers to be ready.\n"
f" Current: {len(self.worker_client.instance_ids())},"
f" Required: {self.min_workers}"
)
await asyncio.sleep(30)
async def _generate(self, raw_request, request_type: RequestType):
raw_request.skip_special_tokens = False
raw_request.add_special_tokens = False
raw_request.spaces_between_special_tokens = False
logger.debug(f"[preprocessor] Received request: {raw_request}")
if request_type == RequestType.CHAT:
preprocessed_request = await self.chat_processor.preprocess(raw_request)
else:
preprocessed_request = await self.completions_processor.preprocess(
raw_request
)
worker_id = ""
if self.router_mode == "kv":
router_generator = await self.router_client.generate(
preprocessed_request.tokens.model_dump_json()
)
decision = await router_generator.__anext__()
decision = decision.data()
worker_id, prefix_hit_rate = decision.split("_")
prefix_hit_rate = float(prefix_hit_rate)
logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
if worker_id == "":
if self.router_mode == "round-robin":
self._send_request = self.worker_client.round_robin
else:
# fallback to random
self._send_request = self.worker_client.random
engine_generator = await self._send_request(
preprocessed_request.model_dump_json()
)
else:
engine_generator = await self.worker_client.direct(
preprocessed_request.model_dump_json(), int(worker_id)
)
if request_type == RequestType.CHAT:
async for response in self.chat_processor.postprocess(
engine_generator,
raw_request,
preprocessed_request.conversation,
):
logger.debug(f"[preprocessor] Response: {response}")
yield json.loads(response)
else:
async for response in self.completions_processor.postprocess(
engine_generator, raw_request
):
logger.debug(f"[preprocessor] Response: {response}")
yield json.loads(response)
@endpoint(name="chat/completions")
async def generate_chat(self, raw_request: DynamoTRTLLMChatCompletionRequest):
# max_tokens is deprecated, however if the max_tokens is provided instead
# of max_completion_tokens, we will use the value as max_completion_tokens.
if raw_request.max_tokens is not None:
if raw_request.max_completion_tokens is None:
raw_request.max_completion_tokens = raw_request.max_tokens
else:
if raw_request.max_tokens != raw_request.max_completion_tokens:
raise ValueError(
"max_tokens and max_completion_tokens must be the same"
)
# min_tokens isn't currently propagated through the Rust OpenAI HTTP frontend,
# and ignore_eos is passed through the 'nvext' field, so set both when found.
if raw_request.nvext:
ignore_eos = raw_request.nvext.get("ignore_eos")
raw_request.ignore_eos = ignore_eos
# If ignore_eos is True, set min_tokens to max_tokens to guarantee
# the full expected OSL for consistent benchmarking purposes.
if ignore_eos:
logger.debug(
f"[preprocessor] `ignore_eos` detected, setting `min_tokens` to `max_completion_tokens`: {raw_request.max_completion_tokens}"
)
raw_request.min_tokens = raw_request.max_completion_tokens
async for response in self._generate(raw_request, RequestType.CHAT):
yield response
@endpoint(name="completions")
async def completions(self, raw_request: DynamoTRTLLMCompletionRequest):
# min_tokens isn't currently propagated through the Rust OpenAI HTTP frontend,
# and ignore_eos is passed through the 'nvext' field, so set both when found.
if raw_request.nvext:
ignore_eos = raw_request.nvext.get("ignore_eos")
raw_request.ignore_eos = ignore_eos
# If ignore_eos is True, set min_tokens to max_tokens to guarantee
# the full expected OSL for consistent benchmarking purposes.
if ignore_eos:
logger.debug(
f"[preprocessor] `ignore_eos` detected, setting `min_tokens` to `max_tokens`: {raw_request.max_tokens}"
)
raw_request.min_tokens = raw_request.max_tokens
async for response in self._generate(raw_request, RequestType.COMPLETION):
yield response
......@@ -21,6 +21,7 @@ from common.protocol import TRTLLMWorkerRequest
from common.utils import ServerType
from components.prefill_worker import TensorRTLLMPrefillWorker
from dynamo.llm import ModelType, register_llm
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
......@@ -43,10 +44,12 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine):
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
args, engine_config = parse_tensorrt_llm_args(config_args)
self.served_model_name = args.served_model_name
worker_id = dynamo_context["endpoints"][0].lease_id()
namespace, _ = TensorRTLLMWorker.dynamo_address() # type: ignore
self._min_prefill_workers = args.min_prefill_workers
super().__init__(
namespace_str="dynamo",
namespace_str=namespace,
component_str=class_name,
worker_id=worker_id,
engine_config=engine_config,
......@@ -62,6 +65,24 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine):
async def async_init(self):
self._init_engine()
runtime = dynamo_context["runtime"]
logger.info("Registering LLM for discovery")
comp_ns, comp_name = TensorRTLLMWorker.dynamo_address() # type: ignore
endpoint = runtime.namespace(comp_ns).component(comp_name).endpoint("generate")
try:
await register_llm(
ModelType.Backend,
endpoint,
self._engine_config.model_name,
self.served_model_name,
kv_cache_block_size=self._kv_block_size,
)
logger.info("Successfully registered LLM for discovery")
except Exception as e:
logger.error(f"Failed to register LLM for discovery: {e}")
raise
if self._remote_prefill:
runtime = dynamo_context["runtime"]
comp_ns, comp_name = TensorRTLLMPrefillWorker.dynamo_address() # type: ignore
......
......@@ -15,20 +15,15 @@
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint_completions: dynamo.Processor.completions
endpoint_chat: dynamo.Processor.chat/completions
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
Processor:
engine_args: "configs/llm_api_config.yaml"
router: round-robin
ServiceArgs:
workers: 5 # to reduce the tokenization bottleneck at a high concurrency
TensorRTLLMWorker:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
engine_args: "configs/llm_api_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
gpu: 1
\ No newline at end of file
......@@ -15,19 +15,9 @@
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint_completions: dynamo.Processor.completions
endpoint_chat: dynamo.Processor.chat/completions
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
Processor:
engine_args: "configs/llm_api_config_router.yaml"
router: kv
ServiceArgs:
workers: 5 # to reduce the tokenization bottleneck at a high concurrency
Router:
model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
min-workers: 1
TensorRTLLMWorker:
engine_args: "configs/llm_api_config_router.yaml"
......
......@@ -16,18 +16,12 @@
Frontend:
# This is the client-facing model name, you can set this to anything you'd like.
served_model_name: "nvidia/DeepSeek-R1-FP4"
endpoint_chat: dynamo.Processor.chat/completions
endpoint_completions: dynamo.Processor.completions
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
Processor:
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
router: round-robin
# Parallelize preprocessing/tokenization to avoid bottlenecks
ServiceArgs:
workers: 5
TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
router: round-robin
ServiceArgs:
......
......@@ -16,19 +16,12 @@
Frontend:
# This is the client-facing model name, you can set this to anything you'd like.
served_model_name: "nvidia/DeepSeek-R1-FP4"
endpoint_chat: dynamo.Processor.chat/completions
endpoint_completions: dynamo.Processor.completions
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
Processor:
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
router: round-robin
remote-prefill: true
# Parallelize preprocessing/tokenization to avoid bottlenecks
ServiceArgs:
workers: 5
TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
llmapi-disaggregated-config: "configs/deepseek_r1/disagg_llm_api_config.yaml"
remote-prefill: true
......
......@@ -15,23 +15,17 @@
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint_completions: dynamo.Processor.completions
endpoint_chat: dynamo.Processor.chat/completions
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
Processor:
engine_args: "configs/llm_api_config.yaml"
router: round-robin
remote-prefill: true
ServiceArgs:
workers: 5 # to reduce the tokenization bottleneck at a high concurrency
TensorRTLLMWorker:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
engine_args: "configs/llm_api_config.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml"
router: round-robin
remote-prefill: true
min-prefill-workers: 1
router: round-robin
ServiceArgs:
workers: 1
resources:
......
......@@ -15,37 +15,27 @@
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint_completions: dynamo.Processor.completions
endpoint_chat: dynamo.Processor.chat/completions
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
Processor:
engine_args: "configs/llm_api_config_disagg_router.yaml"
router: "kv"
remote-prefill: true
ServiceArgs:
workers: 5 # to reduce the tokenization bottleneck at a high concurrency
Router:
model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
min-workers: 1
router: kv
TensorRTLLMWorker:
engine_args: "configs/llm_api_config_disagg_router.yaml"
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
engine_args: "configs/llm_api_config_router.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_router_configs/single_node_config.yaml"
router: kv
remote-prefill: true
min-prefill-workers: 1
router: kv
ServiceArgs:
workers: 1
resources:
gpu: 1
TensorRTLLMPrefillWorker:
engine_args: "configs/llm_api_config_disagg_router.yaml"
engine_args: "configs/llm_api_config_router.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_router_configs/single_node_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
gpu: 1
\ No newline at end of file
......@@ -14,7 +14,6 @@
# limitations under the License.
from components.frontend import Frontend
from components.processor import Processor
from components.worker import TensorRTLLMWorker
Frontend.link(Processor).link(TensorRTLLMWorker)
Frontend.link(TensorRTLLMWorker)
......@@ -14,8 +14,6 @@
# limitations under the License.
from components.frontend import Frontend
from components.kv_router import Router
from components.processor import Processor
from components.worker import TensorRTLLMWorker
Frontend.link(Processor).link(Router).link(TensorRTLLMWorker)
Frontend.link(TensorRTLLMWorker)
......@@ -15,7 +15,6 @@
from components.frontend import Frontend
from components.prefill_worker import TensorRTLLMPrefillWorker
from components.processor import Processor
from components.worker import TensorRTLLMWorker
Frontend.link(Processor).link(TensorRTLLMWorker).link(TensorRTLLMPrefillWorker)
Frontend.link(TensorRTLLMWorker).link(TensorRTLLMPrefillWorker)
......@@ -14,11 +14,7 @@
# limitations under the License.
from components.frontend import Frontend
from components.kv_router import Router
from components.prefill_worker import TensorRTLLMPrefillWorker
from components.processor import Processor
from components.worker import TensorRTLLMWorker
Frontend.link(Processor).link(Router).link(TensorRTLLMWorker).link(
TensorRTLLMPrefillWorker
)
Frontend.link(TensorRTLLMWorker).link(TensorRTLLMPrefillWorker)
......@@ -16,6 +16,18 @@ from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
logging.basicConfig(level=logging.DEBUG)
def _to_signed_i64(value: int | None) -> int | None:
"""Convert a Python int to signed 64-bit range by two's complement."""
if value is None:
return None
if value >= 2**63:
return value - 2**64
if value < -(2**63):
return ((value + 2**63) % 2**64) - 2**63
return value
class ManagedThread(threading.Thread):
"""
A thread that runs a task and handles errors.
......@@ -242,13 +254,13 @@ class Publisher:
event_id = event["event_id"]
data = event["data"]
if data["type"] == "stored":
parent_hash = data["parent_hash"]
parent_hash = _to_signed_i64(data["parent_hash"])
token_ids = []
num_block_tokens = []
block_hashes = []
for block in data["blocks"]:
token_num_in_block = len(block["tokens"])
block_hash = block["block_hash"]
block_hash = _to_signed_i64(block["block_hash"])
if token_num_in_block > self.kv_block_size:
logging.error(
f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self.kv_block_size}"
......@@ -284,6 +296,7 @@ class Publisher:
elif data["type"] == "removed":
block_hashes = []
for block_hash in data["block_hashes"]:
block_hash = _to_signed_i64(block_hash)
if block_hash in self.partial_block_hashes:
logging.debug(
f"Skipping removing block hash {block_hash} since it is a partial block"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment