Unverified Commit 6559ba0f authored by KhalidAlt's avatar KhalidAlt Committed by GitHub
Browse files

Merge branch 'add_lama' into master

parents 0816eba0 1e44199c
...@@ -57,6 +57,8 @@ from . import gsm8k ...@@ -57,6 +57,8 @@ from . import gsm8k
from . import storycloze from . import storycloze
from . import hans from . import hans
from . import gem_webnlg from . import gem_webnlg
from . import lama
# from . import e2e_nlg_cleaned
from . import gem_xsum from . import gem_xsum
from . import gem_mlsum from . import gem_mlsum
from . import wino_bias from . import wino_bias
...@@ -140,6 +142,10 @@ TASK_REGISTRY = { ...@@ -140,6 +142,10 @@ TASK_REGISTRY = {
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
"lama_trex": lama.Trex,
"lama_squad": lama.Squad,
"lama_google_re": lama.google_re,
"lama_concptnet": lama.Conceptnet,
"logiqa": logiqa.LogiQA, "logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag, "hellaswag": hellaswag.HellaSwag,
"openbookqa": openbookqa.OpenBookQA, "openbookqa": openbookqa.OpenBookQA,
...@@ -163,6 +169,8 @@ TASK_REGISTRY = { ...@@ -163,6 +169,8 @@ TASK_REGISTRY = {
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
#"tydiqa_primary" : TyDiQA.Primary, not implemented yet
#"tydiqa_secondary" : TyDiQA.Secondary, not implemented yet
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration, "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue # dialogue
......
"""
https://arxiv.org/abs/1909.01066
https://arxiv.org/abs/2005.04611
LAMA is a prob dataset to test the factual and commonsense knowledge in language models The dataset include a subset of
Google_RE (https://code.google.com/archive/p/relation-extraction-corpus/), TRex (subset of wikidata triples),
Conceptnet (https://github.com/commonsense/conceptnet5/wiki) and Squad.
Homepage: https://github.com/facebookresearch/LAMA
"""
from lm_eval.base import PromptSourceTask from lm_eval.base import PromptSourceTask
import numpy as np
from lm_eval.metrics import mean
from typing import Optional
_CITATION = """ _CITATION = """
@inproceedings{petroni2019language, title={Language Models as Knowledge Bases?}, @inproceedings{petroni2019language, title={Language Models as Knowledge Bases?},
author={F. Petroni, T. Rockt{"{a}}schel, A. H. Miller, P. Lewis, A. Bakhtin, Y. Wu and S. Riedel}, author={F. Petroni, T. Rockt{"{a}}schel, A. H. Miller, P. Lewis, A. Bakhtin, Y. Wu and S. Riedel},
...@@ -11,6 +24,7 @@ _CITATION = """ ...@@ -11,6 +24,7 @@ _CITATION = """
""" """
class BigScienceLAMA(PromptSourceTask): class BigScienceLAMA(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "janck/bigscience-lama" DATASET_PATH = "janck/bigscience-lama"
...@@ -30,15 +44,243 @@ class BigScienceLAMA(PromptSourceTask): ...@@ -30,15 +44,243 @@ class BigScienceLAMA(PromptSourceTask):
if self.has_training_docs(): if self.has_training_docs():
return self.dataset["train"] return self.dataset["train"]
class Trex(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "trex"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return True
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return True
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["train"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def process_results(self, doc, results):
out = {}
#gold = doc
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
#pred = np.argmax(results)
out["acc"] = pred == target
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
class google_re(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "google_re"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return True
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return True
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.dataset["train"] return self.dataset["train"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["test"]
def process_results(self, doc, results):
out = {}
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
out["acc"] = pred == target
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
class Conceptnet(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "conceptnet"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return True
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return True
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["train"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def process_results(self, doc, results):
out = {}
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
out["acc"] = pred == target
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
class Squad(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "squad"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return True
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return True
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["train"]
def test_docs(self):
if self.has_test_docs():
self._test_docs = list(self.dataset["test"]) self._test_docs = list(self.dataset["test"])
return self._test_docs return self._test_docs
def process_results(self, doc, results):
out = {}
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
#pred = np.argmax(results)
out["acc"] = pred == target
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
return 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment