Unverified Commit 574e565a authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'big-refactor' into verbosity-rework

parents 73f3029c b7a4ea06
...@@ -20,12 +20,12 @@ Task naming + registration: ...@@ -20,12 +20,12 @@ Task naming + registration:
Dataset configuration options: Dataset configuration options:
- **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub. - **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub.
- **dataset_name** (`str`, *optional*, defaults to None) — The name of, what HF calls, a “data instance” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.) - **dataset_name** (`str`, *optional*, defaults to None) — The name of what HF calls a “data instance” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.)
- **dataset_kwargs** (`dict`, *optional*) — Auxiliary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv. - **dataset_kwargs** (`dict`, *optional*) — Auxiliary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv.
- **training_split** (`str`, *optional*) — Split in the dataset to use as the training split. - **training_split** (`str`, *optional*) — Split in the dataset to use as the training split.
- **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split. - **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split.
- **test_split** (`str`, *optional*) — Split in the dataset to use as the test split. - **test_split** (`str`, *optional*) — Split in the dataset to use as the test split.
- **fewshot_split** (`str`, *optional*) — Split in the dataset to draw few-shot exemplars from. assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) - **fewshot_split** (`str`, *optional*) — Split in the dataset to draw few-shot exemplars from. assert that this not None if num_fewshot > 0.
- **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template. - **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template.
Prompting / in-context formatting options: Prompting / in-context formatting options:
......
...@@ -5,3 +5,4 @@ maka ...@@ -5,3 +5,4 @@ maka
mor mor
te te
ond ond
extraversion
...@@ -55,7 +55,9 @@ eval_logger = logging.getLogger("lm-eval") ...@@ -55,7 +55,9 @@ eval_logger = logging.getLogger("lm-eval")
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
task: str = None task: str = None
task_alias: str = None
group: Union[str, list] = None group: Union[str, list] = None
group_alias: Union[str, list] = None
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
...@@ -72,7 +74,6 @@ class TaskConfig(dict): ...@@ -72,7 +74,6 @@ class TaskConfig(dict):
doc_to_text: Union[Callable, str] = None doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
doc_to_choice: Union[Callable, str, dict, list] = None doc_to_choice: Union[Callable, str, dict, list] = None
gold_alias: Union[Callable, str] = None
process_results: Union[Callable, str] = None process_results: Union[Callable, str] = None
use_prompt: str = None use_prompt: str = None
description: str = "" description: str = ""
...@@ -896,26 +897,6 @@ class ConfigurableTask(Task): ...@@ -896,26 +897,6 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def gold_alias(self, doc):
# returns a version of the gold target answer to a document,
# which should be passed into metric for scoring as the ground truth.
# in multiple_choice tasks, this should be castable to an int corresponding to the index
# within the answer choices, while doc_to_target is the string version of {{answer_choices[gold]}}.
if self.config.gold_alias is not None:
doc_to_target = self.config.gold_alias
else:
return self.doc_to_target(doc)
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
return doc_to_target(doc)
elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
raise TypeError
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
......
...@@ -2,7 +2,6 @@ import random ...@@ -2,7 +2,6 @@ import random
import itertools import itertools
import json import json
import collections import collections
import logging
import sys import sys
import torch import torch
...@@ -218,14 +217,15 @@ def evaluate( ...@@ -218,14 +217,15 @@ def evaluate(
task_hierarchy = collections.defaultdict(list) task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups # store the ordering of tasks and groups
task_order = collections.defaultdict(int) task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group task_group_alias = collections.defaultdict(dict)
sample_agg_fn = collections.defaultdict(dict)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group_name, task = task group_name, task = task
task_hierarchy[group_name].append(task_name) task_hierarchy[group_name].append(task_name)
versions[group_name] = "N/A"
else: else:
task_hierarchy[task_name] = [] task_hierarchy[task_name] = []
...@@ -235,6 +235,14 @@ def evaluate( ...@@ -235,6 +235,14 @@ def evaluate(
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"]
if ("group_alias" in configs[task_name]) and (
group_name not in task_group_alias
):
task_group_alias[group_name] = configs[task_name]["group_alias"]
if limit is not None: if limit is not None:
if task.has_test_docs(): if task.has_test_docs():
task_docs = task.test_docs() task_docs = task.test_docs()
...@@ -446,23 +454,8 @@ def evaluate( ...@@ -446,23 +454,8 @@ def evaluate(
group_name = None group_name = None
agg_fn = task.aggregation()[metric] agg_fn = task.aggregation()[metric]
task_score = agg_fn(items) results[task_name][metric_key] = agg_fn(items)
results[task_name]["samples"] = len(items)
if group_name is not None:
sample_metric_key = metric + "(sample agg)," + key
for grouping in task_to_group[task_name]:
if metric_key in results[grouping]:
results[grouping][metric_key].append(task_score)
else:
results[grouping][metric_key] = [task_score]
if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
...@@ -478,33 +471,139 @@ def evaluate( ...@@ -478,33 +471,139 @@ def evaluate(
results[task_name][metric + "_stderr" + "," + key] = stderr(items) results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(results): if bool(results):
for task_or_group in results.keys():
for metric in results[task_or_group].keys():
if type(results[task_or_group][metric]) == list:
if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[
task_or_group
][metric](results[task_or_group][metric])
else:
results[task_or_group][metric] = np.average(
results[task_or_group][metric]
)
versions[task_or_group] = "N/A"
for task_name, task in task_dict.items(): for group, task_list in reversed(task_hierarchy.items()):
if type(task) == tuple:
group_name, task = task if task_list == []:
total_size = results[group]["samples"]
else:
total_size = 0
for task in task_list:
metrics = results[task]
current_size = metrics.pop("samples")
# TODO: There should be a way for users
# to toggle between weighted and
# unweighted averaging
# For unweighted averaging, use:
# current_size = 1
all_stderr = []
for metric in [
key for key in metrics.keys() if "_stderr" not in key
]:
stderr = "_stderr,".join(metric.split(","))
stderr_score = results[task][stderr]
var_score = stderr_score**2
metric_score = results[task][metric]
all_stderr.append(stderr)
if metric in results[group]:
results[group][metric] = (
results[group][metric] * total_size
+ metric_score * current_size
) / (total_size + current_size)
# $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
results[group][stderr] = (
(total_size - 1) * results[group][stderr]
+ (current_size - 1) * var_score
) / (
total_size + current_size - 1
) + total_size * current_size / (
(total_size + current_size)
* (total_size + current_size - 1)
) * (
results[group][metric] - metric_score
) ** 2
else:
results[group][metric] = metric_score
results[group][stderr] = var_score
total_size += current_size
for stderr in all_stderr:
results[group][stderr] = np.sqrt(results[group][stderr])
results[group]["samples"] = total_size
def print_tasks(task_hierarchy, task_order, task_version, task_group_alias):
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
for group_name, task_list in task_hierarchy.items():
order = task_order[group_name] order = task_order[group_name]
tabbed_name = "-" * order + group_name results_agg[group_name] = results[group_name].copy()
results_agg[tabbed_name] = results[group_name] results_agg[group_name]["tab"] = order
versions[tabbed_name] = versions[group_name]
if order == 0: if (order < max(task_order.values())) and (len(task_list) > 0):
groups_agg[group_name] = results[group_name] groups_agg[group_name] = results[group_name].copy()
groups_agg[group_name]["tab"] = order
order = task_order[task_name]
tabbed_name = "-" * order + task_name if task_list != []:
results_agg[tabbed_name] = results[task_name] for task in sorted(task_list):
versions[tabbed_name] = versions[task_name] if task in task_hierarchy:
_task_hierarchy = {task: task_hierarchy[task]}
else:
_task_hierarchy = {task: []}
_results_agg, _groups_agg, task_version = print_tasks(
_task_hierarchy, task_order, task_version, task_group_alias
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg, task_version
results_agg, groups_agg, versions = print_tasks(
task_hierarchy, task_order, versions, task_group_alias
)
_results_agg = collections.defaultdict(dict)
_versions = collections.defaultdict(dict)
for task in results_agg:
task_results = results_agg[task]
if "samples" in task_results:
task_results.pop("samples")
tab_string = ""
if "tab" in task_results:
tab = task_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if task in task_group_alias:
task_alias = task_group_alias[task]
_results_agg[tab_string + task_alias] = task_results
_versions[tab_string + task_alias] = versions[task]
else:
_results_agg[tab_string + task] = task_results
_versions[tab_string + task] = versions[task]
results_agg = _results_agg
versions = _versions
_groups_agg = collections.defaultdict(dict)
for group in groups_agg:
group_results = groups_agg[group]
if "samples" in group_results:
group_results.pop("samples")
tab_string = ""
if "tab" in group_results:
tab = group_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if group in task_group_alias:
group_alias = task_group_alias[group]
_groups_agg[tab_string + group_alias] = group_results
else:
_groups_agg[tab_string + group] = group_results
groups_agg = _groups_agg
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
......
...@@ -675,7 +675,8 @@ class HFLM(LM): ...@@ -675,7 +675,8 @@ class HFLM(LM):
else None, else None,
) )
for chunk in tqdm(chunks, disable=(disable_tqdm or (self.rank != 0))): pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for chunk in chunks:
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
inplens = [] inplens = []
...@@ -812,6 +813,9 @@ class HFLM(LM): ...@@ -812,6 +813,9 @@ class HFLM(LM):
res.append(answer) res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
pbar.update(1)
pbar.close()
return re_ord.get_original(res) return re_ord.get_original(res)
...@@ -857,7 +861,7 @@ class HFLM(LM): ...@@ -857,7 +861,7 @@ class HFLM(LM):
if self.batch_size == "auto" and not adaptive_batch_size if self.batch_size == "auto" and not adaptive_batch_size
else None, else None,
) )
for chunk in tqdm(chunks, disable=self.rank != 0): for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same # we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
......
...@@ -14,8 +14,7 @@ Q: There were nine computers in the server room. Five more computers were instal ...@@ -14,8 +14,7 @@ Q: There were nine computers in the server room. Five more computers were instal
Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\n\nA: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.\n\n\ Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\n\nA: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.\n\n\
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\n\nA: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.\n\n\ Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\n\nA: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.\n\n\
Q: {{question}}\n\nA:" Q: {{question}}\n\nA:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}" doc_to_target: " {{answer.split('### ')[-1].rstrip()}}"
gold_alias: "{{answer.split('### ')[-1].rstrip()}}" # this post-processes the reference that we'll score against
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -25,6 +24,8 @@ metric_list: ...@@ -25,6 +24,8 @@ metric_list:
regexes_to_ignore: regexes_to_ignore:
- "," - ","
- "\\$" - "\\$"
- "(?s).*#### "
- "\n\n"
generation_kwargs: generation_kwargs:
until: until:
- "Q:" - "Q:"
...@@ -37,5 +38,5 @@ filter_list: ...@@ -37,5 +38,5 @@ filter_list:
- name: "get-answer" - name: "get-answer"
filter: filter:
- function: "regex" - function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)" regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)."
- function: "take_first" - function: "take_first"
group: group:
- math_word_problems - math_word_problems
task: gsm8k_yaml task: gsm8k
dataset_path: gsm8k dataset_path: gsm8k
dataset_name: main dataset_name: main
output_type: generate_until output_type: generate_until
...@@ -9,7 +9,6 @@ fewshot_split: train ...@@ -9,7 +9,6 @@ fewshot_split: train
test_split: test test_split: test
doc_to_text: "Question: {{question}}\nAnswer:" doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}" doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
gold_alias: "{{answer.split('### ')[-1].rstrip()}}" # this post-processes the reference that we'll score against
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -19,7 +18,7 @@ metric_list: ...@@ -19,7 +18,7 @@ metric_list:
regexes_to_ignore: regexes_to_ignore:
- "," - ","
- "\\$" - "\\$"
- ".*### " - "(?s).*#### "
generation_kwargs: generation_kwargs:
until: until:
- "\n\n" - "\n\n"
...@@ -28,9 +27,9 @@ generation_kwargs: ...@@ -28,9 +27,9 @@ generation_kwargs:
temperature: 0.0 temperature: 0.0
repeats: 1 repeats: 1
num_fewshot: 5 num_fewshot: 5
# filter_list: filter_list:
# - name: "get-answer" - name: "get-answer"
# filter: filter:
# - function: "regex" - function: "regex"
# regex_pattern: "### (\\-?[0-9\\.\\,]+)" regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
# - function: "take_first" - function: "take_first"
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
# template_aliases: #"{% set answer_choices = range(1, 11)|list %}" # template_aliases: #"{% set answer_choices = range(1, 11)|list %}"
# doc_to_text: 'Activity: "{{activity}}"\nRating:' # doc_to_text: 'Activity: "{{activity}}"\nRating:'
# doc_to_target: "{{answer_choices[label]}}" # doc_to_target: "{{answer_choices[label]}}"
# gold_alias: "{{label}}" # this will be cast to an int.
# metric_list: # metric_list:
# - metric: acc # - metric: acc
# TODO: we want this to be implemented as a winograd_schema task type, actually # TODO: we want this to be implemented as a winograd_schema task type, actually
""" """
Take in a YAML, and output all other splits with this YAML Take in a YAML, and output all "other" splits with this YAML
""" """
import os import os
import yaml import yaml
...@@ -10,73 +10,74 @@ from tqdm import tqdm ...@@ -10,73 +10,74 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
SUBJECTS = [ SUBJECTS = {
"abstract_algebra", "abstract_algebra": "stem",
"anatomy", "anatomy": "stem",
"astronomy", "astronomy": "stem",
"business_ethics", "business_ethics": "other",
"clinical_knowledge", "clinical_knowledge": "other",
"college_biology", "college_biology": "stem",
"college_chemistry", "college_chemistry": "stem",
"college_computer_science", "college_computer_science": "stem",
"college_mathematics", "college_mathematics": "stem",
"college_medicine", "college_medicine": "other",
"college_physics", "college_physics": "stem",
"computer_security", "computer_security": "stem",
"conceptual_physics", "conceptual_physics": "stem",
"econometrics", "econometrics": "social_sciences",
"electrical_engineering", "electrical_engineering": "stem",
"elementary_mathematics", "elementary_mathematics": "stem",
"formal_logic", "formal_logic": "humanities",
"global_facts", "global_facts": "other",
"high_school_biology", "high_school_biology": "stem",
"high_school_chemistry", "high_school_chemistry": "stem",
"high_school_computer_science", "high_school_computer_science": "stem",
"high_school_european_history", "high_school_european_history": "humanities",
"high_school_geography", "high_school_geography": "social_sciences",
"high_school_government_and_politics", "high_school_government_and_politics": "social_sciences",
"high_school_macroeconomics", "high_school_macroeconomics": "social_sciences",
"high_school_mathematics", "high_school_mathematics": "stem",
"high_school_microeconomics", "high_school_microeconomics": "social_sciences",
"high_school_physics", "high_school_physics": "stem",
"high_school_psychology", "high_school_psychology": "social_sciences",
"high_school_statistics", "high_school_statistics": "stem",
"high_school_us_history", "high_school_us_history": "humanities",
"high_school_world_history", "high_school_world_history": "humanities",
"human_aging", "human_aging": "other",
"human_sexuality", "human_sexuality": "social_sciences",
"international_law", "international_law": "humanities",
"jurisprudence", "jurisprudence": "humanities",
"logical_fallacies", "logical_fallacies": "humanities",
"machine_learning", "machine_learning": "stem",
"management", "management": "other",
"marketing", "marketing": "other",
"medical_genetics", "medical_genetics": "other",
"miscellaneous", "miscellaneous": "other",
"moral_disputes", "moral_disputes": "humanities",
"moral_scenarios", "moral_scenarios": "humanities",
"nutrition", "nutrition": "other",
"philosophy", "philosophy": "humanities",
"prehistory", "prehistory": "humanities",
"professional_accounting", "professional_accounting": "other",
"professional_law", "professional_law": "humanities",
"professional_medicine", "professional_medicine": "other",
"professional_psychology", "professional_psychology": "social_sciences",
"public_relations", "public_relations": "social_sciences",
"security_studies", "security_studies": "social_sciences",
"sociology", "sociology": "social_sciences",
"us_foreign_policy", "us_foreign_policy": "social_sciences",
"virology", "virology": "other",
"world_religions", "world_religions": "humanities",
] }
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--base_yaml_path", required=True) parser.add_argument("--base_yaml_path", required=True)
parser.add_argument("--save_prefix_path", default="flan") parser.add_argument("--save_prefix_path", default="mmlu")
parser.add_argument("--cot_prompt_path", default=None) parser.add_argument("--cot_prompt_path", default=None)
parser.add_argument("--task_prefix", default="") parser.add_argument("--task_prefix", default="")
parser.add_argument("--group_prefix", default="")
return parser.parse_args() return parser.parse_args()
...@@ -84,7 +85,7 @@ if __name__ == "__main__": ...@@ -84,7 +85,7 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our "other" YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1] base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path) as f: with open(args.base_yaml_path) as f:
base_yaml = yaml.full_load(f) base_yaml = yaml.full_load(f)
...@@ -95,7 +96,12 @@ if __name__ == "__main__": ...@@ -95,7 +96,12 @@ if __name__ == "__main__":
with open(args.cot_prompt_path) as f: with open(args.cot_prompt_path) as f:
cot_file = json.load(f) cot_file = json.load(f)
for subject in tqdm(SUBJECTS): ALL_CATEGORIES = []
for subject, category in tqdm(SUBJECTS.items()):
if category not in ALL_CATEGORIES:
ALL_CATEGORIES.append(category)
if args.cot_prompt_path is not None: if args.cot_prompt_path is not None:
description = cot_file[subject] description = cot_file[subject]
else: else:
...@@ -103,9 +109,14 @@ if __name__ == "__main__": ...@@ -103,9 +109,14 @@ if __name__ == "__main__":
yaml_dict = { yaml_dict = {
"include": base_yaml_name, "include": base_yaml_name,
"group": f"mmlu_{args.task_prefix}_{category}"
if args.task_prefix != ""
else f"mmlu_{category}",
"group_alias": category.replace("_", " "),
"task": f"mmlu_{args.task_prefix}_{subject}" "task": f"mmlu_{args.task_prefix}_{subject}"
if args.task_prefix != "" if args.task_prefix != ""
else f"mmlu_{subject}", else f"mmlu_{subject}",
"task_alias": subject.replace("_", " "),
"dataset_name": subject, "dataset_name": subject,
"description": description, "description": description,
} }
...@@ -116,7 +127,33 @@ if __name__ == "__main__": ...@@ -116,7 +127,33 @@ if __name__ == "__main__":
yaml.dump( yaml.dump(
yaml_dict, yaml_dict,
yaml_file, yaml_file,
width=float("inf"), # width=float("inf"),
allow_unicode=True, allow_unicode=True,
default_style='"', default_style='"',
) )
if args.task_prefix != "":
mmlu_subcategories = [
f"mmlu_{args.task_prefix}_{category}" for category in ALL_CATEGORIES
]
else:
mmlu_subcategories = [f"mmlu_{category}" for category in ALL_CATEGORIES]
if args.group_prefix != "":
file_save_path = args.group_prefix + ".yaml"
else:
file_save_path = args.save_prefix_path + ".yaml"
eval_logger.info(f"Saving benchmark config to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
yaml.dump(
{
"group": f"mmlu_{args.task_prefix}"
if args.task_prefix != ""
else "mmlu",
"task": mmlu_subcategories,
},
yaml_file,
indent=4,
default_flow_style=False,
)
...@@ -12,6 +12,3 @@ metric_list: ...@@ -12,6 +12,3 @@ metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
group: mmlu
task:
- mmlu_stem
- mmlu_other
- mmlu_social_sciences
- mmlu_humanities
"dataset_name": "abstract_algebra" "dataset_name": "abstract_algebra"
"description": "The following are multiple choice questions (with answers) about abstract algebra.\n\n" "description": "The following are multiple choice questions (with answers) about abstract\
\ algebra.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "mmlu_abstract_algebra" "task": "mmlu_abstract_algebra"
"task_alias": "abstract_algebra"
"dataset_name": "anatomy" "dataset_name": "anatomy"
"description": "The following are multiple choice questions (with answers) about anatomy.\n\n" "description": "The following are multiple choice questions (with answers) about anatomy.\n\
\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "mmlu_anatomy" "task": "mmlu_anatomy"
"task_alias": "anatomy"
"dataset_name": "astronomy" "dataset_name": "astronomy"
"description": "The following are multiple choice questions (with answers) about astronomy.\n\n" "description": "The following are multiple choice questions (with answers) about astronomy.\n\
\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "mmlu_astronomy" "task": "mmlu_astronomy"
"task_alias": "astronomy"
"dataset_name": "business_ethics" "dataset_name": "business_ethics"
"description": "The following are multiple choice questions (with answers) about business ethics.\n\n" "description": "The following are multiple choice questions (with answers) about business\
\ ethics.\n\n"
"group": "mmlu_other"
"group_alias": "other"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "mmlu_business_ethics" "task": "mmlu_business_ethics"
"task_alias": "business_ethics"
"dataset_name": "clinical_knowledge" "dataset_name": "clinical_knowledge"
"description": "The following are multiple choice questions (with answers) about clinical knowledge.\n\n" "description": "The following are multiple choice questions (with answers) about clinical\
\ knowledge.\n\n"
"group": "mmlu_other"
"group_alias": "other"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "mmlu_clinical_knowledge" "task": "mmlu_clinical_knowledge"
"task_alias": "clinical_knowledge"
"dataset_name": "college_biology" "dataset_name": "college_biology"
"description": "The following are multiple choice questions (with answers) about college biology.\n\n" "description": "The following are multiple choice questions (with answers) about college\
\ biology.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "mmlu_college_biology" "task": "mmlu_college_biology"
"task_alias": "college_biology"
"dataset_name": "college_chemistry" "dataset_name": "college_chemistry"
"description": "The following are multiple choice questions (with answers) about college chemistry.\n\n" "description": "The following are multiple choice questions (with answers) about college\
\ chemistry.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "mmlu_college_chemistry" "task": "mmlu_college_chemistry"
"task_alias": "college_chemistry"
"dataset_name": "college_computer_science" "dataset_name": "college_computer_science"
"description": "The following are multiple choice questions (with answers) about college computer science.\n\n" "description": "The following are multiple choice questions (with answers) about college\
\ computer science.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml" "include": "_default_template_yaml"
"task": "mmlu_college_computer_science" "task": "mmlu_college_computer_science"
"task_alias": "college_computer_science"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment