"test/client_app/client_app_impl.hpp" did not exist on "0aa899aa6cefded54ab98dcdf3dd094e60ca1535"
Commit cb485883 authored by lintangsutawika's avatar lintangsutawika
Browse files

filter takes docs as argument in case filtering requires it

parent 7d16a7cd
...@@ -2,7 +2,7 @@ from dataclasses import dataclass ...@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import List from typing import List
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from datasets import Dataset
class Filter: class Filter:
""" """
...@@ -18,7 +18,7 @@ class Filter: ...@@ -18,7 +18,7 @@ class Filter:
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
def apply(self, resps): def apply(self, resps, docs):
""" """
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g. Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...@@ -40,14 +40,14 @@ class FilterEnsemble: ...@@ -40,14 +40,14 @@ class FilterEnsemble:
name: str name: str
filters: List[Filter] filters: List[Filter]
def apply(self, instances: List[Instance]): def apply(self, instances: List[Instance], docs: List[Dataset]):
resps = [ resps = [
inst.resps for inst in instances inst.resps for inst in instances
] # operate just on the model responses ] # operate just on the model responses
for f in self.filters: for f in self.filters:
# apply filters in sequence # apply filters in sequence
resps = f.apply(resps) resps = f.apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances. # add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name. # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
......
...@@ -627,19 +627,19 @@ class ConfigurableTask(Task): ...@@ -627,19 +627,19 @@ class ConfigurableTask(Task):
) )
if self.has_test_docs(): if self.has_test_docs():
docs = self.test_docs() self.task_docs = self.test_docs()
elif self.has_validation_docs(): elif self.has_validation_docs():
docs = self.validation_docs() self.task_docs = self.validation_docs()
else: else:
assert ( assert (
False False
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" ), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
# Test One Doc # Test One Doc
self.features = list(docs.features.keys()) self.features = list(self.task_docs.features.keys())
self.multiple_input = 0 self.multiple_input = 0
self.multiple_target = 0 self.multiple_target = 0
test_doc = docs[0] test_doc = self.task_docs[0]
test_text = self.doc_to_text(test_doc) test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc) test_target = self.doc_to_target(test_doc)
...@@ -743,6 +743,15 @@ class ConfigurableTask(Task): ...@@ -743,6 +743,15 @@ class ConfigurableTask(Task):
) )
return super().fewshot_docs() return super().fewshot_docs()
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances, self.task_docs)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
def should_decontaminate(self): def should_decontaminate(self):
return self._config.should_decontaminate return self._config.should_decontaminate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment