Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
a339ffd8
Commit
a339ffd8
authored
Jun 05, 2023
by
lintangsutawika
Browse files
allow to use alternative methods to use hf datasets, allow configuration with dataset_kwargs
parent
36da9c66
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
24 deletions
+24
-24
lm_eval/api/task.py
lm_eval/api/task.py
+24
-24
No files found.
lm_eval/api/task.py
View file @
a339ffd8
...
@@ -45,15 +45,16 @@ class TaskConfig(dict):
...
@@ -45,15 +45,16 @@ class TaskConfig(dict):
task_name
:
str
=
(
task_name
:
str
=
(
None
# TODO: deprecate this, it'll be set in __post_init__ to be names[0]
None
# TODO: deprecate this, it'll be set in __post_init__ to be names[0]
)
)
base_task
:
str
=
None
dataset_path
:
str
=
None
dataset_path
:
str
=
None
dataset_name
:
str
=
None
dataset_name
:
str
=
None
dataset_kwargs
:
dict
=
None
training_split
:
str
=
None
training_split
:
str
=
None
validation_split
:
str
=
None
validation_split
:
str
=
None
test_split
:
str
=
None
test_split
:
str
=
None
fewshot_split
:
str
=
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
fewshot_split
:
str
=
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases
:
str
=
None
template_aliases
:
str
=
None
aliases
:
Union
[
str
,
list
]
=
None
doc_to_text
:
Union
[
Callable
,
str
]
=
None
doc_to_text
:
Union
[
Callable
,
str
]
=
None
doc_to_target
:
Union
[
Callable
,
str
]
=
None
doc_to_target
:
Union
[
Callable
,
str
]
=
None
...
@@ -79,12 +80,12 @@ class TaskConfig(dict):
...
@@ -79,12 +80,12 @@ class TaskConfig(dict):
# allow user-specified aliases so that users can
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# force prompt-compatibility for some prompt regardless of
# field names in prompt
# field names in prompt
if
self
.
template_aliases
is
not
None
:
#
if self.template_aliases is not None:
if
type
(
self
.
doc_to_text
)
==
str
:
#
if type(self.doc_to_text) == str:
self
.
doc_to_text
=
self
.
template_aliases
+
self
.
doc_to_text
#
self.doc_to_text = self.template_aliases + self.doc_to_text
if
type
(
self
.
doc_to_target
)
==
str
:
#
if type(self.doc_to_target) == str:
self
.
doc_to_target
=
self
.
template_aliases
+
self
.
doc_to_target
#
self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set
# set "task_name" metadata field based on the "primary" name set
if
self
.
names
:
if
self
.
names
:
...
@@ -188,15 +189,6 @@ class Task(abc.ABC):
...
@@ -188,15 +189,6 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
Fresh download and fresh dataset.
"""
"""
if
self
.
DATASET_PATH
in
[
"json"
,
"csv"
]:
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
data_files
=
self
.
DATASET_NAME
,
data_dir
=
data_dir
,
cache_dir
=
cache_dir
,
download_mode
=
download_mode
,
)
else
:
self
.
dataset
=
datasets
.
load_dataset
(
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
,
name
=
self
.
DATASET_NAME
,
...
@@ -524,7 +516,7 @@ class ConfigurableTask(Task):
...
@@ -524,7 +516,7 @@ class ConfigurableTask(Task):
"Please check https://huggingface.co/evaluate-metric"
,
"Please check https://huggingface.co/evaluate-metric"
,
)
)
self
.
download
(
data_dir
,
cache_dir
,
download_mode
)
self
.
download
(
self
.
_config
.
dataset_kwargs
)
self
.
_training_docs
=
None
self
.
_training_docs
=
None
self
.
_fewshot_docs
=
None
self
.
_fewshot_docs
=
None
...
@@ -559,6 +551,14 @@ class ConfigurableTask(Task):
...
@@ -559,6 +551,14 @@ class ConfigurableTask(Task):
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
()
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
()
)
# TODO: pass the correct docs in here
)
# TODO: pass the correct docs in here
def
download
(
self
,
dataset_kwargs
=
None
):
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
,
**
dataset_kwargs
if
dataset_kwargs
is
not
None
else
{},
)
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
if
self
.
_config
.
training_split
is
not
None
:
if
self
.
_config
.
training_split
is
not
None
:
return
True
return
True
...
@@ -710,7 +710,7 @@ class ConfigurableTask(Task):
...
@@ -710,7 +710,7 @@ class ConfigurableTask(Task):
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
results
=
results
[
0
]
results
=
results
[
0
]
ll
,
is_greedy
=
results
ll
,
is_greedy
=
results
result_dict
=
{
"perplexity"
:
ll
,
"acc
uracy
"
:
int
(
is_greedy
)}
result_dict
=
{
"perplexity"
:
ll
,
"acc"
:
int
(
is_greedy
)}
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
(
loglikelihood
,)
=
results
(
loglikelihood
,)
=
results
words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment