Commit f692caa9 authored by lintangsutawika's avatar lintangsutawika
Browse files

updated to appease the pre-commit

parent ab96fc7e
......@@ -11,7 +11,7 @@ from typing import Union
import numpy as np
from lm_eval import evaluator, utils
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
from lm_eval.tasks import TaskManager, initialize_tasks
from lm_eval.utils import make_table
......
import logging
import math
import random
from collections.abc import Iterable
from collections import defaultdict
from collections.abc import Iterable
from typing import List
import evaluate
import numpy as np
import sacrebleu
import sklearn.metrics
import evaluate
from lm_eval.api.registry import register_aggregation, register_metric
......@@ -119,7 +119,6 @@ def ter(items):
@register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function
# Certain datasets like arc_easy can have a different number of choices.
golds, predictions = list(zip(*items))
......
......@@ -2,6 +2,7 @@ import logging
from typing import Callable, Dict
import evaluate
from lm_eval.api.model import LM
......
......@@ -1193,8 +1193,8 @@ class ConfigurableTask(Task):
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
# {"brier_score": (gold, prob_norm)}
**(
# {"brier_score": (gold, prob_norm)}
{"brier_score": [np.eye(len(prob_norm))[gold], prob_norm]}
if "brier_score" in use_metric
else {}
......
......@@ -498,7 +498,6 @@ def evaluate(
metric_key = f"{metric},{key}"
agg_fn = task.aggregation()[metric]
results[task_name][metric_key] = agg_fn(items)
results[task_name]["samples"] = len(items)
......@@ -524,19 +523,37 @@ def evaluate(
# or `task_name: []`.
# we only want to operate on groups here.
continue
for metric in [
key
for key in results[task_list[0]].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
]: # TODO: what if tasks don't all share the same metrics
group_metrics = list(
dict.fromkeys(
[
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
]
)
)
for metric in group_metrics:
# TODO: what if tasks don't all share the same metrics
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [
results[task][metric] for task in task_list
results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy?
stderrs = [results[task][stderr] for task in task_list]
sizes = [results[task]["samples"] for task in task_list]
stderrs = [
results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
# compute group's pooled metric and stderr
results[group][
......
......@@ -16,4 +16,3 @@ filter_list:
- function: "regex"
regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))"
- function: "take_first"
......@@ -15,4 +15,3 @@ filter_list:
- function: "regex"
regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))"
- function: "take_first"
......@@ -15,4 +15,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
\ No newline at end of file
version: 0.0
......@@ -17,7 +17,6 @@ from typing import (
)
import numpy as np
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment