Commit a339ffd8 authored by lintangsutawika's avatar lintangsutawika
Browse files

allow to use alternative methods to use hf datasets, allow configuration with dataset_kwargs

parent 36da9c66
......@@ -45,15 +45,16 @@ class TaskConfig(dict):
task_name: str = (
None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
)
base_task: str = None
dataset_path: str = None
dataset_name: str = None
dataset_kwargs: dict = None
training_split: str = None
validation_split: str = None
test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = None
aliases: Union[str, list] = None
doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None
......@@ -79,12 +80,12 @@ class TaskConfig(dict):
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# field names in prompt
if self.template_aliases is not None:
if type(self.doc_to_text) == str:
self.doc_to_text = self.template_aliases + self.doc_to_text
# if self.template_aliases is not None:
# if type(self.doc_to_text) == str:
# self.doc_to_text = self.template_aliases + self.doc_to_text
if type(self.doc_to_target) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target
# if type(self.doc_to_target) == str:
# self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set
if self.names:
......@@ -188,22 +189,13 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
if self.DATASET_PATH in ["json", "csv"]:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
data_files=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode,
)
else:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode,
)
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode,
)
@abc.abstractmethod
def has_training_docs(self):
......@@ -524,7 +516,7 @@ class ConfigurableTask(Task):
"Please check https://huggingface.co/evaluate-metric",
)
self.download(data_dir, cache_dir, download_mode)
self.download(self._config.dataset_kwargs)
self._training_docs = None
self._fewshot_docs = None
......@@ -559,6 +551,14 @@ class ConfigurableTask(Task):
list(self.fewshot_docs()), self, rnd=random.Random()
) # TODO: pass the correct docs in here
def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def has_training_docs(self):
if self._config.training_split is not None:
return True
......@@ -710,7 +710,7 @@ class ConfigurableTask(Task):
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
result_dict = {"perplexity": ll, "accuracy": int(is_greedy)}
result_dict = {"perplexity": ll, "acc": int(is_greedy)}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment