Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
87445e95
Commit
87445e95
authored
Jul 23, 2025
by
Baber
Browse files
types
parent
5fdeb436
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
28 deletions
+26
-28
lm_eval/api/task.py
lm_eval/api/task.py
+17
-21
lm_eval/config/task.py
lm_eval/config/task.py
+9
-7
No files found.
lm_eval/api/task.py
View file @
87445e95
...
@@ -7,12 +7,7 @@ import random
...
@@ -7,12 +7,7 @@ import random
import
re
import
re
from
collections.abc
import
Callable
,
Iterable
,
Iterator
,
Mapping
from
collections.abc
import
Callable
,
Iterable
,
Iterator
,
Mapping
from
copy
import
deepcopy
from
copy
import
deepcopy
from
typing
import
(
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
overload
TYPE_CHECKING
,
Any
,
Literal
,
overload
,
)
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
...
@@ -25,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
...
@@ -25,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from
lm_eval.api.utils
import
check_gold_index_error
from
lm_eval.api.utils
import
check_gold_index_error
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.task
import
TaskConfig
from
lm_eval.config.task
import
DataSet
,
TaskConfig
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
...
@@ -134,6 +129,7 @@ class Task(abc.ABC):
...
@@ -134,6 +129,7 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
Fresh download and fresh dataset.
"""
"""
assert
self
.
DATASET_PATH
is
not
None
,
"DATASET_PATH must be set in Task class"
self
.
dataset
=
datasets
.
load_dataset
(
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
,
name
=
self
.
DATASET_NAME
,
...
@@ -147,43 +143,40 @@ class Task(abc.ABC):
...
@@ -147,43 +143,40 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class."""
"""Returns the TaskConfig associated with this class."""
return
self
.
_config
return
self
.
_config
@
abc
.
abstractmethod
def
has_training_docs
(
self
)
->
bool
:
def
has_training_docs
(
self
)
->
bool
:
"""Whether the task has a training set"""
"""Whether the task has a training set"""
pass
raise
NotImplementedError
@
abc
.
abstractmethod
def
has_validation_docs
(
self
)
->
bool
:
def
has_validation_docs
(
self
)
->
bool
:
"""Whether the task has a validation set"""
"""Whether the task has a validation set"""
pass
raise
NotImplementedError
@
abc
.
abstractmethod
def
has_test_docs
(
self
)
->
bool
:
def
has_test_docs
(
self
)
->
bool
:
"""Whether the task has a test set"""
"""Whether the task has a test set"""
pass
raise
NotImplementedError
def
training_docs
(
self
)
->
Iterabl
e
:
def
training_docs
(
self
)
->
DataSet
|
Non
e
:
"""
"""
:return: Iterable[obj]
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
A iterable of any object, that doc_to_text can handle
"""
"""
return
[]
return
[]
def
validation_docs
(
self
)
->
Iterabl
e
:
def
validation_docs
(
self
)
->
DataSet
|
Non
e
:
"""
"""
:return: Iterable[obj]
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
A iterable of any object, that doc_to_text can handle
"""
"""
return
[]
return
[]
def
test_docs
(
self
)
->
Iterabl
e
:
def
test_docs
(
self
)
->
DataSet
|
Non
e
:
"""
"""
:return: Iterable[obj]
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
A iterable of any object, that doc_to_text can handle
"""
"""
return
[]
return
[]
def
fewshot_docs
(
self
)
->
Iterabl
e
:
def
fewshot_docs
(
self
)
->
DataSet
|
Non
e
:
"""
"""
:return: Iterable[obj]
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
A iterable of any object, that doc_to_text can handle
...
@@ -587,7 +580,7 @@ class ConfigurableTask(Task):
...
@@ -587,7 +580,7 @@ class ConfigurableTask(Task):
data_dir
=
None
,
data_dir
=
None
,
cache_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
download_mode
=
None
,
config
:
dict
|
None
=
None
,
config
:
Mapping
[
str
,
Any
]
|
None
=
None
,
)
->
None
:
)
->
None
:
# Get pre-configured attributes
# Get pre-configured attributes
self
.
_config
=
self
.
CONFIG
self
.
_config
=
self
.
CONFIG
...
@@ -722,6 +715,9 @@ class ConfigurableTask(Task):
...
@@ -722,6 +715,9 @@ class ConfigurableTask(Task):
)
)
self
.
dataset
=
df
(
**
(
self
.
config
.
dataset_kwargs
|
self
.
config
.
metadata
))
self
.
dataset
=
df
(
**
(
self
.
config
.
dataset_kwargs
|
self
.
config
.
metadata
))
else
:
else
:
assert
self
.
config
.
dataset_path
is
not
None
,
(
"dataset_path must be set in TaskConfig"
)
self
.
dataset
=
datasets
.
load_dataset
(
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
config
.
dataset_path
,
path
=
self
.
config
.
dataset_path
,
name
=
self
.
config
.
dataset_name
,
name
=
self
.
config
.
dataset_name
,
...
@@ -737,7 +733,7 @@ class ConfigurableTask(Task):
...
@@ -737,7 +733,7 @@ class ConfigurableTask(Task):
def
has_test_docs
(
self
)
->
bool
:
def
has_test_docs
(
self
)
->
bool
:
return
self
.
config
.
test_split
is
not
None
return
self
.
config
.
test_split
is
not
None
def
training_docs
(
self
)
->
datasets
.
Data
s
et
|
None
:
def
training_docs
(
self
)
->
Data
S
et
|
None
:
if
self
.
has_training_docs
():
if
self
.
has_training_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
...
@@ -745,7 +741,7 @@ class ConfigurableTask(Task):
...
@@ -745,7 +741,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
config
.
training_split
]
return
self
.
dataset
[
self
.
config
.
training_split
]
def
validation_docs
(
self
)
->
datasets
.
Data
s
et
|
None
:
def
validation_docs
(
self
)
->
Data
S
et
|
None
:
if
self
.
has_validation_docs
():
if
self
.
has_validation_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
...
@@ -753,7 +749,7 @@ class ConfigurableTask(Task):
...
@@ -753,7 +749,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
config
.
validation_split
]
return
self
.
dataset
[
self
.
config
.
validation_split
]
def
test_docs
(
self
)
->
datasets
.
Data
s
et
|
None
:
def
test_docs
(
self
)
->
Data
S
et
|
None
:
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
...
...
lm_eval/config/task.py
View file @
87445e95
...
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
...
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
DataSet
=
Union
[
datasets
.
Dataset
,
Iterable
[
dict
[
str
,
Any
]]]
DataSet
=
Union
[
datasets
.
Dataset
,
Iterable
[
dict
[
str
,
Any
]]]
DSplits
=
dict
[
str
,
DataSet
]
@
dataclass
@
dataclass
...
@@ -34,7 +35,7 @@ class RepeatConfig:
...
@@ -34,7 +35,7 @@ class RepeatConfig:
@
dataclass
@
dataclass
class
FilterConfig
:
class
FilterConfig
:
"""Encapsulates information about a single filter."""
"""Encapsulates information about a single filter
pipeline
."""
name
:
str
name
:
str
ensemble
:
FilterEnsemble
ensemble
:
FilterEnsemble
...
@@ -71,16 +72,17 @@ class FewshotConfig:
...
@@ -71,16 +72,17 @@ class FewshotConfig:
"""Check if any fewshot source is configured."""
"""Check if any fewshot source is configured."""
return
self
.
split
is
not
None
or
self
.
samples
is
not
None
return
self
.
split
is
not
None
or
self
.
samples
is
not
None
def
_get_raw_docs
(
def
_get_raw_docs
(
self
,
dataset
:
DSplits
)
->
DataSet
|
None
:
self
,
dataset
)
->
list
[
dict
]
|
Callable
[[],
Iterable
[
dict
[
str
,
Any
]]]
|
None
:
"""Get raw documents from configured source."""
"""Get raw documents from configured source."""
if
self
.
split
is
not
None
:
if
self
.
split
is
not
None
:
return
dataset
[
self
.
split
]
return
dataset
[
self
.
split
]
if
self
.
samples
is
not
None
:
if
self
.
samples
is
not
None
:
if
isinstance
(
self
.
samples
,
list
)
or
callable
(
self
.
samples
)
:
if
isinstance
(
self
.
samples
,
list
):
return
self
.
samples
return
self
.
samples
elif
callable
(
self
.
samples
):
# If samples is a callable, it should return a list of dicts
return
self
.
samples
()
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"samples must be either a list of dicts or a callable returning a list"
"samples must be either a list of dicts or a callable returning a list"
...
@@ -158,7 +160,7 @@ class TaskConfig:
...
@@ -158,7 +160,7 @@ class TaskConfig:
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_config
:
dict
[
str
,
Any
]
|
None
=
None
fewshot_config
:
dict
[
str
,
Any
]
|
None
=
None
# runtime configuration options
# runtime configuration options
num_fewshot
:
int
|
None
=
0
num_fewshot
:
int
|
None
=
None
generation_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
generation_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
# scoring options
# scoring options
metric_list
:
list
|
None
=
None
metric_list
:
list
|
None
=
None
...
@@ -359,7 +361,7 @@ class TaskConfig:
...
@@ -359,7 +361,7 @@ class TaskConfig:
return
x
return
x
@
classmethod
@
classmethod
def
from_yaml
(
cls
,
data
:
dict
)
->
TaskConfig
:
def
from_yaml
(
cls
,
data
:
dict
[
str
,
Any
]
)
->
TaskConfig
:
"""Create a TaskConfig instance from a YAML-like dictionary."""
"""Create a TaskConfig instance from a YAML-like dictionary."""
return
cls
(
**
data
)
return
cls
(
**
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment