Commit a339ffd8 authored by lintangsutawika's avatar lintangsutawika
Browse files

allow to use alternative methods to use hf datasets, allow configuration with dataset_kwargs

parent 36da9c66
...@@ -45,15 +45,16 @@ class TaskConfig(dict): ...@@ -45,15 +45,16 @@ class TaskConfig(dict):
task_name: str = ( task_name: str = (
None # TODO: deprecate this, it'll be set in __post_init__ to be names[0] None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
) )
base_task: str = None
dataset_path: str = None dataset_path: str = None
dataset_name: str = None dataset_name: str = None
dataset_kwargs: dict = None
training_split: str = None training_split: str = None
validation_split: str = None validation_split: str = None
test_split: str = None test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = None template_aliases: str = None
aliases: Union[str, list] = None
doc_to_text: Union[Callable, str] = None doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
...@@ -79,12 +80,12 @@ class TaskConfig(dict): ...@@ -79,12 +80,12 @@ class TaskConfig(dict):
# allow user-specified aliases so that users can # allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of # force prompt-compatibility for some prompt regardless of
# field names in prompt # field names in prompt
if self.template_aliases is not None: # if self.template_aliases is not None:
if type(self.doc_to_text) == str: # if type(self.doc_to_text) == str:
self.doc_to_text = self.template_aliases + self.doc_to_text # self.doc_to_text = self.template_aliases + self.doc_to_text
if type(self.doc_to_target) == str: # if type(self.doc_to_target) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target # self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set # set "task_name" metadata field based on the "primary" name set
if self.names: if self.names:
...@@ -188,22 +189,13 @@ class Task(abc.ABC): ...@@ -188,22 +189,13 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD` - `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset. Fresh download and fresh dataset.
""" """
if self.DATASET_PATH in ["json", "csv"]: self.dataset = datasets.load_dataset(
self.dataset = datasets.load_dataset( path=self.DATASET_PATH,
path=self.DATASET_PATH, name=self.DATASET_NAME,
data_files=self.DATASET_NAME, data_dir=data_dir,
data_dir=data_dir, cache_dir=cache_dir,
cache_dir=cache_dir, download_mode=download_mode,
download_mode=download_mode, )
)
else:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode,
)
@abc.abstractmethod @abc.abstractmethod
def has_training_docs(self): def has_training_docs(self):
...@@ -524,7 +516,7 @@ class ConfigurableTask(Task): ...@@ -524,7 +516,7 @@ class ConfigurableTask(Task):
"Please check https://huggingface.co/evaluate-metric", "Please check https://huggingface.co/evaluate-metric",
) )
self.download(data_dir, cache_dir, download_mode) self.download(self._config.dataset_kwargs)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
...@@ -559,6 +551,14 @@ class ConfigurableTask(Task): ...@@ -559,6 +551,14 @@ class ConfigurableTask(Task):
list(self.fewshot_docs()), self, rnd=random.Random() list(self.fewshot_docs()), self, rnd=random.Random()
) # TODO: pass the correct docs in here ) # TODO: pass the correct docs in here
def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def has_training_docs(self): def has_training_docs(self):
if self._config.training_split is not None: if self._config.training_split is not None:
return True return True
...@@ -710,7 +710,7 @@ class ConfigurableTask(Task): ...@@ -710,7 +710,7 @@ class ConfigurableTask(Task):
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
results = results[0] results = results[0]
ll, is_greedy = results ll, is_greedy = results
result_dict = {"perplexity": ll, "accuracy": int(is_greedy)} result_dict = {"perplexity": ll, "acc": int(is_greedy)}
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc)) words = self.count_words(self.doc_to_target(doc))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment