Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1768fd3b
Commit
1768fd3b
authored
Jul 19, 2025
by
Baber
Browse files
ruff rules; types
parent
f650197a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
63 additions
and
44 deletions
+63
-44
.pre-commit-config.yaml
.pre-commit-config.yaml
+3
-4
lm_eval/api/group.py
lm_eval/api/group.py
+3
-3
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+31
-25
lm_eval/api/task.py
lm_eval/api/task.py
+13
-4
lm_eval/config/utils.py
lm_eval/config/utils.py
+2
-2
lm_eval/evaluator.py
lm_eval/evaluator.py
+5
-3
pyproject.toml
pyproject.toml
+6
-3
No files found.
.pre-commit-config.yaml
View file @
1768fd3b
...
@@ -32,10 +32,9 @@ repos:
...
@@ -32,10 +32,9 @@ repos:
rev
:
v0.12.2
rev
:
v0.12.2
hooks
:
hooks
:
# Run the linter.
# Run the linter.
-
id
:
ruff
-
id
:
ruff-check
args
:
args
:
[
--fix
]
-
--fix
# Run the formatter.
# Run the formatter.
-
id
:
ruff-format
-
id
:
ruff-format
-
repo
:
https://github.com/codespell-project/codespell
-
repo
:
https://github.com/codespell-project/codespell
rev
:
v2.4.1
rev
:
v2.4.1
...
...
lm_eval/api/group.py
View file @
1768fd3b
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
@
dataclass
@
dataclass
class
AggMetricConfig
(
dict
):
class
AggMetricConfig
(
dict
):
metric
:
Optional
[
str
]
=
None
metric
:
Optional
[
str
]
=
None
aggregation
:
Optional
[
str
]
=
"mean"
aggregation
:
Optional
[
str
]
=
"mean"
weight_by_size
:
Optional
[
str
]
=
False
weight_by_size
:
bool
=
False
# list of filter names which should be incorporated into the aggregated metric.
# list of filter names which should be incorporated into the aggregated metric.
filter_list
:
Optional
[
Union
[
str
,
list
]]
=
"none"
filter_list
:
Optional
[
Union
[
str
,
list
]]
=
"none"
...
@@ -27,7 +27,7 @@ class GroupConfig(dict):
...
@@ -27,7 +27,7 @@ class GroupConfig(dict):
group_alias
:
Optional
[
str
]
=
None
group_alias
:
Optional
[
str
]
=
None
task
:
Optional
[
Union
[
str
,
list
]]
=
None
task
:
Optional
[
Union
[
str
,
list
]]
=
None
aggregate_metric_list
:
Optional
[
aggregate_metric_list
:
Optional
[
Union
[
L
ist
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
Union
[
l
ist
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
]
=
None
]
=
None
version
:
Optional
[
str
]
=
None
version
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
(
metadata
:
Optional
[
dict
]
=
(
...
...
lm_eval/api/metrics.py
View file @
1768fd3b
from
__future__
import
annotations
import
logging
import
logging
import
math
import
math
import
os
import
os
import
random
import
random
import
re
import
re
import
string
import
string
from
collections.abc
import
Iterabl
e
from
collections.abc
import
Callable
,
Iterable
,
Sequenc
e
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
TypeVar
from
typing
import
Generic
,
TypeVar
import
numpy
as
np
import
numpy
as
np
...
@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float:
...
@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float:
@
register_aggregation
(
"mean"
)
@
register_aggregation
(
"mean"
)
def
mean
(
arr
:
list
[
float
])
->
float
:
def
mean
(
arr
:
Sequence
[
float
])
->
float
:
return
sum
(
arr
)
/
len
(
arr
)
return
sum
(
arr
)
/
len
(
arr
)
...
@@ -70,7 +72,7 @@ def f1_score(items):
...
@@ -70,7 +72,7 @@ def f1_score(items):
@
register_aggregation
(
"matthews_corrcoef"
)
@
register_aggregation
(
"matthews_corrcoef"
)
def
matthews_corrcoef
(
items
)
:
def
matthews_corrcoef
(
items
:
Iterable
[
tuple
[
int
,
int
]
|
tuple
[
str
,
str
]])
->
float
:
from
sklearn.metrics
import
matthews_corrcoef
from
sklearn.metrics
import
matthews_corrcoef
unzipped_list
=
list
(
zip
(
*
items
))
unzipped_list
=
list
(
zip
(
*
items
))
...
@@ -80,7 +82,7 @@ def matthews_corrcoef(items):
...
@@ -80,7 +82,7 @@ def matthews_corrcoef(items):
@
register_aggregation
(
"bleu"
)
@
register_aggregation
(
"bleu"
)
def
bleu
(
items
):
def
bleu
(
items
:
Iterable
[
tuple
[
str
,
str
]]
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
n-grams in the candidate translation to n-grams in the reference text, where
...
@@ -117,7 +119,7 @@ def chrf(items):
...
@@ -117,7 +119,7 @@ def chrf(items):
@
register_aggregation
(
"ter"
)
@
register_aggregation
(
"ter"
)
def
ter
(
items
):
def
ter
(
items
:
Iterable
[
tuple
[
str
,
str
]]
):
"""Translation Error Rate is an error metric for machine translation that
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
measures the number of edits required to change a system output into one
of the references
of the references
...
@@ -135,7 +137,9 @@ def ter(items):
...
@@ -135,7 +137,9 @@ def ter(items):
@
register_aggregation
(
"brier_score"
)
@
register_aggregation
(
"brier_score"
)
def
brier_score
(
items
):
# This is a passthrough function
def
brier_score
(
items
:
Iterable
[
tuple
[
str
,
float
]],
):
# This is a passthrough function
gold
,
predictions
=
list
(
zip
(
*
items
))
gold
,
predictions
=
list
(
zip
(
*
items
))
bs
,
num_class
=
np
.
array
(
predictions
).
shape
bs
,
num_class
=
np
.
array
(
predictions
).
shape
...
@@ -203,8 +207,8 @@ def acc_mutual_info_fn(items): # This is a passthrough function
...
@@ -203,8 +207,8 @@ def acc_mutual_info_fn(items): # This is a passthrough function
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
def
exact_match_hf_evaluate
(
def
exact_match_hf_evaluate
(
predictions
,
predictions
:
Iterable
[
str
]
,
references
,
references
:
Iterable
[
str
]
,
regexes_to_ignore
=
None
,
regexes_to_ignore
=
None
,
ignore_case
=
False
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
ignore_punctuation
=
False
,
...
@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function
...
@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
aggregation
=
"weighted_perplexity"
,
)
)
def
word_perplexity_fn
(
items
)
:
# This is a passthrough function
def
word_perplexity_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
return
items
...
@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function
...
@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
aggregation
=
"weighted_perplexity"
,
)
)
def
byte_perplexity_fn
(
items
)
:
# This is a passthrough function
def
byte_perplexity_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
return
items
...
@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function
...
@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"bits_per_byte"
,
aggregation
=
"bits_per_byte"
,
)
)
def
bits_per_byte_fn
(
items
)
:
# This is a passthrough function
def
bits_per_byte_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
return
items
...
@@ -295,7 +299,7 @@ def pop_stddev(arr):
...
@@ -295,7 +299,7 @@ def pop_stddev(arr):
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
len
(
arr
))
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
len
(
arr
))
def
sample_stddev
(
arr
:
Sequence
[
T
])
->
float
:
def
sample_stddev
(
arr
:
Sequence
[
float
])
->
float
:
mu
=
mean
(
arr
)
mu
=
mean
(
arr
)
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
(
len
(
arr
)
-
1
))
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
(
len
(
arr
)
-
1
))
...
@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
...
@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return
max
(
scores_for_ground_truths
)
return
max
(
scores_for_ground_truths
)
def
weighted_mean
(
items
:
L
ist
[
tuple
[
float
,
float
]])
->
float
:
def
weighted_mean
(
items
:
l
ist
[
tuple
[
float
,
float
]])
->
float
:
a
,
b
=
zip
(
*
items
)
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
return
sum
(
a
)
/
sum
(
b
)
...
@@ -427,15 +431,15 @@ def is_non_str_iterable(obj):
...
@@ -427,15 +431,15 @@ def is_non_str_iterable(obj):
def
_sacreformat
(
refs
,
preds
):
def
_sacreformat
(
refs
,
preds
):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (
L
ist[str],
L
ist[
L
ist[str])
# Sacrebleu expects (
l
ist[str],
l
ist[
l
ist[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# This is a different order of dimensions that I would expect
# We expect refs to be
L
ist[str] or
L
ist[
L
ist[str]], the outer list corresponding to preds
# We expect refs to be
l
ist[str] or
l
ist[
l
ist[str]], the outer list corresponding to preds
# Must become
L
ist[
L
ist[str]] with the inner list corresponding to preds
# Must become
l
ist[
l
ist[str]] with the inner list corresponding to preds
if
not
is_non_str_iterable
(
refs
):
if
not
is_non_str_iterable
(
refs
):
refs
=
list
(
refs
)
refs
=
list
(
refs
)
if
not
is_non_str_iterable
(
refs
[
0
]):
if
not
is_non_str_iterable
(
refs
[
0
]):
...
@@ -443,7 +447,7 @@ def _sacreformat(refs, preds):
...
@@ -443,7 +447,7 @@ def _sacreformat(refs, preds):
refs
=
list
(
zip
(
*
refs
))
refs
=
list
(
zip
(
*
refs
))
# Note the number of refs in each ref list much match the number of preds
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be
L
ist[str] or
L
ist[
L
ist[str]]. Must become
L
ist[str]
# We expect preds to be
l
ist[str] or
l
ist[
l
ist[str]]. Must become
l
ist[str]
if
not
is_non_str_iterable
(
preds
):
if
not
is_non_str_iterable
(
preds
):
preds
=
list
(
preds
)
preds
=
list
(
preds
)
if
is_non_str_iterable
(
preds
[
0
]):
if
is_non_str_iterable
(
preds
[
0
]):
...
@@ -456,7 +460,7 @@ def _sacreformat(refs, preds):
...
@@ -456,7 +460,7 @@ def _sacreformat(refs, preds):
# stderr stuff
# stderr stuff
class
_bootstrap_internal
:
class
_bootstrap_internal
(
Generic
[
T
])
:
"""
"""
Pool worker: `(i, xs)` → `n` bootstrap replicates
Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`.
of `f(xs)`using a RNG seeded with `i`.
...
@@ -539,7 +543,7 @@ def bootstrap_stderr(
...
@@ -539,7 +543,7 @@ def bootstrap_stderr(
def
stderr_for_metric
(
def
stderr_for_metric
(
metric
:
Callable
[[
Sequence
[
T
]],
float
],
bootstrap_iters
:
int
metric
:
Callable
[[
Sequence
[
T
]],
float
],
bootstrap_iters
:
int
)
->
Optional
[
Callable
[[
Sequence
[
T
]],
float
]
]
:
)
->
Callable
[[
Sequence
[
T
]],
float
]
|
None
:
"""
"""
Return a function that estimates the standard error of `metric(xs)`.
Return a function that estimates the standard error of `metric(xs)`.
...
@@ -569,10 +573,10 @@ def stderr_for_metric(
...
@@ -569,10 +573,10 @@ def stderr_for_metric(
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
return
stderr
.
get
(
metric
)
def
pooled_sample_stderr
(
stderrs
:
L
ist
[
float
],
sizes
:
L
ist
[
int
]):
def
pooled_sample_stderr
(
stderrs
:
l
ist
[
float
],
sizes
:
l
ist
[
int
]):
# Used to aggregate bootstrapped stderrs across subtasks in a group,
# Used to aggregate bootstrapped stderrs across subtasks in a group,
# when we are weighting by the size of each subtask.
# when we are weighting by the size of each subtask.
#
#
...
@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
...
@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
return
np
.
sqrt
(
pooled_sample_var
/
sum
(
sizes
))
return
np
.
sqrt
(
pooled_sample_var
/
sum
(
sizes
))
def
combined_sample_stderr
(
stderrs
:
L
ist
[
float
],
sizes
:
L
ist
[
int
],
metrics
=
None
):
def
combined_sample_stderr
(
stderrs
:
l
ist
[
float
],
sizes
:
l
ist
[
int
],
metrics
=
None
):
assert
metrics
is
not
None
,
(
assert
metrics
is
not
None
,
(
"Need to pass a list of each subtask's metric for this stderr aggregation"
"Need to pass a list of each subtask's metric for this stderr aggregation"
)
)
...
@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None)
...
@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None)
return
np
.
sqrt
(
variance
)
return
np
.
sqrt
(
variance
)
def
aggregate_subtask_metrics
(
metrics
,
sizes
,
weight_by_size
=
True
):
def
aggregate_subtask_metrics
(
metrics
:
list
[
float
],
sizes
:
list
[
float
],
weight_by_size
:
bool
=
True
):
# A helper function that is used to aggregate
# A helper function that is used to aggregate
# subtask scores cross-task.
# subtask scores cross-task.
# TODO: does not hold for non-mean aggregations
# TODO: does not hold for non-mean aggregations
...
@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
...
@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
assert
len
(
metrics
)
==
len
(
sizes
)
assert
len
(
metrics
)
==
len
(
sizes
)
return
sum
(
[
metric
*
size
for
metric
,
size
in
zip
(
metrics
,
sizes
)
]
)
/
sum
(
sizes
)
return
sum
(
metric
*
size
for
metric
,
size
in
zip
(
metrics
,
sizes
))
/
sum
(
sizes
)
lm_eval/api/task.py
View file @
1768fd3b
...
@@ -1053,7 +1053,9 @@ class ConfigurableTask(Task):
...
@@ -1053,7 +1053,9 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
print
(
type
(
doc_to_text
))
raise
TypeError
raise
TypeError
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
[
int
]]:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_target = self.prompt
# doc_to_target = self.prompt
if
doc_to_target
is
not
None
:
if
doc_to_target
is
not
None
:
...
@@ -1100,7 +1102,9 @@ class ConfigurableTask(Task):
...
@@ -1100,7 +1102,9 @@ class ConfigurableTask(Task):
raise
TypeError
raise
TypeError
def
doc_to_choice
(
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
Union
[
str
,
list
,
dict
,
None
]
=
None
self
,
doc
:
dict
,
doc_to_choice
:
Union
[
str
,
list
,
dict
,
Callable
[...,
list
[
str
]],
None
]
=
None
,
)
->
List
[
str
]:
)
->
List
[
str
]:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_choice = self.prompt
# doc_to_choice = self.prompt
...
@@ -1123,8 +1127,8 @@ class ConfigurableTask(Task):
...
@@ -1123,8 +1127,8 @@ class ConfigurableTask(Task):
return
list
(
doc_to_choice
.
values
())
return
list
(
doc_to_choice
.
values
())
elif
callable
(
doc_to_choice
):
elif
callable
(
doc_to_choice
):
return
doc_to_choice
(
doc
)
return
doc_to_choice
(
doc
)
elif
hasattr
(
doc_to_choice
,
"get_answer_choices_list"
):
#
elif hasattr(doc_to_choice, "get_answer_choices_list"):
return
doc_to_choice
.
get_answer_choices_list
(
doc
)
#
return doc_to_choice.get_answer_choices_list(doc)
else
:
else
:
raise
TypeError
raise
TypeError
...
@@ -1333,6 +1337,8 @@ class ConfigurableTask(Task):
...
@@ -1333,6 +1337,8 @@ class ConfigurableTask(Task):
raise
ValueError
raise
ValueError
# and this stores our "regular" conditional loglikelihoods
# and this stores our "regular" conditional loglikelihoods
lls
=
lls
[:
len
(
choices
)]
lls
=
lls
[:
len
(
choices
)]
else
:
lls_unconditional
=
None
pred
=
np
.
argmax
(
lls
)
pred
=
np
.
argmax
(
lls
)
pred_norm
=
np
.
argmax
(
lls
/
completion_len
)
pred_norm
=
np
.
argmax
(
lls
/
completion_len
)
...
@@ -1390,6 +1396,9 @@ class ConfigurableTask(Task):
...
@@ -1390,6 +1396,9 @@ class ConfigurableTask(Task):
}
}
if
"acc_mutual_info"
in
use_metric
:
if
"acc_mutual_info"
in
use_metric
:
assert
lls_unconditional
is
not
None
,
(
"lls_unconditional should not be None if acc_mutual_info is in use_metric"
)
lls_mutual_info
=
[
lls_mutual_info
=
[
ll_c
-
ll_u
for
ll_c
,
ll_u
in
zip
(
lls
,
lls_unconditional
)
ll_c
-
ll_u
for
ll_c
,
ll_u
in
zip
(
lls
,
lls_unconditional
)
]
]
...
...
lm_eval/config/utils.py
View file @
1768fd3b
...
@@ -3,8 +3,8 @@ from typing import Any, Callable, Union
...
@@ -3,8 +3,8 @@ from typing import Any, Callable, Union
def
serialize_callable
(
def
serialize_callable
(
value
:
Union
[
Callable
,
str
],
keep_callable
=
False
value
:
Union
[
Callable
[...,
Any
]
,
str
],
keep_callable
=
False
)
->
Union
[
Callable
,
str
]:
)
->
Union
[
Callable
[...,
Any
]
,
str
]:
"""Serializes a given function or string.
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
If 'keep_callable' is True, the original callable is returned.
...
...
lm_eval/evaluator.py
View file @
1768fd3b
from
__future__
import
annotations
import
itertools
import
itertools
import
json
import
json
import
logging
import
logging
...
@@ -5,7 +7,7 @@ import os
...
@@ -5,7 +7,7 @@ import os
import
random
import
random
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -49,7 +51,7 @@ eval_logger = logging.getLogger(__name__)
...
@@ -49,7 +51,7 @@ eval_logger = logging.getLogger(__name__)
@
positional_deprecated
@
positional_deprecated
def
simple_evaluate
(
def
simple_evaluate
(
model
,
model
,
model_args
:
Optional
[
Union
[
str
,
dict
]]
=
None
,
model_args
:
Optional
[
Union
[
str
,
dict
[
str
,
Any
]
]]
=
None
,
tasks
:
Optional
[
List
[
Union
[
str
,
dict
,
object
]]]
=
None
,
tasks
:
Optional
[
List
[
Union
[
str
,
dict
,
object
]]]
=
None
,
num_fewshot
:
Optional
[
int
]
=
None
,
num_fewshot
:
Optional
[
int
]
=
None
,
batch_size
:
Optional
[
Union
[
int
,
str
]]
=
None
,
batch_size
:
Optional
[
Union
[
int
,
str
]]
=
None
,
...
@@ -420,7 +422,7 @@ def simple_evaluate(
...
@@ -420,7 +422,7 @@ def simple_evaluate(
def
evaluate
(
def
evaluate
(
lm
:
"LM"
,
lm
:
"LM"
,
task_dict
,
task_dict
,
limit
:
Optional
[
int
]
=
None
,
limit
:
int
|
float
|
None
=
None
,
samples
:
Optional
[
dict
]
=
None
,
samples
:
Optional
[
dict
]
=
None
,
cache_requests
:
bool
=
False
,
cache_requests
:
bool
=
False
,
rewrite_requests_cache
:
bool
=
False
,
rewrite_requests_cache
:
bool
=
False
,
...
...
pyproject.toml
View file @
1768fd3b
...
@@ -107,16 +107,19 @@ plugins.md028.enabled = false # no-blanks-blockquote
...
@@ -107,16 +107,19 @@ plugins.md028.enabled = false # no-blanks-blockquote
plugins.md029.allow_extended_start_values
=
true
# ol-prefix
plugins.md029.allow_extended_start_values
=
true
# ol-prefix
plugins.md034.enabled
=
false
# no-bare-urls
plugins.md034.enabled
=
false
# no-bare-urls
[tool.ruff.lint]
extend-select
=
[
"I"
,
"W605"
]
[tool.ruff]
target-version
=
"py39"
lint.extend-select
=
[
"I"
,
"UP"
,
"E"
,
"C419"
,
"F"
,
"B"
,
"SIM"
]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
]
[tool.ruff.lint.isort]
[tool.ruff.lint.isort]
combine-as-imports
=
true
lines-after-imports
=
2
lines-after-imports
=
2
known-first-party
=
["lm_eval"]
known-first-party
=
["lm_eval"]
[tool.ruff.lint.extend-per-file-ignores]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
["F401","F402","F403"]
"__init__.py"
=
["F401","F402","F403"]
"utils.py"
=
["F401"]
[dependency-groups]
[dependency-groups]
dev
=
[
dev
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment