Unverified Commit 232632c6 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

[Refactor] Non-greedy generation ; WIP GSM8k yaml (#559)

* add wip gsm8k yaml

* cleanup tasks dir

* push gsm8k yaml changes

* rename gpt2.py

* add updated gsm8k , triviaqa baseline

* add new cot yaml

* allow for multiple filter pipelines, new filter types

* updated gsm8k + sampling gen configs

* cleanup self-consistency yaml
parent 4e9412d5
......@@ -13,7 +13,7 @@ class Filter:
"""
def __init__(self):
def __init__(self, *args, **kwargs):
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
......@@ -47,10 +47,7 @@ class FilterEnsemble:
] # operate just on the model responses
for f in self.filters:
# apply filters in sequence
out = f.apply(resps)
resps = (
out # TODO: handle the case where a filter returns multiple "buckets"
)
resps = f.apply(resps)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
......
import abc
from dataclasses import dataclass
from dataclasses import dataclass, field
import re
import ast
......@@ -39,7 +39,6 @@ class TaskConfig(dict):
task: str = None
group: str = None
names: str = None
reference: str = None
task_name: str = (
None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
......@@ -63,6 +62,7 @@ class TaskConfig(dict):
metric_list: str = None
gold_alias: str = None
output_type: str = "greedy_until"
generation_kwargs: dict = None
delimiter: str = "\n\n"
filter_list: Union[str, list] = None
normalization: str = (
......@@ -85,9 +85,12 @@ class TaskConfig(dict):
if type(self.doc_to_target) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set
if self.names:
self.task_name = self.names[0]
if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.doc_to_target
if not self.generation_kwargs:
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
def __getitem__(self, item):
return getattr(self, item)
......@@ -243,7 +246,7 @@ class Task(abc.ABC):
else:
eval_logger.warning(
"has_training_docs and has_validation_docs are False"
"using test_docs but this is not recommended."
", using test_docs but this is not recommended."
)
return self.test_docs()
......@@ -308,7 +311,7 @@ class Task(abc.ABC):
inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
metadata=(self._config["task_name"], doc_id, self._config.repeats),
metadata=(self._config["task"], doc_id, self._config.repeats),
)
if not isinstance(inst, list):
......@@ -527,8 +530,8 @@ class ConfigurableTask(Task):
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
else:
self._filters = [
build_filter_ensemble("take_first", [["take_first", None]])
......@@ -639,6 +642,23 @@ class ConfigurableTask(Task):
else:
raise TypeError
def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref.
if self._config.gold_alias:
doc_to_target = self._config.gold_alias
else:
doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
return doc_to_target(doc)
elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
raise TypeError
def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "loglikelihood":
......@@ -686,7 +706,7 @@ class ConfigurableTask(Task):
return request_list
elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.delimiter)
arguments = (ctx, self._config.generation_kwargs)
return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
......@@ -759,7 +779,7 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None:
gold = doc[self._config.gold_alias]
gold = self.gold_alias(doc)
else:
gold = self.doc_to_target(doc)
......
......@@ -274,7 +274,7 @@ def evaluate(
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric + " - filter=" + key] = task.aggregation()[
results[task_name][metric + "," + key] = task.aggregation()[
metric
](items)
......@@ -289,7 +289,7 @@ def evaluate(
)
if stderr is not None:
results[task_name][metric + " - filter=" + key + "_stderr"] = stderr(
results[task_name][metric + "_stderr" + "," + key] = stderr(
items
)
......
......@@ -6,6 +6,8 @@ from . import extraction
FILTER_REGISTRY = {
"take_first": selection.TakeFirstFilter,
"regex": extraction.RegexFilter,
"majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
......
......@@ -26,8 +26,6 @@ class RegexFilter(Filter):
match = self.regex.search(resp)
if match:
match = match.group(1).strip()
match.replace(",", "")
# TODO: should we assume any other filtering is performed?
else:
match = self.fallback
filtered.append(match)
......
from collections import Counter
from lm_eval.api.filter import Filter
class TakeFirstFilter:
class TakeFirstFilter(Filter):
def __init__(self):
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
......@@ -12,3 +14,35 @@ class TakeFirstFilter:
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return map(lambda r: r[0], resps)
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs):
self.k = kwargs.pop("k")
super().__init__(*args, **kwargs)
def apply(self, resps):
# check we have at least k responses per doc, else we can't take the first k
assert len(resps[0]) >= self.k, f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
return map(lambda r: r[:self.k], resps)
class MajorityVoteFilter(Filter):
def __init__(self):
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps):
"""
Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`.
"""
def select_majority(resp):
counts = Counter(resp)
vote = counts.most_common(1)[0][0]
return vote
return map(lambda r: [select_majority(r)], resps)
from . import gpt2
from . import hf_causal
from . import gpt3
from . import textsynth
from . import dummy
......
import torch
import transformers
import copy
from tqdm import tqdm
import torch.nn.functional as F
......@@ -55,10 +56,10 @@ class HFLM(LM):
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
).to(self.device)
self.gpt2.eval()
self.model.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
......@@ -84,7 +85,7 @@ class HFLM(LM):
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
else:
self.gpt2 = accelerator.prepare(self.gpt2)
self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
......@@ -103,17 +104,17 @@ class HFLM(LM):
def max_length(self):
try:
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.gpt2).config.n_ctx
return self.accelerator.unwrap_model(self.model).config.n_ctx
else:
return self.gpt2.config.n_ctx
return self.model.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(
self.gpt2
self.model
).config.max_position_embeddings
else:
return self.gpt2.config.max_position_embeddings
return self.model.config.max_position_embeddings
@property
def max_gen_toks(self):
......@@ -150,15 +151,19 @@ class HFLM(LM):
logits returned from the model
"""
with torch.no_grad():
return self.gpt2(inps)[0]
def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
return self.model(inps)[0]
def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
return self.model.generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
do_sample=False,
**generation_kwargs,
)
def loglikelihood(self, requests):
......@@ -267,7 +272,7 @@ class HFLM(LM):
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
......@@ -347,18 +352,42 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, until in tqdm(re_ord.get_reordered()):
if isinstance(until, str):
until = [until]
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}")
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
(primary_until,) = self.tok_encode(until[0])
try:
(primary_until,) = self.tok_encode(until[0])
except:
# if our primary until would be multiple tokens long, we'll have errors.
# TODO: handling this better will let us stop generating earlier + often.
primary_until = self.eot_token_id
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
[self.tok_encode(context)[max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
eos_token_id=primary_until,
**gen_kwargs,
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
......
# v1.0 Tasks
This list keeps track of which tasks' implementations have been ported to YAML / v2.0 of the Eval Harness.
Boxes should be checked iff tasks are implemented in v2.0 and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation.
- [ ] Glue
- [ ] SuperGlue
- [ ] CoQA
- [ ] DROP
- [x] ~~Lambada~~
- [x] Lambada (Cloze variants)
- [ ] Lambada (Multilingual)
- [x] Wikitext
- [x] PiQA
- [ ] PROST
- [ ] MCTACO
- [ ] Pubmed QA
- [x] SciQ
- [ ] QASPER
- [ ] QA4MRE
- [ ] TriviaQA
- [x] AI2 ARC
- [ ] LogiQA
- [ ] HellaSwag
- [ ] SWAG
- [ ] OpenBookQA
- [ ] SQuADv2
- [ ] RACE
- [ ] HeadQA
- [ ] MathQA
- [ ] WebQs
- [ ] WSC273
- [ ] Winogrande
- [ ] ANLI
- [ ] Hendrycks Ethics
- [ ] TruthfulQA
- [ ] MuTual
- [ ] Hendrycks Math
- [ ] Asdiv
- [ ] GSM8k
- [ ] Arithmetic
- [ ] MMMLU
- [ ] Translation (WMT) suite
- [ ] Unscramble
- [x] ~~Pile (perplexity)~~
- [ ] BLiMP
- [ ] ToxiGen
- [ ] CrowS-Pairs
- [ ] XCopa
- [ ] BIG-Bench
- [ ] XStoryCloze
- [ ] XWinograd
- [ ] PAWS-X
- [ ] XNLI
- [ ] MGSM
# Novel Tasks
Tasks added in the revamped harness that were not previously available. Again, a strikethrough denotes checking performed *against the original task's implementation or published results introducing the task*.
# Task Wishlist
- [ ] TheoremQA
- [ ] Theorem Proving evaluations
- [ ] Chain of Thought
- [ ] Self-consistency ; Least-to-Most prompting, etc.
- [ ] Summarization Tasks
- [ ] Anthropic Model-Written Evals
\ No newline at end of file
import os
from typing import List, Union
from .gsm8k import *
from .triviaqa import *
from lm_eval import utils
from lm_eval.logger import eval_logger
......@@ -33,17 +35,18 @@ for root, subdirs, file_list in os.walk(task_dir):
)
if "task" in config:
task_name = "{}:{}".format(
get_task_name_from_config(config), config["task"]
task_name = "{}".format(
config["task"]
)
register_task(task_name)(SubClass)
if "group" in config:
for group in config["group"]:
register_group(group)(SubClass)
except Exception:
except Exception as e:
raise e
eval_logger.warning(
"Failed to load config at in\n"
"Failed to load config in\n"
f" {yaml_path}\n"
" Config will not be added to registry"
)
......
"""
Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge
https://arxiv.org/pdf/1803.05457.pdf
The ARC dataset consists of 7,787 science exam questions drawn from a variety
of sources, including science questions provided under license by a research
partner affiliated with AI2. These are text-only, English language exam questions
that span several grade levels as indicated in the files. Each question has a
multiple choice structure (typically 4 answer options). The questions are sorted
into a Challenge Set of 2,590 “hard” questions (those that both a retrieval and
a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questions.
Homepage: https://allenai.org/data/arc
"""
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.api.task import MultipleChoiceTask
from lm_eval.api.register import register_task, register_group
_CITATION = """
@article{Clark2018ThinkYH,
title={Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge},
author={Peter Clark and Isaac Cowhey and Oren Etzioni and Tushar Khot and Ashish Sabharwal and Carissa Schoenick and Oyvind Tafjord},
journal={ArXiv},
year={2018},
volume={abs/1803.05457}
}
"""
@register_group("arc")
@register_task("arc_easy")
class ARCEasy(MultipleChoiceTask):
VERSION = "2.0"
DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy"
OUTPUT_TYPE = "loglikelihood"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"])
out_doc = {
"id": doc["id"],
"question": doc["question"],
"choices": doc["choices"]["text"],
"gold": ["A", "B", "C", "D", "E"].index(doc["answerKey"]),
}
return out_doc
def doc_to_text(self, doc):
doc_to_text = get_prompt("qa-basic:question-newline-answer")
return utils.apply_template(doc_to_text, doc)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
@register_group("arc")
@register_task("arc_challenge")
class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Challenge"
......@@ -92,7 +92,7 @@ class GradeSchoolMath8K(Task):
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, ["\n"]),
arguments=(ctx, ["\n\n"]),
idx=0,
**kwargs
)
......@@ -113,7 +113,7 @@ class GradeSchoolMath8K(Task):
assert gold != INVALID_ANS, "No ground truth answer found in the document."
# return self._extract_answer(completion) == gold
# print(completion)
return completion == gold
return self._extract_answer(completion) == gold
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
# GSM8k
## Paper
Training Verifiers to Solve Math Word Problems
https://arxiv.org/abs/2110.14168
State-of-the-art language models can match human performance on many tasks, but
they still struggle to robustly perform multi-step mathematical reasoning. To
diagnose the failures of current models and support research, we introduce GSM8K,
a dataset of 8.5K high quality linguistically diverse grade school math word problems.
We find that even the largest transformer models fail to achieve high test performance,
despite the conceptual simplicity of this problem distribution.
NOTE: See the official implementation of the task:
https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
for how to make use of the dataset's calculator annotations in your language
model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
## Citation
```
@misc{cobbe2021training,
title={Training Verifiers to Solve Math Word Problems},
author={Karl Cobbe and Vineet Kosaraju and Mohammad Bavarian and Jacob Hilton and Reiichiro Nakano and Christopher Hesse and John Schulman},
year={2021},
eprint={2110.14168},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
\ No newline at end of file
# "Training Verifiers to Solve Math Word Problems"
# https://arxiv.org/abs/2110.14168
# State-of-the-art language models can match human performance on many tasks, but
# they still struggle to robustly perform multi-step mathematical reasoning. To
# diagnose the failures of current models and support research, we introduce GSM8K,
# a dataset of 8.5K high quality linguistically diverse grade school math word problems.
# We find that even the largest transformer models fail to achieve high test performance,
# despite the conceptual simplicity of this problem distribution.
# NOTE: See the official implementation of the task:
# https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
# for how to make use of the dataset's calculator annotations in your language
# model's sample/generation function.
# Homepage: https://github.com/openai/grade-school-math
# _CITATION = """
# @misc{cobbe2021training,
# title={Training Verifiers to Solve Math Word Problems},
# author={Karl Cobbe and Vineet Kosaraju and Mohammad Bavarian and Jacob Hilton and Reiichiro Nakano and Christopher Hesse and John Schulman},
# year={2021},
# eprint={2110.14168},
# archivePrefix={arXiv},
# primaryClass={cs.LG}
# }
# """
task: gsm8k_yaml
dataset_path: gsm8k
dataset_name: main
training_split: train
test_split: test
use_prompt: "qa-basic:question-newline-answer"
doc_to_target: "{{answer.split('### ')[-1]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
delimiter: "\n"
repeats: 4
# filter_list:
# - name: "get-answer"
# filter:
# - function: "regex"
# regex_pattern: "#### (\-?[0-9\.\,]+)"
include: gsm8k-cot.yaml
group:
- chain_of_thought
- self_consistency
task: gsm8k_cot_self_consistency
generation_kwargs:
until:
- "Q:"
- "\n\n"
do_sample: true
temperature: 0.2
repeats: 8
filter_list:
- name: "score-first" # pick only the first response, and report metrics on that
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "take_first"
- name: "maj@64"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote"
- function: "take_first"
- name: "maj@8" # get Maj@8 , via selecting the first 8 responses. Using a better estimator would be optimal.
filter:
- function: "take_first_k"
k: 8
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote"
- function: "take_first"
\ No newline at end of file
group:
- chain_of_thought
task: gsm8k_cot
dataset_path: gsm8k
dataset_name: main
output_type: greedy_until
test_split: test
doc_to_text: "Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\n\nA: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.\n\n\
Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\n\nA: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.\n\n\
Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\n\nA: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39.\n\n\
Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\n\nA: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8.\n\n\
Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\n\nA: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9.\n\n\
Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\n\nA: There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29.\n\n\
Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\n\nA: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.\n\n\
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\n\nA: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.\n\n\
Q: {{question}}\n\nA:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
gold_alias: "{{answer.split('### ')[-1].rstrip()}}" # this post-processes the reference that we'll score against
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: false
regexes_to_ignore:
- ","
- "\\$"
delimiter: "\n\n"
generation_kwargs:
until:
- "Q:"
- "\n\n"
do_sample: false
temperature: 0.0
repeats: 1
num_fewshot: 0
filter_list:
- name: "get-answer"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)"
- function: "take_first"
\ No newline at end of file
task: gsm8k_yaml
dataset_path: gsm8k
dataset_name: main
output_type: greedy_until
training_split: train
fewshot_split: train
test_split: test
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
gold_alias: "{{answer.split('### ')[-1].rstrip()}}" # this post-processes the reference that we'll score against
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: false
regexes_to_ignore:
- ","
- "\\$"
- ".*### "
delimiter: "\n\n"
generation_kwargs:
until:
- "\n\n"
- "Question:"
do_sample: false
temperature: 0.0
repeats: 2
num_fewshot: 5
# filter_list:
# - name: "get-answer"
# filter:
# - function: "regex"
# regex_pattern: "### (\\-?[0-9\\.\\,]+)"
# - function: "take_first"
\ No newline at end of file
"""
The LAMBADA dataset: Word prediction requiring a broad discourse context∗
https://arxiv.org/pdf/1606.06031.pdf
LAMBADA is a dataset to evaluate the capabilities of computational models for text
understanding by means of a word prediction task. LAMBADA is a collection of narrative
passages sharing the characteristic that human subjects are able to guess their last
word if they are exposed to the whole passage, but not if they only see the last
sentence preceding the target word. To succeed on LAMBADA, computational models
cannot simply rely on local context, but must be able to keep track of information
in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity
from lm_eval.api.register import register_task, register_group
_CITATION = """
@misc{
author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
title={The LAMBADA dataset},
DOI={10.5281/zenodo.2630551},
publisher={Zenodo},
year={2016},
month={Aug}
}
"""
class LambadaBase(Task):
VERSION = None
OUTPUT_TYPE = "loglikelihood"
def training_docs(self):
if self.has_training_docs():
return self.dataset["train"]
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def doc_to_text(self, doc):
return doc["text"].rsplit(" ", 1)[0]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["text"]
def doc_to_target(self, doc):
return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx, **kwargs):
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, self.doc_to_target(doc)),
**kwargs
)
def process_results(self, doc, results):
# TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
results = results[
0
] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
ll, is_greedy = results
return {"ppl": ll, "acc": int(is_greedy)}
def aggregation(self):
return {"ppl": perplexity, "acc": mean}
def higher_is_better(self):
return {"ppl": False, "acc": True}
@register_task("lambada_standard")
class LambadaStandard(LambadaBase):
"""The LAMBADA task using the standard original LAMBADA dataset."""
VERSION = "2.0"
DATASET_PATH = "lambada"
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
@register_task("lambada_openai")
class LambadaOpenAI(LambadaBase):
"""The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model.
Reference: https://github.com/openai/gpt-2/issues/131#issuecomment-497136199
"""
VERSION = "2.0"
DATASET_PATH = "EleutherAI/lambada_openai"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
"""
The Pile: An 800GB Dataset of Diverse Text for Language Modeling
https://arxiv.org/pdf/2101.00027.pdf
The Pile is a 825 GiB diverse, open source language modelling data set that consists
of 22 smaller, high-quality datasets combined together. To score well on Pile
BPB (bits per byte), a model must be able to understand many disparate domains
including books, github repositories, webpages, chat logs, and medical, physics,
math, computer science, and philosophy papers.
Homepage: https://pile.eleuther.ai/
"""
from lm_eval.api.task import PerplexityTask
from lm_eval.api.register import register_task, register_group
_CITATION = """
@article{pile,
title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling},
author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor},
journal={arXiv preprint arXiv:2101.00027},
year={2020}
}
"""
class PilePerplexityTask(PerplexityTask):
VERSION = "2.0"
DATASET_PATH = "EleutherAI/the_pile"
DATASET_NAME = None
def has_training_docs(self):
return False
def test_docs(self):
for doc in self.dataset["train"].select(range(100)):
yield doc
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def doc_to_target(self, doc):
return doc["text"]
# def validation_docs(self):
# for doc in self.dataset["validation"]:
# yield doc["text"]
# def test_docs(self):
# for doc in self.dataset["test"]:
# yield doc["text"]
class PileArxiv(PilePerplexityTask):
DATASET_NAME = "pile_arxiv"
class PileBooks3(PilePerplexityTask):
DATASET_NAME = "pile_books3"
class PileBookCorpus2(PilePerplexityTask):
DATASET_NAME = "pile_bookcorpus2"
class PileDmMathematics(PilePerplexityTask):
DATASET_NAME = "pile_dm-mathematics"
@register_task("pile_enron")
class PileEnron(PilePerplexityTask):
DATASET_NAME = "enron_emails"
class PileEuroparl(PilePerplexityTask):
DATASET_NAME = "pile_europarl"
class PileFreeLaw(PilePerplexityTask):
DATASET_NAME = "pile_freelaw"
class PileGithub(PilePerplexityTask):
DATASET_NAME = "pile_github"
class PileGutenberg(PilePerplexityTask):
DATASET_NAME = "pile_gutenberg"
class PileHackernews(PilePerplexityTask):
DATASET_NAME = "pile_hackernews"
class PileNIHExporter(PilePerplexityTask):
DATASET_NAME = "pile_nih-exporter"
class PileOpenSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_opensubtitles"
class PileOpenWebText2(PilePerplexityTask):
DATASET_NAME = "pile_openwebtext2"
class PilePhilPapers(PilePerplexityTask):
DATASET_NAME = "pile_philpapers"
class PilePileCc(PilePerplexityTask):
DATASET_NAME = "pile_pile-cc"
class PilePubmedAbstracts(PilePerplexityTask):
DATASET_NAME = "pile_pubmed-abstracts"
class PilePubmedCentral(PilePerplexityTask):
DATASET_NAME = "pile_pubmed-central"
class PileStackExchange(PilePerplexityTask):
DATASET_NAME = "pile_stackexchange"
class PileUspto(PilePerplexityTask):
DATASET_NAME = "pile_upsto"
class PileUbuntuIrc(PilePerplexityTask):
DATASET_NAME = "pile_ubuntu-irc"
class PileWikipedia(PilePerplexityTask):
DATASET_NAME = "pile_wikipedia"
class PileYoutubeSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_youtubesubtitles"
......@@ -7,7 +7,6 @@ output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
# TODO: we should see how shuffling answer choices affects perf.
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment