Commit 9a3a31fd authored by Harry Mellor's avatar Harry Mellor Committed by khluu
Browse files

Don't compile vision encoder for Transformers backend (#30518)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: default avatarkhluu <khluu000@gmail.com>
parent 7bd3f40d
......@@ -205,6 +205,8 @@ def support_torch_compile(
if v.annotation in [
torch.Tensor,
torch.Tensor | None,
torch.FloatTensor,
torch.FloatTensor | None,
IntermediateTensors,
IntermediateTensors | None,
]:
......@@ -346,7 +348,7 @@ def _support_torch_compile(
def __init__(
self: _T,
*,
*args,
vllm_config: VllmConfig | None = None,
prefix: str = "",
**kwargs: Any,
......@@ -357,11 +359,24 @@ def _support_torch_compile(
# NOTE: to support multimodal models (such as encoder),
# we may not have vllm_config so we may need to patch it
sig = inspect.signature(old_init)
# Check that any positional arguments match the old_init method signature
annotations = [p.annotation for p in sig.parameters.values()]
for arg, annotation in zip(args, annotations):
if annotation is inspect._empty:
continue
if not isinstance(arg, annotation):
init = f"'{type(self).__name__}.__init__'"
arg_type = f"'{type(arg).__name__}'"
raise TypeError(
f"{init} received a positional argument of type {arg_type}, "
"but no parameter of that type was found in the method signature. "
f"Please either annotate {init} or pass it as a keyword argument."
)
if "vllm_config" in sig.parameters:
kwargs["vllm_config"] = vllm_config
if "prefix" in sig.parameters:
kwargs["prefix"] = prefix
old_init(self, **kwargs)
old_init(self, *args, **kwargs)
self.vllm_config = vllm_config
self.compilation_config = self.vllm_config.compilation_config
......
......@@ -488,9 +488,10 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder.
Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models
on selected platforms. Disabled by default until more models
are supported/tested to work."""
Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models on selected
platforms. It may also work for models loaded with the Transformers modeling backend
if the encoder is compilable. Disabled by default until more models are
supported/tested to work."""
# Vision encoder CUDA graph
cudagraph_mm_encoder: bool = False
......
......@@ -16,13 +16,11 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
from vllm.compilation.decorators import support_torch_compile
from vllm.model_executor.models.transformers.base import Base
from vllm.model_executor.models.transformers.causal import CausalMixin
from vllm.model_executor.models.transformers.legacy import LegacyMixin
from vllm.model_executor.models.transformers.moe import MoEMixin
from vllm.model_executor.models.transformers.multimodal import (
DYNAMIC_ARG_DIMS,
MultiModalDummyInputsBuilder,
MultiModalMixin,
MultiModalProcessingInfo,
......@@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import (
EmbeddingMixin,
SequenceClassificationMixin,
)
from vllm.model_executor.models.transformers.utils import can_enable_torch_compile
from vllm.multimodal import MULTIMODAL_REGISTRY
# Text only models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(CausalMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
......@@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
......@@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalMoEForCausalLM(
MoEMixin, MultiModalMixin, CausalMixin, Base
): ...
# Embedding models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
......@@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ...
# Sequence classification models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForSequenceClassification(
SequenceClassificationMixin, LegacyMixin, Base
): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification(
SequenceClassificationMixin, MoEMixin, Base
): ...
......@@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification(
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForSequenceClassification(
SequenceClassificationMixin, MultiModalMixin, Base
): ...
......
......@@ -20,7 +20,9 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING
import torch
from transformers import AutoModel
from vllm.compilation.decorators import should_torch_compile_mm_encoder
from vllm.config.utils import getattr_iter
from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input
from vllm.logger import init_logger
......@@ -46,19 +48,11 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from transformers import BatchFeature
from transformers import BatchFeature, PreTrainedModel
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
DYNAMIC_ARG_DIMS = {
"input_ids": 0,
# set `positions` to last dim to support Qwen-mrope
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
logger = init_logger(__name__)
......@@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Skip SupportsMRoPE.__init__ and call the next class in MRO
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
def _get_encoder_cls(
self, modality: str = "image", **kwargs: dict
) -> type["PreTrainedModel"]:
"""
Get the encoder class from the model.
Args:
kwargs: The kwargs to create the model.
Returns:
The encoder class.
"""
with torch.device("meta"):
model: PreTrainedModel = AutoModel.from_config(**kwargs)
encoder_cls = type(model.get_encoder(modality=modality))
logger.debug("Identified encoder class as: %s", encoder_cls)
if type(model) is encoder_cls:
raise ValueError(
"Unable to infer vision encoder class from the model. "
"You must either: update the model so that "
"https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.get_encoder"
" can detect the vision encoder correctly, or remove "
"'compile_mm_encoder'."
)
del model
return encoder_cls
def _decorate_for_torch_compile(self, **kwargs: dict):
"""
Decorate the model's decoder and encoder classes to indicate to vLLM that they
support torch compile if `can_enable_torch_compile` and
`should_torch_compile_mm_encoder` are True respectively.
Args:
kwargs: The kwargs to create the model, which are needed to get the decoder
and encoder classes.
"""
super()._decorate_for_torch_compile(**kwargs)
# Decorate the vision encoder model class to support torch compile if needed
if self.compilation_config.compile_mm_encoder:
self.check_version("5.0.0", "multimodal encoder compilation support")
logger.warning_once(
"Multimodal encoder compilation with the Transformers modeling backend "
"is an experimental feature. It relies on:\n"
"- The vision encoder being torch compilable.\n"
"- All vision encoder tensor inputs must be type hinted as either "
"`torch.Tensor` or `torch.FloatTensor`.\n"
"- The 0-th dimension of all tensor inputs to the vision encoder being "
"the dynamic dimension (i.e., sequence length or number of patches).\n"
"Please report any issues you encounter to help us improve it."
)
self._decorate_cls_for_torch_compile(
cls=self._get_encoder_cls(**kwargs),
# TODO: properly infer dynamic_arg_dims based on the encoder's forward
# method signature. Currently we assume dim 0 for all tensor inputs.
dynamic_arg_dims=None,
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
def forward(
self,
input_ids: torch.Tensor | None,
......@@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
# Other models will not have `token_type_ids` in kwargs
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
# Positions shape handling for MRoPE models
if self.model_config.uses_mrope:
# [3, seq_len] -> [3, 1, seq_len]
positions = positions[:, None]
model_output = super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment