Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d19bd889
Commit
d19bd889
authored
Jul 21, 2025
by
Baber
Browse files
improve metric aggregation default and higher-better checks; add `TaskConfig.from_template`
parent
69d14fb3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
136 additions
and
9 deletions
+136
-9
lm_eval/api/registry.py
lm_eval/api/registry.py
+8
-4
lm_eval/config/task.py
lm_eval/config/task.py
+67
-2
lm_eval/config/template.py
lm_eval/config/template.py
+48
-3
lm_eval/config/utils.py
lm_eval/config/utils.py
+13
-0
No files found.
lm_eval/api/registry.py
View file @
d19bd889
...
...
@@ -167,20 +167,24 @@ def get_aggregation(name: str) -> Callable[..., Any] | None:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
dict
[
str
,
Callable
]]
|
None
:
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
dict
[
str
,
Callable
[...,
Any
]]]
:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!. Using default aggregation mean"
)
return
AGGREGATION_REGISTRY
[
"mean"
]
def
is_higher_better
(
metric_name
:
str
)
->
bool
|
None
:
def
is_higher_better
(
metric_name
:
str
)
->
bool
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
eval_logger
.
warning
(
f
"higher_is_better not specified for metric '
{
metric_name
}
'!"
f
"higher_is_better not specified for metric '
{
metric_name
}
'!
. Will default to True.
"
)
return
True
def
register_filter
(
name
:
str
):
...
...
lm_eval/config/task.py
View file @
d19bd889
...
...
@@ -14,6 +14,7 @@ from lm_eval.config.utils import maybe_serialize
if
TYPE_CHECKING
:
from
lm_eval.api.samplers
import
ContextSampler
from
lm_eval.api.task
import
Task
from
lm_eval.config.template
import
TemplateConfig
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -119,7 +120,7 @@ class FewshotConfig:
@
dataclass
class
TaskConfig
(
dict
)
:
class
TaskConfig
:
# task naming/registry
task
:
str
|
None
=
None
task_alias
:
str
|
None
=
None
...
...
@@ -240,7 +241,7 @@ class TaskConfig(dict):
name
=
metric_name
,
fn
=
get_metric
(
metric_name
),
aggregation_fn
=
get_metric_aggregation
(
metric_name
),
higher_is_better
=
is_higher_better
(
metric_name
),
higher_is_better
=
is_higher_better
(
metric_name
)
or
True
,
)
for
metric_name
in
_metric_list
)
...
...
@@ -357,6 +358,70 @@ class TaskConfig(dict):
"""Create a TaskConfig instance from a YAML-like dictionary."""
return
cls
(
**
data
)
@
classmethod
def
from_template
(
cls
,
template
:
TemplateConfig
,
**
kwargs
)
->
TaskConfig
:
"""Create a TaskConfig instance from a template.
Args:
template: TemplateConfig instance (MCQTemplateConfig or ClozeTemplateConfig)
**kwargs: Additional arguments to override template defaults
Returns:
TaskConfig instance configured from the template
"""
from
lm_eval.config.template
import
(
ClozeTemplateConfig
,
MCQTemplateConfig
,
)
# Extract base configuration from template
config_dict
=
{
"task"
:
template
.
task
,
"doc_to_text"
:
template
.
doc_to_text
,
"doc_to_choice"
:
template
.
doc_to_choice
,
"doc_to_target"
:
template
.
doc_to_target
,
"description"
:
template
.
description
,
"target_delimiter"
:
template
.
target_delimiter
,
"fewshot_delimiter"
:
template
.
fewshot_delimiter
,
"metric_list"
:
template
.
metric_list
,
}
# Add common template attributes if they exist
if
hasattr
(
template
,
"answer_suffix"
):
config_dict
[
"target_delimiter"
]
=
(
template
.
answer_suffix
+
template
.
target_delimiter
)
# Handle template-specific configurations
if
isinstance
(
template
,
MCQTemplateConfig
):
# For MCQ templates, set up multiple choice specific config
config_dict
[
"output_type"
]
=
"multiple_choice"
# MCQ templates typically use accuracy metrics
if
template
.
metric_list
is
None
:
config_dict
[
"metric_list"
]
=
[{
"metric"
:
"acc"
}]
elif
isinstance
(
template
,
ClozeTemplateConfig
):
# For Cloze templates, set up generation config
config_dict
[
"output_type"
]
=
"generate_until"
# Cloze templates typically use accuracy and normalized accuracy
if
template
.
metric_list
is
None
:
config_dict
[
"metric_list"
]
=
[{
"metric"
:
"acc"
},
{
"metric"
:
"acc_norm"
}]
else
:
# Generic template - try to infer output type
if
hasattr
(
template
,
"template"
):
if
template
.
template
==
"mcq"
:
config_dict
[
"output_type"
]
=
"multiple_choice"
elif
template
.
template
==
"cloze"
:
config_dict
[
"output_type"
]
=
"generate_until"
# Override with any user-provided kwargs
config_dict
.
update
(
kwargs
)
# Create and return TaskConfig instance
return
cls
(
**
config_dict
)
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
...
...
lm_eval/config/template.py
View file @
d19bd889
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
from
lm_eval.config.utils
import
create_mc_choices
if
TYPE_CHECKING
:
from
lm_eval.config.metric
import
MetricConfig
@
dataclass
class
TemplateConfig
:
class
TemplateConfig
(
ABC
)
:
"""Encapsulates information about a template."""
#
template
:
str
task
:
str
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
...
...
@@ -29,9 +34,22 @@ class TemplateConfig:
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
)
@
abstractmethod
def
_doc_to_text
(
self
,
doc
:
dict
)
->
str
:
"""Convert a document to text."""
raise
NotImplementedError
def
_doc_to_choice
(
self
,
doc
:
dict
)
->
str
:
"""Convert a document to choices."""
raise
NotImplementedError
def
_doc_to_target
(
self
,
doc
:
dict
)
->
int
|
str
:
"""Convert a document to target."""
raise
NotImplementedError
@
dataclass
class
MCQTemplateConfig
:
class
MCQTemplateConfig
(
TemplateConfig
)
:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
...
...
@@ -56,9 +74,36 @@ class MCQTemplateConfig:
fewshot_delimiter
:
str
=
"
\n\n
"
metric_list
:
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
])
def
_doc_to_text
(
self
,
doc
:
dict
)
->
str
:
"""Convert a document to text."""
doc_to_text
=
(
self
.
doc_to_text
if
isinstance
(
self
.
doc_to_text
,
str
)
else
self
.
doc_to_text
(
doc
)
)
return
self
.
context_prefix
+
doc_to_text
def
_doc_to_choice
(
self
,
doc
:
dict
)
->
str
:
if
callable
(
self
.
doc_to_choice
):
doc_to_choice
=
self
.
doc_to_choice
(
doc
)
elif
isinstance
(
self
.
doc_to_choice
,
str
):
doc_to_choice
=
doc
[
self
.
doc_to_choice
]
else
:
doc_to_choice
=
self
.
doc_to_choice
return
create_mc_choices
(
doc_to_choice
,
choice_delimiter
=
self
.
choice_delimiter
)
def
_doc_to_target
(
self
,
doc
:
dict
)
->
int
:
"""Convert a document to target."""
if
callable
(
self
.
doc_to_target
):
return
self
.
doc_to_target
(
doc
)
elif
isinstance
(
self
.
doc_to_target
,
str
):
return
doc
[
self
.
doc_to_target
]
else
:
return
self
.
doc_to_target
@
dataclass
class
ClozeTemplateConfig
:
class
ClozeTemplateConfig
(
TemplateConfig
)
:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
...
...
lm_eval/config/utils.py
View file @
d19bd889
...
...
@@ -28,3 +28,16 @@ def maybe_serialize(val: Callable | Any, keep_callable=False) -> Callable | Any:
return
(
serialize_callable
(
val
,
keep_callable
=
keep_callable
)
if
callable
(
val
)
else
val
)
def
create_mc_choices
(
choices
:
list
[
str
],
choice_delimiter
:
str
|
None
=
"
\n
"
)
->
str
:
"""Creates a multiple-choice question format from a list of choices."""
if
len
(
choices
)
<
2
:
raise
ValueError
(
"At least two choices are required for a multiple-choice question."
)
if
choice_delimiter
is
None
:
choice_delimiter
=
"
\n
"
formatted_choices
=
[
f
"
{
chr
(
65
+
i
)
}
.
{
choice
}
"
for
i
,
choice
in
enumerate
(
choices
)]
return
choice_delimiter
.
join
(
formatted_choices
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment