import os import requests import time import random import copy import traceback import pandas as pd from PIL import Image from typing import List, Dict, Tuple, Any from common_utils import encode_image_to_base64 from collections import defaultdict try: from latex2sympy2 import latex2sympy except ImportError: print('Warning: latex2sympy2 not installed. Install with: pip install latex2sympy2') latex2sympy = None FAIL_MSG = 'Failed to obtain answer via API.' def is_equal(asw: str, gt_asw: str) -> bool: """Check if two answers are equal.""" if not isinstance(asw, str) or not isinstance(gt_asw, str): print('Warning: input is not string') print(asw, gt_asw) asw = str(asw).lower().strip() gt_asw = str(gt_asw).lower().strip() if gt_asw == asw: return True try: a = eval(gt_asw) b = eval(asw) if abs(a - b) < 1e-6: return True except: pass if latex2sympy is not None: try: a = latex2sympy(gt_asw) b = latex2sympy(asw) if abs(eval(str(a)) - eval(str(b))) < 1e-6: return True if abs(a - b) < 1e-6: return True except: pass return False def get_gpt4_ICE(): """Get in-context examples for GPT-4 answer extraction.""" example_1 = """ Hint: Please answer the question and provide the final answer at the end.\n Question: Which number is missing?\n Model response: The number missing in the sequence is 14.\n Extracted answer: 14 """ example_2 = """ Hint: Please answer the question and provide the final answer at the end.\n Question: What is the fraction of females facing the camera?\n Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.\n Extracted answer: 0.6 """ example_3 = """ Hint: Please answer the question and provide the final answer at the end.\n Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n Extracted answer: 1.45 """ example_4 = """ Hint: Please answer the question and provide the final answer at the end.\n Question: Between which two years does the line graph saw its maximum peak?\n Model response: The line graph saw its maximum peak between 2007 and 2008.\n Extracted answer: [2007, 2008] """ example_5 = """ Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n Question: What fraction of the shape is blue?\n Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n Model response: The correct answer is (B) 8/11.\n Extracted answer: B """ return [example_1, example_2, example_3, example_4, example_5] def build_mathv_gpt4_prompt(line): """Build the prompt for GPT-4 to extract answer from model response.""" task_description = """ Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.\n """ question = line['question'] prediction = str(line['prediction']) prompt = task_description examples = get_gpt4_ICE() for example in examples: prompt += example + '\n' prompt += question + '\n' prompt += 'Model response: ' + prediction + '\n' prompt += 'Extracted answer: ' return prompt def list_to_dict(lst): """Convert list to dictionary with uppercase letters as keys.""" return {chr(65 + i): val for i, val in enumerate(lst)} def can_infer_option(answer, choices): """Rule-based extraction of answer option.""" if FAIL_MSG in answer: return False reject_to_answer = [ "Sorry, I can't help with images of people yet.", "I can't process this file.", "I'm sorry, but without the image provided", 'Cannot determine the answer' ] for err in reject_to_answer: if err in answer: return 'Z' def count_choice(splits, choices, prefix='', suffix=''): cnt = 0 for c in choices: if prefix + c + suffix in splits: cnt += 1 return cnt answer_mod = copy.copy(answer) chars = '.()[],:;!*#{}' for c in chars: answer_mod = answer_mod.replace(c, ' ') splits = [x.strip() for x in answer_mod.split()] count = count_choice(splits, choices) if count == 1: for ch in choices: if 'A' in splits and len(splits) > 3: return False if ch in splits: return ch elif count == 0 and count_choice(splits, {'Z', ''}) == 1: return 'Z' return False def can_infer_text(answer, choices): """Extract answer by matching text content.""" answer = answer.lower() assert isinstance(choices, dict) for k in choices: choices[k] = str(choices[k]).lower() cands = [] for k in choices: if choices[k] in answer: cands.append(k) if len(cands) == 1: return cands[0] return False def can_infer(answer, choices): """Combined approach to infer answer choice.""" answer = str(answer) copt = can_infer_option(answer, choices) return copt if copt else can_infer_text(answer, choices) def post_check(line, prefetch=False): """Check if the prediction matches the answer.""" res = None ans = line['answer'] response = line['prediction'] if prefetch else line['res'] try: if len(eval(line['choices'])) > 0: ans = line['answer'] choices = list_to_dict(eval(line['choices'])) res = can_infer(response, choices) if prefetch: return res else: res = str(response) ans = str(ans) except ValueError: pass if is_equal(res, ans): return res if prefetch else True else: return False class OpenAIWrapper: """Wrapper for OpenAI API.""" def __init__(self, model, api_base, api_key, timeout=60, retry=5, wait=5): self.model = model self.api_base = api_base self.api_key = api_key self.timeout = timeout self.retry = retry self.wait = wait self.fail_msg = FAIL_MSG def generate(self, prompt, temperature=0): """Generate a response from the API.""" headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.api_key}'} payload = { "model": self.model, "messages": [{"role": "user", "content": prompt}], "max_tokens": 4096, "temperature": temperature } for i in range(self.retry): try: response = requests.post( self.api_base, headers=headers, json=payload, timeout=self.timeout ) if response.status_code == 200: resp_json = response.json() return resp_json['choices'][0]['message']['content'].strip() time.sleep(self.wait) except Exception as e: print(f"API error: {e}") time.sleep(self.wait) return self.fail_msg class DashScopeWrapper: """Wrapper for DashScope API.""" def __init__(self, model, api_base, api_key, timeout=60, retry=5, wait=5): self.model = model self.api_base = api_base self.api_key = api_key self.timeout = timeout self.retry = retry self.wait = wait self.fail_msg = FAIL_MSG def generate(self, prompt, temperature=0): """Generate a response from the API.""" headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.api_key}'} payload = { "model": self.model, "messages": [{"role": "user", "content": prompt}], "max_completion_tokens": 4096, "n": 1, "temperature": temperature, "stream": False } for i in range(self.retry): try: response = requests.post( self.api_base, headers=headers, json=payload, timeout=self.timeout ) if response.status_code == 200: resp_json = response.json() # Check finish reason for output in resp_json['choices']: if output['finish_reason'] not in ['stop', 'function_call']: print(f"DashScope finished with error: {resp_json}") time.sleep(self.wait) continue return resp_json['choices'][0]['message']['content'] else: print(f"DashScope API error: HTTP {response.status_code}") try: error_content = response.json() print(f"Error details: {error_content}") except: print(f"Raw error content: {response.content.decode('utf-8', errors='replace')}") time.sleep(self.wait) except Exception as e: print(f"DashScope error: {e}") time.sleep(self.wait) return self.fail_msg def build_judge(model, api_type): """Build a judge model for evaluation.""" if api_type == 'mit': api_key = os.environ.get('MIT_SPIDER_TOKEN', '') api_base = os.environ.get('MIT_SPIDER_URL', '') return OpenAIWrapper(model, api_base, api_key) elif api_type == 'dash': api_key = os.environ.get('CHATGPT_DASHSCOPE_API_KEY', '') api_base = os.environ.get('DASHSCOPE_API_BASE', '') return DashScopeWrapper(model, api_base, api_key) else: raise ValueError(f"Unsupported API type: {api_type}") def MATH_V_auxeval(args): """Auxiliary evaluation for MathVision - extract answer from model response.""" model, line = args prompt = build_mathv_gpt4_prompt(line) log = '' retry = 5 # Try rule-based extraction first if post_check(line, prefetch=True): res = post_check(line, prefetch=True) log += 'Prefetch succeed.\n' extract_flag = True if not res or res == 'Z': extract_flag = False log += f'Rule extract failed with ans: {res}' else: log += f'Rule extract success with ans: {res}' return dict(log=log, res=res, extract_model='rule', extract_flag=extract_flag) # Use model-based extraction for i in range(retry): prediction = line['prediction'] res = model.generate(prompt, temperature=i * 0.5) if FAIL_MSG in res: log += f'Try {i}: output is {prediction}, failed to parse.\n' else: log += f'{model.model} extract Succeed.\n' return dict(log=log, res=res, extract_model=model.model, extract_flag=True) log += f'All {retry} retries failed.\n {model.model} response:{res}' return dict(log=log, res='', extract_model=model.model, extract_flag=False) def MATH_V_acc(result_file): """Calculate accuracy for MathVision results.""" data = pd.read_excel(result_file) if result_file.endswith('.xlsx') else pd.read_csv(result_file) tot = defaultdict(lambda: 0) fetch = defaultdict(lambda: 0) hit = defaultdict(lambda: 0) lt = len(data) extract_counts = {} for i in range(lt): item = data.iloc[i] cate = item['category'] tot['Overall'] += 1 tot[cate] += 1 if 'Prefetch succeed' in item['log']: fetch['Overall'] += 1 fetch[cate] += 1 if post_check(item, prefetch=False): hit['Overall'] += 1 hit[cate] += 1 # Statistics of answers extracted by rule and gpt extract_model = item['extract_model'] extract_flag = item['extract_flag'] if extract_model in extract_counts: extract_counts[extract_model][1] += 1 else: extract_counts[extract_model] = [0, 1] # succeed, total if extract_flag: extract_counts[extract_model][0] += 1 res = defaultdict(list) for k in tot.keys(): res['Subject'].append(k) res['tot'].append(tot[k]) res['prefetch'].append(fetch[k]) res['hit'].append(hit[k]) res['prefetch_rate'].append(fetch[k] / tot[k] * 100) res['acc'].append(hit[k] / tot[k] * 100) if k == 'Overall': for model_key in extract_counts: res[model_key+'_success'].append(extract_counts[model_key][0]) res[model_key+'_all'].append(extract_counts[model_key][1]) else: for model_key in extract_counts: res[model_key+'_success'].append(0) res[model_key+'_all'].append(0) res = pd.DataFrame(res).sort_values('Subject', ignore_index=True) return res def eval_single_sample(args): """Evaluate a single sample.""" return MATH_V_auxeval(args)