Unverified Commit 8d30cd45 authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

fix: Multimodal flag was ignored for TRTLLM (#6468)

* Why?

Commit `5a67b246` refactored configs for the TRTLLM backend, breaking
`--modality multimodal`.

* What?

This commit fixes this bug, and adds a unit test verified to fail
without it.
parent 42d69805
......@@ -3,8 +3,10 @@
"""Unit tests for TRTLLM backend components."""
import asyncio
import re
from pathlib import Path
from unittest import mock
import pytest
import torch
......@@ -17,8 +19,10 @@ if not torch.cuda.is_available():
)
from dynamo.trtllm.args import Config, parse_args
from dynamo.trtllm.constants import Modality
from dynamo.trtllm.tests.conftest import make_cli_args_fixture
from dynamo.trtllm.utils.trtllm_utils import deep_update
from dynamo.trtllm.workers.llm_worker import init_llm_worker
# Get path relative to this test file
REPO_ROOT = Path(__file__).resolve().parents[5]
......@@ -164,3 +168,27 @@ def test_deep_update_adds_new_keys():
source = {"b": 2, "c": {"nested": 3}}
deep_update(target, source)
assert target == {"a": 1, "b": 2, "c": {"nested": 3}}
class MultimodalProcessorInstantiated(Exception):
"""Custom exception for testing MultimodalRequestProcessor."""
@pytest.mark.asyncio
async def test_init_llm_worker_creates_multimodal_processor():
config = parse_args(["--model", "fake-model", "--modality", "multimodal"])
assert config.modality == Modality.MULTIMODAL
# Mock everything init_llm_worker touches before MultimodalRequestProcessor.
with mock.patch("dynamo.trtllm.workers.llm_worker.tokenizer_factory"), mock.patch(
"dynamo.trtllm.workers.llm_worker.AutoConfig.from_pretrained",
), mock.patch(
"dynamo.trtllm.workers.llm_worker.MultimodalRequestProcessor",
side_effect=MultimodalProcessorInstantiated,
):
with pytest.raises(MultimodalProcessorInstantiated):
await init_llm_worker(
runtime=mock.MagicMock(),
config=config,
shutdown_event=asyncio.Event(),
)
......@@ -46,7 +46,7 @@ from dynamo.llm import (
)
from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config
from dynamo.trtllm.constants import DisaggregationMode
from dynamo.trtllm.constants import DisaggregationMode, Modality
from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
......@@ -290,7 +290,7 @@ async def init_llm_worker(
# This overrides the skip_tokenizer_init=True set earlier
engine_args["skip_tokenizer_init"] = False
if modality == "multimodal":
if modality == Modality.MULTIMODAL:
engine_args["skip_tokenizer_init"] = False
model_config = AutoConfig.from_pretrained(config.model, trust_remote_code=True)
multimodal_processor = MultimodalRequestProcessor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment