Commit 383318fe authored by KhalidAlt's avatar KhalidAlt
Browse files

add lama task

parent 567e24c9
"""
Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference
https://arxiv.org/abs/1902.01007
A controlled evaluation set called HANS (Heuristic Analysis for NLI Systems),
which contains many examples where the heuristics fail.
Homepage: https://github.com/tommccoy1/hans
"""
from lm_eval.base import PromptSourceTask
_CITATION = """\
@article{tydiqa,
title = {TyDi QA: A Benchmark for Information-Seeking Question Answering in Typologically Diverse Languages},
author = {Jonathan H. Clark and Eunsol Choi and Michael Collins and Dan Garrette and Tom Kwiatkowski and Vitaly Nikolaev and Jennimaria Palomaki}
year = {2020},
journal = {Transactions of the Association for Computational Linguistics}
}
"""
class Primary(PromptSourceTask):
VERSION = 0
DATASET_PATH = "tydiqa"
DATASET_NAME = "primary_task"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def process_results(self, doc, results):
out = {}
#gold = doc
pred = results[0].strip()
print("############")
print(self.doc_to_target(doc))
target = self.doc_to_target(doc)['sub_label']
#pred = np.argmax(results)
out["acc"] = pred == target
#result = metric.compute(predictions=pred, references=gold)
#out['acc'] = {"accuracy": result["score"]}
#out['acc'] = 1.0 if pred == gold else 0.0
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
class Secondary(PromptSourceTask):
VERSION = 0
DATASET_PATH = "tydiqa"
DATASET_NAME = "secondary_task"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
......@@ -54,7 +54,8 @@ from . import gsm8k
from . import storycloze
from . import hans
from . import gem_webnlg
from . import TyDiQA
from . import lama
# from . import e2e_nlg_cleaned
########################################
......@@ -133,6 +134,10 @@ TASK_REGISTRY = {
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet
"lama_trex": lama.Trex,
"lama_squad": lama.Squad,
"lama_google_re": lama.google_re,
"lama_concptnet": lama.Conceptnet,
"logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag,
"openbookqa": openbookqa.OpenBookQA,
......@@ -156,6 +161,8 @@ TASK_REGISTRY = {
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue,
"tydiqa_primary" : TyDiQA.Primary,
"tydiqa_secondary" : TyDiQA.Secondary,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue
......
"""
Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference
https://arxiv.org/abs/1902.01007
A controlled evaluation set called HANS (Heuristic Analysis for NLI Systems),
which contains many examples where the heuristics fail.
Homepage: https://github.com/tommccoy1/hans
"""
from lm_eval.base import PromptSourceTask
import numpy as np
from lm_eval.metrics import mean
from lm_eval import metrics,utils
from typing import Iterable, Optional
_CITATION = """
@inproceedings{petroni2019language, title={Language Models as Knowledge Bases?},
author={F. Petroni, T. Rockt{"{a}}schel, A. H. Miller, P. Lewis, A. Bakhtin, Y. Wu and S. Riedel},
booktitle={In: Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing (EMNLP), 2019}, year={2019} }
@inproceedings{petroni2020how,
title={How Context Affects Language Models' Factual Predictions},
author={Fabio Petroni and Patrick Lewis and Aleksandra Piktus and Tim Rockt{"a}schel and Yuxiang Wu and Alexander H. Miller and Sebastian Riedel},
booktitle={Automated Knowledge Base Construction}, year={2020}, url={https://openreview.net/forum?id=025X0zPfn} }
"""
class Trex(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "trex"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return True
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return True
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["train"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def process_results(self, doc, results):
out = {}
#gold = doc
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
#pred = np.argmax(results)
out["acc"] = pred == target
#result = metric.compute(predictions=pred, references=gold)
#out['acc'] = {"accuracy": result["score"]}
#out['acc'] = 1.0 if pred == gold else 0.0
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
class google_re(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "google_re"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return True
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return True
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["train"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def process_results(self, doc, results):
out = {}
#gold = doc
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
#pred = np.argmax(results)
out["acc"] = pred == target
#result = metric.compute(predictions=pred, references=gold)
#out['acc'] = {"accuracy": result["score"]}
#out['acc'] = 1.0 if pred == gold else 0.0
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
class Conceptnet(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "conceptnet"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return True
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return True
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["train"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def process_results(self, doc, results):
out = {}
#gold = doc
pred = results[0].strip()
target = self.doc_to_target(doc)['obj_label']
#pred = np.argmax(results)
out["acc"] = pred == target
#result = metric.compute(predictions=pred, references=gold)
#out['acc'] = {"accuracy": result["score"]}
#out['acc'] = 1.0 if pred == gold else 0.0
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
class Squad(PromptSourceTask):
VERSION = 0
DATASET_PATH = "lama"
DATASET_NAME = "squad"
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return True
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return True
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["train"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def process_results(self, doc, results):
out = {}
#gold = doc
pred = results[0].strip()
print("################")
print(pred)
target = self.doc_to_target(doc)['obj_label']
#pred = np.argmax(results)
out["acc"] = pred == target
#result = metric.compute(predictions=pred, references=gold)
#out['acc'] = {"accuracy": result["score"]}
#out['acc'] = 1.0 if pred == gold else 0.0
if self.save_examples:
example = {
"pred": pred,
"target": target,
}
return out, example
return out
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
def doc_to_target(self, doc):
return doc
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
return 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment