Commit 8181f43c authored by Baber's avatar Baber
Browse files

nits

parent 2106fbeb
......@@ -5,8 +5,8 @@ import itertools
import json
import os
from functools import cached_property
from operator import itemgetter
from io import BytesIO
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union
from PIL import Image
......@@ -15,10 +15,8 @@ from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI
from lm_eval.models.utils import handle_stop_sequences
from lm_eval.models.api_models import JsonChatStr, TemplateAPI
from lm_eval.models.utils import Collator
from lm_eval.models.utils import Collator, handle_stop_sequences
from lm_eval.utils import eval_logger
......
......@@ -6,7 +6,6 @@ output_type: "generate_until"
doc_to_image:
- decoded_image
doc_to_text: "<image>{{query}}"
#doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F"][:choices.length] }}'
doc_to_target: answer
process_results: !function utils.process_results
generation_kwargs:
......@@ -15,6 +14,11 @@ generation_kwargs:
temperature: 0.0
do_sample: false
max_gen_toks: 1024
filter_list:
- name: "extract_answer"
filter:
- function: "custom"
filter_fn: !function utils.build_predictions
metric_list:
- metric: acc
aggregation: mean
......
import re
from typing import Optional
# from api_model import make_concurrent_requests
from Levenshtein import distance
......@@ -53,6 +54,13 @@ def create_test_prompt(demo_prompt, query, response):
return full_prompt
def verify_extraction(extraction):
extraction = extraction.strip()
if not extraction:
return False
return True
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py
def get_most_similar(prediction: str, choices: list) -> float:
"""
......@@ -140,24 +148,16 @@ def safe_equal(prediction, answer):
return False
def extract_answer(response: str, problem: dict) -> str:
def extract_answer(response: str, problem: dict, quick_extract=True) -> str:
question_type = problem["question_type"]
answer_type = problem["answer_type"]
choices = problem["choices"]
# query = problem["query"]
# pid = problem['pid']
# pid = problem["pid"]
if response == "":
return ""
### This is not in the original code:
extract = re.findall(
r"[tT]he answer is ([A-Za-z0-9]+(?:\.[A-Za-z0-9]+)?)", response
)
if extract:
return str(extract[0])
###
if question_type == "multi_choice" and response in choices:
return response
......@@ -175,7 +175,39 @@ def extract_answer(response: str, problem: dict) -> str:
except Exception:
pass
return response
# quick extraction
if quick_extract:
# The answer is "text". -> "text"
try:
result = re.search(r'The answer is "(.*)"\.', response)
if result:
extraction = result.group(1)
return extraction
except Exception:
pass
# general extraction
# try:
# full_prompt = create_test_prompt(DEMO_PROMPT, query, response)
# extraction = make_concurrent_requests(full_prompt)
# return extraction
# except Exception:
# print(
# f"Error in extracting answer for problem: {pid} with response: {response}"
# )
# # logging.info(f"Error in extracting answer for problem: {pid} with response: {response}")
# # logging.info(e)
return ""
def extract_all_answers(
resps: list[list[str]], docs: dict, quick_extract=True
) -> list[str]:
return [
extract_answer(resp[0], doc, quick_extract=quick_extract)
for resp, doc in zip(resps, docs)
]
# adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
......@@ -186,11 +218,16 @@ def process_results(doc: dict, results: list[str]):
answer_type = doc["answer_type"]
precision = doc["precision"] # noqa: F841
answer = doc["answer"]
extracted_answer = extract_answer(response, doc)
normalized_extraction = normalize_extracted_answer(
extracted_answer, choices, question_type, answer_type, precision
)
res = safe_equal(normalized_extraction, answer)
# step 1: extract the answer from the model response
# extracted_answer = extract_answer(response, doc)
extracted_answer = response[0]
if verify_extraction(extracted_answer):
normalized_extraction = normalize_extracted_answer(
extracted_answer, choices, question_type, answer_type, precision
)
res = safe_equal(normalized_extraction, answer)
else:
res = False
return {"acc": 1.0} if res else {"acc": 0.0}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment