mplug_owl2.py 5.04 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import sys
import torch
from PIL import Image
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE


class mPLUG_Owl2(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(self, model_path='MAGAer13/mplug-owl2-llama2-7b', **kwargs):
        try:
            from mplug_owl2.model.builder import load_pretrained_model
            from mplug_owl2.mm_utils import get_model_name_from_path
        except Exception as e:
            logging.critical('Please install mPLUG_Owl2 before using mPLUG_Owl2. ')
            raise e

        model_name = get_model_name_from_path(model_path)
        tokenizer, model, image_processor, context_len = load_pretrained_model(
            model_path, None, model_name, load_8bit=False, load_4bit=False, device='cpu')

        self.model = model.cuda()
        self.device = self.model.device
        self.image_processor = image_processor
        tokenizer.padding_side = 'left'
        tokenizer.pad_token_id = tokenizer.eos_token_id
        self.tokenizer = tokenizer
        self.context_len = context_len

        kwargs_default = dict(
            max_new_tokens=512, do_sample=False, num_beams=1,
            min_new_tokens=1, length_penalty=1, num_return_sequences=1)
        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default
        warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if listinstr(['MMMU'], dataset):
            return False
        if DATASET_TYPE(dataset) == 'MCQ' or dataset == 'MMVet':
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert dataset is None or isinstance(dataset, str)
        assert self.use_custom_prompt(dataset)
        tgt_path = self.dump_image(line, dataset)
        question = line['question']
        if dataset == 'MMVet':
            prompt = question + '\nAnswer the question directly. '
        elif DATASET_TYPE(dataset) == 'MCQ':
            options = {
                cand: line[cand]
                for cand in string.ascii_uppercase
                if cand in line and not pd.isna(line[cand])
            }
            options_prompt = ''
            for key, item in options.items():
                options_prompt += f'{key}. {item}\n'

            hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
            prompt = f'Hint: {hint}\n' if hint is not None else ''
            prompt += f'{question}\n'
            prompt += (
                f'{options_prompt}\nAnswer with the option’s letter from the given choices directly. '
                if len(options) else 'Answer the question directly. '
            )
        else:
            raise NotImplementedError

        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=s) for s in tgt_path])
        return message

    def generate_inner(self, message, dataset=None):
        from mplug_owl2.constants import IMAGE_TOKEN_INDEX
        from mplug_owl2.mm_utils import process_images, tokenizer_image_token
        kwargs = cp.deepcopy(self.kwargs)
        if dataset in ['MMVet', 'LLaVABench']:
            kwargs['length_penalty'] = 0
        elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
            kwargs['length_penalty'] = 0
        elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
            kwargs['max_new_tokens'] = 10
        num_images = len([x for x in message if x['type'] == 'image'])
        assert num_images >= 0
        prompt_full = 'USER: '
        images = []
        if num_images == 1:
            prompt, image = self.message_to_promptimg(message, dataset=dataset)
            prompt_full += f'<|image|>{prompt} \nASSISTANT: '
            images.append(image)
        else:
            for msg in message:
                if msg['type'] == 'image':
                    images.append(msg['value'])
                    prompt_full += '<|image|>'
                elif msg['type'] == 'text':
                    prompt_full += msg['value']
            prompt_full += '\nASSISTANT: '

        def preproc_image(fname):
            image = Image.open(fname).convert('RGB')
            max_edge = max(image.size)
            image = image.resize((max_edge, max_edge))
            return image
        images = [preproc_image(fname) for fname in images]
        image_tensor = process_images(images, self.image_processor)
        image_tensor = image_tensor.to(self.device, dtype=torch.float16)
        input_ids = tokenizer_image_token(
            prompt_full, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids=input_ids,
                images=image_tensor,
                output_hidden_states=True,
                use_cache=True,
                **kwargs)
        answer = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
        return answer.split('</s>')[0]