Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
28c78d30
Commit
28c78d30
authored
Jun 30, 2025
by
Baber
Browse files
add MetricConfig
parent
de496b80
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
223 additions
and
156 deletions
+223
-156
lm_eval/__main__.py
lm_eval/__main__.py
+4
-0
lm_eval/api/group.py
lm_eval/api/group.py
+1
-2
lm_eval/api/instance.py
lm_eval/api/instance.py
+17
-4
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+6
-1
lm_eval/api/registry.py
lm_eval/api/registry.py
+13
-10
lm_eval/api/samplers.py
lm_eval/api/samplers.py
+3
-3
lm_eval/api/task.py
lm_eval/api/task.py
+174
-131
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-1
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+2
-2
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+2
-2
No files found.
lm_eval/__main__.py
View file @
28c78d30
...
@@ -485,6 +485,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
...
@@ -485,6 +485,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if
results
is
not
None
:
if
results
is
not
None
:
if
args
.
log_samples
:
if
args
.
log_samples
:
samples
=
results
.
pop
(
"samples"
)
samples
=
results
.
pop
(
"samples"
)
# TODO: fix this!
results
[
"higher_is_better"
]
=
{
k
:
True
for
k
,
v
in
results
[
"higher_is_better"
].
items
()
}
dumped
=
json
.
dumps
(
dumped
=
json
.
dumps
(
results
,
indent
=
2
,
default
=
handle_non_serializable
,
ensure_ascii
=
False
results
,
indent
=
2
,
default
=
handle_non_serializable
,
ensure_ascii
=
False
)
)
...
...
lm_eval/api/group.py
View file @
28c78d30
import
abc
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
...
@@ -84,7 +83,7 @@ class GroupConfig(dict):
...
@@ -84,7 +83,7 @@ class GroupConfig(dict):
return
str
(
value
)
return
str
(
value
)
class
ConfigurableGroup
(
abc
.
ABC
)
:
class
ConfigurableGroup
:
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Optional
[
dict
]
=
None
,
config
:
Optional
[
dict
]
=
None
,
...
...
lm_eval/api/instance.py
View file @
28c78d30
...
@@ -14,10 +14,23 @@ class Instance:
...
@@ -14,10 +14,23 @@ class Instance:
arguments
:
tuple
arguments
:
tuple
idx
:
int
idx
:
int
metadata
:
Tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
int
]]
=
field
(
metadata
:
Tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
int
]]
=
field
(
default_factory
=
lambda
:
(
None
,
None
,
None
)
default_factory
=
lambda
:
(
None
,
None
,
None
),
metadata
=
dict
(
description
=
"Metadata tuple containing task name, document ID, and number of repeats."
),
)
resps
:
list
=
field
(
default_factory
=
list
,
metadata
=
dict
(
description
=
"List of responses from the model for this instance."
),
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
,
metadata
=
dict
(
description
=
"List of filtered responses for this instance, keyed by filter name."
),
)
)
resps
:
list
=
field
(
default_factory
=
list
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
)
# initialized after init
# initialized after init
task_name
:
Optional
[
str
]
=
None
task_name
:
Optional
[
str
]
=
None
...
@@ -29,7 +42,7 @@ class Instance:
...
@@ -29,7 +42,7 @@ class Instance:
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
@
property
@
property
def
args
(
self
):
def
args
(
self
)
->
tuple
:
"""
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
"""
...
...
lm_eval/api/metrics.py
View file @
28c78d30
...
@@ -8,7 +8,6 @@ from collections.abc import Iterable
...
@@ -8,7 +8,6 @@ from collections.abc import Iterable
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
TypeVar
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
TypeVar
import
numpy
as
np
import
numpy
as
np
import
sacrebleu
from
lm_eval.api.registry
import
register_aggregation
,
register_metric
from
lm_eval.api.registry
import
register_aggregation
,
register_metric
...
@@ -92,6 +91,8 @@ def bleu(items):
...
@@ -92,6 +91,8 @@ def bleu(items):
Higher is better
Higher is better
"""
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
@@ -107,6 +108,8 @@ def chrf(items):
...
@@ -107,6 +108,8 @@ def chrf(items):
Higher is better # TODO I think
Higher is better # TODO I think
"""
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
@@ -123,6 +126,8 @@ def ter(items):
...
@@ -123,6 +126,8 @@ def ter(items):
Lower is better
Lower is better
"""
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
...
lm_eval/api/registry.py
View file @
28c78d30
import
logging
import
logging
from
typing
import
Callable
,
Dict
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Union
import
evaluate
as
hf_evaluate
from
lm_eval.api.model
import
LM
if
TYPE_CHECKING
:
from
lm_eval.api.model
import
LM
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
...
@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
def
register_model
(
*
names
):
def
register_model
(
*
names
):
from
lm_eval.api.model
import
LM
# either pass a list or a single alias.
# either pass a list or a single alias.
# function receives them as a tuple of strings
# function receives them as a tuple of strings
...
@@ -31,7 +32,7 @@ def register_model(*names):
...
@@ -31,7 +32,7 @@ def register_model(*names):
return
decorate
return
decorate
def
get_model
(
model_name
)
:
def
get_model
(
model_name
:
str
)
->
type
[
"LM"
]
:
try
:
try
:
return
MODEL_REGISTRY
[
model_name
]
return
MODEL_REGISTRY
[
model_name
]
except
KeyError
:
except
KeyError
:
...
@@ -46,7 +47,7 @@ ALL_TASKS = set()
...
@@ -46,7 +47,7 @@ ALL_TASKS = set()
func2task_index
=
{}
func2task_index
=
{}
def
register_task
(
name
):
def
register_task
(
name
:
str
):
def
decorate
(
fn
):
def
decorate
(
fn
):
assert
name
not
in
TASK_REGISTRY
,
(
assert
name
not
in
TASK_REGISTRY
,
(
f
"task named '
{
name
}
' conflicts with existing registered task!"
f
"task named '
{
name
}
' conflicts with existing registered task!"
...
@@ -120,7 +121,7 @@ def register_metric(**args):
...
@@ -120,7 +121,7 @@ def register_metric(**args):
return
decorate
return
decorate
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
:
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Optional
[
Callable
]
:
if
not
hf_evaluate_metric
:
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
return
METRIC_REGISTRY
[
name
]
...
@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
...
@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
)
)
try
:
try
:
import
evaluate
as
hf_evaluate
metric_object
=
hf_evaluate
.
load
(
name
)
metric_object
=
hf_evaluate
.
load
(
name
)
return
metric_object
.
compute
return
metric_object
.
compute
except
Exception
:
except
Exception
:
...
@@ -150,21 +153,21 @@ def register_aggregation(name: str):
...
@@ -150,21 +153,21 @@ def register_aggregation(name: str):
return
decorate
return
decorate
def
get_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
def
get_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]
]
:
try
:
try
:
return
AGGREGATION_REGISTRY
[
name
]
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
def
get_metric_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]
]
:
try
:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
)
->
bool
:
def
is_higher_better
(
metric_name
)
->
Optional
[
bool
]
:
try
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
except
KeyError
:
...
...
lm_eval/api/samplers.py
View file @
28c78d30
import
logging
import
logging
import
warnings
import
warnings
from
functools
import
partial
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Sequence
,
Union
import
datasets
import
datasets
...
@@ -181,7 +181,7 @@ class ContextSampler:
...
@@ -181,7 +181,7 @@ class ContextSampler:
return
chat_history
return
chat_history
def
sample
(
self
,
n
:
int
):
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]
:
"""
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
"""
...
@@ -190,7 +190,7 @@ class ContextSampler:
...
@@ -190,7 +190,7 @@ class ContextSampler:
class
FirstNSampler
(
ContextSampler
):
class
FirstNSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
None
:
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]
:
"""
"""
Draw the first `n` samples in order from the specified split.
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...
...
lm_eval/api/task.py
View file @
28c78d30
This diff is collapsed.
Click to expand it.
lm_eval/evaluator.py
View file @
28c78d30
...
@@ -287,7 +287,7 @@ def simple_evaluate(
...
@@ -287,7 +287,7 @@ def simple_evaluate(
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def
_adjust_config
(
task_dict
)
:
def
_adjust_config
(
task_dict
:
dict
[
str
,
"Task"
])
->
dict
[
str
,
"Task"
]
:
adjusted_task_dict
=
{}
adjusted_task_dict
=
{}
for
task_name
,
task_obj
in
task_dict
.
items
():
for
task_name
,
task_obj
in
task_dict
.
items
():
if
isinstance
(
task_obj
,
dict
):
if
isinstance
(
task_obj
,
dict
):
...
...
lm_eval/evaluator_utils.py
View file @
28c78d30
...
@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
...
@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr
,
pooled_sample_stderr
,
stderr_for_metric
,
stderr_for_metric
,
)
)
from
lm_eval.api.task
import
Task
from
lm_eval.api.task
import
ConfigurableTask
,
Task
from
lm_eval.utils
import
positional_deprecated
from
lm_eval.utils
import
positional_deprecated
...
@@ -58,7 +58,7 @@ class TaskOutput:
...
@@ -58,7 +58,7 @@ class TaskOutput:
group_alias
=
None
,
group_alias
=
None
,
is_group
=
None
,
is_group
=
None
,
):
):
self
.
task
=
task
self
.
task
:
Union
[
Task
,
ConfigurableTask
]
=
task
self
.
task_config
=
task_config
self
.
task_config
=
task_config
self
.
task_name
=
task_name
self
.
task_name
=
task_name
self
.
group_name
=
group_name
self
.
group_name
=
group_name
...
...
lm_eval/filters/__init__.py
View file @
28c78d30
from
functools
import
partial
from
functools
import
partial
from
typing
import
List
from
typing
import
List
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.registry
import
get_filter
from
lm_eval.api.registry
import
get_filter
...
@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
...
@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def
build_filter_ensemble
(
def
build_filter_ensemble
(
filter_name
:
str
,
components
:
List
[
L
ist
[
str
]]
filter_name
:
str
,
components
:
list
[
Union
[
list
[
dict
],
l
ist
[
str
]]
]
)
->
FilterEnsemble
:
)
->
FilterEnsemble
:
"""
"""
Create a filtering pipeline.
Create a filtering pipeline.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment