Unverified Commit 901715b5 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

refactor: Refactor the TRTLLM examples remove dynamo SDK (#1884)

parent 5bf23d54
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Environment variables with defaults
export MODEL_PATH=${MODEL_PATH:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
export AGG_ENGINE_ARGS=${AGG_ENGINE_ARGS:-"engine_configs/agg.yaml"}
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID 2>/dev/null || true
wait $DYNAMO_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run clear_namespace
python3 utils/clear_namespace.py --namespace dynamo
# run ingress
dynamo run in=http out=dyn --router-mode kv --http-port=8000 &
DYNAMO_PID=$!
# run worker
python3 components/worker.py \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$AGG_ENGINE_ARGS" \
--publish-events-and-metrics
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Environment variables with defaults
export MODEL_PATH=${MODEL_PATH:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"decode_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"engine_configs/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"engine_configs/decode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"}
export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"1"}
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run clear_namespace
python3 utils/clear_namespace.py --namespace dynamo
# run ingress
dynamo run in=http out=dyn --http-port=8000 &
DYNAMO_PID=$!
# run prefill worker
CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 components/worker.py \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PREFILL_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--disaggregation-mode prefill &
PREFILL_PID=$!
# run decode worker
CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 components/worker.py \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$DECODE_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--disaggregation-mode decode
\ No newline at end of file
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Environment variables with defaults
export MODEL_PATH=${MODEL_PATH:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"prefill_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"engine_configs/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"engine_configs/decode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"}
export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"1"}
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run clear_namespace
python3 utils/clear_namespace.py --namespace dynamo
# run ingress
dynamo run in=http out=dyn --router-mode kv --http-port=8000 &
DYNAMO_PID=$!
EXTRA_PREFILL_ARGS=()
EXTRA_DECODE_ARGS=()
if [ "$DISAGGREGATION_STRATEGY" == "prefill_first" ]; then
EXTRA_PREFILL_ARGS+=(--publish-events-and-metrics)
else
EXTRA_DECODE_ARGS+=(--publish-events-and-metrics)
fi
# run prefill worker
CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 components/worker.py \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PREFILL_ENGINE_ARGS" \
--disaggregation-mode prefill \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
"${EXTRA_PREFILL_ARGS[@]}" &
PREFILL_PID=$!
# run decode worker
CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 components/worker.py \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$DECODE_ENGINE_ARGS" \
--disaggregation-mode decode \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
"${EXTRA_DECODE_ARGS[@]}"
\ No newline at end of file
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Llama 4 Maverick Instruct with Eagle Speculative Decoding on SLURM
This guide demonstrates how to deploy Llama 4 Maverick Instruct with Eagle Speculative Decoding on GB200x4 nodes. We will be following the [multi-node deployment instructions](./multinode/multinode-examples.md) to set up the environment for the following scenarios:
- **Aggregated Serving:**
Deploy the entire Llama 4 model on a single GB200x4 node for end-to-end serving.
- **Disaggregated Serving:**
Distribute the workload across two GB200x4 nodes:
- One node runs the decode worker.
- The other node runs the prefill worker.
For advanced control over how requests are routed between prefill and decode workers in disaggregated mode, refer to the [Disaggregation Strategy](./README.md#disaggregation-strategy) section.
## Notes
* To run Eagle Speculative Decoding with Llama 4, ensure the container meets the following criteria:
* Built with a version of TensorRT-LLM based on the 0.21 release [Link](https://github.com/NVIDIA/TensorRT-LLM/tree/release/0.21)
* The TensorRT-LLM build includes the changes from this PR [Link](https://github.com/NVIDIA/TensorRT-LLM/pull/5975)
* If you need to download model weights off huggingface, make sure you run the command `huggingface-cli login` and have access to the necessary gated models.
## Setup
Assuming you have already allocated your nodes via `salloc`, and are
inside an interactive shell on one of the allocated nodes, set the
following environment variables based:
```bash
cd $DYNAMO_ROOT/examples/tensorrt_llm
export IMAGE="<dynamo_trtllm_image>"
# export MOUNTS="${PWD}/:/mnt,/lustre:/lustre"
export MOUNTS="${PWD}/:/mnt"
export MODEL_PATH="nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8"
export SERVED_MODEL_NAME="nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8"
```
See [this](./multinode/multinode-examples.md#setup) section from multinode guide to learn more about the above options.
## Aggregated Serving
```bash
export NUM_NODES=1
export ENGINE_CONFIG="/mnt/engine_configs/llama4/eagle/eagle_agg.yaml"
./multinode/srun_aggregated.sh
```
* Known Issue: In Aggregated Serving, setting `max_num_tokens` to higher values (e.g. `max_num_tokens: 8448`) can lead to Out of Memory (OOM) errors. This is being investigated by the TRTLLM team.
## Disaggregated Serving
```bash
export NUM_PREFILL_NODES=1
export PREFILL_ENGINE_CONFIG="/mnt/engine_configs/llama4/eagle/eagle_prefill.yaml"
export NUM_DECODE_NODES=1
export DECODE_ENGINE_CONFIG="/mnt/engine_configs/llama4/eagle/eagle_decode.yaml"
./multinode/srun_disaggregated.sh
```
* Known Issue: In Aggregated Serving, setting `max_num_tokens` to higher values (e.g. `max_num_tokens: 8448`) can lead to Out of Memory (OOM) errors. This is being investigated by the TRTLLM team.
## Example Request
See [here](./multinode/multinode-examples.md#example-request) to learn how to send a request to the deployment.
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Example: Multi-node TRTLLM Workers with Dynamo on Slurm
......@@ -94,8 +106,8 @@ export IMAGE="<dynamo_trtllm_image>"
# For example, assuming your cluster had a `/lustre` directory on the host, you
# could add that as a mount like so:
#
# export MOUNTS="${PWD}:/mnt,/lustre:/lustre"
export MOUNTS="${PWD}:/mnt"
# export MOUNTS="${PWD}/../:/mnt,/lustre:/lustre"
export MOUNTS="${PWD}/../:/mnt"
# NOTE: In general, Deepseek R1 is very large, so it is recommended to
# pre-download the model weights and save them in some shared location,
......@@ -124,7 +136,7 @@ follow these steps below to launch an **aggregated** deployment across 4 nodes:
```bash
# Default set in srun_aggregated.sh, but can customize here.
# export ENGINE_CONFIG="/mnt/engine_configs/wide_ep_agg.yaml"
# export ENGINE_CONFIG="/mnt/engine_configs/deepseek_r1/wide_ep/wide_ep_agg.yaml"
# Customize NUM_NODES to match the desired parallelism in ENGINE_CONFIG
# The product of NUM_NODES*NUM_GPUS_PER_NODE should match the number of
......@@ -153,8 +165,8 @@ deployment across 8 nodes:
```bash
# Defaults set in srun_disaggregated.sh, but can customize here.
# export PREFILL_ENGINE_CONFIG="/mnt/engine_configs/wide_ep_prefill.yaml"
# export DECODE_ENGINE_CONFIG="/mnt/engine_configs/wide_ep_decode.yaml"
# export PREFILL_ENGINE_CONFIG="/mnt/engine_configs/deepseek_r1/wide_ep/wide_ep_prefill.yaml"
# export DECODE_ENGINE_CONFIG="/mnt/engine_configs/deepseek_r1/wide_ep/wide_ep_decode.yaml"
# Customize NUM_PREFILL_NODES to match the desired parallelism in PREFILL_ENGINE_CONFIG
# Customize NUM_DECODE_NODES to match the desired parallelism in DECODE_ENGINE_CONFIG
......
......@@ -10,7 +10,7 @@ IMAGE="${IMAGE:-""}"
# but you may freely customize the mounts based on your cluster. A common practice
# is to mount paths to NFS storage for common scripts, model weights, etc.
# NOTE: This can be a comma separated list of multiple mounts as well.
DEFAULT_MOUNT="${PWD}:/mnt"
DEFAULT_MOUNT="${PWD}/../:/mnt"
MOUNTS="${MOUNTS:-${DEFAULT_MOUNT}}"
# Example values, assuming 4 nodes with 4 GPUs on each node, such as 4xGB200 nodes.
......@@ -18,7 +18,7 @@ MOUNTS="${MOUNTS:-${DEFAULT_MOUNT}}"
NUM_NODES=${NUM_NODES:-4}
NUM_GPUS_PER_NODE=${NUM_GPUS_PER_NODE:-4}
export ENGINE_CONFIG="${ENGINE_CONFIG:-/mnt/engine_configs/wide_ep_agg.yaml}"
export ENGINE_CONFIG="${ENGINE_CONFIG:-/mnt/engine_configs/deepseek_r1/wide_ep/wide_ep_agg.yaml}"
# Automate settings of certain variables for convenience, but you are free
# to manually set these for more control as well.
......@@ -51,20 +51,19 @@ srun \
--nodelist "${HEAD_NODE}" \
--nodes 1 \
--jobid "${SLURM_JOB_ID}" \
/mnt/start_frontend_services.sh &
/mnt/multinode/start_frontend_services.sh &
# NOTE: Output streamed to stdout for ease of understanding the example, but
# in practice you would probably set `srun --output ... --error ...` to pipe
# the stdout/stderr to files.
echo "Launching multi-node worker in background."
# No --task for the worker defaults to aggregated mode
TASK="" \
DISAGGREGATION_MODE="prefill_and_decode" \
srun \
--mpi pmix \
--oversubscribe \
--container-image "${IMAGE}" \
--container-mounts "${MOUNTS}" \
--container-env ETCD_ENDPOINTS,NATS_SERVER,HEAD_NODE_IP,HEAD_NODE,TASK,ENGINE_CONFIG \
--container-env ETCD_ENDPOINTS,NATS_SERVER,HEAD_NODE_IP,HEAD_NODE,DISAGGREGATION_MODE,ENGINE_CONFIG \
--verbose \
--label \
-A "${ACCOUNT}" \
......@@ -72,4 +71,4 @@ srun \
--nodes "${NUM_NODES}" \
--ntasks-per-node "${NUM_GPUS_PER_NODE}" \
--jobid "${SLURM_JOB_ID}" \
/mnt/start_trtllm_worker.sh &
/mnt/multinode/start_trtllm_worker.sh &
\ No newline at end of file
......@@ -10,16 +10,18 @@ IMAGE="${IMAGE:-""}"
# but you may freely customize the mounts based on your cluster. A common practice
# is to mount paths to NFS storage for common scripts, model weights, etc.
# NOTE: This can be a comma separated list of multiple mounts as well.
DEFAULT_MOUNT="${PWD}:/mnt"
DEFAULT_MOUNT="${PWD}/../:/mnt"
MOUNTS="${MOUNTS:-${DEFAULT_MOUNT}}"
NUM_GPUS_PER_NODE=${NUM_GPUS_PER_NODE:-4}
NUM_PREFILL_NODES=${NUM_PREFILL_NODES:-4}
PREFILL_ENGINE_CONFIG="${PREFILL_ENGINE_CONFIG:-/mnt/engine_configs/wide_ep_prefill.yaml}"
PREFILL_ENGINE_CONFIG="${PREFILL_ENGINE_CONFIG:-/mnt/engine_configs/deepseek_r1/wide_ep/wide_ep_prefill.yaml}"
NUM_DECODE_NODES=${NUM_DECODE_NODES:-4}
DECODE_ENGINE_CONFIG="${DECODE_ENGINE_CONFIG:-/mnt/engine_configs/wide_ep_decode.yaml}"
DECODE_ENGINE_CONFIG="${DECODE_ENGINE_CONFIG:-/mnt/engine_configs/deepseek_r1/wide_ep/wide_ep_decode.yaml}"
DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"decode_first"}
# Automate settings of certain variables for convenience, but you are free
# to manually set these for more control as well.
......@@ -52,20 +54,20 @@ srun \
--nodelist "${HEAD_NODE}" \
--nodes 1 \
--jobid "${SLURM_JOB_ID}" \
/mnt/start_frontend_services.sh &
/mnt/multinode/start_frontend_services.sh &
# NOTE: Output streamed to stdout for ease of understanding the example, but
# in practice you would probably set `srun --output ... --error ...` to pipe
# the stdout/stderr to files.
echo "Launching multi-node prefill worker in background."
TASK=prefill \
DISAGGREGATION_MODE=prefill \
ENGINE_CONFIG=${PREFILL_ENGINE_CONFIG} \
srun \
--mpi pmix \
--oversubscribe \
--container-image "${IMAGE}" \
--container-mounts "${MOUNTS}" \
--container-env ETCD_ENDPOINTS,NATS_SERVER,HEAD_NODE_IP,HEAD_NODE,TASK,ENGINE_CONFIG \
--container-env ETCD_ENDPOINTS,NATS_SERVER,HEAD_NODE_IP,HEAD_NODE,DISAGGREGATION_MODE,DISAGGREGATION_STRATEGY,ENGINE_CONFIG \
--verbose \
--label \
-A "${ACCOUNT}" \
......@@ -73,17 +75,17 @@ srun \
--nodes "${NUM_PREFILL_NODES}" \
--ntasks-per-node "${NUM_GPUS_PER_NODE}" \
--jobid "${SLURM_JOB_ID}" \
/mnt/start_trtllm_worker.sh &
/mnt/multinode/start_trtllm_worker.sh &
echo "Launching multi-node decode worker in background."
TASK=decode \
DISAGGREGATION_MODE=decode \
ENGINE_CONFIG=${DECODE_ENGINE_CONFIG} \
srun \
--mpi pmix \
--oversubscribe \
--container-image "${IMAGE}" \
--container-mounts "${MOUNTS}" \
--container-env ETCD_ENDPOINTS,NATS_SERVER,HEAD_NODE_IP,HEAD_NODE,TASK,ENGINE_CONFIG \
--container-env ETCD_ENDPOINTS,NATS_SERVER,HEAD_NODE_IP,HEAD_NODE,DISAGGREGATION_MODE,DISAGGREGATION_STRATEGY,ENGINE_CONFIG \
--verbose \
--label \
-A "${ACCOUNT}" \
......@@ -91,4 +93,4 @@ srun \
--nodes "${NUM_DECODE_NODES}" \
--ntasks-per-node "${NUM_GPUS_PER_NODE}" \
--jobid "${SLURM_JOB_ID}" \
/mnt/start_trtllm_worker.sh &
/mnt/multinode/start_trtllm_worker.sh &
\ No newline at end of file
......@@ -22,11 +22,6 @@ if [[ -z ${ENGINE_CONFIG} ]]; then
exit 1
fi
EXTRA_ARGS=""
if [[ -n ${TASK} ]]; then
EXTRA_ARGS+="--task ${TASK}"
fi
# NOTE: When this script is run directly from srun, the environment variables
# for TRTLLM KV cache are not set. So we need to set them here.
# Related issue: https://github.com/ai-dynamo/dynamo/issues/1743
......@@ -34,13 +29,18 @@ if [[ -z ${TRTLLM_USE_UCX_KVCACHE} ]] && [[ -z ${TRTLLM_USE_NIXL_KVCACHE} ]]; th
export TRTLLM_USE_UCX_KVCACHE=1
fi
# NOTE: trtllm_inc.py is a standalone python script that launches a Dynamo+TRTLLM
# worker and registers itself with the runtime. It is currently easier to wrap
# this standalone script with `trtllm-llmapi-launch` for MPI handling purposes,
# but this may be refactored into 'dynamo serve' in the future.
EXTRA_ARGS=""
if [[ -n ${DISAGGREGATION_MODE} ]]; then
EXTRA_ARGS+="--disaggregation-mode ${DISAGGREGATION_MODE} "
fi
if [[ -n ${DISAGGREGATION_STRATEGY} ]]; then
EXTRA_ARGS+="--disaggregation-strategy ${DISAGGREGATION_STRATEGY} "
fi
trtllm-llmapi-launch \
python3 /workspace/launch/dynamo-run/src/subprocess/trtllm_inc.py \
python3 /mnt/components/worker.py \
--model-path "${MODEL_PATH}" \
--model-name "${SERVED_MODEL_NAME}" \
--served-model-name "${SERVED_MODEL_NAME}" \
--extra-engine-args "${ENGINE_CONFIG}" \
${EXTRA_ARGS}
${EXTRA_ARGS}
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import argparse
import asyncio
import logging
from dynamo.runtime import DistributedRuntime, EtcdKvCache, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
@dynamo_worker()
async def clear_namespace(runtime: DistributedRuntime, namespace: str):
etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
f"/{namespace}/",
{},
)
await etcd_kv_cache.clear_all()
logger.info(f"Cleared /{namespace} in EtcdKvCache")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--namespace", type=str, required=True)
args = parser.parse_args()
asyncio.run(clear_namespace(args.namespace))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from tensorrt_llm.llmapi import DisaggregatedParams
class DisaggregatedParamsCodec:
"""
Codec for encoding and decoding disaggregated params for network transfer.
"""
@staticmethod
def decode(
disaggregated_params: DisaggregatedParams,
) -> DisaggregatedParams:
if disaggregated_params is None:
return None
opaque_state = (
base64.b64decode(disaggregated_params.opaque_state)
if disaggregated_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
@staticmethod
def encode(
disaggregated_params: DisaggregatedParams,
) -> DisaggregatedParams:
if disaggregated_params is None:
return None
encoded_opaque_state = (
base64.b64encode(disaggregated_params.opaque_state).decode("utf-8")
if disaggregated_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=encoded_opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import asdict, dataclass
from enum import Enum
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from utils.disagg_utils import DisaggregatedParams, DisaggregatedParamsCodec
from dynamo.llm.tensorrtllm.engine import TensorRTLLMEngine
from dynamo.llm.tensorrtllm.publisher import Publisher
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
class DisaggregationMode(Enum):
AGGREGATED = "prefill_and_decode"
PREFILL = "prefill"
DECODE = "decode"
class DisaggregationStrategy(Enum):
PREFILL_FIRST = "prefill_first"
DECODE_FIRST = "decode_first"
@dataclass
class RequestHandlerConfig:
"""
Configuration for the request handler
"""
component: object
engine: TensorRTLLMEngine
default_sampling_params: SamplingParams
publisher: Publisher
disaggregation_mode: DisaggregationMode
disaggregation_strategy: DisaggregationStrategy
next_client: object
class HandlerBase:
"""
Base class for request handlers.
"""
def __init__(self, config: RequestHandlerConfig):
self.engine = config.engine
self.component = config.component
self.default_sampling_params = config.default_sampling_params
self.publisher = config.publisher
self.disaggregation_mode = config.disaggregation_mode
self.disaggregation_strategy = config.disaggregation_strategy
self.next_client = config.next_client
self.first_generation = True
def check_error(self, result: dict):
"""
Check if there is an error in the result.
"""
if self.disaggregation_mode == DisaggregationMode.PREFILL:
return result["finish_reason"] == "error"
else:
return (
result["finish_reason"] == "stop" or result["finish_reason"] == "error"
)
async def generate_locally(self, request: dict):
"""
Generate responses based on the disaggregation mode in the request.
"""
logging.debug(f"Request: {request}")
# Check if there is an error in the publisher error queue
publishers_error = (
self.publisher.check_error_queue() if self.publisher else None
)
if publishers_error:
raise publishers_error
inputs = request["token_ids"]
# Decode the disaggregated params from the request
disaggregated_params = None
if self.disaggregation_mode == DisaggregationMode.PREFILL:
request["stop_conditions"]["max_tokens"] = 1
disaggregated_params = LlmDisaggregatedParams(request_type="context_only")
if "disaggregated_params" in request:
if self.disaggregation_mode == DisaggregationMode.PREFILL:
raise ValueError("Cannot provide disaggregated_params in prefill mode")
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**request["disaggregated_params"])
)
disaggregated_params.request_type = "generation_only"
if (
self.disaggregation_mode == DisaggregationMode.DECODE
and disaggregated_params is None
):
raise ValueError("Disaggregated params are required for decode mode")
num_output_tokens_so_far = 0
sampling_params = self.default_sampling_params
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
# TODO: Instead of True, we should use streaming from the request.
# However, currently dynamo run does not send streaming in the request.
streaming = (
False if self.disaggregation_mode == DisaggregationMode.PREFILL else True
)
async for res in self.engine.llm.generate_async(
inputs=inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=streaming,
):
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publisher:
self.publisher.start()
self.first_generation = False
if res.finished and self.disaggregation_mode != DisaggregationMode.PREFILL:
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)
yield out
num_output_tokens_so_far = next_total_toks
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
from utils.request_handlers.handler_base import (
DisaggregationMode,
DisaggregationStrategy,
HandlerBase,
RequestHandlerConfig,
)
class RequestHandlerFactory:
def __init__(self):
self.handlers = {
"prefill": PrefillHandler,
"decode": DecodeHandler,
"prefill_and_decode": AggregatedHandler,
}
def _validate_config(self, config: RequestHandlerConfig):
if config.disaggregation_mode.value not in self.handlers:
raise ValueError(
f"Invalid disaggregation_mode '{config.disaggregation_mode.value}'"
)
if not config.next_client:
if (
config.disaggregation_mode == DisaggregationMode.PREFILL
and config.disaggregation_strategy
== DisaggregationStrategy.PREFILL_FIRST
):
raise ValueError(
"Next client is required for the main worker when disaggregation_mode='prefill' and disaggregation_strategy='prefill_first'."
)
if (
config.disaggregation_mode == DisaggregationMode.DECODE
and config.disaggregation_strategy
== DisaggregationStrategy.DECODE_FIRST
):
raise ValueError(
"Next client is required for the decode worker when disaggregation_mode='decode' and disaggregation_strategy='decode_first'."
)
def get_request_handler(self, config: RequestHandlerConfig) -> HandlerBase:
self._validate_config(config)
return self.handlers[config.disaggregation_mode.value](config)
def get_request_handler(config: RequestHandlerConfig) -> HandlerBase:
return RequestHandlerFactory().get_request_handler(config)
class AggregatedHandler(HandlerBase):
"""
Handler for the aggregated mode.
"""
def __init__(self, config: RequestHandlerConfig):
super().__init__(config)
async def generate(self, request: dict):
# Implement all steps locally.
async for res in self.generate_locally(request):
yield res
class PrefillHandler(HandlerBase):
"""
Handler for the prefill mode.
"""
def __init__(self, config: RequestHandlerConfig):
super().__init__(config)
async def remote_decode(self, request: dict):
async for res in await self.next_client.round_robin(request):
yield res.data()
async def generate(self, request: dict):
# Generate the prefill response locally
prefill_request = copy.deepcopy(request)
prefill_response = None
response_count = 0
async for res in self.generate_locally(prefill_request):
prefill_response = res
response_count += 1
if response_count > 1:
raise ValueError("Prefill response should be generated only once.")
if (
self.disaggregation_strategy == DisaggregationStrategy.PREFILL_FIRST
and not self.check_error(prefill_response)
):
# If operating under prefill_first strategy, the prefill handler needs to trigger
# the decode handler.
if prefill_response is not None:
request["disaggregated_params"] = prefill_response[
"disaggregated_params"
]
async for res in self.remote_decode(request):
yield res
else:
# Return response to the decode handler.
yield prefill_response
class DecodeHandler(HandlerBase):
"""
Handler for the decode mode.
"""
def __init__(self, config: RequestHandlerConfig):
super().__init__(config)
async def remote_prefill(self, request: dict):
async for res in await self.next_client.round_robin(request):
yield res
async def generate(self, request: dict):
if self.disaggregation_strategy == DisaggregationStrategy.DECODE_FIRST:
prefill_response = None
# If operating under decode_first strategy, the decode handler needs to trigger
# the prefill handler.
response_count = 0
async for res in self.remote_prefill(request):
prefill_response = res
response_count += 1
if response_count > 1:
raise ValueError("Prefill response should be generated only once.")
response_data = (
prefill_response.data() if prefill_response is not None else None
)
if prefill_response is not None and self.check_error(response_data):
yield response_data
return
if prefill_response is not None and response_data is not None:
request["disaggregated_params"] = response_data["disaggregated_params"]
async for res in self.generate_locally(request):
yield res
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import Optional
from utils.request_handlers.handler_base import (
DisaggregationMode,
DisaggregationStrategy,
)
# Default endpoint for the next worker.
DEFAULT_ENDPOINT = "dyn://dynamo.tensorrt_llm.generate"
DEFAULT_MODEL_PATH = "TinyLlama-1.1B-Instruct"
DEFAULT_NEXT_ENDPOINT = "dyn://dynamo.tensorrt_llm_next.generate"
DEFAULT_DISAGGREGATION_STRATEGY = DisaggregationStrategy.DECODE_FIRST
DEFAULT_DISAGGREGATION_MODE = DisaggregationMode.AGGREGATED
class Config:
"""Command line parameters or defaults"""
def __init__(self) -> None:
self.namespace: str = ""
self.component: str = ""
self.endpoint: str = ""
self.model_path: str = ""
self.served_model_name: Optional[str] = None
self.tensor_parallel_size: int = 1
self.kv_block_size: int = 32
self.extra_engine_args: str = ""
self.publish_events_and_metrics: bool = False
self.disaggregation_mode: DisaggregationMode = DEFAULT_DISAGGREGATION_MODE
self.disaggregation_strategy: DisaggregationStrategy = (
DEFAULT_DISAGGREGATION_STRATEGY
)
self.next_endpoint: str = ""
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
f"component={self.component}, "
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"served_model_name={self.served_model_name}, "
f"tensor_parallel_size={self.tensor_parallel_size}, "
f"kv_block_size={self.kv_block_size}, "
f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
f"disaggregation_strategy={self.disaggregation_strategy}, "
f"next_endpoint={self.next_endpoint})"
)
def is_first_worker(config):
"""
Check if the current worker is the first worker in the disaggregation chain.
"""
is_primary_worker = config.disaggregation_mode == DisaggregationMode.AGGREGATED
if not is_primary_worker:
is_primary_worker = (
config.disaggregation_strategy == DisaggregationStrategy.PREFILL_FIRST
) and (config.disaggregation_mode == DisaggregationMode.PREFILL)
if not is_primary_worker:
is_primary_worker = (
config.disaggregation_strategy == DisaggregationStrategy.DECODE_FIRST
) and (config.disaggregation_mode == DisaggregationMode.DECODE)
return is_primary_worker
def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
raise ValueError(
f"Invalid endpoint format: '{endpoint}'. "
"Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
namespace, component, endpoint_name = endpoint_parts
return namespace, component, endpoint_name
def cmd_line_args():
parser = argparse.ArgumentParser(
description="TensorRT-LLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default="",
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT} if first worker, {DEFAULT_NEXT_ENDPOINT} if next worker",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL_PATH,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL_PATH}",
)
parser.add_argument(
"--served-model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
# IMPORTANT: We should ideally not expose this to users. We should be able to
# query the block size from the TRTLLM engine.
parser.add_argument(
"--kv-block-size", type=int, default=32, help="Size of a KV cache block."
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
parser.add_argument(
"--publish-events-and-metrics",
action="store_true",
help="Publish events and metrics to the dynamo components. Note: This is not supported when running in prefill disaggregation mode.",
)
parser.add_argument(
"--disaggregation-mode",
type=str,
default=DEFAULT_DISAGGREGATION_MODE,
choices=[mode.value for mode in DisaggregationMode],
help=f"Mode to use for disaggregation. Default: {DEFAULT_DISAGGREGATION_MODE}",
)
parser.add_argument(
"--disaggregation-strategy",
type=str,
default=DEFAULT_DISAGGREGATION_STRATEGY,
choices=[strategy.value for strategy in DisaggregationStrategy],
help=f"Strategy to use for disaggregation. Default: {DEFAULT_DISAGGREGATION_STRATEGY}",
)
parser.add_argument(
"--next-endpoint",
type=str,
default="",
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send requests to when running in disaggregation mode. Default: {DEFAULT_NEXT_ENDPOINT} if first worker, empty if next worker",
)
args = parser.parse_args()
config = Config()
# Set the model path and served model name.
config.model_path = args.model_path
if args.served_model_name:
config.served_model_name = args.served_model_name
else:
# This becomes an `Option` on the Rust side
config.served_model_name = None
# Set the disaggregation mode and strategy.
config.disaggregation_mode = DisaggregationMode(args.disaggregation_mode)
config.disaggregation_strategy = DisaggregationStrategy(
args.disaggregation_strategy
)
# Set the appropriate defaults for the endpoint and next endpoint.
if is_first_worker(config):
if args.endpoint == "":
args.endpoint = DEFAULT_ENDPOINT
if (
args.next_endpoint == ""
and config.disaggregation_mode != DisaggregationMode.AGGREGATED
):
args.next_endpoint = DEFAULT_NEXT_ENDPOINT
else:
if args.endpoint == "":
args.endpoint = DEFAULT_NEXT_ENDPOINT
if args.next_endpoint != "":
raise ValueError("Next endpoint is not allowed for the next worker")
endpoint = args.endpoint
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
)
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.next_endpoint = args.next_endpoint
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment