"vscode:/vscode.git/clone" did not exist on "ac9595a9f118a023e248eaffcfa5c324f36fd081"
Unverified Commit 815f59e6 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #922 from EleutherAI/mmlu_subgroups

[Refactor] Mmlu subgroups and weight avg
parents 3533e4b9 44124d95
......@@ -50,7 +50,7 @@ dataset_kwargs: null # any extra keyword arguments that should be passed to the
```
dataset_path: json
dataset_name: null
dataset_kwargs:
dataset_kwargs:
data_files: /path/to/my/json
```
-------------------------------
......
......@@ -52,7 +52,9 @@ ALL_OUTPUT_TYPES = [
class TaskConfig(dict):
# task naming/registry
task: str = None
task_alias: str = None
group: Union[str, list] = None
group_alias: Union[str, list] = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
......
......@@ -221,14 +221,15 @@ def evaluate(
task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups
task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group
sample_agg_fn = collections.defaultdict(dict)
task_group_alias = collections.defaultdict(dict)
# get lists of each type of request
for task_name, task in task_dict.items():
if type(task) == tuple:
group_name, task = task
task_hierarchy[group_name].append(task_name)
versions[group_name] = "N/A"
else:
task_hierarchy[task_name] = []
......@@ -238,6 +239,14 @@ def evaluate(
versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config())
if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"]
if ("group_alias" in configs[task_name]) and (
group_name not in task_group_alias
):
task_group_alias[group_name] = configs[task_name]["group_alias"]
if limit is not None:
if task.has_test_docs():
task_docs = task.test_docs()
......@@ -449,23 +458,8 @@ def evaluate(
group_name = None
agg_fn = task.aggregation()[metric]
task_score = agg_fn(items)
if group_name is not None:
sample_metric_key = metric + "(sample agg)," + key
for grouping in task_to_group[task_name]:
if metric_key in results[grouping]:
results[grouping][metric_key].append(task_score)
else:
results[grouping][metric_key] = [task_score]
if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score
results[task_name][metric_key] = agg_fn(items)
results[task_name]["samples"] = len(items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
......@@ -481,33 +475,139 @@ def evaluate(
results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(results):
for task_or_group in results.keys():
for metric in results[task_or_group].keys():
if type(results[task_or_group][metric]) == list:
if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[
task_or_group
][metric](results[task_or_group][metric])
else:
results[task_or_group][metric] = np.average(
results[task_or_group][metric]
)
versions[task_or_group] = "N/A"
for task_name, task in task_dict.items():
if type(task) == tuple:
group_name, task = task
for group, task_list in reversed(task_hierarchy.items()):
if task_list == []:
total_size = results[group]["samples"]
else:
total_size = 0
for task in task_list:
metrics = results[task]
current_size = metrics.pop("samples")
# TODO: There should be a way for users
# to toggle between weighted and
# unweighted averaging
# For unweighted averaging, use:
# current_size = 1
all_stderr = []
for metric in [
key for key in metrics.keys() if "_stderr" not in key
]:
stderr = "_stderr,".join(metric.split(","))
stderr_score = results[task][stderr]
var_score = stderr_score**2
metric_score = results[task][metric]
all_stderr.append(stderr)
if metric in results[group]:
results[group][metric] = (
results[group][metric] * total_size
+ metric_score * current_size
) / (total_size + current_size)
# $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
results[group][stderr] = (
(total_size - 1) * results[group][stderr]
+ (current_size - 1) * var_score
) / (
total_size + current_size - 1
) + total_size * current_size / (
(total_size + current_size)
* (total_size + current_size - 1)
) * (
results[group][metric] - metric_score
) ** 2
else:
results[group][metric] = metric_score
results[group][stderr] = var_score
total_size += current_size
for stderr in all_stderr:
results[group][stderr] = np.sqrt(results[group][stderr])
results[group]["samples"] = total_size
def print_tasks(task_hierarchy, task_order, task_version, task_group_alias):
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
for group_name, task_list in task_hierarchy.items():
order = task_order[group_name]
tabbed_name = "-" * order + group_name
results_agg[tabbed_name] = results[group_name]
versions[tabbed_name] = versions[group_name]
if order == 0:
groups_agg[group_name] = results[group_name]
order = task_order[task_name]
tabbed_name = "-" * order + task_name
results_agg[tabbed_name] = results[task_name]
versions[tabbed_name] = versions[task_name]
results_agg[group_name] = results[group_name].copy()
results_agg[group_name]["tab"] = order
if (order < max(task_order.values())) and (len(task_list) > 0):
groups_agg[group_name] = results[group_name].copy()
groups_agg[group_name]["tab"] = order
if task_list != []:
for task in sorted(task_list):
if task in task_hierarchy:
_task_hierarchy = {task: task_hierarchy[task]}
else:
_task_hierarchy = {task: []}
_results_agg, _groups_agg, task_version = print_tasks(
_task_hierarchy, task_order, task_version, task_group_alias
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg, task_version
results_agg, groups_agg, versions = print_tasks(
task_hierarchy, task_order, versions, task_group_alias
)
_results_agg = collections.defaultdict(dict)
_versions = collections.defaultdict(dict)
for task in results_agg:
task_results = results_agg[task]
if "samples" in task_results:
task_results.pop("samples")
tab_string = ""
if "tab" in task_results:
tab = task_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if task in task_group_alias:
task_alias = task_group_alias[task]
_results_agg[tab_string + task_alias] = task_results
_versions[tab_string + task_alias] = versions[task]
else:
_results_agg[tab_string + task] = task_results
_versions[tab_string + task] = versions[task]
results_agg = _results_agg
versions = _versions
_groups_agg = collections.defaultdict(dict)
for group in groups_agg:
group_results = groups_agg[group]
if "samples" in group_results:
group_results.pop("samples")
tab_string = ""
if "tab" in group_results:
tab = group_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if group in task_group_alias:
group_alias = task_group_alias[group]
_groups_agg[tab_string + group_alias] = group_results
else:
_groups_agg[tab_string + group] = group_results
groups_agg = _groups_agg
results_dict = {
"results": dict(results_agg.items()),
......
"""
Take in a YAML, and output all other splits with this YAML
Take in a YAML, and output all "other" splits with this YAML
"""
import os
import yaml
......@@ -10,73 +10,74 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger
SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
SUBJECTS = {
"abstract_algebra": "stem",
"anatomy": "stem",
"astronomy": "stem",
"business_ethics": "other",
"clinical_knowledge": "other",
"college_biology": "stem",
"college_chemistry": "stem",
"college_computer_science": "stem",
"college_mathematics": "stem",
"college_medicine": "other",
"college_physics": "stem",
"computer_security": "stem",
"conceptual_physics": "stem",
"econometrics": "social_sciences",
"electrical_engineering": "stem",
"elementary_mathematics": "stem",
"formal_logic": "humanities",
"global_facts": "other",
"high_school_biology": "stem",
"high_school_chemistry": "stem",
"high_school_computer_science": "stem",
"high_school_european_history": "humanities",
"high_school_geography": "social_sciences",
"high_school_government_and_politics": "social_sciences",
"high_school_macroeconomics": "social_sciences",
"high_school_mathematics": "stem",
"high_school_microeconomics": "social_sciences",
"high_school_physics": "stem",
"high_school_psychology": "social_sciences",
"high_school_statistics": "stem",
"high_school_us_history": "humanities",
"high_school_world_history": "humanities",
"human_aging": "other",
"human_sexuality": "social_sciences",
"international_law": "humanities",
"jurisprudence": "humanities",
"logical_fallacies": "humanities",
"machine_learning": "stem",
"management": "other",
"marketing": "other",
"medical_genetics": "other",
"miscellaneous": "other",
"moral_disputes": "humanities",
"moral_scenarios": "humanities",
"nutrition": "other",
"philosophy": "humanities",
"prehistory": "humanities",
"professional_accounting": "other",
"professional_law": "humanities",
"professional_medicine": "other",
"professional_psychology": "social_sciences",
"public_relations": "social_sciences",
"security_studies": "social_sciences",
"sociology": "social_sciences",
"us_foreign_policy": "social_sciences",
"virology": "other",
"world_religions": "humanities",
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_yaml_path", required=True)
parser.add_argument("--save_prefix_path", default="flan")
parser.add_argument("--save_prefix_path", default="mmlu")
parser.add_argument("--cot_prompt_path", default=None)
parser.add_argument("--task_prefix", default="")
parser.add_argument("--group_prefix", default="")
return parser.parse_args()
......@@ -84,7 +85,7 @@ if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
# get filename of base_yaml so we can `"include": ` it in our "other" YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path) as f:
base_yaml = yaml.full_load(f)
......@@ -95,7 +96,12 @@ if __name__ == "__main__":
with open(args.cot_prompt_path) as f:
cot_file = json.load(f)
for subject in tqdm(SUBJECTS):
ALL_CATEGORIES = []
for subject, category in tqdm(SUBJECTS.items()):
if category not in ALL_CATEGORIES:
ALL_CATEGORIES.append(category)
if args.cot_prompt_path is not None:
description = cot_file[subject]
else:
......@@ -103,9 +109,14 @@ if __name__ == "__main__":
yaml_dict = {
"include": base_yaml_name,
"group": f"mmlu_{args.task_prefix}_{category}"
if args.task_prefix != ""
else f"mmlu_{category}",
"group_alias": category.replace("_", " "),
"task": f"mmlu_{args.task_prefix}_{subject}"
if args.task_prefix != ""
else f"mmlu_{subject}",
"task_alias": subject.replace("_", " "),
"dataset_name": subject,
"description": description,
}
......@@ -116,7 +127,33 @@ if __name__ == "__main__":
yaml.dump(
yaml_dict,
yaml_file,
width=float("inf"),
# width=float("inf"),
allow_unicode=True,
default_style='"',
)
if args.task_prefix != "":
mmlu_subcategories = [
f"mmlu_{args.task_prefix}_{category}" for category in ALL_CATEGORIES
]
else:
mmlu_subcategories = [f"mmlu_{category}" for category in ALL_CATEGORIES]
if args.group_prefix != "":
file_save_path = args.group_prefix + ".yaml"
else:
file_save_path = args.save_prefix_path + ".yaml"
eval_logger.info(f"Saving benchmark config to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
yaml.dump(
{
"group": f"mmlu_{args.task_prefix}"
if args.task_prefix != ""
else "mmlu",
"task": mmlu_subcategories,
},
yaml_file,
indent=4,
default_flow_style=False,
)
......@@ -12,6 +12,3 @@ metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
group: mmlu
task:
- mmlu_stem
- mmlu_other
- mmlu_social_sciences
- mmlu_humanities
"dataset_name": "abstract_algebra"
"description": "The following are multiple choice questions (with answers) about abstract algebra.\n\n"
"description": "The following are multiple choice questions (with answers) about abstract\
\ algebra.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_abstract_algebra"
"task_alias": "abstract_algebra"
"dataset_name": "anatomy"
"description": "The following are multiple choice questions (with answers) about anatomy.\n\n"
"description": "The following are multiple choice questions (with answers) about anatomy.\n\
\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_anatomy"
"task_alias": "anatomy"
"dataset_name": "astronomy"
"description": "The following are multiple choice questions (with answers) about astronomy.\n\n"
"description": "The following are multiple choice questions (with answers) about astronomy.\n\
\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_astronomy"
"task_alias": "astronomy"
"dataset_name": "business_ethics"
"description": "The following are multiple choice questions (with answers) about business ethics.\n\n"
"description": "The following are multiple choice questions (with answers) about business\
\ ethics.\n\n"
"group": "mmlu_other"
"group_alias": "other"
"include": "_default_template_yaml"
"task": "mmlu_business_ethics"
"task_alias": "business_ethics"
"dataset_name": "clinical_knowledge"
"description": "The following are multiple choice questions (with answers) about clinical knowledge.\n\n"
"description": "The following are multiple choice questions (with answers) about clinical\
\ knowledge.\n\n"
"group": "mmlu_other"
"group_alias": "other"
"include": "_default_template_yaml"
"task": "mmlu_clinical_knowledge"
"task_alias": "clinical_knowledge"
"dataset_name": "college_biology"
"description": "The following are multiple choice questions (with answers) about college biology.\n\n"
"description": "The following are multiple choice questions (with answers) about college\
\ biology.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_college_biology"
"task_alias": "college_biology"
"dataset_name": "college_chemistry"
"description": "The following are multiple choice questions (with answers) about college chemistry.\n\n"
"description": "The following are multiple choice questions (with answers) about college\
\ chemistry.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_college_chemistry"
"task_alias": "college_chemistry"
"dataset_name": "college_computer_science"
"description": "The following are multiple choice questions (with answers) about college computer science.\n\n"
"description": "The following are multiple choice questions (with answers) about college\
\ computer science.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_college_computer_science"
"task_alias": "college_computer_science"
"dataset_name": "college_mathematics"
"description": "The following are multiple choice questions (with answers) about college mathematics.\n\n"
"description": "The following are multiple choice questions (with answers) about college\
\ mathematics.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_college_mathematics"
"task_alias": "college_mathematics"
"dataset_name": "college_medicine"
"description": "The following are multiple choice questions (with answers) about college medicine.\n\n"
"description": "The following are multiple choice questions (with answers) about college\
\ medicine.\n\n"
"group": "mmlu_other"
"group_alias": "other"
"include": "_default_template_yaml"
"task": "mmlu_college_medicine"
"task_alias": "college_medicine"
"dataset_name": "college_physics"
"description": "The following are multiple choice questions (with answers) about college physics.\n\n"
"description": "The following are multiple choice questions (with answers) about college\
\ physics.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_college_physics"
"task_alias": "college_physics"
"dataset_name": "computer_security"
"description": "The following are multiple choice questions (with answers) about computer security.\n\n"
"description": "The following are multiple choice questions (with answers) about computer\
\ security.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_computer_security"
"task_alias": "computer_security"
"dataset_name": "conceptual_physics"
"description": "The following are multiple choice questions (with answers) about conceptual physics.\n\n"
"description": "The following are multiple choice questions (with answers) about conceptual\
\ physics.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_conceptual_physics"
"task_alias": "conceptual_physics"
"dataset_name": "econometrics"
"description": "The following are multiple choice questions (with answers) about econometrics.\n\n"
"description": "The following are multiple choice questions (with answers) about econometrics.\n\
\n"
"group": "mmlu_social_sciences"
"group_alias": "social_sciences"
"include": "_default_template_yaml"
"task": "mmlu_econometrics"
"task_alias": "econometrics"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment