Unverified Commit a1a4a32e authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #119 from jon-tow/task-refactor

Refactor `Dataset` naming and `HFTask` properties
parents 826d90e2 5cfb7308
...@@ -58,10 +58,10 @@ class LM(abc.ABC): ...@@ -58,10 +58,10 @@ class LM(abc.ABC):
return cls() return cls()
class Dataset(abc.ABC): class Task(abc.ABC):
def __init__(self): def __init__(self):
self.download() self.download()
self._traindocs = None self._training_docs = None
def download(self): def download(self):
"""Downloads the task dataset if necessary""" """Downloads the task dataset if necessary"""
...@@ -71,7 +71,7 @@ class Dataset(abc.ABC): ...@@ -71,7 +71,7 @@ class Dataset(abc.ABC):
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def has_validation_docs(self): def has_validation_docs(self):
"""Whether the task has a validation set""" """Whether the task has a validation set"""
...@@ -84,23 +84,29 @@ class Dataset(abc.ABC): ...@@ -84,23 +84,29 @@ class Dataset(abc.ABC):
def training_docs(self): def training_docs(self):
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def validation_docs(self): def validation_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [] return []
def test_docs(self): def test_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [] return []
def fewshot_examples(self, k):
if self._traindocs is None:
self._traindocs = list(self.training_docs())
return random.sample(self._traindocs, k) def fewshot_examples(self, k):
if self._training_docs is None:
self._training_docs = list(self.training_docs())
return random.sample(self._training_docs, k)
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -123,7 +129,7 @@ class Dataset(abc.ABC): ...@@ -123,7 +129,7 @@ class Dataset(abc.ABC):
part of the document for `doc`. part of the document for `doc`.
""" """
pass pass
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -161,7 +167,7 @@ class Dataset(abc.ABC): ...@@ -161,7 +167,7 @@ class Dataset(abc.ABC):
def fewshot_context(self, doc, num_fewshot, provide_description): def fewshot_context(self, doc, num_fewshot, provide_description):
raw_description = self.fewshot_description() raw_description = self.fewshot_description()
description = (raw_description + "\n===\n\n") if provide_description and raw_description else "" description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
if num_fewshot == 0: if num_fewshot == 0:
labeled_examples = "" labeled_examples = ""
else: else:
......
...@@ -2,12 +2,12 @@ import abc ...@@ -2,12 +2,12 @@ import abc
import json import json
import os import os
from collections import namedtuple from collections import namedtuple
from lm_eval.base import Dataset, mean, rf from lm_eval.base import Task, mean, rf
from best_download import download_file from best_download import download_file
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion']) ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Dataset): class Arithmetic(Task):
directory = 'data/arithmetic/' directory = 'data/arithmetic/'
def __init__(self): def __init__(self):
......
import datasets import datasets
import numpy as np import numpy as np
import random from ..base import Task
from ..base import Dataset
class HFTask(Dataset): class HFTask(Task):
DATASET_PATH = None DATASET_PATH = None
DATASET_NAME = None DATASET_NAME = None
def __init__(self): def __init__(self):
self.data = None
super().__init__() super().__init__()
self._training_docs = None
def download(self): def download(self):
self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME) self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
import json import json
import random import random
from lm_eval.base import Dataset from lm_eval.base import Task
from ..utils import sh from ..utils import sh
class CoQA(Dataset): class CoQA(Task):
def __init__(self): def __init__(self):
self.download() self.download()
def download(self): def download(self):
......
...@@ -5,9 +5,9 @@ from sklearn.metrics import f1_score, matthews_corrcoef ...@@ -5,9 +5,9 @@ from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno from . common import HFTask, simple_accuracy_metric, yesno
from pathlib import Path from pathlib import Path
from ..base import Dataset from ..base import Task
class DROP(Dataset): class DROP(Task):
DATAFOLDER = Path(__file__).parent / "../../data/drop" DATAFOLDER = Path(__file__).parent / "../../data/drop"
def __init__(self): def __init__(self):
......
from lm_eval.base import Dataset, rf, mean from lm_eval.base import Task, rf, mean
from lm_eval.utils import sh from lm_eval.utils import sh
import json import json
import math import math
from best_download import download_file from best_download import download_file
class LAMBADA(Dataset): class LAMBADA(Task):
def download(self): def download(self):
sh("mkdir -p data/lambada") sh("mkdir -p data/lambada")
download_file( download_file(
......
...@@ -30,10 +30,10 @@ class NaturalQs(HFTask): ...@@ -30,10 +30,10 @@ class NaturalQs(HFTask):
def fewshot_examples(self, k): def fewshot_examples(self, k):
# Data is too large to fit in memory. We just sample from the first bit. # Data is too large to fit in memory. We just sample from the first bit.
if self._traindocs is None: if self._training_docs is None:
self._traindocs = list(islice(self.training_docs(), 0, 100000)) self._training_docs = list(islice(self.training_docs(), 0, 100000))
return random.sample(self._traindocs, k) return random.sample(self._training_docs, k)
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Q: ' + doc['question']['text'] + '\n\n' + 'A: ' return 'Q: ' + doc['question']['text'] + '\n\n' + 'A: '
......
import json import json
import random import random
from lm_eval.base import Dataset, rf, mean from lm_eval.base import Task, rf, mean
from ..utils import sh from ..utils import sh
import os import os
class PiQA(Dataset): class PiQA(Task):
def download(self): def download(self):
if not os.path.exists('data/piqa'): if not os.path.exists('data/piqa'):
#TODO: use best_download #TODO: use best_download
......
import json import json
import random import random
import os import os
from lm_eval.base import Dataset from lm_eval.base import Task
from ..utils import sh from ..utils import sh
class QuAC(Dataset): class QuAC(Task):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
import json import json
import random import random
import os import os
from lm_eval.base import Dataset, rf, mean from lm_eval.base import Task, rf, mean
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
from . common import simple_accuracy_metric from . common import simple_accuracy_metric
import numpy as np import numpy as np
from ..utils import sh from ..utils import sh
class SATAnalogies(Dataset): class SATAnalogies(Task):
NEEDS_MANUAL_DL = True NEEDS_MANUAL_DL = True
def __init__(self): def __init__(self):
......
import json import json
import random import random
from lm_eval.base import Dataset from lm_eval.base import Task
from ..utils import sh from ..utils import sh
import csv import csv
class StoryCloze(Dataset): class StoryCloze(Task):
NEEDS_MANUAL_DL = True NEEDS_MANUAL_DL = True
def download(self): def download(self):
......
import os import os
import json import json
import random import random
from lm_eval.base import Dataset, mean, rf from lm_eval.base import Task, mean, rf
from ..utils import sh from ..utils import sh
class TriviaQA(Dataset): class TriviaQA(Task):
def download(self): def download(self):
if not os.path.exists('data/triviaqa'): if not os.path.exists('data/triviaqa'):
sh(""" sh("""
......
import json import json
import random import random
import os import os
from lm_eval.base import Dataset from lm_eval.base import Task
from ..utils import sh from ..utils import sh
class WinogradSchemaChallenge273(Dataset): class WinogradSchemaChallenge273(Task):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment