Commit cb485883 authored by lintangsutawika's avatar lintangsutawika
Browse files

filter takes docs as argument in case filtering requires it

parent 7d16a7cd
......@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import List
from lm_eval.api.instance import Instance
from datasets import Dataset
class Filter:
"""
......@@ -18,7 +18,7 @@ class Filter:
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps):
def apply(self, resps, docs):
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
......@@ -40,14 +40,14 @@ class FilterEnsemble:
name: str
filters: List[Filter]
def apply(self, instances: List[Instance]):
def apply(self, instances: List[Instance], docs: List[Dataset]):
resps = [
inst.resps for inst in instances
] # operate just on the model responses
for f in self.filters:
# apply filters in sequence
resps = f.apply(resps)
resps = f.apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
......
......@@ -627,19 +627,19 @@ class ConfigurableTask(Task):
)
if self.has_test_docs():
docs = self.test_docs()
self.task_docs = self.test_docs()
elif self.has_validation_docs():
docs = self.validation_docs()
self.task_docs = self.validation_docs()
else:
assert (
False
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
# Test One Doc
self.features = list(docs.features.keys())
self.features = list(self.task_docs.features.keys())
self.multiple_input = 0
self.multiple_target = 0
test_doc = docs[0]
test_doc = self.task_docs[0]
test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
......@@ -743,6 +743,15 @@ class ConfigurableTask(Task):
)
return super().fewshot_docs()
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances, self.task_docs)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
def should_decontaminate(self):
return self._config.should_decontaminate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment