"test/vscode:/vscode.git/clone" did not exist on "4dc2f1a1b90b5134730cdae92e8328d365d93874"
Commit 8ac99269 authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Replace the `fewshot_description` API with a `description_dict` based interface

parent b67aec37
...@@ -36,13 +36,9 @@ class PROST(HFTask, MultipleChoiceTask): ...@@ -36,13 +36,9 @@ class PROST(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def fewshot_description(self): def fewshot_context(self, doc, num_fewshot, rnd, description=None):
# TODO: figure out fewshot description
return ""
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.' assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
return super().fewshot_context(doc, num_fewshot, provide_description, rnd) return super().fewshot_context(doc, num_fewshot, rnd, description)
def _convert_standard(self, doc): def _convert_standard(self, doc):
out_doc = { out_doc = {
......
...@@ -23,11 +23,6 @@ class Pubmed_QA(HFTask): ...@@ -23,11 +23,6 @@ class Pubmed_QA(HFTask):
# HF is labelled as train but its really just for testing # HF is labelled as train but its really just for testing
return self.data["train"] return self.data["train"]
def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
ctxs = "\n".join(doc["context"]["contexts"]) ctxs = "\n".join(doc["context"]["contexts"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format( return "Abstract: {}\nQuestion: {}\nAnswer:".format(
......
...@@ -67,9 +67,6 @@ class QA4MRE(MultipleChoiceTask): ...@@ -67,9 +67,6 @@ class QA4MRE(MultipleChoiceTask):
out_doc['source'] = src out_doc['source'] = src
yield out_doc yield out_doc
def fewshot_description(self):
return ""
def test_docs(self): def test_docs(self):
return self.load_docs(f"data/qa4mre/QA4MRE-{self.YEAR}-EN_GS.xml") return self.load_docs(f"data/qa4mre/QA4MRE-{self.YEAR}-EN_GS.xml")
......
...@@ -51,11 +51,6 @@ class QuAC(Task): ...@@ -51,11 +51,6 @@ class QuAC(Task):
def test_docs(self): def test_docs(self):
raise NotImplementedError("QuAC has no test docs.") raise NotImplementedError("QuAC has no test docs.")
def fewshot_description(self):
# TODO: figure out fewshot description
desc = "TITLE: Title of the context passage - subtitle of the passage\nPARAGRAPH: Passage describing the relevant information for answering questions.\n\nQ: Text of a question.\n\nA: Answer to the question, based on the passage. If it cannot be answered based on the passage, write CANNOTANSWER"
return desc
def load_doc(self, myjson): def load_doc(self, myjson):
docs = [] docs = []
for item in myjson: for item in myjson:
......
...@@ -63,10 +63,6 @@ class RACE(HFTask): ...@@ -63,10 +63,6 @@ class RACE(HFTask):
def test_docs(self): def test_docs(self):
return self._collate_data("test") return self._collate_data("test")
def fewshot_description(self):
# TODO: figure out description
return ""
@classmethod @classmethod
def get_answer_option(cls, problem): def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']] answer = cls.letter_to_num[problem['answer']]
......
...@@ -61,10 +61,5 @@ class SATAnalogies(MultipleChoiceTask): ...@@ -61,10 +61,5 @@ class SATAnalogies(MultipleChoiceTask):
} }
yield doc yield doc
def fewshot_description(self):
# TODO: figure out actual description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{} is to {} as".format(*doc['query']) return "{} is to {} as".format(*doc['query'])
...@@ -50,9 +50,6 @@ class SciQ(MultipleChoiceTask): ...@@ -50,9 +50,6 @@ class SciQ(MultipleChoiceTask):
for record in docs: for record in docs:
yield self._convert_standard(record) yield self._convert_standard(record)
def fewshot_description(self):
return ""
def training_docs(self): def training_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/train.json") return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
......
...@@ -41,10 +41,6 @@ class SQuAD2(HFTask): ...@@ -41,10 +41,6 @@ class SQuAD2(HFTask):
def validation_docs(self): def validation_docs(self):
return self.data["validation"] return self.data["validation"]
def fewshot_description(self):
# TODO: figure out description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:' return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
......
...@@ -27,18 +27,12 @@ class StoryCloze(Task): ...@@ -27,18 +27,12 @@ class StoryCloze(Task):
filereader = csv.reader(file) filereader = csv.reader(file)
return list(filereader) return list(filereader)
def validation_docs(self): def validation_docs(self):
return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv") return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv")
def test_docs(self): def test_docs(self):
return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv") return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv")
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ' '.join([*doc[1:5]]) return ' '.join([*doc[1:5]])
......
...@@ -26,10 +26,6 @@ class BoolQ(HFTask): ...@@ -26,10 +26,6 @@ class BoolQ(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}\nAnswer:" return f"{doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
...@@ -78,11 +74,6 @@ class CommitmentBank(HFTask): ...@@ -78,11 +74,6 @@ class CommitmentBank(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return "Given a premise and a hypothesis, classify whether the author of the premise is committed" \
"to the truth of the hypothesis. The three possible labels are true, false or neither."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format(
doc["premise"], doc["premise"],
...@@ -150,11 +141,6 @@ class Copa(HFTask): ...@@ -150,11 +141,6 @@ class Copa(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return "Given a premise and one alternative with a causal relation to the premise and another without," \
"choose the more plausible alternative"
def doc_to_text(self, doc): def doc_to_text(self, doc):
# Drop the period # Drop the period
connector = { connector = {
...@@ -215,10 +201,6 @@ class MultiRC(HFTask): ...@@ -215,10 +201,6 @@ class MultiRC(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return "READING COMPREHENSION ANSWER KEY"
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:" return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
...@@ -270,10 +252,6 @@ class ReCoRD(HFTask): ...@@ -270,10 +252,6 @@ class ReCoRD(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return ""
def training_docs(self): def training_docs(self):
# In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing. # In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing.
# Each doc consists of multiple answer candidates, each of which is scored yes/no. # Each doc consists of multiple answer candidates, each of which is scored yes/no.
...@@ -363,10 +341,6 @@ class WordsInContext(HFTask): ...@@ -363,10 +341,6 @@ class WordsInContext(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \ return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nAnswer:".format( " two sentences above?\nAnswer:".format(
...@@ -432,12 +406,6 @@ class SGWinogradSchemaChallenge(HFTask): ...@@ -432,12 +406,6 @@ class SGWinogradSchemaChallenge(HFTask):
] ]
return self._training_docs return self._training_docs
def fewshot_description(self):
return "Final Exam with Answer Key\n" \
"Instructions: Please carefully read the following passages. " \
"For each passage, you must identify which noun the pronoun marked in *bold*" \
" refers to.\n====="
def doc_to_text(self, doc): def doc_to_text(self, doc):
raw_passage = doc["text"] raw_passage = doc["text"]
# NOTE: HuggingFace span indices are word-based not character-based. # NOTE: HuggingFace span indices are word-based not character-based.
......
...@@ -166,12 +166,6 @@ class GeneralTranslationTask(Task): ...@@ -166,12 +166,6 @@ class GeneralTranslationTask(Task):
"ter": False, "ter": False,
} }
def fewshot_description(self):
language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"Translate these {src_lang} phrases to {tar_lang}."
def __str__(self): def __str__(self):
language_codes = self.sacrebleu_language_pair.split("-") language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0]) src_lang = code_to_language(language_codes[0])
......
...@@ -36,10 +36,6 @@ class TriviaQA(Task): ...@@ -36,10 +36,6 @@ class TriviaQA(Task):
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Question: {doc['Question']}\nAnswer:" return f"Question: {doc['Question']}\nAnswer:"
...@@ -57,7 +53,6 @@ class TriviaQA(Task): ...@@ -57,7 +53,6 @@ class TriviaQA(Task):
return ret return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['Answer']['Aliases']): for alias in self._remove_prefixes(doc['Answer']['Aliases']):
......
...@@ -85,9 +85,9 @@ class TruthfulQAMultipleChoice(Task): ...@@ -85,9 +85,9 @@ class TruthfulQAMultipleChoice(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " return " "
def fewshot_context(self, doc, num_fewshot, provide_description, rnd): def fewshot_context(self, doc, num_fewshot, rnd, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting." assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd) return super().fewshot_context(doc, num_fewshot, rnd, description)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -213,9 +213,9 @@ class TruthfulQAGeneration(Task): ...@@ -213,9 +213,9 @@ class TruthfulQAGeneration(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " return " "
def fewshot_context(self, doc, num_fewshot, provide_description, rnd): def fewshot_context(self, doc, num_fewshot, rnd, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting." assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd) return super().fewshot_context(doc, num_fewshot, rnd, description)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
......
...@@ -45,9 +45,6 @@ class WordUnscrambleTask(Task): ...@@ -45,9 +45,6 @@ class WordUnscrambleTask(Task):
file = self.BASE_PATH / self.FILENAME file = self.BASE_PATH / self.FILENAME
return (json.loads(line) for line in open(file).read().splitlines()) return (json.loads(line) for line in open(file).read().splitlines())
def fewshot_description(self):
return "Please unscramble the letters into a word, and write that word:"
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["context"] return doc["context"]
......
...@@ -17,10 +17,6 @@ class WebQs(HFTask): ...@@ -17,10 +17,6 @@ class WebQs(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def fewshot_description(self):
# TODO: figure out description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc['question'] + '\nAnswer:'
...@@ -41,7 +37,6 @@ class WebQs(HFTask): ...@@ -41,7 +37,6 @@ class WebQs(HFTask):
return ret return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['answers']): for alias in self._remove_prefixes(doc['answers']):
......
...@@ -49,10 +49,6 @@ class WikiText(PerplexityTask): ...@@ -49,10 +49,6 @@ class WikiText(PerplexityTask):
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", "data/wikitext/wikitext-2-raw-v1.zip", "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11") download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", "data/wikitext/wikitext-2-raw-v1.zip", "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip") sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def has_validation_docs(self): def has_validation_docs(self):
return True return True
......
...@@ -29,10 +29,6 @@ class Winogrande(HFTask): ...@@ -29,10 +29,6 @@ class Winogrande(HFTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]]) return self.partial_context(doc, doc["option" + doc["answer"]])
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
@classmethod @classmethod
def partial_context(cls, doc, option): def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option # Substitute the pronoun in the sentence with the specified option
......
...@@ -53,10 +53,6 @@ class WinogradSchemaChallenge273(HFTask): ...@@ -53,10 +53,6 @@ class WinogradSchemaChallenge273(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# NOTE: `super().fewshot_examples` samples from training docs which are # NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset. # not available for this test-set-only dataset.
......
...@@ -13,7 +13,7 @@ def parse_args(): ...@@ -13,7 +13,7 @@ def parse_args():
parser.add_argument('--model', required=True) parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="") parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks") parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true") parser.add_argument('--description_path', default=None)
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=str, default=None) parser.add_argument('--device', type=str, default=None)
...@@ -26,8 +26,6 @@ def main(): ...@@ -26,8 +26,6 @@ def main():
args = parse_args() args = parse_args()
assert not args.provide_description # not implemented
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
...@@ -36,7 +34,17 @@ def main(): ...@@ -36,7 +34,17 @@ def main():
else: else:
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
results = evaluator.simple_evaluate(args.model, args.model_args, task_names, args.num_fewshot, args.batch_size, args.device, args.no_cache, args.limit) results = evaluator.simple_evaluate(
args.model,
args.model_args,
task_names,
args.description_path,
args.num_fewshot,
args.batch_size,
args.device,
args.no_cache,
args.limit
)
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
...@@ -46,7 +54,7 @@ def main(): ...@@ -46,7 +54,7 @@ def main():
with open(args.output_path, "w") as f: with open(args.output_path, "w") as f:
f.write(dumped) f.write(dumped)
print(f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}") print(f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}")
print(evaluator.make_table(results)) print(evaluator.make_table(results))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -51,7 +51,7 @@ def main(): ...@@ -51,7 +51,7 @@ def main():
values = [] values = []
for taskname in task_list.split(","): for taskname in task_list.split(","):
lm.tokencost = 0 lm.tokencost = 0
evaluator.evaluate(lm, {taskname: tasks.get_task(taskname)()}, False, 0, None, bootstrap_iters=10) evaluator.evaluate(lm, {taskname: tasks.get_task(taskname)()}, 0, None, bootstrap_iters=10)
print(taskname, lm.tokencost) print(taskname, lm.tokencost)
values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06]) values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment