"tests/vscode:/vscode.git/clone" did not exist on "ca1969186dde5fc0f76d22f2124dc0e6e0c9b792"
Unverified Commit dfe5e316 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Don't compile vision encoder for Transformers backend (#30518)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 2ce3d0ce
...@@ -205,6 +205,8 @@ def support_torch_compile( ...@@ -205,6 +205,8 @@ def support_torch_compile(
if v.annotation in [ if v.annotation in [
torch.Tensor, torch.Tensor,
torch.Tensor | None, torch.Tensor | None,
torch.FloatTensor,
torch.FloatTensor | None,
IntermediateTensors, IntermediateTensors,
IntermediateTensors | None, IntermediateTensors | None,
]: ]:
...@@ -346,7 +348,7 @@ def _support_torch_compile( ...@@ -346,7 +348,7 @@ def _support_torch_compile(
def __init__( def __init__(
self: _T, self: _T,
*, *args,
vllm_config: VllmConfig | None = None, vllm_config: VllmConfig | None = None,
prefix: str = "", prefix: str = "",
**kwargs: Any, **kwargs: Any,
...@@ -357,11 +359,24 @@ def _support_torch_compile( ...@@ -357,11 +359,24 @@ def _support_torch_compile(
# NOTE: to support multimodal models (such as encoder), # NOTE: to support multimodal models (such as encoder),
# we may not have vllm_config so we may need to patch it # we may not have vllm_config so we may need to patch it
sig = inspect.signature(old_init) sig = inspect.signature(old_init)
# Check that any positional arguments match the old_init method signature
annotations = [p.annotation for p in sig.parameters.values()]
for arg, annotation in zip(args, annotations):
if annotation is inspect._empty:
continue
if not isinstance(arg, annotation):
init = f"'{type(self).__name__}.__init__'"
arg_type = f"'{type(arg).__name__}'"
raise TypeError(
f"{init} received a positional argument of type {arg_type}, "
"but no parameter of that type was found in the method signature. "
f"Please either annotate {init} or pass it as a keyword argument."
)
if "vllm_config" in sig.parameters: if "vllm_config" in sig.parameters:
kwargs["vllm_config"] = vllm_config kwargs["vllm_config"] = vllm_config
if "prefix" in sig.parameters: if "prefix" in sig.parameters:
kwargs["prefix"] = prefix kwargs["prefix"] = prefix
old_init(self, **kwargs) old_init(self, *args, **kwargs)
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.compilation_config = self.vllm_config.compilation_config self.compilation_config = self.vllm_config.compilation_config
......
...@@ -495,9 +495,10 @@ class CompilationConfig: ...@@ -495,9 +495,10 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs).""" If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = False compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder. """Whether or not to compile the multimodal encoder.
Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models on selected
on selected platforms. Disabled by default until more models platforms. It may also work for models loaded with the Transformers modeling backend
are supported/tested to work.""" if the encoder is compilable. Disabled by default until more models are
supported/tested to work."""
# Vision encoder CUDA graph # Vision encoder CUDA graph
cudagraph_mm_encoder: bool = False cudagraph_mm_encoder: bool = False
......
...@@ -16,13 +16,11 @@ ...@@ -16,13 +16,11 @@
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` models""" """Wrapper around `transformers` models"""
from vllm.compilation.decorators import support_torch_compile
from vllm.model_executor.models.transformers.base import Base from vllm.model_executor.models.transformers.base import Base
from vllm.model_executor.models.transformers.causal import CausalMixin from vllm.model_executor.models.transformers.causal import CausalMixin
from vllm.model_executor.models.transformers.legacy import LegacyMixin from vllm.model_executor.models.transformers.legacy import LegacyMixin
from vllm.model_executor.models.transformers.moe import MoEMixin from vllm.model_executor.models.transformers.moe import MoEMixin
from vllm.model_executor.models.transformers.multimodal import ( from vllm.model_executor.models.transformers.multimodal import (
DYNAMIC_ARG_DIMS,
MultiModalDummyInputsBuilder, MultiModalDummyInputsBuilder,
MultiModalMixin, MultiModalMixin,
MultiModalProcessingInfo, MultiModalProcessingInfo,
...@@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import ( ...@@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import (
EmbeddingMixin, EmbeddingMixin,
SequenceClassificationMixin, SequenceClassificationMixin,
) )
from vllm.model_executor.models.transformers.utils import can_enable_torch_compile
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
# Text only models # Text only models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(CausalMixin, Base): ... class TransformersForCausalLM(CausalMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ... class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
...@@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ... ...@@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ... class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
...@@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ... ...@@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalMoEForCausalLM( class TransformersMultiModalMoEForCausalLM(
MoEMixin, MultiModalMixin, CausalMixin, Base MoEMixin, MultiModalMixin, CausalMixin, Base
): ... ): ...
# Embedding models # Embedding models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ... class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ... class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
...@@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ... ...@@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ... class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ...
# Sequence classification models # Sequence classification models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForSequenceClassification( class TransformersForSequenceClassification(
SequenceClassificationMixin, LegacyMixin, Base SequenceClassificationMixin, LegacyMixin, Base
): ... ): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification( class TransformersMoEForSequenceClassification(
SequenceClassificationMixin, MoEMixin, Base SequenceClassificationMixin, MoEMixin, Base
): ... ): ...
...@@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification( ...@@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification(
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForSequenceClassification( class TransformersMultiModalForSequenceClassification(
SequenceClassificationMixin, MultiModalMixin, Base SequenceClassificationMixin, MultiModalMixin, Base
): ... ): ...
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
"""Transformers modeling backend base class.""" """Transformers modeling backend base class."""
import sys
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import chain from itertools import chain
from operator import attrgetter from operator import attrgetter
...@@ -29,6 +30,7 @@ from torch import nn ...@@ -29,6 +30,7 @@ from torch import nn
from transformers import AutoModel from transformers import AutoModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.compilation.decorators import support_torch_compile
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
...@@ -47,6 +49,7 @@ from vllm.model_executor.models.interfaces import ( ...@@ -47,6 +49,7 @@ from vllm.model_executor.models.interfaces import (
) )
from vllm.model_executor.models.interfaces_base import VllmModel from vllm.model_executor.models.interfaces_base import VllmModel
from vllm.model_executor.models.transformers.utils import ( from vllm.model_executor.models.transformers.utils import (
can_enable_torch_compile,
get_feature_request_tip, get_feature_request_tip,
init_on_device_without_buffers, init_on_device_without_buffers,
log_replacement, log_replacement,
...@@ -117,6 +120,7 @@ class Base( ...@@ -117,6 +120,7 @@ class Base(
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.text_config = self.config.get_text_config() self.text_config = self.config.get_text_config()
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device_config = vllm_config.device_config self.device_config = vllm_config.device_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
...@@ -155,14 +159,16 @@ class Base( ...@@ -155,14 +159,16 @@ class Base(
if "gptq" in quant_method_name: if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias") self.ignore_unexpected_suffixes.append(".bias")
# Patch config and init on "meta" to delay allocating GPU tensors
self._patch_config() self._patch_config()
with init_on_device_without_buffers("meta"): from_config_kwargs = dict(
self.model: PreTrainedModel = AutoModel.from_config( config=self.config,
self.config,
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
self._decorate_for_torch_compile(**from_config_kwargs)
# Init on "meta" to delay allocating GPU tensors
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(**from_config_kwargs)
# Create weight name to module qualname mapper # Create weight name to module qualname mapper
self._create_hf_to_vllm_mapper() self._create_hf_to_vllm_mapper()
...@@ -218,6 +224,82 @@ class Base( ...@@ -218,6 +224,82 @@ class Base(
if sub_config.dtype != (dtype := self.config.dtype): if sub_config.dtype != (dtype := self.config.dtype):
sub_config.dtype = dtype sub_config.dtype = dtype
def _get_decoder_cls(self, **kwargs: dict) -> type[PreTrainedModel]:
"""
Get the decoder class from the model.
Args:
kwargs: The kwargs to create the model.
Returns:
The decoder class.
"""
with torch.device("meta"):
model: PreTrainedModel = AutoModel.from_config(**kwargs)
decoder_cls = type(model.get_decoder())
logger.debug("Identified decoder class as: %s", decoder_cls)
del model
return decoder_cls
def _decorate_cls_for_torch_compile(
self,
cls: type[PreTrainedModel],
dynamic_arg_dims: dict[str, int] | None,
enable_if: Callable[["VllmConfig"], bool],
is_encoder: bool,
):
"""
Decorate `cls` to indicate to vLLM that it supports torch compile.
Args:
cls: The PreTrainedModel class to decorate.
dynamic_arg_dims: A mapping from argument name to the dynamic dimensions
of the argument. If None, default dynamic arg dims will be used. See
[`support_torch_compile`][vllm.compilation.decorators.support_torch_compile]
for more details.
enable_if: A function which takes in the vLLM config and returns whether
torch compile should be enabled for this class.
is_encoder: Whether the class being decorated is an encoder.
"""
logger.debug(
"Decorating `%s` as %s for torch compile with dynamic_arg_dims of %s",
cls.__name__,
"encoder" if is_encoder else "decoder",
dynamic_arg_dims,
)
@support_torch_compile(
dynamic_arg_dims=dynamic_arg_dims,
enable_if=enable_if,
is_encoder=is_encoder,
)
class SupportTorchCompileWrapper(cls): ...
# Patch the class in its module
module = sys.modules[cls.__module__]
setattr(module, cls.__name__, SupportTorchCompileWrapper)
def _decorate_for_torch_compile(self, **kwargs: dict):
"""
Decorate the model's decoder class to indicate to vLLM that it supports torch
compile if `can_enable_torch_compile` is True.
Args:
kwargs: The kwargs to create the model, which are needed to get the decoder
class.
"""
self._decorate_cls_for_torch_compile(
cls=self._get_decoder_cls(**kwargs),
# Applied to a PreTrainedModel so the batch dimension will exist
dynamic_arg_dims=dict[str, int](
input_ids=1, # shape: [1, seq_len]
inputs_embeds=1, # shape: [1, seq_len, hidden_size]
position_ids=-1, # shape: [1, seq_len] or [3, 1, seq_len] for mrope
),
enable_if=can_enable_torch_compile,
is_encoder=False,
)
def _create_hf_to_vllm_mapper(self): def _create_hf_to_vllm_mapper(self):
""" """
Create a WeightsMapper to map checkpoint weight names to module qualnames. Create a WeightsMapper to map checkpoint weight names to module qualnames.
...@@ -553,11 +635,6 @@ class Base( ...@@ -553,11 +635,6 @@ class Base(
input_ids = None input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"] inputs_embeds = intermediate_tensors["hidden_states"]
if input_ids is not None:
input_ids = input_ids[None, ...]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
# If the model scales embeddings inside the input embedding layer we must # If the model scales embeddings inside the input embedding layer we must
# ensure they are scaled here since VocabParallelEmbedding will not do it # ensure they are scaled here since VocabParallelEmbedding will not do it
if ( if (
...@@ -568,22 +645,29 @@ class Base( ...@@ -568,22 +645,29 @@ class Base(
inputs_embeds = self.embed_input_ids(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
input_ids = None input_ids = None
if self.model_config.uses_mrope: # Add batch dimension before entering Transformers model
position_ids = positions[:, None] if input_ids is not None and input_ids.ndim == 1:
else: # [seq_len] -> [1, seq_len]
position_ids = positions[None, ...] input_ids = input_ids[None, ...]
if inputs_embeds is not None and inputs_embeds.ndim == 2:
# [seq_len, hidden_size] -> [1, seq_len, hidden_size]
inputs_embeds = inputs_embeds[None, ...]
if positions.ndim == 1:
# [seq_len] -> [1, seq_len]
positions = positions[None, ...]
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=False, use_cache=False,
position_ids=position_ids, position_ids=positions,
attention_instances=self.attention_instances, attention_instances=self.attention_instances,
return_dict=False, return_dict=False,
**self._output_aux_hidden_states_kwargs, **self._output_aux_hidden_states_kwargs,
**kwargs, **kwargs,
) )
# We must remove the batch dimension from these outputs
# Remove batch dimension after exiting Transformers model
hidden_states = outputs[0][0, ...] hidden_states = outputs[0][0, ...]
if self._output_aux_hidden_states_kwargs: if self._output_aux_hidden_states_kwargs:
aux_hidden_states = [x[0][0, ...] for x in outputs[1:]] aux_hidden_states = [x[0][0, ...] for x in outputs[1:]]
......
...@@ -20,7 +20,9 @@ from collections.abc import Mapping ...@@ -20,7 +20,9 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from transformers import AutoModel
from vllm.compilation.decorators import should_torch_compile_mm_encoder
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -46,19 +48,11 @@ from vllm.platforms import current_platform ...@@ -46,19 +48,11 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import BatchFeature from transformers import BatchFeature, PreTrainedModel
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
DYNAMIC_ARG_DIMS = {
"input_ids": 0,
# set `positions` to last dim to support Qwen-mrope
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Skip SupportsMRoPE.__init__ and call the next class in MRO # Skip SupportsMRoPE.__init__ and call the next class in MRO
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix) super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
def _get_encoder_cls(
self, modality: str = "image", **kwargs: dict
) -> type["PreTrainedModel"]:
"""
Get the encoder class from the model.
Args:
kwargs: The kwargs to create the model.
Returns:
The encoder class.
"""
with torch.device("meta"):
model: PreTrainedModel = AutoModel.from_config(**kwargs)
encoder_cls = type(model.get_encoder(modality=modality))
logger.debug("Identified encoder class as: %s", encoder_cls)
if type(model) is encoder_cls:
raise ValueError(
"Unable to infer vision encoder class from the model. "
"You must either: update the model so that "
"https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.get_encoder"
" can detect the vision encoder correctly, or remove "
"'compile_mm_encoder'."
)
del model
return encoder_cls
def _decorate_for_torch_compile(self, **kwargs: dict):
"""
Decorate the model's decoder and encoder classes to indicate to vLLM that they
support torch compile if `can_enable_torch_compile` and
`should_torch_compile_mm_encoder` are True respectively.
Args:
kwargs: The kwargs to create the model, which are needed to get the decoder
and encoder classes.
"""
super()._decorate_for_torch_compile(**kwargs)
# Decorate the vision encoder model class to support torch compile if needed
if self.compilation_config.compile_mm_encoder:
self.check_version("5.0.0", "multimodal encoder compilation support")
logger.warning_once(
"Multimodal encoder compilation with the Transformers modeling backend "
"is an experimental feature. It relies on:\n"
"- The vision encoder being torch compilable.\n"
"- All vision encoder tensor inputs must be type hinted as either "
"`torch.Tensor` or `torch.FloatTensor`.\n"
"- The 0-th dimension of all tensor inputs to the vision encoder being "
"the dynamic dimension (i.e., sequence length or number of patches).\n"
"Please report any issues you encounter to help us improve it."
)
self._decorate_cls_for_torch_compile(
cls=self._get_encoder_cls(**kwargs),
# TODO: properly infer dynamic_arg_dims based on the encoder's forward
# method signature. Currently we assume dim 0 for all tensor inputs.
dynamic_arg_dims=None,
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,
...@@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly # Gemma3 and PaliGemma needs `token_type_ids` to work correctly
# Other models will not have `token_type_ids` in kwargs # Other models will not have `token_type_ids` in kwargs
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"} kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
# Positions shape handling for MRoPE models
if self.model_config.uses_mrope:
# [3, seq_len] -> [3, 1, seq_len]
positions = positions[:, None]
model_output = super().forward( model_output = super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment