Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
4ad6cd9f
Commit
4ad6cd9f
authored
Jul 22, 2025
by
Baber
Browse files
remove deps; types
parent
689e0c91
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
240 additions
and
146 deletions
+240
-146
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
lm_eval/api/model.py
lm_eval/api/model.py
+20
-20
lm_eval/api/task.py
lm_eval/api/task.py
+96
-27
lm_eval/config/metric.py
lm_eval/config/metric.py
+3
-3
lm_eval/config/task.py
lm_eval/config/task.py
+30
-32
lm_eval/decontamination/archiver.py
lm_eval/decontamination/archiver.py
+15
-4
lm_eval/utils.py
lm_eval/utils.py
+25
-18
pyproject.toml
pyproject.toml
+50
-41
No files found.
.pre-commit-config.yaml
View file @
4ad6cd9f
...
...
@@ -33,7 +33,7 @@ repos:
hooks
:
# Run the linter.
-
id
:
ruff-check
args
:
[
--fix
]
args
:
[
--fix
]
# Run the formatter.
-
id
:
ruff-format
-
repo
:
https://github.com/codespell-project/codespell
...
...
lm_eval/api/model.py
View file @
4ad6cd9f
from
__future__
import
annotations
import
abc
import
hashlib
import
json
import
logging
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Iterable
,
Optional
,
Type
,
TypeVar
,
Union
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
from
tqdm
import
tqdm
...
...
@@ -31,7 +34,7 @@ class LM(abc.ABC):
# set rank and world size to a single process, by default.
self
.
_rank
=
0
self
.
_world_size
=
1
self
.
cache_hook
:
"
CacheHook
"
=
CacheHook
(
None
)
self
.
cache_hook
:
CacheHook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
:
list
[
Instance
])
->
list
[
tuple
[
float
,
bool
]]:
...
...
@@ -101,7 +104,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
:
list
[
"
Instance
"
])
->
list
[
str
]:
def
generate_until
(
self
,
requests
:
list
[
Instance
])
->
list
[
str
]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
...
...
@@ -137,7 +140,7 @@ class LM(abc.ABC):
@
classmethod
def
create_from_arg_string
(
cls
:
T
ype
[
T
],
arg_string
:
str
,
additional_config
:
Optional
[
dict
]
=
None
cls
:
t
ype
[
T
],
arg_string
:
str
,
additional_config
:
dict
|
None
=
None
)
->
T
:
"""
Creates an instance of the LM class using the given argument string and additional config.
...
...
@@ -156,7 +159,7 @@ class LM(abc.ABC):
@
classmethod
def
create_from_arg_obj
(
cls
:
T
ype
[
T
],
arg_dict
:
dict
,
additional_config
:
Optional
[
dict
]
=
None
cls
:
t
ype
[
T
],
arg_dict
:
dict
,
additional_config
:
dict
|
None
=
None
)
->
T
:
"""
Creates an instance of the LM class using the given arg_obj
...
...
@@ -201,7 +204,7 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property."
)
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]
:
def
chat_template
(
self
,
chat_template
:
bool
|
str
=
False
)
->
str
|
None
:
"""Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default.
...
...
@@ -209,7 +212,7 @@ class LM(abc.ABC):
return
""
def
set_cache_hook
(
self
,
cache_hook
:
"
CacheHook
"
)
->
None
:
def
set_cache_hook
(
self
,
cache_hook
:
CacheHook
)
->
None
:
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
self
.
cache_hook
=
cache_hook
...
...
@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class
CacheHook
:
def
__init__
(
self
,
cachinglm
:
Optional
[
"
CachingLM
"
]
)
->
None
:
def
__init__
(
self
,
cachinglm
:
CachingLM
|
None
)
->
None
:
"""CacheHook is used to cache responses from the LM."""
if
cachinglm
is
None
:
self
.
dbdict
:
Optional
[
"
SqliteDict
"
]
=
None
self
.
dbdict
:
SqliteDict
|
None
=
None
return
self
.
dbdict
=
cachinglm
.
dbdict
...
...
@@ -238,7 +241,7 @@ class CacheHook:
class
CachingLM
:
def
__init__
(
self
,
lm
:
"
LM
"
,
cache_db
:
str
)
->
None
:
def
__init__
(
self
,
lm
:
LM
,
cache_db
:
str
)
->
None
:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
...
...
@@ -263,7 +266,7 @@ class CachingLM:
eval_logger
.
debug
(
f
"Passing through attribute '
{
attr
}
' to underlying LM"
)
return
lm_attr
def
_fn
(
requests
:
list
[
"
Instance
"
])
->
list
[
"
Instance
"
]:
def
_fn
(
requests
:
list
[
Instance
])
->
list
[
Instance
]:
res
=
[]
remaining_reqs
=
[]
warned
=
False
...
...
@@ -295,11 +298,8 @@ class CachingLM:
eval_logger
.
info
(
f
"Cached requests:
{
len
(
requests
)
-
len
(
remaining_reqs
)
}
, Requests remaining:
{
len
(
remaining_reqs
)
}
"
)
if
remaining_reqs
:
# actually run the LM on the requests that do not have cached results
rem_res
=
getattr
(
self
.
lm
,
attr
)(
remaining_reqs
)
else
:
rem_res
=
[]
rem_res
=
getattr
(
self
.
lm
,
attr
)(
remaining_reqs
)
if
remaining_reqs
else
[]
# stick the new ones back into the list and also cache any of the new ones
resptr
=
0
...
...
@@ -318,7 +318,7 @@ class CachingLM:
return
_fn
def
get_cache_hook
(
self
)
->
"
CacheHook
"
:
def
get_cache_hook
(
self
)
->
CacheHook
:
return
CacheHook
(
self
)
...
...
@@ -399,7 +399,7 @@ class TemplateLM(LM):
return
context_enc
,
continuation_enc
def
loglikelihood
(
self
,
requests
:
list
[
"
Instance
"
],
disable_tqdm
:
bool
=
False
self
,
requests
:
list
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
list
[
tuple
[
float
,
bool
]]:
"""Compute log-likelihood of generating a continuation from a context.
...
...
@@ -432,7 +432,7 @@ class TemplateLM(LM):
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
:
list
[
"
Instance
"
],
disable_tqdm
:
bool
=
False
self
,
requests
:
list
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
list
[
str
]:
"""Generate until a stopping sequence.
...
...
@@ -453,7 +453,7 @@ class TemplateLM(LM):
"""
pass
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]
:
def
chat_template
(
self
,
chat_template
:
bool
|
str
=
False
)
->
str
|
None
:
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
...
...
lm_eval/api/task.py
View file @
4ad6cd9f
...
...
@@ -7,11 +7,7 @@ import random
import
re
from
collections.abc
import
Callable
from
copy
import
deepcopy
from
typing
import
(
TYPE_CHECKING
,
Any
,
Literal
,
)
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
overload
import
datasets
import
numpy
as
np
...
...
@@ -24,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from
lm_eval.api.utils
import
check_gold_index_error
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.task
import
TaskConfig
from
lm_eval.config.task
import
DataSet
,
TaskConfig
from
lm_eval.filters
import
build_filter_ensemble
...
...
@@ -133,6 +129,7 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
assert
self
.
DATASET_PATH
is
not
None
,
"DATASET_PATH must be set in Task class"
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
,
...
...
@@ -146,43 +143,40 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class."""
return
self
.
_config
@
abc
.
abstractmethod
def
has_training_docs
(
self
)
->
bool
:
"""Whether the task has a training set"""
pass
raise
NotImplementedError
@
abc
.
abstractmethod
def
has_validation_docs
(
self
)
->
bool
:
"""Whether the task has a validation set"""
pass
raise
NotImplementedError
@
abc
.
abstractmethod
def
has_test_docs
(
self
)
->
bool
:
"""Whether the task has a test set"""
pass
raise
NotImplementedError
def
training_docs
(
self
)
->
Iterabl
e
:
def
training_docs
(
self
)
->
DataSet
|
Non
e
:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return
[]
def
validation_docs
(
self
)
->
Iterabl
e
:
def
validation_docs
(
self
)
->
DataSet
|
Non
e
:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return
[]
def
test_docs
(
self
)
->
Iterabl
e
:
def
test_docs
(
self
)
->
DataSet
|
Non
e
:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return
[]
def
fewshot_docs
(
self
)
->
Iterabl
e
:
def
fewshot_docs
(
self
)
->
DataSet
|
Non
e
:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
...
...
@@ -192,7 +186,7 @@ class Task(abc.ABC):
elif
self
.
has_validation_docs
():
return
self
.
validation_docs
()
else
:
if
self
.
config
.
get
(
"
num_fewshot
"
,
0
)
>
0
:
if
self
.
config
.
num_fewshot
and
self
.
config
.
num_fewshot
>
0
:
eval_logger
.
warning
(
f
"[Task:
{
self
.
config
.
task
}
] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended."
...
...
@@ -331,7 +325,7 @@ class Task(abc.ABC):
inst
=
self
.
construct_requests
(
doc
=
doc
,
ctx
=
fewshot_ctx
,
metadata
=
(
self
.
config
[
"
task
"
]
,
doc_id
,
self
.
config
.
repeats
),
metadata
=
(
self
.
config
.
task
,
doc_id
,
self
.
config
.
repeats
),
apply_chat_template
=
apply_chat_template
,
chat_template
=
chat_template
,
)
...
...
@@ -586,7 +580,7 @@ class ConfigurableTask(Task):
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
dict
|
None
=
None
,
config
:
Mapping
[
str
,
Any
]
|
None
=
None
,
)
->
None
:
# Get pre-configured attributes
self
.
_config
=
self
.
CONFIG
...
...
@@ -727,6 +721,9 @@ class ConfigurableTask(Task):
)
self
.
dataset
=
df
(
**
(
self
.
config
.
dataset_kwargs
|
self
.
config
.
metadata
))
else
:
assert
self
.
config
.
dataset_path
is
not
None
,
(
"dataset_path must be set in TaskConfig"
)
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
config
.
dataset_path
,
name
=
self
.
config
.
dataset_name
,
...
...
@@ -742,7 +739,7 @@ class ConfigurableTask(Task):
def
has_test_docs
(
self
)
->
bool
:
return
self
.
config
.
test_split
is
not
None
def
training_docs
(
self
)
->
datasets
.
Data
s
et
|
None
:
def
training_docs
(
self
)
->
Data
S
et
|
None
:
if
self
.
has_training_docs
():
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
...
...
@@ -750,7 +747,7 @@ class ConfigurableTask(Task):
)
return
self
.
dataset
[
self
.
config
.
training_split
]
def
validation_docs
(
self
)
->
datasets
.
Data
s
et
|
None
:
def
validation_docs
(
self
)
->
Data
S
et
|
None
:
if
self
.
has_validation_docs
():
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
...
...
@@ -758,7 +755,7 @@ class ConfigurableTask(Task):
)
return
self
.
dataset
[
self
.
config
.
validation_split
]
def
test_docs
(
self
)
->
datasets
.
Data
s
et
|
None
:
def
test_docs
(
self
)
->
Data
S
et
|
None
:
if
self
.
has_test_docs
():
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
...
...
@@ -996,9 +993,21 @@ class ConfigurableTask(Task):
"""
return
doc
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
None
=
None
)
->
str
|
int
:
...
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
)
->
int
:
...
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
str
)
->
str
:
...
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
Callable
[...,
str
])
->
str
:
...
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
[...,
str
]
|
None
=
None
)
->
str
:
)
->
str
|
int
:
# if self.prompt is not None:
# doc_to_text = self.prompt
doc_to_text
=
doc_to_text
or
self
.
config
.
doc_to_text
...
...
@@ -1031,6 +1040,25 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
raise
TypeError
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
None
=
None
)
->
int
|
str
|
list
[
int
]:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
int
)
->
int
:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
str
)
->
int
|
str
|
list
[
int
]:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
list
)
->
list
[
int
]:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
Callable
[...,
int
|
str
|
list
[
int
]]
)
->
int
|
str
|
list
[
int
]:
...
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
int
|
str
|
list
[
int
]:
# if self.prompt is not None:
# doc_to_target = self.prompt
...
...
@@ -1077,6 +1105,23 @@ class ConfigurableTask(Task):
else
:
raise
TypeError
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
None
=
None
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
str
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
list
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
dict
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
Callable
[...,
list
[
str
]]
)
->
list
[
str
]:
...
def
doc_to_choice
(
self
,
doc
:
dict
,
...
...
@@ -1108,6 +1153,18 @@ class ConfigurableTask(Task):
else
:
raise
TypeError
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
None
=
None
)
->
None
:
...
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
list
)
->
list
:
...
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
str
)
->
int
|
str
|
None
:
...
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
Callable
[...,
Any
])
->
Any
:
...
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
int
|
str
|
list
|
None
:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
...
...
@@ -1131,6 +1188,18 @@ class ConfigurableTask(Task):
else
:
return
None
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
None
=
None
)
->
None
:
...
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
list
)
->
list
:
...
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
str
)
->
int
|
str
|
None
:
...
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
Callable
[...,
Any
])
->
Any
:
...
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
int
|
str
|
list
|
None
:
if
doc_to_audio
is
not
None
:
doc_to_audio
=
doc_to_audio
...
...
@@ -1375,15 +1444,15 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
gold
=
self
.
doc_to_target
(
doc
)
result
=
results
[
0
]
for
metric
in
self
.
_metric_
fn_
list
:
for
metric
in
self
.
config
.
_metric_list
:
try
:
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
(
result_score
=
metric
.
fn
(
references
=
[
gold
]
if
not
isinstance
(
gold
,
list
)
else
gold
,
predictions
=
[
result
],
**
self
.
_
metric
_fn_
kwargs
[
metric
]
,
**
metric
.
kwargs
,
)
except
TypeError
:
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
([
gold
,
result
])
result_score
=
metric
.
fn
([
gold
,
result
])
if
isinstance
(
result_score
,
dict
):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
...
...
lm_eval/config/metric.py
View file @
4ad6cd9f
from
__future__
import
annotations
from
collections.abc
import
Callable
,
Mapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
functools
import
cached_property
from
typing
import
Any
...
...
@@ -11,8 +11,8 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
name
:
str
fn
:
Callable
|
None
=
None
kwargs
:
Mapping
[
str
,
Any
]
|
None
=
None
fn
:
Callable
kwargs
:
Mapping
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
aggregation_fn
:
Callable
|
None
=
None
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
...
...
lm_eval/config/task.py
View file @
4ad6cd9f
...
...
@@ -3,7 +3,9 @@ from __future__ import annotations
import
logging
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Union
import
datasets
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
...
...
@@ -18,6 +20,9 @@ if TYPE_CHECKING:
eval_logger
=
logging
.
getLogger
(
__name__
)
DataSet
=
Union
[
datasets
.
Dataset
,
Iterable
[
dict
[
str
,
Any
]]]
DSplits
=
dict
[
str
,
DataSet
]
@
dataclass
class
RepeatConfig
:
...
...
@@ -30,7 +35,7 @@ class RepeatConfig:
@
dataclass
class
FilterConfig
:
"""Encapsulates information about a single filter."""
"""Encapsulates information about a single filter
pipeline
."""
name
:
str
ensemble
:
FilterEnsemble
...
...
@@ -44,10 +49,8 @@ class FewshotConfig:
num_fewshot
:
Callable
[[],
int
]
split
:
str
|
None
=
None
sampler
:
str
|
Callable
=
"default"
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
process_docs
:
Callable
[[
list
[
dict
[
str
,
Any
]]],
Iterable
[
dict
[
str
,
Any
]]]
|
None
=
(
None
)
samples
:
Callable
[[],
DataSet
]
|
DataSet
|
None
=
None
process_docs
:
Callable
[[
DataSet
],
DataSet
]
|
None
=
None
fewshot_indices
:
list
[
int
]
|
None
=
None
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
...
...
@@ -69,22 +72,23 @@ class FewshotConfig:
"""Check if any fewshot source is configured."""
return
self
.
split
is
not
None
or
self
.
samples
is
not
None
def
_get_raw_docs
(
self
,
dataset
)
->
list
[
dict
]
|
Callable
[[],
Iterable
[
dict
]]
|
None
:
def
_get_raw_docs
(
self
,
dataset
:
DSplits
)
->
DataSet
|
None
:
"""Get raw documents from configured source."""
if
self
.
split
is
not
None
:
return
dataset
[
self
.
split
]
if
self
.
samples
is
not
None
:
if
isinstance
(
self
.
samples
,
list
)
or
callable
(
self
.
samples
)
:
if
isinstance
(
self
.
samples
,
list
):
return
self
.
samples
elif
callable
(
self
.
samples
):
# If samples is a callable, it should return a list of dicts
return
self
.
samples
()
else
:
raise
TypeError
(
"samples must be either a list of dicts or a callable returning a list"
)
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
[
str
,
Any
]]
|
None
:
def
get_docs
(
self
,
dataset
)
->
DataSet
|
None
:
"""Get processed documents from configured source."""
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
if
raw_docs
is
None
:
...
...
@@ -130,34 +134,34 @@ class TaskConfig:
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset
:
Callable
|
None
=
None
custom_dataset
:
Callable
[...,
DataSet
]
|
None
=
None
dataset_path
:
str
|
None
=
None
dataset_name
:
str
|
None
=
None
dataset_kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
training_split
:
str
|
None
=
None
validation_split
:
str
|
None
=
None
test_split
:
str
|
None
=
None
fewshot_split
:
str
|
None
=
(
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
fewshot_split
:
str
|
None
=
None
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs
:
Callable
|
None
=
None
doc_to_text
:
Callable
|
str
|
None
=
None
doc_to_target
:
Callable
|
str
|
None
=
None
doc_to_image
:
Callable
|
str
|
None
=
None
doc_to_audio
:
Callable
|
str
|
None
=
None
process_docs
:
Callable
[[
DataSet
],
DataSet
]
|
None
=
None
doc_to_text
:
Callable
[[
dict
[
str
,
Any
]],
Any
]
|
str
|
None
=
None
doc_to_target
:
Callable
[[
dict
[
str
,
Any
]],
Any
]
|
str
|
None
=
None
doc_to_image
:
Callable
[[
dict
[
str
,
Any
]],
Any
]
|
str
|
None
=
None
doc_to_audio
:
Callable
[[
dict
[
str
,
Any
]],
Any
]
|
str
|
None
=
None
unsafe_code
:
bool
=
False
doc_to_choice
:
Callable
|
str
|
dict
|
list
|
None
=
None
process_results
:
Callable
|
str
|
None
=
None
doc_to_choice
:
Callable
[[
dict
[
str
,
Any
]],
Any
]
|
str
|
dict
|
list
|
None
=
None
process_results
:
(
Callable
[[
dict
[
str
,
Any
],
list
[
Any
]],
dict
[
str
,
Any
]]
|
str
|
None
)
=
None
use_prompt
:
str
|
None
=
None
description
:
str
=
""
target_delimiter
:
str
=
" "
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_config
:
dict
|
None
=
None
fewshot_config
:
dict
[
str
,
Any
]
|
None
=
None
# runtime configuration options
num_fewshot
:
int
|
None
=
0
generation_kwargs
:
dict
|
None
=
None
num_fewshot
:
int
|
None
=
None
generation_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
# scoring options
metric_list
:
list
|
None
=
None
output_type
:
OutputType
=
"generate_until"
...
...
@@ -357,7 +361,7 @@ class TaskConfig:
return
x
@
classmethod
def
from_yaml
(
cls
,
data
:
dict
)
->
TaskConfig
:
def
from_yaml
(
cls
,
data
:
dict
[
str
,
Any
]
)
->
TaskConfig
:
"""Create a TaskConfig instance from a YAML-like dictionary."""
return
cls
(
**
data
)
...
...
@@ -425,12 +429,6 @@ class TaskConfig:
# Create and return TaskConfig instance
return
cls
(
**
config_dict
)
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
def
__setitem__
(
self
,
item
,
value
):
return
setattr
(
self
,
item
,
value
)
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
def
_ser
(
x
):
if
isinstance
(
x
,
dict
):
...
...
lm_eval/decontamination/archiver.py
View file @
4ad6cd9f
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "jsonlines",
# "mmap",
# "tqdm",
# "zstandard",
# ]
# ///
# ruff: noqa
import
datetime
import
io
import
json
...
...
@@ -111,7 +122,7 @@ class TextReader:
current_file_position
=
0
line_counter
=
0
with
(
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf-8"
)
as
fh
,
open
(
self
.
file_path
,
encoding
=
"utf-8"
)
as
fh
,
tqdm
.
tqdm
(
total
=
os
.
path
.
getsize
(
self
.
file_path
),
dynamic_ncols
=
True
,
...
...
@@ -133,7 +144,7 @@ class TextReader:
def
read_and_tell
(
self
):
current_file_position
=
0
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
...
...
@@ -143,14 +154,14 @@ class TextReader:
yield
line
[:
-
1
],
raw_bytes_read
def
read
(
self
):
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
yield
line
[:
-
1
]
def
read_slow
(
self
):
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
encoding
=
"utf8"
)
as
fh
:
while
True
:
line
=
fh
.
readline
()
if
line
==
-
1
or
line
==
""
:
...
...
lm_eval/utils.py
View file @
4ad6cd9f
import
collections
import
fnmatch
import
functools
import
hashlib
import
importlib.util
import
inspect
...
...
@@ -8,10 +7,12 @@ import json
import
logging
import
os
import
re
from
collections.abc
import
Generator
from
dataclasses
import
asdict
,
is_dataclass
from
functools
import
lru_cache
,
partial
,
wraps
from
itertools
import
islice
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Generator
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Optional
import
numpy
as
np
import
yaml
...
...
@@ -108,7 +109,7 @@ def escaped_split(text, sep_char, maxsplit=-1):
return
text
maxsplit
=
max
(
0
,
maxsplit
)
return
re
.
split
(
r
"(?<!\\)"
+
sep_char
,
text
,
maxsplit
)
return
re
.
split
(
r
"(?<!\\)"
+
sep_char
,
text
,
maxsplit
=
maxsplit
)
def
handle_arg_string
(
arg
):
...
...
@@ -125,7 +126,7 @@ def handle_arg_string(arg):
def
handle_non_serializable
(
o
):
if
isinstance
(
o
,
np
.
int
64
)
or
isinstance
(
o
,
np
.
int32
):
if
isinstance
(
o
,
np
.
int
eger
):
return
int
(
o
)
elif
isinstance
(
o
,
set
):
return
list
(
o
)
...
...
@@ -235,21 +236,21 @@ def sanitize_task_name(task_name: str) -> str:
return
re
.
sub
(
r
"\W"
,
"_"
,
task_name
)
def
get_latest_filename
(
filenames
:
L
ist
[
str
])
->
str
:
def
get_latest_filename
(
filenames
:
l
ist
[
str
])
->
str
:
"""
Given a list of filenames, returns the filename with the latest datetime.
"""
return
max
(
filenames
,
key
=
lambda
f
:
get_file_datetime
(
f
))
def
get_results_filenames
(
filenames
:
L
ist
[
str
])
->
L
ist
[
str
]:
def
get_results_filenames
(
filenames
:
l
ist
[
str
])
->
l
ist
[
str
]:
"""
Extracts filenames that correspond to aggregated results.
"""
return
[
f
for
f
in
filenames
if
"/results_"
in
f
and
".json"
in
f
]
def
get_sample_results_filenames
(
filenames
:
L
ist
[
str
])
->
L
ist
[
str
]:
def
get_sample_results_filenames
(
filenames
:
l
ist
[
str
])
->
l
ist
[
str
]:
"""
Extracts filenames that correspond to sample results.
"""
...
...
@@ -257,8 +258,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
def
get_rolling_token_windows
(
token_list
:
L
ist
[
int
],
prefix_token
:
int
,
max_seq_len
:
int
,
context_len
:
int
)
->
Generator
[
T
uple
[
L
ist
[
int
],
L
ist
[
int
]],
None
,
None
]:
token_list
:
l
ist
[
int
],
prefix_token
:
int
,
max_seq_len
:
int
,
context_len
:
int
)
->
Generator
[
t
uple
[
l
ist
[
int
],
l
ist
[
int
]],
None
,
None
]:
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
...
...
@@ -300,8 +301,8 @@ def get_rolling_token_windows(
def
make_disjoint_window
(
pair
:
T
uple
[
L
ist
[
int
],
L
ist
[
int
]],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]:
pair
:
t
uple
[
l
ist
[
int
],
l
ist
[
int
]],
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]:
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a
,
b
=
pair
return
a
[:
len
(
a
)
-
(
len
(
b
)
-
1
)],
b
...
...
@@ -320,7 +321,7 @@ class EnhancedJSONEncoder(json.JSONEncoder):
class
Reorderer
:
def
__init__
(
self
,
arr
:
L
ist
[
Any
],
fn
:
Callable
)
->
None
:
def
__init__
(
self
,
arr
:
l
ist
[
Any
],
fn
:
Callable
)
->
None
:
"""Reorder an array according to some function
Args:
...
...
@@ -423,11 +424,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
# TODO: fix
hib
=
"↑"
v
=
"%.4f"
%
v
if
isinstance
(
v
,
float
)
else
v
v
=
f
"
{
v
:.
4
f
}
"
if
isinstance
(
v
,
float
)
else
v
if
m
+
"_stderr"
+
","
+
f
in
dic
:
se
=
dic
[
m
+
"_stderr"
+
","
+
f
]
se
=
" N/A"
if
se
==
"N/A"
else
"%.4f"
%
se
se
=
" N/A"
if
se
==
"N/A"
else
f
"
{
se
:.
4
f
}
"
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
"±"
,
se
])
else
:
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
""
,
""
])
...
...
@@ -448,7 +449,8 @@ def positional_deprecated(fn):
wrapped function, `fn`.
"""
@
functools
.
wraps
(
fn
)
wraps
(
fn
)
def
_wrapper
(
*
args
,
**
kwargs
):
if
len
(
args
)
!=
1
if
inspect
.
ismethod
(
fn
)
else
0
:
print
(
...
...
@@ -494,7 +496,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
if
yaml_path
is
None
:
raise
ValueError
(
"yaml_path must be provided if mode is 'full'."
)
# Attach yaml_path to the import function so that it can be used later
constructor_fn
=
functools
.
partial
(
import_function
,
yaml_path
=
Path
(
yaml_path
))
constructor_fn
=
partial
(
import_function
,
yaml_path
=
Path
(
yaml_path
))
loader
=
yaml
.
CLoader
if
yaml
.
__with_libyaml__
else
yaml
.
FullLoader
# Add the import_function constructor to the YAML loader
...
...
@@ -543,13 +545,18 @@ def regex_replace(string, pattern, repl, count: int = 0):
env
=
Environment
(
loader
=
BaseLoader
,
undefined
=
StrictUndefined
,
keep_trailing_newline
=
True
loader
=
BaseLoader
()
,
undefined
=
StrictUndefined
,
keep_trailing_newline
=
True
)
env
.
filters
[
"regex_replace"
]
=
regex_replace
@
lru_cache
(
maxsize
=
128
)
def
_compile
(
raw
:
str
):
return
env
.
from_string
(
raw
)
def
apply_template
(
template
:
str
,
doc
:
dict
)
->
str
:
rtemplate
=
env
.
from_string
(
template
)
rtemplate
=
_compile
(
template
)
return
rtemplate
.
render
(
**
doc
)
...
...
pyproject.toml
View file @
4ad6cd9f
...
...
@@ -11,34 +11,28 @@ authors = [
description
=
"A framework for evaluating language models"
readme
=
"README.md"
classifiers
=
[
"Development Status :: 3 - Alpha"
,
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: MIT License"
,
"Operating System :: OS Independent"
,
"Development Status :: 3 - Alpha"
,
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: MIT License"
,
"Operating System :: OS Independent"
]
requires-python
=
">=3.9"
license
=
{
"text"
=
"MIT"
}
dependencies
=
[
"accelerate>=0.26.0"
,
"evaluate"
,
"datasets>=2.16.0,<4.0"
,
"evaluate>=0.4.0"
,
"jsonlines"
,
"numexpr"
,
"peft>=0.2.0"
,
"pybind11>=2.6.2"
,
"pytablewriter"
,
"rouge-score>=0.0.4"
,
"sacrebleu>=1.5.0"
,
"scikit-learn>=0.24.1"
,
"sqlitedict"
,
"torch>=1.8"
,
"tqdm-multiprocess"
,
"transformers>=4.1"
,
"zstandard"
,
"dill"
,
"word2number"
,
"more_itertools"
,
"accelerate>=0.26.0"
,
"datasets>=2.16.0,<4.0"
,
"evaluate>=0.4.0"
,
"peft>=0.2.0"
,
"pytablewriter"
,
"rouge-score>=0.0.4"
,
"sacrebleu>=1.5.0"
,
"scikit-learn>=0.24.1"
,
"sqlitedict"
,
"torch>=1.8"
,
"transformers>=4.1"
,
"dill"
,
"word2number"
,
"more_itertools"
]
[tool.setuptools.packages.find]
...
...
@@ -68,7 +62,7 @@ ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
ifeval
=
[
"langdetect"
,
"immutabledict"
,
"nltk>=3.9.1"
]
ipex
=
["optimum"]
japanese_leaderboard
=
[
"emoji==2.14.0"
,
"neologdn==0.5.3"
,
"fugashi[unidic-lite]"
,
"rouge_score>=0.1.2"
]
longbench
=
[
"jieba"
,
"fuzzywuzzy"
,
"rouge"
]
longbench
=
[
"jieba"
,
"fuzzywuzzy"
,
"rouge"
]
libra
=
["pymorphy2"]
mamba
=
[
"mamba_ssm"
,
"causal-conv1d==1.0.2"
,
"torch"
]
math
=
[
"sympy>=1.12"
,
"antlr4-python3-runtime==4.11"
,
"math_verify[antlr4_11_0]"
]
...
...
@@ -87,17 +81,30 @@ vllm = ["vllm>=0.4.2"]
wandb
=
[
"wandb>=0.16.3"
,
"pandas"
,
"numpy"
]
zeno
=
[
"pandas"
,
"zeno-client"
]
tasks
=
[
"lm_eval[acpbench]"
,
"lm_eval[discrim_eval]"
,
"lm_eval[acpbench]"
,
"lm_eval[discrim_eval]"
,
"lm_eval[ifeval]"
,
"lm_eval[japanese_leaderboard]"
,
"lm_eval[longbench]"
,
"lm_eval[japanese_leaderboard]"
,
"lm_eval[longbench]"
,
"lm_eval[libra]"
,
"lm_eval[mamba]"
,
"lm_eval[math]"
,
"lm_eval[multilingual]"
,
"lm_eval[ruler]"
,
"lm_eval[math]"
,
"lm_eval[multilingual]"
,
"lm_eval[ruler]"
]
testing
=
[
"pytest"
,
"pytest-cov"
,
"pytest-xdist"
]
unitxt
=
["unitxt==1.22.0"]
vllm
=
["vllm>=0.4.2"]
wandb
=
[
"wandb>=0.16.3"
,
"pandas"
,
"numpy"
]
zeno
=
[
"pandas"
,
"zeno-client"
]
[project.scripts]
lm-eval
=
"lm_eval.__main__:cli_evaluate"
lm_eval
=
"lm_eval.__main__:cli_evaluate"
[project.urls]
Homepage
=
"https://github.com/EleutherAI/lm-evaluation-harness"
Repository
=
"https://github.com/EleutherAI/lm-evaluation-harness"
[tool.pymarkdown]
plugins.md013.enabled
=
false
# line-length
...
...
@@ -107,21 +114,23 @@ plugins.md028.enabled = false # no-blanks-blockquote
plugins.md029.allow_extended_start_values
=
true
# ol-prefix
plugins.md034.enabled
=
false
# no-bare-urls
[tool.ruff]
target-version
=
"py39"
lint.extend-select
=
[
"I"
,
"UP"
,
"E"
,
"C419"
,
"F"
,
"B"
,
"SIM"
]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
]
lint.fixable
=
[
"I001"
,
"F401"
,
"UP"
]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
,
"E741"
]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
[
"F401"
,
"F402"
,
"F403"
]
[tool.ruff.lint.isort]
combine-as-imports
=
true
lines-after-imports
=
2
known-first-party
=
["lm_eval"]
lines-after-imports
=
2
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
["F401","F402","F403"]
# required to include yaml files in pip installation
[tool.setuptools.package-data]
lm_eval
=
[
"**/*.yaml"
,
"tasks/**/*"
]
[dependency-groups]
dev
=
[
"api"
,
"dev"
,
"sentencepiece"
]
[tool.setuptools.packages.find]
include
=
["lm_eval*"]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment