test_intern_vit.py 2.62 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
6
7
import pytest
import torch
import torch.nn as nn
zhuwenwen's avatar
zhuwenwen committed
8
# from huggingface_hub import snapshot_download
9
10
from transformers import AutoConfig, AutoModel, CLIPImageProcessor

zhuwenwen's avatar
zhuwenwen committed
11

12
from vllm.distributed import cleanup_dist_env_and_memory
13
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
14

15
from ....conftest import ImageTestAssets
16
from ....utils import models_path_prefix
17
18
19
20
21
22

# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]


23
@torch.inference_mode()
24
def run_intern_vit_test(
25
    image_assets: ImageTestAssets,
26
    model_id: str,
27
28
29
    *,
    dtype: str,
):
zhuwenwen's avatar
zhuwenwen committed
30
31
    # model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
    model = os.path.join(models_path_prefix, model_id)
32
    torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
33

34
35
36
    img_processor = CLIPImageProcessor.from_pretrained(model)
    images = [asset.pil_image for asset in image_assets]
    pixel_values = [
37
        img_processor(images, return_tensors="pt").pixel_values.to(torch_dtype)
38
39
40
41
42
43
44
        for images in images
    ]

    config = AutoConfig.from_pretrained(model, trust_remote_code=True)
    if not getattr(config, "norm_type", None):
        config.norm_type = "rms_norm"

45
    hf_model = AutoModel.from_pretrained(
46
        model, dtype=torch_dtype, trust_remote_code=True
47
    ).to("cuda")
48
49
50
51
52
    hf_outputs_per_image = [
        hf_model(pixel_value.to("cuda")).last_hidden_state
        for pixel_value in pixel_values
    ]

53
    from vllm.model_executor.models.intern_vit import InternVisionModel
54

55
56
57
58
    vllm_model = InternVisionModel(config)
    vllm_model.load_weights(hf_model.state_dict().items())

    del hf_model
59
    cleanup_dist_env_and_memory()
60

61
    vllm_model = vllm_model.to("cuda", torch_dtype)
62
    vllm_outputs_per_image = [
63
        vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values
64
65
    ]
    del vllm_model
66
    cleanup_dist_env_and_memory()
67
68

    cos_similar = nn.CosineSimilarity(dim=-1)
69
    for vllm_output, hf_output in zip(vllm_outputs_per_image, hf_outputs_per_image):
70
71
72
        assert cos_similar(vllm_output, hf_output).mean() > 0.99


73
74
75
76
77
78
79
@pytest.mark.parametrize(
    "model_id",
    [
        "OpenGVLab/InternViT-300M-448px",
        "OpenGVLab/InternViT-6B-448px-V1-5",
    ],
)
80
@pytest.mark.parametrize("dtype", ["half"])
81
82
83
def test_models(
    default_vllm_config, dist_init, image_assets, model_id, dtype: str
) -> None:
84
85
    run_intern_vit_test(
        image_assets,
86
        model_id,
87
88
        dtype=dtype,
    )