Commit 9deea116 authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'local-file' into dataset-metric-log

parents 9f62d3ac 6acc5c47
...@@ -18,6 +18,7 @@ from collections.abc import Callable ...@@ -18,6 +18,7 @@ from collections.abc import Callable
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
METRIC_REGISTRY, METRIC_REGISTRY,
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
...@@ -187,13 +188,22 @@ class Task(abc.ABC): ...@@ -187,13 +188,22 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD` - `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset. Fresh download and fresh dataset.
""" """
self.dataset = datasets.load_dataset( if self.DATASET_PATH in ["json", "csv"]:
path=self.DATASET_PATH, self.dataset = datasets.load_dataset(
name=self.DATASET_NAME, path=self.DATASET_PATH,
data_dir=data_dir, data_files=self.DATASET_NAME,
cache_dir=cache_dir, data_dir=data_dir,
download_mode=download_mode, cache_dir=cache_dir,
) download_mode=download_mode,
)
else:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode,
)
@abc.abstractmethod @abc.abstractmethod
def has_training_docs(self): def has_training_docs(self):
...@@ -436,8 +446,12 @@ class Task(abc.ABC): ...@@ -436,8 +446,12 @@ class Task(abc.ABC):
def apply_filters(self): def apply_filters(self):
for f in self._filters: if hasattr(self, "_filters"):
f.apply(self._instances) for f in self._filters:
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
class ConfigurableTask(Task): class ConfigurableTask(Task):
...@@ -514,8 +528,8 @@ class ConfigurableTask(Task): ...@@ -514,8 +528,8 @@ class ConfigurableTask(Task):
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
self._filters = []
if self._config.filter_list is not None: if self._config.filter_list is not None:
self._filters = []
for filter_config in self._config.filter_list: for filter_config in self._config.filter_list:
for filter_pipeline in filter_config: for filter_pipeline in filter_config:
filter_name = filter_config["name"] filter_name = filter_config["name"]
...@@ -528,11 +542,9 @@ class ConfigurableTask(Task): ...@@ -528,11 +542,9 @@ class ConfigurableTask(Task):
components.append([function["function"], kwargs]) components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components) filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
else: else:
self._filters = [ self._filters = [build_filter_ensemble("none", [("none", None)])]
build_filter_ensemble("take_first", [["take_first", None]])
]
if self._config.use_prompt is not None: if self._config.use_prompt is not None:
eval_logger.info(f"loading prompt {self._config.use_prompt}") eval_logger.info(f"loading prompt {self._config.use_prompt}")
...@@ -768,7 +780,7 @@ class ConfigurableTask(Task): ...@@ -768,7 +780,7 @@ class ConfigurableTask(Task):
references=[gold], predictions=[result], **self._metric_kwargs[key] references=[gold], predictions=[result], **self._metric_kwargs[key]
) )
result_dict[key] = _dict[key] result_dict = {**result_dict, **_dict}
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment