Commit 011a0b0c authored by Leo Gao's avatar Leo Gao
Browse files

Finish plumbing

parent e41a082c
...@@ -31,18 +31,27 @@ def main(): ...@@ -31,18 +31,27 @@ def main():
else: else:
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict(task_names)
task_dict_items = list(task_dict.items())
# TODO: fall back to test docs
task_dict_items = [(name, task) for name, task in task_dict.items() if task.has_validation_docs()]
results = {} results = {}
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
requests_lengths = collections.defaultdict(list) requests_origin = collections.defaultdict(list)
# if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
# we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
# (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {}
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
# TODO: fall back to test docs for doc_id, doc in enumerate(itertools.islice(task.validation_docs(), 0, args.limit)):
if not task.has_validation_docs(): docs[(task_name, doc_id)] = doc
continue
for doc in itertools.islice(task.validation_docs(), 0, args.limit):
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc,
provide_description=args.provide_description, provide_description=args.provide_description,
...@@ -51,22 +60,40 @@ def main(): ...@@ -51,22 +60,40 @@ def main():
reqs = task.construct_requests(ctx) reqs = task.construct_requests(ctx)
lengths = collections.defaultdict(int) for i, req in enumerate(reqs):
for req in reqs:
requests[req.type].append(req) requests[req.type].append(req)
lengths[req.type] += 1 requests_origin[req.type].append((i, task_name, doc, doc_id))
for type, ct in lengths.items(): process_res_queue = collections.defaultdict(list)
requests_lengths[type].append(ct)
for reqtype, reqs in requests.items():
# TODO: finish implementation resps = getattr(lm, reqtype)([req.args for req in reqs])
for reqname, reqs in requests.items():
lm_res = getattr(lm, reqname)([req.args for req in reqs]) for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
for task_name, task in task_dict_items:
if not task.has_validation_docs(): vals = collections.defaultdict(list)
continue
for (task_name, doc_id), args in process_res_queue.items():
args.sort(lambda x: x[0])
args = [x[1] for x in args]
task = task_dict[task_name]
doc = docs[(task_name, doc_id)]
metrics = task.process_results(doc, args)
for metric in metrics:
results[(task_name, metric['submetric'])] = {
"higher_is_better": metric["higher_is_better"],
"aggregation": metric["aggregation"]
}
vals[(task_name, metric['submetric'])].append(metric['value'])
for k in results.keys():
results[k]['value'] = results[k]['aggregation'](vals[k])
# can't serialize a function
del results[k]['aggregation']
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment