Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9c647fc1
Commit
9c647fc1
authored
Jun 30, 2025
by
Baber
Browse files
add FewshotConfig
parent
28c78d30
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
29 deletions
+72
-29
lm_eval/api/filter.py
lm_eval/api/filter.py
+2
-2
lm_eval/api/task.py
lm_eval/api/task.py
+65
-25
lm_eval/utils.py
lm_eval/utils.py
+5
-2
No files found.
lm_eval/api/filter.py
View file @
9c647fc1
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Iterable
,
List
,
Union
from
typing
import
Iterable
,
List
,
Union
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
...
@@ -40,7 +40,7 @@ class FilterEnsemble:
...
@@ -40,7 +40,7 @@ class FilterEnsemble:
"""
"""
name
:
str
name
:
str
filters
:
List
[
Callable
[[],
Filter
]]
filters
:
List
[
type
[
Filter
]]
def
apply
(
self
,
instances
:
List
[
Instance
])
->
None
:
def
apply
(
self
,
instances
:
List
[
Instance
])
->
None
:
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
...
...
lm_eval/api/task.py
View file @
9c647fc1
...
@@ -90,6 +90,12 @@ class FilterConfig:
...
@@ -90,6 +90,12 @@ class FilterConfig:
kwargs
:
Optional
[
dict
]
=
None
kwargs
:
Optional
[
dict
]
=
None
@
dataclass
class
FewshotConfig
:
sampler
:
str
samples
:
list
[
dict
]
@
dataclass
@
dataclass
class
TaskConfig
(
dict
):
class
TaskConfig
(
dict
):
# task naming/registry
# task naming/registry
...
@@ -185,6 +191,9 @@ class TaskConfig(dict):
...
@@ -185,6 +191,9 @@ class TaskConfig(dict):
metrics
=
[]
metrics
=
[]
if
self
.
metric_list
is
None
:
if
self
.
metric_list
is
None
:
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
eval_logger
.
info
(
f
"No metrics defined in config, using default metrics for
{
self
.
output_type
}
=
{
_metric_list
}
"
)
metrics
.
extend
(
metrics
.
extend
(
MetricConfig
(
MetricConfig
(
name
=
metric_name
,
name
=
metric_name
,
...
@@ -261,6 +270,35 @@ class TaskConfig(dict):
...
@@ -261,6 +270,35 @@ class TaskConfig(dict):
)
)
return
metrics
return
metrics
def
get_filters
(
self
):
if
self
.
filter_list
is
not
None
:
_filter_list
=
[]
if
isinstance
(
self
.
filter_list
,
dict
):
for
filter_config
in
self
.
filter_list
:
_filter_list
.
append
(
build_filter_ensemble
(
filter_name
=
filter_config
[
"name"
],
components
=
[
[
{
key
:
function
[
key
]
for
key
in
function
if
key
!=
"function"
}
]
for
function
in
filter_config
[
"filter"
]
],
)
)
else
:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger
.
debug
(
"No custom filters defined. Using default 'take_first' filter for handling repeats."
)
_filter_list
=
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
return
_filter_list
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
return
getattr
(
self
,
item
)
...
@@ -908,31 +946,33 @@ class ConfigurableTask(Task):
...
@@ -908,31 +946,33 @@ class ConfigurableTask(Task):
self
.
_training_docs
=
None
self
.
_training_docs
=
None
self
.
_fewshot_docs
=
None
self
.
_fewshot_docs
=
None
if
self
.
config
.
filter_list
is
not
None
:
self
.
_filters
=
self
.
config
.
get_filters
()
self
.
_filters
=
[]
if
isinstance
(
self
.
config
.
filter_list
,
dict
):
# if self.config.filter_list is not None:
for
filter_config
in
self
.
config
.
filter_list
:
# self._filters = []
self
.
_filters
.
append
(
# if isinstance(self.config.filter_list, dict):
build_filter_ensemble
(
# for filter_config in self.config.filter_list:
filter_config
[
"name"
],
# self._filters.append(
[
# build_filter_ensemble(
[
# filter_config["name"],
{
# [
key
:
function
[
key
]
# [
for
key
in
function
# {
if
key
!=
"function"
# key: function[key]
}
# for key in function
]
# if key != "function"
for
function
in
filter_config
[
"filter"
]
# }
],
# ]
)
# for function in filter_config["filter"]
)
# ],
else
:
# )
# TODO: handle repeats in a more general way rather than just discarding
# )
eval_logger
.
debug
(
# else:
"No custom filters defined. Using default 'take_first' filter for handling repeats."
# # TODO: handle repeats in a more general way rather than just discarding
)
# eval_logger.debug(
self
.
_filters
=
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
# "No custom filters defined. Using default 'take_first' filter for handling repeats."
# )
# self._filters = [build_filter_ensemble("none", [["take_first", None]])]
if
self
.
config
.
use_prompt
is
not
None
:
if
self
.
config
.
use_prompt
is
not
None
:
eval_logger
.
info
(
f
"loading prompt
{
self
.
config
.
use_prompt
}
"
)
eval_logger
.
info
(
f
"loading prompt
{
self
.
config
.
use_prompt
}
"
)
...
...
lm_eval/utils.py
View file @
9c647fc1
...
@@ -405,7 +405,8 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
...
@@ -405,7 +405,8 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
dic
=
result_dict
[
column
][
k
]
dic
=
result_dict
[
column
][
k
]
version
=
result_dict
[
"versions"
].
get
(
k
,
" N/A"
)
version
=
result_dict
[
"versions"
].
get
(
k
,
" N/A"
)
n
=
str
(
result_dict
.
get
(
"n-shot"
,
" "
).
get
(
k
,
" "
))
n
=
str
(
result_dict
.
get
(
"n-shot"
,
" "
).
get
(
k
,
" "
))
higher_is_better
=
result_dict
.
get
(
"higher_is_better"
,
{}).
get
(
k
,
{})
# TODO: fix this
# higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
if
"alias"
in
dic
:
if
"alias"
in
dic
:
k
=
dic
.
pop
(
"alias"
)
k
=
dic
.
pop
(
"alias"
)
...
@@ -418,7 +419,9 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
...
@@ -418,7 +419,9 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
if
m
.
endswith
(
"_stderr"
):
if
m
.
endswith
(
"_stderr"
):
continue
continue
hib
=
HIGHER_IS_BETTER_SYMBOLS
.
get
(
higher_is_better
.
get
(
m
),
""
)
# hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
# TODO: fix
hib
=
"↑"
v
=
"%.4f"
%
v
if
isinstance
(
v
,
float
)
else
v
v
=
"%.4f"
%
v
if
isinstance
(
v
,
float
)
else
v
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment