import ast import random import re import numpy as np random.seed(42) # source for prompt fstrings: https://github.com/MMMU-Benchmark/MMMU/blob/7787d60648c82a9d40acd656fa541a6c74f58995/eval/configs/llava1.5.yaml#L3 MULTI_CHOICE_EXAMPLE_FORMAT = """{} {} Answer with the option's letter from the given choices directly.""" SHORT_ANS_EXAMPLE_FORMAT = """{} Answer the question using a single word or phrase.""" START_CHR = "A" def doc_to_image(doc): # get formatted prompt (incl. multi-choice options) pre- reformatting input_text = _doc_to_text(doc) # locate instances in input image_placeholders = [ img.replace(" ", "_").replace("<", "").replace(">", "") for img in re.findall("", input_text) ] # collect visuals (can have dupes of a given image or be out of order) # E.g. validation_Math_19 contains and but seen as [, , , , ] visuals = [doc[img] for img in image_placeholders] return visuals def doc_to_text(doc): """Get the prompt for a given document.""" prompt = _doc_to_text(doc) for i in range(1, 8): # replace with . TODO: check this is always the right decision incl. for non-HF models prompt = prompt.replace(f"", "") return prompt def _doc_to_text(doc): """Helper--get the prompt for a given document but DO NOT yet replace with .""" if doc["question_type"] == "multiple-choice": choices_str = "" for i, choice in enumerate(ast.literal_eval(doc["options"])): # add (A) {choice1}\n , (B) {choice2}\n , and so on # to create the list of formatted choices in the prompt choices_str += f"\n({chr(ord(START_CHR) + i)}) {choice}" choices_str = ( choices_str.lstrip() ) # remove the extraneous prepended \n that we added prompt = MULTI_CHOICE_EXAMPLE_FORMAT.format(doc["question"], choices_str) else: prompt = SHORT_ANS_EXAMPLE_FORMAT.format(doc["question"]) return prompt def process_results(doc, results): if doc["question_type"] == "multiple-choice": # multichoice logic option_strs = ast.literal_eval(doc["options"]) option_letters = ["A", "B", "C", "D", "E", "F", "G", "H", "I"] all_choices = option_letters[: len(option_strs)] index2ans = {index: ans for index, ans in zip(option_letters, option_strs)} pred = parse_multi_choice_response(results[0], all_choices, index2ans) # print(pred, all_choices, index2ans) is_correct = eval_multi_choice(doc["answer"], pred) else: # freeform response handling pred = parse_open_response(results[0]) is_correct = eval_open(doc["answer"], pred) return {"acc": float(is_correct)} # TODO: it would be better if we could use a Filter for this logic. ### Output parsing and answer selection taken from ### https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/utils/data_utils.py ### and ### https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/utils/eval_utils.py # ----------- Process Multi-choice ------------- def parse_multi_choice_response(response, all_choices, index2ans): """ Parse the prediction from the generated response. Return the predicted index e.g., A, B, C, D. """ for char in [",", ".", "!", "?", ";", ":", "'"]: response = response.strip(char) response = " " + response + " " # add space to avoid partial match index_ans = True ans_with_brack = False candidates = [] for choice in all_choices: # e.g., (A) (B) (C) (D) if f"({choice})" in response: candidates.append(choice) ans_with_brack = True if len(candidates) == 0: for choice in all_choices: # e.g., A B C D if f" {choice} " in response: candidates.append(choice) # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example if len(candidates) == 0 and len(response.split()) > 5: for index, ans in index2ans.items(): if ans.lower() in response.lower(): candidates.append(index) index_ans = False # it's content ans. if len(candidates) == 0: # still not get answer, randomly choose one. pred_index = random.choice(all_choices) elif len(candidates) > 1: start_indexes = [] if index_ans: if ans_with_brack: for can in candidates: index = response.rfind(f"({can})") start_indexes.append(index) # -1 will be ignored anyway # start_indexes = [generated_response.index(f'({can})') for can in candidates] else: for can in candidates: index = response.rfind(f" {can} ") start_indexes.append(index) else: for can in candidates: index = response.lower().rfind(index2ans[can].lower()) start_indexes.append(index) # get the last one pred_index = candidates[np.argmax(start_indexes)] else: # if only one candidate, use it. pred_index = candidates[0] # print(response, all_choices, index2ans, pred_index) return pred_index # ----------- Process Open ------------- def check_is_number(string): """ Check if the given string a number. """ try: float(string.replace(",", "")) return True except ValueError: # check if there's comma inside return False def normalize_str(string): """ Normalize the str to lower case and make them float numbers if possible. """ # check if characters in the string # if number, numerize it. string = string.strip() is_number = check_is_number(string) if is_number: string = string.replace(",", "") string = float(string) # leave 2 decimal string = round(string, 2) return [string] else: # it's likely to be a string # lower it string = string.lower() if len(string) == 1: return [" " + string, string + " "] # avoid trivial matches return [string] def extract_numbers(string): """ Exact all forms of numbers from a string with regex. """ # Pattern for numbers with commas pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b" # Pattern for scientific notation pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+" # Pattern for simple numbers without commas pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])" # Extract numbers with commas numbers_with_commas = re.findall(pattern_commas, string) # Extract numbers in scientific notation numbers_scientific = re.findall(pattern_scientific, string) # Extract simple numbers without commas numbers_simple = re.findall(pattern_simple, string) # Combine all extracted numbers all_numbers = numbers_with_commas + numbers_scientific + numbers_simple return all_numbers def parse_open_response(response): """ Parse the prediction from the generated response. Return a list of predicted strings or numbers. """ # content = content.strip("\n").strip(".").strip(" ") def get_key_subresponses(response): key_responses = [] response = response.strip().strip(".").lower() sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response) indicators_of_keys = [ "could be ", "so ", "is ", "thus ", "therefore ", "final ", "answer ", "result ", ] key_responses = [] for index, resp in enumerate(sub_responses): # if last one, accept it's an equation (the entire response can be just one sentence with equation) if index == len(sub_responses) - 1: indicators_of_keys.extend(["="]) shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) for indicator in indicators_of_keys: if indicator in resp: if not shortest_key_response: shortest_key_response = resp.split(indicator)[-1].strip() else: if len(resp.split(indicator)[-1].strip()) < len( shortest_key_response ): shortest_key_response = resp.split(indicator)[-1].strip() # key_responses.append(resp.split(indicator)[1].strip()) if shortest_key_response: # and it's not trivial if shortest_key_response.strip() not in [ ":", ",", ".", "!", "?", ";", ":", "'", ]: key_responses.append(shortest_key_response) if len(key_responses) == 0: # did not found any return [response] return key_responses # pdb.set_trace() key_responses = get_key_subresponses(response) pred_list = key_responses.copy() # keep the original string response for resp in key_responses: pred_list.extend(extract_numbers(resp)) tmp_pred_list = [] for i in range(len(pred_list)): tmp_pred_list.extend(normalize_str(pred_list[i])) pred_list = tmp_pred_list # remove duplicates pred_list = list(set(pred_list)) return pred_list # ----------- Evaluation ------------- def eval_multi_choice(gold_i, pred_i): """ Evaluate a multiple choice instance. """ correct = False # only they are exactly the same, we consider it as correct if isinstance(gold_i, list): for answer in gold_i: if answer == pred_i: correct = True break else: # gold_i is a string if gold_i == pred_i: correct = True return correct def eval_open(gold_i, pred_i): """ Evaluate an open question instance """ correct = False if isinstance(gold_i, list): # use float to avoid trivial matches norm_answers = [] for answer in gold_i: norm_answers.extend(normalize_str(answer)) else: norm_answers = normalize_str(gold_i) for pred in pred_i: # pred is already normalized in parse response phase if isinstance(pred, str): # if it's a string, then find if ans in the pred_i for norm_ans in norm_answers: # only see if the string answer in the string pred if isinstance(norm_ans, str) and norm_ans in pred: if not correct: correct = True break else: # it's a float number if pred in norm_answers: if not correct: correct = True break return correct