Commit 6acc5c47 authored by lintangsutawika's avatar lintangsutawika
Browse files

result_dict process

parent b698048d
......@@ -18,6 +18,7 @@ from collections.abc import Callable
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.metrics import (
METRIC_REGISTRY,
AGGREGATION_REGISTRY,
......@@ -187,13 +188,22 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode,
)
if self.DATASET_PATH in ["json", "csv"]:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
data_files=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode,
)
else:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode,
)
@abc.abstractmethod
def has_training_docs(self):
......@@ -436,8 +446,12 @@ class Task(abc.ABC):
def apply_filters(self):
for f in self._filters:
f.apply(self._instances)
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
class ConfigurableTask(Task):
......@@ -514,8 +528,8 @@ class ConfigurableTask(Task):
self._training_docs = None
self._fewshot_docs = None
self._filters = []
if self._config.filter_list is not None:
self._filters = []
for filter_config in self._config.filter_list:
for filter_pipeline in filter_config:
filter_name = filter_config["name"]
......@@ -528,11 +542,9 @@ class ConfigurableTask(Task):
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
self._filters.append(filter_pipeline)
else:
self._filters = [
build_filter_ensemble("take_first", [["take_first", None]])
]
self._filters = [build_filter_ensemble("none", [("none", None)])]
if self._config.use_prompt is not None:
eval_logger.info(f"loading prompt {self._config.use_prompt}")
......@@ -768,7 +780,7 @@ class ConfigurableTask(Task):
references=[gold], predictions=[result], **self._metric_kwargs[key]
)
result_dict[key] = _dict[key]
result_dict = {**result_dict, **_dict}
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment