sensechat_vision.py 10.6 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
from vlmeval.smp import *
from vlmeval.api.base import BaseAPI
from vlmeval.dataset import img_root_map
from vlmeval.dataset import DATASET_TYPE


class SenseChatVisionWrapper(BaseAPI):

    is_api: bool = True

    def __init__(self,
                 model: str = 'SenseChat-5-Vision',
                 retry: int = 5,
                 wait: int = 5,
                 ak: str = None,
                 sk: str = None,
                 verbose: bool = True,
                 system_prompt: str = None,
                 max_tokens: int = 1024,
                 proxy: str = None,
                 **kwargs):

        self.model = model
        self.fail_msg = 'Failed to obtain answer via API. '
        self.ak = os.environ.get('SENSECHAT_AK', None) if ak is None else ak
        self.sk = os.environ.get('SENSECHAT_SK', None) if sk is None else sk
        assert self.ak is not None and self.sk is not None
        self.max_new_tokens = max_tokens
        super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)

    def dump_image(self, line, dataset):
        """Dump the image(s) of the input line to the corresponding dataset folder.

        Args:
            line (line of pd.DataFrame): The raw input line.
            dataset (str): The name of the dataset.

        Returns:
            str | list[str]: The paths of the dumped images.
        """
        ROOT = LMUDataRoot()
        assert isinstance(dataset, str)
        img_root = osp.join(ROOT, 'images', img_root_map(dataset))
        os.makedirs(img_root, exist_ok=True)
        if 'image' in line:
            if isinstance(line['image'], list):
                tgt_path = []
                assert 'image_path' in line
                for img, im_name in zip(line['image'], line['image_path']):
                    path = osp.join(img_root, im_name)
                    if not read_ok(path):
                        decode_base64_to_image_file(img, path)
                    tgt_path.append(path)
            else:
                tgt_path = osp.join(img_root, f"{line['index']}.jpg")
                if not read_ok(tgt_path):
                    decode_base64_to_image_file(line['image'], tgt_path)
                tgt_path = [tgt_path]
        else:
            assert 'image_path' in line
            tgt_path = toliststr(line['image_path'])

        return tgt_path

    def image_to_base64(self, image_path):
        import base64
        with open(image_path, 'rb') as image_file:
            encoded_string = base64.b64encode(image_file.read())
            return encoded_string.decode('utf-8')

    def encode_jwt_token(self, ak, sk):
        import jwt
        headers = {'alg': 'HS256', 'typ': 'JWT'}
        payload = {
            'iss': ak,
            'exp': int(time.time())
            + 1800,  # 填写您期望的有效时间,此处示例代表当前时间+30分钟
            'nbf': int(time.time()) - 5,  # 填写您期望的生效时间,此处示例代表当前时间-5秒
        }
        token = jwt.encode(payload, sk, headers=headers)
        return token

    def use_custom_prompt(self, dataset):
        return True

    def build_multi_choice_prompt(self, line, dataset=None):
        question = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            question = hint + '\n' + question

        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f'\n{key}. {item}'
        prompt = question

        if len(options):
            prompt += '\n请直接回答选项字母。' if cn_string(
                prompt) else "\nAnswer with the option's letter from the given choices directly."
        else:
            prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'

        return prompt

    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert dataset is None or isinstance(dataset, str)

        tgt_path = self.dump_image(line, dataset)

        if dataset is not None and listinstr(['MME'], dataset):
            question = line['question']
            prompt = question + ' Answer the question using a single word or phrase.'
        elif dataset is not None and listinstr(['HallusionBench'], dataset):
            question = line['question']
            prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
        elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ' and 'MMMU' not in dataset:
            prompt = self.build_multi_choice_prompt(line, dataset)
        elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
            if 'MathVista' in dataset:
                prompt = line['question']
            elif listinstr(['LLaVABench'], dataset):
                question = line['question']
                prompt = question + '\nAnswer this question in detail.'
            elif listinstr(['MMVet'], dataset):
                prompt = line['question']
            else:
                question = line['question']
                prompt = question + '\nAnswer the question using a single word or phrase.'
        elif dataset is not None and 'MMMU' in dataset:
            question = line['question']
            options = {
                cand: line[cand]
                for cand in string.ascii_uppercase
                if cand in line and not pd.isna(line[cand])
            }
            for key, item in options.items():
                question += f'\n{key}. {item}'
            prompt = {
                'multiple-choice': 'Answer with carefully thought step by step. Apply the thinking process recursively at both macro and micro levels. Verify consistency of reasoning and look for potential flaws or gaps during thinking. When realize mistakes, explain why the previous thinking was incorrect, fix it and then continue thinking.\n\n',  # noqa
                'open': 'Answer with carefully thought step by step. Apply the thinking process recursively at both macro and micro levels. Verify consistency of reasoning and look for potential flaws or gaps during thinking. When realize mistakes, explain why the previous thinking was incorrect, fix it and then continue thinking.\n\n'  # noqa
            }
            subject = '_'.join(line['id'].split('_')[1:-1])
            prompt = prompt[line['question_type']].format(subject, subject) + '\n' + question
        else:
            prompt = line['question']

        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=s) for s in tgt_path])

        return message

    def message_to_promptimg(self, message, dataset=None):
        if dataset is None or listinstr(['MMMU', 'BLINK'], dataset):
            prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
            image = [[x['value'] for x in message if x['type'] == 'image'][0]]
        else:
            prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
            image = [x['value'] for x in message if x['type'] == 'image']
        return prompt, image

    def generate_inner(self, inputs, **kwargs) -> str:
        assert isinstance(inputs, str) or isinstance(inputs, list)
        inputs = [inputs] if isinstance(inputs, str) else inputs
        dataset = kwargs.get('dataset', None)

        if dataset is not None and listinstr(['ChartQA_TEST','MathVista_MINI'], dataset):
            self.max_num = 12
        elif dataset is not None and listinstr(['DocVQA_VAL', 'DocVQA_TEST'], dataset):
            self.max_num = 18
        elif dataset is not None and listinstr(['InfoVQA_VAL', 'InfoVQA_TEST', 'OCRBench'], dataset):
            self.max_num = 24
        else:
            self.max_num = 6

        if dataset is None:
            pass
        elif listinstr(['AI2D_TEST'], dataset):
            self.max_new_tokens = 10
        elif 'MMMU' in dataset:
            self.max_new_tokens = 4096  # 1024
        elif 'MMBench' in dataset:
            self.max_new_tokens = 100
        elif 'MathVista_MINI' in dataset:
            self.max_new_tokens = 4096

        prompt, image = self.message_to_promptimg(message=inputs, dataset=dataset)

        url = 'https://api.sensenova.cn/v1/llm/chat-completions'
        api_secret_key = self.encode_jwt_token(self.ak, self.sk)

        content = [{
            'image_base64': self.image_to_base64(item),
            'image_file_id': '',
            'image_url': '',
            'text': '',
            'text': '',
            'type': 'image_base64'
        } for item in image]

        content.append({
            'image_base64': '',
            'image_file_id': '',
            'image_url': '',
            'text': prompt,
            'type': 'text'
        })

        message = [{'content': content, 'role': 'user'}]

        data = {
            'messages': message,
            'max_new_tokens': self.max_new_tokens,  # 1024
            'temperature': 0,
            "top_k": 0,
            "top_p": 0.99,
            'repetition_penalty': 1.05,
            'model': self.model,
            'stream': False,
        }
        headers = {
            'Content-type': 'application/json',
            'Authorization': 'Bearer ' + api_secret_key
        }

        response = requests.post(
            url,
            headers=headers,
            json=data,
        )
        request_id = response.headers['x-request-id']

        time.sleep(1)
        try:
            assert response.status_code == 200
            response = response.json()['data']['choices'][0]['message'].strip()
            if self.verbose:
                self.logger.info(f'inputs: {inputs}\nanswer: {response}')
            return 0, response, 'Succeeded! '
        except Exception as err:
            if self.verbose:
                self.logger.error('---------------------------ERROR---------------------------')
                self.logger.error(response.json())
                self.logger.error(err)
                self.logger.error('---------------------------request_id---------------------------' + request_id)
                self.logger.error(
                    'api error' + response.json()['error']['message']
                    + str([input['value'] if input['type'] == 'image' else None for input in inputs])
                )
                self.logger.error(f'The input messages are {inputs}.')
            return -1, response.json()['error']['message'], ''


class SenseChatVisionAPI(SenseChatVisionWrapper):

    def generate(self, message, dataset=None):
        return super(SenseChatVisionAPI, self).generate(message, dataset=dataset)