Unverified Commit 25dfd3f6 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #389 from Muennighoff/babi

Add bAbI
parents a8996a2a caadcbc9
......@@ -4,6 +4,7 @@ from typing import List, Union
import sacrebleu
import lm_eval.base
from . import babi
from . import superglue
from . import glue
from . import arc
......@@ -93,6 +94,7 @@ all_translation_benchmarks = {
TASK_REGISTRY = {
"babi": babi.Babi,
# GLUE
"cola": glue.CoLA,
"mnli": glue.MNLI,
......
"""
Inspired by https://github.com/stanford-crfm/helm/blob/0eaaa62a2263ddb94e9850ee629423b010f57e4a/src/helm/benchmark/scenarios/babi_qa_scenario.py
"""
import numpy as np
from collections import defaultdict
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
@article{weston2015towards,
title={Towards ai-complete question answering: A set of prerequisite toy tasks},
author={Weston, Jason and Bordes, Antoine and Chopra, Sumit and Rush, Alexander M and Van Merri{\"e}nboer, Bart and Joulin, Armand and Mikolov, Tomas},
journal={arXiv preprint arXiv:1502.05698},
year={2015}
}
"""
class Babi(Task):
VERSION = 0
DATASET_PATH = "Muennighoff/babi"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
return self.dataset["train"]
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["valid"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def doc_to_text(self, doc):
return (
doc['passage'] + doc['question']
)
def should_decontaminate(self):
return False # TODO Necessary?
def doc_to_decontamination_query(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + doc['answer']
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold = doc["answer"]
pred = gold.strip() == results[0].strip()
return {"em": pred}
def aggregation(self):
return {
"em": mean,
}
def higher_is_better(self):
return {
"em": True,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment