Commit d7076a63 authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Implement `MuTual`

parent be893e9d
...@@ -40,6 +40,7 @@ from . import cbt ...@@ -40,6 +40,7 @@ from . import cbt
from . import lambada_cloze from . import lambada_cloze
from . import pile from . import pile
from . import wikitext from . import wikitext
from . import mutual
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -139,6 +140,10 @@ TASK_REGISTRY = { ...@@ -139,6 +140,10 @@ TASK_REGISTRY = {
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
# dialogue
"mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus,
# math # math
"math_algebra": hendrycks_math.MathAlgebra, "math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
......
"""
MuTual: A Dataset for Multi-Turn Dialogue Reasoning
https://www.aclweb.org/anthology/2020.acl-main.130/
@inproceedings{mutual,
title = "MuTual: A Dataset for Multi-Turn Dialogue Reasoning",
author = "Cui, Leyang and Wu, Yu and Liu, Shujie and Zhang, Yue and Zhou, Ming" ,
booktitle = "Proceedings of the 58th Conference of the Association for Computational Linguistics",
year = "2020",
publisher = "Association for Computational Linguistics",
}
"""
import json
import zipfile
import shutil
import numpy as np
from pathlib import Path
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from lm_eval.utils import sh
class MuTualBase(Task):
VERSION = 0
BASE_PATH = Path("data/mutual")
DATASET_NAME = None
CHOICES = ['A', 'B', 'C', 'D']
def __init__(self):
super().__init__()
def download(self):
if self.BASE_PATH.exists():
return
Path.mkdir(self.BASE_PATH, parents=True)
sh("wget https://github.com/Nealcly/MuTual/archive/master.zip")
with zipfile.ZipFile("./master.zip", 'r') as zip:
zip.extractall("./data")
Path("./data/Mutual-master/data/").rename("./data/mutual/")
# Remove left over files and directories.
Path("./master.zip").unlink()
shutil.rmtree(Path("./data/MuTual-master"))
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def _load_docs(self, path):
for file in path.iterdir():
if file.suffix != ".txt":
continue
with open(file, 'r', encoding='utf-8') as f:
yield json.load(f)
def training_docs(self):
return self._load_docs(self.BASE_PATH / self.DATASET_NAME / "train")
def validation_docs(self):
return self._load_docs(self.BASE_PATH / self.DATASET_NAME / "dev")
def test_docs(self):
return NotImplemented
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
return doc["article"]
def doc_to_target(self, doc):
return " " + doc["options"][self.CHOICES.index(doc["answers"])]
def construct_requests(self, doc, ctx):
lls = []
for option in doc["options"]:
lls.append(rf.loglikelihood(ctx, f" {option}"))
return lls
def process_results(self, doc, results):
gold = self.CHOICES.index(doc["answers"])
r4_1 = np.argmax(results) == gold # r4_1 = accuracy
ranks = sorted(results, reverse=True)
r4_2 = (ranks.index(results[gold]) == 1) + r4_1
mrr = 1. / (ranks.index(results[gold]) + 1) # `+ 1` for index offset
return {
"r@1": r4_1,
"r@2": r4_2,
"mrr": mrr
}
def aggregation(self):
return {
"r@1": mean,
"r@2": mean,
"mrr": mean
}
def higher_is_better(self):
return {
"r@1": True,
"r@2": True,
"mrr": True
}
class MuTual(MuTualBase):
DATASET_NAME = Path("mutual")
class MuTualPlus(MuTualBase):
DATASET_NAME = Path("mutual_plus")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment