"...mmpose-rtmo_pytorch.git" did not exist on "ca8a762a7fd9b5e2641add9254e43b4ef0b44526"
Commit d79a4389 authored by Jon Tow's avatar Jon Tow
Browse files

Add specific `load_dataset` args

parent 4f8b99f6
...@@ -348,19 +348,41 @@ class Task(abc.ABC): ...@@ -348,19 +348,41 @@ class Task(abc.ABC):
{"question": ..., question, answer) {"question": ..., question, answer)
""" """
# The name of the `Task` benchmark as denoted in the HuggingFace `datasets` # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# API or a path to a custom `datasets` loading script. # or a path to a custom `datasets` loading script.
DATASET_PATH: str = None DATASET_PATH: str = None
# The name of a subset within `DATASET_PATH`. # The name of a subset within `DATASET_PATH`.
DATASET_NAME: str = None DATASET_NAME: str = None
def __init__(self, **kwargs): def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
self.download(**kwargs) """
:param data_dir: str
Stores the path to a local folder containing the `Task`'s data files.
Use this to specify the path to manually downloaded data (usually when
the dataset is not publicly accessible).
:param cache_dir: str
The directory to read/write the `Task` dataset. This follows the
HuggingFace `datasets` API with the default cache directory located at:
`~/.cache/huggingface/datasets`
NOTE: You can change the cache location globally for a given process
by setting the shell environment variable, `HF_DATASETS_CACHE`,
to another directory:
`export HF_DATASETS_CACHE="/path/to/another/directory"`
:param download_mode: datasets.DownloadMode
How to treat pre-existing `Task` downloads and data.
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
Reuse download and reuse dataset.
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
Reuse download with fresh dataset.
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
self.download(data_dir, cache_dir, download_mode)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
def download(self, **load_dataset_kwargs): def download(self, data_dir=None, cache_dir=None, download_mode=None):
""" Downloads and returns the task dataset. """ Downloads and returns the task dataset.
Override this method to download the dataset from a custom API. Override this method to download the dataset from a custom API.
...@@ -370,7 +392,9 @@ class Task(abc.ABC): ...@@ -370,7 +392,9 @@ class Task(abc.ABC):
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.DATASET_PATH,
name=self.DATASET_NAME, name=self.DATASET_NAME,
**load_dataset_kwargs data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode
) )
@abstractmethod @abstractmethod
...@@ -532,7 +556,8 @@ class Task(abc.ABC): ...@@ -532,7 +556,8 @@ class Task(abc.ABC):
return description + labeled_examples + example return description + labeled_examples + example
class MultipleChoiceTask(Task, abc.ABC): class MultipleChoiceTask(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['choices'][doc['gold']] return " " + doc['choices'][doc['gold']]
......
...@@ -90,7 +90,7 @@ class GeneralTranslationTask(Task): ...@@ -90,7 +90,7 @@ class GeneralTranslationTask(Task):
super().__init__() super().__init__()
def download(self): def download(self, data_dir=None, cache_dir=None, download_mode=None):
# This caches in the users home dir automatically # This caches in the users home dir automatically
self.src_file, self.ref_file = \ self.src_file, self.ref_file = \
sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair) sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment