Commit baff2568 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fix matthews corr. metric

parent 50267992
...@@ -20,6 +20,8 @@ def median(arr): ...@@ -20,6 +20,8 @@ def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
# Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
@register_aggregation("perplexity") @register_aggregation("perplexity")
def perplexity(items): def perplexity(items):
return math.exp(-mean(items)) return math.exp(-mean(items))
...@@ -35,6 +37,25 @@ def bits_per_byte(items): ...@@ -35,6 +37,25 @@ def bits_per_byte(items):
return -weighted_mean(items) / math.log(2) return -weighted_mean(items) / math.log(2)
@register_aggregation("f1")
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
@register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
# print(preds)
return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_metric( @register_metric(
metric="acc", metric="acc",
higher_is_better=True, higher_is_better=True,
...@@ -119,27 +140,24 @@ def mean_stderr(arr): ...@@ -119,27 +140,24 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr)) return sample_stddev(arr) / math.sqrt(len(arr))
@register_metric(metric="matthews_corrcoef", higher_is_better=True, aggregation="mean") @register_metric(
def matthews_corrcoef(items): metric="mcc",
unzipped_list = list(zip(*items)) higher_is_better=True,
golds = unzipped_list[0] output_type="multiple_choice",
preds = unzipped_list[1] aggregation="matthews_corrcoef",
return sklearn.metrics.matthews_corrcoef(golds, preds) )
def mcc_fn(items): # This is a passthrough function
return items
@register_metric( @register_metric(
metric="f1", metric="f1",
higher_is_better=True, higher_is_better=True,
output_type="multiple_choice", output_type="multiple_choice",
aggregation="mean", aggregation="f1",
) )
def f1_score(items): def f1_fn(items): # This is a passthrough function
unzipped_list = list(zip(*items)) return items
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
@register_metric( @register_metric(
......
...@@ -26,7 +26,10 @@ def register_model(*names): ...@@ -26,7 +26,10 @@ def register_model(*names):
def get_model(model_name): def get_model(model_name):
return MODEL_REGISTRY[model_name] try:
return MODEL_REGISTRY[model_name]
except KeyError:
raise ValueError(f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}")
TASK_REGISTRY = {} TASK_REGISTRY = {}
...@@ -133,7 +136,7 @@ searching in HF Evaluate library..." ...@@ -133,7 +136,7 @@ searching in HF Evaluate library..."
def register_aggregation(name): def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn): def decorate(fn):
assert ( assert (
name not in AGGREGATION_REGISTRY name not in AGGREGATION_REGISTRY
......
...@@ -106,7 +106,21 @@ class TaskConfig(dict): ...@@ -106,7 +106,21 @@ class TaskConfig(dict):
return getattr(self, item) return getattr(self, item)
def to_dict(self): def to_dict(self):
return asdict(self) """dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
:return: dict
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
cfg_dict = asdict(self)
# remove values that are `None`
for k, v in list(cfg_dict.items()):
if v is None:
cfg_dict.pop(k)
return cfg_dict
class Task(abc.ABC): class Task(abc.ABC):
...@@ -653,6 +667,7 @@ class ConfigurableTask(Task): ...@@ -653,6 +667,7 @@ class ConfigurableTask(Task):
else: else:
if self._config.num_fewshot > 0: if self._config.num_fewshot > 0:
eval_logger.warning( eval_logger.warning(
f"Task '{self._config.task}': "
"num_fewshot > 0 but fewshot_split is None. " "num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule." "using preconfigured rule."
) )
...@@ -842,7 +857,8 @@ class ConfigurableTask(Task): ...@@ -842,7 +857,8 @@ class ConfigurableTask(Task):
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (pred, gold)} if "f1" in use_metric else {}), **({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
} }
......
...@@ -57,7 +57,7 @@ def oa_completion(**kwargs): ...@@ -57,7 +57,7 @@ def oa_completion(**kwargs):
backoff_time *= 1.5 backoff_time *= 1.5
@register_model("openai"., "gooseai") @register_model("openai", "openai-completions", "gooseai")
class GPT3LM(LM): class GPT3LM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20
......
# v1.0 Tasks # v1.0 Tasks
This list keeps track of which tasks' implementations have been ported to YAML / v2.0 of the Eval Harness. This list keeps track of which tasks' implementations have been ported to YAML / v2.0 of the Eval Harness.
Boxes should be checked iff tasks are implemented in v2.0 and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation. Boxes should be checked iff tasks are implemented in the refactor and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation.
- [ ] Glue - [ ] Glue
- [ ] SuperGlue - [x] SuperGlue
- [ ] CoQA - [ ] CoQA
- [ ] DROP - [ ] DROP
- [x] ~~Lambada~~ - [x] ~~Lambada~~
...@@ -31,7 +31,7 @@ Boxes should be checked iff tasks are implemented in v2.0 and tested for regress ...@@ -31,7 +31,7 @@ Boxes should be checked iff tasks are implemented in v2.0 and tested for regress
- [ ] WebQs - [ ] WebQs
- [ ] WSC273 - [ ] WSC273
- [ ] Winogrande - [ ] Winogrande
- [ ] ANLI - [x] ANLI
- [ ] Hendrycks Ethics - [ ] Hendrycks Ethics
- [ ] TruthfulQA - [ ] TruthfulQA
- [ ] MuTual - [ ] MuTual
......
...@@ -3,6 +3,7 @@ from typing import List, Union ...@@ -3,6 +3,7 @@ from typing import List, Union
from .gsm8k import * from .gsm8k import *
from .triviaqa import * from .triviaqa import *
from .glue import *
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment