Commit 4eecbabb authored by Baber's avatar Baber
Browse files

Merge branch 'main' into prefill

parents dac8b534 fb963f0f
task: mmmu_val_marketing
include: _template_yaml
task_alias: Marketing
dataset_name: Marketing
task: mmmu_val_materials
include: _template_yaml
task_alias: Materials
dataset_name: Materials
task: mmmu_val_math
include: _template_yaml
task_alias: Math
dataset_name: Math
task: mmmu_val_mechanical_engineering
include: _template_yaml
task_alias: Mechanical Engineering
dataset_name: Mechanical_Engineering
task: mmmu_val_music
include: _template_yaml
task_alias: Music
dataset_name: Music
task: mmmu_val_pharmacy
include: _template_yaml
task_alias: Pharmacy
dataset_name: Pharmacy
task: mmmu_val_physics
include: _template_yaml
task_alias: Physics
dataset_name: Physics
task: mmmu_val_psychology
include: _template_yaml
task_alias: Psychology
dataset_name: Psychology
task: mmmu_val_public_health
include: _template_yaml
task_alias: Public Health
dataset_name: Public_Health
task: mmmu_val_sociology
include: _template_yaml
task_alias: Sociology
dataset_name: Sociology
import ast
import random
import re
import numpy as np
random.seed(42)
# source for prompt fstrings: https://github.com/MMMU-Benchmark/MMMU/blob/7787d60648c82a9d40acd656fa541a6c74f58995/eval/configs/llava1.5.yaml#L3
MULTI_CHOICE_EXAMPLE_FORMAT = """{}
{}
Answer with the option's letter from the given choices directly."""
SHORT_ANS_EXAMPLE_FORMAT = """{}
Answer the question using a single word or phrase."""
START_CHR = "A"
def doc_to_image(doc):
# get formatted prompt (incl. multi-choice options) pre-<image {i}> reformatting
input_text = _doc_to_text(doc)
# locate <image {i}> instances in input
image_placeholders = [
img.replace(" ", "_").replace("<", "").replace(">", "")
for img in re.findall("<image [1-7]>", input_text)
]
# collect visuals (can have dupes of a given image or be out of order)
# E.g. validation_Math_19 contains <image 1> and <image 2> but seen as [<image 1>, <image 2>, <image 1>, <image 1>, <image 2>]
visuals = [doc[img] for img in image_placeholders]
return visuals
def doc_to_text(doc):
"""Get the prompt for a given document."""
prompt = _doc_to_text(doc)
for i in range(1, 8):
# replace <image {i}> with <image>. TODO: check this is always the right decision incl. for non-HF models
prompt = prompt.replace(f"<image {i}>", "<image>")
return prompt
def _doc_to_text(doc):
"""Helper--get the prompt for a given document but DO NOT yet replace <image {i}> with <image>."""
if doc["question_type"] == "multiple-choice":
choices_str = ""
for i, choice in enumerate(ast.literal_eval(doc["options"])):
# add (A) {choice1}\n , (B) {choice2}\n , and so on
# to create the list of formatted choices in the prompt
choices_str += f"\n({chr(ord(START_CHR) + i)}) {choice}"
choices_str = (
choices_str.lstrip()
) # remove the extraneous prepended \n that we added
prompt = MULTI_CHOICE_EXAMPLE_FORMAT.format(doc["question"], choices_str)
else:
prompt = SHORT_ANS_EXAMPLE_FORMAT.format(doc["question"])
return prompt
def process_results(doc, results):
if doc["question_type"] == "multiple-choice":
# multichoice logic
option_strs = ast.literal_eval(doc["options"])
option_letters = ["A", "B", "C", "D", "E", "F", "G", "H", "I"]
all_choices = option_letters[: len(option_strs)]
index2ans = {index: ans for index, ans in zip(option_letters, option_strs)}
pred = parse_multi_choice_response(results[0], all_choices, index2ans)
# print(pred, all_choices, index2ans)
is_correct = eval_multi_choice(doc["answer"], pred)
else:
# freeform response handling
pred = parse_open_response(results[0])
is_correct = eval_open(doc["answer"], pred)
return {"acc": float(is_correct)}
# TODO: it would be better if we could use a Filter for this logic.
### Output parsing and answer selection taken from
### https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/utils/data_utils.py
### and
### https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/utils/eval_utils.py
# ----------- Process Multi-choice -------------
def parse_multi_choice_response(response, all_choices, index2ans):
"""
Parse the prediction from the generated response.
Return the predicted index e.g., A, B, C, D.
"""
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " " # add space to avoid partial match
index_ans = True
ans_with_brack = False
candidates = []
for choice in all_choices: # e.g., (A) (B) (C) (D)
if f"({choice})" in response:
candidates.append(choice)
ans_with_brack = True
if len(candidates) == 0:
for choice in all_choices: # e.g., A B C D
if f" {choice} " in response:
candidates.append(choice)
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
candidates.append(index)
index_ans = False # it's content ans.
if len(candidates) == 0: # still not get answer, randomly choose one.
pred_index = random.choice(all_choices)
elif len(candidates) > 1:
start_indexes = []
if index_ans:
if ans_with_brack:
for can in candidates:
index = response.rfind(f"({can})")
start_indexes.append(index) # -1 will be ignored anyway
# start_indexes = [generated_response.index(f'({can})') for can in candidates]
else:
for can in candidates:
index = response.rfind(f" {can} ")
start_indexes.append(index)
else:
for can in candidates:
index = response.lower().rfind(index2ans[can].lower())
start_indexes.append(index)
# get the last one
pred_index = candidates[np.argmax(start_indexes)]
else: # if only one candidate, use it.
pred_index = candidates[0]
# print(response, all_choices, index2ans, pred_index)
return pred_index
# ----------- Process Open -------------
def check_is_number(string):
"""
Check if the given string a number.
"""
try:
float(string.replace(",", ""))
return True
except ValueError:
# check if there's comma inside
return False
def normalize_str(string):
"""
Normalize the str to lower case and make them float numbers if possible.
"""
# check if characters in the string
# if number, numerize it.
string = string.strip()
is_number = check_is_number(string)
if is_number:
string = string.replace(",", "")
string = float(string)
# leave 2 decimal
string = round(string, 2)
return [string]
else: # it's likely to be a string
# lower it
string = string.lower()
if len(string) == 1:
return [" " + string, string + " "] # avoid trivial matches
return [string]
def extract_numbers(string):
"""
Exact all forms of numbers from a string with regex.
"""
# Pattern for numbers with commas
pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
# Pattern for scientific notation
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
# Pattern for simple numbers without commas
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
# Extract numbers with commas
numbers_with_commas = re.findall(pattern_commas, string)
# Extract numbers in scientific notation
numbers_scientific = re.findall(pattern_scientific, string)
# Extract simple numbers without commas
numbers_simple = re.findall(pattern_simple, string)
# Combine all extracted numbers
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
return all_numbers
def parse_open_response(response):
"""
Parse the prediction from the generated response.
Return a list of predicted strings or numbers.
"""
# content = content.strip("\n").strip(".").strip(" ")
def get_key_subresponses(response):
key_responses = []
response = response.strip().strip(".").lower()
sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response)
indicators_of_keys = [
"could be ",
"so ",
"is ",
"thus ",
"therefore ",
"final ",
"answer ",
"result ",
]
key_responses = []
for index, resp in enumerate(sub_responses):
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
if index == len(sub_responses) - 1:
indicators_of_keys.extend(["="])
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
shortest_key_response = resp.split(indicator)[-1].strip()
else:
if len(resp.split(indicator)[-1].strip()) < len(
shortest_key_response
):
shortest_key_response = resp.split(indicator)[-1].strip()
# key_responses.append(resp.split(indicator)[1].strip())
if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [
":",
",",
".",
"!",
"?",
";",
":",
"'",
]:
key_responses.append(shortest_key_response)
if len(key_responses) == 0: # did not found any
return [response]
return key_responses
# pdb.set_trace()
key_responses = get_key_subresponses(response)
pred_list = key_responses.copy() # keep the original string response
for resp in key_responses:
pred_list.extend(extract_numbers(resp))
tmp_pred_list = []
for i in range(len(pred_list)):
tmp_pred_list.extend(normalize_str(pred_list[i]))
pred_list = tmp_pred_list
# remove duplicates
pred_list = list(set(pred_list))
return pred_list
# ----------- Evaluation -------------
def eval_multi_choice(gold_i, pred_i):
"""
Evaluate a multiple choice instance.
"""
correct = False
# only they are exactly the same, we consider it as correct
if isinstance(gold_i, list):
for answer in gold_i:
if answer == pred_i:
correct = True
break
else: # gold_i is a string
if gold_i == pred_i:
correct = True
return correct
def eval_open(gold_i, pred_i):
"""
Evaluate an open question instance
"""
correct = False
if isinstance(gold_i, list):
# use float to avoid trivial matches
norm_answers = []
for answer in gold_i:
norm_answers.extend(normalize_str(answer))
else:
norm_answers = normalize_str(gold_i)
for pred in pred_i: # pred is already normalized in parse response phase
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
for norm_ans in norm_answers:
# only see if the string answer in the string pred
if isinstance(norm_ans, str) and norm_ans in pred:
if not correct:
correct = True
break
else: # it's a float number
if pred in norm_answers:
if not correct:
correct = True
break
return correct
......@@ -4,3 +4,10 @@ task:
- tmmluplus_social_sciences
- tmmluplus_humanities
- tmmluplus_STEM
aggregate_metric_list:
- metric: acc
weight_by_size: True
- metric: acc_norm
weight_by_size: True
metadata:
version: 2.0
group: tmmluplus_STEM
task:
- tmmluplus_STEM_tasks
aggregate_metric_list:
- metric: acc
weight_by_size: True
- metric: acc_norm
weight_by_size: True
metadata:
version: 2.0
group: tmmluplus_humanities
task:
- tmmluplus_humanities_tasks
aggregate_metric_list:
- metric: acc
weight_by_size: True
- metric: acc_norm
weight_by_size: True
metadata:
version: 2.0
group: tmmluplus_other
task:
- tmmluplus_other_tasks
aggregate_metric_list:
- metric: acc
weight_by_size: True
- metric: acc_norm
weight_by_size: True
metadata:
version: 2.0
group: tmmluplus_social_sciences
task:
- tmmluplus_social_sciences_tasks
aggregate_metric_list:
- metric: acc
weight_by_size: True
- metric: acc_norm
weight_by_size: True
metadata:
version: 2.0
......@@ -16,4 +16,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
version: 2.0
"dataset_name": "accounting"
"description": "以下為會計學的單選題,請提供正確答案的選項。\n\n"
"group": "tmmluplus_other"
"group_alias": "other"
"include": "_default_template_yaml"
"tag": "tmmluplus_other_tasks"
"include": "_tmmluplus_template_yaml"
"task": "tmmluplus_accounting"
"task_alias": "accounting"
"dataset_name": "administrative_law"
"description": "以下為行政法的單選題,請提供正確答案的選項。\n\n"
"group": "tmmluplus_humanities"
"group_alias": "humanities"
"include": "_default_template_yaml"
"tag": "tmmluplus_humanities_tasks"
"include": "_tmmluplus_template_yaml"
"task": "tmmluplus_administrative_law"
"task_alias": "administrative law"
"dataset_name": "advance_chemistry"
"description": "以下為化學的單選題,請提供正確答案的選項。\n\n"
"group": "tmmluplus_STEM"
"group_alias": "STEM"
"include": "_default_template_yaml"
"tag": "tmmluplus_STEM_tasks"
"include": "_tmmluplus_template_yaml"
"task": "tmmluplus_advance_chemistry"
"task_alias": "advance chemistry"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment