Unverified Commit 901715b5 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

refactor: Refactor the TRTLLM examples remove dynamo SDK (#1884)

parent 5bf23d54
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, Optional
from common.protocol import DisaggregatedTypeConverter, TRTLLMWorkerRequest
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.serve.openai_protocol import (
DisaggregatedParams as OAIDisaggregatedParams,
)
from dynamo.llm import get_tensorrtllm_engine, get_tensorrtllm_publisher
from dynamo.runtime import DistributedRuntime
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
raise ValueError(
f"Invalid endpoint format: '{endpoint}'. "
"Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
return (endpoint_parts[0], endpoint_parts[1], endpoint_parts[2])
@dataclass
class BaseEngineConfig:
"""Base engine configuration"""
namespace: str
component: str
endpoint: str
model_path: str
served_model_name: Optional[str] = None
kv_block_size: int = 32
extra_engine_args: str = ""
publish_events_and_metrics: bool = False
disaggregation_mode: str = "prefill_and_decode"
remote_prefill_endpoint: Optional[str] = None
lease_id: int = 0
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
f"component={self.component}, "
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"served_model_name={self.served_model_name}, "
f"kv_block_size={self.kv_block_size}, "
f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
f"remote_prefill_endpoint={self.remote_prefill_endpoint}, "
f"lease_id={self.lease_id})"
)
class BaseTensorrtLLMEngine:
def __init__(
self,
config: BaseEngineConfig,
):
self._config = config
self._prefill_client = None
self._llm_engine = None
self._llm_engine_context = None
self._llm_publisher = None
self._llm_publisher_context = None
self._runtime = None
self._first_generation = True
# Initialize default sampling params
self.default_sampling_params = SamplingParams()
async def initialize(self, runtime: DistributedRuntime):
"""Initialize the engine and prefill client if needed"""
self._runtime = runtime
# Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(self._config.model_path)
# Initialize the LLM engine
engine_args: dict[str, Any] = {
"model": model_path,
"tensor_parallel_size": 1,
"backend": "pytorch",
"skip_tokenizer_init": True,
}
if self._config.extra_engine_args:
# TODO: Support extra engine args from json file as well.
engine_args = update_llm_args_with_extra_options(
engine_args, self._config.extra_engine_args
)
# Update the model path in the config to the model path used by the engine.
self._config.model_path = str(engine_args["model"])
if not self._config.model_path:
raise ValueError(
"Model specification is required. Present neither in the config nor in the extra engine args."
)
# Populate default sampling params from the model
tokenizer = tokenizer_factory(self._config.model_path)
self.default_sampling_params = SamplingParams()
self.default_sampling_params._setup(tokenizer)
self.default_sampling_params.stop = None
if self._config.publish_events_and_metrics:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config: dict[str, Any] | Any = None
if "kv_cache_config" not in engine_args:
kv_cache_config = {}
kv_cache_config[
"event_buffer_max_size"
] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
else:
kv_cache_config = engine_args["kv_cache_config"]
if (
hasattr(kv_cache_config, "event_buffer_max_size")
and not kv_cache_config.event_buffer_max_size
):
kv_cache_config.event_buffer_max_size = (
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
)
elif (
isinstance(kv_cache_config, dict)
and "event_buffer_max_size" not in kv_cache_config
):
kv_cache_config[
"event_buffer_max_size"
] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
engine_args["kv_cache_config"] = kv_cache_config
# Enable iter perf stats by default if we are publishing events and metrics.
if not engine_args.get("enable_iter_perf_stats"):
engine_args["enable_iter_perf_stats"] = True
# Only pytorch backend is supported for now to publish events and metrics.
if engine_args.get("backend") != "pytorch":
logging.error(
"Only pytorch backend is supported for now to publish events and metrics."
)
raise RuntimeError(
"Only pytorch backend is supported for now to publish events and metrics. Hence, KV router is not supported."
)
logging.info(f"TRTLLM engine args: {engine_args}")
# Get the engine using the asynccontextmanager
self._llm_engine_context = get_tensorrtllm_engine(engine_args)
if self._llm_engine_context is not None:
self._llm_engine = await self._llm_engine_context.__aenter__()
else:
raise RuntimeError("Failed to create LLM engine context")
if (
self._config.publish_events_and_metrics
and self._config.disaggregation_mode != "prefill"
):
kv_listener = runtime.namespace(self._config.namespace).component(
self._config.component
)
self._llm_publisher_context = get_tensorrtllm_publisher(
kv_listener,
self._llm_engine,
kv_listener,
self._config.lease_id,
self._config.kv_block_size,
)
if self._llm_publisher_context is not None:
self._llm_publisher = await self._llm_publisher_context.__aenter__()
else:
raise RuntimeError("Failed to create LLM publisher context")
# Initialize prefill client if in decode mode
if self._config.disaggregation_mode == "decode":
if self._config.remote_prefill_endpoint is None:
raise ValueError("remote_prefill_endpoint is required for decode mode")
logging.info(
f"Initializing remote prefill client for endpoint: {self._config.remote_prefill_endpoint}"
)
(
parsed_namespace,
parsed_component_name,
parsed_endpoint_name,
) = parse_endpoint(self._config.remote_prefill_endpoint)
if self._runtime is not None:
self._prefill_client = (
await self._runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
else:
raise RuntimeError("Runtime not initialized")
async def cleanup(self):
"""Cleanup resources"""
if self._llm_publisher_context:
try:
await self._llm_publisher_context.__aexit__(None, None, None)
except Exception as e:
logging.error(f"Error during publisher cleanup: {e}")
finally:
self._llm_publisher = None
self._llm_publisher_context = None
if self._llm_engine_context:
try:
await self._llm_engine_context.__aexit__(None, None, None)
except Exception as e:
logging.error(f"Error during engine cleanup: {e}")
finally:
self._llm_engine = None
self._llm_engine_context = None
self._prefill_client = None
async def remote_prefill(self, request: TRTLLMWorkerRequest):
"""
Send a prefill request to the remote prefill worker.
Args:
request: The original request to be sent for prefill
Returns:
The response from the remote prefill worker
Raises:
ValueError: If prefill client is not initialized or multiple responses received
"""
prefill_request = request.model_copy(deep=True)
# TRTLLM requires max_tokens to be set for prefill requests.
prefill_request.stop_conditions.max_tokens = 1
prefill_request.disaggregated_params = OAIDisaggregatedParams(
request_type="context_only"
)
if self._prefill_client is None:
raise ValueError("Prefill client not initialized")
try:
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
remote_prefill_responses = [
remote_prefill_response
async for remote_prefill_response in await self._prefill_client.round_robin(
prefill_request.model_dump_json()
)
]
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
if len(remote_prefill_responses) > 1:
raise ValueError(
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
)
if len(remote_prefill_responses) == 0:
raise ValueError("No response received from remote prefill worker")
remote_prefill_response = remote_prefill_responses[0]
return remote_prefill_response
async def generate(self, request: TRTLLMWorkerRequest):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
if self._llm_publisher:
publishers_error = self._llm_publisher.check_error_queue()
if publishers_error:
raise publishers_error
inputs = request.token_ids
# Decode the disaggregated params from the request
disaggregated_params = DisaggregatedTypeConverter.to_llm_disaggregated_params(
request.disaggregated_params
)
num_output_tokens_so_far = 0
if self._config.disaggregation_mode == "decode":
# Run prefill/context phase remotely if disaggregation mode is decode.
try:
prefill_result = await self.remote_prefill(request)
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
remote_prefill_response = prefill_result.data()
if (
remote_prefill_response["finish_reason"] == "stop"
or remote_prefill_response["finish_reason"] == "error"
):
yield remote_prefill_response
return
# Decode the disaggregated params from the remote prefill response
# Decode the disaggregated params from the remote prefill response
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
OAIDisaggregatedParams(
**remote_prefill_response["disaggregated_params"]
)
)
)
# Set the disaggregated params to generation_only for the rest of the generation
disaggregated_params.request_type = "generation_only"
sampling_params = self.default_sampling_params
for key, value in request.sampling_options.model_dump().items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request.stop_conditions.max_tokens
if max_tokens:
sampling_params.max_tokens = max_tokens
ignore_eos = request.stop_conditions.ignore_eos
if ignore_eos:
sampling_params.ignore_eos = ignore_eos
# TODO: Disable streaming for context only requests when adding disagg support
async for res in self._llm_engine.llm.generate_async(
inputs=inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=(self._config.disaggregation_mode != "prefill"),
):
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self._first_generation and self._llm_publisher:
self._llm_publisher.start()
self._first_generation = False
if res.finished and self._config.disaggregation_mode != "prefill":
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self._config.disaggregation_mode == "prefill":
# Return the disaggregated params only when operating in prefill mode.
out[
"disaggregated_params"
] = DisaggregatedTypeConverter.to_oai_disaggregated_params(
output.disaggregated_params
).model_dump()
yield out
num_output_tokens_so_far = next_total_toks
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def parse_tensorrt_llm_args(
config_args,
) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="A TensorRT-LLM Worker parser")
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
parser.add_argument(
"--model-path",
type=str,
default=None,
help="Path to disk model or HuggingFace model identifier to load.",
)
parser.add_argument(
"--served_model_name",
type=str,
help="Name to serve the model under.",
)
parser.add_argument(
"--router",
type=str,
choices=["random", "round-robin", "kv"],
default="random",
help="Router type to use for scheduling requests to workers",
)
parser.add_argument(
"--kv-block-size",
type=int,
default=32,
help="Number of tokens per KV block in TRTLLM worker. Default is 32 for pytorch backend.",
)
parser.add_argument(
"--enable-disagg",
action="store_true",
help="Enable remote prefill for the worker",
)
args = parser.parse_args(config_args)
return args
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from typing import List, Optional
from pydantic import BaseModel, Field
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
class Tokens(BaseModel):
tokens: list[int]
TokenIdType = int
class DisaggregatedTypeConverter:
@staticmethod
def to_llm_disaggregated_params(
disaggregated_params: DisaggregatedParams,
) -> LlmDisaggregatedParams:
if disaggregated_params is None:
return None
else:
opaque_state = (
base64.b64decode(disaggregated_params.encoded_opaque_state)
if disaggregated_params.encoded_opaque_state is not None
else None
)
return LlmDisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=opaque_state,
)
@staticmethod
def to_oai_disaggregated_params(
tllm_disagg_params: LlmDisaggregatedParams,
) -> DisaggregatedParams:
if tllm_disagg_params is None:
return None
else:
encoded_opaque_state = (
base64.b64encode(tllm_disagg_params.opaque_state).decode("utf-8")
if tllm_disagg_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=tllm_disagg_params.request_type,
first_gen_tokens=tllm_disagg_params.first_gen_tokens,
ctx_request_id=tllm_disagg_params.ctx_request_id,
encoded_opaque_state=encoded_opaque_state,
)
# TODO: move these to common for all LLMs once we adopt dynamo-run
# derived from lib/llm/src/protocols/common/preprocessor.rs
class StopConditions(BaseModel):
max_tokens: Optional[int] = None
stop: Optional[List[str]] = None
stop_token_ids_hidden: Optional[List[TokenIdType]] = None
min_tokens: Optional[int] = None
ignore_eos: Optional[bool] = None
class SamplingOptions(BaseModel):
n: Optional[int] = None
best_of: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
repetition_penalty: Optional[float] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
min_p: Optional[float] = None
use_beam_search: Optional[bool] = None
length_penalty: Optional[float] = None
seed: Optional[int] = None
class TRTLLMWorkerRequest(BaseModel):
token_ids: List[TokenIdType]
stop_conditions: StopConditions
sampling_options: SamplingOptions
eos_token_ids: List[TokenIdType] = Field(default_factory=list)
mdc_sum: Optional[str] = None
annotations: List[str] = Field(default_factory=list)
estimated_prefix_hit_num_blocks: Optional[int] = None
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import subprocess
from pathlib import Path
from components.worker import TensorRTLLMWorker
from fastapi import FastAPI
from pydantic import BaseModel
from dynamo import sdk
from dynamo.sdk import depends, service
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__)
def get_dynamo_run_binary():
"""Find the dynamo-run binary path in SDK or fallback to 'dynamo-run' command."""
sdk_path = Path(sdk.__file__)
binary_path = sdk_path.parent / "cli/bin/dynamo-run"
if not binary_path.exists():
return "dynamo-run"
else:
return str(binary_path)
class FrontendConfig(BaseModel):
"""Configuration for the Frontend service including model and HTTP server settings."""
served_model_name: str
endpoint: str
port: int = 8000
router: str = "round-robin"
block_size: int = 32
# todo this should be called ApiServer
@service(
dynamo={
"namespace": "dynamo",
},
workers=1,
image=DYNAMO_IMAGE,
app=FastAPI(title="TensorRT-LLM Example"),
)
class Frontend:
worker = depends(TensorRTLLMWorker)
def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
self.frontend_config = FrontendConfig(
**ServiceConfig.get_parsed_config("Frontend")
)
self.process = None
logger.warning(f"Frontend config: {self.frontend_config}")
self.start_ingress_and_processor()
def start_ingress_and_processor(self):
"""Starting dynamo-run based ingress and processor"""
logger.info(
f"Starting HTTP server and processor on port {self.frontend_config.port}"
)
dynamo_run_binary = get_dynamo_run_binary()
cmd = [
dynamo_run_binary,
"in=http",
"out=dyn",
"--http-port",
str(self.frontend_config.port),
"--router-mode",
self.frontend_config.router,
]
logger.info(f"Frontend cmd: {cmd}")
self.process = subprocess.Popen(
cmd,
stdout=None,
stderr=None,
)
def close(self):
"""Clean up resources by terminating the subprocess."""
if self.process is not None:
try:
logger.info("Terminating subprocess...")
self.process.terminate()
# Wait for process to terminate with a timeout
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.warning("Subprocess did not terminate gracefully, forcing kill")
self.process.kill()
self.process.wait()
except Exception as e:
logger.error(f"Error while terminating subprocess: {e}")
finally:
self.process = None
def __del__(self):
"""Destructor to ensure subprocess is cleaned up."""
self.close()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from common.base_engine import BaseEngineConfig, BaseTensorrtLLMEngine
from common.parser import parse_tensorrt_llm_args
from common.protocol import TRTLLMWorkerRequest
from dynamo.sdk import async_on_start, dynamo_context, endpoint, on_shutdown, service
from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__)
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class TensorRTLLMPrefillWorker(BaseTensorrtLLMEngine):
def __init__(self):
logger.info("Initializing TensorRT-LLM Prefill Worker")
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
args = parse_tensorrt_llm_args(config_args)
lease_id = dynamo_context["endpoints"][0].lease_id()
namespace, _ = TensorRTLLMPrefillWorker.dynamo_address() # type: ignore
engine_config = BaseEngineConfig(
namespace=namespace,
component=class_name,
endpoint="generate",
model_path=args.model_path,
served_model_name=args.served_model_name,
kv_block_size=args.kv_block_size,
extra_engine_args=args.extra_engine_args,
publish_events_and_metrics=False,
disaggregation_mode="prefill",
remote_prefill_endpoint=None,
lease_id=lease_id,
)
super().__init__(config=engine_config)
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
await self.initialize(runtime)
logger.info("TensorRT-LLM Prefill Worker initialized")
@on_shutdown
async def async_cleanup(self):
logger.info("Cleaning up TensorRT-LLM Prefill Worker")
await self.cleanup()
logger.info("TensorRT-LLM Prefill Worker cleanup completed")
@endpoint()
async def generate(self, request: TRTLLMWorkerRequest):
async for response in super().generate(request):
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); import asyncio
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import os
import signal
import sys
from typing import TYPE_CHECKING
import uvloop
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from common.base_engine import BaseEngineConfig, BaseTensorrtLLMEngine from dynamo.llm import (
from common.parser import parse_tensorrt_llm_args ModelType,
from common.protocol import TRTLLMWorkerRequest get_tensorrtllm_engine,
from components.prefill_worker import TensorRTLLMPrefillWorker get_tensorrtllm_publisher,
register_llm,
from dynamo.llm import ModelType, register_llm
from dynamo.sdk import (
async_on_start,
depends,
dynamo_context,
endpoint,
on_shutdown,
service,
) )
from dynamo.sdk.lib.config import ServiceConfig from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
logger = logging.getLogger(__name__) if TYPE_CHECKING:
from utils.trtllm_utils import Config
@service( def _setup_path_and_imports():
dynamo={ """Setup path and import utils modules"""
"namespace": "dynamo", # Add the parent directory to the Python path so we can import utils
}, parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, if parent_dir not in sys.path:
workers=1, sys.path.insert(0, parent_dir)
)
class TensorRTLLMWorker(BaseTensorrtLLMEngine): from utils.request_handlers.handlers import (
prefill_worker = depends(TensorRTLLMPrefillWorker) RequestHandlerConfig,
RequestHandlerFactory,
def __init__(self): )
logger.info("Initializing TensorRT-LLM Worker") from utils.trtllm_utils import (
class_name = self.__class__.__name__ Config,
config = ServiceConfig.get_instance() cmd_line_args,
config_args = config.as_args(class_name, prefix="") is_first_worker,
args = parse_tensorrt_llm_args(config_args) parse_endpoint,
lease_id = dynamo_context["endpoints"][0].lease_id() )
namespace, _ = TensorRTLLMWorker.dynamo_address() # type: ignore
endpoint_name = "generate" return (
publish_events_and_metrics = args.router == "kv" RequestHandlerConfig,
prefill_class_name = "TensorRTLLMPrefillWorker" RequestHandlerFactory,
Config,
if args.enable_disagg: cmd_line_args,
disaggregation_mode = "decode" is_first_worker,
else: parse_endpoint,
disaggregation_mode = "prefill_and_decode" )
engine_config = BaseEngineConfig(
namespace=namespace, # Import utils modules
component=class_name, (
endpoint=endpoint_name, RequestHandlerConfig,
model_path=args.model_path, RequestHandlerFactory,
served_model_name=args.served_model_name, Config,
kv_block_size=args.kv_block_size, cmd_line_args,
extra_engine_args=args.extra_engine_args, is_first_worker,
publish_events_and_metrics=publish_events_and_metrics, parse_endpoint,
disaggregation_mode=disaggregation_mode, ) = _setup_path_and_imports()
remote_prefill_endpoint=f"dyn://{namespace}.{prefill_class_name}.generate",
lease_id=lease_id, # Default buffer size for kv cache events.
) DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
configure_dynamo_logging()
async def graceful_shutdown(runtime):
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
config = cmd_line_args()
await init(runtime, config)
super().__init__(config=engine_config)
@async_on_start async def init(runtime: DistributedRuntime, config: Config):
async def async_init(self): """
runtime = dynamo_context["runtime"] Instantiate and serve
await self.initialize(runtime) """
logging.info(f"Initializing the worker with config: {config}")
logger.info("Registering LLM for discovery") next_client = None
endpoint = ( if config.next_endpoint:
runtime.namespace(self._config.namespace) logging.info(
.component(self._config.component) f"Initializing next worker client for endpoint: {config.next_endpoint}"
.endpoint(self._config.endpoint)
) )
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
config.next_endpoint
)
next_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
# Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(config.model_path)
arg_map = {
"model": model_path,
"tensor_parallel_size": config.tensor_parallel_size,
"backend": "pytorch",
"skip_tokenizer_init": True,
}
if config.extra_engine_args != "":
# TODO: Support extra engine args from json file as well.
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
if config.publish_events_and_metrics:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config = None
if "kv_cache_config" not in arg_map:
kv_cache_config = {}
kv_cache_config["event_buffer_max_size"] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
else:
kv_cache_config = arg_map["kv_cache_config"]
if not kv_cache_config.event_buffer_max_size:
kv_cache_config.event_buffer_max_size = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
arg_map["kv_cache_config"] = kv_cache_config
# Only pytorch backend is supported for now to publish events and metrics.
if "backend" not in arg_map:
arg_map["backend"] = "pytorch"
elif arg_map["backend"] != "pytorch":
logging.error(
"Only pytorch backend is supported for now to publish events and metrics."
)
sys.exit(1)
logging.info(f"TensorRT-LLM engine args: {arg_map}")
engine_args = arg_map
try: # Populate default sampling params from the model
tokenizer = tokenizer_factory(arg_map["model"])
default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None
async with get_tensorrtllm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)
if is_first_worker(config):
# Register the model with the endpoint if only the worker is first in the disaggregation chain.
await register_llm( await register_llm(
ModelType.Backend, ModelType.Backend,
endpoint, endpoint,
self._config.model_path, config.model_path,
self._config.served_model_name, config.served_model_name,
kv_cache_block_size=self._config.kv_block_size, kv_cache_block_size=config.kv_block_size,
) )
logger.info("Successfully registered LLM for discovery")
except Exception as e: # publisher will be set later if publishing is enabled.
logger.error(f"Failed to register LLM for discovery: {e}") handler_config = RequestHandlerConfig(
raise component=component,
engine=engine,
logger.info("TensorRT-LLM Worker initialized") default_sampling_params=default_sampling_params,
publisher=None,
@on_shutdown disaggregation_mode=config.disaggregation_mode,
async def async_cleanup(self): disaggregation_strategy=config.disaggregation_strategy,
logger.info("Cleaning up TensorRT-LLM Worker") next_client=next_client,
await self.cleanup() )
logger.info("TensorRT-LLM Worker cleanup completed")
if config.publish_events_and_metrics and is_first_worker(config):
@endpoint() # Initialize and pass in the publisher to the request handler to
async def generate(self, request: TRTLLMWorkerRequest): # publish events and metrics.
async for response in super().generate(request): kv_listener = runtime.namespace(config.namespace).component(
yield response config.component
)
async with get_tensorrtllm_publisher(
component,
engine,
kv_listener,
int(endpoint.lease_id()),
config.kv_block_size,
) as publisher:
handler_config.publisher = publisher
handler = RequestHandlerFactory().get_request_handler(handler_config)
await endpoint.serve_endpoint(handler.generate)
else:
handler = RequestHandlerFactory().get_request_handler(handler_config)
await endpoint.serve_endpoint(handler.generate)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
router: round-robin
TensorRTLLMWorker:
# Path to disk model or HuggingFace model identifier to load
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Name to serve the model under
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/engine_configs/agg_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
router: kv
TensorRTLLMWorker:
# Path to disk model or HuggingFace model identifier to load
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Name to serve the model under
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/engine_configs/agg_config.yaml"
router: kv
ServiceArgs:
workers: 1
resources:
gpu: 1
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
# This is the client-facing model name, you can set this to anything you'd like.
served_model_name: "nvidia/DeepSeek-R1-FP4"
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
router: round-robin
TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
extra-engine-args: "configs/deepseek_r1/engine_configs/agg_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
# This is the client-facing model name, you can set this to anything you'd like.
served_model_name: "nvidia/DeepSeek-R1-FP4"
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
router: round-robin
TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
extra-engine-args: "configs/deepseek_r1/engine_configs/decode_config.yaml"
enable-disagg: true
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 4
TensorRTLLMPrefillWorker:
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
extra-engine-args: "configs/deepseek_r1/engine_configs/prefill_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: "nvidia/DeepSeek-R1-FP4"
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
router: round-robin
TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/deepseek_r1/mtp/engine_configs/agg_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: "nvidia/DeepSeek-R1-FP4"
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
router: round-robin
TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/deepseek_r1/mtp/engine_configs/decode_config.yaml"
router: round-robin
enable-disagg: true
ServiceArgs:
workers: 1
resources:
gpu: 4
TensorRTLLMPrefillWorker:
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/deepseek_r1/mtp/engine_configs/prefill_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
router: round-robin
TensorRTLLMWorker:
# Path to disk model or HuggingFace model identifier to load
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Name to serve the model under
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/engine_configs/decode_config.yaml"
enable-disagg: true
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
TensorRTLLMPrefillWorker:
# Path to disk model or HuggingFace model identifier to load
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/engine_configs/prefill_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
router: kv
TensorRTLLMWorker:
# Path to disk model or HuggingFace model identifier to load
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Name to serve the model under
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/engine_configs/decode_config.yaml"
enable-disagg: true
router: kv
ServiceArgs:
workers: 1
resources:
gpu: 1
TensorRTLLMPrefillWorker:
# Path to disk model or HuggingFace model identifier to load
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/engine_configs/prefill_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 1
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
# This is the client-facing model name, you can set this to anything you'd like.
served_model_name: "nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8"
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
router: round-robin
TensorRTLLMWorker:
served_model_name: "nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8"
model-path: "nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8"
extra-engine-args: "configs/llama4/eagle/engine_configs/agg_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: "nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8"
endpoint: dynamo.TensorRTLLMWorker.generate
port: 8000
router: round-robin
TensorRTLLMWorker:
served_model_name: "nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8"
model-path: "nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/llama4/eagle/engine_configs/decode_config.yaml"
router: round-robin
enable-disagg: true
ServiceArgs:
workers: 1
resources:
gpu: 4
TensorRTLLMPrefillWorker:
model-path: "nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/llama4/eagle/engine_configs/prefill_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
resources:
gpu: 4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment