Commit 55e62507 authored by researcher2's avatar researcher2
Browse files

Merge branch 'master' into researcher2

parents bb0eafbb 26f0233f
...@@ -65,10 +65,6 @@ class RACE(HFTask): ...@@ -65,10 +65,6 @@ class RACE(HFTask):
def test_docs(self): def test_docs(self):
return self._collate_data("test") return self._collate_data("test")
def fewshot_description(self):
# TODO: figure out description
return ""
@classmethod @classmethod
def get_answer_option(cls, problem): def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']] answer = cls.letter_to_num[problem['answer']]
......
...@@ -61,10 +61,5 @@ class SATAnalogies(MultipleChoiceTask): ...@@ -61,10 +61,5 @@ class SATAnalogies(MultipleChoiceTask):
} }
yield doc yield doc
def fewshot_description(self):
# TODO: figure out actual description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{} is to {} as".format(*doc['query']) return "{} is to {} as".format(*doc['query'])
...@@ -13,8 +13,8 @@ class SciQ(MultipleChoiceTask): ...@@ -13,8 +13,8 @@ class SciQ(MultipleChoiceTask):
os.makedirs('data/sciq', exist_ok=True) os.makedirs('data/sciq', exist_ok=True)
download_file( download_file(
'https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip', 'https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip',
'data/sciq/SciQ.zip', local_file='data/sciq/SciQ.zip',
'7f3312f6ac6b09970b32942d106a8c44ec0dad46a0369f17d635aff8e348a87c', expected_checksum='7f3312f6ac6b09970b32942d106a8c44ec0dad46a0369f17d635aff8e348a87c',
) )
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf: with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/") zf.extractall("data/sciq/")
...@@ -50,9 +50,6 @@ class SciQ(MultipleChoiceTask): ...@@ -50,9 +50,6 @@ class SciQ(MultipleChoiceTask):
for record in docs: for record in docs:
yield self._convert_standard(record) yield self._convert_standard(record)
def fewshot_description(self):
return ""
def training_docs(self): def training_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/train.json") return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
......
...@@ -41,10 +41,6 @@ class SQuAD2(HFTask): ...@@ -41,10 +41,6 @@ class SQuAD2(HFTask):
def validation_docs(self): def validation_docs(self):
return self.data["validation"] return self.data["validation"]
def fewshot_description(self):
# TODO: figure out description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:' return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
......
...@@ -27,18 +27,12 @@ class StoryCloze(Task): ...@@ -27,18 +27,12 @@ class StoryCloze(Task):
filereader = csv.reader(file) filereader = csv.reader(file)
return list(filereader) return list(filereader)
def validation_docs(self): def validation_docs(self):
return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv") return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv")
def test_docs(self): def test_docs(self):
return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv") return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv")
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ' '.join([*doc[1:5]]) return ' '.join([*doc[1:5]])
......
...@@ -13,7 +13,7 @@ from ..utils import general_detokenize ...@@ -13,7 +13,7 @@ from ..utils import general_detokenize
class BoolQ(HFTask): class BoolQ(HFTask):
VERSION = 0 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "boolq" DATASET_NAME = "boolq"
...@@ -26,12 +26,8 @@ class BoolQ(HFTask): ...@@ -26,12 +26,8 @@ class BoolQ(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}\nAnswer:" return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + yesno(doc['label']) return " " + yesno(doc['label'])
...@@ -65,7 +61,7 @@ class BoolQ(HFTask): ...@@ -65,7 +61,7 @@ class BoolQ(HFTask):
class CommitmentBank(HFTask): class CommitmentBank(HFTask):
VERSION = 0 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "cb" DATASET_NAME = "cb"
...@@ -78,11 +74,6 @@ class CommitmentBank(HFTask): ...@@ -78,11 +74,6 @@ class CommitmentBank(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return "Given a premise and a hypothesis, classify whether the author of the premise is committed" \
"to the truth of the hypothesis. The three possible labels are true, false or neither."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format(
doc["premise"], doc["premise"],
...@@ -93,14 +84,14 @@ class CommitmentBank(HFTask): ...@@ -93,14 +84,14 @@ class CommitmentBank(HFTask):
# True = entailment # True = entailment
# False = contradiction # False = contradiction
# Neither = neutral # Neither = neutral
return " {}".format({0: "True", 1: "Neither", 2: "False"}[doc["label"]]) return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, ' True') ll_true, _ = rf.loglikelihood(ctx, ' True')
ll_neither, _ = rf.loglikelihood(ctx, ' Neither')
ll_false, _ = rf.loglikelihood(ctx, ' False') ll_false, _ = rf.loglikelihood(ctx, ' False')
ll_neither, _ = rf.loglikelihood(ctx, ' Neither')
return ll_true, ll_neither, ll_false return ll_true, ll_false, ll_neither
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
...@@ -150,11 +141,6 @@ class Copa(HFTask): ...@@ -150,11 +141,6 @@ class Copa(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return "Given a premise and one alternative with a causal relation to the premise and another without," \
"choose the more plausible alternative"
def doc_to_text(self, doc): def doc_to_text(self, doc):
# Drop the period # Drop the period
connector = { connector = {
...@@ -202,7 +188,7 @@ class Copa(HFTask): ...@@ -202,7 +188,7 @@ class Copa(HFTask):
class MultiRC(HFTask): class MultiRC(HFTask):
VERSION = 0 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "multirc" DATASET_NAME = "multirc"
...@@ -215,10 +201,6 @@ class MultiRC(HFTask): ...@@ -215,10 +201,6 @@ class MultiRC(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return "READING COMPREHENSION ANSWER KEY"
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:" return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
...@@ -228,7 +210,7 @@ class MultiRC(HFTask): ...@@ -228,7 +210,7 @@ class MultiRC(HFTask):
@staticmethod @staticmethod
def format_answer(answer, label): def format_answer(answer, label):
label_str = "yes" if label else "no" label_str = "yes" if label else "no"
return f"{label_str}, {answer}" return f"{answer}\nIs the answer correct? {label_str}"
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
true_choice = self.format_answer(answer=doc["answer"], label=True) true_choice = self.format_answer(answer=doc["answer"], label=True)
...@@ -240,7 +222,8 @@ class MultiRC(HFTask): ...@@ -240,7 +222,8 @@ class MultiRC(HFTask):
return ll_true_choice, ll_false_choice return ll_true_choice, ll_false_choice
def process_results(self, doc, results): def process_results(self, doc, results):
pred = np.argmax(results) ll_true_choice, ll_false_choice = results
pred = ll_true_choice > ll_false_choice
return { return {
"acc": (pred, doc) "acc": (pred, doc)
} }
...@@ -270,10 +253,6 @@ class ReCoRD(HFTask): ...@@ -270,10 +253,6 @@ class ReCoRD(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return ""
def training_docs(self): def training_docs(self):
# In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing. # In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing.
# Each doc consists of multiple answer candidates, each of which is scored yes/no. # Each doc consists of multiple answer candidates, each of which is scored yes/no.
...@@ -363,10 +342,6 @@ class WordsInContext(HFTask): ...@@ -363,10 +342,6 @@ class WordsInContext(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out actual description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \ return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nAnswer:".format( " two sentences above?\nAnswer:".format(
...@@ -432,12 +407,6 @@ class SGWinogradSchemaChallenge(HFTask): ...@@ -432,12 +407,6 @@ class SGWinogradSchemaChallenge(HFTask):
] ]
return self._training_docs return self._training_docs
def fewshot_description(self):
return "Final Exam with Answer Key\n" \
"Instructions: Please carefully read the following passages. " \
"For each passage, you must identify which noun the pronoun marked in *bold*" \
" refers to.\n====="
def doc_to_text(self, doc): def doc_to_text(self, doc):
raw_passage = doc["text"] raw_passage = doc["text"]
# NOTE: HuggingFace span indices are word-based not character-based. # NOTE: HuggingFace span indices are word-based not character-based.
......
...@@ -166,12 +166,6 @@ class GeneralTranslationTask(Task): ...@@ -166,12 +166,6 @@ class GeneralTranslationTask(Task):
"ter": False, "ter": False,
} }
def fewshot_description(self):
language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"Translate these {src_lang} phrases to {tar_lang}."
def __str__(self): def __str__(self):
language_codes = self.sacrebleu_language_pair.split("-") language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0]) src_lang = code_to_language(language_codes[0])
......
...@@ -12,7 +12,7 @@ class TriviaQA(Task): ...@@ -12,7 +12,7 @@ class TriviaQA(Task):
def download(self): def download(self):
if not os.path.exists('data/triviaqa/unfiltered-web-train.jsonl'): if not os.path.exists('data/triviaqa/unfiltered-web-train.jsonl'):
os.makedirs("data/triviaqa/", exist_ok=True) os.makedirs("data/triviaqa/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz", "data/triviaqa/triviaqa-unfiltered.tar.gz", "adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e") download_file("http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz", local_file="data/triviaqa/triviaqa-unfiltered.tar.gz", expected_checksum="adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e")
sh(""" sh("""
cd data/triviaqa/ cd data/triviaqa/
tar -xf triviaqa-unfiltered.tar.gz tar -xf triviaqa-unfiltered.tar.gz
...@@ -36,10 +36,6 @@ class TriviaQA(Task): ...@@ -36,10 +36,6 @@ class TriviaQA(Task):
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Question: {doc['Question']}\nAnswer:" return f"Question: {doc['Question']}\nAnswer:"
...@@ -57,7 +53,6 @@ class TriviaQA(Task): ...@@ -57,7 +53,6 @@ class TriviaQA(Task):
return ret return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['Answer']['Aliases']): for alias in self._remove_prefixes(doc['Answer']['Aliases']):
......
...@@ -58,7 +58,7 @@ class TruthfulQAMultipleChoice(Task): ...@@ -58,7 +58,7 @@ class TruthfulQAMultipleChoice(Task):
Path.mkdir(self.DATASET_PATH, parents=True) Path.mkdir(self.DATASET_PATH, parents=True)
mc_url = "https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json" mc_url = "https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json"
checksum = "6eb4125d25750c0145c4be2dce00440736684ab6f74ce6bff2139571cc758954" checksum = "6eb4125d25750c0145c4be2dce00440736684ab6f74ce6bff2139571cc758954"
download_file(mc_url, str(self.DATASET_PATH / "mc_task.json"), checksum) download_file(mc_url, local_file=str(self.DATASET_PATH / "mc_task.json"), expected_checksum=checksum)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -85,9 +85,14 @@ class TruthfulQAMultipleChoice(Task): ...@@ -85,9 +85,14 @@ class TruthfulQAMultipleChoice(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " return " "
def fewshot_context(self, doc, num_fewshot, provide_description, rnd): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting." assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd) return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -163,7 +168,7 @@ class TruthfulQAGeneration(Task): ...@@ -163,7 +168,7 @@ class TruthfulQAGeneration(Task):
Path.mkdir(self.DATASET_PATH, parents=True) Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv" url = "https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv"
checksum = "8d7dd15f033196140f032d97d30f037da7a7b1192c3f36f9937c1850925335a2" checksum = "8d7dd15f033196140f032d97d30f037da7a7b1192c3f36f9937c1850925335a2"
download_file(url, str(self.DATASET_PATH / "TruthfulQA.csv"), checksum) download_file(url, local_file=str(self.DATASET_PATH / "TruthfulQA.csv"), expected_checksum=checksum)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -217,9 +222,14 @@ class TruthfulQAGeneration(Task): ...@@ -217,9 +222,14 @@ class TruthfulQAGeneration(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " return " "
def fewshot_context(self, doc, num_fewshot, provide_description, rnd): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting." assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd) return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
......
...@@ -29,7 +29,7 @@ class WordUnscrambleTask(Task): ...@@ -29,7 +29,7 @@ class WordUnscrambleTask(Task):
if not file.exists(): if not file.exists():
rawfile = file.parent / (file.name + ".gz") rawfile = file.parent / (file.name + ".gz")
base_url = "https://raw.githubusercontent.com/openai/gpt-3/master/data" base_url = "https://raw.githubusercontent.com/openai/gpt-3/master/data"
download_file(f"{base_url}/{self.FILENAME}.gz", str(rawfile), self.CHECKSUM) download_file(f"{base_url}/{self.FILENAME}.gz", local_file=str(rawfile), expected_checksum=self.CHECKSUM)
extract_gzip(gz=rawfile, to=file) extract_gzip(gz=rawfile, to=file)
def has_training_docs(self): def has_training_docs(self):
...@@ -45,9 +45,6 @@ class WordUnscrambleTask(Task): ...@@ -45,9 +45,6 @@ class WordUnscrambleTask(Task):
file = self.BASE_PATH / self.FILENAME file = self.BASE_PATH / self.FILENAME
return (json.loads(line) for line in open(file).read().splitlines()) return (json.loads(line) for line in open(file).read().splitlines())
def fewshot_description(self):
return "Please unscramble the letters into a word, and write that word:"
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["context"] return doc["context"]
......
...@@ -17,10 +17,6 @@ class WebQs(HFTask): ...@@ -17,10 +17,6 @@ class WebQs(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def fewshot_description(self):
# TODO: figure out description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc['question'] + '\nAnswer:'
...@@ -41,7 +37,6 @@ class WebQs(HFTask): ...@@ -41,7 +37,6 @@ class WebQs(HFTask):
return ret return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['answers']): for alias in self._remove_prefixes(doc['answers']):
......
...@@ -41,18 +41,14 @@ def wikitext_detokenizer(string): ...@@ -41,18 +41,14 @@ def wikitext_detokenizer(string):
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = 0 VERSION = 1
def download(self): def download(self):
if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'): if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
os.makedirs("data/wikitext/", exist_ok=True) os.makedirs("data/wikitext/", exist_ok=True)
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", "data/wikitext/wikitext-2-raw-v1.zip", "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11") download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", local_file="data/wikitext/wikitext-2-raw-v1.zip", expected_checksum="ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip") sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def has_validation_docs(self): def has_validation_docs(self):
return True return True
......
...@@ -35,10 +35,6 @@ class Winogrande(HFTask): ...@@ -35,10 +35,6 @@ class Winogrande(HFTask):
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc["sentence"] return doc["sentence"]
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
@classmethod @classmethod
def partial_context(cls, doc, option): def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option # Substitute the pronoun in the sentence with the specified option
......
...@@ -53,10 +53,6 @@ class WinogradSchemaChallenge273(HFTask): ...@@ -53,10 +53,6 @@ class WinogradSchemaChallenge273(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# NOTE: `super().fewshot_examples` samples from training docs which are # NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset. # not available for this test-set-only dataset.
......
import os import os
import re import re
import collections import collections
import functools
import inspect
class ExitCodeError(Exception): class ExitCodeError(Exception):
...@@ -139,3 +141,17 @@ class Reorderer: ...@@ -139,3 +141,17 @@ class Reorderer:
assert all(cov) assert all(cov)
return res return res
def positional_deprecated(fn):
"""
A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`.
"""
@functools.wraps(fn)
def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0:
print(f"WARNING: using {fn.__name__} with positional arguments is "
"deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!")
return fn(*args, **kwargs)
return _wrapper
...@@ -68,12 +68,16 @@ def pattern_match(patterns, source_list): ...@@ -68,12 +68,16 @@ def pattern_match(patterns, source_list):
return list(task_names) return list(task_names)
def main(): def main():
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args()
def main():
args = parse_args() args = parse_args()
if not ensure_correct_decontamination_params(args): if not ensure_correct_decontamination_params(args):
return return
# assert not args.provide_description # not implemented assert not args.provide_description # not implemented
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
...@@ -100,11 +104,25 @@ def main(): ...@@ -100,11 +104,25 @@ def main():
print(f"Selected Tasks: {task_names}") print(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate(args.model, args.model_args, task_names, description_dict = {}
num_fewshot=args.num_fewshot, batch_size=args.batch_size, if args.description_dict_path:
device=args.device, no_cache=args.no_cache, limit=args.limit, with open(args.description_dict_path, 'r') as f:
decontaminate=args.decontaminate, ngrams_path=args.ngrams_path, description_dict = json.load(f)
ngrams_n_size=args.ngrams_n_size)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
device=args.device,
no_cache=args.no_cache,
limit=args.limit,
description_dict=description_dict,
decontaminate=args.decontaminate,
ngrams_path=args.ngrams_path,
ngrams_n_size=args.ngrams_n_size
)
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
...@@ -113,8 +131,12 @@ def main(): ...@@ -113,8 +131,12 @@ def main():
with open(args.output_path, "w") as f: with open(args.output_path, "w") as f:
f.write(dumped) f.write(dumped)
print(f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}") print(
f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
)
print(evaluator.make_table(results)) print(evaluator.make_table(results))
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -51,7 +51,14 @@ def main(): ...@@ -51,7 +51,14 @@ def main():
values = [] values = []
for taskname in task_list.split(","): for taskname in task_list.split(","):
lm.tokencost = 0 lm.tokencost = 0
evaluator.evaluate(lm, {taskname: tasks.get_task(taskname)()}, False, 0, None, bootstrap_iters=10) evaluator.evaluate(
lm=lm,
task_dict={taskname: tasks.get_task(taskname)()},
num_fewshot=0,
limit=None,
bootstrap_iters=10,
description_dict=None
)
print(taskname, lm.tokencost) print(taskname, lm.tokencost)
values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06]) values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06])
......
import json
import numpy as np
import random
import logging
from lm_eval import models, tasks, evaluator, base
logging.getLogger("openai").setLevel(logging.WARNING)
fewshot_descriptions = [
"foo",
"bar"
]
task = "lambada"
num_fewshot = 0
model = "gpt2"
model_args = ""
limit = None
no_cache = False
class CustomDescTask:
def __init__(self, task, desc):
self.task = task
self.desc = desc
def fewshot_description():
return self.desc
self.task.fewshot_description = fewshot_description
def __getattr__(self, attr):
return getattr(self.task, attr)
def main():
random.seed(42)
np.random.seed(42)
lm = models.get_model(model).create_from_arg_string(model_args)
if limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if not no_cache:
lm = base.CachingLM(lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_') + '.db')
task_dict = tasks.get_task_dict([task])
for desc in fewshot_descriptions:
custom_task_dict = {k: CustomDescTask(v, desc) for k, v in task_dict.items()}
results = evaluator.evaluate(lm, custom_task_dict, True, num_fewshot, limit)
dumped = json.dumps(results, indent=2)
print('Description:', desc)
print(dumped)
# MAKE TABLE
from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter()
writer.headers = ["Task", "Metric", "Value"]
values = []
for k, dic in results.items():
for m, v in dic.items():
values.append([k, m, '%.4f' % v])
k = ""
writer.value_matrix = values
print(writer.dumps())
if __name__ == "__main__":
main()
...@@ -9,7 +9,6 @@ for tname, Task in tasks.TASK_REGISTRY.items():#[('record', tasks.superglue.ReCo ...@@ -9,7 +9,6 @@ for tname, Task in tasks.TASK_REGISTRY.items():#[('record', tasks.superglue.ReCo
print('#', tname) print('#', tname)
docs = islice(task.validation_docs() if task.has_validation_docs() else task.test_docs(), ct) docs = islice(task.validation_docs() if task.has_validation_docs() else task.test_docs(), ct)
print() print()
print('**Zero-Shot Prompt**:', "\n```\n" + task.fewshot_description() + "\n```\n")
for i in range(ct): for i in range(ct):
print() print()
doc = next(docs) doc = next(docs)
......
import argparse import argparse
import numpy as np import numpy as np
import json
import os import os
import random import random
from lm_eval import tasks from lm_eval import tasks
...@@ -17,6 +18,7 @@ def parse_args(): ...@@ -17,6 +18,7 @@ def parse_args():
parser.add_argument('--num_fewshot', type=int, default=1) parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--seed', type=int, default=42) parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_examples', type=int, default=1) parser.add_argument('--num_examples', type=int, default=1)
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args() return parser.parse_args()
...@@ -29,6 +31,12 @@ def main(): ...@@ -29,6 +31,12 @@ def main():
else: else:
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict(task_names)
description_dict = {}
if args.description_dict_path:
with open(args.description_dict_path, 'r') as f:
description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True) os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
rnd = random.Random() rnd = random.Random()
...@@ -47,14 +55,16 @@ def main(): ...@@ -47,14 +55,16 @@ def main():
docs = join_iters(iters) docs = join_iters(iters)
description = description_dict[task_name] if description_dict and task_name in description_dict else ""
with open(os.path.join(args.output_base_path, task_name), "w") as f: with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs): for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs):
f.write(EXAMPLE_DIVIDER.format(i=i)) f.write(EXAMPLE_DIVIDER.format(i=i))
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc,
provide_description=args.provide_description,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
rnd=rnd rnd=rnd,
description=description
) )
f.write(ctx + "\n") f.write(ctx + "\n")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment