aria.py 7.86 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import torch
import warnings
import copy as cp
from PIL import Image
import pandas as pd
import string
import re
from .base import BaseModel
from ..smp import isimg, listinstr, cn_string
from ..dataset import DATASET_TYPE, DATASET_MODALITY


class Aria(BaseModel):

    INSTALL_REQ = False
    INTERLEAVE = True

    def __init__(self, model_path='rhymes-ai/Aria', **kwargs):
        from transformers import AutoModelForCausalLM, AutoProcessor
        assert model_path is not None
        self.model_path = model_path
        processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        tokenizer = processor.tokenizer
        tokenizer.padding_side = 'left'
        tokenizer.pad_token_id = tokenizer.unk_token_id
        self.processor = processor
        self.tokenizer = tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map='cuda',
            torch_dtype=torch.bfloat16,
            trust_remote_code=True
        ).eval()
        default_kwargs = dict(
            do_sample=False,
            num_beams=1,
            max_new_tokens=512,
            min_new_tokens=1,
            num_return_sequences=1,
            use_cache=True,
            output_hidden_states=True,
            pad_token_id=tokenizer.unk_token_id,
            stop_strings=["<|im_end|>"],
            tokenizer=processor.tokenizer,
        )
        default_kwargs.update(kwargs)
        self.kwargs = default_kwargs
        warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
        torch.cuda.empty_cache()

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if listinstr(['MMDU', 'MME-RealWorld', 'MME-RealWorld-CN'], dataset):
            # For Multi-Turn we don't have custom prompt
            return False
        if DATASET_MODALITY(dataset) == 'VIDEO':
            # For Video benchmarks we don't have custom prompt at here
            return False
        else:
            return True

    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert dataset is None or isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)

        question = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            question = hint + '\n' + question

        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f'\n{key}. {item}'
        prompt = question

        if len(options):
            prompt += (
                "\nAnswer with the option's letter from the given choices directly."
            )
        else:
            if listinstr(['MathVista', 'MathVision', 'VCR', 'MTVQA', 'MMVet', 'MathVerse'], dataset):
                prompt = prompt
            elif listinstr(['LLaVABench', 'MMBench-Video'], dataset):
                prompt += '\nAnswer this question in detail.'
            elif listinstr(['DocVQA'], dataset):
                prompt += '\nAnswer briefly and directly.'
            else:
                prompt += '\nAnswer the question using a single word or phrase.'

        message = [dict(type='image', value=s) for s in tgt_path]
        message.append(dict(type='text', value=prompt))
        return message

    def build_video_prompt(self, prompt, dataset=None):
        if listinstr(['MMBench-Video'], dataset):
            prompt = prompt.replace('\nAnswer:', '')
            prompt = prompt.replace(
                'Question: ',
                'Please carefully check the video and then answer the following question with details:'
            )
        elif listinstr(['Video-MME'], dataset):
            prompt = prompt.replace('\nAnswer:', '')
            prompt += "\nAnswer with the option's letter from the given choices directly."
        elif listinstr(['MVBench'], dataset):
            prompt = prompt.replace('Best option:(', '')
            system_prompt = 'Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n'  # noqa: E501
            prompt = prompt.replace(system_prompt, '')

        return prompt

    def adjust_kwargs(self, dataset):
        kwargs = cp.deepcopy(self.kwargs)
        kwargs["temperature"] = 0.0
        kwargs["do_sample"] = False

        if DATASET_MODALITY(dataset) == "VIDEO":
            kwargs["max_image_size"] = 490
        else:
            kwargs["max_image_size"] = 980

        kwargs["split_image"] = False

        if listinstr(['MMMU', 'MMStar', 'Math'], dataset):
            # These datasets may lead the model to work as a CoT-alike behaviour.
            # Allow to output longer.
            kwargs['max_new_tokens'] = 512
            return kwargs
        if DATASET_TYPE(dataset) in ['MCQ', 'Y/N']:
            kwargs['max_new_tokens'] = 64
        elif DATASET_TYPE(dataset) == 'Caption' and 'COCO' in dataset:
            kwargs['max_new_tokens'] = 64
        elif DATASET_TYPE(dataset) == 'VQA':
            if listinstr(['OCRVQA', 'ChartQA', 'DocVQA'], dataset):
                kwargs['max_new_tokens'] = 128
            elif listinstr(['TextVQA'], dataset):
                kwargs['max_new_tokens'] = 32

        if listinstr(['OCR', 'ChartQA', 'DocVQA', 'InfoVQA', 'TextVQA'], dataset):
            # OCR-related datasets that need to split image
            kwargs["split_image"] = True

        return kwargs

    def generate_inner(self, message, dataset=None):
        if dataset is not None:
            kwargs = self.adjust_kwargs(dataset)
        else:
            kwargs = self.kwargs

        max_image_size = kwargs.pop("max_image_size")
        split_image = kwargs.pop("split_image")

        prompt = '<|im_start|>user\n'
        images = []
        last_message_modality = "text"

        if listinstr(['MLVU', 'TempCompass', 'MVBench'], dataset):  # re-arrange the data
            new_message = []
            for s in message:
                if s['type'] == 'image':
                    new_message.append(s)
            for s in message:
                if s['type'] == 'text':
                    new_message.append(s)
            message = new_message

        for s in message:
            if s['type'] == 'image':
                prompt += '<fim_prefix><|img|><fim_suffix>'
                images.append(s['value'])
                last_message_modality = "image"
            elif s['type'] == 'text':
                text = re.sub(r"<image \d+>", "", s["value"])
                if last_message_modality == "image":
                    prompt += "\n"
                    last_message_modality = "text"
                prompt += text

        if DATASET_MODALITY(dataset) == 'VIDEO':
            prompt = self.build_video_prompt(prompt, dataset)

        prompt += '<|im_end|>\n<|im_start|>assistant\n'
        if images:
            images = [Image.open(s).convert('RGB') for s in images]
            encoded = self.processor(
                text=prompt,
                images=images,
                return_tensors='pt',
                padding='longest',
                max_image_size=max_image_size,
                split_image=split_image,
            )
        else:
            encoded = self.processor(text=prompt, return_tensors='pt', padding='longest')
        encoded["pixel_values"] = encoded["pixel_values"].to(self.model.dtype)
        encoded = {k: v.to(self.model.device) for k, v in encoded.items()}

        pred = self.model.generate(**encoded, **kwargs)
        answer = self.tokenizer.decode(pred[0][encoded['input_ids'].size(1):].cpu(), skip_special_tokens=True).strip()
        answer = answer.replace('<|im_end|>', '')
        return answer