Unverified Commit c5ed8cdc authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #501 from EleutherAI/update-config

Update config
parents f6b76f5d c17e3659
import os
from typing import List, Union
from lm_eval.api.task import register_yaml_task
from .arc import *
from .vanilla import *
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.register import (
register_task,
register_group,
task_registry,
group_registry,
)
# we want to register all yaml tasks in our .yaml folder.
yaml_dir = os.path.dirname(os.path.abspath(__file__)) + "/" + "yaml"
def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
for yaml in sorted(os.listdir(yaml_dir)):
yaml = os.path.join(yaml_dir, yaml)
register_yaml_task(yaml)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == []) and (len(file_list) > 0):
for file in file_list:
if "yaml" in file:
yaml_path = os.path.join(root, file)
try:
config = utils.load_yaml_config(yaml_path)
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
{"CONFIG": TaskConfig(**config)},
)
if "task" in config:
task_name = "{}:{}".format(
get_task_name_from_config(config), config["task"]
)
register_task(task_name)(SubClass)
if "group" in config:
for group in config["group"]:
register_group(group)(SubClass)
except Exception as err:
print(f"Unexpected {err=}, {type(err)=}")
TASK_REGISTRY = task_registry
GROUP_REGISTRY = group_registry
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
def get_task(task_name, config):
try:
return TASK_REGISTRY[task_name](config=config)
except KeyError:
eval_logger.info("Available tasks:")
eval_logger.info(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
config = {**kwargs}
task_name_from_registry_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
for task_element in task_name_list:
if isinstance(task_element, str):
if task_element in GROUP_REGISTRY:
for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(task_name=task_name, config=config),
}
else:
task_name = task_element
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(task_name=task_element, config=config),
}
elif isinstance(task_element, dict):
task_element.update(config)
task_name_from_config_dict = {
**task_name_from_config_dict,
get_task_name_from_config(task_element): ConfigurableTask(
config=task_element
),
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element,
}
# task_name_from_registry_dict = {
# task_name: get_task(
# task_name=task_name,
# task_config=config
# )
# for group_name in task_name_list for task_name in GROUP_REGISTRY[group_name]
# if (isinstance(group_name, str)) and (group_name in GROUP_REGISTRY)
# }
# task_name_from_config_dict = {
# get_task_name_from_config(task_config): ConfigurableTask(
# config=task_config
# )
# for task_config in task_name_list
# if isinstance(task_config, dict)
# }
# # TODO: Do we still need this?
# task_name_from_object_dict = {
# get_task_name_from_object(task_object): task_object
# for task_object in task_name_list
# if isinstance(task_object, Task)
# }
assert set(task_name_from_registry_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys())
)
return {
**task_name_from_registry_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
......@@ -12,11 +12,11 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc
"""
from lm_eval.api.task import MultipleChoiceTask, register_task
from lm_eval.prompts import get_prompt
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.api.task import MultipleChoiceTask
from lm_eval.api.register import register_task, register_group
_CITATION = """
@article{Clark2018ThinkYH,
......@@ -28,6 +28,8 @@ _CITATION = """
}
"""
@register_group("arc")
@register_task("arc_easy")
class ARCEasy(MultipleChoiceTask):
VERSION = "2.0"
......@@ -80,6 +82,7 @@ class ARCEasy(MultipleChoiceTask):
return doc["query"]
@register_group("arc")
@register_task("arc_challenge")
class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc"
......
# ARC
Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge
https://arxiv.org/pdf/1803.05457.pdf
The ARC dataset consists of 7,787 science exam questions drawn from a variety
of sources, including science questions provided under license by a research
partner affiliated with AI2. These are text-only, English language exam questions
that span several grade levels as indicated in the files. Each question has a
multiple choice structure (typically 4 answer options). The questions are sorted
into a Challenge Set of 2,590 “hard” questions (those that both a retrieval and
a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questions.
Homepage: https://allenai.org/data/arc
### Citation
```
@article{Clark2018ThinkYH,
title={Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge},
author={Peter Clark and Isaac Cowhey and Oren Etzioni and Tushar Khot and Ashish Sabharwal and Carissa Schoenick and Oyvind Tafjord},
journal={ArXiv},
year={2018},
volume={abs/1803.05457}
}
```
names:
- arc_challenge_yaml
group:
- arc_yaml
task: arc_challenge_yaml
dataset_path: ai2_arc
dataset_name: ARC-Challenge
output_type: multiple_choice
......@@ -8,7 +9,7 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......@@ -19,6 +20,3 @@ metric_list:
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
\ No newline at end of file
names:
- arc_easy_yaml
group:
- arc_yaml
task: arc_easy_yaml
dataset_path: ai2_arc
dataset_name: ARC-Easy
output_type: multiple_choice
......@@ -8,7 +9,7 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......@@ -19,6 +20,3 @@ metric_list:
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
\ No newline at end of file
......@@ -17,13 +17,14 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
"""
import re
from lm_eval.api.task import Task, register_task
from lm_eval.api.instance import Instance
from lm_eval import utils
from lm_eval.api.task import Task
from lm_eval.api.metrics import mean
from lm_eval.api.instance import Instance
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.api.register import register_task, register_group
_CITATION = """
@misc{cobbe2021training,
......@@ -88,7 +89,13 @@ class GradeSchoolMath8K(Task):
"""
# NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution.
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), idx=0, **kwargs)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, ["\n"]),
idx=0,
**kwargs
)
# completion = rf.greedy_until(ctx, ["\n"])
# return completion
......
# "Training Verifiers to Solve Math Word Problems"
# https://arxiv.org/abs/2110.14168
# State-of-the-art language models can match human performance on many tasks, but
# they still struggle to robustly perform multi-step mathematical reasoning. To
# diagnose the failures of current models and support research, we introduce GSM8K,
# a dataset of 8.5K high quality linguistically diverse grade school math word problems.
# We find that even the largest transformer models fail to achieve high test performance,
# despite the conceptual simplicity of this problem distribution.
# NOTE: See the official implementation of the task:
# https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
# for how to make use of the dataset's calculator annotations in your language
# model's sample/generation function.
# Homepage: https://github.com/openai/grade-school-math
# _CITATION = """
# @misc{cobbe2021training,
# title={Training Verifiers to Solve Math Word Problems},
# author={Karl Cobbe and Vineet Kosaraju and Mohammad Bavarian and Jacob Hilton and Reiichiro Nakano and Christopher Hesse and John Schulman},
# year={2021},
# eprint={2110.14168},
# archivePrefix={arXiv},
# primaryClass={cs.LG}
# }
# """
task: gsm8k_yaml
dataset_path: gsm8k
dataset_name: main
training_split: train
test_split: test
use_prompt: "qa-basic:question-newline-answer"
doc_to_target: "{{answer.split('### ')[-1]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
delimiter: "\n"
repeats: 4
# filter_list:
# - name: "get-answer"
# filter:
# - function: "regex"
# regex_pattern: "#### (\-?[0-9\.\,]+)"
......@@ -12,10 +12,11 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from lm_eval.api.task import Task, register_task
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity
from lm_eval.api.register import register_task, register_group
_CITATION = """
@misc{
......@@ -59,11 +60,18 @@ class LambadaBase(Task):
return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx, **kwargs):
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, self.doc_to_target(doc)),
**kwargs
)
def process_results(self, doc, results):
# TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
results = results[0] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
results = results[
0
] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
ll, is_greedy = results
return {"ppl": ll, "acc": int(is_greedy)}
......@@ -91,6 +99,7 @@ class LambadaStandard(LambadaBase):
def has_test_docs(self):
return True
@register_task("lambada_openai")
class LambadaOpenAI(LambadaBase):
"""The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
......
names:
- lambada_openai_yaml
task: lambada_openai_yaml
dataset_path: EleutherAI/lambada_openai
dataset_name: default
output_type: loglikelihood
......
......@@ -10,8 +10,9 @@ math, computer science, and philosophy papers.
Homepage: https://pile.eleuther.ai/
"""
from lm_eval.api.task import PerplexityTask, register_task
from lm_eval.api.task import PerplexityTask
from lm_eval.api.register import register_task, register_group
_CITATION = """
@article{pile,
......@@ -34,7 +35,7 @@ class PilePerplexityTask(PerplexityTask):
def test_docs(self):
for doc in self.dataset["train"].select(range(100)):
yield doc
def has_validation_docs(self):
return False
......@@ -139,4 +140,4 @@ class PileWikipedia(PilePerplexityTask):
class PileYoutubeSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_youtubesubtitles"
\ No newline at end of file
DATASET_NAME = "pile_youtubesubtitles"
# The Pile: An 800GB Dataset of Diverse Text for Language Modeling
# https://arxiv.org/pdf/2101.00027.pdf
# The Pile is a 825 GiB diverse, open source language modelling data set that consists
# of 22 smaller, high-quality datasets combined together. To score well on Pile
# BPB (bits per byte), a model must be able to understand many disparate domains
# including books, github repositories, webpages, chat logs, and medical, physics,
# math, computer science, and philosophy papers.
# Homepage: https://pile.eleuther.ai/
# _CITATION = """
# @article{pile,
# title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling},
# author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor},
# journal={arXiv preprint arXiv:2101.00027},
# year={2020}
# }
# """
names:
- pile_enron_yaml
dataset_path: EleutherAI/the_pile
......@@ -18,4 +37,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
names:
- sglue_cb_yamltest
group:
- super-glue-cb
task: based on the previous passage
reference: "Adapted from the BoolQ prompts in Schick & Sch\xFCtze 2021."
dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypothesis}}\"? Yes, no, or maybe?"
doc_to_text: "{{premise}} Based on the previous passage, is it true that \"{{hypothesis}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
metric_list:
- metric: exact_match
......
group:
- super-glue-cb
include: based_on_previous_passage.yaml
task: can we infer
reference: Webson & Pavlick 2021
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypothesis}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
group:
- super-glue-cb
include: based_on_previous_passage.yaml
task: claim true/false/inconclusive
reference: Sanh et al. 2021
doc_to_text: "{{premise}} Based on that information, is the claim: \"{{hypothesis}}\" \"true\", \"false\", or \"inconclusive\"?"
doc_to_target: "{% set answer_choices = ['True', 'False', 'Inconclusive'] %}{{answer_choices[label]}}"
group:
- super-glue-t5-prompt
task: t5-prompt
reference: "From Raffel et. al. 2019"
dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
template_aliases: "{% set hypo = hypothesis %}"
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypo}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
doc_to_text: "cb hypothesis: {{hypothesis}} premise {{premise}}"
doc_to_target: "{% set answer_choices = ['entailment', 'contradiction', 'neutral'] %}{{answer_choices[label]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
\ No newline at end of file
ignore_punctuation: true
group:
- super-glue-t5-prompt
task: t5-prompt
reference: "From Raffel et. al. 2019"
dataset_path: super_glue
dataset_name: copa
training_split: train
validation_split: validation
doc_to_text: "copa choice1: {{choice1}} choice2: {{choice2}} question: {{question}}"
doc_to_target: "{% set answer_choices = ['False', 'True'] %}{{answer_choices[label]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
group:
- super-glue-t5-prompt
task: t5-prompt
reference: "From Raffel et. al. 2019"
dataset_path: super_glue
dataset_name: record
training_split: train
validation_split: validation
doc_to_text: "record query: {{query}} entities: {{entities}} passage: {{passage}}"
doc_to_target: "{{answers}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
names:
- gsm8k_yaml
dataset_path: gsm8k
dataset_name: main
group:
- t0-eval
task: "does the pronoun refer to"
dataset_path: super_glue
dataset_name: wsc.fixed
training_split: train
test_split: test
doc_to_target: "{{answer.split('### ')[-1]}}"
use_prompt: "qa-basic:question-newline-answer"
validation_split: validation
use_prompt: "promptsource:does the pronoun refer to"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
delimiter: "\n"
# filters: [
# ["regex", ["regex", "take_first"]]
# ]
\ No newline at end of file
group:
- t0-eval
task: "by p they mean"
dataset_path: super_glue
dataset_name: wsc.fixed
training_split: train
validation_split: validation
use_prompt: "promptsource:by p they mean"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
group:
- t0-eval
task: "in other words"
dataset_path: super_glue
dataset_name: wsc.fixed
training_split: train
validation_split: validation
use_prompt: "promptsource:in other words"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment