Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
84ef60ee
Unverified
Commit
84ef60ee
authored
May 21, 2023
by
Stella Biderman
Committed by
GitHub
May 21, 2023
Browse files
Merge pull request #481 from janEbert/json-task
Add perplexity task on arbitrary JSON data
parents
bda68845
4de8a74e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
129 additions
and
1 deletion
+129
-1
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+34
-0
lm_eval/tasks/json.py
lm_eval/tasks/json.py
+62
-0
lm_eval/utils.py
lm_eval/utils.py
+23
-0
main.py
main.py
+10
-1
No files found.
lm_eval/tasks/__init__.py
View file @
84ef60ee
...
@@ -52,6 +52,7 @@ from . import gsm8k
...
@@ -52,6 +52,7 @@ from . import gsm8k
from
.
import
storycloze
from
.
import
storycloze
from
.
import
toxigen
from
.
import
toxigen
from
.
import
crowspairs
from
.
import
crowspairs
from
.
import
json
from
.
import
xcopa
from
.
import
xcopa
from
.
import
bigbench
from
.
import
bigbench
from
.
import
xstorycloze
from
.
import
xstorycloze
...
@@ -329,9 +330,42 @@ TASK_REGISTRY = {
...
@@ -329,9 +330,42 @@ TASK_REGISTRY = {
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
_EXAMPLE_JSON_PATH
=
"split:key:/absolute/path/to/data.json"
def
add_json_task
(
task_name
):
"""Add a JSON perplexity task if the given task name matches the
JSON task specification.
See `json.JsonPerplexity`.
"""
if
not
task_name
.
startswith
(
"json"
):
return
def
create_json_task
():
splits
=
task_name
.
split
(
"="
,
1
)
if
len
(
splits
)
!=
2
or
not
splits
[
1
]:
raise
ValueError
(
"json tasks need a path argument pointing to the local "
"dataset, specified like this: json="
+
_EXAMPLE_JSON_PATH
+
' (if there are no splits, use "train")'
)
json_path
=
splits
[
1
]
if
json_path
==
_EXAMPLE_JSON_PATH
:
raise
ValueError
(
"please do not copy the example path directly, but substitute "
"it with a path to your local dataset"
)
return
lambda
:
json
.
JsonPerplexity
(
json_path
)
TASK_REGISTRY
[
task_name
]
=
create_json_task
()
def
get_task
(
task_name
):
def
get_task
(
task_name
):
try
:
try
:
add_json_task
(
task_name
)
return
TASK_REGISTRY
[
task_name
]
return
TASK_REGISTRY
[
task_name
]
except
KeyError
:
except
KeyError
:
print
(
"Available tasks:"
)
print
(
"Available tasks:"
)
...
...
lm_eval/tasks/json.py
0 → 100644
View file @
84ef60ee
import
datasets
from
lm_eval.base
import
PerplexityTask
from
lm_eval.utils
import
escaped_split
class
JsonPerplexity
(
PerplexityTask
):
VERSION
=
0
DATASET_NAME
=
"json"
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
"""
:param data_dir: str
Use this to specify the path to manually downloaded JSON test data.
This also needs to include the split key and text key for the data
in the following format:
```
split:text:/absolute/path/to/data.json
```
If you do not have splits inside the JSON file, it should be "train".
Colons in the split or text key can be escaped by backslashes.
:param cache_dir: str
The directory to read/write the `Task` dataset. This follows the
HuggingFace `datasets` API with the default cache directory located at:
`~/.cache/huggingface/datasets`
NOTE: You can change the cache location globally for a given process
by setting the shell environment variable, `HF_DATASETS_CACHE`,
to another directory:
`export HF_DATASETS_CACHE="/path/to/another/directory"`
:param download_mode: datasets.DownloadMode
How to treat pre-existing `Task` downloads and data.
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
Reuse download and reuse dataset.
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
Reuse download with fresh dataset.
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
self
.
_split
,
self
.
_key
,
data_file
=
escaped_split
(
data_dir
,
":"
,
2
)
self
.
load
(
data_file
)
self
.
_training_docs
=
None
self
.
_fewshot_docs
=
None
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
raise
TypeError
(
"cannot download an arbitrary JSON dataset"
)
def
load
(
self
,
data_file
):
self
.
dataset
=
datasets
.
load_dataset
(
"json"
,
data_files
=
data_file
)
def
has_validation_docs
(
self
):
return
False
def
has_test_docs
(
self
):
return
True
def
test_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
self
.
_split
])
def
_process_doc
(
self
,
doc
):
return
doc
[
self
.
_key
]
lm_eval/utils.py
View file @
84ef60ee
...
@@ -21,6 +21,29 @@ def sh(x):
...
@@ -21,6 +21,29 @@ def sh(x):
raise
ExitCodeError
()
raise
ExitCodeError
()
def
escaped_split
(
text
,
sep_char
,
maxsplit
=-
1
):
"""Split text into a list on occurrences of the given separation
character `sep_char`. The separation character may be escaped by a
backslash to avoid splitting at that location.
The separation character must be a string of size 1.
If `maxsplit` is given, at most `maxsplit` splits are done (thus,
the list will have at most `maxsplit + 1` elements). If `maxsplit`
is not specified or less than 0, then there is no limit on the
number of splits (all possible splits are made).
"""
assert
(
len
(
sep_char
)
==
1
),
"separation string must be a single character for escaped splitting"
if
maxsplit
==
0
:
return
text
maxsplit
=
max
(
0
,
maxsplit
)
return
re
.
split
(
r
"(?<!\\)"
+
sep_char
,
text
,
maxsplit
)
def
simple_parse_args_string
(
args_string
):
def
simple_parse_args_string
(
args_string
):
"""
"""
Parses something like
Parses something like
...
...
main.py
View file @
84ef60ee
...
@@ -9,6 +9,10 @@ from lm_eval import tasks, evaluator
...
@@ -9,6 +9,10 @@ from lm_eval import tasks, evaluator
logging
.
getLogger
(
"openai"
).
setLevel
(
logging
.
WARNING
)
logging
.
getLogger
(
"openai"
).
setLevel
(
logging
.
WARNING
)
def
_is_json_task
(
task_name
):
return
task_name
==
"json"
or
task_name
.
startswith
(
"json="
)
class
MultiChoice
:
class
MultiChoice
:
def
__init__
(
self
,
choices
):
def
__init__
(
self
,
choices
):
self
.
choices
=
choices
self
.
choices
=
choices
...
@@ -16,7 +20,9 @@ class MultiChoice:
...
@@ -16,7 +20,9 @@ class MultiChoice:
# Simple wildcard support (linux filename patterns)
# Simple wildcard support (linux filename patterns)
def
__contains__
(
self
,
values
):
def
__contains__
(
self
,
values
):
for
value
in
values
.
split
(
","
):
for
value
in
values
.
split
(
","
):
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
:
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
and
not
_is_json_task
(
value
):
return
False
return
False
return
True
return
True
...
@@ -55,6 +61,9 @@ def parse_args():
...
@@ -55,6 +61,9 @@ def parse_args():
def
pattern_match
(
patterns
,
source_list
):
def
pattern_match
(
patterns
,
source_list
):
task_names
=
set
()
task_names
=
set
()
for
pattern
in
patterns
:
for
pattern
in
patterns
:
if
_is_json_task
(
pattern
):
task_names
.
add
(
pattern
)
for
matching
in
fnmatch
.
filter
(
source_list
,
pattern
):
for
matching
in
fnmatch
.
filter
(
source_list
,
pattern
):
task_names
.
add
(
matching
)
task_names
.
add
(
matching
)
return
sorted
(
list
(
task_names
))
return
sorted
(
list
(
task_names
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment