Commit 8171906d authored by lintangsutawika's avatar lintangsutawika
Browse files

resolved latest merge conflict

parents a2e41158 232632c6
......@@ -7,13 +7,38 @@
This project provides a unified framework to test generative language models on a large number of different evaluation tasks.
Features:
**Features:**
- 200+ tasks implemented. See the [task-table](./docs/task_table.md) for a complete list.
- Support for the Hugging Face `transformers` library, GPT-NeoX, Megatron-DeepSpeed, and the OpenAI API, with flexible tokenization-agnostic interface.
- Support for evaluation on adapters (e.g. LoRa) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft).
- Task versioning to ensure reproducibility.
**Evaluation Overview**
`Task` and `Prompt` classes contain information that, when combined, produces the input to the language model. The language model is then queried to obtain an output. One or more `Filters` can then be applied to perform arbitrary operations on the model's raw output, such as selecting the final answer (for chain of thought) or calling an external API. This final output is then evaluated using a `Metric` to obtain the final result.
```mermaid
graph LR;
classDef empty width:0px,height:0px;
T[Task]
I[Input]
F[Filter]
M[Model]
O[Ouput]:::empty
P[Prompt]
Me[Metric]
R[Result]
T --- I:::empty
P --- I
I --> M
M --> O
O --> F
Me --> R:::empty
F --> R
```
## Install
To install `lm-eval` from the github repository main branch, run:
......
......@@ -13,7 +13,7 @@ class Filter:
"""
def __init__(self):
def __init__(self, *args, **kwargs):
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
......@@ -47,10 +47,7 @@ class FilterEnsemble:
] # operate just on the model responses
for f in self.filters:
# apply filters in sequence
out = f.apply(resps)
resps = (
out # TODO: handle the case where a filter returns multiple "buckets"
)
resps = f.apply(resps)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
......
import abc
from dataclasses import dataclass
from dataclasses import dataclass, field
import re
import ast
......@@ -52,7 +52,6 @@ class TaskConfig(dict):
task: str = None
group: str = None
names: str = None
reference: str = None
task_name: str = (
None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
......@@ -77,6 +76,7 @@ class TaskConfig(dict):
metric_list: str = None
gold_alias: str = None
output_type: str = "greedy_until"
generation_kwargs: dict = None
delimiter: str = "\n\n"
filter_list: Union[str, list] = None
normalization: str = (
......@@ -99,9 +99,12 @@ class TaskConfig(dict):
if type(self.doc_to_target) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set
if self.names:
self.task_name = self.names[0]
if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.doc_to_target
if not self.generation_kwargs:
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
def __getitem__(self, item):
return getattr(self, item)
......@@ -257,7 +260,7 @@ class Task(abc.ABC):
else:
eval_logger.warning(
"has_training_docs and has_validation_docs are False"
"using test_docs but this is not recommended."
", using test_docs but this is not recommended."
)
return self.test_docs()
......@@ -322,7 +325,7 @@ class Task(abc.ABC):
inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
metadata=(self._config["task_name"], doc_id, self._config.repeats),
metadata=(self._config["task"], doc_id, self._config.repeats),
)
if not isinstance(inst, list):
......@@ -697,6 +700,23 @@ class ConfigurableTask(Task):
else:
raise TypeError
def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref.
if self._config.gold_alias:
doc_to_target = self._config.gold_alias
else:
doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
return doc_to_target(doc)
elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
raise TypeError
def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "loglikelihood":
......@@ -744,7 +764,7 @@ class ConfigurableTask(Task):
return request_list
elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.delimiter)
arguments = (ctx, self._config.generation_kwargs)
return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
......@@ -833,7 +853,7 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None:
gold = doc[self._config.gold_alias]
gold = self.gold_alias(doc)
else:
gold = self.doc_to_target(doc)
......
......@@ -274,9 +274,7 @@ def evaluate(
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric + " - filter=" + key] = task.aggregation()[
metric
](items)
results[task_name][metric + "," + key] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
......@@ -289,9 +287,7 @@ def evaluate(
)
if stderr is not None:
results[task_name][metric + " - filter=" + key + "_stderr"] = stderr(
items
)
results[task_name][metric + "_stderr" + "," + key] = stderr(items)
return {"results": dict(results), "versions": dict(versions)}
......
......@@ -6,6 +6,8 @@ from . import extraction
FILTER_REGISTRY = {
"take_first": selection.TakeFirstFilter,
"regex": extraction.RegexFilter,
"majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
......
......@@ -26,8 +26,6 @@ class RegexFilter(Filter):
match = self.regex.search(resp)
if match:
match = match.group(1).strip()
match.replace(",", "")
# TODO: should we assume any other filtering is performed?
else:
match = self.fallback
filtered.append(match)
......
from collections import Counter
from lm_eval.api.filter import Filter
class TakeFirstFilter:
class TakeFirstFilter(Filter):
def __init__(self):
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
......@@ -12,3 +14,35 @@ class TakeFirstFilter:
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return map(lambda r: r[0], resps)
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs):
self.k = kwargs.pop("k")
super().__init__(*args, **kwargs)
def apply(self, resps):
# check we have at least k responses per doc, else we can't take the first k
assert len(resps[0]) >= self.k, f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
return map(lambda r: r[:self.k], resps)
class MajorityVoteFilter(Filter):
def __init__(self):
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps):
"""
Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`.
"""
def select_majority(resp):
counts = Counter(resp)
vote = counts.most_common(1)[0][0]
return vote
return map(lambda r: [select_majority(r)], resps)
from . import gpt2
from . import hf_causal
from . import gpt3
from . import textsynth
from . import dummy
......
import torch
import transformers
import copy
from tqdm import tqdm
import torch.nn.functional as F
......@@ -56,10 +57,10 @@ class HFLM(LM):
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
).to(self.device)
self.gpt2.eval()
self.model.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
......@@ -84,7 +85,7 @@ class HFLM(LM):
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
else:
self.gpt2 = accelerator.prepare(self.gpt2)
self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
......@@ -103,17 +104,17 @@ class HFLM(LM):
def max_length(self):
try:
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.gpt2).config.n_ctx
return self.accelerator.unwrap_model(self.model).config.n_ctx
else:
return self.gpt2.config.n_ctx
return self.model.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(
self.gpt2
self.model
).config.max_position_embeddings
else:
return self.gpt2.config.max_position_embeddings
return self.model.config.max_position_embeddings
@property
def max_gen_toks(self):
......@@ -150,25 +151,28 @@ class HFLM(LM):
logits returned from the model
"""
with torch.no_grad():
return self.gpt2(inps)[0]
def _model_generate(self, context, max_length, eos_token_id):
return self.model(inps)[0]
def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.gpt2).generate(
return self.accelerator.unwrap_model(self.model).generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
do_sample=False,
**generation_kwargs,
)
else:
return self.gpt2.generate(
return self.model.generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
do_sample=False,
**generation_kwargs,
)
def loglikelihood(self, requests):
......@@ -277,7 +281,7 @@ class HFLM(LM):
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
......@@ -357,18 +361,44 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, until in tqdm(re_ord.get_reordered()):
if isinstance(until, str):
until = [until]
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
(primary_until,) = self.tok_encode(until[0])
try:
(primary_until,) = self.tok_encode(until[0])
except Exception:
# if our primary until would be multiple tokens long, we'll have errors.
# TODO: handling this better will let us stop generating earlier + often.
primary_until = self.eot_token_id
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
[self.tok_encode(context)[max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
eos_token_id=primary_until,
**gen_kwargs,
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
......
# v1.0 Tasks
This list keeps track of which tasks' implementations have been ported to YAML / v2.0 of the Eval Harness.
Boxes should be checked iff tasks are implemented in v2.0 and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation.
- [ ] Glue
- [ ] SuperGlue
- [ ] CoQA
- [ ] DROP
- [x] ~~Lambada~~
- [x] Lambada (Cloze variants)
- [ ] Lambada (Multilingual)
- [x] Wikitext
- [x] PiQA
- [ ] PROST
- [ ] MCTACO
- [ ] Pubmed QA
- [x] SciQ
- [ ] QASPER
- [ ] QA4MRE
- [ ] TriviaQA
- [x] AI2 ARC
- [ ] LogiQA
- [ ] HellaSwag
- [ ] SWAG
- [ ] OpenBookQA
- [ ] SQuADv2
- [ ] RACE
- [ ] HeadQA
- [ ] MathQA
- [ ] WebQs
- [ ] WSC273
- [ ] Winogrande
- [ ] ANLI
- [ ] Hendrycks Ethics
- [ ] TruthfulQA
- [ ] MuTual
- [ ] Hendrycks Math
- [ ] Asdiv
- [ ] GSM8k
- [ ] Arithmetic
- [ ] MMMLU
- [ ] Translation (WMT) suite
- [ ] Unscramble
- [x] ~~Pile (perplexity)~~
- [ ] BLiMP
- [ ] ToxiGen
- [ ] CrowS-Pairs
- [ ] XCopa
- [ ] BIG-Bench
- [ ] XStoryCloze
- [ ] XWinograd
- [ ] PAWS-X
- [ ] XNLI
- [ ] MGSM
# Novel Tasks
Tasks added in the revamped harness that were not previously available. Again, a strikethrough denotes checking performed *against the original task's implementation or published results introducing the task*.
# Task Wishlist
- [ ] TheoremQA
- [ ] Theorem Proving evaluations
- [ ] Chain of Thought
- [ ] Self-consistency ; Least-to-Most prompting, etc.
- [ ] Summarization Tasks
- [ ] Anthropic Model-Written Evals
\ No newline at end of file
import os
from typing import List, Union
from .gsm8k import *
from .triviaqa import *
from lm_eval import utils
from lm_eval.logger import eval_logger
......@@ -33,9 +35,7 @@ for root, subdirs, file_list in os.walk(task_dir):
)
if "task" in config:
task_name = "{}:{}".format(
get_task_name_from_config(config), config["task"]
)
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
if "group" in config:
......@@ -43,7 +43,7 @@ for root, subdirs, file_list in os.walk(task_dir):
register_group(group)(SubClass)
except Exception as error:
eval_logger.warning(
"Failed to load config at in\n"
"Failed to load config in\n"
f" {yaml_path}\n"
" Config will not be added to registry"
f" Error: {error}"
......
......@@ -92,7 +92,7 @@ class GradeSchoolMath8K(Task):
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, ["\n"]),
arguments=(ctx, ["\n\n"]),
idx=0,
**kwargs
)
......@@ -113,7 +113,7 @@ class GradeSchoolMath8K(Task):
assert gold != INVALID_ANS, "No ground truth answer found in the document."
# return self._extract_answer(completion) == gold
# print(completion)
return completion == gold
return self._extract_answer(completion) == gold
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
# GSM8k
## Paper
Training Verifiers to Solve Math Word Problems
https://arxiv.org/abs/2110.14168
State-of-the-art language models can match human performance on many tasks, but
they still struggle to robustly perform multi-step mathematical reasoning. To
diagnose the failures of current models and support research, we introduce GSM8K,
a dataset of 8.5K high quality linguistically diverse grade school math word problems.
We find that even the largest transformer models fail to achieve high test performance,
despite the conceptual simplicity of this problem distribution.
NOTE: See the official implementation of the task:
https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
for how to make use of the dataset's calculator annotations in your language
model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
## Citation
```
@misc{cobbe2021training,
title={Training Verifiers to Solve Math Word Problems},
author={Karl Cobbe and Vineet Kosaraju and Mohammad Bavarian and Jacob Hilton and Reiichiro Nakano and Christopher Hesse and John Schulman},
year={2021},
eprint={2110.14168},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
\ No newline at end of file
# "Training Verifiers to Solve Math Word Problems"
# https://arxiv.org/abs/2110.14168
# State-of-the-art language models can match human performance on many tasks, but
# they still struggle to robustly perform multi-step mathematical reasoning. To
# diagnose the failures of current models and support research, we introduce GSM8K,
# a dataset of 8.5K high quality linguistically diverse grade school math word problems.
# We find that even the largest transformer models fail to achieve high test performance,
# despite the conceptual simplicity of this problem distribution.
# NOTE: See the official implementation of the task:
# https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
# for how to make use of the dataset's calculator annotations in your language
# model's sample/generation function.
# Homepage: https://github.com/openai/grade-school-math
# _CITATION = """
# @misc{cobbe2021training,
# title={Training Verifiers to Solve Math Word Problems},
# author={Karl Cobbe and Vineet Kosaraju and Mohammad Bavarian and Jacob Hilton and Reiichiro Nakano and Christopher Hesse and John Schulman},
# year={2021},
# eprint={2110.14168},
# archivePrefix={arXiv},
# primaryClass={cs.LG}
# }
# """
task: gsm8k_yaml
dataset_path: gsm8k
dataset_name: main
training_split: train
test_split: test
use_prompt: "qa-basic:question-newline-answer"
doc_to_target: "{{answer.split('### ')[-1]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
delimiter: "\n"
repeats: 4
# filter_list:
# - name: "get-answer"
# filter:
# - function: "regex"
# regex_pattern: "#### (\-?[0-9\.\,]+)"
include: gsm8k-cot.yaml
group:
- chain_of_thought
- self_consistency
task: gsm8k_cot_self_consistency
generation_kwargs:
until:
- "Q:"
- "\n\n"
do_sample: true
temperature: 0.2
repeats: 8
filter_list:
- name: "score-first" # pick only the first response, and report metrics on that
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "take_first"
- name: "maj@64"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote"
- function: "take_first"
- name: "maj@8" # get Maj@8 , via selecting the first 8 responses. Using a better estimator would be optimal.
filter:
- function: "take_first_k"
k: 8
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote"
- function: "take_first"
\ No newline at end of file
group:
- chain_of_thought
task: gsm8k_cot
dataset_path: gsm8k
dataset_name: main
output_type: greedy_until
test_split: test
doc_to_text: "Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\n\nA: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.\n\n\
Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\n\nA: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.\n\n\
Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\n\nA: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39.\n\n\
Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\n\nA: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8.\n\n\
Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\n\nA: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9.\n\n\
Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\n\nA: There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29.\n\n\
Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\n\nA: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.\n\n\
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\n\nA: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.\n\n\
Q: {{question}}\n\nA:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
gold_alias: "{{answer.split('### ')[-1].rstrip()}}" # this post-processes the reference that we'll score against
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: false
regexes_to_ignore:
- ","
- "\\$"
delimiter: "\n\n"
generation_kwargs:
until:
- "Q:"
- "\n\n"
do_sample: false
temperature: 0.0
repeats: 1
num_fewshot: 0
filter_list:
- name: "get-answer"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)"
- function: "take_first"
\ No newline at end of file
task: gsm8k_yaml
dataset_path: gsm8k
dataset_name: main
output_type: greedy_until
training_split: train
fewshot_split: train
test_split: test
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
gold_alias: "{{answer.split('### ')[-1].rstrip()}}" # this post-processes the reference that we'll score against
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: false
regexes_to_ignore:
- ","
- "\\$"
- ".*### "
delimiter: "\n\n"
generation_kwargs:
until:
- "\n\n"
- "Question:"
do_sample: false
temperature: 0.0
repeats: 2
num_fewshot: 5
# filter_list:
# - name: "get-answer"
# filter:
# - function: "regex"
# regex_pattern: "### (\\-?[0-9\\.\\,]+)"
# - function: "take_first"
\ No newline at end of file
......@@ -7,7 +7,6 @@ output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
# TODO: we should see how shuffling answer choices affects perf.
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
......
"""
TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension
https://arxiv.org/pdf/1705.03551.pdf
TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
import inspect
# import lm_eval.datasets.triviaqa.triviaqa
import string
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.register import register_task
from lm_eval.api.metrics import mean
_CITATION = """
@InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
month = {July},
year = {2017},
address = {Vancouver, Canada},
publisher = {Association for Computational Linguistics},
}
"""
@register_task("triviaqa")
class TriviaQA(Task):
VERSION = 1
DATASET_PATH = "trivia_qa" #inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_NAME = "unfiltered.nocontext"
OUTPUT_TYPE = "greedy_until"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
raise NotImplementedError()
def doc_to_text(self, doc):
return f"Q: {doc['question']}\nA:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc):
return " " + doc["answer"]["value"]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, {
"until": ["\n", ".", ","],
"do_sample": False,
}),
idx=0,
**kwargs,
)
return continuation
def process_results(self, doc, results):
continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
return {"em": float(continuation in list_of_candidates)}
def aggregation(self):
return {
"em": mean,
}
def higher_is_better(self):
return {"em": True}
\ No newline at end of file
......@@ -157,22 +157,24 @@ def make_table(result_dict):
md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
md_writer.headers = ["Task", "Version", "Filter", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Filter", "Metric", "Value", "", "Stderr"]
values = []
for k, dic in result_dict["results"].items():
version = result_dict["versions"][k]
for m, v in dic.items():
for (mf), v in dic.items():
m, _, f = mf.partition(",")
print(m,f)
if m.endswith("_stderr"):
continue
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f]
values.append([k, version, f, m, "%.4f" % v, "±", "%.4f" % se])
else:
values.append([k, version, m, "%.4f" % v, "", ""])
values.append([k, version, f, m, "%.4f" % v, "", ""])
k = ""
version = ""
md_writer.value_matrix = values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment