Commit 53d4f850 authored by sdtblck's avatar sdtblck
Browse files

update nlp class to use HF datasets library

parent 5888a695
import nlp import datasets
import numpy as np import numpy as np
import random import random
from ..base import Dataset from ..base import Dataset
...@@ -11,25 +11,35 @@ class NLP_TASK(Dataset): ...@@ -11,25 +11,35 @@ class NLP_TASK(Dataset):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._training_docs = None self._training_docs = None
self.data = datasets.load_dataset(path=self.NLP_PATH, name=self.NLP_NAME)
def _load_nlp_dataset(self): def has_training_docs(self):
return nlp.load_dataset(path=self.NLP_PATH, name=self.NLP_NAME) """Whether the task has a training set"""
return True if "train" in self.data.keys() else False
def has_validation_docs(self):
"""Whether the task has a validation set"""
return True if "validation" in self.data.keys() else False
def has_test_docs(self):
"""Whether the task has a test set"""
return True if "test" in self.data.keys() else False
def training_docs(self): def training_docs(self):
# Cache training for faster few-shot. # Cache training for faster few-shot.
# If data is too large to fit in memory, override this method. # If data is too large to fit in memory, override this method.
if self.has_training_docs(): if self.has_training_docs():
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self._load_nlp_dataset()["train"]) self._training_docs = list(self.data["train"])
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self._load_nlp_dataset()["validation"] return self.data["validation"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self._load_nlp_dataset()["test"] return self.data["test"]
def fewshot_examples(self, k): def fewshot_examples(self, k):
training_docs = self.training_docs() training_docs = self.training_docs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment