Commit 9f1cb1e7 authored by lintangsutawika's avatar lintangsutawika
Browse files

merged with conflict resolved

parents 8f859cd2 0375b792
from lm_eval.logger import eval_logger
from promptsource.templates import DatasetTemplates
# TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both?
# Prompt library.
# Prompt library.
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# This allows us to access prompts
PROMPT_REGISTRY = {
"qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {question}\nA:"
"q-newline-a": "Q: {{question}}\nA:",
},
}
def get_prompt(prompt_id: str):
# unpack prompt name
try:
category_name, prompt_name = prompt_id.split(":")
except:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
# unpack prompt name
category_name, prompt_name = prompt_id.split(":")
if subset_name is None:
dataset_full_name = dataset_name
else:
dataset_full_name = f"{dataset_name}-{subset_name}"
eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
if category_name == "promptsource":
try:
if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(
dataset_name=dataset_name, subset_name=subset_name
)
except Exception:
raise ValueError(f"{dataset_name} and {subset_name} not found")
if prompt_name in prompts.all_template_names:
return prompts[prompt_name]
else:
raise ValueError(
f"{prompt_name} not in prompt list {prompts.all_template_names}"
)
else:
try:
return PROMPT_REGISTRY[category_name][prompt_name]
except Exception:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
return PROMPT_REGISTRY[category_name][prompt_name]
\ No newline at end of file
This diff is collapsed.
......@@ -12,11 +12,11 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc
"""
from lm_eval.api.task import MultipleChoiceTask, register_task
from lm_eval.prompts import get_prompt
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.api.task import MultipleChoiceTask
from lm_eval.api.register import register_task, register_group
_CITATION = """
@article{Clark2018ThinkYH,
......@@ -28,6 +28,8 @@ _CITATION = """
}
"""
@register_group("arc")
@register_task("arc_easy")
class ARCEasy(MultipleChoiceTask):
VERSION = "2.0"
......@@ -80,6 +82,7 @@ class ARCEasy(MultipleChoiceTask):
return doc["query"]
@register_group("arc")
@register_task("arc_challenge")
class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc"
......
# ARC
Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge
https://arxiv.org/pdf/1803.05457.pdf
The ARC dataset consists of 7,787 science exam questions drawn from a variety
of sources, including science questions provided under license by a research
partner affiliated with AI2. These are text-only, English language exam questions
that span several grade levels as indicated in the files. Each question has a
multiple choice structure (typically 4 answer options). The questions are sorted
into a Challenge Set of 2,590 “hard” questions (those that both a retrieval and
a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questions.
Homepage: https://allenai.org/data/arc
### Citation
```
@article{Clark2018ThinkYH,
title={Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge},
author={Peter Clark and Isaac Cowhey and Oren Etzioni and Tushar Khot and Ashish Sabharwal and Carissa Schoenick and Oyvind Tafjord},
journal={ArXiv},
year={2018},
volume={abs/1803.05457}
}
```
group:
- arc_yaml
task: arc_challenge_yaml
dataset_path: ai2_arc
dataset_name: ARC-Challenge
output_type: multiple_choice
......@@ -6,11 +9,14 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
\ No newline at end of file
higher_is_better: true
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
group:
- arc_yaml
task: arc_easy_yaml
dataset_path: ai2_arc
dataset_name: ARC-Easy
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
......@@ -17,13 +17,14 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
"""
import re
from lm_eval.api.task import Task, register_task
from lm_eval.api.instance import Instance
from lm_eval import utils
from lm_eval.api.task import Task
from lm_eval.api.metrics import mean
from lm_eval.api.instance import Instance
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.api.register import register_task, register_group
_CITATION = """
@misc{cobbe2021training,
......@@ -88,7 +89,13 @@ class GradeSchoolMath8K(Task):
"""
# NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution.
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), idx=0, **kwargs)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, ["\n"]),
idx=0,
**kwargs
)
# completion = rf.greedy_until(ctx, ["\n"])
# return completion
......
# "Training Verifiers to Solve Math Word Problems"
# https://arxiv.org/abs/2110.14168
# State-of-the-art language models can match human performance on many tasks, but
# they still struggle to robustly perform multi-step mathematical reasoning. To
# diagnose the failures of current models and support research, we introduce GSM8K,
# a dataset of 8.5K high quality linguistically diverse grade school math word problems.
# We find that even the largest transformer models fail to achieve high test performance,
# despite the conceptual simplicity of this problem distribution.
# NOTE: See the official implementation of the task:
# https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
# for how to make use of the dataset's calculator annotations in your language
# model's sample/generation function.
# Homepage: https://github.com/openai/grade-school-math
# _CITATION = """
# @misc{cobbe2021training,
# title={Training Verifiers to Solve Math Word Problems},
# author={Karl Cobbe and Vineet Kosaraju and Mohammad Bavarian and Jacob Hilton and Reiichiro Nakano and Christopher Hesse and John Schulman},
# year={2021},
# eprint={2110.14168},
# archivePrefix={arXiv},
# primaryClass={cs.LG}
# }
# """
task: gsm8k_yaml
dataset_path: gsm8k
dataset_name: main
training_split: train
test_split: test
use_prompt: "qa-basic:question-newline-answer"
doc_to_target: "{{answer.split('### ')[-1]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
delimiter: "\n"
repeats: 4
# filter_list:
# - name: "get-answer"
# filter:
# - function: "regex"
# regex_pattern: "#### (\-?[0-9\.\,]+)"
......@@ -12,10 +12,11 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from lm_eval.api.task import Task, register_task
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity
from lm_eval.api.register import register_task, register_group
_CITATION = """
@misc{
......@@ -59,11 +60,18 @@ class LambadaBase(Task):
return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx, **kwargs):
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, self.doc_to_target(doc)),
**kwargs
)
def process_results(self, doc, results):
# TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
results = results[0] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
results = results[
0
] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
ll, is_greedy = results
return {"ppl": ll, "acc": int(is_greedy)}
......@@ -91,6 +99,7 @@ class LambadaStandard(LambadaBase):
def has_test_docs(self):
return True
@register_task("lambada_openai")
class LambadaOpenAI(LambadaBase):
"""The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
......
task: lambada_openai_yaml
dataset_path: EleutherAI/lambada_openai
dataset_name: default
output_type: loglikelihood
......
......@@ -10,8 +10,9 @@ math, computer science, and philosophy papers.
Homepage: https://pile.eleuther.ai/
"""
from lm_eval.api.task import PerplexityTask, register_task
from lm_eval.api.task import PerplexityTask
from lm_eval.api.register import register_task, register_group
_CITATION = """
@article{pile,
......@@ -34,7 +35,7 @@ class PilePerplexityTask(PerplexityTask):
def test_docs(self):
for doc in self.dataset["train"].select(range(100)):
yield doc
def has_validation_docs(self):
return False
......@@ -139,4 +140,4 @@ class PileWikipedia(PilePerplexityTask):
class PileYoutubeSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_youtubesubtitles"
\ No newline at end of file
DATASET_NAME = "pile_youtubesubtitles"
# The Pile: An 800GB Dataset of Diverse Text for Language Modeling
# https://arxiv.org/pdf/2101.00027.pdf
# The Pile is a 825 GiB diverse, open source language modelling data set that consists
# of 22 smaller, high-quality datasets combined together. To score well on Pile
# BPB (bits per byte), a model must be able to understand many disparate domains
# including books, github repositories, webpages, chat logs, and medical, physics,
# math, computer science, and philosophy papers.
# Homepage: https://pile.eleuther.ai/
# _CITATION = """
# @article{pile,
# title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling},
# author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor},
# journal={arXiv preprint arXiv:2101.00027},
# year={2020}
# }
# """
names:
- pile_enron_yaml
dataset_path: EleutherAI/the_pile
dataset_name: enron_emails
output_type: loglikelihood_rolling
......@@ -16,4 +37,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
dataset_path: gsm8k
dataset_name: main
group:
- super-glue-promptsource
task: "GPT-3 Style"
dataset_path: super_glue
dataset_name: boolq
training_split: train
test_split: test
doc_to_target: "{{answer.split('### ')[-1]}}"
use_prompt: "qa-basic:question-newline-answer"
validation_split: validation
use_prompt: "promptsource:GPT-3 Style"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
delimiter: "\n"
# filters: [
# ["regex", ["regex", "take_first"]]
# ]
\ No newline at end of file
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "based on the previous passage"
use_prompt: "promptsource:based on the previous passage"
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "based on the following passage"
use_prompt: "promptsource:based on the following passage"
group:
- super-glue-promptsource
task: "GPT-3 style"
dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypothesis}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
use_prompt: "promptsource:GPT-3 style"
metric_list:
- metric: exact_match
aggregation: mean
......
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "MNLI crowdsource"
use_prompt: "promptsource:MNLI crowdsource"
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "based on the previous passage"
use_prompt: "promptsource:based on the previous passage"
group:
- super-glue-t5-prompt
task: t5-prompt
reference: "From Raffel et. al. 2019"
dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
template_aliases: "{% set hypo = hypothesis %}"
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypo}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
doc_to_text: "cb hypothesis: {{hypothesis}} premise {{premise}}"
doc_to_target: "{% set answer_choices = ['entailment', 'contradiction', 'neutral'] %}{{answer_choices[label]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
\ No newline at end of file
ignore_punctuation: true
group:
- super-glue-promptsource
task: "C1 or C2? premise, so/because…"
dataset_path: super_glue
dataset_name: copa
training_split: train
validation_split: validation
use_prompt: "promptsource:C1 or C2? premise, so/because…"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment