test_intern_vit.py 2.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
import pytest
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor

8
from vllm.distributed import cleanup_dist_env_and_memory
9
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
10

11
from ....conftest import ImageTestAssets
12
13
14
15
16
17

# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]


18
@torch.inference_mode()
19
def run_intern_vit_test(
20
    image_assets: ImageTestAssets,
21
    model_id: str,
22
23
24
    *,
    dtype: str,
):
25
    model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
26
    torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
27

28
29
30
    img_processor = CLIPImageProcessor.from_pretrained(model)
    images = [asset.pil_image for asset in image_assets]
    pixel_values = [
31
        img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype)
32
33
34
35
36
37
38
39
        for images in images
    ]

    config = AutoConfig.from_pretrained(model, trust_remote_code=True)
    if not getattr(config, "norm_type", None):
        config.norm_type = "rms_norm"

    hf_model = AutoModel.from_pretrained(model,
40
                                         torch_dtype=torch_dtype,
41
42
43
44
45
46
                                         trust_remote_code=True).to("cuda")
    hf_outputs_per_image = [
        hf_model(pixel_value.to("cuda")).last_hidden_state
        for pixel_value in pixel_values
    ]

47
    from vllm.model_executor.models.intern_vit import InternVisionModel
48
49
50
51
    vllm_model = InternVisionModel(config)
    vllm_model.load_weights(hf_model.state_dict().items())

    del hf_model
52
    cleanup_dist_env_and_memory()
53

54
    vllm_model = vllm_model.to("cuda", torch_dtype)
55
56
57
58
59
    vllm_outputs_per_image = [
        vllm_model(pixel_values=pixel_value.to("cuda"))
        for pixel_value in pixel_values
    ]
    del vllm_model
60
    cleanup_dist_env_and_memory()
61
62
63
64
65
66
67

    cos_similar = nn.CosineSimilarity(dim=-1)
    for vllm_output, hf_output in zip(vllm_outputs_per_image,
                                      hf_outputs_per_image):
        assert cos_similar(vllm_output, hf_output).mean() > 0.99


68
69
70
71
@pytest.mark.parametrize("model_id", [
    "OpenGVLab/InternViT-300M-448px",
    "OpenGVLab/InternViT-6B-448px-V1-5",
])
72
73
@pytest.mark.parametrize("dtype", ["half"])
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
74
75
    run_intern_vit_test(
        image_assets,
76
        model_id,
77
78
        dtype=dtype,
    )