Commit da38e96a authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: TRT-LLM disaggregated serving using UCX (#562)


Signed-off-by: default avatarTanmay Verma <tanmay2592@gmail.com>
Signed-off-by: default avatarTanmay Verma <tanmayv@nvidia.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent 538b4630
......@@ -201,6 +201,12 @@ RUN pip install dist/ai_dynamo_runtime*cp312*.whl && \
ENV DYNAMO_KV_CAPI_PATH="/opt/dynamo/bindings/lib/libdynamo_llm_capi.so"
ENV DYNAMO_HOME=/workspace
# Copy launch banner
RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/launch_message.txt \
sed '/^#\s/d' /workspace/launch_message.txt > ~/.launch_screen && \
echo "cat ~/.launch_screen" >> ~/.bashrc
# FIXME: Copy more specific folders in for dev/debug after directory restructure
COPY . /workspace
......
......@@ -342,7 +342,7 @@ See instructions [here](/examples/tensorrt_llm/README.md#run-container) to run t
Execute the following to load the TensorRT-LLM model specified in the configuration.
```
dynamo run out=pystr:/workspace/examples/tensorrt_llm/engines/agg_engine.py -- --engine_args /workspace/examples/tensorrt_llm/configs/llm_api_config.yaml
dynamo run out=pystr:/workspace/examples/tensorrt_llm/engines/trtllm_engine.py -- --engine_args /workspace/examples/tensorrt_llm/configs/llm_api_config.yaml
```
#### Dynamo does the pre-processing
......
......@@ -25,6 +25,14 @@ This directory contains examples and reference implementations for deploying Lar
See [deployment architectures](../llm/README.md#deployment-architectures) to learn about the general idea of the architecture.
Note that this TensorRT-LLM version does not support all the options yet.
Note: TensorRT-LLM disaggregation does not support conditional disaggregation yet. You can only configure the deployment to always use aggregate or disaggregated serving.
## Getting Started
1. Choose a deployment architecture based on your requirements
2. Configure the components as needed
3. Deploy using the provided scripts
### Prerequisites
Start required services (etcd and NATS) using [Docker Compose](../../deploy/docker-compose.yml)
......@@ -68,6 +76,29 @@ This build script internally points to the base container image built with step
```
## Run Deployment
This figure shows an overview of the major components to deploy:
```
+------+ +-----------+ +------------------+ +---------------+
| HTTP |----->| processor |----->| Worker |------------>| Prefill |
| |<-----| |<-----| |<------------| Worker |
+------+ +-----------+ +------------------+ +---------------+
| ^ |
query best | | return | publish kv events
worker | | worker_id v
| | +------------------+
| +---------| kv-router |
+------------->| |
+------------------+
```
Note: The above architecture illustrates all the components. The final components
that get spawned depend upon the chosen graph.
### Example architectures
#### Aggregated serving
......@@ -82,21 +113,23 @@ cd /workspace/examples/tensorrt_llm
dynamo serve graphs.agg_router:Frontend -f ./configs/agg_router.yaml
```
<!--
This is work in progress and will be enabled soon.
#### Disaggregated serving
```bash
cd /workspace/examples/llm
dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml
cd /workspace/examples/tensorrt_llm
TRTLLM_USE_UCX_KVCACHE=1 dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml
```
We are defining TRTLLM_USE_UCX_KVCACHE so that TRTLLM uses UCX for transfering the KV
cache between the context and generation workers.
#### Disaggregated serving with KV Routing
```bash
cd /workspace/examples/llm
dynamo serve graphs.disagg_router:Frontend -f ./configs/disagg_router.yaml
cd /workspace/examples/tensorrt_llm
TRTLLM_USE_UCX_KVCACHE=1 dynamo serve graphs.disagg_router:Frontend -f ./configs/disagg_router.yaml
```
-->
We are defining TRTLLM_USE_UCX_KVCACHE so that TRTLLM uses UCX for transfering the KV
cache between the context and generation workers.
### Client
......@@ -108,7 +141,7 @@ See [close deployment](../../docs/guides/dynamo_serve.md#close-deployment) secti
Remaining tasks:
- [ ] Add support for the disaggregated serving.
- [x] Add support for the disaggregated serving.
- [ ] Add integration test coverage.
- [ ] Add instructions for benchmarking.
- [ ] Add multi-node support.
......
......@@ -15,88 +15,129 @@
import asyncio
import copy
import logging
import os
import signal
import threading
from contextlib import asynccontextmanager
from dataclasses import dataclass
from dataclasses import asdict
from enum import Enum
from queue import Queue
from typing import Any, Optional
from common.chat_processor import ChatProcessor, CompletionsProcessor
from common.parser import LLMAPIConfig
from common.utils import ManagedThread
from tensorrt_llm._torch import LLM
from tensorrt_llm.logger import logger
from transformers import AutoTokenizer
from common.protocol import (
DisaggregatedTypeConverter,
TRTLLMWorkerRequest,
TRTLLMWorkerResponse,
TRTLLMWorkerResponseOutput,
)
from common.utils import ManagedThread, ServerType
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import LLM, SamplingParams
from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig,
parse_disagg_config_file,
)
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
from dynamo.llm import KvMetricsPublisher
from .kv_cache_event_publisher import KVCacheEventPublisher
logger.set_level("info")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class ChatProcessorMixin:
def __init__(self, engine_config: LLMAPIConfig):
self._engine_config = engine_config
logger.info(f"Using LLM API config: {self._engine_config.to_dict()}")
# model name for chat processor
self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}")
# model for LLMAPI input
self._model = self._model_name
if self._engine_config.model_path:
self._model = self._engine_config.model_path
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_path
)
logger.info(f"Using model from path: {self._engine_config.model_path}")
else:
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_name
)
if self._engine_config.extra_args.get("tokenizer", None):
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.extra_args.get("tokenizer", None)
)
class DisaggRequestType(Enum):
CONTEXT_ONLY = "context_only"
GENERATION_ONLY = "generation_only"
self.chat_processor = ChatProcessor(self._model_name, self._tokenizer)
self.completions_processor = CompletionsProcessor(
self._model_name, self._tokenizer
)
def update_args_from_disagg_config(
engine_config: LLMAPIConfig, server_config: CtxGenServerConfig
):
# Update the LLM API config with the disaggregated config
# Allows for different configs for context and generation servers
engine_config.extra_args.update(**server_config.other_args)
engine_config.update_sub_configs(server_config.other_args)
return engine_config
@dataclass
class TensorrtLLMEngineConfig:
namespace_str: str = "dynamo"
component_str: str = "tensorrt-llm"
engine_config: LLMAPIConfig = None
worker_id: Optional[str] = None
kv_metrics_publisher: Optional[KvMetricsPublisher] = None
publish_stats: bool = False
publish_kv_cache_events: bool = False
# default block size is 32 for pytorch backend
kv_block_size: int = 32
def get_sampling_params(sampling_params):
# Removes keys starting with '_' from the sampling params which gets
# added by the LLM API. TRTLLM does not support creating SamplingParams
# from a dictionary with keys starting with '_'.
cleaned_dict = {
key: value for key, value in sampling_params.items() if not key.startswith("_")
}
return SamplingParams(**cleaned_dict)
class BaseTensorrtLLMEngine(ChatProcessorMixin):
class BaseTensorrtLLMEngine:
def __init__(
self,
trt_llm_engine_config: TensorrtLLMEngineConfig,
namespace_str: str = "dynamo",
component_str: str = "tensorrt-llm",
worker_id: Optional[str] = None,
engine_config: LLMAPIConfig = None,
remote_prefill: bool = False,
min_workers: int = 0,
disagg_config_file: Optional[str] = None,
block_size: int = 32,
router: str = "round_robin",
server_type: ServerType = ServerType.GEN,
):
super().__init__(trt_llm_engine_config.engine_config)
self._namespace_str = trt_llm_engine_config.namespace_str
self._component_str = trt_llm_engine_config.component_str
self._worker_id = trt_llm_engine_config.worker_id
self._kv_metrics_publisher = trt_llm_engine_config.kv_metrics_publisher
self._publish_stats = trt_llm_engine_config.publish_stats
self._publish_kv_cache_events = trt_llm_engine_config.publish_kv_cache_events
self._kv_block_size = trt_llm_engine_config.kv_block_size
self._error_queue: Optional[Queue] = None
self._init_engine()
self._namespace_str = namespace_str
self._component_str = component_str
self._worker_id = worker_id
self._remote_prefill = remote_prefill
self._min_workers = 0
self._kv_block_size = block_size
self._router = router
self._server_type = server_type
self._prefill_client = None
self._error_queue: Queue = Queue()
self._kv_metrics_publisher = None
if self._remote_prefill:
self._min_workers = min_workers
if disagg_config_file is None or not os.path.exists(disagg_config_file):
raise ValueError(
"llmapi_disaggregated_config file does not exist or not provided"
)
disagg_config = parse_disagg_config_file(disagg_config_file)
server_config: CtxGenServerConfig = None
for config in disagg_config.server_configs:
# Select the first context server config
if config.type == server_type.value:
server_config = config
break
if server_config is None:
server_type_str = (
"generation" if server_type == ServerType.GEN else "context"
)
raise ValueError(
f"No {server_type_str} server config found. Please check the disaggregated config file."
)
engine_config = update_args_from_disagg_config(engine_config, server_config)
if router == "kv":
self._publish_stats = True
self._publish_events = True
else:
self._publish_stats = False
self._publish_events = False
if self._publish_stats:
self._kv_metrics_publisher = KvMetricsPublisher()
self._engine_config = engine_config
def _init_engine(self):
logger.info("Initializing engine")
......@@ -126,12 +167,11 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
self._event_thread = None
raise e
self._error_queue = Queue()
try:
if self._publish_stats:
self._init_publish_metrics_thread()
if self._publish_kv_cache_events:
if self._publish_events:
self._init_publish_kv_cache_events_thread()
except Exception as e:
logger.error(f"Failed to initialize publish metrics threads: {e}")
......@@ -308,7 +348,10 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
try:
llm = await loop.run_in_executor(
None,
lambda: LLM(model=self._model, **self._engine_config.to_dict()),
lambda: LLM(
model=self._engine_config.model_name,
**self._engine_config.to_dict(),
),
)
yield llm
finally:
......@@ -368,3 +411,106 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
self._llm_engine = None
logger.info("Shutdown complete")
async def _get_remote_prefill_response(self, request):
prefill_request = copy.deepcopy(request)
prefill_request.sampling_params["max_tokens"] = 1
prefill_request.disaggregated_params = DisaggregatedParams(
request_type=DisaggRequestType.CONTEXT_ONLY.value
)
if self._prefill_client is None:
raise ValueError("Prefill client not initialized")
# TODO: Use smart KV router to determine which prefill worker to use.
ctx_responses = [
ctx_response
async for ctx_response in await self._prefill_client.round_robin(
prefill_request.model_dump_json()
)
]
if len(ctx_responses) > 1:
raise ValueError(
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
)
logger.debug(
f"Received response from prefill worker: {ctx_responses[0].data()}"
)
ctx_response_obj = TRTLLMWorkerResponse.model_validate_json(
ctx_responses[0].data()
)
ctx_response_obj.outputs = [
TRTLLMWorkerResponseOutput(**ctx_response_obj.outputs[0])
]
assert ctx_response_obj.outputs[0].disaggregated_params is not None
return ctx_response_obj
async def generate(self, request: TRTLLMWorkerRequest):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
if not self._error_queue.empty():
raise self._error_queue.get()
self._ongoing_request_count += 1
try:
worker_inputs = request.tokens.tokens
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
request.disaggregated_params
)
)
if self._remote_prefill and self._server_type == ServerType.GEN:
ctx_response_obj = await self._get_remote_prefill_response(request)
worker_inputs = ctx_response_obj.prompt_token_ids
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
DisaggregatedParams(
**ctx_response_obj.outputs[0].disaggregated_params
)
)
)
disaggregated_params.request_type = (
DisaggRequestType.GENERATION_ONLY.value
)
logger.debug(
f"Worker inputs: {worker_inputs}, disaggregated params: {disaggregated_params}"
)
sampling_params = get_sampling_params(request.sampling_params)
async for response in self._llm_engine.generate_async(
inputs=worker_inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=False
if self._server_type == ServerType.CTX
else request.streaming,
):
# Convert the disaggregated params to OAI format so
# it can be sent over the network.
response.outputs[
0
].disaggregated_params = DisaggregatedTypeConverter.to_oai_disaggregated_params(
response.outputs[0].disaggregated_params
)
yield TRTLLMWorkerResponse(
request_id=request.id,
prompt_token_ids=response.prompt_token_ids,
outputs=[asdict(response.outputs[0])],
finished=response.finished,
).model_dump_json(exclude_unset=True)
except CppExecutorError:
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
self._start_threads()
self._ongoing_request_count -= 1
......@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import asdict
from typing import Any, Dict, List, Union
from common.parser import LLMAPIConfig
from common.protocol import (
DisaggregatedTypeConverter,
DynamoTRTLLMChatCompletionResponseStreamChoice,
......@@ -27,10 +29,9 @@ from common.protocol import (
TRTLLMWorkerResponse,
TRTLLMWorkerResponseOutput,
)
from common.utils import ConversationMessage, ServerType
from common.utils import ConversationMessage
from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
......@@ -41,10 +42,44 @@ from tensorrt_llm.serve.openai_protocol import (
ToolCall,
UsageInfo,
)
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
logger.set_level("debug")
logger = logging.getLogger(__name__)
class ChatProcessorMixin:
def __init__(
self, engine_config: LLMAPIConfig, using_engine_generator: bool = False
):
self._engine_config = engine_config
logger.info(f"Using LLM API config: {self._engine_config.to_dict()}")
# model name for chat processor
self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}")
# model for LLMAPI input
self._model = self._model_name
if self._engine_config.model_path:
self._model = self._engine_config.model_path
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_path
)
logger.info(f"Using model from path: {self._engine_config.model_path}")
else:
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_name
)
if self._engine_config.extra_args.get("tokenizer", None):
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.extra_args.get("tokenizer", None)
)
self.chat_processor = ChatProcessor(
self._model_name, self._tokenizer, using_engine_generator
)
self.completions_processor = CompletionsProcessor(
self._model_name, self._tokenizer
)
def parse_chat_message_content(
......@@ -290,6 +325,7 @@ class ChatProcessor(BaseChatProcessor):
return TRTLLMWorkerRequest(
id=request.id,
model=request.model,
prompt=prompt,
sampling_params=asdict(sampling_params),
conversation=conversation,
......@@ -303,8 +339,10 @@ class ChatProcessor(BaseChatProcessor):
engine_generator,
request,
conversation,
server_type: ServerType,
):
first_iteration = True
last_text_len = 0
last_token_ids_len = 0
async for raw_response in engine_generator:
if self.using_engine_generator:
response = TRTLLMWorkerResponse(
......@@ -317,21 +355,27 @@ class ChatProcessor(BaseChatProcessor):
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
else:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
response.outputs[0]["text"] = self.tokenizer.decode(
response.outputs[0]["token_ids"]
)
# Need to keep track of the last text and token ids length
# to calculate the diff.
# TODO: This is a hack to get the diff. We should identify why
# the diff is not being calculated in the worker.
response.outputs[0]["_last_text_len"] = last_text_len
response.outputs[0]["_last_token_ids_len"] = last_token_ids_len
last_text_len = len(response.outputs[0]["text"])
last_token_ids_len = len(response.outputs[0]["token_ids"])
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
if (
request.disaggregated_params is not None
and server_type == ServerType.CTX
):
response_data = self.yield_first_chat(request, request.id, response)
else:
response_data = self.create_chat_stream_response(
request,
request.id,
response,
conversation,
first_iteration=(not request.disaggregated_params is not None),
)
response_data = self.create_chat_stream_response(
request,
request.id,
response,
conversation,
first_iteration=first_iteration,
)
first_iteration = False
logger.debug(f"[postprocessor] Response: {response_data}")
yield response_data
......
......@@ -14,11 +14,10 @@
# limitations under the License.
import ctypes
import logging
from ctypes import c_char_p, c_int64, c_uint32
from tensorrt_llm.logger import logger
logger.set_level("info")
logger = logging.getLogger(__name__)
class DynamoResult:
......@@ -53,7 +52,7 @@ class KVCacheEventPublisher:
logger.info("KVCacheEventPublisher initialization failed!")
except Exception as e:
print(f"Failed to load {lib_path}")
logger.exception(f"Failed to load {lib_path}")
raise e
self.lib.dynamo_kv_event_publish_stored.argtypes = [
......
......@@ -40,6 +40,13 @@ class LLMAPIConfig:
self.kv_cache_config = kv_cache_config
self.extra_args = kwargs
# Hardcoded to skip tokenizer init for now.
# We will handle the tokenization/detokenization
# in the base engine.
if "skip_tokenizer_init" in self.extra_args:
self.extra_args.pop("skip_tokenizer_init")
self.skip_tokenizer_init = True
def to_dict(self) -> Dict[str, Any]:
data = {
"pytorch_backend_config": self.pytorch_backend_config,
......@@ -133,6 +140,12 @@ def parse_tensorrt_llm_args(
default=1,
help="Minimum number of workers for aggregated (monolith) server",
)
parser.add_argument(
"--min-prefill-workers",
type=int,
default=1,
help="Minimum number of prefill workers for disaggregated server",
)
parser.add_argument(
"--block-size",
type=int,
......@@ -156,14 +169,6 @@ def parse_dynamo_run_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]
parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file"
)
# Disaggregated mode is not supported in dynamo-run launcher yet.
# parser.add_argument(
# "--llmapi-disaggregated-config",
# "-c",
# type=str,
# help="Path to the llmapi disaggregated config file",
# default=None,
# )
parser.add_argument(
"--publish-kv-cache-events",
action="store_true",
......
......@@ -23,7 +23,6 @@ import torch
from common.utils import ConversationMessage
from pydantic import BaseModel, ConfigDict, Field
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponseStreamChoice,
......@@ -59,6 +58,7 @@ class Request(BaseModel):
class TRTLLMWorkerRequest(BaseModel):
model: str
id: str
prompt: str | None = None
sampling_params: dict
......@@ -67,44 +67,6 @@ class TRTLLMWorkerRequest(BaseModel):
tokens: Optional[Tokens] = Field(default=None)
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
def to_sampling_params(self) -> SamplingParams:
sampling_params = SamplingParams(
frequency_penalty=self.sampling_params.get("frequency_penalty", 0.0),
return_log_probs=self.sampling_params.get("logprobs", False),
max_tokens=self.sampling_params.get("max_tokens", 16),
n=self.sampling_params.get("n", 1),
presence_penalty=self.sampling_params.get("presence_penalty", 0.0),
seed=self.sampling_params.get("seed", None),
stop=self.sampling_params.get("stop", None),
temperature=self.sampling_params.get("temperature", 0.7),
# chat-completion-sampling-params
best_of=self.sampling_params.get("best_of", None),
use_beam_search=self.sampling_params.get("use_beam_search", False),
top_k=self.sampling_params.get("top_k", 0),
top_p=self.sampling_params.get("top_p", 1.0),
top_p_min=self.sampling_params.get("top_p_min", None),
min_p=self.sampling_params.get("min_p", 0.0),
repetition_penalty=self.sampling_params.get("repetition_penalty", 1.0),
length_penalty=self.sampling_params.get("length_penalty", 1.0),
early_stopping=self.sampling_params.get("early_stopping", False),
stop_token_ids=self.sampling_params.get("stop_token_ids", []),
include_stop_str_in_output=self.sampling_params.get(
"include_stop_str_in_output", False
),
ignore_eos=self.sampling_params.get("ignore_eos", False),
min_tokens=self.sampling_params.get("min_tokens", 0),
skip_special_tokens=self.sampling_params.get("skip_special_tokens", False),
spaces_between_special_tokens=self.sampling_params.get(
"spaces_between_special_tokens", False
),
truncate_prompt_tokens=self.sampling_params.get(
"truncate_prompt_tokens", None
),
# chat-completion-extra-params
add_special_tokens=self.sampling_params.get("add_special_tokens", False),
)
return sampling_params
@dataclass
class TRTLLMWorkerResponseOutput:
......
......@@ -15,6 +15,7 @@
import asyncio
import logging
import threading
import traceback
import weakref
......@@ -22,9 +23,7 @@ from enum import Enum
from queue import Queue
from typing import Callable, Optional, TypedDict, Union
from tensorrt_llm.logger import logger
logger.set_level("info")
logger = logging.getLogger(__name__)
class RoutingStrategy(Enum):
......@@ -43,6 +42,8 @@ class ServerType(Enum):
GEN = "gen"
# Context server used for disaggregated requests
CTX = "ctx"
# Dynamo run server used for Dynamo run requests
DYN_RUN = "dyn_run"
class ConversationMessage(TypedDict):
......
......@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import subprocess
from pathlib import Path
from components.agg_worker import TensorRTLLMWorker
from components.processor import Processor
from components.worker import TensorRTLLMWorker
from pydantic import BaseModel
from dynamo import sdk
......@@ -25,6 +26,8 @@ from dynamo.sdk import depends, service
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__)
def get_http_binary_path():
sdk_path = Path(sdk.__file__)
......@@ -75,7 +78,7 @@ class Frontend:
]
)
print("Starting HTTP server")
logger.info("Starting HTTP server")
http_binary = get_http_binary_path()
process = subprocess.Popen(
[http_binary, "-p", str(frontend_config.port)], stdout=None, stderr=None
......
......@@ -15,20 +15,20 @@
import argparse
import asyncio
import logging
import random
import traceback
from argparse import Namespace
from typing import AsyncIterator
from common.protocol import Tokens
from components.agg_worker import TensorRTLLMWorker
from tensorrt_llm.logger import logger
from components.worker import TensorRTLLMWorker
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger.set_level("debug")
logger = logging.getLogger(__name__)
WorkerId = str
......@@ -92,8 +92,7 @@ class Router:
.client()
)
while len(self.workers_client.endpoint_ids()) < self.args.min_workers:
# TODO: replace print w/ vllm_logger.info
print(
logger.info(
f"Waiting for more workers to be ready.\n"
f" Current: {len(self.workers_client.endpoint_ids())},"
f" Required: {self.args.min_workers}"
......@@ -104,7 +103,7 @@ class Router:
await kv_listener.create_service()
self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
print("KV Router initialized")
logger.info("KV Router initialized")
def _cost_function(
self,
......
......@@ -13,20 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import signal
from dataclasses import asdict
import logging
from common.base_engine import BaseTensorrtLLMEngine, TensorrtLLMEngineConfig
from common.base_engine import BaseTensorrtLLMEngine
from common.parser import parse_tensorrt_llm_args
from common.protocol import TRTLLMWorkerRequest, TRTLLMWorkerResponse
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.logger import logger
from common.protocol import TRTLLMWorkerRequest
from common.utils import ServerType
from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger.set_level("debug")
logger = logging.getLogger(__name__)
@service(
......@@ -37,84 +34,42 @@ logger.set_level("debug")
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class TensorRTLLMWorker(BaseTensorrtLLMEngine):
"""
Request handler for the generate endpoint
"""
class TensorRTLLMPrefillWorker(BaseTensorrtLLMEngine):
def __init__(self):
print("Initializing TensorRT-LLM Worker")
logger.info("Initializing TensorRT-LLM Prefill Worker")
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
self.args, self.engine_config = parse_tensorrt_llm_args(config_args)
if self.args.router == "kv":
publish_stats = True
publish_events = True
else:
publish_stats = False
publish_events = False
trt_llm_engine_config = TensorrtLLMEngineConfig(
args, engine_config = parse_tensorrt_llm_args(config_args)
worker_id = dynamo_context["endpoints"][0].lease_id()
super().__init__(
namespace_str="dynamo",
component_str=class_name,
engine_config=self.engine_config,
publish_stats=publish_stats,
publish_kv_cache_events=publish_events,
kv_block_size=self.args.block_size,
worker_id=worker_id,
engine_config=engine_config,
remote_prefill=args.remote_prefill,
min_workers=args.min_workers,
disagg_config_file=args.llmapi_disaggregated_config,
block_size=args.block_size,
router=args.router,
server_type=ServerType.CTX,
)
if publish_stats:
trt_llm_engine_config.kv_metrics_publisher = KvMetricsPublisher()
trt_llm_engine_config.worker_id = dynamo_context["endpoints"][0].lease_id()
self.trtllm_engine_args = trt_llm_engine_config
@async_on_start
async def async_init(self):
super().__init__(self.trtllm_engine_args)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(lambda _: print("metrics publisher endpoint created"))
print("TensorRT-LLM Worker initialized")
self._init_engine()
if self._kv_metrics_publisher is not None:
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logger.info("metrics publisher endpoint created")
)
logger.info("TensorRT-LLM Prefill Worker initialized")
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
await self.trtllm_engine_args.kv_metrics_publisher.create_endpoint(component)
await self.kv_metrics_publisher.create_endpoint(component)
@dynamo_endpoint()
async def generate(self, request: TRTLLMWorkerRequest):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
if self._error_queue.qsize() > 0:
error = self._error_queue.get()
raise error
self._ongoing_request_count += 1
try:
# TODO: combine with disagg worker
# TODO: only send tokens. Should be pretty simple.
async for response in self._llm_engine.generate_async(
inputs=request.prompt,
sampling_params=request.to_sampling_params(),
disaggregated_params=None,
streaming=True,
):
yield TRTLLMWorkerResponse(
request_id=request.id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
outputs=[asdict(response.outputs[0])],
finished=response.finished,
).model_dump_json(exclude_unset=True)
except CppExecutorError:
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
self._start_threads()
self._ongoing_request_count -= 1
async for response in super().generate(request):
yield response
......@@ -15,19 +15,19 @@
import asyncio
import json
import logging
from common.base_engine import ChatProcessorMixin
from common.chat_processor import ChatProcessorMixin
from common.parser import parse_tensorrt_llm_args
from common.protocol import DynamoTRTLLMChatCompletionRequest
from common.utils import RequestType, ServerType
from components.agg_worker import TensorRTLLMWorker
from common.utils import RequestType
from components.kv_router import Router
from tensorrt_llm.logger import logger
from components.worker import TensorRTLLMWorker
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger.set_level("debug")
logger = logging.getLogger(__name__)
@service(
......@@ -48,11 +48,13 @@ class Processor(ChatProcessorMixin):
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
self.args, self.engine_config = parse_tensorrt_llm_args(config_args)
self.router_mode = self.args.router
super().__init__(self.engine_config)
args, engine_config = parse_tensorrt_llm_args(config_args)
self.remote_prefill = args.remote_prefill
self.router_mode = args.router
self.min_workers = 1
super().__init__(engine_config)
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
......@@ -64,7 +66,7 @@ class Processor(ChatProcessorMixin):
.client()
)
while len(self.worker_client.endpoint_ids()) < self.min_workers:
print(
logger.info(
f"Waiting for workers to be ready.\n"
f" Current: {len(self.worker_client.endpoint_ids())},"
f" Required: {self.min_workers}"
......@@ -97,15 +99,16 @@ class Processor(ChatProcessorMixin):
break
if worker_id == "":
if self.args.router == "round-robin":
engine_generator = await self.worker_client.round_robin(
preprocessed_request.model_dump_json()
)
if self.router_mode == "round-robin":
self._send_request = self.worker_client.round_robin
else:
# fallback to random
engine_generator = await self.worker_client.random(
preprocessed_request.model_dump_json()
)
self._send_request = self.worker_client.random
engine_generator = await self._send_request(
preprocessed_request.model_dump_json()
)
else:
engine_generator = await self.worker_client.direct(
preprocessed_request.model_dump_json(), int(worker_id)
......@@ -116,7 +119,6 @@ class Processor(ChatProcessorMixin):
engine_generator,
raw_request,
preprocessed_request.conversation,
ServerType.GEN,
):
logger.debug(f"[preprocessor] Response: {response}")
yield json.loads(response)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from common.base_engine import BaseTensorrtLLMEngine
from common.parser import parse_tensorrt_llm_args
from common.protocol import TRTLLMWorkerRequest
from common.utils import ServerType
from components.prefill_worker import TensorRTLLMPrefillWorker
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class TensorRTLLMWorker(BaseTensorrtLLMEngine):
prefill_worker = depends(TensorRTLLMPrefillWorker)
def __init__(self):
logger.info("Initializing TensorRT-LLM Worker")
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
args, engine_config = parse_tensorrt_llm_args(config_args)
worker_id = dynamo_context["endpoints"][0].lease_id()
self._min_prefill_workers = args.min_prefill_workers
super().__init__(
namespace_str="dynamo",
component_str=class_name,
worker_id=worker_id,
engine_config=engine_config,
remote_prefill=args.remote_prefill,
min_workers=args.min_workers,
disagg_config_file=args.llmapi_disaggregated_config,
block_size=args.block_size,
router=args.router,
server_type=ServerType.GEN,
)
@async_on_start
async def async_init(self):
self._init_engine()
if self._remote_prefill:
runtime = dynamo_context["runtime"]
comp_ns, comp_name = TensorRTLLMPrefillWorker.dynamo_address() # type: ignore
self._prefill_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
while len(self._prefill_client.endpoint_ids()) < self._min_prefill_workers:
logger.info(
f"Waiting for prefill workers to be ready.\n"
f" Current: {len(self._prefill_client.endpoint_ids())},"
f" Required: {self._min_prefill_workers}"
)
await asyncio.sleep(2)
if self._kv_metrics_publisher is not None:
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logger.info("metrics publisher endpoint created")
)
logger.info("TensorRT-LLM Worker initialized")
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
await self._kv_metrics_publisher.create_endpoint(component)
@dynamo_endpoint()
async def generate(self, request: TRTLLMWorkerRequest):
async for response in super().generate(request):
yield response
......@@ -19,7 +19,7 @@ Frontend:
port: 8000
Processor:
engine_args: "configs/llm_api_config.yaml"
engine_args: "configs/llm_api_config_router.yaml"
router: kv
Router:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000
Processor:
engine_args: "configs/llm_api_config.yaml"
router: round-robin
remote-prefill: true
TensorRTLLMWorker:
engine_args: "configs/llm_api_config.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml"
remote-prefill: true
min-prefill-workers: 1
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
TensorRTLLMPrefillWorker:
engine_args: "configs/llm_api_config.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000
Processor:
engine_args: "configs/llm_api_config_disagg_router.yaml"
router: "kv"
remote-prefill: true
Router:
model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
min-workers: 1
TensorRTLLMWorker:
engine_args: "configs/llm_api_config_disagg_router.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml"
remote-prefill: true
min-prefill-workers: 1
router: kv
ServiceArgs:
workers: 1
resources:
gpu: 1
TensorRTLLMPrefillWorker:
engine_args: "configs/llm_api_config_disagg_router.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
......@@ -17,6 +17,8 @@
# In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file
# TODO: figure out how to generate this from the service config or vice versa
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null
tensor_parallel_size: 1
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file
# TODO: figure out how to generate this from the service config or vice versa
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
kv_cache_config:
free_gpu_memory_fraction: 0.95
event_buffer_max_size: 1024
enable_block_reuse: true
pytorch_backend_config:
enable_overlap_scheduler: false
use_cuda_graph: false
enable_iter_perf_stats: true
\ No newline at end of file
......@@ -17,6 +17,8 @@
# In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file
# TODO: figure out how to generate this from the service config or vice versa
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null
tensor_parallel_size: 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment