Commit ae4d9ed2 authored by lintangsutawika's avatar lintangsutawika
Browse files

changes for pre-commit

parent 6b72d7b7
...@@ -15,8 +15,8 @@ class TakeFirstFilter(Filter): ...@@ -15,8 +15,8 @@ class TakeFirstFilter(Filter):
""" """
return map(lambda r: r[0], resps) return map(lambda r: r[0], resps)
class TakeKFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
...@@ -25,8 +25,10 @@ class TakeKFilter(Filter): ...@@ -25,8 +25,10 @@ class TakeKFilter(Filter):
def apply(self, resps): def apply(self, resps):
# check we have at least k responses per doc, else we can't take the first k # check we have at least k responses per doc, else we can't take the first k
assert len(resps[0]) >= self.k, f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ." assert (
return map(lambda r: r[:self.k], resps) len(resps[0]) >= self.k
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
return map(lambda r: r[: self.k], resps)
class MajorityVoteFilter(Filter): class MajorityVoteFilter(Filter):
...@@ -40,6 +42,7 @@ class MajorityVoteFilter(Filter): ...@@ -40,6 +42,7 @@ class MajorityVoteFilter(Filter):
Each entry of `resps` is a list of model responses. Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`. We select the response that occurs most frequently in each entry of `resps`.
""" """
def select_majority(resp): def select_majority(resp):
counts = Counter(resp) counts = Counter(resp)
vote = counts.most_common(1)[0][0] vote = counts.most_common(1)[0][0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment