Commit 9484eecc authored by jon-tow's avatar jon-tow
Browse files

Fix coqa

parent 7d282b5f
import abc
from typing import Iterable
import promptsource
import numpy as np
import random
import re
......@@ -639,11 +640,12 @@ class PromptSourceTask(Task):
self.prompt = prompt
def doc_to_target(self, doc):
_, target = prompt.apply(doc)
_, target = self.prompt.apply(doc)
return f" {target}"
def doc_to_text(self, doc):
text, _ = prompt.apply(doc)
print(doc)
text, _ = self.prompt.apply(doc)
return text
def construct_requests(self, doc, ctx):
......@@ -660,7 +662,7 @@ class PromptSourceTask(Task):
_requests = []
if self.prompt.metadata.choices_in_prompt:
for answer_choice in prompt.get_fixed_answer_choices_list():
for answer_choice in self.prompt.get_fixed_answer_choices_list():
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice)
else:
......
......@@ -169,8 +169,10 @@ def evaluate(
docs = {}
# get lists of each type of request
for task_name, task in task_dict_items:
versions[task_name] = task.VERSION
for task_prompt_name, task in task_dict_items:
print(f"TASK PROMPT NAME: {task_prompt_name}")
versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs():
......@@ -187,13 +189,13 @@ def evaluate(
rnd.shuffle(task_docs)
description = (
description_dict[task_name]
if description_dict and task_name in description_dict
description_dict[task_prompt_name]
if description_dict and task_prompt_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
docs[(task_name, doc_id)] = doc
docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
......@@ -204,7 +206,7 @@ def evaluate(
requests[req.request_type].append(req)
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id))
requests_origin[req.request_type].append((i, task_prompt_name, doc, doc_id))
# all responses for each (task, doc)
process_res_queue = collections.defaultdict(list)
......@@ -222,32 +224,33 @@ def evaluate(
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
for resp, (i, task_prompt_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_prompt_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task
for (task_name, doc_id), requests in process_res_queue.items():
for (task_prompt_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests]
task = task_dict[task_name]
doc = docs[(task_name, doc_id)]
task = task_dict[task_prompt_name]
doc = docs[(task_prompt_name, doc_id)]
metrics = task.process_results(doc, requests)
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
vals[(task_prompt_name, metric)].append(value)
task_name, prompt_name = task_name.split("+")
results[task_name]["task_name"] = task_name
results[task_name]["prompt_name"] = prompt_name
# aggregate results
for (task_name, metric), items in vals.items():
for (task_prompt_name, metric), items in vals.items():
task_name, prompt_name = task_prompt_name.split("+")
results[task_prompt_name]["task_name"] = task_name
results[task_prompt_name]["prompt_name"] = prompt_name
task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items)
results[task_prompt_name][metric] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
......@@ -258,7 +261,7 @@ def evaluate(
else bootstrap_iters,
)
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
results[task_prompt_name][metric + "_stderr"] = stderr(items)
return {"results": dict(results), "versions": dict(versions)}
......
......@@ -58,10 +58,11 @@ class Arithmetic(Task):
def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc["completion"])
return is_prediction
return ll, is_prediction
def process_results(self, doc, results):
is_prediction, = results
print(results)
results = results
return {
"acc": is_prediction
}
......
......@@ -12,7 +12,7 @@ Homepage: https://stanfordnlp.github.io/coqa/
import inspect
import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa
from lm_eval.base import PromptSourceTask, rf, mean
from lm_eval.base import PromptSourceTask, Task, rf, mean
from itertools import zip_longest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment