test.py 1.9 KB
Newer Older
zhe chen's avatar
zhe chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from PIL import Image
from transformers import (AutoConfig, AutoModel,
                          AutoModelForImageClassification, CLIPImageProcessor)


def test_model(model_name):
    print('model_name:', model_name)
    image_path = 'img_1.png'
    image_processor = CLIPImageProcessor.from_pretrained(model_name)
    image = Image.open(image_path)
    image = image_processor(images=image, return_tensors='pt').pixel_values
    print('image shape:', image.shape)

    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
    output = model(image)
    for k, v in output.items():
        if type(v) == list:
            for idx, item in enumerate(v):
                print(f'{k}_{idx} shape: {item.shape}')
        elif v is None:
            print(f'{k} is None')
        else:
            print(f'{k} shape: {v.shape}')

    print('------------------------')

    model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)
    output = model(image)
    for k, v in output.items():
        if type(v) == list:
            for idx, item in enumerate(v):
                print(f'{k}_{idx} shape: {item.shape}')
        elif v is None:
            print(f'{k} is None')
        else:
            print(f'{k} shape: {v.shape}')
    logits = output['logits']
    argmax = int(torch.argmax(logits, dim=1))
    print(argmax)


test_model('./22k_model/internimage_l_22k_384')
test_model('./22k_model/internimage_xl_22k_384')
test_model('./22k_model/internimage_h_jointto22k_384')
test_model('./22k_model/internimage_g_jointto22k_384')
test_model('./in1k_model/internimage_t_1k_224')
test_model('./in1k_model/internimage_s_1k_224')
test_model('./in1k_model/internimage_b_1k_224')
test_model('./in1k_model/internimage_l_22kto1k_384')
test_model('./in1k_model/internimage_xl_22kto1k_384')
test_model('./in1k_model/internimage_h_22kto1k_640')
test_model('./in1k_model/internimage_g_22kto1k_512')