Unverified Commit 2da74953 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #438 from EleutherAI/configurable-tasks

Configurable-Tasks
parents fa686d04 d2b16757
...@@ -5,10 +5,9 @@ validation_split: validation ...@@ -5,10 +5,9 @@ validation_split: validation
template_aliases: "{% set hypo = hypothesis %}" template_aliases: "{% set hypo = hypothesis %}"
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypo}}\"? Yes, no, or maybe?" doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypo}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}" doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
metric_list: [ metric_list:
[exact_match, mean, true] - metric: exact_match
] aggregation: mean
# filters: [ higher_is_better: true
# ["none", ["take_first"]] ignore_case: true
# ] ignore_punctuation: true
\ No newline at end of file
from . import metrics
METRIC_REGISTRY = {
"matthews_corrcoef": metrics.matthews_corrcoef,
"f1_score": metrics.f1_score,
"perplexity": metrics.perplexity,
"bleu": metrics.bleu,
"chrf": metrics.chrf,
"ter": metrics.ter,
}
AGGREGATION_REGISTRY = {
"mean": metrics.mean,
"median": metrics.median,
"perplexity": metrics.perplexity,
}
HIGHER_IS_BETTER_REGISTRY = {
"matthews_corrcoef": True,
"f1_score": True,
"perplexity": False,
"bleu": True,
"chrf": True,
"ter": False,
"acc": True,
"acc_norm": True,
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
\ No newline at end of file
...@@ -11,6 +11,7 @@ class Instance: ...@@ -11,6 +11,7 @@ class Instance:
resps: list = field(default_factory=list) resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict) filtered_resps: dict = field(default_factory=dict)
# initialized after init
task_name: str = None task_name: str = None
doc_id: str = None doc_id: str = None
repeats: str = None repeats: str = None
......
...@@ -10,7 +10,12 @@ import evaluate ...@@ -10,7 +10,12 @@ import evaluate
AGGREGATION_REGISTRY = {} AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {} METRIC_REGISTRY = {
"acc": None,
"acc_norm": None,
"word_perplexity": None,
"byte_perplexity": None,
}
def register_metric(name): def register_metric(name):
...@@ -45,6 +50,7 @@ searching in HF Evaluate library...") ...@@ -45,6 +50,7 @@ searching in HF Evaluate library...")
def register_aggregation(name): def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn): def decorate(fn):
assert ( assert (
name not in AGGREGATION_REGISTRY name not in AGGREGATION_REGISTRY
...@@ -155,6 +161,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): ...@@ -155,6 +161,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
@register_metric("perplexity") @register_metric("perplexity")
@register_aggregation("perplexity")
def perplexity(items): def perplexity(items):
return math.exp(-mean(items)) return math.exp(-mean(items))
...@@ -165,10 +172,13 @@ def weighted_mean(items): ...@@ -165,10 +172,13 @@ def weighted_mean(items):
@register_metric("weighted_perplexity") @register_metric("weighted_perplexity")
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items): def weighted_perplexity(items):
return math.exp(-weighted_mean(items)) return math.exp(-weighted_mean(items))
@register_metric("bits_per_byte")
@register_aggregation("bits_per_byte")
def bits_per_byte(items): def bits_per_byte(items):
return -weighted_mean(items) / math.log(2) return -weighted_mean(items) / math.log(2)
......
import abc import abc
from typing import Union
from lm_eval import utils from lm_eval import utils
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
def register_model(name): def register_model(*names):
# TODO: should fairseq/elk be cited for this design pattern? # either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls): def decorate(cls):
assert ( for name in names:
issubclass(cls, LM) assert (
), f"Model '{name}' ({cls.__name__}) must extend LM class" issubclass(cls, LM)
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert ( assert (
name not in MODEL_REGISTRY name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model!" ), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
MODEL_REGISTRY[name] = cls MODEL_REGISTRY[name] = cls
return cls return cls
return decorate return decorate
......
...@@ -5,15 +5,19 @@ import re ...@@ -5,15 +5,19 @@ import re
import evaluate import evaluate
import random import random
import itertools import itertools
import functools
import datasets import datasets
import numpy as np import numpy as np
from typing import List, Union from typing import List, Union
from lm_eval.api.metrics import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api import HIGHER_IS_BETTER_REGISTRY
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
from lm_eval import utils from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.api import samplers from lm_eval.api import samplers
...@@ -34,17 +38,20 @@ class TaskConfig(dict): ...@@ -34,17 +38,20 @@ class TaskConfig(dict):
doc_to_text: str = "" doc_to_text: str = ""
doc_to_target: str = "" doc_to_target: str = ""
# aggregation: dict = None # TODO: remove, I think these 2 are obsolete w/ current metric_list impl.
# higher_is_better: dict = None
num_fewshot: int = 0 num_fewshot: int = 0
batch_size: int = 1 batch_size: int = 1
repeats: int = 1
metric_list: str = None metric_list: str = None
gold_alias: str = None gold_alias: str = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
delimiter: str = "\n\n" delimiter: str = "\n\n"
filters: str = None #TODO: need to make this typehint `list`? filters: str = None #TODO: need to make this typehint `list`?
normalization: str = None # TODO: add length-normalization of various types, mutual info normalization: str = None # TODO: add length-normalization of various types, mutual info
stop_sequences: list = None # TODO: allow passing of stop sequences to greedy gen. should_decontaminate: bool = False
doc_to_decontamination_query: str = None
use_prompt: str = None
def __post_init__(self): def __post_init__(self):
# allow user-specified aliases so that users can # allow user-specified aliases so that users can
...@@ -118,7 +125,8 @@ class Task(abc.ABC): ...@@ -118,7 +125,8 @@ class Task(abc.ABC):
filter_pipeline = build_filter_ensemble(name, components) filter_pipeline = build_filter_ensemble(name, components)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(self.training_docs(), self, rnd=random.Random()) # TODO: pass the correct docs in here
self.sampler = samplers.Sampler(self.fewshot_docs(), self, rnd=random.Random()) # TODO: pass the correct docs in here
def download(self, data_dir=None, cache_dir=None, download_mode=None): def download(self, data_dir=None, cache_dir=None, download_mode=None):
"""Downloads and returns the task dataset. """Downloads and returns the task dataset.
...@@ -189,6 +197,19 @@ class Task(abc.ABC): ...@@ -189,6 +197,19 @@ class Task(abc.ABC):
""" """
return [] return []
def fewshot_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
if self.has_training_docs():
return self.training_docs()
elif self.has_validation_docs():
return self.validation_docs()
else:
# TODO: should we allow this case to occur? / should raise a warning here
return self.test_docs()
def _process_doc(self, doc): def _process_doc(self, doc):
""" """
Override this to process (detokenize, strip, replace, etc.) individual Override this to process (detokenize, strip, replace, etc.) individual
...@@ -309,6 +330,16 @@ class Task(abc.ABC): ...@@ -309,6 +330,16 @@ class Task(abc.ABC):
""" """
pass pass
@classmethod
def count_bytes(cls, doc):
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
return len(doc.encode("utf-8"))
@classmethod
def count_words(cls, doc):
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, rnd=None): def fewshot_context(self, doc, num_fewshot, rnd=None):
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
...@@ -332,33 +363,33 @@ class Task(abc.ABC): ...@@ -332,33 +363,33 @@ class Task(abc.ABC):
labeled_examples = "" labeled_examples = ""
else: else:
# labeled_examples = self.sampler.get_context(doc, self._config.num_fewshot) labeled_examples = self.sampler.get_context(doc, self._config.num_fewshot)
# for sets with no training docs, draw from other set *but ensure no overlap with current doc* # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs(): # if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) # fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else: # else:
if self._fewshot_docs is None: # if self._fewshot_docs is None:
self._fewshot_docs = list( # self._fewshot_docs = list(
self.validation_docs() # self.validation_docs()
if self.has_validation_docs() # if self.has_validation_docs()
else self.test_docs() # else self.test_docs()
) # )
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) # fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # # get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] # fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = ( # labeled_examples = (
"\n\n".join( # "\n\n".join(
[ # [
self.doc_to_text(doc) + self.doc_to_target(doc) # self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex # for doc in fewshotex
] # ]
) # )
+ "\n\n" # + "\n\n"
) # )
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
return labeled_examples + example return labeled_examples + example
...@@ -372,13 +403,17 @@ class Task(abc.ABC): ...@@ -372,13 +403,17 @@ class Task(abc.ABC):
class ConfigurableTask(Task): class ConfigurableTask(Task):
VERSION = "2.0" VERSION = "2.0"
OUTPUT_TYPE = "greedy_until" OUTPUT_TYPE = None
def __init__( def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
): ):
self._config = TaskConfig(**config) self._config = TaskConfig(**config)
if self._config.output_type is not None:
self.OUTPUT_TYPE = self._config.output_type
if self._config.dataset_path is not None: if self._config.dataset_path is not None:
self.DATASET_PATH = self._config.dataset_path self.DATASET_PATH = self._config.dataset_path
...@@ -387,27 +422,33 @@ class ConfigurableTask(Task): ...@@ -387,27 +422,33 @@ class ConfigurableTask(Task):
if self._config.metric_list is not None: if self._config.metric_list is not None:
self._metric_list = {} self._metric_list = {}
self._metric_kwargs = {}
self._aggregation_list = {} self._aggregation_list = {}
self._higher_is_better = {} self._higher_is_better = {}
for (metric_name, aggregation, higher_is_better) in self._config.metric_list: for metric_config in self._config.metric_list:
self._aggregation_list[metric_name] = get_aggregation(aggregation) metric_name = metric_config['metric']
self._higher_is_better[metric_name] = higher_is_better aggregation = metric_config['aggregation']
higher_is_better = metric_config['higher_is_better']
self._metric_list[metric_name] = get_metric(metric_name) kwargs = {key: metric_config[key] for key in metric_config if key not in ['metric', 'aggregation', 'higher_is_better']}
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
# if metric_name in METRIC_REGISTRY.keys():
# self._metric_list[metric_name] = METRIC_REGISTRY[metric_name] if metric_name in METRIC_REGISTRY.keys():
# else: self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
# try: self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[metric_name]
# metric_object = evaluate.load(metric_name) else:
# self._metric_list[metric_name] = metric_object self._higher_is_better[metric_name] = higher_is_better
# except Exception as ex: try:
# raise Warning( metric_object = evaluate.load(metric_name)
# "{} not found in the evaluate library!".format(metric_name), self._metric_list[metric_name] = metric_object
# "Please check https://huggingface.co/evaluate-metric", self._metric_kwargs[metric_name] = kwargs
# )
except Exception as ex:
raise Warning(
"{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric",
)
self.download(data_dir, cache_dir, download_mode) self.download(data_dir, cache_dir, download_mode)
self._training_docs = None self._training_docs = None
...@@ -418,6 +459,8 @@ class ConfigurableTask(Task): ...@@ -418,6 +459,8 @@ class ConfigurableTask(Task):
for name, components in self._config.get("filters", [["none", ["take_first"]]]): for name, components in self._config.get("filters", [["none", ["take_first"]]]):
filter_pipeline = build_filter_ensemble(name, components) filter_pipeline = build_filter_ensemble(name, components)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here
def has_training_docs(self): def has_training_docs(self):
if self._config.training_split is not None: if self._config.training_split is not None:
...@@ -449,6 +492,20 @@ class ConfigurableTask(Task): ...@@ -449,6 +492,20 @@ class ConfigurableTask(Task):
if self._config.test_split is not None: if self._config.test_split is not None:
return self.dataset[self._config.test_split] return self.dataset[self._config.test_split]
def fewshot_docs(self):
if self._config.fewshot_split:
return self.dataset[self._config.fewshot_split]
else:
# TODO: warn user if fewshot split isn't explicitly set
return super().fewshot_docs()
def should_decontaminate(self):
return self._config.should_decontaminate
def doc_to_decontamination_query(self, doc):
if self._config.should_decontaminate:
return utils.apply_template(self._config.doc_to_decontamination_query, doc)
def _process_doc(self, doc): def _process_doc(self, doc):
""" """
Override this to process (detokenize, strip, replace, etc.) individual Override this to process (detokenize, strip, replace, etc.) individual
...@@ -461,40 +518,103 @@ class ConfigurableTask(Task): ...@@ -461,40 +518,103 @@ class ConfigurableTask(Task):
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return utils.apply_template(self._config.doc_to_text, doc) if self._config.use_prompt is not None:
doc_to_text = get_prompt(self._config.use_prompt)
else:
doc_to_text = self._config.doc_to_text
return utils.apply_template(doc_to_text, doc)
def doc_to_target(self, doc): def doc_to_target(self, doc):
return utils.apply_template(self._config.doc_to_target, doc) return utils.apply_template(self._config.doc_to_target, doc)
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "greedy_until": if self.OUTPUT_TYPE == "loglikelihood":
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, "\n\n"), idx=0, **kwargs) arguments=(ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments=(self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
import ast
return [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc)))
# we pass the user-defined answer_choices var (in aliases) and echo the result. TODO: any cleaner way to do this?
]
elif self.OUTPUT_TYPE == "greedy_until":
arguments=(ctx, self._config.delimiter)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs
)
def process_results(self, doc, results): def process_results(self, doc, results):
if self._config.gold_alias is not None:
gold = doc[self._config.gold_alias]
else:
gold = self.doc_to_target(doc)
result_dict = {} result_dict = {}
for key, result in zip(self._metric_list.keys(), results): if self.OUTPUT_TYPE == "loglikelihood":
_dict = self._metric_list[key]( results = results[0]
references=[gold], ll, is_greedy = results
predictions=[result], result_dict = {"perplexity": ll, "accuracy": int(is_greedy)}
) elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
return {
"word_perplexity": (loglikelihood, words),
"byte_perplexity": (loglikelihood, bytes_),
"bits_per_byte": (loglikelihood, bytes_),
}
elif self.OUTPUT_TYPE == "multiple_choice":
lls = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy TODO: keep is_greedy to report exact_match as well on multiple choice probs
gold = int(self.doc_to_target(doc))
# TODO: remove dependence on "gold" and "choices" columns
acc = 1.0 if np.argmax(lls) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
# TODO: set which normalization metrics should be reported, and calculate them
# TODO: add mutual info.
result_dict = {
"acc": acc,
"acc_norm": acc_norm,
}
elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None:
gold = doc[self._config.gold_alias]
else:
gold = self.doc_to_target(doc)
for key, result in zip(self._metric_list.keys(), results):
_dict = self._metric_list[key].compute(
references=[gold],
predictions=[result],
**self._metric_kwargs[key]
)
result_dict[key] = _dict[key] result_dict[key] = _dict[key]
else:
raise ValueError(f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'"
)
return result_dict return result_dict
def aggregation(self): def aggregation(self):
return self._aggregation_list return self._aggregation_list
def higher_is_better(self): def higher_is_better(self):
return self._higher_is_better return self._higher_is_better
...@@ -515,11 +635,6 @@ class MultipleChoiceTask(Task): ...@@ -515,11 +635,6 @@ class MultipleChoiceTask(Task):
**kwargs, **kwargs,
) )
for i, choice in enumerate(doc["choices"])] for i, choice in enumerate(doc["choices"])]
#lls = [
# rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
# ]
# return lls
def process_results(self, doc, results): def process_results(self, doc, results):
results = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? results = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...@@ -595,8 +710,8 @@ class PerplexityTask(Task, abc.ABC): ...@@ -595,8 +710,8 @@ class PerplexityTask(Task, abc.ABC):
def process_results(self, doc, results): def process_results(self, doc, results):
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(doc) words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(doc) bytes_ = self.count_bytes(self.doc_to_target(doc))
return { return {
"word_perplexity": (loglikelihood, words), "word_perplexity": (loglikelihood, words),
"byte_perplexity": (loglikelihood, bytes_), "byte_perplexity": (loglikelihood, bytes_),
...@@ -625,19 +740,22 @@ class PerplexityTask(Task, abc.ABC): ...@@ -625,19 +740,22 @@ class PerplexityTask(Task, abc.ABC):
TASK_REGISTRY = {} TASK_REGISTRY = {}
ALL_TASKS = [] ALL_TASKS = []
def register_task(name): def register_task(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls): def decorate(cls):
assert ( for name in names:
issubclass(cls, Task) assert (
), f"Task '{name}' ({cls.__name__}) must extend Task class" issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
assert ( assert (
name not in TASK_REGISTRY name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task!" ), f"Task named '{name}' conflicts with existing task! Please register with a non-conflicting alias instead."
TASK_REGISTRY[name] = cls TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import right. ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import right.
return cls return cls
return decorate return decorate
......
...@@ -145,7 +145,8 @@ def evaluate( ...@@ -145,7 +145,8 @@ def evaluate(
# for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): # for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
task.build_all_requests(limit=limit) task.build_all_requests(limit=limit)
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
requests[task.OUTPUT_TYPE].extend(task.instances) reqtype = "loglikelihood" if task.OUTPUT_TYPE == "multiple_choice" else task.OUTPUT_TYPE #TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
### Run LM on inputs, get all outputs ### ### Run LM on inputs, get all outputs ###
# execute each type of request # execute each type of request
......
...@@ -9,7 +9,7 @@ from lm_eval import utils ...@@ -9,7 +9,7 @@ from lm_eval import utils
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM, register_model
@register_model("hf-causal") @register_model("hf-causal", "gpt2")
class HFLM(LM): class HFLM(LM):
def __init__( def __init__(
self, self,
......
...@@ -41,7 +41,7 @@ from . import lambada ...@@ -41,7 +41,7 @@ from . import lambada
# from . import hendrycks_math # from . import hendrycks_math
# from . import cbt # from . import cbt
# from . import lambada_cloze # from . import lambada_cloze
# from . import pile from . import pile
from . import wikitext from . import wikitext
# from . import lambada_multilingual # from . import lambada_multilingual
# from . import mutual # from . import mutual
......
dataset_path: ai2_arc dataset_path: ai2_arc
dataset_name: ARC-Challenge dataset_name: ARC-Challenge
output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
test_split: test test_split: test
doc_to_text: "Q: {{question}}\nA:" template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_target: "{% set answer_choices = doc['choices']['text'] %}{{answer_choices[int(doc['answerKey']) - 1]}}" doc_to_text: "Question: {{question}}\nAnswer:"
metric_list: [ doc_to_target: "{{gold}}" # this will be cast to an int.
[exact_match, mean, true] metric_list:
] - metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
\ No newline at end of file
...@@ -2,12 +2,15 @@ dataset_path: gsm8k ...@@ -2,12 +2,15 @@ dataset_path: gsm8k
dataset_name: main dataset_name: main
training_split: train training_split: train
test_split: test test_split: test
doc_to_text: "Question: {{question}}\nAnswer:" doc_to_target: "{{answer.split('### ')[-1]}}"
doc_to_target: "{{answer}}" # TODO: this field needs to change to account for the regexing that happens etc. use_prompt: "qa-basic:question-newline-answer"
metric_list: [ metric_list:
[acc, mean, true] - metric: exact_match
] aggregation: mean
filters: [ higher_is_better: true
["regex", ["regex", "take_first"]] ignore_case: true
] ignore_punctuation: true
stop_sequences: ["\n"] delimiter: "\n"
\ No newline at end of file # filters: [
# ["regex", ["regex", "take_first"]]
# ]
\ No newline at end of file
dataset_path: EleutherAI/lambada_openai
dataset_name: default
output_type: loglikelihood
test_split: test
template_aliases: ""
doc_to_text: "{{text.split(' ')[:-1]|join(' ')}}"
doc_to_target: "{{' '+text.split(' ')[-1]}}"
should_decontaminate: true
doc_to_decontamination_query: "{{text}}"
metric_list:
- metric: perplexity
aggregation: perplexity
higher_is_better: true
- metric: accuracy
aggregation: mean
higher_is_better: true
"""
The Pile: An 800GB Dataset of Diverse Text for Language Modeling
https://arxiv.org/pdf/2101.00027.pdf
The Pile is a 825 GiB diverse, open source language modelling data set that consists
of 22 smaller, high-quality datasets combined together. To score well on Pile
BPB (bits per byte), a model must be able to understand many disparate domains
including books, github repositories, webpages, chat logs, and medical, physics,
math, computer science, and philosophy papers.
Homepage: https://pile.eleuther.ai/
"""
from lm_eval.api.task import PerplexityTask, register_task
_CITATION = """
@article{pile,
title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling},
author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor},
journal={arXiv preprint arXiv:2101.00027},
year={2020}
}
"""
class PilePerplexityTask(PerplexityTask):
VERSION = "2.0"
DATASET_PATH = "EleutherAI/the_pile"
DATASET_NAME = None
def has_training_docs(self):
return False
def test_docs(self):
for doc in self.dataset["train"].select(range(100)):
yield doc
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def doc_to_target(self, doc):
return doc["text"]
# def validation_docs(self):
# for doc in self.dataset["validation"]:
# yield doc["text"]
# def test_docs(self):
# for doc in self.dataset["test"]:
# yield doc["text"]
class PileArxiv(PilePerplexityTask):
DATASET_NAME = "pile_arxiv"
class PileBooks3(PilePerplexityTask):
DATASET_NAME = "pile_books3"
class PileBookCorpus2(PilePerplexityTask):
DATASET_NAME = "pile_bookcorpus2"
class PileDmMathematics(PilePerplexityTask):
DATASET_NAME = "pile_dm-mathematics"
@register_task("pile_enron")
class PileEnron(PilePerplexityTask):
DATASET_NAME = "enron_emails"
class PileEuroparl(PilePerplexityTask):
DATASET_NAME = "pile_europarl"
class PileFreeLaw(PilePerplexityTask):
DATASET_NAME = "pile_freelaw"
class PileGithub(PilePerplexityTask):
DATASET_NAME = "pile_github"
class PileGutenberg(PilePerplexityTask):
DATASET_NAME = "pile_gutenberg"
class PileHackernews(PilePerplexityTask):
DATASET_NAME = "pile_hackernews"
class PileNIHExporter(PilePerplexityTask):
DATASET_NAME = "pile_nih-exporter"
class PileOpenSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_opensubtitles"
class PileOpenWebText2(PilePerplexityTask):
DATASET_NAME = "pile_openwebtext2"
class PilePhilPapers(PilePerplexityTask):
DATASET_NAME = "pile_philpapers"
class PilePileCc(PilePerplexityTask):
DATASET_NAME = "pile_pile-cc"
class PilePubmedAbstracts(PilePerplexityTask):
DATASET_NAME = "pile_pubmed-abstracts"
class PilePubmedCentral(PilePerplexityTask):
DATASET_NAME = "pile_pubmed-central"
class PileStackExchange(PilePerplexityTask):
DATASET_NAME = "pile_stackexchange"
class PileUspto(PilePerplexityTask):
DATASET_NAME = "pile_upsto"
class PileUbuntuIrc(PilePerplexityTask):
DATASET_NAME = "pile_ubuntu-irc"
class PileWikipedia(PilePerplexityTask):
DATASET_NAME = "pile_wikipedia"
class PileYoutubeSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_youtubesubtitles"
\ No newline at end of file
dataset_path: EleutherAI/the_pile
dataset_name: enron_emails
output_type: loglikelihood_rolling
test_split: train
template_aliases: ""
doc_to_text: ""
doc_to_target: "{{text}}"
should_decontaminate: true
doc_to_decontamination_query: "{{text}}"
metric_list:
- metric: word_perplexity
aggregation: weighted_perplexity
higher_is_better: false
- metric: byte_perplexity
aggregation: weighted_perplexity
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypothesis}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment