Unverified Commit 43978e3b authored by Anish Thite's avatar Anish Thite Committed by GitHub
Browse files

Merge pull request #37 from EleutherAI/download_on_demand

download_on_demand
parents d6b91191 4e2d1498
...@@ -54,6 +54,12 @@ class LM(abc.ABC): ...@@ -54,6 +54,12 @@ class LM(abc.ABC):
class Dataset(abc.ABC): class Dataset(abc.ABC):
@abc.abstractmethod
def download(self):
"""Downloads the task dataset if necessary"""
pass
@abc.abstractmethod @abc.abstractmethod
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
...@@ -121,4 +127,4 @@ class Dataset(abc.ABC): ...@@ -121,4 +127,4 @@ class Dataset(abc.ABC):
map(self.doc_to_text, self.fewshot_examples(k=num_fewshot)) map(self.doc_to_text, self.fewshot_examples(k=num_fewshot))
) + "\n\n" ) + "\n\n"
example = self.doc_to_text(doc, include_target=False).strip() example = self.doc_to_text(doc, include_target=False).strip()
return description + labeled_examples + example return description + labeled_examples + example
\ No newline at end of file
import nlp import datasets
import numpy as np import numpy as np
import random import random
from ..base import Dataset from ..base import Dataset
...@@ -11,25 +11,35 @@ class HFNLPTask(Dataset): ...@@ -11,25 +11,35 @@ class HFNLPTask(Dataset):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._training_docs = None self._training_docs = None
self.data = datasets.load_dataset(path=self.NLP_PATH, name=self.NLP_NAME)
def _load_nlp_dataset(self): def has_training_docs(self):
return nlp.load_dataset(path=self.NLP_PATH, name=self.NLP_NAME) """Whether the task has a training set"""
return True if "train" in self.data.keys() else False
def has_validation_docs(self):
"""Whether the task has a validation set"""
return True if "validation" in self.data.keys() else False
def has_test_docs(self):
"""Whether the task has a test set"""
return True if "test" in self.data.keys() else False
def training_docs(self): def training_docs(self):
# Cache training for faster few-shot. # Cache training for faster few-shot.
# If data is too large to fit in memory, override this method. # If data is too large to fit in memory, override this method.
if self.has_training_docs(): if self.has_training_docs():
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self._load_nlp_dataset()["train"]) self._training_docs = list(self.data["train"])
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self._load_nlp_dataset()["validation"] return self.data["validation"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self._load_nlp_dataset()["test"] return self.data["test"]
def fewshot_examples(self, k): def fewshot_examples(self, k):
training_docs = self.training_docs() training_docs = self.training_docs()
......
import json import json
import random import random
from lm_eval.base import Dataset from lm_eval.base import Dataset
from ..utils import sh
class CoQA(Dataset): class CoQA(Dataset):
def download(self):
sh("""
mkdir -p data/coqa
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
""")
def has_training_docs(self): def has_training_docs(self):
return True return True
......
import os
class ExitCodeError(Exception):
pass
def sh(x):
if os.system(x):
raise ExitCodeError()
def simple_parse_args_string(args_string): def simple_parse_args_string(args_string):
""" """
Parses something like Parses something like
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment