Commit 007d485b authored by lintangsutawika's avatar lintangsutawika
Browse files

modified apply method to accept docs

parent cb485883
...@@ -17,14 +17,16 @@ FILTER_REGISTRY = { ...@@ -17,14 +17,16 @@ FILTER_REGISTRY = {
def get_filter(filter_name): def get_filter(filter_name):
return FILTER_REGISTRY[filter_name] if filter_name in FILTER_REGISTRY:
return FILTER_REGISTRY[filter_name]
else:
return filter_name
def build_filter_ensemble(filter_name, components): def build_filter_ensemble(filter_name, components):
""" """
Create a filtering pipeline. Create a filtering pipeline.
""" """
filters = [] filters = []
for (function, kwargs) in components: for (function, kwargs) in components:
if kwargs is None: if kwargs is None:
......
...@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter): ...@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter):
""" """
self._decontam_results = None self._decontam_results = None
def apply(self, reps): def apply(self, reps, docs):
""" """
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
""" """
......
...@@ -15,7 +15,7 @@ class RegexFilter(Filter): ...@@ -15,7 +15,7 @@ class RegexFilter(Filter):
self.regex = re.compile(regex_pattern) self.regex = re.compile(regex_pattern)
self.fallback = fallback self.fallback = fallback
def apply(self, resps): def apply(self, resps, docs):
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -44,7 +44,7 @@ class WhitespaceFilter(Filter): ...@@ -44,7 +44,7 @@ class WhitespaceFilter(Filter):
def __init__(self): def __init__(self):
pass pass
def apply(self, resps): def apply(self, resps, docs):
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
......
...@@ -23,7 +23,7 @@ class TakeKFilter(Filter): ...@@ -23,7 +23,7 @@ class TakeKFilter(Filter):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def apply(self, resps): def apply(self, resps, docs):
# check we have at least k responses per doc, else we can't take the first k # check we have at least k responses per doc, else we can't take the first k
assert ( assert (
len(resps[0]) >= self.k len(resps[0]) >= self.k
...@@ -37,7 +37,7 @@ class MajorityVoteFilter(Filter): ...@@ -37,7 +37,7 @@ class MajorityVoteFilter(Filter):
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
def apply(self, resps): def apply(self, resps, docs):
""" """
Each entry of `resps` is a list of model responses. Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`. We select the response that occurs most frequently in each entry of `resps`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment