Commit baff2568 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fix matthews corr. metric

parent 50267992
......@@ -20,6 +20,8 @@ def median(arr):
return arr[len(arr) // 2]
# Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
......@@ -35,6 +37,25 @@ def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_aggregation("f1")
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
@register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
# print(preds)
return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_metric(
metric="acc",
higher_is_better=True,
......@@ -119,27 +140,24 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))
@register_metric(metric="matthews_corrcoef", higher_is_better=True, aggregation="mean")
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_metric(
metric="mcc",
higher_is_better=True,
output_type="multiple_choice",
aggregation="matthews_corrcoef",
)
def mcc_fn(items): # This is a passthrough function
return items
@register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
aggregation="mean",
aggregation="f1",
)
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
def f1_fn(items): # This is a passthrough function
return items
@register_metric(
......
......@@ -26,7 +26,10 @@ def register_model(*names):
def get_model(model_name):
return MODEL_REGISTRY[model_name]
try:
return MODEL_REGISTRY[model_name]
except KeyError:
raise ValueError(f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}")
TASK_REGISTRY = {}
......@@ -133,7 +136,7 @@ searching in HF Evaluate library..."
def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
......
......@@ -106,7 +106,21 @@ class TaskConfig(dict):
return getattr(self, item)
def to_dict(self):
return asdict(self)
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
:return: dict
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
cfg_dict = asdict(self)
# remove values that are `None`
for k, v in list(cfg_dict.items()):
if v is None:
cfg_dict.pop(k)
return cfg_dict
class Task(abc.ABC):
......@@ -653,6 +667,7 @@ class ConfigurableTask(Task):
else:
if self._config.num_fewshot > 0:
eval_logger.warning(
f"Task '{self._config.task}': "
"num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule."
)
......@@ -842,7 +857,8 @@ class ConfigurableTask(Task):
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (pred, gold)} if "f1" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
}
......
......@@ -57,7 +57,7 @@ def oa_completion(**kwargs):
backoff_time *= 1.5
@register_model("openai"., "gooseai")
@register_model("openai", "openai-completions", "gooseai")
class GPT3LM(LM):
REQ_CHUNK_SIZE = 20
......
# v1.0 Tasks
This list keeps track of which tasks' implementations have been ported to YAML / v2.0 of the Eval Harness.
Boxes should be checked iff tasks are implemented in v2.0 and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation.
Boxes should be checked iff tasks are implemented in the refactor and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation.
- [ ] Glue
- [ ] SuperGlue
- [x] SuperGlue
- [ ] CoQA
- [ ] DROP
- [x] ~~Lambada~~
......@@ -31,7 +31,7 @@ Boxes should be checked iff tasks are implemented in v2.0 and tested for regress
- [ ] WebQs
- [ ] WSC273
- [ ] Winogrande
- [ ] ANLI
- [x] ANLI
- [ ] Hendrycks Ethics
- [ ] TruthfulQA
- [ ] MuTual
......
......@@ -3,6 +3,7 @@ from typing import List, Union
from .gsm8k import *
from .triviaqa import *
from .glue import *
from lm_eval import utils
from lm_eval.logger import eval_logger
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment