Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9deea116
Commit
9deea116
authored
Jun 02, 2023
by
lintangsutawika
Browse files
Merge branch 'local-file' into dataset-metric-log
parents
9f62d3ac
6acc5c47
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
15 deletions
+27
-15
lm_eval/api/task.py
lm_eval/api/task.py
+27
-15
No files found.
lm_eval/api/task.py
View file @
9deea116
...
@@ -18,6 +18,7 @@ from collections.abc import Callable
...
@@ -18,6 +18,7 @@ from collections.abc import Callable
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api
import
samplers
from
lm_eval.api
import
samplers
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.metrics
import
(
from
lm_eval.api.metrics
import
(
METRIC_REGISTRY
,
METRIC_REGISTRY
,
AGGREGATION_REGISTRY
,
AGGREGATION_REGISTRY
,
...
@@ -187,13 +188,22 @@ class Task(abc.ABC):
...
@@ -187,13 +188,22 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
Fresh download and fresh dataset.
"""
"""
self
.
dataset
=
datasets
.
load_dataset
(
if
self
.
DATASET_PATH
in
[
"json"
,
"csv"
]:
path
=
self
.
DATASET_PATH
,
self
.
dataset
=
datasets
.
load_dataset
(
name
=
self
.
DATASET_NAME
,
path
=
self
.
DATASET_PATH
,
data_dir
=
data_dir
,
data_files
=
self
.
DATASET_NAME
,
cache_dir
=
cache_dir
,
data_dir
=
data_dir
,
download_mode
=
download_mode
,
cache_dir
=
cache_dir
,
)
download_mode
=
download_mode
,
)
else
:
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
,
data_dir
=
data_dir
,
cache_dir
=
cache_dir
,
download_mode
=
download_mode
,
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
...
@@ -436,8 +446,12 @@ class Task(abc.ABC):
...
@@ -436,8 +446,12 @@ class Task(abc.ABC):
def
apply_filters
(
self
):
def
apply_filters
(
self
):
for
f
in
self
.
_filters
:
if
hasattr
(
self
,
"_filters"
):
f
.
apply
(
self
.
_instances
)
for
f
in
self
.
_filters
:
f
.
apply
(
self
.
_instances
)
else
:
eval_logger
.
warning
(
"No filter defined, passing through instances"
)
return
self
.
_instances
class
ConfigurableTask
(
Task
):
class
ConfigurableTask
(
Task
):
...
@@ -514,8 +528,8 @@ class ConfigurableTask(Task):
...
@@ -514,8 +528,8 @@ class ConfigurableTask(Task):
self
.
_training_docs
=
None
self
.
_training_docs
=
None
self
.
_fewshot_docs
=
None
self
.
_fewshot_docs
=
None
self
.
_filters
=
[]
if
self
.
_config
.
filter_list
is
not
None
:
if
self
.
_config
.
filter_list
is
not
None
:
self
.
_filters
=
[]
for
filter_config
in
self
.
_config
.
filter_list
:
for
filter_config
in
self
.
_config
.
filter_list
:
for
filter_pipeline
in
filter_config
:
for
filter_pipeline
in
filter_config
:
filter_name
=
filter_config
[
"name"
]
filter_name
=
filter_config
[
"name"
]
...
@@ -528,11 +542,9 @@ class ConfigurableTask(Task):
...
@@ -528,11 +542,9 @@ class ConfigurableTask(Task):
components
.
append
([
function
[
"function"
],
kwargs
])
components
.
append
([
function
[
"function"
],
kwargs
])
filter_pipeline
=
build_filter_ensemble
(
filter_name
,
components
)
filter_pipeline
=
build_filter_ensemble
(
filter_name
,
components
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
_filters
.
append
(
filter_pipeline
)
else
:
else
:
self
.
_filters
=
[
self
.
_filters
=
[
build_filter_ensemble
(
"none"
,
[(
"none"
,
None
)])]
build_filter_ensemble
(
"take_first"
,
[[
"take_first"
,
None
]])
]
if
self
.
_config
.
use_prompt
is
not
None
:
if
self
.
_config
.
use_prompt
is
not
None
:
eval_logger
.
info
(
f
"loading prompt
{
self
.
_config
.
use_prompt
}
"
)
eval_logger
.
info
(
f
"loading prompt
{
self
.
_config
.
use_prompt
}
"
)
...
@@ -768,7 +780,7 @@ class ConfigurableTask(Task):
...
@@ -768,7 +780,7 @@ class ConfigurableTask(Task):
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_kwargs
[
key
]
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_kwargs
[
key
]
)
)
result_dict
[
key
]
=
_dict
[
key
]
result_dict
=
{
**
result_dict
,
**
_dict
}
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Passed invalid output_type '
{
self
.
OUTPUT_TYPE
}
' ! Please use one of "
,
f
"Passed invalid output_type '
{
self
.
OUTPUT_TYPE
}
' ! Please use one of "
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment