Unverified Commit 96d185fa authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Cont metrics (#1475)



* add brier_score

* process brier_score

* brier score is working for N-sized class

* fxied brier score

* add TED to BigBench and Brier score to MMLU

* format

* Update metrics.py

* Update task.py

* Update generate_until_template_yaml

* Delete lm_eval/tasks/bigbench/aux_metric.py

* Update generate_until_template_yaml

* Update _default_template_yaml

* Update _generate_configs.py

* Update _generate_configs.py

* Update _generate_configs.py

* fix (format?)

* format?

* format, once more

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 1e6c9272
...@@ -116,6 +116,25 @@ def ter(items): ...@@ -116,6 +116,25 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score return sacrebleu.corpus_ter(preds, refs).score
@register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function
gold, predictions = list(zip(*items))
gold = list(gold)
gold_one_hot = np.eye(np.max(gold) + 1)[gold]
predictions = list(zip(*items))[1]
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
@register_metric(
metric="brier_score",
higher_is_better=False,
output_type=["multiple_choice"],
aggregation="brier_score",
)
def brier_score_fn(items): # This is a passthrough function
return items
@register_metric( @register_metric(
metric="acc", metric="acc",
higher_is_better=True, higher_is_better=True,
......
...@@ -1227,12 +1227,21 @@ class ConfigurableTask(Task): ...@@ -1227,12 +1227,21 @@ class ConfigurableTask(Task):
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -100 else 0 exact_match = int(is_greedy[gold]) if gold != -100 else 0
prob_norm = utils.softmax(lls)
# TODO use keyword arguments to the metric?
# gold, pred, norm stuff, the original lls,
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}), **({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}), **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}), **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**(
{"brier_score": (gold, prob_norm)}
if "brier_score" in use_metric
else {}
),
} }
if "acc_mutual_info" in use_metric: if "acc_mutual_info" in use_metric:
......
""" """
Take in a YAML, and output all other splits with this YAML Take in a YAML, and output all other splits with this YAML
""" """
import os
import yaml
import argparse import argparse
import os
import yaml
from tqdm import tqdm from tqdm import tqdm
...@@ -68,6 +68,7 @@ SUBJECTS = { ...@@ -68,6 +68,7 @@ SUBJECTS = {
"world_religions": "العلوم الانسانية", "world_religions": "العلوم الانسانية",
} }
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--base_yaml_path", required=True) parser.add_argument("--base_yaml_path", required=True)
...@@ -95,9 +96,7 @@ if __name__ == "__main__": ...@@ -95,9 +96,7 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None: if args.cot_prompt_path is not None:
description = cot_file[subject_eng] description = cot_file[subject_eng]
else: else:
description = ( description = f"فم بعملية التقييم في مجال {category} \n\n"
f"فم بعملية التقييم في مجال {category} \n\n"
)
yaml_dict = { yaml_dict = {
"include": base_yaml_name, "include": base_yaml_name,
......
...@@ -11,6 +11,7 @@ from itertools import islice ...@@ -11,6 +11,7 @@ from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, List from typing import Any, Callable, List
import numpy as np
import yaml import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
...@@ -104,6 +105,12 @@ def pattern_match(patterns, source_list): ...@@ -104,6 +105,12 @@ def pattern_match(patterns, source_list):
return sorted(list(task_names)) return sorted(list(task_names))
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
def general_detokenize(string): def general_detokenize(string):
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment