minicpmv.py 2.96 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer

Image.MAX_IMAGE_PIXELS = 1000000000

max_token  = {
    'docVQA': 100,
    'textVQA': 100,
    "docVQATest": 100
}

class MiniCPM_V:

    def __init__(self, model_path, ckpt, device=None)->None:
        self.model_path = model_path
        self.ckpt = ckpt
        self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).eval()
        if self.ckpt is not None:
            self.ckpt = ckpt
            self.state_dict = torch.load(self.ckpt, map_location=torch.device('cpu'))
            self.model.load_state_dict(self.state_dict)
            
        self.model = self.model.to(dtype=torch.float16)
        self.model.to(device)
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        torch.cuda.empty_cache()

    def generate(self, images, questions, datasetname):
        image = Image.open(images[0]).convert('RGB')
        try:
            max_new_tokens = max_token[datasetname]
        except:
            max_new_tokens = 1024
        if (datasetname == 'docVQA') or (datasetname == "docVQATest") :
            prompt = "Answer the question directly with single word." + "\n" + questions[0]
        elif (datasetname == 'textVQA') :
            prompt = "Answer the question directly with single word." + '\n'+ questions[0]
        
        msgs = [{'role': 'user', 'content': prompt}]
        default_kwargs = dict(
            max_new_tokens=max_new_tokens,
            sampling=False,
            num_beams=3
        )
        res = self.model.chat(
            image=image,
            msgs=msgs,
            context=None,
            tokenizer=self.tokenizer,
            **default_kwargs
        )
        
        return [res]
    
    def generate_with_interleaved(self, images, questions, datasetname):
        try:
            max_new_tokens = max_token[datasetname]
        except:
            max_new_tokens = 1024
        
        prompt = "Answer the question directly with single word."
        
        default_kwargs = dict(
            max_new_tokens=max_new_tokens,
            sampling=False,
            num_beams=3
        )
        
        content = []
        message = [
            {'type': 'text', 'value': prompt},
            {'type': 'image', 'value': images[0]},
            {'type': 'text', 'value': questions[0]}
        ]
        for x in message:
            if x['type'] == 'text':
                content.append(x['value'])
            elif x['type'] == 'image':
                image = Image.open(x['value']).convert('RGB')
                content.append(image)
        msgs = [{'role': 'user', 'content': content}]

        res = self.model.chat(
            msgs=msgs,
            context=None,
            tokenizer=self.tokenizer,
            **default_kwargs
        )

        if isinstance(res, tuple) and len(res) > 0:
            res = res[0]
        print(f"Q: {content}, \nA: {res}")
        return [res]