Unverified Commit cda25fef authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'main' into standardize_metrics

parents dfb41835 4d10ad56
import datasets
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc):
instruction = (
f"다음을 읽고 정답으로 알맞은 것을 고르시요.\n"
f"### Question: {doc['question']}\n"
f"### Options:\n"
f"(1) {doc['option#1']}\n(2) {doc['option#2']}\n(3) {doc['option#3']}\n(4) {doc['option#4']}\n"
f"### Answer: 주어진 문제의 정답은"
)
out_doc = {
"question": instruction,
"choices": ["(1)", "(2)", "(3)", "(4)"],
"gold": int(doc["answer"]) - 1,
}
return out_doc
return dataset.map(_process_doc)
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -24,4 +24,4 @@ filter_list: ...@@ -24,4 +24,4 @@ filter_list:
regex_pattern: "^\\s*([A-D])" regex_pattern: "^\\s*([A-D])"
- function: "take_first" - function: "take_first"
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -12,4 +12,4 @@ metric_list: ...@@ -12,4 +12,4 @@ metric_list:
- metric: acc - metric: acc
- metric: f1 - metric: f1
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -26,4 +26,4 @@ metric_list: ...@@ -26,4 +26,4 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -28,4 +28,4 @@ filter_list: ...@@ -28,4 +28,4 @@ filter_list:
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)" regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)"
- function: "take_first" - function: "take_first"
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -28,4 +28,4 @@ filter_list: ...@@ -28,4 +28,4 @@ filter_list:
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)" regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)"
- function: "take_first" - function: "take_first"
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -94,7 +94,6 @@ LANGUAGES = { ...@@ -94,7 +94,6 @@ LANGUAGES = {
def add_regex_pattern(regex_pattern): def add_regex_pattern(regex_pattern):
if regex_pattern is None: if regex_pattern is None:
return {} return {}
return { return {
......
...@@ -21,4 +21,4 @@ metric_list: ...@@ -21,4 +21,4 @@ metric_list:
higher_is_better: true higher_is_better: true
num_fewshot: 0 num_fewshot: 0
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -7,7 +7,6 @@ import argparse ...@@ -7,7 +7,6 @@ import argparse
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
SUBJECTS = { SUBJECTS = {
...@@ -82,7 +81,6 @@ def parse_args(): ...@@ -82,7 +81,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our "other" YAMLs. # get filename of base_yaml so we can `"include": ` it in our "other" YAMLs.
...@@ -98,7 +96,6 @@ if __name__ == "__main__": ...@@ -98,7 +96,6 @@ if __name__ == "__main__":
ALL_CATEGORIES = [] ALL_CATEGORIES = []
for subject, category in tqdm(SUBJECTS.items()): for subject, category in tqdm(SUBJECTS.items()):
if category not in ALL_CATEGORIES: if category not in ALL_CATEGORIES:
ALL_CATEGORIES.append(category) ALL_CATEGORIES.append(category)
......
...@@ -12,4 +12,4 @@ metric_list: ...@@ -12,4 +12,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -23,4 +23,4 @@ metric_list: ...@@ -23,4 +23,4 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -23,4 +23,4 @@ metric_list: ...@@ -23,4 +23,4 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
metadata: metadata:
- version: 0.0 version: 0.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment