Commit f692caa9 authored by lintangsutawika's avatar lintangsutawika
Browse files

updated to appease the pre-commit

parent ab96fc7e
...@@ -11,7 +11,7 @@ from typing import Union ...@@ -11,7 +11,7 @@ from typing import Union
import numpy as np import numpy as np
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.tasks import TaskManager, include_path, initialize_tasks from lm_eval.tasks import TaskManager, initialize_tasks
from lm_eval.utils import make_table from lm_eval.utils import make_table
......
import logging import logging
import math import math
import random import random
from collections.abc import Iterable
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable
from typing import List from typing import List
import evaluate
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
import evaluate
from lm_eval.api.registry import register_aggregation, register_metric from lm_eval.api.registry import register_aggregation, register_metric
...@@ -119,7 +119,6 @@ def ter(items): ...@@ -119,7 +119,6 @@ def ter(items):
@register_aggregation("brier_score") @register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function def brier_score(items): # This is a passthrough function
# Certain datasets like arc_easy can have a different number of choices. # Certain datasets like arc_easy can have a different number of choices.
golds, predictions = list(zip(*items)) golds, predictions = list(zip(*items))
......
...@@ -2,6 +2,7 @@ import logging ...@@ -2,6 +2,7 @@ import logging
from typing import Callable, Dict from typing import Callable, Dict
import evaluate import evaluate
from lm_eval.api.model import LM from lm_eval.api.model import LM
......
...@@ -1193,8 +1193,8 @@ class ConfigurableTask(Task): ...@@ -1193,8 +1193,8 @@ class ConfigurableTask(Task):
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}), **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}), **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**(
# {"brier_score": (gold, prob_norm)} # {"brier_score": (gold, prob_norm)}
**(
{"brier_score": [np.eye(len(prob_norm))[gold], prob_norm]} {"brier_score": [np.eye(len(prob_norm))[gold], prob_norm]}
if "brier_score" in use_metric if "brier_score" in use_metric
else {} else {}
......
...@@ -498,7 +498,6 @@ def evaluate( ...@@ -498,7 +498,6 @@ def evaluate(
metric_key = f"{metric},{key}" metric_key = f"{metric},{key}"
agg_fn = task.aggregation()[metric] agg_fn = task.aggregation()[metric]
results[task_name][metric_key] = agg_fn(items) results[task_name][metric_key] = agg_fn(items)
results[task_name]["samples"] = len(items) results[task_name]["samples"] = len(items)
...@@ -524,19 +523,37 @@ def evaluate( ...@@ -524,19 +523,37 @@ def evaluate(
# or `task_name: []`. # or `task_name: []`.
# we only want to operate on groups here. # we only want to operate on groups here.
continue continue
for metric in [
group_metrics = list(
dict.fromkeys(
[
key key
for key in results[task_list[0]].keys() for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["alias", "samples"] if "_stderr" not in key and key not in ["alias", "samples"]
]: # TODO: what if tasks don't all share the same metrics ]
)
)
for metric in group_metrics:
# TODO: what if tasks don't all share the same metrics
stderr = "_stderr,".join(metric.split(",")) stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks # gather metrics, sizes, and stderrs from subtasks
metrics = [ metrics = [
results[task][metric] for task in task_list results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy? ] # TODO: copy?
stderrs = [results[task][stderr] for task in task_list] stderrs = [
sizes = [results[task]["samples"] for task in task_list] results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
# compute group's pooled metric and stderr # compute group's pooled metric and stderr
results[group][ results[group][
......
...@@ -16,4 +16,3 @@ filter_list: ...@@ -16,4 +16,3 @@ filter_list:
- function: "regex" - function: "regex"
regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))" regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))"
- function: "take_first" - function: "take_first"
...@@ -15,4 +15,3 @@ filter_list: ...@@ -15,4 +15,3 @@ filter_list:
- function: "regex" - function: "regex"
regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))" regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))"
- function: "take_first" - function: "take_first"
...@@ -17,7 +17,6 @@ from typing import ( ...@@ -17,7 +17,6 @@ from typing import (
) )
import numpy as np import numpy as np
import yaml import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment