Unverified Commit 12fe3551 authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

feat: Multimodal support for dynamo with trtllm backend (#2195)

parent ef2b0e67
...@@ -42,6 +42,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1)) ...@@ -42,6 +42,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
- [KV Cache Transfer](#kv-cache-transfer-in-disaggregated-serving) - [KV Cache Transfer](#kv-cache-transfer-in-disaggregated-serving)
- [Client](#client) - [Client](#client)
- [Benchmarking](#benchmarking) - [Benchmarking](#benchmarking)
- [Multimodal Support](#multimodal-support)
## Feature Support Matrix ## Feature Support Matrix
...@@ -261,6 +262,7 @@ DISAGGREGATION_STRATEGY="prefill_first" ./launch/disagg.sh ...@@ -261,6 +262,7 @@ DISAGGREGATION_STRATEGY="prefill_first" ./launch/disagg.sh
Dynamo with TensorRT-LLM supports two methods for transferring KV cache in disaggregated serving: UCX (default) and NIXL (experimental). For detailed information and configuration instructions for each method, see the [KV cache transfer guide](./kv-cache-tranfer.md). Dynamo with TensorRT-LLM supports two methods for transferring KV cache in disaggregated serving: UCX (default) and NIXL (experimental). For detailed information and configuration instructions for each method, see the [KV cache transfer guide](./kv-cache-tranfer.md).
## Request Migration ## Request Migration
You can enable [request migration](../../../docs/architecture/request_migration.md) to handle worker failures gracefully. Use the `--migration-limit` flag to specify how many times a request can be migrated to another worker: You can enable [request migration](../../../docs/architecture/request_migration.md) to handle worker failures gracefully. Use the `--migration-limit` flag to specify how many times a request can be migrated to another worker:
...@@ -281,3 +283,140 @@ NOTE: To send a request to a multi-node deployment, target the node which is run ...@@ -281,3 +283,140 @@ NOTE: To send a request to a multi-node deployment, target the node which is run
To benchmark your deployment with GenAI-Perf, see this utility script, configuring the To benchmark your deployment with GenAI-Perf, see this utility script, configuring the
`model` name and `host` based on your deployment: [perf.sh](../../../benchmarks/llm/perf.sh) `model` name and `host` based on your deployment: [perf.sh](../../../benchmarks/llm/perf.sh)
## Multimodal support
TRTLLM supports multimodal models with dynamo. You can provide multimodal inputs in the following ways:
- By sending image URLs
- By providing paths to pre-computed embedding files
Please note that you should provide **either image URLs or embedding file paths** in a single request.
### Aggregated
Here are quick steps to launch Llama-4 Maverick BF16 in aggregated mode
```bash
cd $DYNAMO_HOME/components/backends/trtllm
export AGG_ENGINE_ARGS=./engine_configs/multinode/agg.yaml
export SERVED_MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
export MODEL_PATH="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
./launch/agg.sh
```
### Example Requests
#### With Image URL
Below is an example of an image being sent to `Llama-4-Maverick-17B-128E-Instruct` model
Request :
```bash
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the image"
},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
}
}
]
}
],
"stream": false,
"max_tokens": 160
}'
```
Response :
```
{"id":"unknown-id","choices":[{"index":0,"message":{"content":"The image depicts a serene landscape featuring a large rock formation, likely El Capitan in Yosemite National Park, California. The scene is characterized by a winding road that curves from the bottom-right corner towards the center-left of the image, with a few rocks and trees lining its edge.\n\n**Key Features:**\n\n* **Rock Formation:** A prominent, tall, and flat-topped rock formation dominates the center of the image.\n* **Road:** A paved road winds its way through the landscape, curving from the bottom-right corner towards the center-left.\n* **Trees and Rocks:** Trees are visible on both sides of the road, with rocks scattered along the left side.\n* **Sky:** The sky above is blue, dotted with white clouds.\n* **Atmosphere:** The overall atmosphere of the","refusal":null,"tool_calls":null,"role":"assistant","function_call":null,"audio":null},"finish_reason":"stop","logprobs":null}],"created":1753322607,"model":"meta-llama/Llama-4-Maverick-17B-128E-Instruct","service_tier":null,"system_fingerprint":null,"object":"chat.completion","usage":null}
```
### Disaggregated
Here are quick steps to launch in disaggregated mode.
The following is an example of launching a model in disaggregated mode. While this example uses `Qwen/Qwen2-VL-7B-Instruct`, you can adapt it for other models by modifying the environment variables for the model path and engine configurations.
```bash
cd $DYNAMO_HOME/components/backends/trtllm
export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2-VL-7B-Instruct"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen2-VL-7B-Instruct"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"decode_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"engine_configs/multimodal/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"engine_configs/multimodal/decode.yaml"}
export MODALITY=${MODALITY:-"multimodal"}
./launch/disagg.sh
```
For a large model like `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, a multi-node setup is required for disaggregated serving, while aggregated serving can run on a single node. This is because the model with a disaggregated configuration is too large to fit on a single node's GPUs. For instance, running this model in disaggregated mode requires a setup of 2 nodes with 8xH200 GPUs or 4 nodes with 4xGB200 GPUs.
In general, disaggregated serving can run on a single node, provided the model fits on the GPU. The multi-node requirement in this example is specific to the size and configuration of the `meta-llama/Llama-4-Maverick-17B-128E-Instruct` model.
To deploy `Llama-4-Maverick-17B-128E-Instruct` in disaggregated mode, you will need to follow the multi-node setup instructions, which can be found [here](multinode/multinode-multimodal-example.md).
### Using Pre-computed Embeddings (Experimental)
Dynamo with TensorRT-LLM supports providing pre-computed embeddings directly in an inference request. This bypasses the need for the model to process an image and generate embeddings itself, which is useful for performance optimization or when working with custom, pre-generated embeddings.
#### Enabling the Feature
This is an experimental feature that requires using a specific TensorRT-LLM commit.
To enable it build the dynamo container with the `--tensorrtllm-commit` flag, followed by the commit hash:
```bash
./container/build.sh --framework tensorrtllm --tensorrtllm-commit b4065d8ca64a64eee9fdc64b39cb66d73d4be47c
```
#### How to Use
Once the container is built, you can send requests with paths to local embedding files.
- **Format:** Provide the embedding as part of the `messages` array, using the `image_url` content type.
- **URL:** The `url` field should contain the absolute or relative path to your embedding file on the local filesystem.
- **File Types:** Supported embedding file extensions are `.pt`, `.pth`, and `.bin`. Dynamo will automatically detect these extensions.
When a request with a supported embedding file is received, Dynamo will load the tensor from the file and pass it directly to the model for inference, skipping the image-to-embedding pipeline.
#### Example Request
Here is an example of how to send a request with a pre-computed embedding file.
```bash
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the content represented by the embeddings"
},
{
"type": "image_url",
"image_url": {
"url": "/path/to/your/embedding.pt"
}
}
]
}
],
"stream": false,
"max_tokens": 160
}'
```
### Supported Multimodal Models
Multimodel models listed [here](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/inputs/utils.py#L221) are supported by dynamo.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 8
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 4096
max_batch_size: 8
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
kv_cache_config:
free_gpu_memory_fraction: 0.3
enable_block_reuse: false
cache_transceiver_config:
backend: default
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: overlap_scheduler enabled by default since this commit and changed
# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler':
# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
disable_overlap_scheduler: false
kv_cache_config:
free_gpu_memory_fraction: 0.30
enable_block_reuse: false
cache_transceiver_config:
backend: default
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 8
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
disable_overlap_scheduler: false
kv_cache_config:
free_gpu_memory_fraction: 0.30
enable_block_reuse: false
cache_transceiver_config:
backend: default
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 8
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
# Overlap scheduler not currently supported in prefill only workers.
disable_overlap_scheduler: true
kv_cache_config:
free_gpu_memory_fraction: 0.30
enable_block_reuse: false
cache_transceiver_config:
backend: default
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
# Overlap scheduler not currently supported in prefill only workers.
disable_overlap_scheduler: true
kv_cache_config:
free_gpu_memory_fraction: 0.30
enable_block_reuse: false
cache_transceiver_config:
backend: default
\ No newline at end of file
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
export MODEL_PATH=${MODEL_PATH:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"} export MODEL_PATH=${MODEL_PATH:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"} export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
export AGG_ENGINE_ARGS=${AGG_ENGINE_ARGS:-"engine_configs/agg.yaml"} export AGG_ENGINE_ARGS=${AGG_ENGINE_ARGS:-"engine_configs/agg.yaml"}
export MODALITY=${MODALITY:-"text"}
# If you want to use multimodal, set MODALITY to "multimodal"
#export MODALITY=${MODALITY:-"multimodal"}
# Setup cleanup trap # Setup cleanup trap
cleanup() { cleanup() {
...@@ -27,4 +30,5 @@ DYNAMO_PID=$! ...@@ -27,4 +30,5 @@ DYNAMO_PID=$!
python3 -m dynamo.trtllm \ python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \ --model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \ --served-model-name "$SERVED_MODEL_NAME" \
--modality "$MODALITY" \
--extra-engine-args "$AGG_ENGINE_ARGS" --extra-engine-args "$AGG_ENGINE_ARGS"
...@@ -10,6 +10,9 @@ export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"engine_configs/prefill.yaml"} ...@@ -10,6 +10,9 @@ export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"engine_configs/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"engine_configs/decode.yaml"} export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"engine_configs/decode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"} export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"}
export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"1"} export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"1"}
export MODALITY=${MODALITY:-"text"}
# If you want to use multimodal, set MODALITY to "multimodal"
#export MODALITY=${MODALITY:-"multimodal"}
# Setup cleanup trap # Setup cleanup trap
cleanup() { cleanup() {
...@@ -33,6 +36,7 @@ CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \ ...@@ -33,6 +36,7 @@ CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--served-model-name "$SERVED_MODEL_NAME" \ --served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PREFILL_ENGINE_ARGS" \ --extra-engine-args "$PREFILL_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \ --disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--modality "$MODALITY" \
--disaggregation-mode prefill & --disaggregation-mode prefill &
PREFILL_PID=$! PREFILL_PID=$!
...@@ -42,4 +46,5 @@ CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \ ...@@ -42,4 +46,5 @@ CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--served-model-name "$SERVED_MODEL_NAME" \ --served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$DECODE_ENGINE_ARGS" \ --extra-engine-args "$DECODE_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \ --disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--modality "$MODALITY" \
--disaggregation-mode decode --disaggregation-mode decode
\ No newline at end of file
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Example: Multi-node TRTLLM Workers with Dynamo on Slurm for multimodal models
This guide demonstrates how to deploy large multimodal models that require a multi-node setup. It builds on the general multi-node deployment process described in the main [multinode-examples.md](./multinode-examples.md) guide.
Before you begin, ensure you have completed the initial environment configuration by following the **Setup** section in that guide.
The following sections provide specific instructions for deploying `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, including environment variable setup and launch commands. These steps can be adapted for other large multimodal models.
### Environment Variable Setup
Assuming you have already allocated your nodes via `salloc`, and are
inside an interactive shell on one of the allocated nodes, set the
following environment variables based:
```bash
# NOTE: IMAGE must be set manually for now
# To build an iamge, see the steps here:
# https://github.com/ai-dynamo/dynamo/tree/main/components/backends/trtllm#build-docker
export IMAGE="<dynamo_trtllm_image>"
# MOUNTS are the host:container path pairs that are mounted into the containers
# launched by each `srun` command.
#
# If you want to reference files, such as $MODEL_PATH below, in a
# different location, you can customize MOUNTS or specify additional
# comma-separated mount pairs here.
#
# NOTE: Currently, this example assumes that the local bash scripts and configs
# referenced are mounted into into /mnt inside the container. If you want to
# customize the location of the scripts, make sure to modify `srun_aggregated.sh`
# accordingly for the new locations of `start_frontend_services.sh` and
# `start_trtllm_worker.sh`.
#
# For example, assuming your cluster had a `/lustre` directory on the host, you
# could add that as a mount like so:
#
# export MOUNTS="${PWD}/../:/mnt,/lustre:/lustre"
export MOUNTS="${PWD}/../:/mnt"
# Can point to local FS as weel
# export MODEL_PATH="/location/to/model"
export MODEL_PATH="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
# The name the model will be served/queried under, matching what's
# returned by the /v1/models endpoint.
#
# By default this is inferred from MODEL_PATH, but when using locally downloaded
# model weights, it can be nice to have explicit control over the name.
export SERVED_MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
export MODALITY=${MODALITY:-"multimodal"}
```
## Disaggregated Mode
Assuming you have at least 4 4xGB200 nodes allocated (2 for prefill, 2 for decode)
following the setup above, follow these steps below to launch a **disaggregated**
deployment across 4 nodes:
> [!Tip]
> Make sure you have a fresh environment and don't still have the aggregated
> example above still deployed on the same set of nodes.
```bash
# Defaults set in srun_disaggregated.sh, but can customize here.
# export PREFILL_ENGINE_CONFIG="/mnt/engine_configs/multimodal/llama4/prefill.yaml"
# export DECODE_ENGINE_CONFIG="/mnt/engine_configs/multimodal/llama4/decode.yaml"
# Customize NUM_PREFILL_NODES to match the desired parallelism in PREFILL_ENGINE_CONFIG
# Customize NUM_DECODE_NODES to match the desired parallelism in DECODE_ENGINE_CONFIG
# The products of NUM_PREFILL_NODES*NUM_GPUS_PER_NODE and
# NUM_DECODE_NODES*NUM_GPUS_PER_NODE should match the respective number of
# GPUs necessary to satisfy the requested parallelism in each config.
# export NUM_PREFILL_NODES=2
# export NUM_DECODE_NODES=2
# GB200 nodes have 4 gpus per node, but for other types of nodes you can configure this.
# export NUM_GPUS_PER_NODE=4
# Launches:
# - frontend + etcd/nats on current (head) node.
# - one large prefill trtllm worker across multiple nodes via MPI tasks
# - one large decode trtllm worker across multiple nodes via MPI tasks
./srun_disaggregated.sh
```
## Understanding the Output
1. The `srun_disaggregated.sh` launches three srun jobs instead of two. One for frontend, one for prefill worker, and one for decode worker.
2. The OpenAI frontend will listen for and dynamically discover workers as
they register themselves with Dynamo's distributed runtime:
```
0: 2025-06-13T02:36:48.160Z INFO dynamo_run::input::http: Watching for remote model at models
0: 2025-06-13T02:36:48.161Z INFO dynamo_llm::http::service::service_v2: Starting HTTP service on: 0.0.0.0:8000 address="0.0.0.0:8000"
```
3. The TRTLLM worker will consist of N (N=8 for TP8) MPI ranks, 1 rank on each
GPU on each node, which will each output their progress while loading the model.
You can see each rank's output prefixed with the rank at the start of each log line
until the model succesfully finishes loading:
```
7: rank7 run mgmn worker node with mpi_world_size: 8 ...
```
4. After the model fully finishes loading on all ranks, the worker will register itself,
and the OpenAI frontend will detect it, signaled by this output:
```
0: 2025-06-13T02:46:35.040Z INFO dynamo_llm::discovery::watcher: added model model_name="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
```
5. At this point, with the worker fully initialized and detected by the frontend,
it is now ready for inference.
## Example Request
To verify the deployed model is working, send a `curl` request:
```bash
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the image"
},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
}
}
]
}
],
"stream": false,
"max_tokens": 160
}'
```
## Cleanup
To cleanup background `srun` processes launched by `srun_aggregated.sh` or
`srun_disaggregated.sh`, you can run:
```bash
pkill srun
```
## Known Issues
- Loading `meta-llama/Llama-4-Maverick-17B-128E-Instruct` with 8 nodes of H100 with TP=16 is not posssible due to Llama4 Maverick has a config `"num_attention_heads": 40` , trtllm engine asserts on assert `self.num_heads % tp_size == 0` causing the engine to crash on model loading.
\ No newline at end of file
...@@ -31,6 +31,10 @@ if [[ -n ${DISAGGREGATION_STRATEGY} ]]; then ...@@ -31,6 +31,10 @@ if [[ -n ${DISAGGREGATION_STRATEGY} ]]; then
EXTRA_ARGS+="--disaggregation-strategy ${DISAGGREGATION_STRATEGY} " EXTRA_ARGS+="--disaggregation-strategy ${DISAGGREGATION_STRATEGY} "
fi fi
if [[ -n ${MODALITY} ]]; then
EXTRA_ARGS+="--modality ${MODALITY} "
fi
trtllm-llmapi-launch \ trtllm-llmapi-launch \
python3 -m dynamo.trtllm \ python3 -m dynamo.trtllm \
--model-path "${MODEL_PATH}" \ --model-path "${MODEL_PATH}" \
......
...@@ -18,11 +18,13 @@ from tensorrt_llm.llmapi import ( ...@@ -18,11 +18,13 @@ from tensorrt_llm.llmapi import (
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from torch.cuda import device_count from torch.cuda import device_count
from transformers import AutoConfig
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import get_llm_engine from dynamo.trtllm.engine import get_llm_engine
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import get_publisher from dynamo.trtllm.publisher import get_publisher
from dynamo.trtllm.request_handlers.handlers import ( from dynamo.trtllm.request_handlers.handlers import (
RequestHandlerConfig, RequestHandlerConfig,
...@@ -119,7 +121,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -119,7 +121,7 @@ async def init(runtime: DistributedRuntime, config: Config):
capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
dynamic_batch_config=dynamic_batch_config, dynamic_batch_config=dynamic_batch_config,
) )
modality = getattr(config, "modality", None) or "text"
arg_map = { arg_map = {
"model": model_path, "model": model_path,
"scheduler_config": scheduler_config, "scheduler_config": scheduler_config,
...@@ -171,7 +173,22 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -171,7 +173,22 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params = SamplingParams() default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer) default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None default_sampling_params.stop = None
modelType = ModelType.Backend
multimodal_processor = None
if modality == "multimodal":
engine_args["skip_tokenizer_init"] = False
modelType = ModelType.Chat
model_config = AutoConfig.from_pretrained(
config.model_path, trust_remote_code=True
)
multimodal_processor = MultimodalRequestProcessor(
model_type=model_config.model_type,
model_dir=config.model_path,
tokenizer=tokenizer,
)
else:
# We already detokenize inside HandlerBase. No need to also do it in TRTLLM. # We already detokenize inside HandlerBase. No need to also do it in TRTLLM.
default_sampling_params.detokenize = False default_sampling_params.detokenize = False
...@@ -181,14 +198,13 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -181,14 +198,13 @@ async def init(runtime: DistributedRuntime, config: Config):
if is_first_worker(config): if is_first_worker(config):
# Register the model with the endpoint if only the worker is first in the disaggregation chain. # Register the model with the endpoint if only the worker is first in the disaggregation chain.
await register_llm( await register_llm(
ModelType.Backend, modelType,
endpoint, endpoint,
config.model_path, config.model_path,
config.served_model_name, config.served_model_name,
kv_cache_block_size=config.kv_block_size, kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit, migration_limit=config.migration_limit,
) )
# publisher will be set later if publishing is enabled. # publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig( handler_config = RequestHandlerConfig(
component=component, component=component,
...@@ -198,6 +214,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -198,6 +214,7 @@ async def init(runtime: DistributedRuntime, config: Config):
disaggregation_mode=config.disaggregation_mode, disaggregation_mode=config.disaggregation_mode,
disaggregation_strategy=config.disaggregation_strategy, disaggregation_strategy=config.disaggregation_strategy,
next_client=next_client, next_client=next_client,
multimodal_processor=multimodal_processor,
) )
if config.publish_events_and_metrics and is_first_worker(config): if config.publish_events_and_metrics and is_first_worker(config):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any, Dict, List, Optional, Protocol, Tuple
import torch
from tensorrt_llm.inputs import default_multimodal_input_loader
class TokenizerProtocol(Protocol):
"""
A protocol for tokenizers that defines a decode method.
This is used for type hinting to resolve mypy errors related to
the tokenizer's decode method not being found on a generic 'object' type.
"""
def decode(self, token_ids: List[int]) -> str:
...
class MultimodalRequestProcessor:
"""Simple processor for OpenAI format multimodal requests."""
def __init__(
self,
model_type: str,
model_dir: str,
tokenizer: Optional[TokenizerProtocol] = None,
):
self.model_type = model_type
self.model_dir = model_dir
self.tokenizer = tokenizer
self.modality = ""
def extract_prompt_and_media(
self, messages: List[Dict]
) -> Tuple[str, List[str], List[str]]:
"""Extracts text prompt, image URLs, and embedding paths from messages."""
text_parts = []
image_urls = []
embedding_paths = []
for message in messages:
for content in message.get("content", []):
if content.get("type") == "text":
text_parts.append(content.get("text", ""))
elif content.get("type") == "image_url":
url = content.get("image_url", {}).get("url", "")
if not url:
continue
self.modality = "image"
if url.endswith((".pt", ".pth", ".bin")):
embedding_paths.append(url)
else:
image_urls.append(url)
return " ".join(text_parts), image_urls, embedding_paths
async def process_openai_request(self, request: Dict) -> Optional[Any]:
"""Process OpenAI request and return with multimodal data."""
# Normalize the request to handle OpenAI format
if "stop_conditions" not in request:
request["stop_conditions"] = {}
if "max_tokens" in request and "max_tokens" not in request["stop_conditions"]:
request["stop_conditions"]["max_tokens"] = request.pop("max_tokens")
if "sampling_options" not in request:
request["sampling_options"] = {}
if (
"temperature" in request
and "temperature" not in request["sampling_options"]
):
request["sampling_options"]["temperature"] = request.pop("temperature")
messages = request.get("messages", [])
text_prompt, image_urls, embedding_paths = self.extract_prompt_and_media(
messages
)
if not image_urls and not embedding_paths:
# No multimodal content, return None
return None
loader_kwargs = {}
if embedding_paths:
mm_embeds = [torch.load(path) for path in embedding_paths]
loader_kwargs["mm_embeddings"] = mm_embeds
elif image_urls:
loader_kwargs["media"] = [image_urls]
# Process with default_multimodal_input_loader
processed_inputs = default_multimodal_input_loader(
tokenizer=None,
model_dir=self.model_dir,
model_type=self.model_type,
modality=self.modality,
prompts=[text_prompt],
image_data_format="pt",
device="cuda",
**loader_kwargs,
)
# Return the first processed input if available
if processed_inputs:
return processed_inputs[0]
return None
def create_response_chunk(
self,
output: Any,
num_output_tokens_so_far: int,
request_id: str,
model_name: str,
) -> Dict[str, Any]:
"""Creates a response chunk for multimodal streaming."""
if self.tokenizer is None:
raise ValueError("Tokenizer must be provided for creating response chunks.")
new_tokens = output.token_ids[num_output_tokens_so_far:]
# Decode the new token IDs into a string. This is the incremental piece
# of text to be sent to the client.
delta_text = self.tokenizer.decode(new_tokens)
# Assemble the delta payload for the response chunk.
delta = {"content": delta_text if delta_text else ""}
if num_output_tokens_so_far == 0:
# The first chunk must include the "assistant" role.
delta["role"] = "assistant"
choice = {
"index": 0,
"delta": delta,
"finish_reason": output.finish_reason,
}
# Wrap the choice in the final response chunk following the OpenAI
# streaming format.
return {
"id": request_id,
"model": model_name,
"created": int(time.time()),
"object": "chat.completion.chunk",
"choices": [choice],
}
def get_stop_response(self, request_id: str, model_name: str) -> Dict[str, Any]:
"""Creates the final stop response chunk for multimodal streaming."""
final_choice = {
"index": 0,
"delta": {},
"finish_reason": "stop",
}
return {
"id": request_id,
"model": model_name,
"created": int(time.time()),
"object": "chat.completion.chunk",
"choices": [final_choice],
"finish_reason": "stop",
}
...@@ -16,12 +16,14 @@ ...@@ -16,12 +16,14 @@
import logging import logging
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from typing import Optional
from tensorrt_llm import SamplingParams from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import Publisher from dynamo.trtllm.publisher import Publisher
from dynamo.trtllm.utils.disagg_utils import ( from dynamo.trtllm.utils.disagg_utils import (
DisaggregatedParams, DisaggregatedParams,
...@@ -55,6 +57,9 @@ class RequestHandlerConfig: ...@@ -55,6 +57,9 @@ class RequestHandlerConfig:
disaggregation_mode: DisaggregationMode disaggregation_mode: DisaggregationMode
disaggregation_strategy: DisaggregationStrategy disaggregation_strategy: DisaggregationStrategy
next_client: object next_client: object
multimodal_processor: Optional[
MultimodalRequestProcessor
] = None # for multimodal support
class HandlerBase: class HandlerBase:
...@@ -70,6 +75,7 @@ class HandlerBase: ...@@ -70,6 +75,7 @@ class HandlerBase:
self.disaggregation_mode = config.disaggregation_mode self.disaggregation_mode = config.disaggregation_mode
self.disaggregation_strategy = config.disaggregation_strategy self.disaggregation_strategy = config.disaggregation_strategy
self.next_client = config.next_client self.next_client = config.next_client
self.multimodal_processor = config.multimodal_processor
self.first_generation = True self.first_generation = True
def check_error(self, result: dict): def check_error(self, result: dict):
...@@ -87,9 +93,22 @@ class HandlerBase: ...@@ -87,9 +93,22 @@ class HandlerBase:
""" """
Generate responses based on the disaggregation mode in the request. Generate responses based on the disaggregation mode in the request.
""" """
logging.debug(f"Request: {request}") logging.debug(f"Request: {request}")
# Default to text-based input. This will be overwritten if multimodal
# content is found and processed.
processed_input = None
# Check for multimodal request and process it
if self.multimodal_processor:
processed_input = await self.multimodal_processor.process_openai_request(
request
)
else:
# text-only flow
processed_input = request.get("token_ids")
# Check if there is an error in the publisher error queue # Check if there is an error in the publisher error queue
publishers_error = ( publishers_error = (
self.publisher.check_error_queue() if self.publisher else None self.publisher.check_error_queue() if self.publisher else None
...@@ -97,10 +116,9 @@ class HandlerBase: ...@@ -97,10 +116,9 @@ class HandlerBase:
if publishers_error: if publishers_error:
raise publishers_error raise publishers_error
inputs = request["token_ids"]
# Decode the disaggregated params from the request # Decode the disaggregated params from the request
disaggregated_params = None disaggregated_params = None
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
request["stop_conditions"]["max_tokens"] = 1 request["stop_conditions"]["max_tokens"] = 1
disaggregated_params = LlmDisaggregatedParams(request_type="context_only") disaggregated_params = LlmDisaggregatedParams(request_type="context_only")
...@@ -122,6 +140,7 @@ class HandlerBase: ...@@ -122,6 +140,7 @@ class HandlerBase:
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
sampling_params = self.default_sampling_params sampling_params = self.default_sampling_params
for key, value in request["sampling_options"].items(): for key, value in request["sampling_options"].items():
if not value: if not value:
continue continue
...@@ -132,11 +151,11 @@ class HandlerBase: ...@@ -132,11 +151,11 @@ class HandlerBase:
if max_tokens: if max_tokens:
sampling_params.max_tokens = max_tokens sampling_params.max_tokens = max_tokens
ignore_eos = request["stop_conditions"]["ignore_eos"] ignore_eos = request["stop_conditions"].get("ignore_eos")
if ignore_eos: if ignore_eos:
sampling_params.ignore_eos = ignore_eos sampling_params.ignore_eos = ignore_eos
min_tokens = request["stop_conditions"]["min_tokens"] min_tokens = request["stop_conditions"].get("min_tokens")
if min_tokens: if min_tokens:
sampling_params.min_tokens = min_tokens sampling_params.min_tokens = min_tokens
...@@ -146,8 +165,12 @@ class HandlerBase: ...@@ -146,8 +165,12 @@ class HandlerBase:
False if self.disaggregation_mode == DisaggregationMode.PREFILL else True False if self.disaggregation_mode == DisaggregationMode.PREFILL else True
) )
request_id = request.get("id") or request.get("request_id", "unknown-id")
model_name = request.get("model", "unknown_model")
# NEW: Updated engine call to include multimodal data
async for res in self.engine.llm.generate_async( async for res in self.engine.llm.generate_async(
inputs=inputs, inputs=processed_input, # Use the correctly extracted inputs
sampling_params=sampling_params, sampling_params=sampling_params,
disaggregated_params=disaggregated_params, disaggregated_params=disaggregated_params,
streaming=streaming, streaming=streaming,
...@@ -158,7 +181,15 @@ class HandlerBase: ...@@ -158,7 +181,15 @@ class HandlerBase:
self.publisher.start() self.publisher.start()
self.first_generation = False self.first_generation = False
# Upon completion, send a final chunk with "stop" as the finish reason.
# This signals to the client that the stream has ended.
if res.finished and self.disaggregation_mode != DisaggregationMode.PREFILL: if res.finished and self.disaggregation_mode != DisaggregationMode.PREFILL:
if self.multimodal_processor:
final_out = self.multimodal_processor.get_stop_response(
request_id, model_name
)
yield final_out
else:
yield {"finish_reason": "stop", "token_ids": []} yield {"finish_reason": "stop", "token_ids": []}
break break
...@@ -167,7 +198,14 @@ class HandlerBase: ...@@ -167,7 +198,14 @@ class HandlerBase:
break break
output = res.outputs[0] output = res.outputs[0]
# The engine returns all tokens generated so far. We must calculate the new
# tokens generated in this iteration to create the "delta".
next_total_toks = len(output.token_ids) next_total_toks = len(output.token_ids)
if self.multimodal_processor:
out = self.multimodal_processor.create_response_chunk(
output, num_output_tokens_so_far, request_id, model_name
)
else:
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason: if output.finish_reason:
out["finish_reason"] = output.finish_reason out["finish_reason"] = output.finish_reason
...@@ -178,5 +216,6 @@ class HandlerBase: ...@@ -178,5 +216,6 @@ class HandlerBase:
out["disaggregated_params"] = asdict( out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params) DisaggregatedParamsCodec.encode(output.disaggregated_params)
) )
# Yield the chunk to the client and update the token count for the next iteration.
yield out yield out
num_output_tokens_so_far = next_total_toks num_output_tokens_so_far = next_total_toks
...@@ -124,6 +124,8 @@ class DecodeHandler(HandlerBase): ...@@ -124,6 +124,8 @@ class DecodeHandler(HandlerBase):
# If operating under decode_first strategy, the decode handler needs to trigger # If operating under decode_first strategy, the decode handler needs to trigger
# the prefill handler. # the prefill handler.
response_count = 0 response_count = 0
# Do not yield the prefill response directly.
# Instead, capture it and extract the state.
async for res in self.remote_prefill(request): async for res in self.remote_prefill(request):
prefill_response = res prefill_response = res
response_count += 1 response_count += 1
...@@ -136,6 +138,7 @@ class DecodeHandler(HandlerBase): ...@@ -136,6 +138,7 @@ class DecodeHandler(HandlerBase):
if prefill_response is not None and self.check_error(response_data): if prefill_response is not None and self.check_error(response_data):
yield response_data yield response_data
return return
if prefill_response is not None and response_data is not None: if prefill_response is not None and response_data is not None:
request["disaggregated_params"] = response_data["disaggregated_params"] request["disaggregated_params"] = response_data["disaggregated_params"]
......
...@@ -46,6 +46,7 @@ class Config: ...@@ -46,6 +46,7 @@ class Config:
DEFAULT_DISAGGREGATION_STRATEGY DEFAULT_DISAGGREGATION_STRATEGY
) )
self.next_endpoint: str = "" self.next_endpoint: str = ""
self.modality: str = "text"
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
...@@ -69,7 +70,8 @@ class Config: ...@@ -69,7 +70,8 @@ class Config:
f"publish_events_and_metrics={self.publish_events_and_metrics}, " f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, " f"disaggregation_mode={self.disaggregation_mode}, "
f"disaggregation_strategy={self.disaggregation_strategy}, " f"disaggregation_strategy={self.disaggregation_strategy}, "
f"next_endpoint={self.next_endpoint})" f"next_endpoint={self.next_endpoint}, "
f"modality={self.modality})"
) )
...@@ -215,6 +217,13 @@ def cmd_line_args(): ...@@ -215,6 +217,13 @@ def cmd_line_args():
choices=[strategy.value for strategy in DisaggregationStrategy], choices=[strategy.value for strategy in DisaggregationStrategy],
help=f"Strategy to use for disaggregation. Default: {DEFAULT_DISAGGREGATION_STRATEGY}", help=f"Strategy to use for disaggregation. Default: {DEFAULT_DISAGGREGATION_STRATEGY}",
) )
parser.add_argument(
"--modality",
type=str,
default="text",
choices=["text", "multimodal"],
help="Modality to use for the model. Default: text. Current supported modalities are image.",
)
parser.add_argument( parser.add_argument(
"--next-endpoint", "--next-endpoint",
type=str, type=str,
...@@ -279,5 +288,6 @@ def cmd_line_args(): ...@@ -279,5 +288,6 @@ def cmd_line_args():
config.migration_limit = args.migration_limit config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics config.publish_events_and_metrics = args.publish_events_and_metrics
config.modality = args.modality
return config return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment