test_mapping.py 3.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import pytest
import torch
import transformers
from transformers import AutoConfig, PreTrainedModel

from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.config import try_get_safetensors_metadata

from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS


def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]:
    """Create weights from safetensors checkpoint metadata"""
    metadata = try_get_safetensors_metadata(repo)
    weight_names = list(metadata.weight_map.keys())
21
    with torch.device("meta"):
22
23
24
        return ((name, torch.empty(0)) for name in weight_names)


25
def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel:
26
27
28
29
30
31
    """
    Create weights from a dummy meta deserialized hf model with name conversion
    """
    model_cls: PreTrainedModel = getattr(transformers, model_arch)
    config = AutoConfig.from_pretrained(repo)
    with torch.device("meta"):
32
        return model_cls._from_config(config)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


def model_architectures_for_test() -> list[str]:
    arch_to_test = list[str]()
    for model_arch, info in _MULTIMODAL_EXAMPLE_MODELS.items():
        if not info.trust_remote_code and hasattr(transformers, model_arch):
            model_cls: PreTrainedModel = getattr(transformers, model_arch)
            if getattr(model_cls, "_checkpoint_conversion_mapping", None):
                arch_to_test.append(model_arch)
    return arch_to_test


@pytest.mark.core_model
@pytest.mark.parametrize("model_arch", model_architectures_for_test())
def test_hf_model_weights_mapper(model_arch: str):
    model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")

52
    model_config = model_info.build_model_config(config_format="hf")
53
54
    model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)

55
56
    original_weights = create_repo_dummy_weights(model_config.model)
    hf_dummy_model = create_dummy_model(model_config.model, model_arch)
57
58
    hf_converted_weights = hf_dummy_model.named_parameters()
    hf_converted_buffers = hf_dummy_model.named_buffers()
59
60
61
62
    mapper: WeightsMapper = model_cls.hf_to_vllm_mapper

    mapped_original_weights = mapper.apply(original_weights)
    mapped_hf_converted_weights = mapper.apply(hf_converted_weights)
63
    mapped_hf_converted_buffers = mapper.apply(hf_converted_buffers)
64
65
66

    ref_weight_names = set(map(lambda x: x[0], mapped_original_weights))
    weight_names = set(map(lambda x: x[0], mapped_hf_converted_weights))
67
68
69
70
    buffer_names = set(map(lambda x: x[0], mapped_hf_converted_buffers))

    # Some checkpoints may have buffers, we ignore them for this test
    ref_weight_names -= buffer_names
71
72
73

    weights_missing = ref_weight_names - weight_names
    weights_unmapped = weight_names - ref_weight_names
74
    assert not weights_missing and not weights_unmapped, (
75
        f"Following weights are not mapped correctly: {weights_unmapped}, "
76
77
        f"Missing expected weights: {weights_missing}."
    )