Commit 34f591af authored by jon-tow's avatar jon-tow
Browse files

Add multiple tasks

parent 2bfa4518
...@@ -52,6 +52,7 @@ from . import blimp ...@@ -52,6 +52,7 @@ from . import blimp
from . import asdiv from . import asdiv
from . import gsm8k from . import gsm8k
from . import storycloze from . import storycloze
from . import e2e_nlg_cleaned
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -124,6 +125,7 @@ TASK_REGISTRY = { ...@@ -124,6 +125,7 @@ TASK_REGISTRY = {
# Science related # Science related
"pubmedqa" : pubmedqa.Pubmed_QA, "pubmedqa" : pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ, "sciq" : sciq.SciQ,
"e2e_nlg_cleaned": e2e_nlg_cleaned.E2E_NLG_Cleaned,
"qasper": qasper.QASPER, "qasper": qasper.QASPER,
......
...@@ -10,7 +10,7 @@ provided explanations. ...@@ -10,7 +10,7 @@ provided explanations.
Homepage: "https://github.com/facebookresearch/anli" Homepage: "https://github.com/facebookresearch/anli"
""" """
import numpy as np import numpy as np
from lm_eval.base import rf, Task from lm_eval.base import rf, PromptSourceTask
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -30,7 +30,7 @@ _CITATION = """ ...@@ -30,7 +30,7 @@ _CITATION = """
""" """
class ANLIBase(Task): class ANLIBase(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "anli" DATASET_PATH = "anli"
DATASET_NAME = None DATASET_NAME = None
...@@ -59,51 +59,6 @@ class ANLIBase(Task): ...@@ -59,51 +59,6 @@ class ANLIBase(Task):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["test_r" + str(self.SPLIT)] return self.dataset["test_r" + str(self.SPLIT)]
def doc_to_text(self, doc):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did?
return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:'
def doc_to_target(self, doc):
# True = entailment
# False = contradiction
# Neither = neutral
return " " + ["True", "Neither", "False"][doc['label']]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
ll_true, _ = rf.loglikelihood(ctx, " True")
ll_neither, _ = rf.loglikelihood(ctx, " Neither")
ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_neither, ll_false
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold = doc["label"]
pred = np.argmax(results)
return {
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
......
...@@ -67,6 +67,7 @@ class CoQA(PromptSourceTask): ...@@ -67,6 +67,7 @@ class CoQA(PromptSourceTask):
# answers.append(additional_answer_for_turn) # answers.append(additional_answer_for_turn)
# return answers # return answers
@staticmethod @staticmethod
def compute_scores(gold_list, pred): def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact) # tests for exact match and on the normalised answer (compute_exact)
...@@ -90,19 +91,21 @@ class CoQA(PromptSourceTask): ...@@ -90,19 +91,21 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)), "f1": f1_sum / max(1, len(gold_list)),
} }
def construct_requests(self, doc, ctx): def eos_token(self):
"""Uses RequestFactory to construct Requests and returns an iterable of return "\n"
Requests which will be sent to the LM.
:param doc: # def construct_requests(self, doc, ctx):
The document as returned from training_docs, validation_docs, or test_docs. # """Uses RequestFactory to construct Requests and returns an iterable of
:param ctx: str # Requests which will be sent to the LM.
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question # :param doc:
part of the document for `doc`. # The document as returned from training_docs, validation_docs, or test_docs.
""" # :param ctx: str
cont_request = rf.greedy_until(ctx, ["\nQ:"]) # The context string, generated by fewshot_context. This includes the natural
return cont_request # language description, as well as the few shot examples, and the question
# part of the document for `doc`.
# """
# return cont_request
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -116,6 +119,13 @@ class CoQA(PromptSourceTask): ...@@ -116,6 +119,13 @@ class CoQA(PromptSourceTask):
""" """
target = self.doc_to_target(doc).strip() target = self.doc_to_target(doc).strip()
pred = results[0].strip().split("\n")[0] pred = results[0].strip().split("\n")[0]
print("*" * 80)
print(f"DOC: {doc}")
# print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(pred)
print("*" * 80)
# turn_id = len(doc["questions"]["input_text"]) # turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id) # gold_list = self.get_answers(doc, turn_id)
......
...@@ -39,7 +39,7 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) ...@@ -39,7 +39,7 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(PromptSourceTask): class DROP(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop) DATASET_PATH = "drop" # inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self): def has_training_docs(self):
...@@ -52,51 +52,13 @@ class DROP(PromptSourceTask): ...@@ -52,51 +52,13 @@ class DROP(PromptSourceTask):
return False return False
def training_docs(self): def training_docs(self):
if self._training_docs is None: # if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"])) # self._training_docs = list()
return self._training_docs # return self._training_docs
return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return map(self._process_doc, self.dataset["validation"]) return self.dataset["validation"]
def _process_doc(self, doc):
return {
"id": doc["query_id"],
"passage": doc["passage"],
"question": doc["question"],
"answers": self.get_answers(doc),
}
@classmethod
def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers):
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas = []
for i in range(len(validated_answers["number"])):
vas.append(
{
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
}
)
return vas
answers = []
answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(
qa["validated_answers"]
)
for candidate in candidates:
answer = cls.parse_answer(candidate)
if answer in answers_set:
continue
answers_set.add(answer)
answers.append(answer)
return answers
@classmethod @classmethod
def parse_answer(cls, answer): def parse_answer(cls, answer):
...@@ -117,19 +79,21 @@ class DROP(PromptSourceTask): ...@@ -117,19 +79,21 @@ class DROP(PromptSourceTask):
# def doc_to_target(self, doc): # def doc_to_target(self, doc):
# return " " + ", ".join(doc["answers"][0]) # return " " + ", ".join(doc["answers"][0])
def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of # """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. # Requests which will be sent to the LM.
:param doc: # :param doc:
The document as returned from training_docs, validation_docs, or test_docs. # The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str # :param ctx: str
The context string, generated by fewshot_context. This includes the natural # The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question # language description, as well as the few shot examples, and the question
part of the document for `doc`. # part of the document for `doc`.
""" # """
conts = [rf.greedy_until(ctx, ["."])] # conts = [rf.greedy_until(ctx, ["."])]
return conts # return conts
def eos_token(self):
return "."
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -145,6 +109,15 @@ class DROP(PromptSourceTask): ...@@ -145,6 +109,15 @@ class DROP(PromptSourceTask):
pred = results[0].strip() pred = results[0].strip()
target = self.doc_to_target(doc).strip() target = self.doc_to_target(doc).strip()
print("*" * 80)
print(f"DOC: {doc}")
print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(pred)
print("*" * 80)
preds = [pred] preds = [pred]
golds = [target] golds = [target]
......
...@@ -45,7 +45,7 @@ _CITATION = """ ...@@ -45,7 +45,7 @@ _CITATION = """
# Single-Sentence Tasks # Single-Sentence Tasks
class CoLA(Task): class CoLA(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "cola" DATASET_NAME = "cola"
...@@ -67,23 +67,20 @@ class CoLA(Task): ...@@ -67,23 +67,20 @@ class CoLA(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def process_results(self, doc, results):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) answer_choices_list = self.prompt.get_answer_choices_list(doc)
pred = np.argmax(results)
def doc_to_target(self, doc): target = answer_choices_list.index(self.doc_to_target(doc).strip())
return " {}".format({1: "yes", 0: "no"}[doc["label"]]) print("*" * 80)
print(f"DOC: {doc}")
def construct_requests(self, doc, ctx): print(f"TEXT: {self.doc_to_text(doc)}")
ll_true, _ = rf.loglikelihood(ctx, " yes") print(f"STRING TARGET: {self.doc_to_target(doc)} END TARGET")
ll_false, _ = rf.loglikelihood(ctx, " no") print(f"TARGET: {target} END TARGET")
return ll_true, ll_false print(f"PRED: {pred}")
print("*" * 80)
def process_results(self, doc, results):
ll_true, ll_false = results
pred = ll_true > ll_false
gold = doc["label"]
return { return {
"mcc": (gold, pred) "mcc": (target, pred)
} }
def higher_is_better(self): def higher_is_better(self):
...@@ -97,7 +94,7 @@ class CoLA(Task): ...@@ -97,7 +94,7 @@ class CoLA(Task):
} }
class SST(Task): class SST(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "sst2" DATASET_NAME = "sst2"
...@@ -119,27 +116,6 @@ class SST(Task): ...@@ -119,27 +116,6 @@ class SST(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format(
general_detokenize(doc["sentence"]),
)
def doc_to_target(self, doc):
return " {}".format({1: "positive", 0: "negative"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_positive, _ = rf.loglikelihood(ctx, " positive")
ll_negative, _ = rf.loglikelihood(ctx, " negative")
return ll_positive, ll_negative
def process_results(self, doc, results):
ll_positive, ll_negative = results
pred = ll_positive > ll_negative
gold = doc["label"]
return {
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {
"acc": True "acc": True
...@@ -154,7 +130,7 @@ class SST(Task): ...@@ -154,7 +130,7 @@ class SST(Task):
# Inference Tasks # Inference Tasks
class MNLI(Task): class MNLI(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mnli" DATASET_NAME = "mnli"
...@@ -181,24 +157,6 @@ class MNLI(Task): ...@@ -181,24 +157,6 @@ class MNLI(Task):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["test_matched"] return self.dataset["test_matched"]
def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"],
doc["hypothesis"].strip() + ('' if doc["hypothesis"].strip().endswith('.') else '.'),
)
def doc_to_target(self, doc):
# True = entailment
# False = contradiction
# Neither = neutral
return " {}".format({0: "True", 1: "Neither", 2: "False"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, " True")
ll_neither, _ = rf.loglikelihood(ctx, " Neither")
ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_neither, ll_false
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
...@@ -251,22 +209,6 @@ class QNLI(Task): ...@@ -251,22 +209,6 @@ class QNLI(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["question"],
doc["sentence"],
)
def doc_to_target(self, doc):
# True = entailment
# False = not entailment
return " {}".format({0: "yes", 1: "no"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no
def process_results(self, doc, results): def process_results(self, doc, results):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_no > ll_yes pred = ll_no > ll_yes
...@@ -342,14 +284,6 @@ class RTE(PromptSourceTask): ...@@ -342,14 +284,6 @@ class RTE(PromptSourceTask):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
# def process_results(self, doc, results):
# ll_true, ll_false = results
# pred = ll_false > ll_true
# gold = doc["label"]
# return {
# "acc": pred == gold
# }
def higher_is_better(self): def higher_is_better(self):
return { return {
"acc": True "acc": True
...@@ -386,20 +320,6 @@ class MRPC(Task): ...@@ -386,20 +320,6 @@ class MRPC(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format(
general_detokenize(doc["sentence1"]),
general_detokenize(doc["sentence2"]),
)
def doc_to_target(self, doc):
return " {}".format(yesno(doc["label"]))
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no
def process_results(self, doc, results): def process_results(self, doc, results):
ll_yes, ll_no = results ll_yes, ll_no = results
gold = doc["label"] gold = doc["label"]
......
...@@ -12,7 +12,7 @@ TODO: WSC requires free-form generation. ...@@ -12,7 +12,7 @@ TODO: WSC requires free-form generation.
import numpy as np import numpy as np
import sklearn import sklearn
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
from lm_eval.base import rf, Task from lm_eval.base import rf, PromptSourceTask
from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno
from lm_eval.utils import general_detokenize from lm_eval.utils import general_detokenize
...@@ -32,7 +32,7 @@ _CITATION = """ ...@@ -32,7 +32,7 @@ _CITATION = """
""" """
class BoolQ(Task): class BoolQ(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "boolq" DATASET_NAME = "boolq"
...@@ -54,29 +54,6 @@ class BoolQ(Task): ...@@ -54,29 +54,6 @@ class BoolQ(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:"
def doc_to_target(self, doc):
return " " + yesno(doc['label'])
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes')
ll_no, _ = rf.loglikelihood(ctx, ' no')
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {
"acc": True "acc": True
...@@ -88,7 +65,7 @@ class BoolQ(Task): ...@@ -88,7 +65,7 @@ class BoolQ(Task):
} }
class CommitmentBank(Task): class CommitmentBank(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "cb" DATASET_NAME = "cb"
...@@ -110,25 +87,6 @@ class CommitmentBank(Task): ...@@ -110,25 +87,6 @@ class CommitmentBank(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format(
doc["premise"],
doc["hypothesis"],
)
def doc_to_target(self, doc):
# True = entailment
# False = contradiction
# Neither = neutral
return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, ' True')
ll_false, _ = rf.loglikelihood(ctx, ' False')
ll_neither, _ = rf.loglikelihood(ctx, ' Neither')
return ll_true, ll_false, ll_neither
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
...@@ -163,7 +121,7 @@ class CommitmentBank(Task): ...@@ -163,7 +121,7 @@ class CommitmentBank(Task):
} }
class Copa(Task): class Copa(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "copa" DATASET_NAME = "copa"
...@@ -185,28 +143,6 @@ class Copa(Task): ...@@ -185,28 +143,6 @@ class Copa(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
# Drop the period
connector = {
"cause": "because",
"effect": "therefore",
}[doc["question"]]
return doc["premise"].strip()[:-1] + f" {connector}"
def doc_to_target(self, doc):
correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"]
# Connect the sentences
return " " + self.convert_choice(correct_choice)
def construct_requests(self, doc, ctx):
choice1 = " " + self.convert_choice(doc["choice1"])
choice2 = " " + self.convert_choice(doc["choice2"])
ll_choice1, _ = rf.loglikelihood(ctx, choice1)
ll_choice2, _ = rf.loglikelihood(ctx, choice2)
return ll_choice1, ll_choice2
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
...@@ -231,7 +167,7 @@ class Copa(Task): ...@@ -231,7 +167,7 @@ class Copa(Task):
return choice[0].lower() + choice[1:] return choice[0].lower() + choice[1:]
class MultiRC(Task): class MultiRC(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "multirc" DATASET_NAME = "multirc"
...@@ -253,26 +189,6 @@ class MultiRC(Task): ...@@ -253,26 +189,6 @@ class MultiRC(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + self.format_answer(answer=doc["answer"], label=doc["label"])
@staticmethod
def format_answer(answer, label):
label_str = "yes" if label else "no"
return f"{answer}\nIs the answer correct? {label_str}"
def construct_requests(self, doc, ctx):
true_choice = self.format_answer(answer=doc["answer"], label=True)
false_choice = self.format_answer(answer=doc["answer"], label=False)
ll_true_choice, _ = rf.loglikelihood(ctx, f' {true_choice}')
ll_false_choice, _ = rf.loglikelihood(ctx, f' {false_choice}')
return ll_true_choice, ll_false_choice
def process_results(self, doc, results): def process_results(self, doc, results):
ll_true_choice, ll_false_choice = results ll_true_choice, ll_false_choice = results
pred = ll_true_choice > ll_false_choice pred = ll_true_choice > ll_false_choice
...@@ -291,7 +207,7 @@ class MultiRC(Task): ...@@ -291,7 +207,7 @@ class MultiRC(Task):
} }
class ReCoRD(Task): class ReCoRD(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "record" DATASET_NAME = "record"
...@@ -328,33 +244,13 @@ class ReCoRD(Task): ...@@ -328,33 +244,13 @@ class ReCoRD(Task):
"answers": sorted(list(set(doc["answers"]))), "answers": sorted(list(set(doc["answers"]))),
} }
def doc_to_text(self, doc):
initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n")
text = initial_text + "\n\n"
for highlight in highlights:
text += f" - {highlight}.\n"
return text
@classmethod
def format_answer(cls, query, entity):
return f' - {query}'.replace("@placeholder", entity)
def doc_to_target(self, doc):
# We only output the first correct entity in a doc
return self.format_answer(query=doc["query"], entity=doc["answers"][0])
def construct_requests(self, doc, ctx):
requests = [
rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity))
for entity in doc["entities"]
]
return requests
def process_results(self, doc, results): def process_results(self, doc, results):
# ReCoRD's evaluation is actually deceptively simple: # ReCoRD's evaluation is actually deceptively simple:
# - Pick the maximum likelihood prediction entity # - Pick the maximum likelihood prediction entity
# - Evaluate the accuracy and token F1 PER EXAMPLE # - Evaluate the accuracy and token F1 PER EXAMPLE
# - Average over all examples # - Average over all examples
# TODO (jon-tow): Look at result
max_idx = np.argmax(np.array([result[0] for result in results])) max_idx = np.argmax(np.array([result[0] for result in results]))
prediction = doc["entities"][max_idx] prediction = doc["entities"][max_idx]
...@@ -380,7 +276,7 @@ class ReCoRD(Task): ...@@ -380,7 +276,7 @@ class ReCoRD(Task):
} }
class WordsInContext(Task): class WordsInContext(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "wic" DATASET_NAME = "wic"
...@@ -402,33 +298,6 @@ class WordsInContext(Task): ...@@ -402,33 +298,6 @@ class WordsInContext(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nAnswer:".format(
doc["sentence1"],
doc["sentence2"],
doc["sentence1"][doc["start1"]:doc["end1"]],
)
def doc_to_target(self, doc):
return " {}".format({0: "no", 1: "yes"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes')
ll_no, _ = rf.loglikelihood(ctx, ' no')
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {
"acc": True "acc": True
...@@ -440,7 +309,7 @@ class WordsInContext(Task): ...@@ -440,7 +309,7 @@ class WordsInContext(Task):
} }
class SGWinogradSchemaChallenge(Task): class SGWinogradSchemaChallenge(PromptSourceTask):
VERSION = 0 VERSION = 0
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE, # Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
# binary version of the task. # binary version of the task.
...@@ -470,41 +339,6 @@ class SGWinogradSchemaChallenge(Task): ...@@ -470,41 +339,6 @@ class SGWinogradSchemaChallenge(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
raw_passage = doc["text"]
# NOTE: HuggingFace span indices are word-based not character-based.
pre = " ".join(raw_passage.split()[:doc["span2_index"]])
post = raw_passage[len(pre) + len(doc["span2_text"]) + 1:]
passage = general_detokenize(pre + " *{}*".format(doc['span2_text']) + post)
noun = doc["span1_text"]
pronoun = doc["span2_text"]
text = (
f"Passage: {passage}\n"
+ f"Question: In the passage above, does the pronoun \"*{pronoun}*\" refer to \"*{noun}*\"?\n"
+ "Answer:"
)
return text
def doc_to_target(self, doc):
return " " + yesno(doc['label'])
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes')
ll_no, _ = rf.loglikelihood(ctx, ' no')
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {
"acc": True "acc": True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment