import ast import json import os import random import re from collections import defaultdict import numpy as np temp_path = './eval/mmmu_pro/options.jsonl' f = open(temp_path, 'r') lines = [json.loads(line) for line in f.readlines()] id2options = {line['id']: line['options'] for line in lines} def mmmu_process_results(results): pred = results['response'] if isinstance(pred, dict): pred = '' try: index2ans, all_choices = get_multi_choice_info(ast.literal_eval(str(results['options']))) parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans) except: index2ans, all_choices = get_multi_choice_info(ast.literal_eval(str(id2options[results['id']]))) parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans) id = results['id'] if parsed_pred == results['answer']: if_right = True else: if_right = False results['pred_indexs'] = parsed_pred results['if_right'] = if_right return results def extract_subset_name(input_string): # Define a regex pattern to match "validation_" at the beginning and "_" at the end split = input_string.split('_')[0] pattern = re.compile(rf'^{split}_(.+?)_\d+$') match = pattern.search(input_string) if match: return match.group(1) else: raise ValueError(f'No match found in "{input_string}"') def mmmu_aggregate_results(results): evaluation_result = {} subset_to_eval_samples = defaultdict(list) for result in results: subset_to_eval_samples[result['subdomain']].append(result) for subset, sub_eval_samples in subset_to_eval_samples.items(): judge_dict, metric_dict = evaluate_mmmu(sub_eval_samples) metric_dict.update({'num_example': len(sub_eval_samples)}) evaluation_result[subset] = metric_dict printable_results = {} for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items(): in_domain_cat_results = {} for cat_name in in_domain_cats: if cat_name in evaluation_result.keys(): in_domain_cat_results[cat_name] = evaluation_result[cat_name] else: pass in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results) in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()]) printable_results['Overall-' + domain] = { 'num': int(in_domain_data_num), 'acc': round(in_domain_ins_acc, 3), } # add sub category for cat_name, cat_results in in_domain_cat_results.items(): printable_results[cat_name] = { 'num': int(cat_results['num_example']), 'acc': round(cat_results['acc'], 3), } all_ins_acc = calculate_ins_level_acc(evaluation_result) printable_results['Overall'] = { 'num': sum([cat_results['num_example'] for cat_results in evaluation_result.values()]), 'acc': round(all_ins_acc, 3), } print(printable_results) return printable_results['Overall']['acc'] ################## # Helper functions written by official MMMU repo. ################## def calculate_ins_level_acc(results): """Calculate the instruction level accuracy for given Subject results https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L246 """ acc = 0 ins_num = 0 for cat_results in results.values(): acc += cat_results['acc'] * cat_results['num_example'] ins_num += cat_results['num_example'] if ins_num == 0: return 0 return acc / ins_num DOMAIN_CAT2SUB_CAT = { 'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'], 'Business': ['Accounting', 'Economics', 'Finance', 'Manage', 'Marketing'], 'Science': [ 'Biology', 'Chemistry', 'Geography', 'Math', 'Physics', ], 'Health and Medicine': [ 'Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine', 'Pharmacy', 'Public_Health', ], 'Humanities and Social Science': [ 'History', 'Literature', 'Sociology', 'Psychology', ], 'Tech and Engineering': [ 'Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics', 'Energy_and_Power', 'Materials', 'Mechanical_Engineering', ], } def eval_multi_choice(gold_i, pred_i): """ Evaluate a multiple choice instance. https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L175 """ correct = False # only they are exactly the same, we consider it as correct if isinstance(gold_i, list): for answer in gold_i: if answer == pred_i: correct = True break else: # gold_i is a string if gold_i == pred_i: correct = True return correct def eval_open(gold_i, pred_i): """ Evaluate an open question instance https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L191 """ correct = False if isinstance(gold_i, list): # use float to avoid trivial matches norm_answers = [] for answer in gold_i: norm_answers.extend(normalize_str(answer)) else: norm_answers = normalize_str(gold_i) for pred in pred_i: # pred is already normalized in parse response phase if isinstance(pred, str): # if it's a string, then find if ans in the pred_i for norm_ans in norm_answers: # only see if the string answer in the string pred if isinstance(norm_ans, str) and norm_ans in pred: if not correct: correct = True break else: # it's a float number if pred in norm_answers: if not correct: correct = True break return correct def evaluate_mmmu(samples): """ Batch evaluation for multiple choice and open questions. https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L219 """ pred_correct = 0 judge_dict = dict() for sample in samples: gold_i = sample['answers'] pred_i = sample['parsed_pred'] correct = eval_multi_choice(gold_i, pred_i) if correct: judge_dict[sample['id']] = 'Correct' pred_correct += 1 else: judge_dict[sample['id']] = 'Wrong' if len(samples) == 0: return {'acc': 0} return judge_dict, {'acc': pred_correct / len(samples)} def parse_multi_choice_responses(response): pred_indexs = [] return pred_indexs def parse_multi_choice_response(response, all_choices, index2ans): """ Parse the prediction from the generated response. Return the predicted index, e.g., A, B, C, D. https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L10 """ last_answer_pos = response.rfind('Answer:') if last_answer_pos != -1: # Extract the string after "Answer:" answer_str = response[last_answer_pos + len('Answer:'):].strip() # Find a unique match in the options matching_options = [option for option in all_choices if option in answer_str] # If a unique match is found, return that option if len(matching_options) == 1: return matching_options[0] if isinstance(response, str): for char in [',', '.', '!', '?', ';', ':', "'"]: response = response.strip(char) response = ' ' + response + ' ' # add space to avoid partial match else: print(response) response = '' index_ans = True ans_with_brack = False candidates = [] for choice in all_choices: # e.g., (A) (B) (C) (D) if f'({choice})' in response: candidates.append(choice) ans_with_brack = True if len(candidates) == 0: for choice in all_choices: # e.g., A B C D if f'{choice} ' in response: candidates.append(choice) if len(candidates) == 0: for choice in all_choices: # e.g., A. B. C. D. if f'{choice}.' in response: candidates.append(choice) # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example if len(candidates) == 0 and len(response.split()) > 5: for index, ans in index2ans.items(): if ans.lower() in response.lower(): candidates.append(index) index_ans = False # it's content ans. if len(candidates) == 0: # still not get answer, randomly choose one. pred_index = random.choice(all_choices) elif len(candidates) > 1: start_indexes = [] if index_ans: if ans_with_brack: for can in candidates: index = response.rfind(f'({can})') start_indexes.append(index) # -1 will be ignored anyway # start_indexes = [generated_response.index(f'({can})') for can in candidates] else: for can in candidates: index = response.rfind(f' {can} ') start_indexes.append(index) else: for can in candidates: index = response.lower().rfind(index2ans[can].lower()) start_indexes.append(index) # get the last one pred_index = candidates[np.argmax(start_indexes)] else: # if only one candidate, use it. pred_index = candidates[0] return pred_index def extract_numbers(string): """ Exact all forms of numbers from a string with regex. https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L100 """ # Pattern for numbers with commas pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b' # Pattern for scientific notation pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+' # Pattern for simple numbers without commas pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])' # Extract numbers with commas numbers_with_commas = re.findall(pattern_commas, string) # Extract numbers in scientific notation numbers_scientific = re.findall(pattern_scientific, string) # Extract simple numbers without commas numbers_simple = re.findall(pattern_simple, string) # Combine all extracted numbersz all_numbers = numbers_with_commas + numbers_scientific + numbers_simple return all_numbers def check_is_number(string): """ Check if the given string a number. https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L65 """ try: float(string.replace(',', '')) return True except ValueError: # check if there's comma inside return False def normalize_str(string): """ Normalize the str to lower case and make them float numbers if possible. https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L76 """ # check if characters in the string # if number, numerize it. string = string.strip() is_number = check_is_number(string) if is_number: string = string.replace(',', '') string = float(string) # leave 2 decimal string = round(string, 2) return [string] else: # it's likely to be a string # lower it string = string.lower() if len(string) == 1: return [' ' + string, string + ' '] # avoid trivial matches return [string] def parse_open_response(response): """ Parse the prediction from the generated response. Return a list of predicted strings or numbers. https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L122 """ # content = content.strip("\n").strip(".").strip(" ") def get_key_subresponses(response): key_responses = [] response = response.strip().strip('.').lower() sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response) indicators_of_keys = [ 'could be ', 'so ', 'is ', 'thus ', 'therefore ', 'final ', 'answer ', 'result ', ] key_responses = [] for index, resp in enumerate(sub_responses): # if last one, accept it's an equation (the entire response can be just one sentence with equation) if index == len(sub_responses) - 1: indicators_of_keys.extend(['=']) shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) for indicator in indicators_of_keys: if indicator in resp: if not shortest_key_response: shortest_key_response = resp.split(indicator)[-1].strip() else: if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): shortest_key_response = resp.split(indicator)[-1].strip() # key_responses.append(resp.split(indicator)[1].strip()) if shortest_key_response: # and it's not trivial if shortest_key_response.strip() not in [ ':', ',', '.', '!', '?', ';', ':', "'", ]: key_responses.append(shortest_key_response) if len(key_responses) == 0: # did not found any return [response] return key_responses # pdb.set_trace() key_responses = get_key_subresponses(response) pred_list = key_responses.copy() # keep the original string response for resp in key_responses: pred_list.extend(extract_numbers(resp)) tmp_pred_list = [] for i in range(len(pred_list)): tmp_pred_list.extend(normalize_str(pred_list[i])) pred_list = tmp_pred_list # remove duplicates pred_list = list(set(pred_list)) return pred_list def get_multi_choice_info(options): """ Given the list of options for multiple choice question Return the index2ans and all_choices https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/data_utils.py#L54 """ start_chr = 'A' all_choices = [] index2ans = {} for i, option in enumerate(options): index2ans[chr(ord(start_chr) + i)] = option all_choices.append(chr(ord(start_chr) + i)) return index2ans, all_choices NUM = 1730 def check_files(input_dir): pattern = re.compile( r'(?P.+)_(?Pstandard \(\d+ options\)|vision)_(?Pcot|direct)\.jsonl') for file_name in os.listdir(input_dir): match = pattern.match(file_name) if match: model_name = match.group('model_name') method = match.group('method') setting = match.group('setting') file_path = os.path.join(input_dir, file_name) # print(file_path) with open(file_path, 'r', encoding='utf-8') as infile: results = [json.loads(data) for data in infile] if len(results) != NUM: print(f'Error: {file_path} has {len(results)} results, expected {NUM}') continue processed_results = [] true_nums = 0 false_nums = 0 for i, result in enumerate(results): new_data = mmmu_process_results(result) processed_results.append(new_data) if new_data['if_right']: true_nums += 1 else: false_nums += 1 # Calculate and output the accuracy total = true_nums + false_nums acc = true_nums / total * 100 if total > 0 else 0 print(f'Model: {model_name:<15} Method: {method:<8} Setting: {setting:<8} - Accuracy: {acc:>6.2f}%') # Write the processed results back to the file with open(file_path, 'w', encoding='utf-8') as outfile: for item in processed_results: outfile.write(json.dumps(item) + '\n') if __name__ == '__main__': input_directory = './eval/mmmu_pro/results' # Replace with your input directory check_files(input_directory)