Commit 8ac99269 authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Replace the `fewshot_description` API with a `description_dict` based interface

parent b67aec37
......@@ -36,13 +36,9 @@ class PROST(HFTask, MultipleChoiceTask):
def has_test_docs(self):
return True
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
def fewshot_context(self, doc, num_fewshot, rnd, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
return super().fewshot_context(doc, num_fewshot, provide_description, rnd)
return super().fewshot_context(doc, num_fewshot, rnd, description)
def _convert_standard(self, doc):
out_doc = {
......
......@@ -23,11 +23,6 @@ class Pubmed_QA(HFTask):
# HF is labelled as train but its really just for testing
return self.data["train"]
def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return ""
def doc_to_text(self, doc):
ctxs = "\n".join(doc["context"]["contexts"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format(
......
......@@ -67,9 +67,6 @@ class QA4MRE(MultipleChoiceTask):
out_doc['source'] = src
yield out_doc
def fewshot_description(self):
return ""
def test_docs(self):
return self.load_docs(f"data/qa4mre/QA4MRE-{self.YEAR}-EN_GS.xml")
......
......@@ -51,11 +51,6 @@ class QuAC(Task):
def test_docs(self):
raise NotImplementedError("QuAC has no test docs.")
def fewshot_description(self):
# TODO: figure out fewshot description
desc = "TITLE: Title of the context passage - subtitle of the passage\nPARAGRAPH: Passage describing the relevant information for answering questions.\n\nQ: Text of a question.\n\nA: Answer to the question, based on the passage. If it cannot be answered based on the passage, write CANNOTANSWER"
return desc
def load_doc(self, myjson):
docs = []
for item in myjson:
......
......@@ -63,10 +63,6 @@ class RACE(HFTask):
def test_docs(self):
return self._collate_data("test")
def fewshot_description(self):
# TODO: figure out description
return ""
@classmethod
def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']]
......
......@@ -61,10 +61,5 @@ class SATAnalogies(MultipleChoiceTask):
}
yield doc
def fewshot_description(self):
# TODO: figure out actual description
return ""
def doc_to_text(self, doc):
return "{} is to {} as".format(*doc['query'])
......@@ -50,9 +50,6 @@ class SciQ(MultipleChoiceTask):
for record in docs:
yield self._convert_standard(record)
def fewshot_description(self):
return ""
def training_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
......@@ -63,4 +60,4 @@ class SciQ(MultipleChoiceTask):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
\ No newline at end of file
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
......@@ -41,10 +41,6 @@ class SQuAD2(HFTask):
def validation_docs(self):
return self.data["validation"]
def fewshot_description(self):
# TODO: figure out description
return ""
def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
......
......@@ -27,18 +27,12 @@ class StoryCloze(Task):
filereader = csv.reader(file)
return list(filereader)
def validation_docs(self):
return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv")
def test_docs(self):
return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv")
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
return ' '.join([*doc[1:5]])
......
......@@ -26,10 +26,6 @@ class BoolQ(HFTask):
def has_test_docs(self):
return False
def fewshot_description(self):
# TODO: figure out actual description
return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
......@@ -78,11 +74,6 @@ class CommitmentBank(HFTask):
def has_test_docs(self):
return False
def fewshot_description(self):
# TODO: figure out actual description
return "Given a premise and a hypothesis, classify whether the author of the premise is committed" \
"to the truth of the hypothesis. The three possible labels are true, false or neither."
def doc_to_text(self, doc):
return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format(
doc["premise"],
......@@ -150,11 +141,6 @@ class Copa(HFTask):
def has_test_docs(self):
return False
def fewshot_description(self):
# TODO: figure out actual description
return "Given a premise and one alternative with a causal relation to the premise and another without," \
"choose the more plausible alternative"
def doc_to_text(self, doc):
# Drop the period
connector = {
......@@ -215,10 +201,6 @@ class MultiRC(HFTask):
def has_test_docs(self):
return False
def fewshot_description(self):
# TODO: figure out actual description
return "READING COMPREHENSION ANSWER KEY"
def doc_to_text(self, doc):
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
......@@ -270,10 +252,6 @@ class ReCoRD(HFTask):
def has_test_docs(self):
return False
def fewshot_description(self):
# TODO: figure out actual description
return ""
def training_docs(self):
# In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing.
# Each doc consists of multiple answer candidates, each of which is scored yes/no.
......@@ -363,10 +341,6 @@ class WordsInContext(HFTask):
def has_test_docs(self):
return False
def fewshot_description(self):
# TODO: figure out actual description
return ""
def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nAnswer:".format(
......@@ -432,12 +406,6 @@ class SGWinogradSchemaChallenge(HFTask):
]
return self._training_docs
def fewshot_description(self):
return "Final Exam with Answer Key\n" \
"Instructions: Please carefully read the following passages. " \
"For each passage, you must identify which noun the pronoun marked in *bold*" \
" refers to.\n====="
def doc_to_text(self, doc):
raw_passage = doc["text"]
# NOTE: HuggingFace span indices are word-based not character-based.
......
......@@ -166,12 +166,6 @@ class GeneralTranslationTask(Task):
"ter": False,
}
def fewshot_description(self):
language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"Translate these {src_lang} phrases to {tar_lang}."
def __str__(self):
language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0])
......
......@@ -36,10 +36,6 @@ class TriviaQA(Task):
def test_docs(self):
raise NotImplementedError()
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
return f"Question: {doc['Question']}\nAnswer:"
......@@ -56,7 +52,6 @@ class TriviaQA(Task):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx):
ret = []
......
......@@ -85,9 +85,9 @@ class TruthfulQAMultipleChoice(Task):
def doc_to_target(self, doc):
return " "
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
def fewshot_context(self, doc, num_fewshot, rnd, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd)
return super().fewshot_context(doc, num_fewshot, rnd, description)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -213,9 +213,9 @@ class TruthfulQAGeneration(Task):
def doc_to_target(self, doc):
return " "
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
def fewshot_context(self, doc, num_fewshot, rnd, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd)
return super().fewshot_context(doc, num_fewshot, rnd, description)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......
......@@ -45,9 +45,6 @@ class WordUnscrambleTask(Task):
file = self.BASE_PATH / self.FILENAME
return (json.loads(line) for line in open(file).read().splitlines())
def fewshot_description(self):
return "Please unscramble the letters into a word, and write that word:"
def doc_to_text(self, doc):
return doc["context"]
......
......@@ -17,10 +17,6 @@ class WebQs(HFTask):
def has_test_docs(self):
return True
def fewshot_description(self):
# TODO: figure out description
return ""
def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:'
......@@ -40,7 +36,6 @@ class WebQs(HFTask):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx):
ret = []
......@@ -62,4 +57,4 @@ class WebQs(HFTask):
def higher_is_better(self):
return {
"acc": True
}
\ No newline at end of file
}
......@@ -49,10 +49,6 @@ class WikiText(PerplexityTask):
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", "data/wikitext/wikitext-2-raw-v1.zip", "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def has_validation_docs(self):
return True
......@@ -87,4 +83,4 @@ class WikiText(PerplexityTask):
def count_words(self, doc):
# count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc))
\ No newline at end of file
return len(re.split(r"\s+", doc))
......@@ -29,10 +29,6 @@ class Winogrande(HFTask):
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
@classmethod
def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option
......
......@@ -53,10 +53,6 @@ class WinogradSchemaChallenge273(HFTask):
def has_test_docs(self):
return True
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def fewshot_examples(self, k, rnd):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
......
......@@ -13,7 +13,7 @@ def parse_args():
parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--description_path', default=None)
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=str, default=None)
......@@ -26,8 +26,6 @@ def main():
args = parse_args()
assert not args.provide_description # not implemented
if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
......@@ -36,7 +34,17 @@ def main():
else:
task_names = args.tasks.split(",")
results = evaluator.simple_evaluate(args.model, args.model_args, task_names, args.num_fewshot, args.batch_size, args.device, args.no_cache, args.limit)
results = evaluator.simple_evaluate(
args.model,
args.model_args,
task_names,
args.description_path,
args.num_fewshot,
args.batch_size,
args.device,
args.no_cache,
args.limit
)
dumped = json.dumps(results, indent=2)
......@@ -46,7 +54,7 @@ def main():
with open(args.output_path, "w") as f:
f.write(dumped)
print(f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}")
print(f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}")
print(evaluator.make_table(results))
if __name__ == "__main__":
......
......@@ -51,7 +51,7 @@ def main():
values = []
for taskname in task_list.split(","):
lm.tokencost = 0
evaluator.evaluate(lm, {taskname: tasks.get_task(taskname)()}, False, 0, None, bootstrap_iters=10)
evaluator.evaluate(lm, {taskname: tasks.get_task(taskname)()}, 0, None, bootstrap_iters=10)
print(taskname, lm.tokencost)
values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment