Unverified Commit 6803e647 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #79 from EleutherAI/bmk_refactor

Bmk refactor
parents 2e1b05d2 041ea8a7
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef
......@@ -26,6 +28,7 @@ class SQuAD(HFTask):
return self.data["validation"]
def fewshot_description(self):
# TODO: redo description
return "Title: The_Title_of_It\n\nBackground: A text passage as background to answer the question with.\n\nQ: Question about the passage.\n\nA: Answer."
def doc_to_text(self, doc, include_target=True):
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import json
import random
from lm_eval.base import Dataset
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import numpy as np
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno
from lm_eval.base import rf, mean
class BoolQ(HFTask):
DATASET_PATH = "super_glue"
......@@ -19,21 +21,33 @@ class BoolQ(HFTask):
def fewshot_description(self):
return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc, include_target=True):
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer: " \
+ (yesno(doc['label']) if include_target else "")
def doc_to_text(self, doc):
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer: "
def doc_to_target(self, doc):
return yesno(doc['label'])
def evaluate(self, docs, lm, provide_description, num_fewshot):
golds = [doc["label"] for doc in docs]
preds = []
for doc in docs:
ctx = self.fewshot_context(
doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot,
)
preds.append(lm.loglikelihood(ctx, ' yes') > lm.loglikelihood(ctx, ' no'))
return simple_accuracy_metric(preds=preds, golds=golds)
def construct_requests(self, ctx):
ll_yes = rf.loglikelihood(ctx, ' yes')
ll_no = rf.loglikelihood(ctx, ' no')
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
return [
{
"submetric": "acc",
"value": acc,
"higher_is_better": True,
"aggregation": mean
}
]
class CommitmentBank(HFTask):
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import json
import random
from lm_eval.base import Dataset
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
from . common import HFTask
class WebQs(HFTask):
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import json
import random
import os
......
......@@ -3,6 +3,7 @@ import json
import numpy as np
import random
import itertools
import collections
from lm_eval import models, tasks
......@@ -16,7 +17,7 @@ def parse_args():
parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', default=None)
parser.add_argument('--limit', type=int, default=None)
return parser.parse_args()
def main():
......@@ -30,17 +31,70 @@ def main():
else:
task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names)
results = {}
for task_name, task in task_dict.items():
if not task.has_validation_docs():
continue
result = task.evaluate(
docs=itertools.isslice(task.validation_docs(), 0, args.limit),
lm=lm,
provide_description=args.provide_description,
num_fewshot=args.num_fewshot,
)
results[task_name] = result
# TODO: fall back to test docs
task_dict_items = [(name, task) for name, task in task_dict.items() if task.has_validation_docs()]
results = collections.defaultdict(dict)
requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list)
# if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
# we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
# (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {}
for task_name, task in task_dict_items:
for doc_id, doc in enumerate(itertools.islice(task.validation_docs(), 0, args.limit)):
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc,
provide_description=args.provide_description,
num_fewshot=args.num_fewshot,
)
reqs = task.construct_requests(ctx)
for i, req in enumerate(reqs):
requests[req.type].append(req)
requests_origin[req.type].append((i, task_name, doc, doc_id))
process_res_queue = collections.defaultdict(list)
for reqtype, reqs in requests.items():
resps = getattr(lm, reqtype)([req.args for req in reqs])
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list)
for (task_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests]
task = task_dict[task_name]
doc = docs[(task_name, doc_id)]
metrics = task.process_results(doc, requests)
for metric in metrics:
results[task_name][metric['submetric']] = {
"higher_is_better": metric["higher_is_better"],
"aggregation": metric["aggregation"]
}
vals[(task_name, metric['submetric'])].append(metric['value'])
for task_name, submetrics in results.items():
for k in submetrics.keys():
submetrics[k]['value'] = submetrics[k]['aggregation'](vals[(task_name, k)])
# can't serialize a function
del submetrics[k]['aggregation']
dumped = json.dumps(results, indent=2)
print(dumped)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment