Unverified Commit 372ca6f5 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #20 from JanKalo/master

Added bigscience-LAMA evaluation 
parents 2e0b659a 49f117ed
...@@ -5,6 +5,7 @@ from typing import List, Union ...@@ -5,6 +5,7 @@ from typing import List, Union
import sacrebleu import sacrebleu
import lm_eval.base import lm_eval.base
from . import superglue from . import superglue
from . import glue from . import glue
from . import arc from . import arc
...@@ -54,12 +55,15 @@ from . import gsm8k ...@@ -54,12 +55,15 @@ from . import gsm8k
from . import storycloze from . import storycloze
from . import hans from . import hans
from . import gem_webnlg from . import gem_webnlg
from . import lama
# from . import e2e_nlg_cleaned
from . import gem_xsum from . import gem_xsum
from . import gem_mlsum from . import gem_mlsum
from . import wino_bias from . import wino_bias
from . import e2e_nlg_cleaned from . import e2e_nlg_cleaned
from . import gem_asset_turk from . import gem_asset_turk
from . import crows_pairs_multilingual from . import crows_pairs_multilingual
from . import lama
from . import HuffPost from . import HuffPost
######################################## ########################################
...@@ -139,6 +143,10 @@ TASK_REGISTRY = { ...@@ -139,6 +143,10 @@ TASK_REGISTRY = {
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
"lama_trex": lama.Trex,
"lama_squad": lama.Squad,
"lama_google_re": lama.google_re,
"lama_concptnet": lama.Conceptnet,
"logiqa": logiqa.LogiQA, "logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag, "hellaswag": hellaswag.HellaSwag,
"openbookqa": openbookqa.OpenBookQA, "openbookqa": openbookqa.OpenBookQA,
...@@ -162,6 +170,8 @@ TASK_REGISTRY = { ...@@ -162,6 +170,8 @@ TASK_REGISTRY = {
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
#"tydiqa_primary" : TyDiQA.Primary, not implemented yet
#"tydiqa_secondary" : TyDiQA.Secondary, not implemented yet
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration, "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue # dialogue
...@@ -314,6 +324,12 @@ TASK_REGISTRY = { ...@@ -314,6 +324,12 @@ TASK_REGISTRY = {
"gem_xsum_challenge_test_nopunc": gem_xsum.GEMXSUMChallgeTestNopunc, "gem_xsum_challenge_test_nopunc": gem_xsum.GEMXSUMChallgeTestNopunc,
"gem_xsum_challenge_test_covid": gem_xsum.GEMXSUMChallgeTestCovid, "gem_xsum_challenge_test_covid": gem_xsum.GEMXSUMChallgeTestCovid,
#LAMA
"lama-trex": lama.Trex,
"lama-squad": lama.Squad,
"lama-google_re": lama.google_re,
"lama-concptnet": lama.Conceptnet,
"bigscience-lama":lama.BigScienceLAMA,
# WinoBias # WinoBias
"wino_bias_type1_pro": wino_bias.WinoBiasType1Pro, "wino_bias_type1_pro": wino_bias.WinoBiasType1Pro,
"wino_bias_type1_anti": wino_bias.WinoBiasType1Anti, "wino_bias_type1_anti": wino_bias.WinoBiasType1Anti,
......
"""
https://arxiv.org/abs/1909.01066
https://arxiv.org/abs/2005.04611
LAMA is a prob dataset to test the factual and commonsense knowledge in language models. The dataset includes a subset of
Google_RE (https://code.google.com/archive/p/relation-extraction-corpus/), TRex (subset of wikidata triples),
Conceptnet (https://github.com/commonsense/conceptnet5/wiki) and Squad.
Homepage: https://github.com/facebookresearch/LAMA
"""
from lm_eval.base import PromptSourceTask
import numpy as np
from lm_eval.metrics import mean
from typing import Optional
_CITATION = """
@inproceedings{petroni2019language, title={Language Models as Knowledge Bases?},
author={F. Petroni, T. Rockt{"{a}}schel, A. H. Miller, P. Lewis, A. Bakhtin, Y. Wu and S. Riedel},
booktitle={In: Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing (EMNLP), 2019}, year={2019} }
@inproceedings{petroni2020how,
title={How Context Affects Language Models' Factual Predictions},
author={Fabio Petroni and Patrick Lewis and Aleksandra Piktus and Tim Rockt{"a}schel and Yuxiang Wu and Alexander H. Miller and Sebastian Riedel},
booktitle={Automated Knowledge Base Construction}, year={2020}, url={https://openreview.net/forum?id=025X0zPfn} }
"""
class BigScienceLAMA(PromptSourceTask):
VERSION = 0
DATASET_PATH = "janck/bigscience-lama"
DATASET_NAME = None
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return True
def training_docs(self):
if self.has_training_docs():
return self.dataset["train"]
class Trex(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "trex"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["train"]
def process_results(self, doc, results):
out = {}
#gold = doc
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
#pred = np.argmax(results)
out["acc"] = pred == target
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
class google_re(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "google_re"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["train"]
def process_results(self, doc, results):
out = {}
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
out["acc"] = pred == target
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
class Conceptnet(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "conceptnet"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["train"]
def process_results(self, doc, results):
out = {}
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
out["acc"] = pred == target
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
class Squad(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "squad"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
self._test_docs = list(self.dataset["train"])
return self._test_docs
def process_results(self, doc, results):
out = {}
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
#pred = np.argmax(results)
out["acc"] = pred == target
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
return 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment