Commit 55be51ea authored by Baber's avatar Baber
Browse files

feat: implement check_gold_index_error utility and refactor process_results...

feat: implement check_gold_index_error utility and refactor process_results for improved error handling. remove generate_until multiple-choice
parent 16030317
......@@ -21,6 +21,7 @@ from typing_extensions import deprecated
from lm_eval import utils
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.utils import check_gold_index_error
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import TaskConfig
......@@ -380,7 +381,7 @@ class Task(abc.ABC):
pass
@abc.abstractmethod
def process_results(self, doc: dict, results: list):
def process_results(self, doc: dict, results: list) -> dict[str, Any]:
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
......@@ -390,7 +391,7 @@ class Task(abc.ABC):
:param results:
The results of the requests created in construct_requests.
"""
pass
raise NotImplementedError
@deprecated("not used anymore")
def aggregation(self):
......@@ -949,11 +950,13 @@ class ConfigurableTask(Task):
def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
if hasattr(self, "_filters") and self._instances:
for f in self._filters:
f.ensemble.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
eval_logger.warning(
"No filter defined or instances found. Passing through instances"
)
return self._instances
def should_decontaminate(self):
......@@ -1255,7 +1258,7 @@ class ConfigurableTask(Task):
**kwargs,
)
def process_results(self, doc: dict, results: list) -> dict:
def process_results(self, doc: dict, results: list) -> dict[str, Any]:
if callable(self.config.process_results):
return self.config.process_results(doc, results)
......@@ -1269,9 +1272,12 @@ class ConfigurableTask(Task):
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
_words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc))
(loglikelihood, *_) = results
assert isinstance(_target := self.doc_to_target(doc), str), (
"Require target to be a string for loglikelihood_rolling"
)
_words = self.count_words(_target)
_bytes = self.count_bytes(_target)
return {
**(
{"word_perplexity": (loglikelihood, _words)}
......@@ -1316,19 +1322,7 @@ class ConfigurableTask(Task):
else:
gold = self.doc_to_target(doc)
gold_index_error = False
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
gold, gold_index_error = check_gold_index_error(choices, gold)
if gold_index_error:
eval_logger.warning(
......@@ -1376,11 +1370,6 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc)
result = results[0]
if self.config.doc_to_choice is not None:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
for metric in self._metric_fn_list:
try:
result_score = self._metric_fn_list[metric](
......
from __future__ import annotations
def check_gold_index_error(
choices: list[int] | list[str], gold: list[int] | int | str
) -> tuple[int | list[int], bool]:
gold_index_error = False
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
return gold, gold_index_error
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
return gold, gold_index_error
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment