Commit a21df355 authored by thefazzer's avatar thefazzer
Browse files

Merge remote-tracking branch 'origin/master' into fazz/refactor-task-coqa

parents efa810f0 1815286c
......@@ -2,6 +2,7 @@ import abc
import random
import numpy as np
import sklearn
import math
class LM(abc.ABC):
......@@ -58,10 +59,10 @@ class LM(abc.ABC):
return cls()
class Dataset(abc.ABC):
class Task(abc.ABC):
def __init__(self):
self.download()
self._traindocs = None
self._training_docs = None
def download(self):
"""Downloads the task dataset if necessary"""
......@@ -84,23 +85,29 @@ class Dataset(abc.ABC):
def training_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def validation_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def test_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def fewshot_examples(self, k):
if self._traindocs is None:
self._traindocs = list(self.training_docs())
return random.sample(self._traindocs, k)
if self._training_docs is None:
self._training_docs = list(self.training_docs())
return random.sample(self._training_docs, k)
@abc.abstractmethod
def doc_to_text(self, doc):
......@@ -193,7 +200,8 @@ def f1_score(items):
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return max(fscore)
return np.max(fscore)
def acc_all(items):
......@@ -223,6 +231,9 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)
def perplexity(items):
return math.exp(-mean(items))
req_ret_lens = {
'loglikelihood': 2
}
......
import collections
import itertools
def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
results = collections.defaultdict(dict)
requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list)
# if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
# we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
# (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {}
# get lists of each type of requeste
for task_name, task in task_dict_items:
#default to validation doc, fall back to test doc if validation unavailable
# TODO: the val-fallback-to-test system isn't final, we should revisit it at some point
if task.has_validation_docs():
task_doc_func = task.validation_docs
elif task.has_test_docs():
task_doc_func = task.test_docs
for doc_id, doc in enumerate(itertools.islice(task_doc_func(), 0, limit)):
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot,
)
reqs = task.construct_requests(doc, ctx)
for i, req in enumerate(reqs):
requests[req.type].append(req)
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.type].append((i, task_name, doc, doc_id))
# all responses for each (task, doc)
process_res_queue = collections.defaultdict(list)
# execute each type of request
for reqtype, reqs in requests.items():
# TODO: right now, this code runs multiple seperate LM requests for multiple Requests differing
# only in index. We could implement some kind of caching, but that would be more of a bandaid
# solution. we could also implement some kind of autogrouping here; they should end up next to each other.
resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task
for (task_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests]
task = task_dict[task_name]
doc = docs[(task_name, doc_id)]
metrics = task.process_results(doc, requests)
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
# aggregate results
for (task_name, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items)
return results
\ No newline at end of file
from . import gpt2
from . import gpt3
from . import dummy
MODEL_REGISTRY = {
"gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM,
"dummy": dummy.DummyLM,
}
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import os
import transformers
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
def get_result(response, ctxlen):
is_greedy = True
logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:])
for i in range(ctxlen, len(response["logprobs"]["tokens"])):
token = response["logprobs"]["tokens"][i]
top_tokens = response["logprobs"]["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
class GPT3LM(LM):
MAX_LENGTH = 2048
REQ_CHUNK_SIZE = 64
MAX_GEN_TOKS = 256
def __init__(self, engine, truncate=False):
"""
......@@ -31,23 +48,52 @@ class GPT3LM(LM):
args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci"))
def loglikelihood(self, context, continuation):
# TODO: implement new framework
def loglikelihood(self, requests):
import openai
for chunk in tqdm(utils.chunks(requests, self.REQ_CHUNK_SIZE)):
inps = []
ctxlens = []
for context, continuation in chunk:
print(context)
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-1024:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)
inps.append(inp)
ctxlens.append(ctxlen)
response = openai.Completion.create(
engine=self.engine,
prompt=inp,
prompt=inps,
echo=True,
max_tokens=0, temperature=0.0,
logprobs=0,
max_tokens=0, temperature=0.,
logprobs=10,
)
logprobs = response.choices[0]["logprobs"]["token_logprobs"]
continuation_logprobs = logprobs[ctxlen:]
return sum(continuation_logprobs)
for resp, ctxlen in zip(response.choices, ctxlens):
res.append(get_result(resp, ctxlen))
return res
def greedy_until(self, requests):
import openai
res = []
for context, until in tqdm(requests):
context_enc = self.tokenizer.encode(context)
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
ctxlen = len(context_enc) - max(0, len(context_enc) - (self.MAX_LENGTH - self.MAX_GEN_TOKS))
response = openai.Completion.create(
engine=self.engine,
prompt=[inp],
max_tokens=self.MAX_GEN_TOKS,
temperature=0.,
logprobs=10,
)
res.append(response.choices[0]['text'])
return res
......@@ -18,6 +18,7 @@ from . import lambada
from . import race
from . import piqa
from . import triviaqa
from . import webqs
TASK_REGISTRY = {
......@@ -37,7 +38,7 @@ TASK_REGISTRY = {
"cb": superglue.CommitmentBank,
"copa": superglue.Copa,
"multirc": superglue.MultiRC,
"record": superglue.ReCoRD,
#"record": superglue.ReCoRD,
"wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge,
......@@ -56,9 +57,9 @@ TASK_REGISTRY = {
# "squad": squad.SQuAD, # not implemented yet
"race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "webqs": webqs.WebQs, # not implemented yet
# "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet
# "winogrande": winogrande.Winogrande, # not implemented yet
"webqs": webqs.WebQs,
"wsc273": wsc273.WinogradSchemaChallenge273,
"winogrande": winogrande.Winogrande,
"anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3,
......
......@@ -2,12 +2,12 @@ import abc
import json
import os
from collections import namedtuple
from lm_eval.base import Dataset, mean, rf
from lm_eval.base import Task, mean, rf
from best_download import download_file
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Dataset):
class Arithmetic(Task):
directory = 'data/arithmetic/'
def __init__(self):
......
import datasets
import numpy as np
import random
from ..base import Dataset
from ..base import Task
class HFTask(Dataset):
class HFTask(Task):
DATASET_PATH = None
DATASET_NAME = None
def __init__(self):
self.data = None
super().__init__()
self._training_docs = None
def download(self):
self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
......
......@@ -5,9 +5,9 @@ from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno
from pathlib import Path
from ..base import Dataset
from ..base import Task
class DROP(Dataset):
class DROP(Task):
DATAFOLDER = Path(__file__).parent / "../../data/drop"
def __init__(self):
......
from lm_eval.base import Dataset, rf, mean
from lm_eval.base import Task, rf, mean, perplexity
from lm_eval.utils import sh
import json
import math
from best_download import download_file
class LAMBADA(Dataset):
class LAMBADA(Task):
def download(self):
sh("mkdir -p data/lambada")
download_file(
......@@ -45,7 +45,7 @@ class LAMBADA(Dataset):
return ""
def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(doc, self.doc_to_target(doc))
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy
......@@ -53,13 +53,13 @@ class LAMBADA(Dataset):
ll, is_greedy = results
return {
'perplexity': math.exp(-ll),
'perplexity': ll,
'accuracy': int(is_greedy)
}
def aggregation(self):
return {
'perplexity': mean,
'perplexity': perplexity,
'accuracy': mean
}
......
......@@ -30,10 +30,10 @@ class NaturalQs(HFTask):
def fewshot_examples(self, k):
# Data is too large to fit in memory. We just sample from the first bit.
if self._traindocs is None:
self._traindocs = list(islice(self.training_docs(), 0, 100000))
if self._training_docs is None:
self._training_docs = list(islice(self.training_docs(), 0, 100000))
return random.sample(self._traindocs, k)
return random.sample(self._training_docs, k)
def doc_to_text(self, doc):
return 'Q: ' + doc['question']['text'] + '\n\n' + 'A: '
......
import json
import random
from lm_eval.base import Dataset, rf, mean
from lm_eval.base import Task, rf, mean
from ..utils import sh
import os
class PiQA(Dataset):
class PiQA(Task):
def download(self):
if not os.path.exists('data/piqa'):
#TODO: use best_download
......
import json
import random
import os
from lm_eval.base import Dataset
from lm_eval.base import Task
from ..utils import sh
class QuAC(Dataset):
class QuAC(Task):
def __init__(self):
super().__init__()
......
......@@ -3,7 +3,19 @@ import datasets
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask
from ..utils_stream import each
import os
from functools import reduce
import operator
from tqdm import tqdm
import json
class each:
def __init__(self, f):
self.f = f
def __rrshift__(self, other):
return list(map(self.f, other))
class RACE(HFTask):
......
import json
import random
import os
from lm_eval.base import Dataset, rf, mean
from lm_eval.base import Task, rf, mean
from tqdm import auto as tqdm_lib
from . common import simple_accuracy_metric
import numpy as np
from ..utils import sh
class SATAnalogies(Dataset):
class SATAnalogies(Task):
NEEDS_MANUAL_DL = True
def __init__(self):
......
import json
import random
from lm_eval.base import Dataset
from lm_eval.base import Task
from ..utils import sh
import csv
class StoryCloze(Dataset):
class StoryCloze(Task):
NEEDS_MANUAL_DL = True
def download(self):
......
......@@ -261,7 +261,7 @@ class ReCoRD(HFTask):
return True
def has_test_docs(self):
return True
return False
def fewshot_description(self):
# TODO: figure out actual description
......@@ -322,6 +322,7 @@ class ReCoRD(HFTask):
# - Evaluate the accuracy and token F1 PER EXAMPLE
# - Average over all examples
max_idx = np.argmax(np.array(results))
prediction = doc["entities"][max_idx]
gold_label_set = list(set(doc["answers"]))
f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, prediction, gold_label_set)
......
import os
import json
import random
from lm_eval.base import Dataset, mean, rf
from lm_eval.base import Task, mean, rf
from ..utils import sh
class TriviaQA(Dataset):
class TriviaQA(Task):
def download(self):
if not os.path.exists('data/triviaqa'):
sh("""
......
from . common import HFTask
from lm_eval.base import mean, rf
class WebQs(HFTask):
DATASET_PATH = "web_questions"
......@@ -18,7 +19,6 @@ class WebQs(HFTask):
return ""
def doc_to_text(self, doc):
print(doc)
return "Q: " + doc['question'] + '\nA:'
def doc_to_target(self, doc):
......@@ -27,47 +27,36 @@ class WebQs(HFTask):
# TODO: make sure we're actually handling multi-answer correctly
return " " + doc['answers'][0]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return ret
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def construct_requests(self, doc, ctx):
ret = []
for alias in self._remove_prefixes(doc['answers']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
def process_results(self, doc, results):
return {
"acc": float(any(results))
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": mean,
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": True
}
\ No newline at end of file
import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno
from . common import HFTask
from lm_eval.base import rf, mean
"""
This evaluation of Winogrande uses partial evaluation as described by
Trinh & Le in Simple Method for Commonsense Reasoning (2018).
Reference: https://arxiv.org/abs/1806.02847
"""
class Winogrande(HFTask):
DATASET_PATH = "winogrande"
......@@ -17,35 +22,31 @@ class Winogrande(HFTask):
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
return self.data["train"]
def validation_docs(self):
if self.has_validation_docs():
return self.data["validation"]
def test_docs(self):
if self.has_test_docs():
return self.data["test"]
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the sentence with each candidate choice
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
context1 = doc["sentence"][:pronoun_loc] + doc["option1"]
context2 = doc["sentence"][:pronoun_loc] + doc["option2"]
return context1, context2
@classmethod
def partial_target(cls, doc):
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def doc_to_text(self, doc):
return doc['sentence']
context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def doc_to_target(self, doc):
text = doc['sentence']
answer_n = doc['answer']
if answer_n == '1':
answer = doc['option1']
elif answer_n == '2':
answer = doc['option2']
else:
raise ValueError("Winogrande from HF datasets contained an invalid answer key")
return text.replace("_", answer)
return self.partial_target(doc)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -58,8 +59,11 @@ class Winogrande(HFTask):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -71,8 +75,10 @@ class Winogrande(HFTask):
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
answer = int(doc["answer"]) - 1 # `- 1` b/c doc["answer"] ∈ {'1', '2'}
return {
"acc": np.argmax(results) == answer
}
def aggregation(self):
"""
......@@ -80,8 +86,9 @@ class Winogrande(HFTask):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": mean
}
def higher_is_better(self):
"""
......@@ -89,5 +96,6 @@ class Winogrande(HFTask):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": True
}
import json
import numpy as np
import random
import os
from lm_eval.base import Dataset
from ..utils import sh
from lm_eval.base import rf, mean
from . common import HFTask
"""
NOTE: This evaluation of Winograd Schema Challenge is based on `partial evaluation`
as described by Trinh & Le in Simple Method for Commonsense Reasoning (2018).
See: https://arxiv.org/abs/1806.02847
"""
class WinogradSchemaChallenge273(HFTask):
DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273"
upper_pronouns = ["A", "An", "The", "She", "He",
"It", "They", "My", "His", "Her", "Their"]
class WinogradSchemaChallenge273(Dataset):
def __init__(self):
super().__init__()
def download(self):
if not os.path.exists('data/wsc273'):
sh("""
mkdir -p data/wsc273
wget https://git.cse.msu.edu/bakerb15/nlp-final-project/raw/master/Winogard/reproduce/commonsense_test/wsc273.json -O data/wsc273/wsc273.json
""")
self.data = self.__clean_data()
def __clean_data(self):
# The HF implementation of `wsc273` is not `partial evaluation` friendly.
data = []
for doc in self.data["test"]:
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc["options"][0], doc)
doc["options"][1] = self.__normalize_option(doc["options"][1], doc)
data.append(doc)
return {"test": data}
def __normalize_option(self, option, doc):
# Append `'s` to possessive determiner based options.
if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
option += "'s"
# Appropriately lowercase the pronoun in the option.
pronoun = option.split()[0]
start_of_sentence = doc["text"][doc['pronoun_loc'] - 2] == '.'
if not start_of_sentence and pronoun in self.upper_pronouns:
return option.replace(pronoun, pronoun.lower())
return option
def has_training_docs(self):
return False
......@@ -25,60 +51,35 @@ class WinogradSchemaChallenge273(Dataset):
def has_test_docs(self):
return True
def training_docs(self):
return []
def validation_docs(self):
return []
def test_docs(self):
myjson = json.load(open('data/wsc273/wsc273.json'))
return self.load_doc(myjson)
def fewshot_examples(self, k):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
return random.sample(list(self.test_docs()), k)
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def load_doc(self, myjson):
docs = []
for i in range(0, 273 * 2, 2):
item1 = myjson[i]
item2 = myjson[i+1]
if item1['question_id'] != item2['question_id']:
raise ValueError("WSC273 has missing completion pair.")
question_id = item1['question_id']
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the original text with each candidate
# choice and ignore everything after.
context1 = doc["text"][:doc["pronoun_loc"]] + doc["options"][0]
context2 = doc["text"][:doc["pronoun_loc"]] + doc["options"][1]
return context1, context2
if item1['correctness'] == True:
doc = {
'id': question_id,
'completions': {
'T': item1['substitution'],
'F': item2['substitution'],
},
}
if item2['correctness'] == True:
doc = {
'id': question_id,
'completions': {
'F': item1['substitution'],
'T': item2['substitution'],
},
}
docs.append(doc)
return docs
@classmethod
def partial_target(cls, doc):
# The target is everything after the document specified pronoun.
start_index = doc["pronoun_loc"] + len(doc["pronoun"])
return doc["text"][start_index:].strip()
def doc_to_text(self, doc):
# TODO: implement
pass
context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def doc_to_target(self, doc):
# TODO: implement
pass
return self.partial_target(doc)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -91,8 +92,11 @@ class WinogradSchemaChallenge273(Dataset):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -104,8 +108,9 @@ class WinogradSchemaChallenge273(Dataset):
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": np.argmax(results) == doc["label"]
}
def aggregation(self):
"""
......@@ -113,8 +118,9 @@ class WinogradSchemaChallenge273(Dataset):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": mean
}
def higher_is_better(self):
"""
......@@ -122,5 +128,6 @@ class WinogradSchemaChallenge273(Dataset):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": True
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment