test_intern_vit.py 2.37 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
import pytest
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor

8
9
from vllm.distributed import cleanup_dist_env_and_memory

10
from ....conftest import _ImageAssets
11
12
13
14
15
16
17
18

# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]


def run_intern_vit_test(
    image_assets: _ImageAssets,
19
    model_id: str,
20
21
22
    *,
    dtype: str,
):
23
24
    model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    img_processor = CLIPImageProcessor.from_pretrained(model)
    images = [asset.pil_image for asset in image_assets]
    pixel_values = [
        img_processor(images, return_tensors='pt').pixel_values.to(dtype)
        for images in images
    ]

    config = AutoConfig.from_pretrained(model, trust_remote_code=True)
    if not getattr(config, "norm_type", None):
        config.norm_type = "rms_norm"

    hf_model = AutoModel.from_pretrained(model,
                                         torch_dtype=dtype,
                                         trust_remote_code=True).to("cuda")
    hf_outputs_per_image = [
        hf_model(pixel_value.to("cuda")).last_hidden_state
        for pixel_value in pixel_values
    ]

44
    from vllm.model_executor.models.intern_vit import InternVisionModel
45
46
47
48
    vllm_model = InternVisionModel(config)
    vllm_model.load_weights(hf_model.state_dict().items())

    del hf_model
49
    cleanup_dist_env_and_memory()
50
51
52
53
54
55
56

    vllm_model = vllm_model.to("cuda", dtype)
    vllm_outputs_per_image = [
        vllm_model(pixel_values=pixel_value.to("cuda"))
        for pixel_value in pixel_values
    ]
    del vllm_model
57
    cleanup_dist_env_and_memory()
58
59
60
61
62
63
64

    cos_similar = nn.CosineSimilarity(dim=-1)
    for vllm_output, hf_output in zip(vllm_outputs_per_image,
                                      hf_outputs_per_image):
        assert cos_similar(vllm_output, hf_output).mean() > 0.99


65
66
67
68
@pytest.mark.parametrize("model_id", [
    "OpenGVLab/InternViT-300M-448px",
    "OpenGVLab/InternViT-6B-448px-V1-5",
])
69
70
@pytest.mark.parametrize("dtype", [torch.half])
@torch.inference_mode()
71
def test_models(image_assets, model_id, dtype: str) -> None:
72
73
    run_intern_vit_test(
        image_assets,
74
        model_id,
75
76
        dtype=dtype,
    )