Commit 337419b8 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add configurable ppl task

parent 487f7811
...@@ -25,4 +25,7 @@ HIGHER_IS_BETTER_REGISTRY = { ...@@ -25,4 +25,7 @@ HIGHER_IS_BETTER_REGISTRY = {
"acc": True, "acc": True,
"acc_norm": True, "acc_norm": True,
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
} }
\ No newline at end of file
...@@ -12,7 +12,9 @@ import evaluate ...@@ -12,7 +12,9 @@ import evaluate
AGGREGATION_REGISTRY = {} AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = { METRIC_REGISTRY = {
"acc": None, "acc": None,
"acc_norm": None, "acc_norm": None,
"word_perplexity": None,
"byte_perplexity": None,
} }
...@@ -170,10 +172,13 @@ def weighted_mean(items): ...@@ -170,10 +172,13 @@ def weighted_mean(items):
@register_metric("weighted_perplexity") @register_metric("weighted_perplexity")
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items): def weighted_perplexity(items):
return math.exp(-weighted_mean(items)) return math.exp(-weighted_mean(items))
@register_metric("bits_per_byte")
@register_aggregation("bits_per_byte")
def bits_per_byte(items): def bits_per_byte(items):
return -weighted_mean(items) / math.log(2) return -weighted_mean(items) / math.log(2)
......
...@@ -330,6 +330,16 @@ class Task(abc.ABC): ...@@ -330,6 +330,16 @@ class Task(abc.ABC):
""" """
pass pass
@classmethod
def count_bytes(cls, doc):
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
return len(doc.encode("utf-8"))
@classmethod
def count_words(cls, doc):
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, rnd=None): def fewshot_context(self, doc, num_fewshot, rnd=None):
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
...@@ -555,10 +565,17 @@ class ConfigurableTask(Task): ...@@ -555,10 +565,17 @@ class ConfigurableTask(Task):
ll, is_greedy = results ll, is_greedy = results
result_dict = {"perplexity": ll, "accuracy": int(is_greedy)} result_dict = {"perplexity": ll, "accuracy": int(is_greedy)}
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
pass (loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
return {
"word_perplexity": (loglikelihood, words),
"byte_perplexity": (loglikelihood, bytes_),
"bits_per_byte": (loglikelihood, bytes_),
}
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy lls = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy TODO: keep is_greedy to report exact_match as well on multiple choice probs
gold = int(self.doc_to_target(doc)) # TODO: if `gold` here is an integer/ds label obj, map it to answer_choice gold = int(self.doc_to_target(doc))
# TODO: remove dependence on "gold" and "choices" columns # TODO: remove dependence on "gold" and "choices" columns
acc = 1.0 if np.argmax(lls) == gold else 0.0 acc = 1.0 if np.argmax(lls) == gold else 0.0
...@@ -693,8 +710,8 @@ class PerplexityTask(Task, abc.ABC): ...@@ -693,8 +710,8 @@ class PerplexityTask(Task, abc.ABC):
def process_results(self, doc, results): def process_results(self, doc, results):
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(doc) words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(doc) bytes_ = self.count_bytes(self.doc_to_target(doc))
return { return {
"word_perplexity": (loglikelihood, words), "word_perplexity": (loglikelihood, words),
"byte_perplexity": (loglikelihood, bytes_), "byte_perplexity": (loglikelihood, bytes_),
......
...@@ -41,7 +41,7 @@ from . import lambada ...@@ -41,7 +41,7 @@ from . import lambada
# from . import hendrycks_math # from . import hendrycks_math
# from . import cbt # from . import cbt
# from . import lambada_cloze # from . import lambada_cloze
# from . import pile from . import pile
from . import wikitext from . import wikitext
# from . import lambada_multilingual # from . import lambada_multilingual
# from . import mutual # from . import mutual
......
...@@ -6,7 +6,7 @@ validation_split: validation ...@@ -6,7 +6,7 @@ validation_split: validation
test_split: test test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what) template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:" doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
"""
The Pile: An 800GB Dataset of Diverse Text for Language Modeling
https://arxiv.org/pdf/2101.00027.pdf
The Pile is a 825 GiB diverse, open source language modelling data set that consists
of 22 smaller, high-quality datasets combined together. To score well on Pile
BPB (bits per byte), a model must be able to understand many disparate domains
including books, github repositories, webpages, chat logs, and medical, physics,
math, computer science, and philosophy papers.
Homepage: https://pile.eleuther.ai/
"""
from lm_eval.api.task import PerplexityTask, register_task
_CITATION = """
@article{pile,
title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling},
author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor},
journal={arXiv preprint arXiv:2101.00027},
year={2020}
}
"""
class PilePerplexityTask(PerplexityTask):
VERSION = "2.0"
DATASET_PATH = "EleutherAI/the_pile"
DATASET_NAME = None
def has_training_docs(self):
return False
def test_docs(self):
for doc in self.dataset["train"].select(range(100)):
yield doc
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def doc_to_target(self, doc):
return doc["text"]
# def validation_docs(self):
# for doc in self.dataset["validation"]:
# yield doc["text"]
# def test_docs(self):
# for doc in self.dataset["test"]:
# yield doc["text"]
class PileArxiv(PilePerplexityTask):
DATASET_NAME = "pile_arxiv"
class PileBooks3(PilePerplexityTask):
DATASET_NAME = "pile_books3"
class PileBookCorpus2(PilePerplexityTask):
DATASET_NAME = "pile_bookcorpus2"
class PileDmMathematics(PilePerplexityTask):
DATASET_NAME = "pile_dm-mathematics"
@register_task("pile_enron")
class PileEnron(PilePerplexityTask):
DATASET_NAME = "enron_emails"
class PileEuroparl(PilePerplexityTask):
DATASET_NAME = "pile_europarl"
class PileFreeLaw(PilePerplexityTask):
DATASET_NAME = "pile_freelaw"
class PileGithub(PilePerplexityTask):
DATASET_NAME = "pile_github"
class PileGutenberg(PilePerplexityTask):
DATASET_NAME = "pile_gutenberg"
class PileHackernews(PilePerplexityTask):
DATASET_NAME = "pile_hackernews"
class PileNIHExporter(PilePerplexityTask):
DATASET_NAME = "pile_nih-exporter"
class PileOpenSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_opensubtitles"
class PileOpenWebText2(PilePerplexityTask):
DATASET_NAME = "pile_openwebtext2"
class PilePhilPapers(PilePerplexityTask):
DATASET_NAME = "pile_philpapers"
class PilePileCc(PilePerplexityTask):
DATASET_NAME = "pile_pile-cc"
class PilePubmedAbstracts(PilePerplexityTask):
DATASET_NAME = "pile_pubmed-abstracts"
class PilePubmedCentral(PilePerplexityTask):
DATASET_NAME = "pile_pubmed-central"
class PileStackExchange(PilePerplexityTask):
DATASET_NAME = "pile_stackexchange"
class PileUspto(PilePerplexityTask):
DATASET_NAME = "pile_upsto"
class PileUbuntuIrc(PilePerplexityTask):
DATASET_NAME = "pile_ubuntu-irc"
class PileWikipedia(PilePerplexityTask):
DATASET_NAME = "pile_wikipedia"
class PileYoutubeSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_youtubesubtitles"
\ No newline at end of file
dataset_path: EleutherAI/the_pile
dataset_name: enron_emails
output_type: loglikelihood_rolling
test_split: train
template_aliases: ""
doc_to_text: ""
doc_to_target: "{{text}}"
should_decontaminate: true
doc_to_decontamination_query: "{{text}}"
metric_list:
- metric: word_perplexity
aggregation: weighted_perplexity
higher_is_better: false
- metric: byte_perplexity
aggregation: weighted_perplexity
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment