Commit da38e96a authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: TRT-LLM disaggregated serving using UCX (#562)


Signed-off-by: default avatarTanmay Verma <tanmay2592@gmail.com>
Signed-off-by: default avatarTanmay Verma <tanmayv@nvidia.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent 538b4630
...@@ -201,6 +201,12 @@ RUN pip install dist/ai_dynamo_runtime*cp312*.whl && \ ...@@ -201,6 +201,12 @@ RUN pip install dist/ai_dynamo_runtime*cp312*.whl && \
ENV DYNAMO_KV_CAPI_PATH="/opt/dynamo/bindings/lib/libdynamo_llm_capi.so" ENV DYNAMO_KV_CAPI_PATH="/opt/dynamo/bindings/lib/libdynamo_llm_capi.so"
ENV DYNAMO_HOME=/workspace ENV DYNAMO_HOME=/workspace
# Copy launch banner
RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/launch_message.txt \
sed '/^#\s/d' /workspace/launch_message.txt > ~/.launch_screen && \
echo "cat ~/.launch_screen" >> ~/.bashrc
# FIXME: Copy more specific folders in for dev/debug after directory restructure # FIXME: Copy more specific folders in for dev/debug after directory restructure
COPY . /workspace COPY . /workspace
......
...@@ -342,7 +342,7 @@ See instructions [here](/examples/tensorrt_llm/README.md#run-container) to run t ...@@ -342,7 +342,7 @@ See instructions [here](/examples/tensorrt_llm/README.md#run-container) to run t
Execute the following to load the TensorRT-LLM model specified in the configuration. Execute the following to load the TensorRT-LLM model specified in the configuration.
``` ```
dynamo run out=pystr:/workspace/examples/tensorrt_llm/engines/agg_engine.py -- --engine_args /workspace/examples/tensorrt_llm/configs/llm_api_config.yaml dynamo run out=pystr:/workspace/examples/tensorrt_llm/engines/trtllm_engine.py -- --engine_args /workspace/examples/tensorrt_llm/configs/llm_api_config.yaml
``` ```
#### Dynamo does the pre-processing #### Dynamo does the pre-processing
......
...@@ -25,6 +25,14 @@ This directory contains examples and reference implementations for deploying Lar ...@@ -25,6 +25,14 @@ This directory contains examples and reference implementations for deploying Lar
See [deployment architectures](../llm/README.md#deployment-architectures) to learn about the general idea of the architecture. See [deployment architectures](../llm/README.md#deployment-architectures) to learn about the general idea of the architecture.
Note that this TensorRT-LLM version does not support all the options yet. Note that this TensorRT-LLM version does not support all the options yet.
Note: TensorRT-LLM disaggregation does not support conditional disaggregation yet. You can only configure the deployment to always use aggregate or disaggregated serving.
## Getting Started
1. Choose a deployment architecture based on your requirements
2. Configure the components as needed
3. Deploy using the provided scripts
### Prerequisites ### Prerequisites
Start required services (etcd and NATS) using [Docker Compose](../../deploy/docker-compose.yml) Start required services (etcd and NATS) using [Docker Compose](../../deploy/docker-compose.yml)
...@@ -68,6 +76,29 @@ This build script internally points to the base container image built with step ...@@ -68,6 +76,29 @@ This build script internally points to the base container image built with step
``` ```
## Run Deployment ## Run Deployment
This figure shows an overview of the major components to deploy:
```
+------+ +-----------+ +------------------+ +---------------+
| HTTP |----->| processor |----->| Worker |------------>| Prefill |
| |<-----| |<-----| |<------------| Worker |
+------+ +-----------+ +------------------+ +---------------+
| ^ |
query best | | return | publish kv events
worker | | worker_id v
| | +------------------+
| +---------| kv-router |
+------------->| |
+------------------+
```
Note: The above architecture illustrates all the components. The final components
that get spawned depend upon the chosen graph.
### Example architectures ### Example architectures
#### Aggregated serving #### Aggregated serving
...@@ -82,21 +113,23 @@ cd /workspace/examples/tensorrt_llm ...@@ -82,21 +113,23 @@ cd /workspace/examples/tensorrt_llm
dynamo serve graphs.agg_router:Frontend -f ./configs/agg_router.yaml dynamo serve graphs.agg_router:Frontend -f ./configs/agg_router.yaml
``` ```
<!--
This is work in progress and will be enabled soon.
#### Disaggregated serving #### Disaggregated serving
```bash ```bash
cd /workspace/examples/llm cd /workspace/examples/tensorrt_llm
dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml TRTLLM_USE_UCX_KVCACHE=1 dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml
``` ```
We are defining TRTLLM_USE_UCX_KVCACHE so that TRTLLM uses UCX for transfering the KV
cache between the context and generation workers.
#### Disaggregated serving with KV Routing #### Disaggregated serving with KV Routing
```bash ```bash
cd /workspace/examples/llm cd /workspace/examples/tensorrt_llm
dynamo serve graphs.disagg_router:Frontend -f ./configs/disagg_router.yaml TRTLLM_USE_UCX_KVCACHE=1 dynamo serve graphs.disagg_router:Frontend -f ./configs/disagg_router.yaml
``` ```
-->
We are defining TRTLLM_USE_UCX_KVCACHE so that TRTLLM uses UCX for transfering the KV
cache between the context and generation workers.
### Client ### Client
...@@ -108,7 +141,7 @@ See [close deployment](../../docs/guides/dynamo_serve.md#close-deployment) secti ...@@ -108,7 +141,7 @@ See [close deployment](../../docs/guides/dynamo_serve.md#close-deployment) secti
Remaining tasks: Remaining tasks:
- [ ] Add support for the disaggregated serving. - [x] Add support for the disaggregated serving.
- [ ] Add integration test coverage. - [ ] Add integration test coverage.
- [ ] Add instructions for benchmarking. - [ ] Add instructions for benchmarking.
- [ ] Add multi-node support. - [ ] Add multi-node support.
......
...@@ -15,88 +15,129 @@ ...@@ -15,88 +15,129 @@
import asyncio import asyncio
import copy
import logging
import os
import signal
import threading import threading
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import asdict
from enum import Enum
from queue import Queue from queue import Queue
from typing import Any, Optional from typing import Any, Optional
from common.chat_processor import ChatProcessor, CompletionsProcessor
from common.parser import LLMAPIConfig from common.parser import LLMAPIConfig
from common.utils import ManagedThread from common.protocol import (
from tensorrt_llm._torch import LLM DisaggregatedTypeConverter,
from tensorrt_llm.logger import logger TRTLLMWorkerRequest,
from transformers import AutoTokenizer TRTLLMWorkerResponse,
TRTLLMWorkerResponseOutput,
)
from common.utils import ManagedThread, ServerType
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import LLM, SamplingParams
from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig,
parse_disagg_config_file,
)
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
from dynamo.llm import KvMetricsPublisher from dynamo.llm import KvMetricsPublisher
from .kv_cache_event_publisher import KVCacheEventPublisher from .kv_cache_event_publisher import KVCacheEventPublisher
logger.set_level("info") logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class ChatProcessorMixin:
def __init__(self, engine_config: LLMAPIConfig):
self._engine_config = engine_config
logger.info(f"Using LLM API config: {self._engine_config.to_dict()}")
# model name for chat processor
self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}")
# model for LLMAPI input
self._model = self._model_name
if self._engine_config.model_path:
self._model = self._engine_config.model_path
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_path
)
logger.info(f"Using model from path: {self._engine_config.model_path}")
else:
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_name
)
if self._engine_config.extra_args.get("tokenizer", None): class DisaggRequestType(Enum):
self._tokenizer = AutoTokenizer.from_pretrained( CONTEXT_ONLY = "context_only"
self._engine_config.extra_args.get("tokenizer", None) GENERATION_ONLY = "generation_only"
)
self.chat_processor = ChatProcessor(self._model_name, self._tokenizer)
self.completions_processor = CompletionsProcessor( def update_args_from_disagg_config(
self._model_name, self._tokenizer engine_config: LLMAPIConfig, server_config: CtxGenServerConfig
) ):
# Update the LLM API config with the disaggregated config
# Allows for different configs for context and generation servers
engine_config.extra_args.update(**server_config.other_args)
engine_config.update_sub_configs(server_config.other_args)
return engine_config
@dataclass def get_sampling_params(sampling_params):
class TensorrtLLMEngineConfig: # Removes keys starting with '_' from the sampling params which gets
namespace_str: str = "dynamo" # added by the LLM API. TRTLLM does not support creating SamplingParams
component_str: str = "tensorrt-llm" # from a dictionary with keys starting with '_'.
engine_config: LLMAPIConfig = None cleaned_dict = {
worker_id: Optional[str] = None key: value for key, value in sampling_params.items() if not key.startswith("_")
kv_metrics_publisher: Optional[KvMetricsPublisher] = None }
publish_stats: bool = False return SamplingParams(**cleaned_dict)
publish_kv_cache_events: bool = False
# default block size is 32 for pytorch backend
kv_block_size: int = 32
class BaseTensorrtLLMEngine(ChatProcessorMixin): class BaseTensorrtLLMEngine:
def __init__( def __init__(
self, self,
trt_llm_engine_config: TensorrtLLMEngineConfig, namespace_str: str = "dynamo",
component_str: str = "tensorrt-llm",
worker_id: Optional[str] = None,
engine_config: LLMAPIConfig = None,
remote_prefill: bool = False,
min_workers: int = 0,
disagg_config_file: Optional[str] = None,
block_size: int = 32,
router: str = "round_robin",
server_type: ServerType = ServerType.GEN,
): ):
super().__init__(trt_llm_engine_config.engine_config) self._namespace_str = namespace_str
self._namespace_str = trt_llm_engine_config.namespace_str self._component_str = component_str
self._component_str = trt_llm_engine_config.component_str self._worker_id = worker_id
self._worker_id = trt_llm_engine_config.worker_id self._remote_prefill = remote_prefill
self._kv_metrics_publisher = trt_llm_engine_config.kv_metrics_publisher self._min_workers = 0
self._publish_stats = trt_llm_engine_config.publish_stats self._kv_block_size = block_size
self._publish_kv_cache_events = trt_llm_engine_config.publish_kv_cache_events self._router = router
self._kv_block_size = trt_llm_engine_config.kv_block_size self._server_type = server_type
self._error_queue: Optional[Queue] = None self._prefill_client = None
self._error_queue: Queue = Queue()
self._init_engine() self._kv_metrics_publisher = None
if self._remote_prefill:
self._min_workers = min_workers
if disagg_config_file is None or not os.path.exists(disagg_config_file):
raise ValueError(
"llmapi_disaggregated_config file does not exist or not provided"
)
disagg_config = parse_disagg_config_file(disagg_config_file)
server_config: CtxGenServerConfig = None
for config in disagg_config.server_configs:
# Select the first context server config
if config.type == server_type.value:
server_config = config
break
if server_config is None:
server_type_str = (
"generation" if server_type == ServerType.GEN else "context"
)
raise ValueError(
f"No {server_type_str} server config found. Please check the disaggregated config file."
)
engine_config = update_args_from_disagg_config(engine_config, server_config)
if router == "kv":
self._publish_stats = True
self._publish_events = True
else:
self._publish_stats = False
self._publish_events = False
if self._publish_stats:
self._kv_metrics_publisher = KvMetricsPublisher()
self._engine_config = engine_config
def _init_engine(self): def _init_engine(self):
logger.info("Initializing engine") logger.info("Initializing engine")
...@@ -126,12 +167,11 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin): ...@@ -126,12 +167,11 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
self._event_thread = None self._event_thread = None
raise e raise e
self._error_queue = Queue()
try: try:
if self._publish_stats: if self._publish_stats:
self._init_publish_metrics_thread() self._init_publish_metrics_thread()
if self._publish_kv_cache_events: if self._publish_events:
self._init_publish_kv_cache_events_thread() self._init_publish_kv_cache_events_thread()
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize publish metrics threads: {e}") logger.error(f"Failed to initialize publish metrics threads: {e}")
...@@ -308,7 +348,10 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin): ...@@ -308,7 +348,10 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
try: try:
llm = await loop.run_in_executor( llm = await loop.run_in_executor(
None, None,
lambda: LLM(model=self._model, **self._engine_config.to_dict()), lambda: LLM(
model=self._engine_config.model_name,
**self._engine_config.to_dict(),
),
) )
yield llm yield llm
finally: finally:
...@@ -368,3 +411,106 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin): ...@@ -368,3 +411,106 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
self._llm_engine = None self._llm_engine = None
logger.info("Shutdown complete") logger.info("Shutdown complete")
async def _get_remote_prefill_response(self, request):
prefill_request = copy.deepcopy(request)
prefill_request.sampling_params["max_tokens"] = 1
prefill_request.disaggregated_params = DisaggregatedParams(
request_type=DisaggRequestType.CONTEXT_ONLY.value
)
if self._prefill_client is None:
raise ValueError("Prefill client not initialized")
# TODO: Use smart KV router to determine which prefill worker to use.
ctx_responses = [
ctx_response
async for ctx_response in await self._prefill_client.round_robin(
prefill_request.model_dump_json()
)
]
if len(ctx_responses) > 1:
raise ValueError(
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
)
logger.debug(
f"Received response from prefill worker: {ctx_responses[0].data()}"
)
ctx_response_obj = TRTLLMWorkerResponse.model_validate_json(
ctx_responses[0].data()
)
ctx_response_obj.outputs = [
TRTLLMWorkerResponseOutput(**ctx_response_obj.outputs[0])
]
assert ctx_response_obj.outputs[0].disaggregated_params is not None
return ctx_response_obj
async def generate(self, request: TRTLLMWorkerRequest):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
if not self._error_queue.empty():
raise self._error_queue.get()
self._ongoing_request_count += 1
try:
worker_inputs = request.tokens.tokens
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
request.disaggregated_params
)
)
if self._remote_prefill and self._server_type == ServerType.GEN:
ctx_response_obj = await self._get_remote_prefill_response(request)
worker_inputs = ctx_response_obj.prompt_token_ids
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
DisaggregatedParams(
**ctx_response_obj.outputs[0].disaggregated_params
)
)
)
disaggregated_params.request_type = (
DisaggRequestType.GENERATION_ONLY.value
)
logger.debug(
f"Worker inputs: {worker_inputs}, disaggregated params: {disaggregated_params}"
)
sampling_params = get_sampling_params(request.sampling_params)
async for response in self._llm_engine.generate_async(
inputs=worker_inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=False
if self._server_type == ServerType.CTX
else request.streaming,
):
# Convert the disaggregated params to OAI format so
# it can be sent over the network.
response.outputs[
0
].disaggregated_params = DisaggregatedTypeConverter.to_oai_disaggregated_params(
response.outputs[0].disaggregated_params
)
yield TRTLLMWorkerResponse(
request_id=request.id,
prompt_token_ids=response.prompt_token_ids,
outputs=[asdict(response.outputs[0])],
finished=response.finished,
).model_dump_json(exclude_unset=True)
except CppExecutorError:
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
self._start_threads()
self._ongoing_request_count -= 1
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from dataclasses import asdict from dataclasses import asdict
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from common.parser import LLMAPIConfig
from common.protocol import ( from common.protocol import (
DisaggregatedTypeConverter, DisaggregatedTypeConverter,
DynamoTRTLLMChatCompletionResponseStreamChoice, DynamoTRTLLMChatCompletionResponseStreamChoice,
...@@ -27,10 +29,9 @@ from common.protocol import ( ...@@ -27,10 +29,9 @@ from common.protocol import (
TRTLLMWorkerResponse, TRTLLMWorkerResponse,
TRTLLMWorkerResponseOutput, TRTLLMWorkerResponseOutput,
) )
from common.utils import ConversationMessage, ServerType from common.utils import ConversationMessage
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import ( from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs, ChatCompletionLogProbs,
ChatCompletionLogProbsContent, ChatCompletionLogProbsContent,
...@@ -41,10 +42,44 @@ from tensorrt_llm.serve.openai_protocol import ( ...@@ -41,10 +42,44 @@ from tensorrt_llm.serve.openai_protocol import (
ToolCall, ToolCall,
UsageInfo, UsageInfo,
) )
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
logger.set_level("debug") logger = logging.getLogger(__name__)
class ChatProcessorMixin:
def __init__(
self, engine_config: LLMAPIConfig, using_engine_generator: bool = False
):
self._engine_config = engine_config
logger.info(f"Using LLM API config: {self._engine_config.to_dict()}")
# model name for chat processor
self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}")
# model for LLMAPI input
self._model = self._model_name
if self._engine_config.model_path:
self._model = self._engine_config.model_path
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_path
)
logger.info(f"Using model from path: {self._engine_config.model_path}")
else:
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_name
)
if self._engine_config.extra_args.get("tokenizer", None):
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.extra_args.get("tokenizer", None)
)
self.chat_processor = ChatProcessor(
self._model_name, self._tokenizer, using_engine_generator
)
self.completions_processor = CompletionsProcessor(
self._model_name, self._tokenizer
)
def parse_chat_message_content( def parse_chat_message_content(
...@@ -290,6 +325,7 @@ class ChatProcessor(BaseChatProcessor): ...@@ -290,6 +325,7 @@ class ChatProcessor(BaseChatProcessor):
return TRTLLMWorkerRequest( return TRTLLMWorkerRequest(
id=request.id, id=request.id,
model=request.model,
prompt=prompt, prompt=prompt,
sampling_params=asdict(sampling_params), sampling_params=asdict(sampling_params),
conversation=conversation, conversation=conversation,
...@@ -303,8 +339,10 @@ class ChatProcessor(BaseChatProcessor): ...@@ -303,8 +339,10 @@ class ChatProcessor(BaseChatProcessor):
engine_generator, engine_generator,
request, request,
conversation, conversation,
server_type: ServerType,
): ):
first_iteration = True
last_text_len = 0
last_token_ids_len = 0
async for raw_response in engine_generator: async for raw_response in engine_generator:
if self.using_engine_generator: if self.using_engine_generator:
response = TRTLLMWorkerResponse( response = TRTLLMWorkerResponse(
...@@ -317,21 +355,27 @@ class ChatProcessor(BaseChatProcessor): ...@@ -317,21 +355,27 @@ class ChatProcessor(BaseChatProcessor):
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])] response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
else: else:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data()) response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
response.outputs[0]["text"] = self.tokenizer.decode(
response.outputs[0]["token_ids"]
)
# Need to keep track of the last text and token ids length
# to calculate the diff.
# TODO: This is a hack to get the diff. We should identify why
# the diff is not being calculated in the worker.
response.outputs[0]["_last_text_len"] = last_text_len
response.outputs[0]["_last_token_ids_len"] = last_token_ids_len
last_text_len = len(response.outputs[0]["text"])
last_token_ids_len = len(response.outputs[0]["token_ids"])
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])] response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
if ( response_data = self.create_chat_stream_response(
request.disaggregated_params is not None request,
and server_type == ServerType.CTX request.id,
): response,
response_data = self.yield_first_chat(request, request.id, response) conversation,
else: first_iteration=first_iteration,
response_data = self.create_chat_stream_response( )
request, first_iteration = False
request.id,
response,
conversation,
first_iteration=(not request.disaggregated_params is not None),
)
logger.debug(f"[postprocessor] Response: {response_data}") logger.debug(f"[postprocessor] Response: {response_data}")
yield response_data yield response_data
......
...@@ -14,11 +14,10 @@ ...@@ -14,11 +14,10 @@
# limitations under the License. # limitations under the License.
import ctypes import ctypes
import logging
from ctypes import c_char_p, c_int64, c_uint32 from ctypes import c_char_p, c_int64, c_uint32
from tensorrt_llm.logger import logger logger = logging.getLogger(__name__)
logger.set_level("info")
class DynamoResult: class DynamoResult:
...@@ -53,7 +52,7 @@ class KVCacheEventPublisher: ...@@ -53,7 +52,7 @@ class KVCacheEventPublisher:
logger.info("KVCacheEventPublisher initialization failed!") logger.info("KVCacheEventPublisher initialization failed!")
except Exception as e: except Exception as e:
print(f"Failed to load {lib_path}") logger.exception(f"Failed to load {lib_path}")
raise e raise e
self.lib.dynamo_kv_event_publish_stored.argtypes = [ self.lib.dynamo_kv_event_publish_stored.argtypes = [
......
...@@ -40,6 +40,13 @@ class LLMAPIConfig: ...@@ -40,6 +40,13 @@ class LLMAPIConfig:
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.extra_args = kwargs self.extra_args = kwargs
# Hardcoded to skip tokenizer init for now.
# We will handle the tokenization/detokenization
# in the base engine.
if "skip_tokenizer_init" in self.extra_args:
self.extra_args.pop("skip_tokenizer_init")
self.skip_tokenizer_init = True
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
data = { data = {
"pytorch_backend_config": self.pytorch_backend_config, "pytorch_backend_config": self.pytorch_backend_config,
...@@ -133,6 +140,12 @@ def parse_tensorrt_llm_args( ...@@ -133,6 +140,12 @@ def parse_tensorrt_llm_args(
default=1, default=1,
help="Minimum number of workers for aggregated (monolith) server", help="Minimum number of workers for aggregated (monolith) server",
) )
parser.add_argument(
"--min-prefill-workers",
type=int,
default=1,
help="Minimum number of prefill workers for disaggregated server",
)
parser.add_argument( parser.add_argument(
"--block-size", "--block-size",
type=int, type=int,
...@@ -156,14 +169,6 @@ def parse_dynamo_run_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]] ...@@ -156,14 +169,6 @@ def parse_dynamo_run_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]
parser.add_argument( parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file" "--engine_args", type=str, required=True, help="Path to the engine args file"
) )
# Disaggregated mode is not supported in dynamo-run launcher yet.
# parser.add_argument(
# "--llmapi-disaggregated-config",
# "-c",
# type=str,
# help="Path to the llmapi disaggregated config file",
# default=None,
# )
parser.add_argument( parser.add_argument(
"--publish-kv-cache-events", "--publish-kv-cache-events",
action="store_true", action="store_true",
......
...@@ -23,7 +23,6 @@ import torch ...@@ -23,7 +23,6 @@ import torch
from common.utils import ConversationMessage from common.utils import ConversationMessage
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.serve.openai_protocol import ( from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponseStreamChoice, ChatCompletionResponseStreamChoice,
...@@ -59,6 +58,7 @@ class Request(BaseModel): ...@@ -59,6 +58,7 @@ class Request(BaseModel):
class TRTLLMWorkerRequest(BaseModel): class TRTLLMWorkerRequest(BaseModel):
model: str
id: str id: str
prompt: str | None = None prompt: str | None = None
sampling_params: dict sampling_params: dict
...@@ -67,44 +67,6 @@ class TRTLLMWorkerRequest(BaseModel): ...@@ -67,44 +67,6 @@ class TRTLLMWorkerRequest(BaseModel):
tokens: Optional[Tokens] = Field(default=None) tokens: Optional[Tokens] = Field(default=None)
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None) disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
def to_sampling_params(self) -> SamplingParams:
sampling_params = SamplingParams(
frequency_penalty=self.sampling_params.get("frequency_penalty", 0.0),
return_log_probs=self.sampling_params.get("logprobs", False),
max_tokens=self.sampling_params.get("max_tokens", 16),
n=self.sampling_params.get("n", 1),
presence_penalty=self.sampling_params.get("presence_penalty", 0.0),
seed=self.sampling_params.get("seed", None),
stop=self.sampling_params.get("stop", None),
temperature=self.sampling_params.get("temperature", 0.7),
# chat-completion-sampling-params
best_of=self.sampling_params.get("best_of", None),
use_beam_search=self.sampling_params.get("use_beam_search", False),
top_k=self.sampling_params.get("top_k", 0),
top_p=self.sampling_params.get("top_p", 1.0),
top_p_min=self.sampling_params.get("top_p_min", None),
min_p=self.sampling_params.get("min_p", 0.0),
repetition_penalty=self.sampling_params.get("repetition_penalty", 1.0),
length_penalty=self.sampling_params.get("length_penalty", 1.0),
early_stopping=self.sampling_params.get("early_stopping", False),
stop_token_ids=self.sampling_params.get("stop_token_ids", []),
include_stop_str_in_output=self.sampling_params.get(
"include_stop_str_in_output", False
),
ignore_eos=self.sampling_params.get("ignore_eos", False),
min_tokens=self.sampling_params.get("min_tokens", 0),
skip_special_tokens=self.sampling_params.get("skip_special_tokens", False),
spaces_between_special_tokens=self.sampling_params.get(
"spaces_between_special_tokens", False
),
truncate_prompt_tokens=self.sampling_params.get(
"truncate_prompt_tokens", None
),
# chat-completion-extra-params
add_special_tokens=self.sampling_params.get("add_special_tokens", False),
)
return sampling_params
@dataclass @dataclass
class TRTLLMWorkerResponseOutput: class TRTLLMWorkerResponseOutput:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import asyncio import asyncio
import logging
import threading import threading
import traceback import traceback
import weakref import weakref
...@@ -22,9 +23,7 @@ from enum import Enum ...@@ -22,9 +23,7 @@ from enum import Enum
from queue import Queue from queue import Queue
from typing import Callable, Optional, TypedDict, Union from typing import Callable, Optional, TypedDict, Union
from tensorrt_llm.logger import logger logger = logging.getLogger(__name__)
logger.set_level("info")
class RoutingStrategy(Enum): class RoutingStrategy(Enum):
...@@ -43,6 +42,8 @@ class ServerType(Enum): ...@@ -43,6 +42,8 @@ class ServerType(Enum):
GEN = "gen" GEN = "gen"
# Context server used for disaggregated requests # Context server used for disaggregated requests
CTX = "ctx" CTX = "ctx"
# Dynamo run server used for Dynamo run requests
DYN_RUN = "dyn_run"
class ConversationMessage(TypedDict): class ConversationMessage(TypedDict):
......
...@@ -13,11 +13,12 @@ ...@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from components.agg_worker import TensorRTLLMWorker
from components.processor import Processor from components.processor import Processor
from components.worker import TensorRTLLMWorker
from pydantic import BaseModel from pydantic import BaseModel
from dynamo import sdk from dynamo import sdk
...@@ -25,6 +26,8 @@ from dynamo.sdk import depends, service ...@@ -25,6 +26,8 @@ from dynamo.sdk import depends, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__)
def get_http_binary_path(): def get_http_binary_path():
sdk_path = Path(sdk.__file__) sdk_path = Path(sdk.__file__)
...@@ -75,7 +78,7 @@ class Frontend: ...@@ -75,7 +78,7 @@ class Frontend:
] ]
) )
print("Starting HTTP server") logger.info("Starting HTTP server")
http_binary = get_http_binary_path() http_binary = get_http_binary_path()
process = subprocess.Popen( process = subprocess.Popen(
[http_binary, "-p", str(frontend_config.port)], stdout=None, stderr=None [http_binary, "-p", str(frontend_config.port)], stdout=None, stderr=None
......
...@@ -15,20 +15,20 @@ ...@@ -15,20 +15,20 @@
import argparse import argparse
import asyncio import asyncio
import logging
import random import random
import traceback import traceback
from argparse import Namespace from argparse import Namespace
from typing import AsyncIterator from typing import AsyncIterator
from common.protocol import Tokens from common.protocol import Tokens
from components.agg_worker import TensorRTLLMWorker from components.worker import TensorRTLLMWorker
from tensorrt_llm.logger import logger
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
logger.set_level("debug") logger = logging.getLogger(__name__)
WorkerId = str WorkerId = str
...@@ -92,8 +92,7 @@ class Router: ...@@ -92,8 +92,7 @@ class Router:
.client() .client()
) )
while len(self.workers_client.endpoint_ids()) < self.args.min_workers: while len(self.workers_client.endpoint_ids()) < self.args.min_workers:
# TODO: replace print w/ vllm_logger.info logger.info(
print(
f"Waiting for more workers to be ready.\n" f"Waiting for more workers to be ready.\n"
f" Current: {len(self.workers_client.endpoint_ids())}," f" Current: {len(self.workers_client.endpoint_ids())},"
f" Required: {self.args.min_workers}" f" Required: {self.args.min_workers}"
...@@ -104,7 +103,7 @@ class Router: ...@@ -104,7 +103,7 @@ class Router:
await kv_listener.create_service() await kv_listener.create_service()
self.indexer = KvIndexer(kv_listener, self.args.block_size) self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener) self.metrics_aggregator = KvMetricsAggregator(kv_listener)
print("KV Router initialized") logger.info("KV Router initialized")
def _cost_function( def _cost_function(
self, self,
......
...@@ -13,20 +13,17 @@ ...@@ -13,20 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import signal import logging
from dataclasses import asdict
from common.base_engine import BaseTensorrtLLMEngine, TensorrtLLMEngineConfig from common.base_engine import BaseTensorrtLLMEngine
from common.parser import parse_tensorrt_llm_args from common.parser import parse_tensorrt_llm_args
from common.protocol import TRTLLMWorkerRequest, TRTLLMWorkerResponse from common.protocol import TRTLLMWorkerRequest
from tensorrt_llm.executor import CppExecutorError from common.utils import ServerType
from tensorrt_llm.logger import logger
from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
logger.set_level("debug") logger = logging.getLogger(__name__)
@service( @service(
...@@ -37,84 +34,42 @@ logger.set_level("debug") ...@@ -37,84 +34,42 @@ logger.set_level("debug")
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
) )
class TensorRTLLMWorker(BaseTensorrtLLMEngine): class TensorRTLLMPrefillWorker(BaseTensorrtLLMEngine):
"""
Request handler for the generate endpoint
"""
def __init__(self): def __init__(self):
print("Initializing TensorRT-LLM Worker") logger.info("Initializing TensorRT-LLM Prefill Worker")
class_name = self.__class__.__name__ class_name = self.__class__.__name__
config = ServiceConfig.get_instance() config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="") config_args = config.as_args(class_name, prefix="")
self.args, self.engine_config = parse_tensorrt_llm_args(config_args) args, engine_config = parse_tensorrt_llm_args(config_args)
worker_id = dynamo_context["endpoints"][0].lease_id()
if self.args.router == "kv": super().__init__(
publish_stats = True
publish_events = True
else:
publish_stats = False
publish_events = False
trt_llm_engine_config = TensorrtLLMEngineConfig(
namespace_str="dynamo", namespace_str="dynamo",
component_str=class_name, component_str=class_name,
engine_config=self.engine_config, worker_id=worker_id,
publish_stats=publish_stats, engine_config=engine_config,
publish_kv_cache_events=publish_events, remote_prefill=args.remote_prefill,
kv_block_size=self.args.block_size, min_workers=args.min_workers,
disagg_config_file=args.llmapi_disaggregated_config,
block_size=args.block_size,
router=args.router,
server_type=ServerType.CTX,
) )
if publish_stats:
trt_llm_engine_config.kv_metrics_publisher = KvMetricsPublisher()
trt_llm_engine_config.worker_id = dynamo_context["endpoints"][0].lease_id()
self.trtllm_engine_args = trt_llm_engine_config
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
super().__init__(self.trtllm_engine_args) self._init_engine()
task = asyncio.create_task(self.create_metrics_publisher_endpoint()) if self._kv_metrics_publisher is not None:
task.add_done_callback(lambda _: print("metrics publisher endpoint created")) task = asyncio.create_task(self.create_metrics_publisher_endpoint())
print("TensorRT-LLM Worker initialized") task.add_done_callback(
lambda _: logger.info("metrics publisher endpoint created")
)
logger.info("TensorRT-LLM Prefill Worker initialized")
async def create_metrics_publisher_endpoint(self): async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"] component = dynamo_context["component"]
await self.trtllm_engine_args.kv_metrics_publisher.create_endpoint(component) await self.kv_metrics_publisher.create_endpoint(component)
@dynamo_endpoint() @dynamo_endpoint()
async def generate(self, request: TRTLLMWorkerRequest): async def generate(self, request: TRTLLMWorkerRequest):
if self._llm_engine is None: async for response in super().generate(request):
raise RuntimeError("Engine not initialized") yield response
if self._error_queue.qsize() > 0:
error = self._error_queue.get()
raise error
self._ongoing_request_count += 1
try:
# TODO: combine with disagg worker
# TODO: only send tokens. Should be pretty simple.
async for response in self._llm_engine.generate_async(
inputs=request.prompt,
sampling_params=request.to_sampling_params(),
disaggregated_params=None,
streaming=True,
):
yield TRTLLMWorkerResponse(
request_id=request.id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
outputs=[asdict(response.outputs[0])],
finished=response.finished,
).model_dump_json(exclude_unset=True)
except CppExecutorError:
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
self._start_threads()
self._ongoing_request_count -= 1
...@@ -15,19 +15,19 @@ ...@@ -15,19 +15,19 @@
import asyncio import asyncio
import json import json
import logging
from common.base_engine import ChatProcessorMixin from common.chat_processor import ChatProcessorMixin
from common.parser import parse_tensorrt_llm_args from common.parser import parse_tensorrt_llm_args
from common.protocol import DynamoTRTLLMChatCompletionRequest from common.protocol import DynamoTRTLLMChatCompletionRequest
from common.utils import RequestType, ServerType from common.utils import RequestType
from components.agg_worker import TensorRTLLMWorker
from components.kv_router import Router from components.kv_router import Router
from tensorrt_llm.logger import logger from components.worker import TensorRTLLMWorker
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
logger.set_level("debug") logger = logging.getLogger(__name__)
@service( @service(
...@@ -48,11 +48,13 @@ class Processor(ChatProcessorMixin): ...@@ -48,11 +48,13 @@ class Processor(ChatProcessorMixin):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
config = ServiceConfig.get_instance() config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="") config_args = config.as_args(class_name, prefix="")
self.args, self.engine_config = parse_tensorrt_llm_args(config_args) args, engine_config = parse_tensorrt_llm_args(config_args)
self.router_mode = self.args.router self.remote_prefill = args.remote_prefill
super().__init__(self.engine_config) self.router_mode = args.router
self.min_workers = 1 self.min_workers = 1
super().__init__(engine_config)
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
runtime = dynamo_context["runtime"] runtime = dynamo_context["runtime"]
...@@ -64,7 +66,7 @@ class Processor(ChatProcessorMixin): ...@@ -64,7 +66,7 @@ class Processor(ChatProcessorMixin):
.client() .client()
) )
while len(self.worker_client.endpoint_ids()) < self.min_workers: while len(self.worker_client.endpoint_ids()) < self.min_workers:
print( logger.info(
f"Waiting for workers to be ready.\n" f"Waiting for workers to be ready.\n"
f" Current: {len(self.worker_client.endpoint_ids())}," f" Current: {len(self.worker_client.endpoint_ids())},"
f" Required: {self.min_workers}" f" Required: {self.min_workers}"
...@@ -97,15 +99,16 @@ class Processor(ChatProcessorMixin): ...@@ -97,15 +99,16 @@ class Processor(ChatProcessorMixin):
break break
if worker_id == "": if worker_id == "":
if self.args.router == "round-robin": if self.router_mode == "round-robin":
engine_generator = await self.worker_client.round_robin( self._send_request = self.worker_client.round_robin
preprocessed_request.model_dump_json()
)
else: else:
# fallback to random # fallback to random
engine_generator = await self.worker_client.random( self._send_request = self.worker_client.random
preprocessed_request.model_dump_json()
) engine_generator = await self._send_request(
preprocessed_request.model_dump_json()
)
else: else:
engine_generator = await self.worker_client.direct( engine_generator = await self.worker_client.direct(
preprocessed_request.model_dump_json(), int(worker_id) preprocessed_request.model_dump_json(), int(worker_id)
...@@ -116,7 +119,6 @@ class Processor(ChatProcessorMixin): ...@@ -116,7 +119,6 @@ class Processor(ChatProcessorMixin):
engine_generator, engine_generator,
raw_request, raw_request,
preprocessed_request.conversation, preprocessed_request.conversation,
ServerType.GEN,
): ):
logger.debug(f"[preprocessor] Response: {response}") logger.debug(f"[preprocessor] Response: {response}")
yield json.loads(response) yield json.loads(response)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from common.base_engine import BaseTensorrtLLMEngine
from common.parser import parse_tensorrt_llm_args
from common.protocol import TRTLLMWorkerRequest
from common.utils import ServerType
from components.prefill_worker import TensorRTLLMPrefillWorker
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class TensorRTLLMWorker(BaseTensorrtLLMEngine):
prefill_worker = depends(TensorRTLLMPrefillWorker)
def __init__(self):
logger.info("Initializing TensorRT-LLM Worker")
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
args, engine_config = parse_tensorrt_llm_args(config_args)
worker_id = dynamo_context["endpoints"][0].lease_id()
self._min_prefill_workers = args.min_prefill_workers
super().__init__(
namespace_str="dynamo",
component_str=class_name,
worker_id=worker_id,
engine_config=engine_config,
remote_prefill=args.remote_prefill,
min_workers=args.min_workers,
disagg_config_file=args.llmapi_disaggregated_config,
block_size=args.block_size,
router=args.router,
server_type=ServerType.GEN,
)
@async_on_start
async def async_init(self):
self._init_engine()
if self._remote_prefill:
runtime = dynamo_context["runtime"]
comp_ns, comp_name = TensorRTLLMPrefillWorker.dynamo_address() # type: ignore
self._prefill_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
while len(self._prefill_client.endpoint_ids()) < self._min_prefill_workers:
logger.info(
f"Waiting for prefill workers to be ready.\n"
f" Current: {len(self._prefill_client.endpoint_ids())},"
f" Required: {self._min_prefill_workers}"
)
await asyncio.sleep(2)
if self._kv_metrics_publisher is not None:
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logger.info("metrics publisher endpoint created")
)
logger.info("TensorRT-LLM Worker initialized")
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
await self._kv_metrics_publisher.create_endpoint(component)
@dynamo_endpoint()
async def generate(self, request: TRTLLMWorkerRequest):
async for response in super().generate(request):
yield response
...@@ -19,7 +19,7 @@ Frontend: ...@@ -19,7 +19,7 @@ Frontend:
port: 8000 port: 8000
Processor: Processor:
engine_args: "configs/llm_api_config.yaml" engine_args: "configs/llm_api_config_router.yaml"
router: kv router: kv
Router: Router:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000
Processor:
engine_args: "configs/llm_api_config.yaml"
router: round-robin
remote-prefill: true
TensorRTLLMWorker:
engine_args: "configs/llm_api_config.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml"
remote-prefill: true
min-prefill-workers: 1
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
TensorRTLLMPrefillWorker:
engine_args: "configs/llm_api_config.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000
Processor:
engine_args: "configs/llm_api_config_disagg_router.yaml"
router: "kv"
remote-prefill: true
Router:
model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
min-workers: 1
TensorRTLLMWorker:
engine_args: "configs/llm_api_config_disagg_router.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml"
remote-prefill: true
min-prefill-workers: 1
router: kv
ServiceArgs:
workers: 1
resources:
gpu: 1
TensorRTLLMPrefillWorker:
engine_args: "configs/llm_api_config_disagg_router.yaml"
llmapi-disaggregated-config: "configs/llmapi_disagg_configs/single_node_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
# In the case of disaggregated deployment, this config will apply to each server # In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file # and will be overwritten by the disaggregated config file
# TODO: figure out how to generate this from the service config or vice versa
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null model_path: null
tensor_parallel_size: 1 tensor_parallel_size: 1
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file
# TODO: figure out how to generate this from the service config or vice versa
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
kv_cache_config:
free_gpu_memory_fraction: 0.95
event_buffer_max_size: 1024
enable_block_reuse: true
pytorch_backend_config:
enable_overlap_scheduler: false
use_cuda_graph: false
enable_iter_perf_stats: true
\ No newline at end of file
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
# In the case of disaggregated deployment, this config will apply to each server # In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file # and will be overwritten by the disaggregated config file
# TODO: figure out how to generate this from the service config or vice versa
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null model_path: null
tensor_parallel_size: 1 tensor_parallel_size: 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment