test_intern_vit.py 2.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
import pytest
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor

9
from vllm.distributed import cleanup_dist_env_and_memory
10
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
11

12
from ....conftest import ImageTestAssets
13
14
15
16
17
18

# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]


19
@torch.inference_mode()
20
def run_intern_vit_test(
21
    image_assets: ImageTestAssets,
22
    model_id: str,
23
24
25
    *,
    dtype: str,
):
26
    model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
27
    torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
28

29
30
31
    img_processor = CLIPImageProcessor.from_pretrained(model)
    images = [asset.pil_image for asset in image_assets]
    pixel_values = [
32
        img_processor(images, return_tensors="pt").pixel_values.to(torch_dtype)
33
34
35
36
37
38
39
        for images in images
    ]

    config = AutoConfig.from_pretrained(model, trust_remote_code=True)
    if not getattr(config, "norm_type", None):
        config.norm_type = "rms_norm"

40
    hf_model = AutoModel.from_pretrained(
41
        model, dtype=torch_dtype, trust_remote_code=True
42
    ).to("cuda")
43
44
45
46
47
    hf_outputs_per_image = [
        hf_model(pixel_value.to("cuda")).last_hidden_state
        for pixel_value in pixel_values
    ]

48
    from vllm.model_executor.models.intern_vit import InternVisionModel
49

50
51
52
53
    vllm_model = InternVisionModel(config)
    vllm_model.load_weights(hf_model.state_dict().items())

    del hf_model
54
    cleanup_dist_env_and_memory()
55

56
    vllm_model = vllm_model.to("cuda", torch_dtype)
57
    vllm_outputs_per_image = [
58
        vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values
59
60
    ]
    del vllm_model
61
    cleanup_dist_env_and_memory()
62
63

    cos_similar = nn.CosineSimilarity(dim=-1)
64
    for vllm_output, hf_output in zip(vllm_outputs_per_image, hf_outputs_per_image):
65
66
67
        assert cos_similar(vllm_output, hf_output).mean() > 0.99


68
69
70
71
72
73
74
@pytest.mark.parametrize(
    "model_id",
    [
        "OpenGVLab/InternViT-300M-448px",
        "OpenGVLab/InternViT-6B-448px-V1-5",
    ],
)
75
@pytest.mark.parametrize("dtype", ["half"])
76
77
78
def test_models(
    default_vllm_config, dist_init, image_assets, model_id, dtype: str
) -> None:
79
80
    run_intern_vit_test(
        image_assets,
81
        model_id,
82
83
        dtype=dtype,
    )