"vscode:/vscode.git/clone" did not exist on "8d5cdd534e5b939f3bbb1128a5fd6b1ade56a59f"
Unverified Commit 0005d2a3 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Use Transformers v5 `WeightRenaming` for Transformers modeling backend (#31545)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent d0b40297
...@@ -206,9 +206,7 @@ VLM_TEST_SETTINGS = { ...@@ -206,9 +206,7 @@ VLM_TEST_SETTINGS = {
"model_impl": "transformers", "model_impl": "transformers",
"default_torch_num_threads": 1, "default_torch_num_threads": 1,
}, },
# FIXME: Investigate why the test hangs marks=[pytest.mark.core_model],
# when processing the 3rd prompt in vLLM
marks=[pytest.mark.core_model, pytest.mark.skip(reason="Test hangs")],
), ),
# Gemma3 has bidirectional mask on images # Gemma3 has bidirectional mask on images
"gemma3-transformers": VLMTestInfo( "gemma3-transformers": VLMTestInfo(
......
...@@ -5,9 +5,10 @@ from collections.abc import Iterable ...@@ -5,9 +5,10 @@ from collections.abc import Iterable
import pytest import pytest
import torch import torch
import transformers import transformers
from transformers import AutoConfig, PreTrainedModel from transformers import AutoConfig, AutoModel, PreTrainedModel
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models.transformers.base import Base as TransformersBase
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.config import try_get_safetensors_metadata from vllm.transformers_utils.config import try_get_safetensors_metadata
...@@ -23,6 +24,16 @@ def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]: ...@@ -23,6 +24,16 @@ def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]:
return ((name, torch.empty(0)) for name in weight_names) return ((name, torch.empty(0)) for name in weight_names)
def create_dummy_base_model(repo: str, model_arch: str) -> PreTrainedModel:
"""
Create weights from a dummy meta deserialized hf base model with name conversion
"""
config = AutoConfig.from_pretrained(repo)
with torch.device("meta"):
model = AutoModel.from_config(config)
return model
def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel: def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel:
""" """
Create weights from a dummy meta deserialized hf model with name conversion Create weights from a dummy meta deserialized hf model with name conversion
...@@ -79,6 +90,19 @@ def test_hf_model_weights_mapper(model_arch: str): ...@@ -79,6 +90,19 @@ def test_hf_model_weights_mapper(model_arch: str):
dtype=model_info.dtype, dtype=model_info.dtype,
) )
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
if issubclass(model_cls, TransformersBase):
# Transformers backend models create their mapper during __init__
# by inspecting the HF model instance. We simulate this by calling
# _create_hf_to_vllm_mapper with a minimal proxy object.
model_cls = type(
"ProxyModelCls",
(),
{
"model": create_dummy_base_model(model_id, model_arch),
"_maybe_apply_model_mapping": lambda self: None,
},
)()
TransformersBase._create_hf_to_vllm_mapper(model_cls)
original_weights = create_repo_dummy_weights(model_id) original_weights = create_repo_dummy_weights(model_id)
hf_dummy_model = create_dummy_model(model_id, model_arch) hf_dummy_model = create_dummy_model(model_id, model_arch)
...@@ -102,6 +126,9 @@ def test_hf_model_weights_mapper(model_arch: str): ...@@ -102,6 +126,9 @@ def test_hf_model_weights_mapper(model_arch: str):
# after they are tied in the model, so the mapper will not be able to map them. # after they are tied in the model, so the mapper will not be able to map them.
# We exclude them from the reference weight names for this test. # We exclude them from the reference weight names for this test.
if isinstance(tied := getattr(hf_dummy_model, "_tied_weights_keys", None), dict): if isinstance(tied := getattr(hf_dummy_model, "_tied_weights_keys", None), dict):
config = hf_dummy_model.config
key = "tie_word_embeddings"
if getattr(config.get_text_config(), key, False) or getattr(config, key, False):
mapped_tied_weights = mapper.apply((k, None) for k in tied) mapped_tied_weights = mapper.apply((k, None) for k in tied)
tied_weight_names = set(map(lambda x: x[0], mapped_tied_weights)) tied_weight_names = set(map(lambda x: x[0], mapped_tied_weights))
ref_weight_names -= tied_weight_names ref_weight_names -= tied_weight_names
......
...@@ -995,19 +995,10 @@ class SupportsQuant: ...@@ -995,19 +995,10 @@ class SupportsQuant:
def __new__(cls, *args, **kwargs) -> Self: def __new__(cls, *args, **kwargs) -> Self:
instance = super().__new__(cls) instance = super().__new__(cls)
# find config passed in arguments # find config passed in arguments and attach it to model for general use
quant_config = cls._find_quant_config(*args, **kwargs) instance.quant_config = cls._find_quant_config(*args, **kwargs)
if quant_config is not None:
# attach config to model for general use cls._maybe_apply_model_mapping(instance)
instance.quant_config = quant_config
# apply model mappings to config for proper config-model matching
if (hf_to_vllm_mapper := instance.hf_to_vllm_mapper) is not None:
instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if instance.packed_modules_mapping is not None:
instance.quant_config.packed_modules_mapping.update(
instance.packed_modules_mapping
)
return instance return instance
...@@ -1026,6 +1017,15 @@ class SupportsQuant: ...@@ -1026,6 +1017,15 @@ class SupportsQuant:
return None return None
def _maybe_apply_model_mapping(self):
"""Apply model mappings to config for proper config-model matching"""
if self.quant_config is None:
return
if (hf_to_vllm_mapper := self.hf_to_vllm_mapper) is not None:
self.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if self.packed_modules_mapping is not None:
self.quant_config.packed_modules_mapping.update(self.packed_modules_mapping)
@runtime_checkable @runtime_checkable
class SupportsRealtime(Protocol): class SupportsRealtime(Protocol):
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Transformers modeling backend base class.""" """Transformers modeling backend base class."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import chain
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import regex as re import regex as re
...@@ -107,27 +108,6 @@ class Base( ...@@ -107,27 +108,6 @@ class Base(
SupportsEagle3, SupportsEagle3,
): ):
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints,
# handling the case where it is already present
"": "model.",
"model.model.": "model.",
# Heads will be adjacent to `model` (pooling included because of adapters)
"model.lm_head.": "lm_head.",
"model.score.": "classifier.",
"model.classifier.": "classifier.",
}
)
def __init_subclass__(cls, *args, **kwargs):
"""Merge hf_to_vllm_mapper in MRO from most specific to least specific."""
super().__init_subclass__(*args, **kwargs)
hf_to_vllm_mapper = WeightsMapper()
for base in cls.__mro__:
if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None):
hf_to_vllm_mapper |= base_hf_to_vllm_mapper
cls.hf_to_vllm_mapper = hf_to_vllm_mapper
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__() super().__init__()
...@@ -174,8 +154,8 @@ class Base( ...@@ -174,8 +154,8 @@ class Base(
if "gptq" in quant_method_name: if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias") self.ignore_unexpected_suffixes.append(".bias")
# Set correct attn and init on "meta" to delay allocating GPU tensors # Patch config and init on "meta" to delay allocating GPU tensors
self.text_config._attn_implementation = "vllm" self._patch_config()
with init_on_device_without_buffers("meta"): with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config( self.model: PreTrainedModel = AutoModel.from_config(
self.config, self.config,
...@@ -183,6 +163,8 @@ class Base( ...@@ -183,6 +163,8 @@ class Base(
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
# Create weight name to module qualname mapper
self._create_hf_to_vllm_mapper()
# Remove layers not on this pipeline parallel rank # Remove layers not on this pipeline parallel rank
self.pipeline_parallel() self.pipeline_parallel()
# Substitute remaining layers with vLLM's layers as needed # Substitute remaining layers with vLLM's layers as needed
...@@ -216,6 +198,104 @@ class Base( ...@@ -216,6 +198,104 @@ class Base(
["hidden_states"], self.text_config.hidden_size ["hidden_states"], self.text_config.hidden_size
) )
def _patch_config(self):
"""
Patch the config to ensure that the model is created correctly:
- Sets the attention implementation to "vllm" so the attention instances from
`create_attention_instances` are used
- Sets the dtype to the default torch dtype set by vLLM because Transformers
uses the config dtype when creating the model
- Propagates this dtype to any sub-configs because Transformers model
implementations do not support/use different dtypes in sub-models
"""
self.text_config._attn_implementation = "vllm"
self.config.dtype = torch.get_default_dtype()
# TODO(hmellor): Remove this when Transformers v4 support is dropped
for sub_config_name in getattr(self.config, "sub_configs", {}):
sub_config = getattr(self.config, sub_config_name)
if sub_config.dtype != (dtype := self.config.dtype):
sub_config.dtype = dtype
def _create_hf_to_vllm_mapper(self):
"""
Create a WeightsMapper to map checkpoint weight names to module qualnames.
This handles:
- Transformers weight renaming:
- from `WeightRenaming` in Transformers v5
- from `_checkpoint_conversion_mapping` in Transformers v4
- Checkpoints saved with a base model prefix that is not `model`
- Checkpoints saved with no base model prefix
- Any quantization config specific mappings
"""
self.hf_to_vllm_mapper = WeightsMapper()
orig_to_new_regex = self.hf_to_vllm_mapper.orig_to_new_regex
if Version(transformers.__version__) >= Version("5.0.0"):
from transformers.conversion_mapping import (
WeightRenaming,
get_model_conversion_mapping,
)
for mapping in get_model_conversion_mapping(self.model):
# Handle weights which have been renamed in Transformers
if isinstance(mapping, WeightRenaming):
# Recompile using regex (Transformers used re)
compiled_sources = re.compile(
mapping.compiled_sources.pattern, mapping.compiled_sources.flags
)
target_pattern = mapping.target_patterns[0]
orig_to_new_regex[compiled_sources] = target_pattern
# TODO: Handle WeightConverter to enable layer merging
else:
# Replace legacy suffixes used for norms
# TODO(hmellor): Remove this when Transformers v4 support is dropped
orig_to_new_regex.update(
{
re.compile(r"\.gamma$"): ".weight",
re.compile(r"\.beta$"): ".bias",
}
)
# Handle weights which have been renamed in Transformers
# TODO(hmellor): Remove this when Transformers v4 support is dropped
ccm = getattr(self.model, "_checkpoint_conversion_mapping", {})
for source, target in ccm.items():
orig_to_new_regex[re.compile(source)] = target
# Handle unexpected weights which should be ignored
if self.model._keys_to_ignore_on_load_unexpected is not None:
for key in self.model._keys_to_ignore_on_load_unexpected:
orig_to_new_regex[re.compile(key)] = None
# Standardise base model prefix
bmp = self.model.base_model_prefix
expected_bmp = r"model.\1"
# Handle checkpoints saved with different base model prefix
if bmp and bmp != "model":
different_bmp_pattern = re.compile(rf"^{bmp}\.(.+)")
orig_to_new_regex[different_bmp_pattern] = expected_bmp
# Handle direct children of self.model which were saved without the model prefix
direct_children = chain(
self.model.named_children(),
self.model.named_parameters(recurse=False),
self.model.named_buffers(recurse=False),
)
model_children = "|".join(name for name, _ in direct_children)
missing_bmp_pattern = re.compile(rf"^(?!model\.)(({model_children}).*)")
orig_to_new_regex[missing_bmp_pattern] = expected_bmp
# Handle weights saved as direct children of self.model which no longer are
unexpected_bmp_pattern = re.compile(rf"^(model\.)((?!{model_children}).+)")
orig_to_new_regex[unexpected_bmp_pattern] = r"\2"
# Handle lm_head which was saved inside the base model
nested_lm_head_pattern = re.compile(r"^model\.(.+\.)*(lm_head.+)")
orig_to_new_regex[nested_lm_head_pattern] = r"\2"
# Apply mapping to quantization config if needed
self._maybe_apply_model_mapping()
def pipeline_parallel(self): def pipeline_parallel(self):
""" """
Apply the model's pipeline parallelization plan. Apply the model's pipeline parallelization plan.
......
...@@ -20,7 +20,6 @@ from typing import TYPE_CHECKING ...@@ -20,7 +20,6 @@ from typing import TYPE_CHECKING
import torch import torch
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -28,20 +27,6 @@ if TYPE_CHECKING: ...@@ -28,20 +27,6 @@ if TYPE_CHECKING:
class LegacyMixin: class LegacyMixin:
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
},
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
......
...@@ -24,7 +24,6 @@ import torch ...@@ -24,7 +24,6 @@ import torch
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MultiModalKwargsItems from vllm.multimodal import MultiModalKwargsItems
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
...@@ -273,30 +272,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -273,30 +272,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
supports_multimodal_raw_input_only = True supports_multimodal_raw_input_only = True
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"visual": "model.visual",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# Qwen models used "model" as the name for the language model.
# Therefore, we must map each of submodule explicitly to avoid
# conflicts with newer models that use "model.language_model".
"model.embed_tokens": "model.language_model.embed_tokens",
"model.layers": "model.language_model.layers",
"model.norm": "model.language_model.norm",
}
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip SupportsMRoPE.__init__ and call the next class in MRO # Skip SupportsMRoPE.__init__ and call the next class in MRO
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix) super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
......
...@@ -7,6 +7,7 @@ from contextlib import contextmanager ...@@ -7,6 +7,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Literal, Protocol, overload from typing import Any, Literal, Protocol, overload
import regex as re
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.modules.module import register_module_module_registration_hook from torch.nn.modules.module import register_module_module_registration_hook
...@@ -38,17 +39,17 @@ from vllm.utils.torch_utils import ( ...@@ -38,17 +39,17 @@ from vllm.utils.torch_utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
WeightsMapping = Mapping[str, str | None]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
@dataclass @dataclass
class WeightsMapper: class WeightsMapper:
"""Maps the name of each weight if they match the following patterns.""" """Maps the name of each weight if they match the following patterns.
If a key maps to a value of `None`, the corresponding weight is ignored."""
orig_to_new_substr: WeightsMapping = field(default_factory=dict) orig_to_new_regex: Mapping[re.Pattern, str | None] = field(default_factory=dict)
orig_to_new_prefix: WeightsMapping = field(default_factory=dict) orig_to_new_substr: Mapping[str, str | None] = field(default_factory=dict)
orig_to_new_suffix: WeightsMapping = field(default_factory=dict) orig_to_new_prefix: Mapping[str, str | None] = field(default_factory=dict)
orig_to_new_suffix: Mapping[str, str | None] = field(default_factory=dict)
def __or__(self, other: "WeightsMapper") -> "WeightsMapper": def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
"""Combine two `WeightsMapper`s by merging their mappings.""" """Combine two `WeightsMapper`s by merging their mappings."""
...@@ -59,6 +60,13 @@ class WeightsMapper: ...@@ -59,6 +60,13 @@ class WeightsMapper:
) )
def _map_name(self, key: str) -> str | None: def _map_name(self, key: str) -> str | None:
for pattern, new_key in self.orig_to_new_regex.items():
if pattern.search(key):
if new_key is None:
return None
key = pattern.sub(new_key, key)
for substr, new_key in self.orig_to_new_substr.items(): for substr, new_key in self.orig_to_new_substr.items():
if substr in key: if substr in key:
if new_key is None: if new_key is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment