Commit 9b4bfb6a authored by jon-tow's avatar jon-tow
Browse files

Merge changes

parent e0cfeb90
......@@ -57,7 +57,7 @@ class Arithmetic(Task):
return True
def doc_to_decontamination_query(self, doc):
return doc.context
return doc["context"]
def doc_to_target(self, doc):
return doc["completion"]
......
......@@ -65,7 +65,7 @@ class CoQA(Task):
return True
def doc_to_decontamination_query(self, doc):
return doc["story"] + " " + doc["questions"]
return doc["story"] + " " + "\n".join(doc["questions"]["input_text"])
@classmethod
def get_answers(cls, doc, turn_id):
......
......@@ -94,7 +94,7 @@ class EthicsCM(Ethics):
return True
def doc_to_decontamination_query(self, doc):
return doc[1]
return doc["input"]
def doc_to_target(self, doc):
return " {}".format(yesno(int(doc["label"])))
......@@ -135,7 +135,7 @@ class EthicsDeontology(Ethics):
return True
def doc_to_decontamination_query(self, doc):
return " ".join([doc[1], doc[2]])
return " ".join([doc["scenario"], doc["excuse"]])
def doc_to_target(self, doc):
target = ["unreasonable", "reasonable"][int(doc["label"])]
......@@ -186,7 +186,7 @@ class EthicsJustice(Ethics):
return True
def doc_to_decontamination_query(self, doc):
return doc[1]
return doc["scenario"]
def doc_to_target(self, doc):
target = ["unreasonable", "reasonable"][int(doc["label"])]
......
......@@ -117,4 +117,3 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def doc_to_decontamination_query(self, doc):
return doc["query"]
......@@ -72,7 +72,7 @@ class LogiQA(MultipleChoiceTask):
return prompt
choices = ['a', 'b', 'c', 'd']
return {
"passage": doc["passage"], # Used for decontamination
"passage": doc["context"], # Used for decontamination
"query": format_example(doc, choices),
"choices": doc["options"],
"gold": choices.index(doc["label"])
......
......@@ -69,4 +69,3 @@ class OpenBookQA(MultipleChoiceTask):
def doc_to_decontamination_query(self, doc):
return doc["query"]
......@@ -77,4 +77,3 @@ class PROST(MultipleChoiceTask):
def doc_to_decontamination_query(self, doc):
return doc["query"]
......@@ -72,5 +72,4 @@ class SATAnalogies(MultipleChoiceTask):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + " " + doc["query"]
return doc["source"] + "\n" + " ".join(doc["query"])
......@@ -76,7 +76,12 @@ class StoryCloze(Task):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
return ' '.join([
doc["input_sentence_1"],
doc["input_sentence_2"],
doc["input_sentence_3"],
doc["input_sentence_4"],
])
def doc_to_target(self, doc):
clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
......
......@@ -93,9 +93,6 @@ class WikiText(PerplexityTask):
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["text"]
def count_words(self, doc):
# count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment