Commit 8181f43c authored by Baber's avatar Baber
Browse files

nits

parent 2106fbeb
...@@ -5,8 +5,8 @@ import itertools ...@@ -5,8 +5,8 @@ import itertools
import json import json
import os import os
from functools import cached_property from functools import cached_property
from operator import itemgetter
from io import BytesIO from io import BytesIO
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from PIL import Image from PIL import Image
...@@ -15,10 +15,8 @@ from tqdm import tqdm ...@@ -15,10 +15,8 @@ from tqdm import tqdm
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI
from lm_eval.models.utils import handle_stop_sequences
from lm_eval.models.api_models import JsonChatStr, TemplateAPI from lm_eval.models.api_models import JsonChatStr, TemplateAPI
from lm_eval.models.utils import Collator from lm_eval.models.utils import Collator, handle_stop_sequences
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
......
...@@ -6,7 +6,6 @@ output_type: "generate_until" ...@@ -6,7 +6,6 @@ output_type: "generate_until"
doc_to_image: doc_to_image:
- decoded_image - decoded_image
doc_to_text: "<image>{{query}}" doc_to_text: "<image>{{query}}"
#doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F"][:choices.length] }}'
doc_to_target: answer doc_to_target: answer
process_results: !function utils.process_results process_results: !function utils.process_results
generation_kwargs: generation_kwargs:
...@@ -15,6 +14,11 @@ generation_kwargs: ...@@ -15,6 +14,11 @@ generation_kwargs:
temperature: 0.0 temperature: 0.0
do_sample: false do_sample: false
max_gen_toks: 1024 max_gen_toks: 1024
filter_list:
- name: "extract_answer"
filter:
- function: "custom"
filter_fn: !function utils.build_predictions
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
import re import re
from typing import Optional from typing import Optional
# from api_model import make_concurrent_requests
from Levenshtein import distance from Levenshtein import distance
...@@ -53,6 +54,13 @@ def create_test_prompt(demo_prompt, query, response): ...@@ -53,6 +54,13 @@ def create_test_prompt(demo_prompt, query, response):
return full_prompt return full_prompt
def verify_extraction(extraction):
extraction = extraction.strip()
if not extraction:
return False
return True
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py # taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py
def get_most_similar(prediction: str, choices: list) -> float: def get_most_similar(prediction: str, choices: list) -> float:
""" """
...@@ -140,24 +148,16 @@ def safe_equal(prediction, answer): ...@@ -140,24 +148,16 @@ def safe_equal(prediction, answer):
return False return False
def extract_answer(response: str, problem: dict) -> str: def extract_answer(response: str, problem: dict, quick_extract=True) -> str:
question_type = problem["question_type"] question_type = problem["question_type"]
answer_type = problem["answer_type"] answer_type = problem["answer_type"]
choices = problem["choices"] choices = problem["choices"]
# query = problem["query"] # query = problem["query"]
# pid = problem['pid'] # pid = problem["pid"]
if response == "": if response == "":
return "" return ""
### This is not in the original code:
extract = re.findall(
r"[tT]he answer is ([A-Za-z0-9]+(?:\.[A-Za-z0-9]+)?)", response
)
if extract:
return str(extract[0])
###
if question_type == "multi_choice" and response in choices: if question_type == "multi_choice" and response in choices:
return response return response
...@@ -175,7 +175,39 @@ def extract_answer(response: str, problem: dict) -> str: ...@@ -175,7 +175,39 @@ def extract_answer(response: str, problem: dict) -> str:
except Exception: except Exception:
pass pass
return response # quick extraction
if quick_extract:
# The answer is "text". -> "text"
try:
result = re.search(r'The answer is "(.*)"\.', response)
if result:
extraction = result.group(1)
return extraction
except Exception:
pass
# general extraction
# try:
# full_prompt = create_test_prompt(DEMO_PROMPT, query, response)
# extraction = make_concurrent_requests(full_prompt)
# return extraction
# except Exception:
# print(
# f"Error in extracting answer for problem: {pid} with response: {response}"
# )
# # logging.info(f"Error in extracting answer for problem: {pid} with response: {response}")
# # logging.info(e)
return ""
def extract_all_answers(
resps: list[list[str]], docs: dict, quick_extract=True
) -> list[str]:
return [
extract_answer(resp[0], doc, quick_extract=quick_extract)
for resp, doc in zip(resps, docs)
]
# adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py # adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
...@@ -186,11 +218,16 @@ def process_results(doc: dict, results: list[str]): ...@@ -186,11 +218,16 @@ def process_results(doc: dict, results: list[str]):
answer_type = doc["answer_type"] answer_type = doc["answer_type"]
precision = doc["precision"] # noqa: F841 precision = doc["precision"] # noqa: F841
answer = doc["answer"] answer = doc["answer"]
extracted_answer = extract_answer(response, doc) # step 1: extract the answer from the model response
normalized_extraction = normalize_extracted_answer( # extracted_answer = extract_answer(response, doc)
extracted_answer, choices, question_type, answer_type, precision extracted_answer = response[0]
) if verify_extraction(extracted_answer):
res = safe_equal(normalized_extraction, answer) normalized_extraction = normalize_extracted_answer(
extracted_answer, choices, question_type, answer_type, precision
)
res = safe_equal(normalized_extraction, answer)
else:
res = False
return {"acc": 1.0} if res else {"acc": 0.0} return {"acc": 1.0} if res else {"acc": 0.0}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment