Commit a18104a4 authored by Leo Gao's avatar Leo Gao
Browse files

Move higher_is_better and aggregation into their own functions

parent 0f9c1624
......@@ -130,22 +130,8 @@ class Dataset(abc.ABC):
@abc.abstractmethod
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
list of dicts, each with the following format:
{
"submetric": str,
"value": float,
"higher_is_better": bool,
"aggregation": ([float] -> float),
}
* `submetric` should be the name of the metric
* `value` should be the value of the metric
* `higher_is_better` determines whether a higher metric is better
* `aggregation` should be a function that takes a list of floats and
aggregates them into one float. This should be the same for all
submetrics of the same name; if it differs, an error should be
raised.
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
......@@ -154,6 +140,24 @@ class Dataset(abc.ABC):
"""
pass
@abc.abstractmethod
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
pass
@abc.abstractmethod
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
pass
def fewshot_description(self):
return ""
......
......@@ -90,11 +90,17 @@ class SATAnalogies(Dataset):
acc = 1. if np.argmax(results) == gold else 0.
return [
{
"submetric": "acc",
"value": acc,
"higher_is_better": True,
"aggregation": mean
}
]
return {
"acc": acc
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
......@@ -38,18 +38,22 @@ class BoolQ(HFTask):
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
print(ll_yes > ll_no, gold)
acc = 1. if (ll_yes > ll_no) == gold else 0.
return [
{
"submetric": "acc",
"value": acc,
"higher_is_better": True,
"aggregation": mean
}
]
return {
"acc": acc
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
class CommitmentBank(HFTask):
......
......@@ -94,20 +94,13 @@ def main():
doc = docs[(task_name, doc_id)]
metrics = task.process_results(doc, requests)
for metric in metrics:
results[task_name][metric['submetric']] = {
"higher_is_better": metric["higher_is_better"],
"aggregation": metric["aggregation"]
}
vals[(task_name, metric['submetric'])].append(metric['value'])
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
# aggregate results
for task_name, submetrics in results.items():
for k in submetrics.keys():
submetrics[k]['value'] = submetrics[k]['aggregation'](vals[(task_name, k)])
# can't serialize a function
del submetrics[k]['aggregation']
for (task_name, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items)
dumped = json.dumps(results, indent=2)
print(dumped)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment