"configs/models/codellama/hf_codellama_7b_instruct.py" did not exist on "5f2e7c3469a1b4d13c84cbd0e20d6141a3165348"
Commit a18104a4 authored by Leo Gao's avatar Leo Gao
Browse files

Move higher_is_better and aggregation into their own functions

parent 0f9c1624
...@@ -130,22 +130,8 @@ class Dataset(abc.ABC): ...@@ -130,22 +130,8 @@ class Dataset(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
list of dicts, each with the following format: dict where keys are the names of submetrics and values are the values of
the metric for that one document
{
"submetric": str,
"value": float,
"higher_is_better": bool,
"aggregation": ([float] -> float),
}
* `submetric` should be the name of the metric
* `value` should be the value of the metric
* `higher_is_better` determines whether a higher metric is better
* `aggregation` should be a function that takes a list of floats and
aggregates them into one float. This should be the same for all
submetrics of the same name; if it differs, an error should be
raised.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
...@@ -154,6 +140,24 @@ class Dataset(abc.ABC): ...@@ -154,6 +140,24 @@ class Dataset(abc.ABC):
""" """
pass pass
@abc.abstractmethod
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
pass
@abc.abstractmethod
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
pass
def fewshot_description(self): def fewshot_description(self):
return "" return ""
......
...@@ -90,11 +90,17 @@ class SATAnalogies(Dataset): ...@@ -90,11 +90,17 @@ class SATAnalogies(Dataset):
acc = 1. if np.argmax(results) == gold else 0. acc = 1. if np.argmax(results) == gold else 0.
return [ return {
{ "acc": acc
"submetric": "acc",
"value": acc,
"higher_is_better": True,
"aggregation": mean
} }
]
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
...@@ -38,18 +38,22 @@ class BoolQ(HFTask): ...@@ -38,18 +38,22 @@ class BoolQ(HFTask):
def process_results(self, doc, results): def process_results(self, doc, results):
ll_yes, ll_no = results ll_yes, ll_no = results
gold = doc["label"] gold = doc["label"]
print(ll_yes > ll_no, gold)
acc = 1. if (ll_yes > ll_no) == gold else 0. acc = 1. if (ll_yes > ll_no) == gold else 0.
return [ return {
{ "acc": acc
"submetric": "acc", }
"value": acc,
"higher_is_better": True, def higher_is_better(self):
"aggregation": mean return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
} }
]
class CommitmentBank(HFTask): class CommitmentBank(HFTask):
......
...@@ -94,20 +94,13 @@ def main(): ...@@ -94,20 +94,13 @@ def main():
doc = docs[(task_name, doc_id)] doc = docs[(task_name, doc_id)]
metrics = task.process_results(doc, requests) metrics = task.process_results(doc, requests)
for metric in metrics: for metric, value in metrics.items():
results[task_name][metric['submetric']] = { vals[(task_name, metric)].append(value)
"higher_is_better": metric["higher_is_better"],
"aggregation": metric["aggregation"]
}
vals[(task_name, metric['submetric'])].append(metric['value'])
# aggregate results # aggregate results
for task_name, submetrics in results.items(): for (task_name, metric), items in vals.items():
for k in submetrics.keys(): task = task_dict[task_name]
submetrics[k]['value'] = submetrics[k]['aggregation'](vals[(task_name, k)]) results[task_name][metric] = task.aggregation()[metric](items)
# can't serialize a function
del submetrics[k]['aggregation']
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment