Unverified Commit 84ef60ee authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #481 from janEbert/json-task

Add perplexity task on arbitrary JSON data
parents bda68845 4de8a74e
...@@ -52,6 +52,7 @@ from . import gsm8k ...@@ -52,6 +52,7 @@ from . import gsm8k
from . import storycloze from . import storycloze
from . import toxigen from . import toxigen
from . import crowspairs from . import crowspairs
from . import json
from . import xcopa from . import xcopa
from . import bigbench from . import bigbench
from . import xstorycloze from . import xstorycloze
...@@ -329,9 +330,42 @@ TASK_REGISTRY = { ...@@ -329,9 +330,42 @@ TASK_REGISTRY = {
ALL_TASKS = sorted(list(TASK_REGISTRY)) ALL_TASKS = sorted(list(TASK_REGISTRY))
_EXAMPLE_JSON_PATH = "split:key:/absolute/path/to/data.json"
def add_json_task(task_name):
"""Add a JSON perplexity task if the given task name matches the
JSON task specification.
See `json.JsonPerplexity`.
"""
if not task_name.startswith("json"):
return
def create_json_task():
splits = task_name.split("=", 1)
if len(splits) != 2 or not splits[1]:
raise ValueError(
"json tasks need a path argument pointing to the local "
"dataset, specified like this: json="
+ _EXAMPLE_JSON_PATH
+ ' (if there are no splits, use "train")'
)
json_path = splits[1]
if json_path == _EXAMPLE_JSON_PATH:
raise ValueError(
"please do not copy the example path directly, but substitute "
"it with a path to your local dataset"
)
return lambda: json.JsonPerplexity(json_path)
TASK_REGISTRY[task_name] = create_json_task()
def get_task(task_name): def get_task(task_name):
try: try:
add_json_task(task_name)
return TASK_REGISTRY[task_name] return TASK_REGISTRY[task_name]
except KeyError: except KeyError:
print("Available tasks:") print("Available tasks:")
......
import datasets
from lm_eval.base import PerplexityTask
from lm_eval.utils import escaped_split
class JsonPerplexity(PerplexityTask):
VERSION = 0
DATASET_NAME = "json"
def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
"""
:param data_dir: str
Use this to specify the path to manually downloaded JSON test data.
This also needs to include the split key and text key for the data
in the following format:
```
split:text:/absolute/path/to/data.json
```
If you do not have splits inside the JSON file, it should be "train".
Colons in the split or text key can be escaped by backslashes.
:param cache_dir: str
The directory to read/write the `Task` dataset. This follows the
HuggingFace `datasets` API with the default cache directory located at:
`~/.cache/huggingface/datasets`
NOTE: You can change the cache location globally for a given process
by setting the shell environment variable, `HF_DATASETS_CACHE`,
to another directory:
`export HF_DATASETS_CACHE="/path/to/another/directory"`
:param download_mode: datasets.DownloadMode
How to treat pre-existing `Task` downloads and data.
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
Reuse download and reuse dataset.
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
Reuse download with fresh dataset.
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
self._split, self._key, data_file = escaped_split(data_dir, ":", 2)
self.load(data_file)
self._training_docs = None
self._fewshot_docs = None
def download(self, data_dir=None, cache_dir=None, download_mode=None):
raise TypeError("cannot download an arbitrary JSON dataset")
def load(self, data_file):
self.dataset = datasets.load_dataset("json", data_files=data_file)
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return map(self._process_doc, self.dataset[self._split])
def _process_doc(self, doc):
return doc[self._key]
...@@ -21,6 +21,29 @@ def sh(x): ...@@ -21,6 +21,29 @@ def sh(x):
raise ExitCodeError() raise ExitCodeError()
def escaped_split(text, sep_char, maxsplit=-1):
"""Split text into a list on occurrences of the given separation
character `sep_char`. The separation character may be escaped by a
backslash to avoid splitting at that location.
The separation character must be a string of size 1.
If `maxsplit` is given, at most `maxsplit` splits are done (thus,
the list will have at most `maxsplit + 1` elements). If `maxsplit`
is not specified or less than 0, then there is no limit on the
number of splits (all possible splits are made).
"""
assert (
len(sep_char) == 1
), "separation string must be a single character for escaped splitting"
if maxsplit == 0:
return text
maxsplit = max(0, maxsplit)
return re.split(r"(?<!\\)" + sep_char, text, maxsplit)
def simple_parse_args_string(args_string): def simple_parse_args_string(args_string):
""" """
Parses something like Parses something like
......
...@@ -9,6 +9,10 @@ from lm_eval import tasks, evaluator ...@@ -9,6 +9,10 @@ from lm_eval import tasks, evaluator
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
def _is_json_task(task_name):
return task_name == "json" or task_name.startswith("json=")
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
self.choices = choices self.choices = choices
...@@ -16,7 +20,9 @@ class MultiChoice: ...@@ -16,7 +20,9 @@ class MultiChoice:
# Simple wildcard support (linux filename patterns) # Simple wildcard support (linux filename patterns)
def __contains__(self, values): def __contains__(self, values):
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0 and not _is_json_task(
value
):
return False return False
return True return True
...@@ -55,6 +61,9 @@ def parse_args(): ...@@ -55,6 +61,9 @@ def parse_args():
def pattern_match(patterns, source_list): def pattern_match(patterns, source_list):
task_names = set() task_names = set()
for pattern in patterns: for pattern in patterns:
if _is_json_task(pattern):
task_names.add(pattern)
for matching in fnmatch.filter(source_list, pattern): for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching) task_names.add(matching)
return sorted(list(task_names)) return sorted(list(task_names))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment