Unverified Commit 761f0087 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #560 from EleutherAI/dataset-metric-log

Dataset metric log [WIP]
parents 232632c6 ae4d9ed2
group:
- super-glue-lm-eval-v1
task: "default"
dataset_path: super_glue
dataset_name: boolq
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
import sklearn
import numpy as np
def cb_multi_fi(items):
preds, golds = zip(*items)
preds = np.array(preds)
golds = np.array(golds)
f11 = sklearn.metrics.f1_score(y_true=golds == 0, y_pred=preds == 0)
f12 = sklearn.metrics.f1_score(y_true=golds == 1, y_pred=preds == 1)
f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2)
avg_f1 = np.mean([f11, f12, f13])
return avg_f1
group:
- super-glue-lm-eval-v1
task: "default"
dataset_path: super_glue
dataset_name: cb
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}"
metric_list:
- metric: acc
- metric: f1
aggregation: !function "aggregate.cb_multi_fi"
...@@ -10,11 +10,12 @@ high quality distant supervision for answering the questions. ...@@ -10,11 +10,12 @@ high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/ Homepage: https://nlp.cs.washington.edu/triviaqa/
""" """
import inspect import inspect
# import lm_eval.datasets.triviaqa.triviaqa # import lm_eval.datasets.triviaqa.triviaqa
import string import string
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.register import register_task from lm_eval.api.registry import register_task
from lm_eval.api.metrics import mean from lm_eval.api.metrics import mean
_CITATION = """ _CITATION = """
...@@ -29,10 +30,11 @@ _CITATION = """ ...@@ -29,10 +30,11 @@ _CITATION = """
} }
""" """
@register_task("triviaqa") @register_task("triviaqa")
class TriviaQA(Task): class TriviaQA(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "trivia_qa" #inspect.getfile(lm_eval.datasets.triviaqa.triviaqa) DATASET_PATH = "trivia_qa" # inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_NAME = "unfiltered.nocontext" DATASET_NAME = "unfiltered.nocontext"
OUTPUT_TYPE = "greedy_until" OUTPUT_TYPE = "greedy_until"
...@@ -90,18 +92,29 @@ class TriviaQA(Task): ...@@ -90,18 +92,29 @@ class TriviaQA(Task):
continuation = Instance( continuation = Instance(
request_type=self.OUTPUT_TYPE, request_type=self.OUTPUT_TYPE,
doc=doc, doc=doc,
arguments=(ctx, { arguments=(
ctx,
{
"until": ["\n", ".", ","], "until": ["\n", ".", ","],
"do_sample": False, "do_sample": False,
}), },
),
idx=0, idx=0,
**kwargs, **kwargs,
) )
return continuation return continuation
def process_results(self, doc, results): def process_results(self, doc, results):
continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation)) continuation = (
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])] results[0]
.strip()
.lower()
.translate(str.maketrans("", "", string.punctuation))
)
list_of_candidates = [
alias.lower().translate(str.maketrans("", "", string.punctuation))
for alias in self._remove_prefixes(doc["answer"]["aliases"])
]
return {"em": float(continuation in list_of_candidates)} return {"em": float(continuation in list_of_candidates)}
def aggregation(self): def aggregation(self):
......
"""
Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf
The WikiText language modeling dataset is a collection of over 100 million tokens
extracted from the set of verified Good and Featured articles on Wikipedia.
NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
import re
from lm_eval.api.task import PerplexityTask
from lm_eval.api.registry import register_task, register_group
_CITATION = """
@misc{merity2016pointer,
title={Pointer Sentinel Mixture Models},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
year={2016},
eprint={1609.07843},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
def wikitext_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
@register_task("wikitext")
class WikiText(PerplexityTask):
VERSION = "2.0"
DATASET_PATH = "EleutherAI/wikitext_document_level"
DATASET_NAME = "wikitext-2-raw-v1"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
return map(self._process_doc, self.dataset["train"])
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
return doc["page"]
def doc_to_target(self, doc):
return wikitext_detokenizer(doc)
def should_decontaminate(self):
return True
def count_words(self, doc):
# count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc))
import re import re
def wikitext_detokenizer(doc): def wikitext_detokenizer(doc):
string = doc["page"] string = doc["page"]
# contractions # contractions
......
group: group:
- wikitext_group - wikitext_group
task: wikitext_yaml task: default
dataset_path: EleutherAI/wikitext_document_level dataset_path: EleutherAI/wikitext_document_level
dataset_name: wikitext-2-raw-v1 dataset_name: wikitext-2-raw-v1
output_type: loglikelihood_rolling output_type: loglikelihood_rolling
...@@ -14,11 +14,5 @@ should_decontaminate: true ...@@ -14,11 +14,5 @@ should_decontaminate: true
doc_to_decontamination_query: "{{page}}" doc_to_decontamination_query: "{{page}}"
metric_list: metric_list:
- metric: word_perplexity - metric: word_perplexity
aggregation: weighted_perplexity
higher_is_better: false
- metric: byte_perplexity - metric: byte_perplexity
aggregation: weighted_perplexity
higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
...@@ -158,7 +158,15 @@ def make_table(result_dict): ...@@ -158,7 +158,15 @@ def make_table(result_dict):
md_writer = MarkdownTableWriter() md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter() latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Filter", "Metric", "Value", "", "Stderr"] md_writer.headers = ["Task", "Version", "Filter", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Filter", "Metric", "Value", "", "Stderr"] latex_writer.headers = [
"Task",
"Version",
"Filter",
"Metric",
"Value",
"",
"Stderr",
]
values = [] values = []
...@@ -166,7 +174,7 @@ def make_table(result_dict): ...@@ -166,7 +174,7 @@ def make_table(result_dict):
version = result_dict["versions"][k] version = result_dict["versions"][k]
for (mf), v in dic.items(): for (mf), v in dic.items():
m, _, f = mf.partition(",") m, _, f = mf.partition(",")
print(m,f) print(m, f)
if m.endswith("_stderr"): if m.endswith("_stderr"):
continue continue
......
...@@ -19,7 +19,9 @@ class MultiChoice: ...@@ -19,7 +19,9 @@ class MultiChoice:
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value)) eval_logger.warning("{} is not in task list.".format(value))
# eval_logger.info(f"{choices} is this") eval_logger.info(f"Available tasks to choose:")
for choice in self.choices:
eval_logger.info(f" {choice}")
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment