Commit 52ee80a6 authored by lintangsutawika's avatar lintangsutawika
Browse files

adjusted to adapt to modiefied registry process

parent 4bc837d4
...@@ -12,10 +12,9 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi ...@@ -12,10 +12,9 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc Homepage: https://allenai.org/data/arc
""" """
from lm_eval.api.task import MultipleChoiceTask, register_task
from lm_eval.prompts import get_prompt
from lm_eval import utils from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.api.task import MultipleChoiceTask
_CITATION = """ _CITATION = """
...@@ -28,9 +27,10 @@ _CITATION = """ ...@@ -28,9 +27,10 @@ _CITATION = """
} }
""" """
@register_task("arc_easy") @utils.register_task
class ARCEasy(MultipleChoiceTask): class ARCEasy(MultipleChoiceTask):
VERSION = "2.0" VERSION = "2.0"
TASK_NAME = "arc_easy"
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy" DATASET_NAME = "ARC-Easy"
...@@ -80,7 +80,8 @@ class ARCEasy(MultipleChoiceTask): ...@@ -80,7 +80,8 @@ class ARCEasy(MultipleChoiceTask):
return doc["query"] return doc["query"]
@register_task("arc_challenge") @utils.register_task
class ARCChallenge(ARCEasy): class ARCChallenge(ARCEasy):
TASK_NAME = "arc_challenge"
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Challenge" DATASET_NAME = "ARC-Challenge"
...@@ -17,9 +17,9 @@ model's sample/generation function. ...@@ -17,9 +17,9 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math Homepage: https://github.com/openai/grade-school-math
""" """
import re import re
from lm_eval.api.task import Task, register_task
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean from lm_eval.api.metrics import mean
from lm_eval.api.instance import Instance
from lm_eval.api.task import Task
from lm_eval import utils from lm_eval import utils
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
...@@ -41,9 +41,10 @@ ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") ...@@ -41,9 +41,10 @@ ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]" INVALID_ANS = "[invalid]"
@register_task("gsm8k") @utils.register_task
class GradeSchoolMath8K(Task): class GradeSchoolMath8K(Task):
VERSION = 0 VERSION = 0
TASK_NAME = "gsm8k"
DATASET_PATH = "gsm8k" DATASET_PATH = "gsm8k"
DATASET_NAME = "main" DATASET_NAME = "main"
......
...@@ -12,10 +12,11 @@ in the broader discourse. ...@@ -12,10 +12,11 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
from lm_eval.api.task import Task, register_task from lm_eval.api.task import Task
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity from lm_eval.api.metrics import mean, perplexity
from lm_eval import utils
_CITATION = """ _CITATION = """
@misc{ @misc{
...@@ -75,11 +76,12 @@ class LambadaBase(Task): ...@@ -75,11 +76,12 @@ class LambadaBase(Task):
return {"ppl": False, "acc": True} return {"ppl": False, "acc": True}
@register_task("lambada_standard") @utils.register_task
class LambadaStandard(LambadaBase): class LambadaStandard(LambadaBase):
"""The LAMBADA task using the standard original LAMBADA dataset.""" """The LAMBADA task using the standard original LAMBADA dataset."""
VERSION = "2.0" VERSION = "2.0"
TASK_NAME = "lambada_standard"
DATASET_PATH = "lambada" DATASET_PATH = "lambada"
def has_training_docs(self): def has_training_docs(self):
...@@ -91,7 +93,8 @@ class LambadaStandard(LambadaBase): ...@@ -91,7 +93,8 @@ class LambadaStandard(LambadaBase):
def has_test_docs(self): def has_test_docs(self):
return True return True
@register_task("lambada_openai")
@utils.register_task
class LambadaOpenAI(LambadaBase): class LambadaOpenAI(LambadaBase):
"""The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the """The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model. original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model.
...@@ -100,6 +103,7 @@ class LambadaOpenAI(LambadaBase): ...@@ -100,6 +103,7 @@ class LambadaOpenAI(LambadaBase):
""" """
VERSION = "2.0" VERSION = "2.0"
TASK_NAME = "lambada_openai"
DATASET_PATH = "EleutherAI/lambada_openai" DATASET_PATH = "EleutherAI/lambada_openai"
def has_training_docs(self): def has_training_docs(self):
......
...@@ -10,7 +10,8 @@ math, computer science, and philosophy papers. ...@@ -10,7 +10,8 @@ math, computer science, and philosophy papers.
Homepage: https://pile.eleuther.ai/ Homepage: https://pile.eleuther.ai/
""" """
from lm_eval.api.task import PerplexityTask, register_task from lm_eval import utils
from lm_eval.api.task import PerplexityTask
_CITATION = """ _CITATION = """
...@@ -69,8 +70,9 @@ class PileDmMathematics(PilePerplexityTask): ...@@ -69,8 +70,9 @@ class PileDmMathematics(PilePerplexityTask):
DATASET_NAME = "pile_dm-mathematics" DATASET_NAME = "pile_dm-mathematics"
@register_task("pile_enron") @utils.register_task
class PileEnron(PilePerplexityTask): class PileEnron(PilePerplexityTask):
TASK_NAME = "pile_enron"
DATASET_NAME = "enron_emails" DATASET_NAME = "enron_emails"
......
...@@ -10,7 +10,9 @@ NOTE: This `Task` is based on WikiText-2. ...@@ -10,7 +10,9 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
""" """
import re import re
from lm_eval.api.task import PerplexityTask, register_task
from lm_eval import utils
from lm_eval.api.task import PerplexityTask
_CITATION = """ _CITATION = """
...@@ -58,9 +60,10 @@ def wikitext_detokenizer(string): ...@@ -58,9 +60,10 @@ def wikitext_detokenizer(string):
return string return string
@register_task("wikitext") @utils.register_task
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = "2.0" VERSION = "2.0"
TASK_NAME = "wikitext"
DATASET_PATH = "EleutherAI/wikitext_document_level" DATASET_PATH = "EleutherAI/wikitext_document_level"
DATASET_NAME = "wikitext-2-raw-v1" DATASET_NAME = "wikitext-2-raw-v1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment