docowl_infer.py 3.63 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from PIL import Image
from transformers import TextStreamer
import os

from mplug_docowl.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_docowl.conversation import conv_templates, SeparatorStyle
from mplug_docowl.model.builder import load_pretrained_model
from mplug_docowl.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from mplug_docowl.processor import DocProcessor
from icecream import ic
import time


class DocOwlInfer():
    def __init__(self, ckpt_path, anchors='grid_9', add_global_img=True, load_8bit=False, load_4bit=False):
        model_name = get_model_name_from_path(ckpt_path)
        ic(model_name)
        self.tokenizer, self.model, _, _ = load_pretrained_model(ckpt_path, None, model_name, load_8bit=load_8bit, load_4bit=load_4bit, device="cuda")
        self.doc_image_processor = DocProcessor(image_size=448, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
        self.streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)

    def inference(self, image, query):
        image_tensor, patch_positions, text = self.doc_image_processor(images=image, query='<|image|>'+query)
        image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
        patch_positions = patch_positions.to(self.model.device)

        # ic(image_tensor.shape, patch_positions.shape, text)

        conv = conv_templates["mplug_owl2"].copy()
        roles = conv.roles # ("USER", "ASSISTANT")

        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        # ic(prompt)

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
        
        # ic(input_ids)

        stop_str = conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                images=image_tensor,
                patch_positions=patch_positions,
                do_sample=False,
                temperature=1.0,
                max_new_tokens=512,
                streamer=self.streamer,
                use_cache=True,
                stopping_criteria=[stopping_criteria])

        outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()

        return outputs.replace('</s>', '')



if __name__ == '__main__':

    model_path = '/home/wanglch/mPLUG-DocOwl/DocOwl1.5-Omni-base/'
    docowl = DocOwlInfer(ckpt_path=model_path, anchors='grid_9', add_global_img=True)
    print('load model from ', model_path)
    # exit(0)

    qas = [
        # docvqa case
        {"image_path":"/home/wanglch/mPLUG-DocOwl/image/hp.jpg", 
        "question":"详细描述这张图片"},
        {"image_path":"/home/wanglch/mPLUG-DocOwl/image/R-C.jpg", 
        "question":"详细描述这张图片"},
    ]
    

    for qa in qas:
        image= qa['image_path'] 
        query = qa['question']
        start_time = time.time()
        ## give relatively longer answer
        answer = docowl.inference(image, query)
        end_time = time.time()
        cost_seconds = end_time-start_time

        ## answer with detailed explanation
        # query = qa['question']+'Answer the question with detailed explanation.'
        # answer = docowl.inference(image, query)
        ic(image)
        ic(query, answer)
        ic(cost_seconds)
        # ic(query_simple, answer_simple)
        print('==================')