Commit a21df355 authored by thefazzer's avatar thefazzer
Browse files

Merge remote-tracking branch 'origin/master' into fazz/refactor-task-coqa

parents efa810f0 1815286c
...@@ -2,6 +2,7 @@ import abc ...@@ -2,6 +2,7 @@ import abc
import random import random
import numpy as np import numpy as np
import sklearn import sklearn
import math
class LM(abc.ABC): class LM(abc.ABC):
...@@ -58,10 +59,10 @@ class LM(abc.ABC): ...@@ -58,10 +59,10 @@ class LM(abc.ABC):
return cls() return cls()
class Dataset(abc.ABC): class Task(abc.ABC):
def __init__(self): def __init__(self):
self.download() self.download()
self._traindocs = None self._training_docs = None
def download(self): def download(self):
"""Downloads the task dataset if necessary""" """Downloads the task dataset if necessary"""
...@@ -71,7 +72,7 @@ class Dataset(abc.ABC): ...@@ -71,7 +72,7 @@ class Dataset(abc.ABC):
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def has_validation_docs(self): def has_validation_docs(self):
"""Whether the task has a validation set""" """Whether the task has a validation set"""
...@@ -84,23 +85,29 @@ class Dataset(abc.ABC): ...@@ -84,23 +85,29 @@ class Dataset(abc.ABC):
def training_docs(self): def training_docs(self):
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def validation_docs(self): def validation_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [] return []
def test_docs(self): def test_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [] return []
def fewshot_examples(self, k):
if self._traindocs is None:
self._traindocs = list(self.training_docs())
return random.sample(self._traindocs, k) def fewshot_examples(self, k):
if self._training_docs is None:
self._training_docs = list(self.training_docs())
return random.sample(self._training_docs, k)
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -123,7 +130,7 @@ class Dataset(abc.ABC): ...@@ -123,7 +130,7 @@ class Dataset(abc.ABC):
part of the document for `doc`. part of the document for `doc`.
""" """
pass pass
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -161,7 +168,7 @@ class Dataset(abc.ABC): ...@@ -161,7 +168,7 @@ class Dataset(abc.ABC):
def fewshot_context(self, doc, num_fewshot, provide_description): def fewshot_context(self, doc, num_fewshot, provide_description):
raw_description = self.fewshot_description() raw_description = self.fewshot_description()
description = (raw_description + "\n===\n\n") if provide_description and raw_description else "" description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
if num_fewshot == 0: if num_fewshot == 0:
labeled_examples = "" labeled_examples = ""
else: else:
...@@ -193,7 +200,8 @@ def f1_score(items): ...@@ -193,7 +200,8 @@ def f1_score(items):
golds = unzipped_list[0] golds = unzipped_list[0]
preds = unzipped_list[1] preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds) fscore = sklearn.metrics.f1_score(golds, preds)
return max(fscore)
return np.max(fscore)
def acc_all(items): def acc_all(items):
...@@ -223,6 +231,9 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): ...@@ -223,6 +231,9 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths) return max(scores_for_ground_truths)
def perplexity(items):
return math.exp(-mean(items))
req_ret_lens = { req_ret_lens = {
'loglikelihood': 2 'loglikelihood': 2
} }
......
import collections
import itertools
def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
results = collections.defaultdict(dict)
requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list)
# if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
# we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
# (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {}
# get lists of each type of requeste
for task_name, task in task_dict_items:
#default to validation doc, fall back to test doc if validation unavailable
# TODO: the val-fallback-to-test system isn't final, we should revisit it at some point
if task.has_validation_docs():
task_doc_func = task.validation_docs
elif task.has_test_docs():
task_doc_func = task.test_docs
for doc_id, doc in enumerate(itertools.islice(task_doc_func(), 0, limit)):
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot,
)
reqs = task.construct_requests(doc, ctx)
for i, req in enumerate(reqs):
requests[req.type].append(req)
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.type].append((i, task_name, doc, doc_id))
# all responses for each (task, doc)
process_res_queue = collections.defaultdict(list)
# execute each type of request
for reqtype, reqs in requests.items():
# TODO: right now, this code runs multiple seperate LM requests for multiple Requests differing
# only in index. We could implement some kind of caching, but that would be more of a bandaid
# solution. we could also implement some kind of autogrouping here; they should end up next to each other.
resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task
for (task_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests]
task = task_dict[task_name]
doc = docs[(task_name, doc_id)]
metrics = task.process_results(doc, requests)
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
# aggregate results
for (task_name, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items)
return results
\ No newline at end of file
from . import gpt2 from . import gpt2
from . import gpt3 from . import gpt3
from . import dummy
MODEL_REGISTRY = { MODEL_REGISTRY = {
"gpt2": gpt2.GPT2LM, "gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM, "gpt3": gpt3.GPT3LM,
"dummy": dummy.DummyLM,
} }
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import os import os
import transformers import transformers
from lm_eval.base import LM from lm_eval.base import LM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm
def get_result(response, ctxlen):
is_greedy = True
logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:])
for i in range(ctxlen, len(response["logprobs"]["tokens"])):
token = response["logprobs"]["tokens"][i]
top_tokens = response["logprobs"]["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
class GPT3LM(LM): class GPT3LM(LM):
MAX_LENGTH = 2048 MAX_LENGTH = 2048
REQ_CHUNK_SIZE = 64
MAX_GEN_TOKS = 256
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate=False):
""" """
...@@ -31,23 +48,52 @@ class GPT3LM(LM): ...@@ -31,23 +48,52 @@ class GPT3LM(LM):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci")) return cls(engine=args.get("engine", "davinci"))
def loglikelihood(self, context, continuation): def loglikelihood(self, requests):
# TODO: implement new framework import openai
for chunk in tqdm(utils.chunks(requests, self.REQ_CHUNK_SIZE)):
inps = []
ctxlens = []
for context, continuation in chunk:
print(context)
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)
inps.append(inp)
ctxlens.append(ctxlen)
response = openai.Completion.create(
engine=self.engine,
prompt=inps,
echo=True,
max_tokens=0, temperature=0.,
logprobs=10,
)
for resp, ctxlen in zip(response.choices, ctxlens):
res.append(get_result(resp, ctxlen))
return res
def greedy_until(self, requests):
import openai import openai
res = []
for context, until in tqdm(requests):
context_enc = self.tokenizer.encode(context)
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
ctxlen = len(context_enc) - max(0, len(context_enc) - (self.MAX_LENGTH - self.MAX_GEN_TOKS))
response = openai.Completion.create(
engine=self.engine,
prompt=[inp],
max_tokens=self.MAX_GEN_TOKS,
temperature=0.,
logprobs=10,
)
res.append(response.choices[0]['text'])
context_enc = self.tokenizer.encode(context) return res
continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-1024:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
response = openai.Completion.create(
engine=self.engine,
prompt=inp,
echo=True,
max_tokens=0, temperature=0.0,
logprobs=0,
)
logprobs = response.choices[0]["logprobs"]["token_logprobs"]
continuation_logprobs = logprobs[ctxlen:]
return sum(continuation_logprobs)
...@@ -18,6 +18,7 @@ from . import lambada ...@@ -18,6 +18,7 @@ from . import lambada
from . import race from . import race
from . import piqa from . import piqa
from . import triviaqa from . import triviaqa
from . import webqs
TASK_REGISTRY = { TASK_REGISTRY = {
...@@ -37,7 +38,7 @@ TASK_REGISTRY = { ...@@ -37,7 +38,7 @@ TASK_REGISTRY = {
"cb": superglue.CommitmentBank, "cb": superglue.CommitmentBank,
"copa": superglue.Copa, "copa": superglue.Copa,
"multirc": superglue.MultiRC, "multirc": superglue.MultiRC,
"record": superglue.ReCoRD, #"record": superglue.ReCoRD,
"wic": superglue.WordsInContext, "wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
...@@ -56,9 +57,9 @@ TASK_REGISTRY = { ...@@ -56,9 +57,9 @@ TASK_REGISTRY = {
# "squad": squad.SQuAD, # not implemented yet # "squad": squad.SQuAD, # not implemented yet
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "webqs": webqs.WebQs, # not implemented yet "webqs": webqs.WebQs,
# "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet "wsc273": wsc273.WinogradSchemaChallenge273,
# "winogrande": winogrande.Winogrande, # not implemented yet "winogrande": winogrande.Winogrande,
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
......
...@@ -2,12 +2,12 @@ import abc ...@@ -2,12 +2,12 @@ import abc
import json import json
import os import os
from collections import namedtuple from collections import namedtuple
from lm_eval.base import Dataset, mean, rf from lm_eval.base import Task, mean, rf
from best_download import download_file from best_download import download_file
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion']) ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Dataset): class Arithmetic(Task):
directory = 'data/arithmetic/' directory = 'data/arithmetic/'
def __init__(self): def __init__(self):
......
import datasets import datasets
import numpy as np import numpy as np
import random from ..base import Task
from ..base import Dataset
class HFTask(Dataset): class HFTask(Task):
DATASET_PATH = None DATASET_PATH = None
DATASET_NAME = None DATASET_NAME = None
def __init__(self): def __init__(self):
self.data = None
super().__init__() super().__init__()
self._training_docs = None
def download(self): def download(self):
self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME) self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
......
...@@ -5,9 +5,9 @@ from sklearn.metrics import f1_score, matthews_corrcoef ...@@ -5,9 +5,9 @@ from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno from . common import HFTask, simple_accuracy_metric, yesno
from pathlib import Path from pathlib import Path
from ..base import Dataset from ..base import Task
class DROP(Dataset): class DROP(Task):
DATAFOLDER = Path(__file__).parent / "../../data/drop" DATAFOLDER = Path(__file__).parent / "../../data/drop"
def __init__(self): def __init__(self):
......
from lm_eval.base import Dataset, rf, mean from lm_eval.base import Task, rf, mean, perplexity
from lm_eval.utils import sh from lm_eval.utils import sh
import json import json
import math import math
from best_download import download_file from best_download import download_file
class LAMBADA(Dataset): class LAMBADA(Task):
def download(self): def download(self):
sh("mkdir -p data/lambada") sh("mkdir -p data/lambada")
download_file( download_file(
...@@ -45,7 +45,7 @@ class LAMBADA(Dataset): ...@@ -45,7 +45,7 @@ class LAMBADA(Dataset):
return "" return ""
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(doc, self.doc_to_target(doc)) ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy return ll, is_greedy
...@@ -53,13 +53,13 @@ class LAMBADA(Dataset): ...@@ -53,13 +53,13 @@ class LAMBADA(Dataset):
ll, is_greedy = results ll, is_greedy = results
return { return {
'perplexity': math.exp(-ll), 'perplexity': ll,
'accuracy': int(is_greedy) 'accuracy': int(is_greedy)
} }
def aggregation(self): def aggregation(self):
return { return {
'perplexity': mean, 'perplexity': perplexity,
'accuracy': mean 'accuracy': mean
} }
......
...@@ -30,10 +30,10 @@ class NaturalQs(HFTask): ...@@ -30,10 +30,10 @@ class NaturalQs(HFTask):
def fewshot_examples(self, k): def fewshot_examples(self, k):
# Data is too large to fit in memory. We just sample from the first bit. # Data is too large to fit in memory. We just sample from the first bit.
if self._traindocs is None: if self._training_docs is None:
self._traindocs = list(islice(self.training_docs(), 0, 100000)) self._training_docs = list(islice(self.training_docs(), 0, 100000))
return random.sample(self._traindocs, k) return random.sample(self._training_docs, k)
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Q: ' + doc['question']['text'] + '\n\n' + 'A: ' return 'Q: ' + doc['question']['text'] + '\n\n' + 'A: '
......
import json import json
import random import random
from lm_eval.base import Dataset, rf, mean from lm_eval.base import Task, rf, mean
from ..utils import sh from ..utils import sh
import os import os
class PiQA(Dataset): class PiQA(Task):
def download(self): def download(self):
if not os.path.exists('data/piqa'): if not os.path.exists('data/piqa'):
#TODO: use best_download #TODO: use best_download
......
import json import json
import random import random
import os import os
from lm_eval.base import Dataset from lm_eval.base import Task
from ..utils import sh from ..utils import sh
class QuAC(Dataset): class QuAC(Task):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -3,7 +3,19 @@ import datasets ...@@ -3,7 +3,19 @@ import datasets
import numpy as np import numpy as np
from lm_eval.base import rf, mean from lm_eval.base import rf, mean
from . common import HFTask from . common import HFTask
from ..utils_stream import each
import os
from functools import reduce
import operator
from tqdm import tqdm
import json
class each:
def __init__(self, f):
self.f = f
def __rrshift__(self, other):
return list(map(self.f, other))
class RACE(HFTask): class RACE(HFTask):
......
import json import json
import random import random
import os import os
from lm_eval.base import Dataset, rf, mean from lm_eval.base import Task, rf, mean
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
from . common import simple_accuracy_metric from . common import simple_accuracy_metric
import numpy as np import numpy as np
from ..utils import sh from ..utils import sh
class SATAnalogies(Dataset): class SATAnalogies(Task):
NEEDS_MANUAL_DL = True NEEDS_MANUAL_DL = True
def __init__(self): def __init__(self):
......
import json import json
import random import random
from lm_eval.base import Dataset from lm_eval.base import Task
from ..utils import sh from ..utils import sh
import csv import csv
class StoryCloze(Dataset): class StoryCloze(Task):
NEEDS_MANUAL_DL = True NEEDS_MANUAL_DL = True
def download(self): def download(self):
......
...@@ -261,7 +261,7 @@ class ReCoRD(HFTask): ...@@ -261,7 +261,7 @@ class ReCoRD(HFTask):
return True return True
def has_test_docs(self): def has_test_docs(self):
return True return False
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out actual description # TODO: figure out actual description
...@@ -322,6 +322,7 @@ class ReCoRD(HFTask): ...@@ -322,6 +322,7 @@ class ReCoRD(HFTask):
# - Evaluate the accuracy and token F1 PER EXAMPLE # - Evaluate the accuracy and token F1 PER EXAMPLE
# - Average over all examples # - Average over all examples
max_idx = np.argmax(np.array(results)) max_idx = np.argmax(np.array(results))
prediction = doc["entities"][max_idx] prediction = doc["entities"][max_idx]
gold_label_set = list(set(doc["answers"])) gold_label_set = list(set(doc["answers"]))
f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, prediction, gold_label_set) f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, prediction, gold_label_set)
......
import os import os
import json import json
import random import random
from lm_eval.base import Dataset, mean, rf from lm_eval.base import Task, mean, rf
from ..utils import sh from ..utils import sh
class TriviaQA(Dataset): class TriviaQA(Task):
def download(self): def download(self):
if not os.path.exists('data/triviaqa'): if not os.path.exists('data/triviaqa'):
sh(""" sh("""
......
from . common import HFTask from . common import HFTask
from lm_eval.base import mean, rf
class WebQs(HFTask): class WebQs(HFTask):
DATASET_PATH = "web_questions" DATASET_PATH = "web_questions"
...@@ -18,7 +19,6 @@ class WebQs(HFTask): ...@@ -18,7 +19,6 @@ class WebQs(HFTask):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
print(doc)
return "Q: " + doc['question'] + '\nA:' return "Q: " + doc['question'] + '\nA:'
def doc_to_target(self, doc): def doc_to_target(self, doc):
...@@ -26,48 +26,37 @@ class WebQs(HFTask): ...@@ -26,48 +26,37 @@ class WebQs(HFTask):
# multiple correct answers being possible. # multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly # TODO: make sure we're actually handling multi-answer correctly
return " " + doc['answers'][0] return " " + doc['answers'][0]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of ret = []
Requests which will be sent to the LM. for alias in self._remove_prefixes(doc['answers']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a return {
dict where keys are the names of submetrics and values are the values of "acc": float(any(results))
the metric for that one document }
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self): def aggregation(self):
""" return {
:returns: {str: [float] -> float} "acc": mean,
A dictionary where keys are the names of submetrics and values are }
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self): def higher_is_better(self):
""" return {
:returns: {str: bool} "acc": True
A dictionary where keys are the names of submetrics and values are }
whether a higher value of the submetric is better \ No newline at end of file
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
import numpy as np import numpy as np
from scipy.stats import pearsonr, spearmanr from . common import HFTask
from sklearn.metrics import f1_score, matthews_corrcoef from lm_eval.base import rf, mean
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno """
This evaluation of Winogrande uses partial evaluation as described by
Trinh & Le in Simple Method for Commonsense Reasoning (2018).
Reference: https://arxiv.org/abs/1806.02847
"""
class Winogrande(HFTask): class Winogrande(HFTask):
DATASET_PATH = "winogrande" DATASET_PATH = "winogrande"
...@@ -17,35 +22,31 @@ class Winogrande(HFTask): ...@@ -17,35 +22,31 @@ class Winogrande(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def training_docs(self):
if self.has_training_docs():
return self.data["train"]
def validation_docs(self):
if self.has_validation_docs():
return self.data["validation"]
def test_docs(self):
if self.has_test_docs():
return self.data["test"]
def fewshot_description(self): def fewshot_description(self):
# TODO: redo description # TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in." return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the sentence with each candidate choice
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
context1 = doc["sentence"][:pronoun_loc] + doc["option1"]
context2 = doc["sentence"][:pronoun_loc] + doc["option2"]
return context1, context2
@classmethod
def partial_target(cls, doc):
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['sentence'] context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def doc_to_target(self, doc): def doc_to_target(self, doc):
text = doc['sentence'] return self.partial_target(doc)
answer_n = doc['answer']
if answer_n == '1':
answer = doc['option1']
elif answer_n == '2':
answer = doc['option2']
else:
raise ValueError("Winogrande from HF datasets contained an invalid answer key")
return text.replace("_", answer)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -58,9 +59,12 @@ class Winogrande(HFTask): ...@@ -58,9 +59,12 @@ class Winogrande(HFTask):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. target = self.partial_target(doc)
raise NotImplementedError('Evaluation not implemented') context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
...@@ -71,8 +75,10 @@ class Winogrande(HFTask): ...@@ -71,8 +75,10 @@ class Winogrande(HFTask):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. answer = int(doc["answer"]) - 1 # `- 1` b/c doc["answer"] ∈ {'1', '2'}
raise NotImplementedError('Evaluation not implemented') return {
"acc": np.argmax(results) == answer
}
def aggregation(self): def aggregation(self):
""" """
...@@ -80,8 +86,9 @@ class Winogrande(HFTask): ...@@ -80,8 +86,9 @@ class Winogrande(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -89,5 +96,6 @@ class Winogrande(HFTask): ...@@ -89,5 +96,6 @@ class Winogrande(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": True
}
import json import numpy as np
import random import random
import os from lm_eval.base import rf, mean
from lm_eval.base import Dataset from . common import HFTask
from ..utils import sh
"""
NOTE: This evaluation of Winograd Schema Challenge is based on `partial evaluation`
class WinogradSchemaChallenge273(Dataset): as described by Trinh & Le in Simple Method for Commonsense Reasoning (2018).
def __init__(self): See: https://arxiv.org/abs/1806.02847
super().__init__() """
def download(self):
if not os.path.exists('data/wsc273'): class WinogradSchemaChallenge273(HFTask):
sh(""" DATASET_PATH = "winograd_wsc"
mkdir -p data/wsc273 DATASET_NAME = "wsc273"
wget https://git.cse.msu.edu/bakerb15/nlp-final-project/raw/master/Winogard/reproduce/commonsense_test/wsc273.json -O data/wsc273/wsc273.json
""") upper_pronouns = ["A", "An", "The", "She", "He",
"It", "They", "My", "His", "Her", "Their"]
def has_training_docs(self):
return False def __init__(self):
super().__init__()
def has_validation_docs(self): self.data = self.__clean_data()
return False
def __clean_data(self):
def has_test_docs(self): # The HF implementation of `wsc273` is not `partial evaluation` friendly.
return True data = []
for doc in self.data["test"]:
def training_docs(self): doc["text"] = doc["text"].replace(" ", " ")
return [] doc["options"][0] = self.__normalize_option(doc["options"][0], doc)
doc["options"][1] = self.__normalize_option(doc["options"][1], doc)
def validation_docs(self): data.append(doc)
return [] return {"test": data}
def test_docs(self): def __normalize_option(self, option, doc):
myjson = json.load(open('data/wsc273/wsc273.json')) # Append `'s` to possessive determiner based options.
return self.load_doc(myjson) if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
option += "'s"
def fewshot_description(self): # Appropriately lowercase the pronoun in the option.
# TODO: redo description pronoun = option.split()[0]
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False." start_of_sentence = doc["text"][doc['pronoun_loc'] - 2] == '.'
if not start_of_sentence and pronoun in self.upper_pronouns:
def load_doc(self, myjson): return option.replace(pronoun, pronoun.lower())
docs = [] return option
for i in range(0, 273 * 2, 2):
item1 = myjson[i] def has_training_docs(self):
item2 = myjson[i+1] return False
if item1['question_id'] != item2['question_id']: def has_validation_docs(self):
raise ValueError("WSC273 has missing completion pair.") return False
question_id = item1['question_id'] def has_test_docs(self):
return True
if item1['correctness'] == True:
doc = { def fewshot_examples(self, k):
'id': question_id, # NOTE: `super().fewshot_examples` samples from training docs which are
'completions': { # not available for this test-set-only dataset.
'T': item1['substitution'], return random.sample(list(self.test_docs()), k)
'F': item2['substitution'],
}, def fewshot_description(self):
} # TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
if item2['correctness'] == True:
doc = { @classmethod
'id': question_id, def partial_context(cls, doc):
'completions': { # Substitute the pronoun in the original text with each candidate
'F': item1['substitution'], # choice and ignore everything after.
'T': item2['substitution'], context1 = doc["text"][:doc["pronoun_loc"]] + doc["options"][0]
}, context2 = doc["text"][:doc["pronoun_loc"]] + doc["options"][1]
} return context1, context2
docs.append(doc) @classmethod
def partial_target(cls, doc):
return docs # The target is everything after the document specified pronoun.
start_index = doc["pronoun_loc"] + len(doc["pronoun"])
def doc_to_text(self, doc): return doc["text"][start_index:].strip()
# TODO: implement
pass def doc_to_text(self, doc):
context1, context2 = self.partial_context(doc)
def doc_to_target(self, doc): return context1 + '\n' + context2 + '\n'
# TODO: implement
pass def doc_to_target(self, doc):
return self.partial_target(doc)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of def construct_requests(self, doc, ctx):
Requests which will be sent to the LM. """ Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs. :param doc:
:param ctx: str The document as returned from training_docs, validation_docs, or test_docs.
The context string, generated by fewshot_context. This includes the natural :param ctx: str
language description, as well as the few shot examples, and the question The context string, generated by fewshot_context. This includes the natural
part of the document for `doc`. language description, as well as the few shot examples, and the question
""" part of the document for `doc`.
# TODO: implement evaluation. """
raise NotImplementedError('Evaluation not implemented') target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
def process_results(self, doc, results): ll_context1, _ = rf.loglikelihood(context1, " " + target)
"""Take a single document and the LM results and evaluates, returning a ll_context2, _ = rf.loglikelihood(context2, " " + target)
dict where keys are the names of submetrics and values are the values of return ll_context1, ll_context2
the metric for that one document
def process_results(self, doc, results):
:param doc: """Take a single document and the LM results and evaluates, returning a
The document as returned from training_docs, validation_docs, or test_docs. dict where keys are the names of submetrics and values are the values of
:param results: the metric for that one document
The results of the requests created in construct_requests.
""" :param doc:
# TODO: implement evaluation. The document as returned from training_docs, validation_docs, or test_docs.
raise NotImplementedError('Evaluation not implemented') :param results:
The results of the requests created in construct_requests.
def aggregation(self): """
""" return {
:returns: {str: [float] -> float} "acc": np.argmax(results) == doc["label"]
A dictionary where keys are the names of submetrics and values are }
functions that aggregate a list of metrics
""" def aggregation(self):
# TODO: implement evaluation. """
raise NotImplementedError('Evaluation not implemented') :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
def higher_is_better(self): functions that aggregate a list of metrics
""" """
:returns: {str: bool} return {
A dictionary where keys are the names of submetrics and values are "acc": mean
whether a higher value of the submetric is better }
"""
# TODO: implement evaluation. def higher_is_better(self):
raise NotImplementedError('Evaluation not implemented') """
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment