Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
4ad6cd9f
Commit
4ad6cd9f
authored
Jul 22, 2025
by
Baber
Browse files
remove deps; types
parent
689e0c91
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
240 additions
and
146 deletions
+240
-146
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
lm_eval/api/model.py
lm_eval/api/model.py
+20
-20
lm_eval/api/task.py
lm_eval/api/task.py
+96
-27
lm_eval/config/metric.py
lm_eval/config/metric.py
+3
-3
lm_eval/config/task.py
lm_eval/config/task.py
+30
-32
lm_eval/decontamination/archiver.py
lm_eval/decontamination/archiver.py
+15
-4
lm_eval/utils.py
lm_eval/utils.py
+25
-18
pyproject.toml
pyproject.toml
+50
-41
No files found.
.pre-commit-config.yaml
View file @
4ad6cd9f
...
@@ -33,7 +33,7 @@ repos:
...
@@ -33,7 +33,7 @@ repos:
hooks
:
hooks
:
# Run the linter.
# Run the linter.
-
id
:
ruff-check
-
id
:
ruff-check
args
:
[
--fix
]
args
:
[
--fix
]
# Run the formatter.
# Run the formatter.
-
id
:
ruff-format
-
id
:
ruff-format
-
repo
:
https://github.com/codespell-project/codespell
-
repo
:
https://github.com/codespell-project/codespell
...
...
lm_eval/api/model.py
View file @
4ad6cd9f
from
__future__
import
annotations
import
abc
import
abc
import
hashlib
import
hashlib
import
json
import
json
import
logging
import
logging
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Iterable
,
Optional
,
Type
,
TypeVar
,
Union
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -31,7 +34,7 @@ class LM(abc.ABC):
...
@@ -31,7 +34,7 @@ class LM(abc.ABC):
# set rank and world size to a single process, by default.
# set rank and world size to a single process, by default.
self
.
_rank
=
0
self
.
_rank
=
0
self
.
_world_size
=
1
self
.
_world_size
=
1
self
.
cache_hook
:
"
CacheHook
"
=
CacheHook
(
None
)
self
.
cache_hook
:
CacheHook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
:
list
[
Instance
])
->
list
[
tuple
[
float
,
bool
]]:
def
loglikelihood
(
self
,
requests
:
list
[
Instance
])
->
list
[
tuple
[
float
,
bool
]]:
...
@@ -101,7 +104,7 @@ class LM(abc.ABC):
...
@@ -101,7 +104,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
# TODO: Add an optional max length
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
:
list
[
"
Instance
"
])
->
list
[
str
]:
def
generate_until
(
self
,
requests
:
list
[
Instance
])
->
list
[
str
]:
"""Generate greedily until a stopping sequence
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
:param requests: list[Instance]
...
@@ -137,7 +140,7 @@ class LM(abc.ABC):
...
@@ -137,7 +140,7 @@ class LM(abc.ABC):
@
classmethod
@
classmethod
def
create_from_arg_string
(
def
create_from_arg_string
(
cls
:
T
ype
[
T
],
arg_string
:
str
,
additional_config
:
Optional
[
dict
]
=
None
cls
:
t
ype
[
T
],
arg_string
:
str
,
additional_config
:
dict
|
None
=
None
)
->
T
:
)
->
T
:
"""
"""
Creates an instance of the LM class using the given argument string and additional config.
Creates an instance of the LM class using the given argument string and additional config.
...
@@ -156,7 +159,7 @@ class LM(abc.ABC):
...
@@ -156,7 +159,7 @@ class LM(abc.ABC):
@
classmethod
@
classmethod
def
create_from_arg_obj
(
def
create_from_arg_obj
(
cls
:
T
ype
[
T
],
arg_dict
:
dict
,
additional_config
:
Optional
[
dict
]
=
None
cls
:
t
ype
[
T
],
arg_dict
:
dict
,
additional_config
:
dict
|
None
=
None
)
->
T
:
)
->
T
:
"""
"""
Creates an instance of the LM class using the given arg_obj
Creates an instance of the LM class using the given arg_obj
...
@@ -201,7 +204,7 @@ class LM(abc.ABC):
...
@@ -201,7 +204,7 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property."
"To use this model with chat templates, please implement the 'tokenizer_name' property."
)
)
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]
:
def
chat_template
(
self
,
chat_template
:
bool
|
str
=
False
)
->
str
|
None
:
"""Returns the chat template structure for user/assistant messages if a template is provided.
"""Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format.
This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default.
For models that do not support chat templates, this method returns None by default.
...
@@ -209,7 +212,7 @@ class LM(abc.ABC):
...
@@ -209,7 +212,7 @@ class LM(abc.ABC):
return
""
return
""
def
set_cache_hook
(
self
,
cache_hook
:
"
CacheHook
"
)
->
None
:
def
set_cache_hook
(
self
,
cache_hook
:
CacheHook
)
->
None
:
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
self
.
cache_hook
=
cache_hook
self
.
cache_hook
=
cache_hook
...
@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
...
@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class
CacheHook
:
class
CacheHook
:
def
__init__
(
self
,
cachinglm
:
Optional
[
"
CachingLM
"
]
)
->
None
:
def
__init__
(
self
,
cachinglm
:
CachingLM
|
None
)
->
None
:
"""CacheHook is used to cache responses from the LM."""
"""CacheHook is used to cache responses from the LM."""
if
cachinglm
is
None
:
if
cachinglm
is
None
:
self
.
dbdict
:
Optional
[
"
SqliteDict
"
]
=
None
self
.
dbdict
:
SqliteDict
|
None
=
None
return
return
self
.
dbdict
=
cachinglm
.
dbdict
self
.
dbdict
=
cachinglm
.
dbdict
...
@@ -238,7 +241,7 @@ class CacheHook:
...
@@ -238,7 +241,7 @@ class CacheHook:
class
CachingLM
:
class
CachingLM
:
def
__init__
(
self
,
lm
:
"
LM
"
,
cache_db
:
str
)
->
None
:
def
__init__
(
self
,
lm
:
LM
,
cache_db
:
str
)
->
None
:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
:param lm: LM
...
@@ -263,7 +266,7 @@ class CachingLM:
...
@@ -263,7 +266,7 @@ class CachingLM:
eval_logger
.
debug
(
f
"Passing through attribute '
{
attr
}
' to underlying LM"
)
eval_logger
.
debug
(
f
"Passing through attribute '
{
attr
}
' to underlying LM"
)
return
lm_attr
return
lm_attr
def
_fn
(
requests
:
list
[
"
Instance
"
])
->
list
[
"
Instance
"
]:
def
_fn
(
requests
:
list
[
Instance
])
->
list
[
Instance
]:
res
=
[]
res
=
[]
remaining_reqs
=
[]
remaining_reqs
=
[]
warned
=
False
warned
=
False
...
@@ -295,11 +298,8 @@ class CachingLM:
...
@@ -295,11 +298,8 @@ class CachingLM:
eval_logger
.
info
(
eval_logger
.
info
(
f
"Cached requests:
{
len
(
requests
)
-
len
(
remaining_reqs
)
}
, Requests remaining:
{
len
(
remaining_reqs
)
}
"
f
"Cached requests:
{
len
(
requests
)
-
len
(
remaining_reqs
)
}
, Requests remaining:
{
len
(
remaining_reqs
)
}
"
)
)
if
remaining_reqs
:
# actually run the LM on the requests that do not have cached results
rem_res
=
getattr
(
self
.
lm
,
attr
)(
remaining_reqs
)
if
remaining_reqs
else
[]
rem_res
=
getattr
(
self
.
lm
,
attr
)(
remaining_reqs
)
else
:
rem_res
=
[]
# stick the new ones back into the list and also cache any of the new ones
# stick the new ones back into the list and also cache any of the new ones
resptr
=
0
resptr
=
0
...
@@ -318,7 +318,7 @@ class CachingLM:
...
@@ -318,7 +318,7 @@ class CachingLM:
return
_fn
return
_fn
def
get_cache_hook
(
self
)
->
"
CacheHook
"
:
def
get_cache_hook
(
self
)
->
CacheHook
:
return
CacheHook
(
self
)
return
CacheHook
(
self
)
...
@@ -399,7 +399,7 @@ class TemplateLM(LM):
...
@@ -399,7 +399,7 @@ class TemplateLM(LM):
return
context_enc
,
continuation_enc
return
context_enc
,
continuation_enc
def
loglikelihood
(
def
loglikelihood
(
self
,
requests
:
list
[
"
Instance
"
],
disable_tqdm
:
bool
=
False
self
,
requests
:
list
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
list
[
tuple
[
float
,
bool
]]:
)
->
list
[
tuple
[
float
,
bool
]]:
"""Compute log-likelihood of generating a continuation from a context.
"""Compute log-likelihood of generating a continuation from a context.
...
@@ -432,7 +432,7 @@ class TemplateLM(LM):
...
@@ -432,7 +432,7 @@ class TemplateLM(LM):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
generate_until
(
def
generate_until
(
self
,
requests
:
list
[
"
Instance
"
],
disable_tqdm
:
bool
=
False
self
,
requests
:
list
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
list
[
str
]:
)
->
list
[
str
]:
"""Generate until a stopping sequence.
"""Generate until a stopping sequence.
...
@@ -453,7 +453,7 @@ class TemplateLM(LM):
...
@@ -453,7 +453,7 @@ class TemplateLM(LM):
"""
"""
pass
pass
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]
:
def
chat_template
(
self
,
chat_template
:
bool
|
str
=
False
)
->
str
|
None
:
"""
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
Set and get the appropriate chat template for the model.
...
...
lm_eval/api/task.py
View file @
4ad6cd9f
...
@@ -7,11 +7,7 @@ import random
...
@@ -7,11 +7,7 @@ import random
import
re
import
re
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
copy
import
deepcopy
from
copy
import
deepcopy
from
typing
import
(
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
overload
TYPE_CHECKING
,
Any
,
Literal
,
)
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
...
@@ -24,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
...
@@ -24,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from
lm_eval.api.utils
import
check_gold_index_error
from
lm_eval.api.utils
import
check_gold_index_error
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.task
import
TaskConfig
from
lm_eval.config.task
import
DataSet
,
TaskConfig
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
...
@@ -133,6 +129,7 @@ class Task(abc.ABC):
...
@@ -133,6 +129,7 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
Fresh download and fresh dataset.
"""
"""
assert
self
.
DATASET_PATH
is
not
None
,
"DATASET_PATH must be set in Task class"
self
.
dataset
=
datasets
.
load_dataset
(
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
,
name
=
self
.
DATASET_NAME
,
...
@@ -146,43 +143,40 @@ class Task(abc.ABC):
...
@@ -146,43 +143,40 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class."""
"""Returns the TaskConfig associated with this class."""
return
self
.
_config
return
self
.
_config
@
abc
.
abstractmethod
def
has_training_docs
(
self
)
->
bool
:
def
has_training_docs
(
self
)
->
bool
:
"""Whether the task has a training set"""
"""Whether the task has a training set"""
pass
raise
NotImplementedError
@
abc
.
abstractmethod
def
has_validation_docs
(
self
)
->
bool
:
def
has_validation_docs
(
self
)
->
bool
:
"""Whether the task has a validation set"""
"""Whether the task has a validation set"""
pass
raise
NotImplementedError
@
abc
.
abstractmethod
def
has_test_docs
(
self
)
->
bool
:
def
has_test_docs
(
self
)
->
bool
:
"""Whether the task has a test set"""
"""Whether the task has a test set"""
pass
raise
NotImplementedError
def
training_docs
(
self
)
->
Iterabl
e
:
def
training_docs
(
self
)
->
DataSet
|
Non
e
:
"""
"""
:return: Iterable[obj]
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
A iterable of any object, that doc_to_text can handle
"""
"""
return
[]
return
[]
def
validation_docs
(
self
)
->
Iterabl
e
:
def
validation_docs
(
self
)
->
DataSet
|
Non
e
:
"""
"""
:return: Iterable[obj]
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
A iterable of any object, that doc_to_text can handle
"""
"""
return
[]
return
[]
def
test_docs
(
self
)
->
Iterabl
e
:
def
test_docs
(
self
)
->
DataSet
|
Non
e
:
"""
"""
:return: Iterable[obj]
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
A iterable of any object, that doc_to_text can handle
"""
"""
return
[]
return
[]
def
fewshot_docs
(
self
)
->
Iterabl
e
:
def
fewshot_docs
(
self
)
->
DataSet
|
Non
e
:
"""
"""
:return: Iterable[obj]
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
A iterable of any object, that doc_to_text can handle
...
@@ -192,7 +186,7 @@ class Task(abc.ABC):
...
@@ -192,7 +186,7 @@ class Task(abc.ABC):
elif
self
.
has_validation_docs
():
elif
self
.
has_validation_docs
():
return
self
.
validation_docs
()
return
self
.
validation_docs
()
else
:
else
:
if
self
.
config
.
get
(
"
num_fewshot
"
,
0
)
>
0
:
if
self
.
config
.
num_fewshot
and
self
.
config
.
num_fewshot
>
0
:
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"[Task:
{
self
.
config
.
task
}
] has_training_docs and has_validation_docs are False"
f
"[Task:
{
self
.
config
.
task
}
] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended."
", using test_docs as fewshot_docs but this is not recommended."
...
@@ -331,7 +325,7 @@ class Task(abc.ABC):
...
@@ -331,7 +325,7 @@ class Task(abc.ABC):
inst
=
self
.
construct_requests
(
inst
=
self
.
construct_requests
(
doc
=
doc
,
doc
=
doc
,
ctx
=
fewshot_ctx
,
ctx
=
fewshot_ctx
,
metadata
=
(
self
.
config
[
"
task
"
]
,
doc_id
,
self
.
config
.
repeats
),
metadata
=
(
self
.
config
.
task
,
doc_id
,
self
.
config
.
repeats
),
apply_chat_template
=
apply_chat_template
,
apply_chat_template
=
apply_chat_template
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
)
)
...
@@ -586,7 +580,7 @@ class ConfigurableTask(Task):
...
@@ -586,7 +580,7 @@ class ConfigurableTask(Task):
data_dir
=
None
,
data_dir
=
None
,
cache_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
download_mode
=
None
,
config
:
dict
|
None
=
None
,
config
:
Mapping
[
str
,
Any
]
|
None
=
None
,
)
->
None
:
)
->
None
:
# Get pre-configured attributes
# Get pre-configured attributes
self
.
_config
=
self
.
CONFIG
self
.
_config
=
self
.
CONFIG
...
@@ -727,6 +721,9 @@ class ConfigurableTask(Task):
...
@@ -727,6 +721,9 @@ class ConfigurableTask(Task):
)
)
self
.
dataset
=
df
(
**
(
self
.
config
.
dataset_kwargs
|
self
.
config
.
metadata
))
self
.
dataset
=
df
(
**
(
self
.
config
.
dataset_kwargs
|
self
.
config
.
metadata
))
else
:
else
:
assert
self
.
config
.
dataset_path
is
not
None
,
(
"dataset_path must be set in TaskConfig"
)
self
.
dataset
=
datasets
.
load_dataset
(
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
config
.
dataset_path
,
path
=
self
.
config
.
dataset_path
,
name
=
self
.
config
.
dataset_name
,
name
=
self
.
config
.
dataset_name
,
...
@@ -742,7 +739,7 @@ class ConfigurableTask(Task):
...
@@ -742,7 +739,7 @@ class ConfigurableTask(Task):
def
has_test_docs
(
self
)
->
bool
:
def
has_test_docs
(
self
)
->
bool
:
return
self
.
config
.
test_split
is
not
None
return
self
.
config
.
test_split
is
not
None
def
training_docs
(
self
)
->
datasets
.
Data
s
et
|
None
:
def
training_docs
(
self
)
->
Data
S
et
|
None
:
if
self
.
has_training_docs
():
if
self
.
has_training_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
...
@@ -750,7 +747,7 @@ class ConfigurableTask(Task):
...
@@ -750,7 +747,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
config
.
training_split
]
return
self
.
dataset
[
self
.
config
.
training_split
]
def
validation_docs
(
self
)
->
datasets
.
Data
s
et
|
None
:
def
validation_docs
(
self
)
->
Data
S
et
|
None
:
if
self
.
has_validation_docs
():
if
self
.
has_validation_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
...
@@ -758,7 +755,7 @@ class ConfigurableTask(Task):
...
@@ -758,7 +755,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
config
.
validation_split
]
return
self
.
dataset
[
self
.
config
.
validation_split
]
def
test_docs
(
self
)
->
datasets
.
Data
s
et
|
None
:
def
test_docs
(
self
)
->
Data
S
et
|
None
:
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
...
@@ -996,9 +993,21 @@ class ConfigurableTask(Task):
...
@@ -996,9 +993,21 @@ class ConfigurableTask(Task):
"""
"""
return
doc
return
doc
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
None
=
None
)
->
str
|
int
:
...
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
)
->
int
:
...
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
str
)
->
str
:
...
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
Callable
[...,
str
])
->
str
:
...
def
doc_to_text
(
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
[...,
str
]
|
None
=
None
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
[...,
str
]
|
None
=
None
)
->
str
:
)
->
str
|
int
:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_text = self.prompt
# doc_to_text = self.prompt
doc_to_text
=
doc_to_text
or
self
.
config
.
doc_to_text
doc_to_text
=
doc_to_text
or
self
.
config
.
doc_to_text
...
@@ -1031,6 +1040,25 @@ class ConfigurableTask(Task):
...
@@ -1031,6 +1040,25 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
print
(
type
(
doc_to_text
))
raise
TypeError
raise
TypeError
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
None
=
None
)
->
int
|
str
|
list
[
int
]:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
int
)
->
int
:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
str
)
->
int
|
str
|
list
[
int
]:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
list
)
->
list
[
int
]:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
Callable
[...,
int
|
str
|
list
[
int
]]
)
->
int
|
str
|
list
[
int
]:
...
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
int
|
str
|
list
[
int
]:
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
int
|
str
|
list
[
int
]:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_target = self.prompt
# doc_to_target = self.prompt
...
@@ -1077,6 +1105,23 @@ class ConfigurableTask(Task):
...
@@ -1077,6 +1105,23 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
None
=
None
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
str
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
list
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
dict
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
Callable
[...,
list
[
str
]]
)
->
list
[
str
]:
...
def
doc_to_choice
(
def
doc_to_choice
(
self
,
self
,
doc
:
dict
,
doc
:
dict
,
...
@@ -1108,6 +1153,18 @@ class ConfigurableTask(Task):
...
@@ -1108,6 +1153,18 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
None
=
None
)
->
None
:
...
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
list
)
->
list
:
...
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
str
)
->
int
|
str
|
None
:
...
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
Callable
[...,
Any
])
->
Any
:
...
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
int
|
str
|
list
|
None
:
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
int
|
str
|
list
|
None
:
if
doc_to_image
is
not
None
:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
doc_to_image
=
doc_to_image
...
@@ -1131,6 +1188,18 @@ class ConfigurableTask(Task):
...
@@ -1131,6 +1188,18 @@ class ConfigurableTask(Task):
else
:
else
:
return
None
return
None
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
None
=
None
)
->
None
:
...
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
list
)
->
list
:
...
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
str
)
->
int
|
str
|
None
:
...
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
Callable
[...,
Any
])
->
Any
:
...
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
int
|
str
|
list
|
None
:
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
int
|
str
|
list
|
None
:
if
doc_to_audio
is
not
None
:
if
doc_to_audio
is
not
None
:
doc_to_audio
=
doc_to_audio
doc_to_audio
=
doc_to_audio
...
@@ -1375,15 +1444,15 @@ class ConfigurableTask(Task):
...
@@ -1375,15 +1444,15 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
gold
=
self
.
doc_to_target
(
doc
)
gold
=
self
.
doc_to_target
(
doc
)
result
=
results
[
0
]
result
=
results
[
0
]
for
metric
in
self
.
_metric_
fn_
list
:
for
metric
in
self
.
config
.
_metric_list
:
try
:
try
:
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
(
result_score
=
metric
.
fn
(
references
=
[
gold
]
if
not
isinstance
(
gold
,
list
)
else
gold
,
references
=
[
gold
]
if
not
isinstance
(
gold
,
list
)
else
gold
,
predictions
=
[
result
],
predictions
=
[
result
],
**
self
.
_
metric
_fn_
kwargs
[
metric
]
,
**
metric
.
kwargs
,
)
)
except
TypeError
:
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
except
TypeError
:
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
([
gold
,
result
])
result_score
=
metric
.
fn
([
gold
,
result
])
if
isinstance
(
result_score
,
dict
):
if
isinstance
(
result_score
,
dict
):
# TODO: this handles the case where HF evaluate returns a dict.
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
# This allows for multiple metrics to be returned from the same function
...
...
lm_eval/config/metric.py
View file @
4ad6cd9f
from
__future__
import
annotations
from
__future__
import
annotations
from
collections.abc
import
Callable
,
Mapping
from
collections.abc
import
Callable
,
Mapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Any
from
typing
import
Any
...
@@ -11,8 +11,8 @@ class MetricConfig:
...
@@ -11,8 +11,8 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
"""Encapsulates information about a single metric."""
name
:
str
name
:
str
fn
:
Callable
|
None
=
None
fn
:
Callable
kwargs
:
Mapping
[
str
,
Any
]
|
None
=
None
kwargs
:
Mapping
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
aggregation_fn
:
Callable
|
None
=
None
aggregation_fn
:
Callable
|
None
=
None
higher_is_better
:
bool
=
True
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
hf_evaluate
:
bool
=
False
...
...
lm_eval/config/task.py
View file @
4ad6cd9f
...
@@ -3,7 +3,9 @@ from __future__ import annotations
...
@@ -3,7 +3,9 @@ from __future__ import annotations
import
logging
import
logging
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Union
import
datasets
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
from
lm_eval.api.instance
import
OutputType
...
@@ -18,6 +20,9 @@ if TYPE_CHECKING:
...
@@ -18,6 +20,9 @@ if TYPE_CHECKING:
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
DataSet
=
Union
[
datasets
.
Dataset
,
Iterable
[
dict
[
str
,
Any
]]]
DSplits
=
dict
[
str
,
DataSet
]
@
dataclass
@
dataclass
class
RepeatConfig
:
class
RepeatConfig
:
...
@@ -30,7 +35,7 @@ class RepeatConfig:
...
@@ -30,7 +35,7 @@ class RepeatConfig:
@
dataclass
@
dataclass
class
FilterConfig
:
class
FilterConfig
:
"""Encapsulates information about a single filter."""
"""Encapsulates information about a single filter
pipeline
."""
name
:
str
name
:
str
ensemble
:
FilterEnsemble
ensemble
:
FilterEnsemble
...
@@ -44,10 +49,8 @@ class FewshotConfig:
...
@@ -44,10 +49,8 @@ class FewshotConfig:
num_fewshot
:
Callable
[[],
int
]
num_fewshot
:
Callable
[[],
int
]
split
:
str
|
None
=
None
split
:
str
|
None
=
None
sampler
:
str
|
Callable
=
"default"
sampler
:
str
|
Callable
=
"default"
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
samples
:
Callable
[[],
DataSet
]
|
DataSet
|
None
=
None
process_docs
:
Callable
[[
list
[
dict
[
str
,
Any
]]],
Iterable
[
dict
[
str
,
Any
]]]
|
None
=
(
process_docs
:
Callable
[[
DataSet
],
DataSet
]
|
None
=
None
None
)
fewshot_indices
:
list
[
int
]
|
None
=
None
fewshot_indices
:
list
[
int
]
|
None
=
None
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
...
@@ -69,22 +72,23 @@ class FewshotConfig:
...
@@ -69,22 +72,23 @@ class FewshotConfig:
"""Check if any fewshot source is configured."""
"""Check if any fewshot source is configured."""
return
self
.
split
is
not
None
or
self
.
samples
is
not
None
return
self
.
split
is
not
None
or
self
.
samples
is
not
None
def
_get_raw_docs
(
def
_get_raw_docs
(
self
,
dataset
:
DSplits
)
->
DataSet
|
None
:
self
,
dataset
)
->
list
[
dict
]
|
Callable
[[],
Iterable
[
dict
]]
|
None
:
"""Get raw documents from configured source."""
"""Get raw documents from configured source."""
if
self
.
split
is
not
None
:
if
self
.
split
is
not
None
:
return
dataset
[
self
.
split
]
return
dataset
[
self
.
split
]
if
self
.
samples
is
not
None
:
if
self
.
samples
is
not
None
:
if
isinstance
(
self
.
samples
,
list
)
or
callable
(
self
.
samples
)
:
if
isinstance
(
self
.
samples
,
list
):
return
self
.
samples
return
self
.
samples
elif
callable
(
self
.
samples
):
# If samples is a callable, it should return a list of dicts
return
self
.
samples
()
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"samples must be either a list of dicts or a callable returning a list"
"samples must be either a list of dicts or a callable returning a list"
)
)
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
[
str
,
Any
]]
|
None
:
def
get_docs
(
self
,
dataset
)
->
DataSet
|
None
:
"""Get processed documents from configured source."""
"""Get processed documents from configured source."""
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
if
raw_docs
is
None
:
if
raw_docs
is
None
:
...
@@ -130,34 +134,34 @@ class TaskConfig:
...
@@ -130,34 +134,34 @@ class TaskConfig:
# HF dataset options.
# HF dataset options.
# which dataset to use,
# which dataset to use,
# and what splits for what purpose
# and what splits for what purpose
custom_dataset
:
Callable
|
None
=
None
custom_dataset
:
Callable
[...,
DataSet
]
|
None
=
None
dataset_path
:
str
|
None
=
None
dataset_path
:
str
|
None
=
None
dataset_name
:
str
|
None
=
None
dataset_name
:
str
|
None
=
None
dataset_kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
dataset_kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
training_split
:
str
|
None
=
None
training_split
:
str
|
None
=
None
validation_split
:
str
|
None
=
None
validation_split
:
str
|
None
=
None
test_split
:
str
|
None
=
None
test_split
:
str
|
None
=
None
fewshot_split
:
str
|
None
=
(
fewshot_split
:
str
|
None
=
None
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options.
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
# see docs/advanced_task_guide.md for more info
process_docs
:
Callable
|
None
=
None
process_docs
:
Callable
[[
DataSet
],
DataSet
]
|
None
=
None
doc_to_text
:
Callable
|
str
|
None
=
None
doc_to_text
:
Callable
[[
dict
[
str
,
Any
]],
Any
]
|
str
|
None
=
None
doc_to_target
:
Callable
|
str
|
None
=
None
doc_to_target
:
Callable
[[
dict
[
str
,
Any
]],
Any
]
|
str
|
None
=
None
doc_to_image
:
Callable
|
str
|
None
=
None
doc_to_image
:
Callable
[[
dict
[
str
,
Any
]],
Any
]
|
str
|
None
=
None
doc_to_audio
:
Callable
|
str
|
None
=
None
doc_to_audio
:
Callable
[[
dict
[
str
,
Any
]],
Any
]
|
str
|
None
=
None
unsafe_code
:
bool
=
False
unsafe_code
:
bool
=
False
doc_to_choice
:
Callable
|
str
|
dict
|
list
|
None
=
None
doc_to_choice
:
Callable
[[
dict
[
str
,
Any
]],
Any
]
|
str
|
dict
|
list
|
None
=
None
process_results
:
Callable
|
str
|
None
=
None
process_results
:
(
Callable
[[
dict
[
str
,
Any
],
list
[
Any
]],
dict
[
str
,
Any
]]
|
str
|
None
)
=
None
use_prompt
:
str
|
None
=
None
use_prompt
:
str
|
None
=
None
description
:
str
=
""
description
:
str
=
""
target_delimiter
:
str
=
" "
target_delimiter
:
str
=
" "
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_config
:
dict
|
None
=
None
fewshot_config
:
dict
[
str
,
Any
]
|
None
=
None
# runtime configuration options
# runtime configuration options
num_fewshot
:
int
|
None
=
0
num_fewshot
:
int
|
None
=
None
generation_kwargs
:
dict
|
None
=
None
generation_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
# scoring options
# scoring options
metric_list
:
list
|
None
=
None
metric_list
:
list
|
None
=
None
output_type
:
OutputType
=
"generate_until"
output_type
:
OutputType
=
"generate_until"
...
@@ -357,7 +361,7 @@ class TaskConfig:
...
@@ -357,7 +361,7 @@ class TaskConfig:
return
x
return
x
@
classmethod
@
classmethod
def
from_yaml
(
cls
,
data
:
dict
)
->
TaskConfig
:
def
from_yaml
(
cls
,
data
:
dict
[
str
,
Any
]
)
->
TaskConfig
:
"""Create a TaskConfig instance from a YAML-like dictionary."""
"""Create a TaskConfig instance from a YAML-like dictionary."""
return
cls
(
**
data
)
return
cls
(
**
data
)
...
@@ -425,12 +429,6 @@ class TaskConfig:
...
@@ -425,12 +429,6 @@ class TaskConfig:
# Create and return TaskConfig instance
# Create and return TaskConfig instance
return
cls
(
**
config_dict
)
return
cls
(
**
config_dict
)
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
def
__setitem__
(
self
,
item
,
value
):
return
setattr
(
self
,
item
,
value
)
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
def
_ser
(
x
):
def
_ser
(
x
):
if
isinstance
(
x
,
dict
):
if
isinstance
(
x
,
dict
):
...
...
lm_eval/decontamination/archiver.py
View file @
4ad6cd9f
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "jsonlines",
# "mmap",
# "tqdm",
# "zstandard",
# ]
# ///
# ruff: noqa
import
datetime
import
datetime
import
io
import
io
import
json
import
json
...
@@ -111,7 +122,7 @@ class TextReader:
...
@@ -111,7 +122,7 @@ class TextReader:
current_file_position
=
0
current_file_position
=
0
line_counter
=
0
line_counter
=
0
with
(
with
(
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf-8"
)
as
fh
,
open
(
self
.
file_path
,
encoding
=
"utf-8"
)
as
fh
,
tqdm
.
tqdm
(
tqdm
.
tqdm
(
total
=
os
.
path
.
getsize
(
self
.
file_path
),
total
=
os
.
path
.
getsize
(
self
.
file_path
),
dynamic_ncols
=
True
,
dynamic_ncols
=
True
,
...
@@ -133,7 +144,7 @@ class TextReader:
...
@@ -133,7 +144,7 @@ class TextReader:
def
read_and_tell
(
self
):
def
read_and_tell
(
self
):
current_file_position
=
0
current_file_position
=
0
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
line
=
line
.
decode
(
"utf-8"
)
...
@@ -143,14 +154,14 @@ class TextReader:
...
@@ -143,14 +154,14 @@ class TextReader:
yield
line
[:
-
1
],
raw_bytes_read
yield
line
[:
-
1
],
raw_bytes_read
def
read
(
self
):
def
read
(
self
):
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
line
=
line
.
decode
(
"utf-8"
)
yield
line
[:
-
1
]
yield
line
[:
-
1
]
def
read_slow
(
self
):
def
read_slow
(
self
):
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
encoding
=
"utf8"
)
as
fh
:
while
True
:
while
True
:
line
=
fh
.
readline
()
line
=
fh
.
readline
()
if
line
==
-
1
or
line
==
""
:
if
line
==
-
1
or
line
==
""
:
...
...
lm_eval/utils.py
View file @
4ad6cd9f
import
collections
import
collections
import
fnmatch
import
fnmatch
import
functools
import
hashlib
import
hashlib
import
importlib.util
import
importlib.util
import
inspect
import
inspect
...
@@ -8,10 +7,12 @@ import json
...
@@ -8,10 +7,12 @@ import json
import
logging
import
logging
import
os
import
os
import
re
import
re
from
collections.abc
import
Generator
from
dataclasses
import
asdict
,
is_dataclass
from
dataclasses
import
asdict
,
is_dataclass
from
functools
import
lru_cache
,
partial
,
wraps
from
itertools
import
islice
from
itertools
import
islice
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Generator
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Optional
import
numpy
as
np
import
numpy
as
np
import
yaml
import
yaml
...
@@ -108,7 +109,7 @@ def escaped_split(text, sep_char, maxsplit=-1):
...
@@ -108,7 +109,7 @@ def escaped_split(text, sep_char, maxsplit=-1):
return
text
return
text
maxsplit
=
max
(
0
,
maxsplit
)
maxsplit
=
max
(
0
,
maxsplit
)
return
re
.
split
(
r
"(?<!\\)"
+
sep_char
,
text
,
maxsplit
)
return
re
.
split
(
r
"(?<!\\)"
+
sep_char
,
text
,
maxsplit
=
maxsplit
)
def
handle_arg_string
(
arg
):
def
handle_arg_string
(
arg
):
...
@@ -125,7 +126,7 @@ def handle_arg_string(arg):
...
@@ -125,7 +126,7 @@ def handle_arg_string(arg):
def
handle_non_serializable
(
o
):
def
handle_non_serializable
(
o
):
if
isinstance
(
o
,
np
.
int
64
)
or
isinstance
(
o
,
np
.
int32
):
if
isinstance
(
o
,
np
.
int
eger
):
return
int
(
o
)
return
int
(
o
)
elif
isinstance
(
o
,
set
):
elif
isinstance
(
o
,
set
):
return
list
(
o
)
return
list
(
o
)
...
@@ -235,21 +236,21 @@ def sanitize_task_name(task_name: str) -> str:
...
@@ -235,21 +236,21 @@ def sanitize_task_name(task_name: str) -> str:
return
re
.
sub
(
r
"\W"
,
"_"
,
task_name
)
return
re
.
sub
(
r
"\W"
,
"_"
,
task_name
)
def
get_latest_filename
(
filenames
:
L
ist
[
str
])
->
str
:
def
get_latest_filename
(
filenames
:
l
ist
[
str
])
->
str
:
"""
"""
Given a list of filenames, returns the filename with the latest datetime.
Given a list of filenames, returns the filename with the latest datetime.
"""
"""
return
max
(
filenames
,
key
=
lambda
f
:
get_file_datetime
(
f
))
return
max
(
filenames
,
key
=
lambda
f
:
get_file_datetime
(
f
))
def
get_results_filenames
(
filenames
:
L
ist
[
str
])
->
L
ist
[
str
]:
def
get_results_filenames
(
filenames
:
l
ist
[
str
])
->
l
ist
[
str
]:
"""
"""
Extracts filenames that correspond to aggregated results.
Extracts filenames that correspond to aggregated results.
"""
"""
return
[
f
for
f
in
filenames
if
"/results_"
in
f
and
".json"
in
f
]
return
[
f
for
f
in
filenames
if
"/results_"
in
f
and
".json"
in
f
]
def
get_sample_results_filenames
(
filenames
:
L
ist
[
str
])
->
L
ist
[
str
]:
def
get_sample_results_filenames
(
filenames
:
l
ist
[
str
])
->
l
ist
[
str
]:
"""
"""
Extracts filenames that correspond to sample results.
Extracts filenames that correspond to sample results.
"""
"""
...
@@ -257,8 +258,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
...
@@ -257,8 +258,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
def
get_rolling_token_windows
(
def
get_rolling_token_windows
(
token_list
:
L
ist
[
int
],
prefix_token
:
int
,
max_seq_len
:
int
,
context_len
:
int
token_list
:
l
ist
[
int
],
prefix_token
:
int
,
max_seq_len
:
int
,
context_len
:
int
)
->
Generator
[
T
uple
[
L
ist
[
int
],
L
ist
[
int
]],
None
,
None
]:
)
->
Generator
[
t
uple
[
l
ist
[
int
],
l
ist
[
int
]],
None
,
None
]:
"""
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
condition on some context
...
@@ -300,8 +301,8 @@ def get_rolling_token_windows(
...
@@ -300,8 +301,8 @@ def get_rolling_token_windows(
def
make_disjoint_window
(
def
make_disjoint_window
(
pair
:
T
uple
[
L
ist
[
int
],
L
ist
[
int
]],
pair
:
t
uple
[
l
ist
[
int
],
l
ist
[
int
]],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]:
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]:
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a
,
b
=
pair
a
,
b
=
pair
return
a
[:
len
(
a
)
-
(
len
(
b
)
-
1
)],
b
return
a
[:
len
(
a
)
-
(
len
(
b
)
-
1
)],
b
...
@@ -320,7 +321,7 @@ class EnhancedJSONEncoder(json.JSONEncoder):
...
@@ -320,7 +321,7 @@ class EnhancedJSONEncoder(json.JSONEncoder):
class
Reorderer
:
class
Reorderer
:
def
__init__
(
self
,
arr
:
L
ist
[
Any
],
fn
:
Callable
)
->
None
:
def
__init__
(
self
,
arr
:
l
ist
[
Any
],
fn
:
Callable
)
->
None
:
"""Reorder an array according to some function
"""Reorder an array according to some function
Args:
Args:
...
@@ -423,11 +424,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
...
@@ -423,11 +424,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
# TODO: fix
# TODO: fix
hib
=
"↑"
hib
=
"↑"
v
=
"%.4f"
%
v
if
isinstance
(
v
,
float
)
else
v
v
=
f
"
{
v
:.
4
f
}
"
if
isinstance
(
v
,
float
)
else
v
if
m
+
"_stderr"
+
","
+
f
in
dic
:
if
m
+
"_stderr"
+
","
+
f
in
dic
:
se
=
dic
[
m
+
"_stderr"
+
","
+
f
]
se
=
dic
[
m
+
"_stderr"
+
","
+
f
]
se
=
" N/A"
if
se
==
"N/A"
else
"%.4f"
%
se
se
=
" N/A"
if
se
==
"N/A"
else
f
"
{
se
:.
4
f
}
"
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
"±"
,
se
])
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
"±"
,
se
])
else
:
else
:
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
""
,
""
])
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
""
,
""
])
...
@@ -448,7 +449,8 @@ def positional_deprecated(fn):
...
@@ -448,7 +449,8 @@ def positional_deprecated(fn):
wrapped function, `fn`.
wrapped function, `fn`.
"""
"""
@
functools
.
wraps
(
fn
)
wraps
(
fn
)
def
_wrapper
(
*
args
,
**
kwargs
):
def
_wrapper
(
*
args
,
**
kwargs
):
if
len
(
args
)
!=
1
if
inspect
.
ismethod
(
fn
)
else
0
:
if
len
(
args
)
!=
1
if
inspect
.
ismethod
(
fn
)
else
0
:
print
(
print
(
...
@@ -494,7 +496,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
...
@@ -494,7 +496,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
if
yaml_path
is
None
:
if
yaml_path
is
None
:
raise
ValueError
(
"yaml_path must be provided if mode is 'full'."
)
raise
ValueError
(
"yaml_path must be provided if mode is 'full'."
)
# Attach yaml_path to the import function so that it can be used later
# Attach yaml_path to the import function so that it can be used later
constructor_fn
=
functools
.
partial
(
import_function
,
yaml_path
=
Path
(
yaml_path
))
constructor_fn
=
partial
(
import_function
,
yaml_path
=
Path
(
yaml_path
))
loader
=
yaml
.
CLoader
if
yaml
.
__with_libyaml__
else
yaml
.
FullLoader
loader
=
yaml
.
CLoader
if
yaml
.
__with_libyaml__
else
yaml
.
FullLoader
# Add the import_function constructor to the YAML loader
# Add the import_function constructor to the YAML loader
...
@@ -543,13 +545,18 @@ def regex_replace(string, pattern, repl, count: int = 0):
...
@@ -543,13 +545,18 @@ def regex_replace(string, pattern, repl, count: int = 0):
env
=
Environment
(
env
=
Environment
(
loader
=
BaseLoader
,
undefined
=
StrictUndefined
,
keep_trailing_newline
=
True
loader
=
BaseLoader
()
,
undefined
=
StrictUndefined
,
keep_trailing_newline
=
True
)
)
env
.
filters
[
"regex_replace"
]
=
regex_replace
env
.
filters
[
"regex_replace"
]
=
regex_replace
@
lru_cache
(
maxsize
=
128
)
def
_compile
(
raw
:
str
):
return
env
.
from_string
(
raw
)
def
apply_template
(
template
:
str
,
doc
:
dict
)
->
str
:
def
apply_template
(
template
:
str
,
doc
:
dict
)
->
str
:
rtemplate
=
env
.
from_string
(
template
)
rtemplate
=
_compile
(
template
)
return
rtemplate
.
render
(
**
doc
)
return
rtemplate
.
render
(
**
doc
)
...
...
pyproject.toml
View file @
4ad6cd9f
...
@@ -14,31 +14,25 @@ classifiers = [
...
@@ -14,31 +14,25 @@ classifiers = [
"Development Status :: 3 - Alpha"
,
"Development Status :: 3 - Alpha"
,
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: MIT License"
,
"License :: OSI Approved :: MIT License"
,
"Operating System :: OS Independent"
,
"Operating System :: OS Independent"
]
]
requires-python
=
">=3.9"
requires-python
=
">=3.9"
license
=
{
"text"
=
"MIT"
}
license
=
{
"text"
=
"MIT"
}
dependencies
=
[
dependencies
=
[
"accelerate>=0.26.0"
,
"accelerate>=0.26.0"
,
"evaluate"
,
"datasets>=2.16.0,<4.0"
,
"datasets>=2.16.0,<4.0"
,
"evaluate>=0.4.0"
,
"evaluate>=0.4.0"
,
"jsonlines"
,
"numexpr"
,
"peft>=0.2.0"
,
"peft>=0.2.0"
,
"pybind11>=2.6.2"
,
"pytablewriter"
,
"pytablewriter"
,
"rouge-score>=0.0.4"
,
"rouge-score>=0.0.4"
,
"sacrebleu>=1.5.0"
,
"sacrebleu>=1.5.0"
,
"scikit-learn>=0.24.1"
,
"scikit-learn>=0.24.1"
,
"sqlitedict"
,
"sqlitedict"
,
"torch>=1.8"
,
"torch>=1.8"
,
"tqdm-multiprocess"
,
"transformers>=4.1"
,
"transformers>=4.1"
,
"zstandard"
,
"dill"
,
"dill"
,
"word2number"
,
"word2number"
,
"more_itertools"
,
"more_itertools"
]
]
[tool.setuptools.packages.find]
[tool.setuptools.packages.find]
...
@@ -68,7 +62,7 @@ ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
...
@@ -68,7 +62,7 @@ ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
ifeval
=
[
"langdetect"
,
"immutabledict"
,
"nltk>=3.9.1"
]
ifeval
=
[
"langdetect"
,
"immutabledict"
,
"nltk>=3.9.1"
]
ipex
=
["optimum"]
ipex
=
["optimum"]
japanese_leaderboard
=
[
"emoji==2.14.0"
,
"neologdn==0.5.3"
,
"fugashi[unidic-lite]"
,
"rouge_score>=0.1.2"
]
japanese_leaderboard
=
[
"emoji==2.14.0"
,
"neologdn==0.5.3"
,
"fugashi[unidic-lite]"
,
"rouge_score>=0.1.2"
]
longbench
=
[
"jieba"
,
"fuzzywuzzy"
,
"rouge"
]
longbench
=
[
"jieba"
,
"fuzzywuzzy"
,
"rouge"
]
libra
=
["pymorphy2"]
libra
=
["pymorphy2"]
mamba
=
[
"mamba_ssm"
,
"causal-conv1d==1.0.2"
,
"torch"
]
mamba
=
[
"mamba_ssm"
,
"causal-conv1d==1.0.2"
,
"torch"
]
math
=
[
"sympy>=1.12"
,
"antlr4-python3-runtime==4.11"
,
"math_verify[antlr4_11_0]"
]
math
=
[
"sympy>=1.12"
,
"antlr4-python3-runtime==4.11"
,
"math_verify[antlr4_11_0]"
]
...
@@ -96,8 +90,21 @@ tasks = [
...
@@ -96,8 +90,21 @@ tasks = [
"lm_eval[mamba]"
,
"lm_eval[mamba]"
,
"lm_eval[math]"
,
"lm_eval[math]"
,
"lm_eval[multilingual]"
,
"lm_eval[multilingual]"
,
"lm_eval[ruler]"
,
"lm_eval[ruler]"
]
]
testing
=
[
"pytest"
,
"pytest-cov"
,
"pytest-xdist"
]
unitxt
=
["unitxt==1.22.0"]
vllm
=
["vllm>=0.4.2"]
wandb
=
[
"wandb>=0.16.3"
,
"pandas"
,
"numpy"
]
zeno
=
[
"pandas"
,
"zeno-client"
]
[project.scripts]
lm-eval
=
"lm_eval.__main__:cli_evaluate"
lm_eval
=
"lm_eval.__main__:cli_evaluate"
[project.urls]
Homepage
=
"https://github.com/EleutherAI/lm-evaluation-harness"
Repository
=
"https://github.com/EleutherAI/lm-evaluation-harness"
[tool.pymarkdown]
[tool.pymarkdown]
plugins.md013.enabled
=
false
# line-length
plugins.md013.enabled
=
false
# line-length
...
@@ -107,21 +114,23 @@ plugins.md028.enabled = false # no-blanks-blockquote
...
@@ -107,21 +114,23 @@ plugins.md028.enabled = false # no-blanks-blockquote
plugins.md029.allow_extended_start_values
=
true
# ol-prefix
plugins.md029.allow_extended_start_values
=
true
# ol-prefix
plugins.md034.enabled
=
false
# no-bare-urls
plugins.md034.enabled
=
false
# no-bare-urls
[tool.ruff]
[tool.ruff]
target-version
=
"py39"
target-version
=
"py39"
lint.extend-select
=
[
"I"
,
"UP"
,
"E"
,
"C419"
,
"F"
,
"B"
,
"SIM"
]
lint.extend-select
=
[
"I"
,
"UP"
,
"E"
,
"C419"
,
"F"
,
"B"
,
"SIM"
]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
]
lint.fixable
=
[
"I001"
,
"F401"
,
"UP"
]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
,
"E741"
]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
[
"F401"
,
"F402"
,
"F403"
]
[tool.ruff.lint.isort]
[tool.ruff.lint.isort]
combine-as-imports
=
true
combine-as-imports
=
true
lines-after-imports
=
2
known-first-party
=
["lm_eval"]
known-first-party
=
["lm_eval"]
lines-after-imports
=
2
[tool.ruff.lint.extend-per-file-ignores]
# required to include yaml files in pip installation
"__init__.py"
=
["F401","F402","F403"]
[tool.setuptools.package-data]
lm_eval
=
[
"**/*.yaml"
,
"tasks/**/*"
]
[dependency-groups]
[tool.setuptools.packages.find]
dev
=
[
include
=
["lm_eval*"]
"api"
,
"dev"
,
"sentencepiece"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment