Unverified Commit 8d30cd45 authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

fix: Multimodal flag was ignored for TRTLLM (#6468)

* Why?

Commit `5a67b246` refactored configs for the TRTLLM backend, breaking
`--modality multimodal`.

* What?

This commit fixes this bug, and adds a unit test verified to fail
without it.
parent 42d69805
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
"""Unit tests for TRTLLM backend components.""" """Unit tests for TRTLLM backend components."""
import asyncio
import re import re
from pathlib import Path from pathlib import Path
from unittest import mock
import pytest import pytest
import torch import torch
...@@ -17,8 +19,10 @@ if not torch.cuda.is_available(): ...@@ -17,8 +19,10 @@ if not torch.cuda.is_available():
) )
from dynamo.trtllm.args import Config, parse_args from dynamo.trtllm.args import Config, parse_args
from dynamo.trtllm.constants import Modality
from dynamo.trtllm.tests.conftest import make_cli_args_fixture from dynamo.trtllm.tests.conftest import make_cli_args_fixture
from dynamo.trtllm.utils.trtllm_utils import deep_update from dynamo.trtllm.utils.trtllm_utils import deep_update
from dynamo.trtllm.workers.llm_worker import init_llm_worker
# Get path relative to this test file # Get path relative to this test file
REPO_ROOT = Path(__file__).resolve().parents[5] REPO_ROOT = Path(__file__).resolve().parents[5]
...@@ -164,3 +168,27 @@ def test_deep_update_adds_new_keys(): ...@@ -164,3 +168,27 @@ def test_deep_update_adds_new_keys():
source = {"b": 2, "c": {"nested": 3}} source = {"b": 2, "c": {"nested": 3}}
deep_update(target, source) deep_update(target, source)
assert target == {"a": 1, "b": 2, "c": {"nested": 3}} assert target == {"a": 1, "b": 2, "c": {"nested": 3}}
class MultimodalProcessorInstantiated(Exception):
"""Custom exception for testing MultimodalRequestProcessor."""
@pytest.mark.asyncio
async def test_init_llm_worker_creates_multimodal_processor():
config = parse_args(["--model", "fake-model", "--modality", "multimodal"])
assert config.modality == Modality.MULTIMODAL
# Mock everything init_llm_worker touches before MultimodalRequestProcessor.
with mock.patch("dynamo.trtllm.workers.llm_worker.tokenizer_factory"), mock.patch(
"dynamo.trtllm.workers.llm_worker.AutoConfig.from_pretrained",
), mock.patch(
"dynamo.trtllm.workers.llm_worker.MultimodalRequestProcessor",
side_effect=MultimodalProcessorInstantiated,
):
with pytest.raises(MultimodalProcessorInstantiated):
await init_llm_worker(
runtime=mock.MagicMock(),
config=config,
shutdown_event=asyncio.Event(),
)
...@@ -46,7 +46,7 @@ from dynamo.llm import ( ...@@ -46,7 +46,7 @@ from dynamo.llm import (
) )
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config from dynamo.trtllm.args import Config
from dynamo.trtllm.constants import DisaggregationMode from dynamo.trtllm.constants import DisaggregationMode, Modality
from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
...@@ -290,7 +290,7 @@ async def init_llm_worker( ...@@ -290,7 +290,7 @@ async def init_llm_worker(
# This overrides the skip_tokenizer_init=True set earlier # This overrides the skip_tokenizer_init=True set earlier
engine_args["skip_tokenizer_init"] = False engine_args["skip_tokenizer_init"] = False
if modality == "multimodal": if modality == Modality.MULTIMODAL:
engine_args["skip_tokenizer_init"] = False engine_args["skip_tokenizer_init"] = False
model_config = AutoConfig.from_pretrained(config.model, trust_remote_code=True) model_config = AutoConfig.from_pretrained(config.model, trust_remote_code=True)
multimodal_processor = MultimodalRequestProcessor( multimodal_processor = MultimodalRequestProcessor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment