Commit ae4d9ed2 authored by lintangsutawika's avatar lintangsutawika
Browse files

changes for pre-commit

parent 6b72d7b7
...@@ -29,7 +29,7 @@ graph LR; ...@@ -29,7 +29,7 @@ graph LR;
P[Prompt] P[Prompt]
Me[Metric] Me[Metric]
R[Result] R[Result]
T --- I:::empty T --- I:::empty
P --- I P --- I
I --> M I --> M
......
...@@ -6,7 +6,7 @@ from . import extraction ...@@ -6,7 +6,7 @@ from . import extraction
FILTER_REGISTRY = { FILTER_REGISTRY = {
"take_first": selection.TakeFirstFilter, "take_first": selection.TakeFirstFilter,
"regex": extraction.RegexFilter, "regex": extraction.RegexFilter,
"majority_vote": selection.MajorityVoteFilter, "majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter, "take_first_k": selection.TakeKFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward, # that takes an input and returns a scalar and then should select the max reward,
......
...@@ -15,8 +15,8 @@ class TakeFirstFilter(Filter): ...@@ -15,8 +15,8 @@ class TakeFirstFilter(Filter):
""" """
return map(lambda r: r[0], resps) return map(lambda r: r[0], resps)
class TakeKFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
...@@ -25,8 +25,10 @@ class TakeKFilter(Filter): ...@@ -25,8 +25,10 @@ class TakeKFilter(Filter):
def apply(self, resps): def apply(self, resps):
# check we have at least k responses per doc, else we can't take the first k # check we have at least k responses per doc, else we can't take the first k
assert len(resps[0]) >= self.k, f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ." assert (
return map(lambda r: r[:self.k], resps) len(resps[0]) >= self.k
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
return map(lambda r: r[: self.k], resps)
class MajorityVoteFilter(Filter): class MajorityVoteFilter(Filter):
...@@ -37,12 +39,13 @@ class MajorityVoteFilter(Filter): ...@@ -37,12 +39,13 @@ class MajorityVoteFilter(Filter):
def apply(self, resps): def apply(self, resps):
""" """
Each entry of `resps` is a list of model responses. Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`. We select the response that occurs most frequently in each entry of `resps`.
""" """
def select_majority(resp): def select_majority(resp):
counts = Counter(resp) counts = Counter(resp)
vote = counts.most_common(1)[0][0] vote = counts.most_common(1)[0][0]
return vote return vote
return map(lambda r: [select_majority(r)], resps) return map(lambda r: [select_majority(r)], resps)
...@@ -64,4 +64,4 @@ Tasks added in the revamped harness that were not previously available. Again, a ...@@ -64,4 +64,4 @@ Tasks added in the revamped harness that were not previously available. Again, a
- [ ] Chain of Thought - [ ] Chain of Thought
- [ ] Self-consistency ; Least-to-Most prompting, etc. - [ ] Self-consistency ; Least-to-Most prompting, etc.
- [ ] Summarization Tasks - [ ] Summarization Tasks
- [ ] Anthropic Model-Written Evals - [ ] Anthropic Model-Written Evals
\ No newline at end of file
...@@ -29,4 +29,4 @@ Homepage: https://github.com/openai/grade-school-math ...@@ -29,4 +29,4 @@ Homepage: https://github.com/openai/grade-school-math
archivePrefix={arXiv}, archivePrefix={arXiv},
primaryClass={cs.LG} primaryClass={cs.LG}
} }
``` ```
\ No newline at end of file
...@@ -29,4 +29,4 @@ filter_list: ...@@ -29,4 +29,4 @@ filter_list:
- function: "regex" - function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)" regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote" - function: "majority_vote"
- function: "take_first" - function: "take_first"
\ No newline at end of file
...@@ -39,4 +39,4 @@ filter_list: ...@@ -39,4 +39,4 @@ filter_list:
filter: filter:
- function: "regex" - function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)" regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)"
- function: "take_first" - function: "take_first"
\ No newline at end of file
...@@ -32,4 +32,4 @@ num_fewshot: 5 ...@@ -32,4 +32,4 @@ num_fewshot: 5
# filter: # filter:
# - function: "regex" # - function: "regex"
# regex_pattern: "### (\\-?[0-9\\.\\,]+)" # regex_pattern: "### (\\-?[0-9\\.\\,]+)"
# - function: "take_first" # - function: "take_first"
\ No newline at end of file
# LAMBADA # LAMBADA
### Paper ### Paper
The LAMBADA dataset: Word prediction requiring a broad discourse context The LAMBADA dataset: Word prediction requiring a broad discourse context
https://arxiv.org/pdf/1606.06031.pdf https://arxiv.org/pdf/1606.06031.pdf
LAMBADA is a dataset to evaluate the capabilities of computational models for text LAMBADA is a dataset to evaluate the capabilities of computational models for text
...@@ -23,4 +23,4 @@ Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI ...@@ -23,4 +23,4 @@ Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
publisher={Zenodo}, publisher={Zenodo},
year={2016}, year={2016},
month={Aug} month={Aug}
} }
\ No newline at end of file
...@@ -20,4 +20,4 @@ Homepage: https://pile.eleuther.ai/ ...@@ -20,4 +20,4 @@ Homepage: https://pile.eleuther.ai/
journal={arXiv preprint arXiv:2101.00027}, journal={arXiv preprint arXiv:2101.00027},
year={2020} year={2020}
} }
``` ```
\ No newline at end of file
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment