Commit 9484eecc authored by jon-tow's avatar jon-tow
Browse files

Fix coqa

parent 7d282b5f
import abc import abc
from typing import Iterable from typing import Iterable
import promptsource
import numpy as np import numpy as np
import random import random
import re import re
...@@ -639,11 +640,12 @@ class PromptSourceTask(Task): ...@@ -639,11 +640,12 @@ class PromptSourceTask(Task):
self.prompt = prompt self.prompt = prompt
def doc_to_target(self, doc): def doc_to_target(self, doc):
_, target = prompt.apply(doc) _, target = self.prompt.apply(doc)
return f" {target}" return f" {target}"
def doc_to_text(self, doc): def doc_to_text(self, doc):
text, _ = prompt.apply(doc) print(doc)
text, _ = self.prompt.apply(doc)
return text return text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -660,7 +662,7 @@ class PromptSourceTask(Task): ...@@ -660,7 +662,7 @@ class PromptSourceTask(Task):
_requests = [] _requests = []
if self.prompt.metadata.choices_in_prompt: if self.prompt.metadata.choices_in_prompt:
for answer_choice in prompt.get_fixed_answer_choices_list(): for answer_choice in self.prompt.get_fixed_answer_choices_list():
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
......
...@@ -169,8 +169,10 @@ def evaluate( ...@@ -169,8 +169,10 @@ def evaluate(
docs = {} docs = {}
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict_items: for task_prompt_name, task in task_dict_items:
versions[task_name] = task.VERSION print(f"TASK PROMPT NAME: {task_prompt_name}")
versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable # default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs(): if task.has_test_docs():
...@@ -187,13 +189,13 @@ def evaluate( ...@@ -187,13 +189,13 @@ def evaluate(
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
description = ( description = (
description_dict[task_name] description_dict[task_prompt_name]
if description_dict and task_name in description_dict if description_dict and task_prompt_name in description_dict
else "" else ""
) )
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
docs[(task_name, doc_id)] = doc docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
) )
...@@ -204,7 +206,7 @@ def evaluate( ...@@ -204,7 +206,7 @@ def evaluate(
requests[req.request_type].append(req) requests[req.request_type].append(req)
# i: index in requests for a single task instance # i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id)) requests_origin[req.request_type].append((i, task_prompt_name, doc, doc_id))
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
...@@ -222,32 +224,33 @@ def evaluate( ...@@ -222,32 +224,33 @@ def evaluate(
x if req.index is None else x[req.index] for x, req in zip(resps, reqs) x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
] ]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): for resp, (i, task_prompt_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp)) process_res_queue[(task_prompt_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
for (task_name, doc_id), requests in process_res_queue.items(): for (task_prompt_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0]) requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests] requests = [x[1] for x in requests]
task = task_dict[task_name] task = task_dict[task_prompt_name]
doc = docs[(task_name, doc_id)] doc = docs[(task_prompt_name, doc_id)]
metrics = task.process_results(doc, requests) metrics = task.process_results(doc, requests)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, metric)].append(value) vals[(task_prompt_name, metric)].append(value)
task_name, prompt_name = task_name.split("+")
results[task_name]["task_name"] = task_name
results[task_name]["prompt_name"] = prompt_name
# aggregate results # aggregate results
for (task_name, metric), items in vals.items(): for (task_prompt_name, metric), items in vals.items():
task_name, prompt_name = task_prompt_name.split("+")
results[task_prompt_name]["task_name"] = task_name
results[task_prompt_name]["prompt_name"] = prompt_name
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) results[task_prompt_name][metric] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
...@@ -258,7 +261,7 @@ def evaluate( ...@@ -258,7 +261,7 @@ def evaluate(
else bootstrap_iters, else bootstrap_iters,
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_prompt_name][metric + "_stderr"] = stderr(items)
return {"results": dict(results), "versions": dict(versions)} return {"results": dict(results), "versions": dict(versions)}
......
...@@ -58,10 +58,11 @@ class Arithmetic(Task): ...@@ -58,10 +58,11 @@ class Arithmetic(Task):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc["completion"]) ll, is_prediction = rf.loglikelihood(ctx, doc["completion"])
return is_prediction return ll, is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
is_prediction, = results print(results)
results = results
return { return {
"acc": is_prediction "acc": is_prediction
} }
......
...@@ -12,7 +12,7 @@ Homepage: https://stanfordnlp.github.io/coqa/ ...@@ -12,7 +12,7 @@ Homepage: https://stanfordnlp.github.io/coqa/
import inspect import inspect
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa import lm_eval.datasets.coqa.coqa
from lm_eval.base import PromptSourceTask, rf, mean from lm_eval.base import PromptSourceTask, Task, rf, mean
from itertools import zip_longest from itertools import zip_longest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment