Unverified Commit 0005d2a3 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Use Transformers v5 `WeightRenaming` for Transformers modeling backend (#31545)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent d0b40297
......@@ -206,9 +206,7 @@ VLM_TEST_SETTINGS = {
"model_impl": "transformers",
"default_torch_num_threads": 1,
},
# FIXME: Investigate why the test hangs
# when processing the 3rd prompt in vLLM
marks=[pytest.mark.core_model, pytest.mark.skip(reason="Test hangs")],
marks=[pytest.mark.core_model],
),
# Gemma3 has bidirectional mask on images
"gemma3-transformers": VLMTestInfo(
......
......@@ -5,9 +5,10 @@ from collections.abc import Iterable
import pytest
import torch
import transformers
from transformers import AutoConfig, PreTrainedModel
from transformers import AutoConfig, AutoModel, PreTrainedModel
from vllm.config import ModelConfig
from vllm.model_executor.models.transformers.base import Base as TransformersBase
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.config import try_get_safetensors_metadata
......@@ -23,6 +24,16 @@ def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]:
return ((name, torch.empty(0)) for name in weight_names)
def create_dummy_base_model(repo: str, model_arch: str) -> PreTrainedModel:
"""
Create weights from a dummy meta deserialized hf base model with name conversion
"""
config = AutoConfig.from_pretrained(repo)
with torch.device("meta"):
model = AutoModel.from_config(config)
return model
def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel:
"""
Create weights from a dummy meta deserialized hf model with name conversion
......@@ -79,6 +90,19 @@ def test_hf_model_weights_mapper(model_arch: str):
dtype=model_info.dtype,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
if issubclass(model_cls, TransformersBase):
# Transformers backend models create their mapper during __init__
# by inspecting the HF model instance. We simulate this by calling
# _create_hf_to_vllm_mapper with a minimal proxy object.
model_cls = type(
"ProxyModelCls",
(),
{
"model": create_dummy_base_model(model_id, model_arch),
"_maybe_apply_model_mapping": lambda self: None,
},
)()
TransformersBase._create_hf_to_vllm_mapper(model_cls)
original_weights = create_repo_dummy_weights(model_id)
hf_dummy_model = create_dummy_model(model_id, model_arch)
......@@ -102,6 +126,9 @@ def test_hf_model_weights_mapper(model_arch: str):
# after they are tied in the model, so the mapper will not be able to map them.
# We exclude them from the reference weight names for this test.
if isinstance(tied := getattr(hf_dummy_model, "_tied_weights_keys", None), dict):
config = hf_dummy_model.config
key = "tie_word_embeddings"
if getattr(config.get_text_config(), key, False) or getattr(config, key, False):
mapped_tied_weights = mapper.apply((k, None) for k in tied)
tied_weight_names = set(map(lambda x: x[0], mapped_tied_weights))
ref_weight_names -= tied_weight_names
......
......@@ -995,19 +995,10 @@ class SupportsQuant:
def __new__(cls, *args, **kwargs) -> Self:
instance = super().__new__(cls)
# find config passed in arguments
quant_config = cls._find_quant_config(*args, **kwargs)
if quant_config is not None:
# attach config to model for general use
instance.quant_config = quant_config
# apply model mappings to config for proper config-model matching
if (hf_to_vllm_mapper := instance.hf_to_vllm_mapper) is not None:
instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if instance.packed_modules_mapping is not None:
instance.quant_config.packed_modules_mapping.update(
instance.packed_modules_mapping
)
# find config passed in arguments and attach it to model for general use
instance.quant_config = cls._find_quant_config(*args, **kwargs)
cls._maybe_apply_model_mapping(instance)
return instance
......@@ -1026,6 +1017,15 @@ class SupportsQuant:
return None
def _maybe_apply_model_mapping(self):
"""Apply model mappings to config for proper config-model matching"""
if self.quant_config is None:
return
if (hf_to_vllm_mapper := self.hf_to_vllm_mapper) is not None:
self.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if self.packed_modules_mapping is not None:
self.quant_config.packed_modules_mapping.update(self.packed_modules_mapping)
@runtime_checkable
class SupportsRealtime(Protocol):
......
......@@ -17,6 +17,7 @@
"""Transformers modeling backend base class."""
from collections.abc import Iterable
from itertools import chain
from typing import TYPE_CHECKING
import regex as re
......@@ -107,27 +108,6 @@ class Base(
SupportsEagle3,
):
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints,
# handling the case where it is already present
"": "model.",
"model.model.": "model.",
# Heads will be adjacent to `model` (pooling included because of adapters)
"model.lm_head.": "lm_head.",
"model.score.": "classifier.",
"model.classifier.": "classifier.",
}
)
def __init_subclass__(cls, *args, **kwargs):
"""Merge hf_to_vllm_mapper in MRO from most specific to least specific."""
super().__init_subclass__(*args, **kwargs)
hf_to_vllm_mapper = WeightsMapper()
for base in cls.__mro__:
if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None):
hf_to_vllm_mapper |= base_hf_to_vllm_mapper
cls.hf_to_vllm_mapper = hf_to_vllm_mapper
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__()
......@@ -174,8 +154,8 @@ class Base(
if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias")
# Set correct attn and init on "meta" to delay allocating GPU tensors
self.text_config._attn_implementation = "vllm"
# Patch config and init on "meta" to delay allocating GPU tensors
self._patch_config()
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
......@@ -183,6 +163,8 @@ class Base(
trust_remote_code=self.model_config.trust_remote_code,
)
# Create weight name to module qualname mapper
self._create_hf_to_vllm_mapper()
# Remove layers not on this pipeline parallel rank
self.pipeline_parallel()
# Substitute remaining layers with vLLM's layers as needed
......@@ -216,6 +198,104 @@ class Base(
["hidden_states"], self.text_config.hidden_size
)
def _patch_config(self):
"""
Patch the config to ensure that the model is created correctly:
- Sets the attention implementation to "vllm" so the attention instances from
`create_attention_instances` are used
- Sets the dtype to the default torch dtype set by vLLM because Transformers
uses the config dtype when creating the model
- Propagates this dtype to any sub-configs because Transformers model
implementations do not support/use different dtypes in sub-models
"""
self.text_config._attn_implementation = "vllm"
self.config.dtype = torch.get_default_dtype()
# TODO(hmellor): Remove this when Transformers v4 support is dropped
for sub_config_name in getattr(self.config, "sub_configs", {}):
sub_config = getattr(self.config, sub_config_name)
if sub_config.dtype != (dtype := self.config.dtype):
sub_config.dtype = dtype
def _create_hf_to_vllm_mapper(self):
"""
Create a WeightsMapper to map checkpoint weight names to module qualnames.
This handles:
- Transformers weight renaming:
- from `WeightRenaming` in Transformers v5
- from `_checkpoint_conversion_mapping` in Transformers v4
- Checkpoints saved with a base model prefix that is not `model`
- Checkpoints saved with no base model prefix
- Any quantization config specific mappings
"""
self.hf_to_vllm_mapper = WeightsMapper()
orig_to_new_regex = self.hf_to_vllm_mapper.orig_to_new_regex
if Version(transformers.__version__) >= Version("5.0.0"):
from transformers.conversion_mapping import (
WeightRenaming,
get_model_conversion_mapping,
)
for mapping in get_model_conversion_mapping(self.model):
# Handle weights which have been renamed in Transformers
if isinstance(mapping, WeightRenaming):
# Recompile using regex (Transformers used re)
compiled_sources = re.compile(
mapping.compiled_sources.pattern, mapping.compiled_sources.flags
)
target_pattern = mapping.target_patterns[0]
orig_to_new_regex[compiled_sources] = target_pattern
# TODO: Handle WeightConverter to enable layer merging
else:
# Replace legacy suffixes used for norms
# TODO(hmellor): Remove this when Transformers v4 support is dropped
orig_to_new_regex.update(
{
re.compile(r"\.gamma$"): ".weight",
re.compile(r"\.beta$"): ".bias",
}
)
# Handle weights which have been renamed in Transformers
# TODO(hmellor): Remove this when Transformers v4 support is dropped
ccm = getattr(self.model, "_checkpoint_conversion_mapping", {})
for source, target in ccm.items():
orig_to_new_regex[re.compile(source)] = target
# Handle unexpected weights which should be ignored
if self.model._keys_to_ignore_on_load_unexpected is not None:
for key in self.model._keys_to_ignore_on_load_unexpected:
orig_to_new_regex[re.compile(key)] = None
# Standardise base model prefix
bmp = self.model.base_model_prefix
expected_bmp = r"model.\1"
# Handle checkpoints saved with different base model prefix
if bmp and bmp != "model":
different_bmp_pattern = re.compile(rf"^{bmp}\.(.+)")
orig_to_new_regex[different_bmp_pattern] = expected_bmp
# Handle direct children of self.model which were saved without the model prefix
direct_children = chain(
self.model.named_children(),
self.model.named_parameters(recurse=False),
self.model.named_buffers(recurse=False),
)
model_children = "|".join(name for name, _ in direct_children)
missing_bmp_pattern = re.compile(rf"^(?!model\.)(({model_children}).*)")
orig_to_new_regex[missing_bmp_pattern] = expected_bmp
# Handle weights saved as direct children of self.model which no longer are
unexpected_bmp_pattern = re.compile(rf"^(model\.)((?!{model_children}).+)")
orig_to_new_regex[unexpected_bmp_pattern] = r"\2"
# Handle lm_head which was saved inside the base model
nested_lm_head_pattern = re.compile(r"^model\.(.+\.)*(lm_head.+)")
orig_to_new_regex[nested_lm_head_pattern] = r"\2"
# Apply mapping to quantization config if needed
self._maybe_apply_model_mapping()
def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.
......
......@@ -20,7 +20,6 @@ from typing import TYPE_CHECKING
import torch
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
......@@ -28,20 +27,6 @@ if TYPE_CHECKING:
class LegacyMixin:
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
},
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
......
......@@ -24,7 +24,6 @@ import torch
from vllm.config.utils import getattr_iter
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MultiModalKwargsItems
from vllm.multimodal.inputs import (
MultiModalDataDict,
......@@ -273,30 +272,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
supports_multimodal_raw_input_only = True
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"visual": "model.visual",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# Qwen models used "model" as the name for the language model.
# Therefore, we must map each of submodule explicitly to avoid
# conflicts with newer models that use "model.language_model".
"model.embed_tokens": "model.language_model.embed_tokens",
"model.layers": "model.language_model.layers",
"model.norm": "model.language_model.norm",
}
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip SupportsMRoPE.__init__ and call the next class in MRO
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
......
......@@ -7,6 +7,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Literal, Protocol, overload
import regex as re
import torch
import torch.nn as nn
from torch.nn.modules.module import register_module_module_registration_hook
......@@ -38,17 +39,17 @@ from vllm.utils.torch_utils import (
logger = init_logger(__name__)
WeightsMapping = Mapping[str, str | None]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
@dataclass
class WeightsMapper:
"""Maps the name of each weight if they match the following patterns."""
"""Maps the name of each weight if they match the following patterns.
If a key maps to a value of `None`, the corresponding weight is ignored."""
orig_to_new_substr: WeightsMapping = field(default_factory=dict)
orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
orig_to_new_regex: Mapping[re.Pattern, str | None] = field(default_factory=dict)
orig_to_new_substr: Mapping[str, str | None] = field(default_factory=dict)
orig_to_new_prefix: Mapping[str, str | None] = field(default_factory=dict)
orig_to_new_suffix: Mapping[str, str | None] = field(default_factory=dict)
def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
"""Combine two `WeightsMapper`s by merging their mappings."""
......@@ -59,6 +60,13 @@ class WeightsMapper:
)
def _map_name(self, key: str) -> str | None:
for pattern, new_key in self.orig_to_new_regex.items():
if pattern.search(key):
if new_key is None:
return None
key = pattern.sub(new_key, key)
for substr, new_key in self.orig_to_new_substr.items():
if substr in key:
if new_key is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment