Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d9876b22
Commit
d9876b22
authored
Jul 22, 2025
by
Baber
Browse files
`check_gold_index_error` util; fix `process_results`; rm generate_until multiple-choice
parent
d19bd889
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
53 deletions
+76
-53
lm_eval/api/samplers.py
lm_eval/api/samplers.py
+30
-18
lm_eval/api/task.py
lm_eval/api/task.py
+19
-31
lm_eval/api/utils.py
lm_eval/api/utils.py
+21
-0
lm_eval/config/task.py
lm_eval/config/task.py
+6
-4
No files found.
lm_eval/api/samplers.py
View file @
d9876b22
from
__future__
import
annotations
import
logging
import
warnings
from
collections.abc
import
Iterable
,
Sequence
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Any
import
datasets
...
...
@@ -18,9 +21,9 @@ class ContextSampler:
def
__init__
(
self
,
docs
:
list
[
dict
],
task
:
Union
[
"Task"
,
"
ConfigurableTask
"
]
,
fewshot_indices
:
Optional
[
Iterable
]
=
None
,
rnd
:
Optional
[
"Random"
]
=
None
,
task
:
Task
|
ConfigurableTask
,
fewshot_indices
:
Iterable
|
None
=
None
,
rnd
:
Random
|
None
=
None
,
)
->
None
:
self
.
rnd
=
rnd
if
not
self
.
rnd
:
...
...
@@ -75,7 +78,7 @@ class ContextSampler:
)
self
.
docs
=
self
.
docs
.
select
(
fewshot_indices
)
def
get_context
(
self
,
doc
:
dict
,
num_fewshot
:
int
,
gen_prefix
:
str
=
None
):
def
get_context
(
self
,
doc
:
dict
,
num_fewshot
:
int
,
gen_prefix
:
str
|
None
=
None
):
# draw an extra fewshot sample if using same split as evaluating on
prefix
=
gen_prefix
+
" "
if
gen_prefix
else
""
n_samples
=
(
...
...
@@ -95,10 +98,13 @@ class ContextSampler:
for
doc
in
selected_docs
:
doc_content
=
self
.
doc_to_text
(
doc
)
doc_target
=
self
.
doc_to_target
(
doc
)
if
self
.
config
.
doc_to_choice
is
None
or
isinstance
(
doc_content
,
str
):
if
(
self
.
config
.
doc_to_choice
is
None
and
isinstance
(
doc_content
,
str
)
)
or
isinstance
(
doc_content
,
str
):
labeled_examples
+=
doc_content
else
:
labeled_examples
+=
self
.
doc_to_choice
(
doc
)[
doc_content
]
if
isinstance
(
doc_content
,
int
):
labeled_examples
+=
self
.
doc_to_choice
(
doc
)[
doc_content
]
if
doc_target
!=
""
:
if
self
.
target_delimiter
.
isspace
()
and
str
(
doc_target
)[
0
].
isspace
():
...
...
@@ -126,7 +132,7 @@ class ContextSampler:
doc
:
dict
,
num_fewshot
:
int
,
fewshot_as_multiturn
:
bool
=
False
,
gen_prefix
:
Optional
[
str
]
=
None
,
gen_prefix
:
str
|
None
=
None
,
):
# TODO: Do we need any other delimiter
prefix
=
gen_prefix
+
" "
if
gen_prefix
else
""
...
...
@@ -181,16 +187,22 @@ class ContextSampler:
return
chat_history
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]:
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
assert
self
.
rnd
is
not
None
,
(
"Error: `rnd` must be set to a random.Random instance before sampling."
)
return
self
.
rnd
.
sample
(
self
.
docs
,
n
)
class
FirstNSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]:
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
[
str
,
Any
]
]:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...
...
@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler):
class
BalancedSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
None
:
def
sample
(
self
,
n
:
int
):
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
"""
pass
raise
NotImplementedError
class
ManualSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
None
:
def
sample
(
self
,
n
:
int
):
""" """
pass
raise
NotImplementedError
SAMPLER_REGISTRY
=
{
SAMPLER_REGISTRY
:
dict
[
str
,
type
[
ContextSampler
]]
=
{
"default"
:
ContextSampler
,
"first_n"
:
FirstNSampler
,
}
...
...
@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = {
def
get_sampler
(
name
:
str
):
try
:
return
SAMPLER_REGISTRY
[
name
]
except
KeyError
:
raise
Value
Error
(
except
KeyError
as
e
:
raise
Key
Error
(
f
"Attempted to use contextsampler '
{
name
}
', but no sampling strategy for this name found! Supported model names:
{
', '
.
join
(
SAMPLER_REGISTRY
.
keys
())
}
"
)
)
from
e
lm_eval/api/task.py
View file @
d9876b22
...
...
@@ -21,6 +21,7 @@ from typing_extensions import deprecated
from
lm_eval
import
utils
from
lm_eval.api.instance
import
Instance
,
OutputType
from
lm_eval.api.metrics
import
bits_per_byte
,
mean
,
weighted_perplexity
from
lm_eval.api.utils
import
check_gold_index_error
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.task
import
TaskConfig
...
...
@@ -380,7 +381,7 @@ class Task(abc.ABC):
pass
@
abc
.
abstractmethod
def
process_results
(
self
,
doc
:
dict
,
results
:
list
):
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
[
str
,
Any
]
:
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
...
...
@@ -390,7 +391,7 @@ class Task(abc.ABC):
:param results:
The results of the requests created in construct_requests.
"""
pass
raise
NotImplementedError
@
deprecated
(
"not used anymore"
)
def
aggregation
(
self
):
...
...
@@ -955,11 +956,13 @@ class ConfigurableTask(Task):
def
apply_filters
(
self
)
->
list
[
Instance
]
|
None
:
"""Iterates over FilterEnsembles and applies them to instances"""
if
hasattr
(
self
,
"_filters"
):
if
hasattr
(
self
,
"_filters"
)
and
self
.
_instances
:
for
f
in
self
.
_filters
:
f
.
ensemble
.
apply
(
self
.
_instances
)
else
:
eval_logger
.
warning
(
"No filter defined, passing through instances"
)
eval_logger
.
warning
(
"No filter defined or instances found. Passing through instances"
)
return
self
.
_instances
def
should_decontaminate
(
self
):
...
...
@@ -993,13 +996,12 @@ class ConfigurableTask(Task):
"""
return
doc
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
|
None
=
None
):
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
[...,
str
]
|
None
=
None
)
->
str
:
# if self.prompt is not None:
# doc_to_text = self.prompt
if
doc_to_text
is
not
None
:
doc_to_text
=
doc_to_text
else
:
doc_to_text
=
self
.
config
.
doc_to_text
doc_to_text
=
doc_to_text
or
self
.
config
.
doc_to_text
if
isinstance
(
doc_to_text
,
int
):
return
doc_to_text
...
...
@@ -1261,7 +1263,7 @@ class ConfigurableTask(Task):
**
kwargs
,
)
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
:
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
[
str
,
Any
]
:
if
callable
(
self
.
config
.
process_results
):
return
self
.
config
.
process_results
(
doc
,
results
)
...
...
@@ -1275,9 +1277,12 @@ class ConfigurableTask(Task):
**
({
"acc"
:
int
(
is_greedy
)}
if
"acc"
in
use_metric
else
{}),
}
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
(
loglikelihood
,)
=
results
_words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
_bytes
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
(
loglikelihood
,
*
_
)
=
results
assert
isinstance
(
_target
:
=
self
.
doc_to_target
(
doc
),
str
),
(
"Require target to be a string for loglikelihood_rolling"
)
_words
=
self
.
count_words
(
_target
)
_bytes
=
self
.
count_bytes
(
_target
)
return
{
**
(
{
"word_perplexity"
:
(
loglikelihood
,
_words
)}
...
...
@@ -1322,19 +1327,7 @@ class ConfigurableTask(Task):
else
:
gold
=
self
.
doc_to_target
(
doc
)
gold_index_error
=
False
if
isinstance
(
gold
,
list
):
gold
=
[
i
if
i
<
len
(
choices
)
else
-
100
for
i
in
gold
]
if
-
100
in
gold
:
gold_index_error
=
True
else
:
if
isinstance
(
gold
,
int
):
gold
=
gold
if
gold
<
len
(
choices
)
else
-
100
elif
isinstance
(
gold
,
str
):
gold
=
choices
.
index
(
gold
)
if
gold
in
choices
else
-
100
if
gold
==
-
100
:
gold_index_error
=
True
gold
,
gold_index_error
=
check_gold_index_error
(
choices
,
gold
)
if
gold_index_error
:
eval_logger
.
warning
(
...
...
@@ -1382,11 +1375,6 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
gold
=
self
.
doc_to_target
(
doc
)
result
=
results
[
0
]
if
self
.
config
.
doc_to_choice
is
not
None
:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices
=
self
.
doc_to_choice
(
doc
)
gold
=
choices
[
gold
]
for
metric
in
self
.
_metric_fn_list
:
try
:
result_score
=
self
.
_metric_fn_list
[
metric
](
...
...
lm_eval/api/utils.py
0 → 100644
View file @
d9876b22
from
__future__
import
annotations
def
check_gold_index_error
(
choices
:
list
[
int
]
|
list
[
str
],
gold
:
list
[
int
]
|
int
|
str
)
->
tuple
[
int
|
list
[
int
],
bool
]:
gold_index_error
=
False
if
isinstance
(
gold
,
list
):
gold
=
[
i
if
i
<
len
(
choices
)
else
-
100
for
i
in
gold
]
if
-
100
in
gold
:
gold_index_error
=
True
return
gold
,
gold_index_error
else
:
if
isinstance
(
gold
,
int
):
gold
=
gold
if
gold
<
len
(
choices
)
else
-
100
elif
isinstance
(
gold
,
str
):
gold
=
choices
.
index
(
gold
)
if
gold
in
choices
else
-
100
if
gold
==
-
100
:
gold_index_error
=
True
return
gold
,
gold_index_error
lm_eval/config/task.py
View file @
d9876b22
...
...
@@ -3,7 +3,7 @@ from __future__ import annotations
import
logging
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
...
...
@@ -45,7 +45,9 @@ class FewshotConfig:
split
:
str
|
None
=
None
sampler
:
str
|
Callable
=
"default"
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
process_docs
:
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
|
None
=
None
process_docs
:
Callable
[[
list
[
dict
[
str
,
Any
]]],
Iterable
[
dict
[
str
,
Any
]]]
|
None
=
(
None
)
fewshot_indices
:
list
[
int
]
|
None
=
None
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
...
...
@@ -82,7 +84,7 @@ class FewshotConfig:
"samples must be either a list of dicts or a callable returning a list"
)
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
]
|
None
:
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
[
str
,
Any
]
]
|
None
:
"""Get processed documents from configured source."""
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
if
raw_docs
is
None
:
...
...
@@ -93,7 +95,7 @@ class FewshotConfig:
return
raw_docs
@
property
def
get_sampler
(
self
):
def
get_sampler
(
self
)
->
Callable
[...,
Any
]
|
None
:
from
lm_eval.api
import
samplers
if
isinstance
(
self
.
sampler
,
str
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment