Commit 9a3a31fd authored by Harry Mellor's avatar Harry Mellor Committed by khluu
Browse files

Don't compile vision encoder for Transformers backend (#30518)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: default avatarkhluu <khluu000@gmail.com>
parent 7bd3f40d
...@@ -205,6 +205,8 @@ def support_torch_compile( ...@@ -205,6 +205,8 @@ def support_torch_compile(
if v.annotation in [ if v.annotation in [
torch.Tensor, torch.Tensor,
torch.Tensor | None, torch.Tensor | None,
torch.FloatTensor,
torch.FloatTensor | None,
IntermediateTensors, IntermediateTensors,
IntermediateTensors | None, IntermediateTensors | None,
]: ]:
...@@ -346,7 +348,7 @@ def _support_torch_compile( ...@@ -346,7 +348,7 @@ def _support_torch_compile(
def __init__( def __init__(
self: _T, self: _T,
*, *args,
vllm_config: VllmConfig | None = None, vllm_config: VllmConfig | None = None,
prefix: str = "", prefix: str = "",
**kwargs: Any, **kwargs: Any,
...@@ -357,11 +359,24 @@ def _support_torch_compile( ...@@ -357,11 +359,24 @@ def _support_torch_compile(
# NOTE: to support multimodal models (such as encoder), # NOTE: to support multimodal models (such as encoder),
# we may not have vllm_config so we may need to patch it # we may not have vllm_config so we may need to patch it
sig = inspect.signature(old_init) sig = inspect.signature(old_init)
# Check that any positional arguments match the old_init method signature
annotations = [p.annotation for p in sig.parameters.values()]
for arg, annotation in zip(args, annotations):
if annotation is inspect._empty:
continue
if not isinstance(arg, annotation):
init = f"'{type(self).__name__}.__init__'"
arg_type = f"'{type(arg).__name__}'"
raise TypeError(
f"{init} received a positional argument of type {arg_type}, "
"but no parameter of that type was found in the method signature. "
f"Please either annotate {init} or pass it as a keyword argument."
)
if "vllm_config" in sig.parameters: if "vllm_config" in sig.parameters:
kwargs["vllm_config"] = vllm_config kwargs["vllm_config"] = vllm_config
if "prefix" in sig.parameters: if "prefix" in sig.parameters:
kwargs["prefix"] = prefix kwargs["prefix"] = prefix
old_init(self, **kwargs) old_init(self, *args, **kwargs)
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.compilation_config = self.vllm_config.compilation_config self.compilation_config = self.vllm_config.compilation_config
......
...@@ -488,9 +488,10 @@ class CompilationConfig: ...@@ -488,9 +488,10 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs).""" If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = False compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder. """Whether or not to compile the multimodal encoder.
Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models on selected
on selected platforms. Disabled by default until more models platforms. It may also work for models loaded with the Transformers modeling backend
are supported/tested to work.""" if the encoder is compilable. Disabled by default until more models are
supported/tested to work."""
# Vision encoder CUDA graph # Vision encoder CUDA graph
cudagraph_mm_encoder: bool = False cudagraph_mm_encoder: bool = False
......
...@@ -16,13 +16,11 @@ ...@@ -16,13 +16,11 @@
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` models""" """Wrapper around `transformers` models"""
from vllm.compilation.decorators import support_torch_compile
from vllm.model_executor.models.transformers.base import Base from vllm.model_executor.models.transformers.base import Base
from vllm.model_executor.models.transformers.causal import CausalMixin from vllm.model_executor.models.transformers.causal import CausalMixin
from vllm.model_executor.models.transformers.legacy import LegacyMixin from vllm.model_executor.models.transformers.legacy import LegacyMixin
from vllm.model_executor.models.transformers.moe import MoEMixin from vllm.model_executor.models.transformers.moe import MoEMixin
from vllm.model_executor.models.transformers.multimodal import ( from vllm.model_executor.models.transformers.multimodal import (
DYNAMIC_ARG_DIMS,
MultiModalDummyInputsBuilder, MultiModalDummyInputsBuilder,
MultiModalMixin, MultiModalMixin,
MultiModalProcessingInfo, MultiModalProcessingInfo,
...@@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import ( ...@@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import (
EmbeddingMixin, EmbeddingMixin,
SequenceClassificationMixin, SequenceClassificationMixin,
) )
from vllm.model_executor.models.transformers.utils import can_enable_torch_compile
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
# Text only models # Text only models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(CausalMixin, Base): ... class TransformersForCausalLM(CausalMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ... class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
...@@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ... ...@@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ... class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
...@@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ... ...@@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalMoEForCausalLM( class TransformersMultiModalMoEForCausalLM(
MoEMixin, MultiModalMixin, CausalMixin, Base MoEMixin, MultiModalMixin, CausalMixin, Base
): ... ): ...
# Embedding models # Embedding models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ... class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ... class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
...@@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ... ...@@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ... class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ...
# Sequence classification models # Sequence classification models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForSequenceClassification( class TransformersForSequenceClassification(
SequenceClassificationMixin, LegacyMixin, Base SequenceClassificationMixin, LegacyMixin, Base
): ... ): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification( class TransformersMoEForSequenceClassification(
SequenceClassificationMixin, MoEMixin, Base SequenceClassificationMixin, MoEMixin, Base
): ... ): ...
...@@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification( ...@@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification(
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForSequenceClassification( class TransformersMultiModalForSequenceClassification(
SequenceClassificationMixin, MultiModalMixin, Base SequenceClassificationMixin, MultiModalMixin, Base
): ... ): ...
......
...@@ -20,7 +20,9 @@ from collections.abc import Mapping ...@@ -20,7 +20,9 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from transformers import AutoModel
from vllm.compilation.decorators import should_torch_compile_mm_encoder
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -46,19 +48,11 @@ from vllm.platforms import current_platform ...@@ -46,19 +48,11 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import BatchFeature from transformers import BatchFeature, PreTrainedModel
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
DYNAMIC_ARG_DIMS = {
"input_ids": 0,
# set `positions` to last dim to support Qwen-mrope
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Skip SupportsMRoPE.__init__ and call the next class in MRO # Skip SupportsMRoPE.__init__ and call the next class in MRO
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix) super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
def _get_encoder_cls(
self, modality: str = "image", **kwargs: dict
) -> type["PreTrainedModel"]:
"""
Get the encoder class from the model.
Args:
kwargs: The kwargs to create the model.
Returns:
The encoder class.
"""
with torch.device("meta"):
model: PreTrainedModel = AutoModel.from_config(**kwargs)
encoder_cls = type(model.get_encoder(modality=modality))
logger.debug("Identified encoder class as: %s", encoder_cls)
if type(model) is encoder_cls:
raise ValueError(
"Unable to infer vision encoder class from the model. "
"You must either: update the model so that "
"https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.get_encoder"
" can detect the vision encoder correctly, or remove "
"'compile_mm_encoder'."
)
del model
return encoder_cls
def _decorate_for_torch_compile(self, **kwargs: dict):
"""
Decorate the model's decoder and encoder classes to indicate to vLLM that they
support torch compile if `can_enable_torch_compile` and
`should_torch_compile_mm_encoder` are True respectively.
Args:
kwargs: The kwargs to create the model, which are needed to get the decoder
and encoder classes.
"""
super()._decorate_for_torch_compile(**kwargs)
# Decorate the vision encoder model class to support torch compile if needed
if self.compilation_config.compile_mm_encoder:
self.check_version("5.0.0", "multimodal encoder compilation support")
logger.warning_once(
"Multimodal encoder compilation with the Transformers modeling backend "
"is an experimental feature. It relies on:\n"
"- The vision encoder being torch compilable.\n"
"- All vision encoder tensor inputs must be type hinted as either "
"`torch.Tensor` or `torch.FloatTensor`.\n"
"- The 0-th dimension of all tensor inputs to the vision encoder being "
"the dynamic dimension (i.e., sequence length or number of patches).\n"
"Please report any issues you encounter to help us improve it."
)
self._decorate_cls_for_torch_compile(
cls=self._get_encoder_cls(**kwargs),
# TODO: properly infer dynamic_arg_dims based on the encoder's forward
# method signature. Currently we assume dim 0 for all tensor inputs.
dynamic_arg_dims=None,
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,
...@@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly # Gemma3 and PaliGemma needs `token_type_ids` to work correctly
# Other models will not have `token_type_ids` in kwargs # Other models will not have `token_type_ids` in kwargs
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"} kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
# Positions shape handling for MRoPE models
if self.model_config.uses_mrope:
# [3, seq_len] -> [3, 1, seq_len]
positions = positions[:, None]
model_output = super().forward( model_output = super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment