Unverified Commit f0bfda1e authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

fix: Skip Encoder llm creation for unsupported models in trtllm (#6866)


Signed-off-by: default avatarIndrajit Bhosale <iamindrajitb@gmail.com>
parent eb0bf24e
...@@ -432,8 +432,15 @@ class EncodeHelper: ...@@ -432,8 +432,15 @@ class EncodeHelper:
"error": "model_dir and model_type are required for full EPD encode" "error": "model_dir and model_type are required for full EPD encode"
} }
return return
if engine is None: if engine is None or not engine.encoder_available:
yield {"error": "No engine configured on encode worker for full EPD"} yield {
"error": (
"MultimodalEncoder is not available on this encode worker. "
"The model architecture may not support standalone encoder "
"in TRT-LLM. Use the embedding-path flow or run without "
"disaggregated encode mode."
)
}
return return
# Use token_ids from request (Rust preprocessor already applied # Use token_ids from request (Rust preprocessor already applied
# chat template and tokenized; token_ids then include image placeholder tokens # chat template and tokenized; token_ids then include image placeholder tokens
......
...@@ -9,11 +9,17 @@ from typing import AsyncGenerator, Optional ...@@ -9,11 +9,17 @@ from typing import AsyncGenerator, Optional
from tensorrt_llm import LLM, MultimodalEncoder from tensorrt_llm import LLM, MultimodalEncoder
from tensorrt_llm.llmapi.llm import BaseLLM from tensorrt_llm.llmapi.llm import BaseLLM
from transformers import AutoConfig
from dynamo.trtllm.constants import DisaggregationMode from dynamo.trtllm.constants import DisaggregationMode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Model architectures without standalone encoder support in TRT-LLM
# (missing @register_vision_encoder). These handle vision encoding
# inside the main model (prefill/decode) instead.
_UNSUPPORTED_STANDALONE_ENCODER_ARCHS = {"Llama4ForConditionalGeneration"}
class Backend(str, enum.Enum): class Backend(str, enum.Enum):
"""Supported TensorRT-LLM backend types.""" """Supported TensorRT-LLM backend types."""
...@@ -52,6 +58,11 @@ class TensorRTLLMEngine: ...@@ -52,6 +58,11 @@ class TensorRTLLMEngine:
self.engine_args = engine_args self.engine_args = engine_args
@property
def encoder_available(self) -> bool:
"""Whether the multimodal encoder LLM is initialized."""
return self._llm is not None
async def initialize(self): async def initialize(self):
if not self._llm: if not self._llm:
if self.disaggregation_mode == DisaggregationMode.ENCODE: if self.disaggregation_mode == DisaggregationMode.ENCODE:
...@@ -60,8 +71,14 @@ class TensorRTLLMEngine: ...@@ -60,8 +71,14 @@ class TensorRTLLMEngine:
# (model, backend settings, kv cache config, etc.). ENCODE workers instead use # (model, backend settings, kv cache config, etc.). ENCODE workers instead use
# TRT-LLM's `MultimodalEncoder`, which has a different constructor surface. # TRT-LLM's `MultimodalEncoder`, which has a different constructor surface.
# We intentionally pass only the supported parameters to avoid unexpected kwargs. # We intentionally pass only the supported parameters to avoid unexpected kwargs.
max_batch_size = self.engine_args.get("max_batch_size", 1)
model = self.engine_args.get("model") model = self.engine_args.get("model")
# Skip MultimodalEncoder for architectures that handle vision
# encoding inside the main model (e.g. Llama4).
if self._is_unsupported_encoder_arch(model):
return
max_batch_size = self.engine_args.get("max_batch_size", 1)
logging.info( logging.info(
f"Initializing multimodal encoder with max_batch_size: {max_batch_size}" f"Initializing multimodal encoder with max_batch_size: {max_batch_size}"
) )
...@@ -135,6 +152,17 @@ class TensorRTLLMEngine: ...@@ -135,6 +152,17 @@ class TensorRTLLMEngine:
field_name, field_name,
) )
@staticmethod
def _is_unsupported_encoder_arch(model_path: str) -> bool:
"""Return True if *model_path*'s architecture is not supported by
TRT-LLM's standalone MultimodalEncoder."""
try:
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
archs = getattr(config, "architectures", None) or []
return any(a in _UNSUPPORTED_STANDALONE_ENCODER_ARCHS for a in archs)
except Exception:
return False
@asynccontextmanager @asynccontextmanager
async def get_llm_engine( async def get_llm_engine(
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
<</SYS>> <</SYS>>
{% elif message['role'] == 'user' -%} {% elif message['role'] == 'user' -%}
[INST] {% if message['content'] is string %}{{ message['content'] }}{% else %}{% for item in message['content'] %}{% if item['type'] == 'image_url' %}<image> [INST] {% if message['content'] is string %}{{ message['content'] }}{% else %}{% for item in message['content'] %}{% if item['type'] == 'image_url' or item['type'] == 'image' %}<image>
{% elif item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %} [/INST] {% elif item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %} [/INST]
{% elif message['role'] == 'assistant' -%} {% elif message['role'] == 'assistant' -%}
{{ message['content'] }}</s> {{ message['content'] }}</s>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment