Unverified Commit 0f6cd358 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #865 from EleutherAI/csatqa-refactor

[Refactor] Port CSATQA to refactor
parents c2848879 020c4115
group: csatqa
dataset_path: EleutherAI/csatqa
test_split: test
output_type: multiple_choice
process_docs: !function utils.process_docs
doc_to_text: "{{question}}"
doc_to_choice: "{{choices}}"
doc_to_target: "{{gold}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
"""
Take in a YAML, and output all other splits with this YAML
"""
import os
import yaml
import argparse
from tqdm import tqdm
from lm_eval.logger import eval_logger
SUBSETS = ["WR", "GR", "RCS", "RCSS", "RCH", "LI"]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_yaml_path", required=True)
parser.add_argument("--save_prefix_path", default="csatqa")
parser.add_argument("--task_prefix", default="")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path) as f:
base_yaml = yaml.full_load(f)
for name in tqdm(SUBSETS):
yaml_dict = {
"include": base_yaml_name,
"task": f"csatqa_{args.task_prefix}_{name}"
if args.task_prefix != ""
else f"csatqa_{name.lower()}",
"dataset_name": name,
}
file_save_path = args.save_prefix_path + f"_{name.lower()}.yaml"
eval_logger.info(f"Saving yaml for subset {name} to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
width=float("inf"),
allow_unicode=True,
default_style='"',
)
"dataset_name": "GR"
"include": "_default_csatqa_yaml"
"task": "csatqa_gr"
"dataset_name": "LI"
"include": "_default_csatqa_yaml"
"task": "csatqa_li"
"dataset_name": "RCH"
"include": "_default_csatqa_yaml"
"task": "csatqa_rch"
"dataset_name": "RCS"
"include": "_default_csatqa_yaml"
"task": "csatqa_rcs"
"dataset_name": "RCSS"
"include": "_default_csatqa_yaml"
"task": "csatqa_rcss"
"dataset_name": "WR"
"include": "_default_csatqa_yaml"
"task": "csatqa_wr"
import datasets
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc):
instruction = f"""다음을 읽고 정답으로 알맞은 것을 고르시요.
### Context: {doc["context"]}
### Question: {doc["question"]}
### Options:
(1) {doc['option#1']}\n(2) {doc["option#2"]}\n(3) {doc["option#3"]}\n(4) {doc['option#4']}\n(5) {doc['option#5']}
### Answer: 주어진 문제의 정답은"""
out_doc = {
"question": instruction,
"choices": ["(1)", "(2)", "(3)", "(4)", "(5)"],
"gold": int(doc["gold"]) - 1,
}
return out_doc
return dataset.map(_process_doc)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment