Unverified Commit 47ed1227 authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

fix: Fix E + PD Multimodal Flow in trtllm (#6726)


Signed-off-by: default avatarIndrajit Bhosale <iamindrajitb@gmail.com>
parent 6243506b
...@@ -414,7 +414,7 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -414,7 +414,7 @@ class HandlerBase(BaseGenerativeHandler):
ep_disaggregated_params: Optional[Any], ep_disaggregated_params: Optional[Any],
) -> tuple[Any, Any, dict]: ) -> tuple[Any, Any, dict]:
""" """
Setup disaggregated_params based on PREFILL/DECODE mode. Setup disaggregated_params based on disaggregation mode.
For PREFILL mode: For PREFILL mode:
- Uses ep_disaggregated_params from encode worker if available - Uses ep_disaggregated_params from encode worker if available
...@@ -424,6 +424,11 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -424,6 +424,11 @@ class HandlerBase(BaseGenerativeHandler):
- Decodes disaggregated_params from prefill_result - Decodes disaggregated_params from prefill_result
- Extracts EPD metadata for prompt optimization - Extracts EPD metadata for prompt optimization
For PREFILL_AND_DECODE (aggregated) mode:
- Uses ep_disaggregated_params from encode worker if available
(passes multimodal_embedding_handles to TRT-LLM and sets
request_type="context_and_generation" for full prefill + decode)
Args: Args:
request: Request dictionary (may contain prefill_result) request: Request dictionary (may contain prefill_result)
ep_disaggregated_params: Optional params from encode worker (EPD flow) ep_disaggregated_params: Optional params from encode worker (EPD flow)
...@@ -444,6 +449,20 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -444,6 +449,20 @@ class HandlerBase(BaseGenerativeHandler):
request_type="context_only" request_type="context_only"
) )
# AGGREGATED (prefill_and_decode) mode with encoder disaggregation:
# Pass the encode worker's DisaggregatedParams (containing
# multimodal_embedding_handles) directly so TRT-LLM can import
# the vision embeddings. Use "context_and_generation" so the
# engine runs a full prefill + decode cycle.
elif (
self.disaggregation_mode == DisaggregationMode.AGGREGATED
and ep_disaggregated_params is not None
):
disaggregated_params = DisaggregatedParamsCodec.decode(
ep_disaggregated_params
)
disaggregated_params.request_type = "context_and_generation"
# DECODE mode: decode params from prefill_result # DECODE mode: decode params from prefill_result
prefill_result = request.get("prefill_result") prefill_result = request.get("prefill_result")
if prefill_result and "disaggregated_params" in prefill_result: if prefill_result and "disaggregated_params" in prefill_result:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 2048
max_batch_size: 8
max_seq_len: 8192
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
kv_cache_config:
free_gpu_memory_fraction: 0.30
enable_block_reuse: false
cache_transceiver_config:
backend: DEFAULT
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
tensor_parallel_size: 1 tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false enable_attention_dp: false
max_num_tokens: 1024 max_num_tokens: 1024
max_batch_size: 4 max_batch_size: 4
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
tensor_parallel_size: 1 tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false enable_attention_dp: false
max_num_tokens: 1024 max_num_tokens: 1024
max_batch_size: 4 max_batch_size: 4
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
tensor_parallel_size: 1 tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false enable_attention_dp: false
max_num_tokens: 1024 max_num_tokens: 1024
max_batch_size: 4 max_batch_size: 4
......
...@@ -2,22 +2,21 @@ ...@@ -2,22 +2,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# 1 Encode + 1 PD worker for llava-v1.6-mistral-7b-hf # 1 Encode + 1 PD worker for Qwen3-VL-2B-Instruct
# GPU 0: Encode (vision encoder) # GPU 0: Encode (vision encoder)
# GPU 1: PD worker (prefill + decode, TP=1) # GPU 0: PD worker (prefill + decode, TP=1)
# Environment variables with defaults # Environment variables with defaults
export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"} export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"}
export MODEL_PATH=${MODEL_PATH:-"llava-hf/llava-v1.6-mistral-7b-hf"} export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen3-VL-2B-Instruct"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"llava-v1.6-mistral-7b-hf"} export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen3-VL-2B-Instruct"}
export ENCODE_ENGINE_ARGS=${ENCODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/encode.yaml"} export ENCODE_ENGINE_ARGS=${ENCODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/qwen3-vl-2b-instruct/encode.yaml"}
export PD_ENGINE_ARGS=${PD_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/agg.yaml"} export PD_ENGINE_ARGS=${PD_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/qwen3-vl-2b-instruct/agg.yaml"}
export ENCODE_CUDA_VISIBLE_DEVICES=${ENCODE_CUDA_VISIBLE_DEVICES:-"0"} export ENCODE_CUDA_VISIBLE_DEVICES=${ENCODE_CUDA_VISIBLE_DEVICES:-"0"}
export ENCODE_ENDPOINT=${ENCODE_ENDPOINT:-"dyn://dynamo.tensorrt_llm_encode.generate"} export ENCODE_ENDPOINT=${ENCODE_ENDPOINT:-"dyn://dynamo.tensorrt_llm_encode.generate"}
export MODALITY=${MODALITY:-"multimodal"} export MODALITY=${MODALITY:-"multimodal"}
export ALLOWED_LOCAL_MEDIA_PATH=${ALLOWED_LOCAL_MEDIA_PATH:-"/tmp"} export ALLOWED_LOCAL_MEDIA_PATH=${ALLOWED_LOCAL_MEDIA_PATH:-"/tmp"}
export MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50} export MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
export CUSTOM_TEMPLATE=${CUSTOM_TEMPLATE:-"$DYNAMO_HOME/examples/backends/trtllm/templates/llava_multimodal.jinja"}
# Extra arguments forwarded to the PD worker (e.g. --multimodal-embedding-cache-capacity-gb 10) # Extra arguments forwarded to the PD worker (e.g. --multimodal-embedding-cache-capacity-gb 10)
EXTRA_PD_ARGS=("$@") EXTRA_PD_ARGS=("$@")
...@@ -47,13 +46,12 @@ CUDA_VISIBLE_DEVICES=$ENCODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \ ...@@ -47,13 +46,12 @@ CUDA_VISIBLE_DEVICES=$ENCODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--disaggregation-mode encode & --disaggregation-mode encode &
ENCODE_PID=$! ENCODE_PID=$!
# run PD worker 1 (GPU 1) # run PD worker 1 (GPU 0)
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.trtllm \ CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \ --model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \ --served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PD_ENGINE_ARGS" \ --extra-engine-args "$PD_ENGINE_ARGS" \
--modality "$MODALITY" \ --modality "$MODALITY" \
--custom-jinja-template "$CUSTOM_TEMPLATE" \
--encode-endpoint "$ENCODE_ENDPOINT" \ --encode-endpoint "$ENCODE_ENDPOINT" \
--disaggregation-mode prefill_and_decode \ --disaggregation-mode prefill_and_decode \
"${EXTRA_PD_ARGS[@]}" & "${EXTRA_PD_ARGS[@]}" &
......
...@@ -237,6 +237,32 @@ trtllm_configs = { ...@@ -237,6 +237,32 @@ trtllm_configs = {
"ENCODE_CUDA_VISIBLE_DEVICES": "0", "ENCODE_CUDA_VISIBLE_DEVICES": "0",
}, },
), ),
# Test Encoder with Aggregated PD worker on same GPU
# Make this pre-merge after TRTLLM #5938603 is fixed
"e_pd_multimodal": TRTLLMConfig(
name="e_pd_multimodal",
directory=trtllm_dir,
script_name="disagg_e_pd.sh",
marks=[
pytest.mark.gpu_1,
pytest.mark.trtllm,
pytest.mark.multimodal,
pytest.mark.nightly,
],
model="Qwen/Qwen3-VL-2B-Instruct",
frontend_port=DefaultPort.FRONTEND.value,
timeout=900,
delayed_start=120,
request_payloads=[
multimodal_payload_default(
text="Describe what you see in this image.",
expected_response=["mountain", "rock", "trees", "road"],
)
],
env={
"ENCODE_CUDA_VISIBLE_DEVICES": "0",
},
),
"completions_only": TRTLLMConfig( "completions_only": TRTLLMConfig(
name="completions_only", name="completions_only",
directory=trtllm_dir, directory=trtllm_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment