Unverified Commit 761f0087 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #560 from EleutherAI/dataset-metric-log

Dataset metric log [WIP]
parents 232632c6 ae4d9ed2
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
group:
group:
- piqa_yaml_grp
task: piqa_yaml
dataset_path: piqa
......@@ -9,7 +9,7 @@ validation_split: validation
test_split: null
template_aliases: "{% set question = goal %}{% set answer_choices = [sol1, sol2] %}{% set gold = label %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......
group:
group:
- sciq_yaml_grp
task: sciq_yaml
dataset_path: sciq
......@@ -9,7 +9,7 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: true
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
\ No newline at end of file
higher_is_better: true
group:
- super-glue-lm-eval-v1
task: "default"
dataset_path: super_glue
dataset_name: boolq
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
import sklearn
import numpy as np
def cb_multi_fi(items):
preds, golds = zip(*items)
preds = np.array(preds)
golds = np.array(golds)
f11 = sklearn.metrics.f1_score(y_true=golds == 0, y_pred=preds == 0)
f12 = sklearn.metrics.f1_score(y_true=golds == 1, y_pred=preds == 1)
f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2)
avg_f1 = np.mean([f11, f12, f13])
return avg_f1
group:
- super-glue-lm-eval-v1
task: "default"
dataset_path: super_glue
dataset_name: cb
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}"
metric_list:
- metric: acc
- metric: f1
aggregation: !function "aggregate.cb_multi_fi"
......@@ -10,11 +10,12 @@ high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
import inspect
# import lm_eval.datasets.triviaqa.triviaqa
import string
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.register import register_task
from lm_eval.api.registry import register_task
from lm_eval.api.metrics import mean
_CITATION = """
......@@ -29,10 +30,11 @@ _CITATION = """
}
"""
@register_task("triviaqa")
class TriviaQA(Task):
VERSION = 1
DATASET_PATH = "trivia_qa" #inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_PATH = "trivia_qa" # inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_NAME = "unfiltered.nocontext"
OUTPUT_TYPE = "greedy_until"
......@@ -90,18 +92,29 @@ class TriviaQA(Task):
continuation = Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, {
"until": ["\n", ".", ","],
"do_sample": False,
}),
arguments=(
ctx,
{
"until": ["\n", ".", ","],
"do_sample": False,
},
),
idx=0,
**kwargs,
)
return continuation
def process_results(self, doc, results):
continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
continuation = (
results[0]
.strip()
.lower()
.translate(str.maketrans("", "", string.punctuation))
)
list_of_candidates = [
alias.lower().translate(str.maketrans("", "", string.punctuation))
for alias in self._remove_prefixes(doc["answer"]["aliases"])
]
return {"em": float(continuation in list_of_candidates)}
def aggregation(self):
......@@ -110,4 +123,4 @@ class TriviaQA(Task):
}
def higher_is_better(self):
return {"em": True}
\ No newline at end of file
return {"em": True}
"""
Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf
The WikiText language modeling dataset is a collection of over 100 million tokens
extracted from the set of verified Good and Featured articles on Wikipedia.
NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
import re
from lm_eval.api.task import PerplexityTask
from lm_eval.api.registry import register_task, register_group
_CITATION = """
@misc{merity2016pointer,
title={Pointer Sentinel Mixture Models},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
year={2016},
eprint={1609.07843},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
def wikitext_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
@register_task("wikitext")
class WikiText(PerplexityTask):
VERSION = "2.0"
DATASET_PATH = "EleutherAI/wikitext_document_level"
DATASET_NAME = "wikitext-2-raw-v1"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
return map(self._process_doc, self.dataset["train"])
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
return doc["page"]
def doc_to_target(self, doc):
return wikitext_detokenizer(doc)
def should_decontaminate(self):
return True
def count_words(self, doc):
# count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc))
......@@ -33,4 +33,4 @@ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-
- [x] Is in Eval-harness v1.0 ?
- [x] Has been checked for regression from v1.0?
- [ ] Has been checked for equivalence with original paper methodology?
- [ ] "Main" checked variant clearly denoted?
\ No newline at end of file
- [ ] "Main" checked variant clearly denoted?
import re
def wikitext_detokenizer(doc):
string = doc["page"]
# contractions
......
group:
- wikitext_group
task: wikitext_yaml
task: default
dataset_path: EleutherAI/wikitext_document_level
dataset_name: wikitext-2-raw-v1
output_type: loglikelihood_rolling
......@@ -14,11 +14,5 @@ should_decontaminate: true
doc_to_decontamination_query: "{{page}}"
metric_list:
- metric: word_perplexity
aggregation: weighted_perplexity
higher_is_better: false
- metric: byte_perplexity
aggregation: weighted_perplexity
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
......@@ -158,7 +158,15 @@ def make_table(result_dict):
md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Filter", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Filter", "Metric", "Value", "", "Stderr"]
latex_writer.headers = [
"Task",
"Version",
"Filter",
"Metric",
"Value",
"",
"Stderr",
]
values = []
......@@ -166,7 +174,7 @@ def make_table(result_dict):
version = result_dict["versions"][k]
for (mf), v in dic.items():
m, _, f = mf.partition(",")
print(m,f)
print(m, f)
if m.endswith("_stderr"):
continue
......
......@@ -19,7 +19,9 @@ class MultiChoice:
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value))
# eval_logger.info(f"{choices} is this")
eval_logger.info(f"Available tasks to choose:")
for choice in self.choices:
eval_logger.info(f" {choice}")
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment