"examples/basic/offline_inference/classify.py" did not exist on "02f0c7b220422792f5e53de2a7d51d2d3ff2df28"
Unverified Commit fb5e10d3 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Refactor Transformers backend to use mixins (#26906)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent b2f78cba
...@@ -57,7 +57,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson ...@@ -57,7 +57,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/v1/offloading @ApostaC /tests/v1/offloading @ApostaC
# Transformers backend # Transformers backend
/vllm/model_executor/models/transformers.py @hmellor /vllm/model_executor/models/transformers @hmellor
/tests/models/test_transformers.py @hmellor /tests/models/test_transformers.py @hmellor
# Docs # Docs
......
...@@ -912,11 +912,11 @@ _TRANSFORMERS_BACKEND_MODELS = { ...@@ -912,11 +912,11 @@ _TRANSFORMERS_BACKEND_MODELS = {
"TransformersForCausalLM": _HfExamplesInfo( "TransformersForCausalLM": _HfExamplesInfo(
"hmellor/Ilama-3.2-1B", trust_remote_code=True "hmellor/Ilama-3.2-1B", trust_remote_code=True
), ),
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo( "TransformersMoEForCausalLM": _HfExamplesInfo(
"allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0" "allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"
), ),
"TransformersMoEForMultimodalLM": _HfExamplesInfo( "TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0" "Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"
), ),
"TransformersMoEEmbeddingModel": _HfExamplesInfo( "TransformersMoEEmbeddingModel": _HfExamplesInfo(
...@@ -925,6 +925,10 @@ _TRANSFORMERS_BACKEND_MODELS = { ...@@ -925,6 +925,10 @@ _TRANSFORMERS_BACKEND_MODELS = {
"TransformersMoEForSequenceClassification": _HfExamplesInfo( "TransformersMoEForSequenceClassification": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0" "Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
), ),
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo(
"google/gemma-3-4b-it"
),
} }
_EXAMPLE_MODELS = { _EXAMPLE_MODELS = {
......
...@@ -37,7 +37,7 @@ MINIMAL_MODEL_ARCH_LIST = [ ...@@ -37,7 +37,7 @@ MINIMAL_MODEL_ARCH_LIST = [
"JinaVLForRanking", "JinaVLForRanking",
"InternVLChatModel", "InternVLChatModel",
"InternLM2ForRewardModel", "InternLM2ForRewardModel",
"TransformersForMultimodalLM", "TransformersMultiModalForCausalLM",
"PrithviGeoSpatialMAE", "PrithviGeoSpatialMAE",
"UltravoxModel", "UltravoxModel",
"DeepSeekMTPModel", "DeepSeekMTPModel",
......
...@@ -211,11 +211,7 @@ def test_embed_loading(vllm_runner, model): ...@@ -211,11 +211,7 @@ def test_embed_loading(vllm_runner, model):
def test_pooling(hf_runner, vllm_runner, example_prompts, arch): def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
model = get_model(arch) model = get_model(arch)
vllm_kwargs = dict( vllm_kwargs = dict(max_model_len=None, model_impl="transformers")
max_model_len=None,
model_impl="transformers",
compilation_config=dict(cudagraph_capture_sizes=[8]),
)
hf_kwargs = dict() hf_kwargs = dict()
if arch == "TransformersEmbeddingModel": if arch == "TransformersEmbeddingModel":
......
...@@ -147,6 +147,10 @@ class ModelConfig: ...@@ -147,6 +147,10 @@ class ModelConfig:
seed: int | None = None seed: int | None = None
"""Random seed for reproducibility. Initialized to None in V0, but """Random seed for reproducibility. Initialized to None in V0, but
initialized to 0 in V1.""" initialized to 0 in V1."""
hf_config: PretrainedConfig = field(init=False)
"""The Hugging Face config of the model."""
hf_text_config: PretrainedConfig = field(init=False)
"""The Hugging Face config of the text model (same as hf_config for text models)."""
hf_config_path: str | None = None hf_config_path: str | None = None
"""Name or path of the Hugging Face config to use. If unspecified, model """Name or path of the Hugging Face config to use. If unspecified, model
name or path will be used.""" name or path will be used."""
...@@ -771,8 +775,10 @@ class ModelConfig: ...@@ -771,8 +775,10 @@ class ModelConfig:
def _get_transformers_backend_cls(self) -> str: def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if """Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`.""" `model_impl` is set to `transformers` or `auto`."""
prefix = "Transformers" cls = "Transformers"
prefix += "MoE" if self.get_num_experts() > 1 else "" # If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal
cls += "MultiModal" if self.hf_config != self.hf_text_config else ""
cls += "MoE" if self.get_num_experts() > 1 else ""
# Check if the architecture we're wrapping has defaults # Check if the architecture we're wrapping has defaults
runner = None runner = None
convert = None convert = None
...@@ -788,18 +794,15 @@ class ModelConfig: ...@@ -788,18 +794,15 @@ class ModelConfig:
runner = "generate" runner = "generate"
if convert in {None, "none"}: if convert in {None, "none"}:
convert = "embed" convert = "embed"
# Resolve Transformers backend pooling classes # Resolve Transformers backend task
if runner == "pooling": if runner == "pooling":
if convert == "embed": if convert == "embed":
return prefix + "EmbeddingModel" return cls + "EmbeddingModel"
if convert == "classify": if convert == "classify":
return prefix + "ForSequenceClassification" return cls + "ForSequenceClassification"
# Resolve Transformers backend generate classes else:
if self.hf_config != self.hf_text_config: cls += "ForCausalLM"
# If 'hf_text_config' is the same as 'hf_config'. If not, it is return cls
# probably a composite config, i.e. multimodal
return prefix + "ForMultimodalLM"
return prefix + "ForCausalLM"
def using_transformers_backend(self) -> bool: def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class.""" """Check if the model is using the Transformers backend class."""
......
...@@ -19,7 +19,7 @@ from vllm.config.multimodal import BaseDummyOptions ...@@ -19,7 +19,7 @@ from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.transformers import replace_linear_class from vllm.model_executor.models.transformers.utils import replace_linear_class
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
......
...@@ -401,32 +401,44 @@ _TRANSFORMERS_SUPPORTED_MODELS = { ...@@ -401,32 +401,44 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
# Text generation models # Text generation models
"SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"), "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
# Multimodal models # Multimodal models
"Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 "Emu3ForConditionalGeneration": (
"transformers",
"TransformersMultiModalForCausalLM",
),
} }
_TRANSFORMERS_BACKEND_MODELS = { _TRANSFORMERS_BACKEND_MODELS = {
# Text generation models
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
"TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501 # Multimodal models
"TransformersMoEForMultimodalLM": ( "TransformersMultiModalForCausalLM": (
"transformers_moe", "transformers",
"TransformersMoEForMultimodalLM", "TransformersMultiModalForCausalLM",
),
"TransformersMultiModalMoEForCausalLM": (
"transformers",
"TransformersMultiModalMoEForCausalLM",
), ),
"TransformersEmbeddingModel": ( # Embedding models
"transformers_pooling", "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
"TransformersEmbeddingModel", "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
"TransformersMultiModalEmbeddingModel": (
"transformers",
"TransformersMultiModalEmbeddingModel",
), ),
# Sequence classification models
"TransformersForSequenceClassification": ( "TransformersForSequenceClassification": (
"transformers_pooling", "transformers",
"TransformersForSequenceClassification", "TransformersForSequenceClassification",
), ),
"TransformersMoEForSequenceClassification": ( "TransformersMoEForSequenceClassification": (
"transformers_pooling", "transformers",
"TransformersMoEForSequenceClassification", "TransformersMoEForSequenceClassification",
), ),
"TransformersMoEEmbeddingModel": ( "TransformersMultiModalForSequenceClassification": (
"transformers_pooling", "transformers",
"TransformersMoEEmbeddingModel", "TransformersMultiModalForSequenceClassification",
), ),
} }
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
from vllm.compilation.decorators import support_torch_compile
from vllm.model_executor.models.transformers.base import Base
from vllm.model_executor.models.transformers.causal import CausalMixin
from vllm.model_executor.models.transformers.legacy import LegacyMixin
from vllm.model_executor.models.transformers.moe import MoEMixin
from vllm.model_executor.models.transformers.multimodal import (
DYNAMIC_ARG_DIMS,
MultiModalDummyInputsBuilder,
MultiModalMixin,
MultiModalProcessingInfo,
MultiModalProcessor,
)
from vllm.model_executor.models.transformers.pooling import (
EmbeddingMixin,
SequenceClassificationMixin,
)
from vllm.model_executor.models.transformers.utils import can_enable_torch_compile
from vllm.multimodal import MULTIMODAL_REGISTRY
# Text only models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(CausalMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
# Multimodal models
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalMoEForCausalLM(
MoEMixin, MultiModalMixin, CausalMixin, Base
): ...
# Embedding models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ...
# Sequence classification models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForSequenceClassification(
SequenceClassificationMixin, LegacyMixin, Base
): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification(
SequenceClassificationMixin, MoEMixin, Base
): ...
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForSequenceClassification(
SequenceClassificationMixin, MultiModalMixin, Base
): ...
def __getattr__(name: str):
"""Handle imports of non-existent classes with a helpful error message."""
if name not in globals():
raise AttributeError(
"The Transformers backend does not currently have a class to handle "
f"the requested model type: {name}. Please open an issue at "
"https://github.com/vllm-project/vllm/issues/new"
)
return globals()[name]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers backend mixin for causal language models."""
from typing import TYPE_CHECKING
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces_base import VllmModelForTextGeneration
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
if TYPE_CHECKING:
import torch
from vllm.config import VllmConfig
class CausalMixin(VllmModelForTextGeneration):
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip VllmModelForTextGeneration.__init__ and call the next class in MRO
super(VllmModelForTextGeneration, self).__init__(
vllm_config=vllm_config, prefix=prefix
)
# Tell `Base.load_weights` to skip
# `lm_head` if the model has tied word embeddings
if self.text_config.tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
if self.pp_group.is_last_rank:
self.unpadded_vocab_size = self.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.text_config.vocab_size,
self.text_config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.text_config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings()
)
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
)
else:
self.lm_head = PPMissingLayer()
def compute_logits(self, hidden_states: "torch.Tensor") -> "torch.Tensor | None":
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers backend mixin for legacy models."""
from typing import TYPE_CHECKING
import torch
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.config import VllmConfig
class LegacyMixin:
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Classifier/scoring heads will be adjacent to `model`
"model.score": "classifier",
"model.classifier": "classifier",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
},
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Skip unsupported/unwanted output embeddings layers
self.skip_prefixes.extend(
[
"model.lm_head.",
"model.predictions.",
"model.qa_outputs.",
"model.embeddings_project.",
"model.discriminator_predictions.",
]
)
# Some encoder models have the position_ids buffer in the checkpoint.
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")
# Some encoder models have the bias of the final classifier layer
# in the checkpoint. vLLM does not use this bias, so we skip loading
# it if it exists
self.skip_substrs.append("score.bias")
# roberta-like models an extra padding in positions.
# FIXME(Isotr0py): This is quite hacky for roberta edge case,
# we should find a better way to handle this.
self.is_roberta = "roberta" in self.text_config.model_type
self.padding_idx = self.text_config.pad_token_id
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if self.is_roberta:
# RoBERTa-specific positions padding
positions += self.padding_idx + 1
return super().forward(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
...@@ -14,31 +14,27 @@ ...@@ -14,31 +14,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` MoE models.""" """Transformers backend mixin for Mixture of Experts (MoE) models."""
from typing import Any from typing import TYPE_CHECKING, Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.distributed import get_dp_group, get_ep_group from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.models.interfaces import MixtureOfExperts
from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .interfaces import MixtureOfExperts, SupportsMultiModal from .utils import log_replacement
from .transformers import (
TransformersBase, if TYPE_CHECKING:
TransformersForCausalLM, from vllm.config import VllmConfig
TransformersForMultimodalLM,
can_enable_torch_compile,
log_replacement,
)
from .utils import maybe_prefix
@CustomOp.register("transformers_fused_moe") @CustomOp.register("transformers_fused_moe")
...@@ -117,11 +113,11 @@ direct_register_custom_op( ...@@ -117,11 +113,11 @@ direct_register_custom_op(
) )
class TransformersMoEBase(TransformersBase, MixtureOfExperts): class MoEMixin(MixtureOfExperts):
def __init__(self, *, vllm_config, prefix=""): def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
self.check_version("4.57.0.dev0", "MoE models support") self.check_version("4.57.0.dev0", "MoE models support")
self.ep_group = get_ep_group() # Skip MixtureOfExperts.__init__ and call the next class in MRO
super().__init__(vllm_config=vllm_config, prefix=prefix) super(MixtureOfExperts, self).__init__(vllm_config=vllm_config, prefix=prefix)
def set_eplb_state( def set_eplb_state(
self, self,
...@@ -242,7 +238,7 @@ class TransformersMoEBase(TransformersBase, MixtureOfExperts): ...@@ -242,7 +238,7 @@ class TransformersMoEBase(TransformersBase, MixtureOfExperts):
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
# MixtureOfExperts mixin settings # MixtureOfExperts mixin settings
ep_size = self.ep_group.world_size ep_size = get_ep_group().world_size
self.mlp_layers = [] # Used for MixtureOfExperts methods self.mlp_layers = [] # Used for MixtureOfExperts methods
self.expert_weights = [] self.expert_weights = []
...@@ -316,24 +312,5 @@ class TransformersMoEBase(TransformersBase, MixtureOfExperts): ...@@ -316,24 +312,5 @@ class TransformersMoEBase(TransformersBase, MixtureOfExperts):
_recursive_replace(child_module, prefix=qual_name) _recursive_replace(child_module, prefix=qual_name)
_recursive_replace(self.model, prefix="model") _recursive_replace(self.model, prefix="model")
# Continue with the replacement of layers in TransformersBase # Continue with the replacement of layers in Base
super().recursive_replace() super().recursive_replace()
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM):
pass
@support_torch_compile(
# set `positions` to last dim to support Qwen-mrope
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
},
enable_if=can_enable_torch_compile,
)
class TransformersMoEForMultimodalLM(TransformersMoEBase, TransformersForMultimodalLM):
get_input_embeddings = SupportsMultiModal.get_input_embeddings
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers backend mixin for multi-modal models."""
from collections.abc import Mapping
from typing import TYPE_CHECKING
import torch
from vllm.config.utils import getattr_iter
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MultiModalKwargsItems
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalInputs,
MultiModalUUIDDict,
PlaceholderRange,
)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from transformers import BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
DYNAMIC_ARG_DIMS = {
"input_ids": 0,
# set `positions` to last dim to support Qwen-mrope
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
class MultiModalProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self):
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
return {"image": self.get_max_image_tokens()}
def get_max_image_tokens(self) -> int:
width, height = self.get_max_image_size()
processor = self.get_hf_processor()
multimodal_config = self.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
mm_tokens = processor._get_num_multimodal_tokens(
image_sizes=([height, width],), **mm_processor_kwargs
)
image_tokens = mm_tokens["num_image_tokens"][0]
return image_tokens
def get_max_image_size(self):
return 10_000, 10_000 # hardcode for arbitrary very large size
class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
if "gemma3" in processor.__class__.__name__.lower():
image_token = processor.boi_token
else:
image_token = getattr(processor, "image_token", "")
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, "BaseDummyOptions"] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_max_image_size()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
),
}
class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
):
"""
Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform.
The information returned by this method is used to update token inputs
which bypass the HF processor. It is also used to update the output of
HF processor if the HF process does not apply prompt updates to text
inputs.
Moreover, this information is critical to determine the token positions
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
for each multi-modal item.
"""
return None
def _get_mm_fields_config(
self,
hf_inputs: "BatchFeature",
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
# HF Processors always return a mask but vLLM doesn't need it
hf_inputs.pop("attention_mask", None)
num_image_patches = hf_inputs.get("num_image_patches")
mm_fields = {
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
for key in hf_inputs
}
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
"image", num_image_patches
)
# Keep these as batched, as they always have batch size as first dim
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
return mm_fields
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[Mapping[str, object], Mapping[str, object]]:
"""
In contrast to the base class, this method always adds
`return_mm_token_type_ids` to the processor data
"""
processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
processor_data["return_mm_token_type_ids"] = True
return processor_data, passthrough_data
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
"""
if tokenization_kwargs is None:
tokenization_kwargs = {}
mm_items = self._to_mm_items(mm_data)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if not isinstance(prompt, str):
# the prompt is the tokenized ids which is not supported
# by the hf_processor, which is why we would need to decode the ids
# into string
prompt = hf_processor.decode(prompt)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method
# transforms outputs to `MultiModalKwargs` which is not going to
# work for Transformers. We have a lot of logic tied to
# `mm_tokens_per_modality` below
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# For gemma3 we check `token_type_ids` as the key
token_type_key = (
"mm_token_type_ids"
if "mm_token_type_ids" in processed_data
else "token_type_ids"
)
mm_token_type_ids = processed_data.pop(token_type_key)
# We can infer vLLM style placeholder from token type ids, if we split
# it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
multimodal_config = self.info.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
image_sizes = []
for item_idx in range(len(images)):
image_size = images.get_image_size(item_idx)
image_sizes.append((image_size.height, image_size.width))
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
image_sizes=image_sizes, **mm_processor_kwargs
)
mm_placeholders = {}
split_sizes = mm_tokens_per_modality["num_image_tokens"]
if split_sizes:
chunked_mm_positions = torch.split(mm_positions, split_sizes)
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
ranges = [
PlaceholderRange(
offset=positions[0].item(),
length=positions.shape[0],
is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
)
for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
]
mm_placeholders = {"image": ranges}
processed_data["num_image_patches"] = torch.tensor(
mm_tokens_per_modality["num_image_patches"]
)
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
)
return MultiModalInputs(
type="multimodal",
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
supports_multimodal_raw_input_only = True
merge_by_field_config = True
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"visual": "model.visual",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# Qwen models used "model" as the name for the language model.
# Therefore, we must map each of submodule explicitly to avoid
# conflicts with newer models that use "model.language_model".
"model.embed_tokens": "model.language_model.embed_tokens",
"model.layers": "model.language_model.layers",
"model.norm": "model.language_model.norm",
}
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip SupportsMRoPE.__init__ and call the next class in MRO
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
# Other models will not have `token_type_ids` in kwargs
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
model_output = super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return model_output
def get_language_model(self) -> torch.nn.Module:
"""Transformers backend multimodal classes do not contain a separate vLLM
language model class. Therefore, in order to return a language model vLLM class,
we use a wrapper to give `self` the same interface as a text model."""
# Exclude self and object
bases = self.__class__.mro()[1:-1]
# Keep only classes defined in `vllm.model_executor.models.transformers`
bases = [b for b in bases if ".transformers." in b.__module__]
# Exclude MultiModalMixin itself
bases = [b for b in bases if b is not MultiModalMixin]
class LanguageModel(*bases):
def __init__(self, multimodal_model):
# Don't call super().__init__() to avoid re-initialization
self.__dict__.update(multimodal_model.__dict__)
model = getattr_iter(self.model, ("language_model", "text_model"), None)
return LanguageModel(self)
def get_multimodal_embeddings(self, **kwargs):
pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None)
image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None)
# Model might use `image_patches` instead of `pixel_values`
if pixel_values is None:
pixel_values = kwargs.pop("image_patches", None)
if image_embeds is not None:
return image_embeds
if pixel_values is None:
return None
num_image_patches = kwargs.pop("num_image_patches")
kwargs.pop("token_type_ids", None) # used only in `forward`
if pixel_values is not None:
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
if isinstance(vision_embeddings, torch.Tensor):
if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)
# Embeddings have to be 2D tensors of length `num_images`
# but transformers returns concat tensors if each patch
# is of different size. We split it back to make vLLM happy
vision_embeddings = torch.split(
vision_embeddings, num_image_patches.flatten().tolist()
)
vision_embeddings = [
embed.flatten(start_dim=0, end_dim=-2)
for embed in vision_embeddings
]
return vision_embeddings
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: "PretrainedConfig",
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
if any((second_per_grid_ts, audio_feature_lengths, use_audio_in_video)):
raise NotImplementedError("Transformers backend only supports images.")
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
mrope_positions, mrope_position_delta = self.model.get_rope_index(
input_ids=torch.tensor(input_tokens).unsqueeze(0),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
)
mrope_positions = mrope_positions[:, 0, context_len:seq_len]
mrope_position_delta = mrope_position_delta[0].item()
return mrope_positions, mrope_position_delta
...@@ -14,121 +14,34 @@ ...@@ -14,121 +14,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` models for pooling tasks.""" """Transformers backend mixins for pooling models."""
from typing import TYPE_CHECKING
import torch import torch
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import (
ClassifierPooler, ClassifierPooler,
CLSPool, CLSPool,
DispatchPooler, DispatchPooler,
Pooler, Pooler,
) )
from vllm.sequence import IntermediateTensors from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
from .interfaces_base import VllmModelForPooling
from .transformers import TransformersBase, can_enable_torch_compile
from .transformers_moe import TransformersMoEBase
from .utils import WeightsMapper
class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Classifier/scoring heads will be adjacent to `model`
"model.score": "classifier",
"model.classifier": "classifier",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Skip unsupported/unwanted output embeddings layers
self.skip_prefixes.extend(
[
"model.lm_head.",
"model.predictions.",
"model.qa_outputs.",
"model.embeddings_project.",
"model.discriminator_predictions.",
]
)
# Some encoder models have the position_ids buffer in the checkpoint. if TYPE_CHECKING:
# vLLM will always pass position_ids as an argument, so we skip loading from vllm.config import VllmConfig
# the buffer if it exists
self.skip_substrs.append("position_ids")
# Some encoder models have the bias of the final classifier layer
# in the checkpoint. vLLM does not use this bias, so we skip loading
# it if it exists
self.skip_substrs.append("score.bias")
# roberta-like models an extra padding in positions.
# FIXME(Isotr0py): This is quite hacky for roberta edge case,
# we should find a better way to handle this.
self.is_roberta = "roberta" in self.text_config.model_type
self.padding_idx = self.text_config.pad_token_id
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER
) -> dict[int, Attention]:
# TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True)
# vLLM does not support encoder-decoder models, so if any encoder layer
# is found, we assume the whole model is an encoder model
if any(is_encoder(m) for m in self.model.modules()):
attn_type = AttentionType.ENCODER_ONLY
# Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY:
self.check_version("4.57.0.dev0", "encoder models support")
return super().create_attention_instances(attn_type)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if self.is_roberta:
# RoBERTa-specific positions padding
positions += self.padding_idx + 1
return super().forward(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
@support_torch_compile(enable_if=can_enable_torch_compile) class EmbeddingMixin(VllmModelForPooling):
class TransformersEmbeddingModel(TransformersPoolingBase):
default_pooling_type = "CLS" default_pooling_type = "CLS"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) # Skip VllmModelForPooling.__init__ and call the next class in MRO
super(VllmModelForPooling, self).__init__(
vllm_config=vllm_config, prefix=prefix
)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
...@@ -141,12 +54,14 @@ class TransformersEmbeddingModel(TransformersPoolingBase): ...@@ -141,12 +54,14 @@ class TransformersEmbeddingModel(TransformersPoolingBase):
) )
@support_torch_compile(enable_if=can_enable_torch_compile) class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
class TransformersForSequenceClassification(TransformersPoolingBase):
default_pooling_type = "CLS" default_pooling_type = "CLS"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) # Skip VllmModelForPooling.__init__ and call the next class in MRO
super(VllmModelForPooling, self).__init__(
vllm_config=vllm_config, prefix=prefix
)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
...@@ -201,15 +116,3 @@ class TransformersForSequenceClassification(TransformersPoolingBase): ...@@ -201,15 +116,3 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
), ),
} }
) )
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(TransformersMoEBase, TransformersEmbeddingModel):
pass
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification(
TransformersMoEBase, TransformersForSequenceClassification
):
pass
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers backend utilities."""
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Literal
import torch
from torch import nn
from vllm.config.utils import getattr_iter
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
logger = init_logger(__name__)
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
"""
A context manager under which models are initialized with all
parameters on the specified device. However buffers are not
initialized on specified device.
Args:
device (`torch.device`):
Device to initialize all parameters on.
"""
old_register_parameter = nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs
)
tensor_constructors_to_patch = {}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
try:
nn.Module.register_parameter = register_empty_parameter
for torch_function_name in tensor_constructors_to_patch:
setattr(
torch,
torch_function_name,
patch_tensor_constructor(getattr(torch, torch_function_name)),
)
yield
finally:
nn.Module.register_parameter = old_register_parameter
for (
torch_function_name,
old_torch_function,
) in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
def replace_linear_class(
linear: nn.Linear,
style: Style = "replicate",
quant_config: "QuantizationConfig | None" = None,
*,
prefix: str = "",
) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args:
linear: `nn.Linear` to be replaced.
style: Tensor parallel style of the new linear, e.g. "colwise".
quant_config: Quantization config for the new linear.
Returns:
The new linear.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
prefix=prefix,
return_bias=False,
**vllm_linear_kwargs,
)
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
This method assumes:
- Weight is stored as `weight`.
- Epsilon is stored as `eps` or `variance_epsilon`.
- `with_scale` indicates whether the layer has a weight (Gemma3n only).
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
and Transformers doesn't appear to have the same concept.
"""
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
kwargs = {"hidden_size": hidden_size, "eps": eps}
# Update hidden size if weight is available
weight_meta = getattr(rms_norm, "weight", None)
if weight_meta is not None:
kwargs["hidden_size"] = weight_meta.size(0)
# Check if weight is all zeros, which indicates GemmaRMSNorm
# We must create a new instance because rms_norm is on meta
try:
with torch.device("cpu"):
weight_test = getattr(rms_norm.__class__(1), "weight", None)
except Exception:
logger.warning(
"Failed to determine if RMSNorm weight is centered on zero or one. "
"Defaulting to one."
)
weight_test = None
if weight_test is not None and torch.all(weight_test == 0):
return GemmaRMSNorm(**kwargs)
# Otherwise assume it's a regular RMSNorm
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
if weight_meta is not None:
kwargs["dtype"] = weight_meta.dtype
else:
# No weight, fall back to weightless RMSNorm
kwargs["has_weight"] = False
return RMSNorm(**kwargs)
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)
def get_feature_request_tip(
model: str,
trust_remote_code: bool,
) -> str:
hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
url = hf_url if trust_remote_code else gh_url
prefix = f"Please open {url} to request support for this feature. "
if Path(model).exists():
prefix = ""
doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
tip = f"See {doc_url} for instructions on how to add support yourself."
return f"{prefix}{tip}"
def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
"""
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
Defaults to `True` but is disabled in the following situations:
- The model uses dynamic rope scaling.
"""
text_config = vllm_config.model_config.hf_config.get_text_config()
# Dynamic rope scaling is not compatible with torch.compile
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
return rope_scaling.get("rope_type") != "dynamic"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment