Commit b27bc18e authored by Baber's avatar Baber
Browse files

add processing code

parent 0d187eda
dataset_path: AI4Math/MathVista dataset_path: AI4Math/MathVista
task: mathvista_mcq task: mathvista
test_split: testmini test_split: testmini
output_type: "greedy_until" output_type: "generate_until"
process_docs: !function utils.process_docs #process_docs: !function utils.process_docs
doc_to_image: !function utils.doc_to_image doc_to_image: decoded_image
doc_to_text: "<image> {{query}}" doc_to_text: "<image> {{query}}"
#doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F"][:choices.length] }}' #doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F"][:choices.length] }}'
doc_to_target: answer doc_to_target: answer
process_results: !function utils.process_results process_results: !function utils.process_results
generation_kwargs:
until:
- "<|endoftext|>"
temperature: 0.0
do_sample: false
max_gen_toks: 64
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
import re import re
from typing import Optional
from Levenshtein import distance from Levenshtein import distance
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py # taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py
def get_most_similar(prediction: str, choices: list): def get_most_similar(prediction: str, choices: list) -> float:
""" """
Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction
""" """
...@@ -14,17 +15,19 @@ def get_most_similar(prediction: str, choices: list): ...@@ -14,17 +15,19 @@ def get_most_similar(prediction: str, choices: list):
# return min(choices, key=lambda choice: distance(prediction, choice)) # return min(choices, key=lambda choice: distance(prediction, choice))
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
def normalize_extracted_answer( def normalize_extracted_answer(
extraction, extraction: str,
choices: list, choices: list,
question_type: str, question_type: str,
answer_type: str, answer_type: str,
precision, precision,
ignore_empty_extractions=False, ignore_empty_extractions=True,
): ) -> Optional[str]:
""" """
Normalize the extracted answer to match the answer type Normalize the extracted answer to match the answer type
""" """
if question_type == "multi_choice": if question_type == "multi_choice":
# make sure the extraction is a string # make sure the extraction is a string
if isinstance(extraction, str): if isinstance(extraction, str):
...@@ -88,37 +91,47 @@ def safe_equal(prediction, answer): ...@@ -88,37 +91,47 @@ def safe_equal(prediction, answer):
return False return False
def get_acc_with_contion(res_pd, key, value): def extract_answer(response: str, problem: dict) -> str:
if key == "skills": question_type = problem["question_type"]
total_pd = res_pd[res_pd[key].apply(lambda x: value in x)] answer_type = problem["answer_type"]
else: choices = problem["choices"]
total_pd = res_pd[res_pd[key] == value] # query = problem["query"]
# pid = problem['pid']
correct_pd = total_pd[total_pd["true_false"] == True] # noqa: E712
acc = len(correct_pd) / len(total_pd)
return len(correct_pd), len(total_pd), acc if response == "":
return ""
# adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
def process_results(doc, results):
response = results[0]
choices = doc["choices"]
question_type = doc["question_type"]
answer_type = doc["answer_type"]
precision = doc["precision"] # noqa: F841
extraction = doc["extraction"] # noqa: F841
if question_type == "multi_choice" and response in choices: if question_type == "multi_choice" and response in choices:
return {"acc": 1.0} return response
if answer_type == "integer": if answer_type == "integer":
try: try:
extraction = int(response) extraction = int(response)
return str(extraction) return str(extraction)
except Exception: except Exception:
pass pass
if answer_type == "float": if answer_type == "float":
try: try:
extraction = str(float(response)) extraction = str(float(response))
return extraction return extraction
except Exception: except Exception:
pass pass
return ""
# adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
def process_results(doc: dict, results: list[str]):
response = results[0] # noqa: F841
choices = doc["choices"]
question_type = doc["question_type"]
answer_type = doc["answer_type"]
precision = doc["precision"] # noqa: F841
answer = doc["answer"]
extracted_answer = extract_answer(response, doc)
normalized_extraction = normalize_extracted_answer(
extracted_answer, choices, question_type, answer_type, precision
)
res = safe_equal(normalized_extraction, answer)
return {"acc": 1.0} if res else {"acc": 0.0}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment