Commit 90b055b6 authored by Baber's avatar Baber
Browse files

mcq_to_generative

parent a682edad
......@@ -1779,3 +1779,27 @@ class PerplexityTask(Task):
def count_words(cls, doc) -> int:
"""Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
class Generate_MultipleChoice(ConfigurableTask):
OUTPUT_TYPE = "generate_until"
def process_results(self, doc, results):
letters = [chr(i) for i in range(65, 91)]
gold = self.doc_to_target(doc)
result = results[0]
if isinstance(gold, int):
gold = letters[gold]
elif (self.config.doc_to_choice is not None) and (gold not in letters):
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
_index = choices.index(gold)
gold = letters[_index]
for metric in self._metric_fn_list.keys():
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
return result_score
......@@ -7,7 +7,7 @@ from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils
from lm_eval.api.group import ConfigurableGroup, GroupConfig
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.api.task import ConfigurableTask, Generate_MultipleChoice, Task
from lm_eval.evaluator_utils import get_subtask_list
......@@ -25,51 +25,29 @@ def convert_mcq_to_generative(cfg: dict):
+ cfg.get("doc_to_text", "")
+ 'Your response should end with "The best answer is [the_answer_letter]" where the [the_answer_letter] is one of choice letters, A, B, C etc.'
)
cfg["generation_kwargs"] = ({"until": ["."], "max_gen_toks": 10},)
cfg["filter_list"] = (
[
{
"name": "strict_match",
"filter": [
{"function": "remove_whitespace"},
{"function": "take_first"},
],
}
],
)
cfg["generation_kwargs"] = {"until": ["."], "max_gen_toks": 10}
cfg["filter_list"] = [
{
"name": "strict_match",
"filter": [
{"function": "remove_whitespace"},
{"function": "take_first"},
],
}
]
cfg["metric_list"] = [
{
"metric": "exact_match",
"aggregation": "mean",
"higher_is_better": True,
"ignore_case": True,
"ignore_punctuation": True,
"regexes_to_ignore": ["\\$", "\\.$"],
}
]
return cfg
# def convert_mcq_to_generative(cfg: dict):
# Prompt = """Given the following question and candidate answers, choose the correct answer."""
# if cfg.get("output_type", "generate_until") == "generate_until":
# return cfg
# else:
# cfg["output_type"] = "generate_until"
# doc_to_text: str = cfg.get("doc_to_text", "")
# doc_to_choice = cfg.get("doc_to_choice")
# assert doc_to_choice is not None, "doc_to_choice is required!"
# if isinstance(doc_to_choice, str):
# doc_to_choice = doc_to_choice.replace("{", "").replace("}", "")
# if doc_to_text.lower().rfind("answer") != -1:
# doc_to_text = doc_to_text[:doc_to_text.lower().rfind(r"answer")].strip()
# elif doc_to_text.lower().rfind("a:") != -1:
# doc_to_text = doc_to_text[:doc_to_text.lower().rfind(r"a:")].strip()
#
# cfg['doc_to_text'] = (
# f"{Prompt + '\n' + doc_to_text + '\n'}"
# "{% set letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'] %}"
# f"{{% for choice in {doc_to_choice} %}}"
# "{{letters[loop.index0]}}. {{choice}}" + "\n"
# "{% endfor %}\n"
# """Your response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of the answer letters."""
# )
# del cfg["doc_to_choice"]
# cfg["gen_prefix"] = "The answer is"
#
# return cfg
class TaskManager:
"""TaskManager indexes all tasks from the default `lm_eval/tasks/`
and an optional directory if provided.
......@@ -81,6 +59,7 @@ class TaskManager:
verbosity="INFO",
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
mcq_to_generative: bool = False,
) -> None:
self.verbosity = verbosity
self.include_path = include_path
......@@ -107,6 +86,7 @@ class TaskManager:
)
self.task_group_map = collections.defaultdict(list)
self.mcq_to_generative = mcq_to_generative
def initialize_tasks(
self,
......@@ -333,8 +313,11 @@ class TaskManager:
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = task
else:
config = convert_mcq_to_generative(config)
task_object = ConfigurableTask(config=config)
if self.mcq_to_generative:
config = convert_mcq_to_generative(config)
task_object = Generate_MultipleChoice(config=config)
else:
task_object = ConfigurableTask(config=config)
return {task: task_object}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment