Commit 55be51ea authored by Baber's avatar Baber
Browse files

feat: implement check_gold_index_error utility and refactor process_results...

feat: implement check_gold_index_error utility and refactor process_results for improved error handling. remove generate_until multiple-choice
parent 16030317
...@@ -21,6 +21,7 @@ from typing_extensions import deprecated ...@@ -21,6 +21,7 @@ from typing_extensions import deprecated
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance, OutputType from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.utils import check_gold_index_error
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import TaskConfig from lm_eval.config.task import TaskConfig
...@@ -380,7 +381,7 @@ class Task(abc.ABC): ...@@ -380,7 +381,7 @@ class Task(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc: dict, results: list): def process_results(self, doc: dict, results: list) -> dict[str, Any]:
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
...@@ -390,7 +391,7 @@ class Task(abc.ABC): ...@@ -390,7 +391,7 @@ class Task(abc.ABC):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
pass raise NotImplementedError
@deprecated("not used anymore") @deprecated("not used anymore")
def aggregation(self): def aggregation(self):
...@@ -949,11 +950,13 @@ class ConfigurableTask(Task): ...@@ -949,11 +950,13 @@ class ConfigurableTask(Task):
def apply_filters(self) -> list[Instance] | None: def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters") and self._instances:
for f in self._filters: for f in self._filters:
f.ensemble.apply(self._instances) f.ensemble.apply(self._instances)
else: else:
eval_logger.warning("No filter defined, passing through instances") eval_logger.warning(
"No filter defined or instances found. Passing through instances"
)
return self._instances return self._instances
def should_decontaminate(self): def should_decontaminate(self):
...@@ -1255,7 +1258,7 @@ class ConfigurableTask(Task): ...@@ -1255,7 +1258,7 @@ class ConfigurableTask(Task):
**kwargs, **kwargs,
) )
def process_results(self, doc: dict, results: list) -> dict: def process_results(self, doc: dict, results: list) -> dict[str, Any]:
if callable(self.config.process_results): if callable(self.config.process_results):
return self.config.process_results(doc, results) return self.config.process_results(doc, results)
...@@ -1269,9 +1272,12 @@ class ConfigurableTask(Task): ...@@ -1269,9 +1272,12 @@ class ConfigurableTask(Task):
**({"acc": int(is_greedy)} if "acc" in use_metric else {}), **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
} }
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results (loglikelihood, *_) = results
_words = self.count_words(self.doc_to_target(doc)) assert isinstance(_target := self.doc_to_target(doc), str), (
_bytes = self.count_bytes(self.doc_to_target(doc)) "Require target to be a string for loglikelihood_rolling"
)
_words = self.count_words(_target)
_bytes = self.count_bytes(_target)
return { return {
**( **(
{"word_perplexity": (loglikelihood, _words)} {"word_perplexity": (loglikelihood, _words)}
...@@ -1316,19 +1322,7 @@ class ConfigurableTask(Task): ...@@ -1316,19 +1322,7 @@ class ConfigurableTask(Task):
else: else:
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
gold_index_error = False gold, gold_index_error = check_gold_index_error(choices, gold)
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
if gold_index_error: if gold_index_error:
eval_logger.warning( eval_logger.warning(
...@@ -1376,11 +1370,6 @@ class ConfigurableTask(Task): ...@@ -1376,11 +1370,6 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until": elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
result = results[0] result = results[0]
if self.config.doc_to_choice is not None:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
for metric in self._metric_fn_list: for metric in self._metric_fn_list:
try: try:
result_score = self._metric_fn_list[metric]( result_score = self._metric_fn_list[metric](
......
from __future__ import annotations
def check_gold_index_error(
choices: list[int] | list[str], gold: list[int] | int | str
) -> tuple[int | list[int], bool]:
gold_index_error = False
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
return gold, gold_index_error
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
return gold, gold_index_error
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment