Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d9876b22
Commit
d9876b22
authored
Jul 22, 2025
by
Baber
Browse files
`check_gold_index_error` util; fix `process_results`; rm generate_until multiple-choice
parent
d19bd889
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
53 deletions
+76
-53
lm_eval/api/samplers.py
lm_eval/api/samplers.py
+30
-18
lm_eval/api/task.py
lm_eval/api/task.py
+19
-31
lm_eval/api/utils.py
lm_eval/api/utils.py
+21
-0
lm_eval/config/task.py
lm_eval/config/task.py
+6
-4
No files found.
lm_eval/api/samplers.py
View file @
d9876b22
from
__future__
import
annotations
import
logging
import
logging
import
warnings
import
warnings
from
collections.abc
import
Iterable
,
Sequence
from
functools
import
partial
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Any
import
datasets
import
datasets
...
@@ -18,9 +21,9 @@ class ContextSampler:
...
@@ -18,9 +21,9 @@ class ContextSampler:
def
__init__
(
def
__init__
(
self
,
self
,
docs
:
list
[
dict
],
docs
:
list
[
dict
],
task
:
Union
[
"Task"
,
"
ConfigurableTask
"
]
,
task
:
Task
|
ConfigurableTask
,
fewshot_indices
:
Optional
[
Iterable
]
=
None
,
fewshot_indices
:
Iterable
|
None
=
None
,
rnd
:
Optional
[
"Random"
]
=
None
,
rnd
:
Random
|
None
=
None
,
)
->
None
:
)
->
None
:
self
.
rnd
=
rnd
self
.
rnd
=
rnd
if
not
self
.
rnd
:
if
not
self
.
rnd
:
...
@@ -75,7 +78,7 @@ class ContextSampler:
...
@@ -75,7 +78,7 @@ class ContextSampler:
)
)
self
.
docs
=
self
.
docs
.
select
(
fewshot_indices
)
self
.
docs
=
self
.
docs
.
select
(
fewshot_indices
)
def
get_context
(
self
,
doc
:
dict
,
num_fewshot
:
int
,
gen_prefix
:
str
=
None
):
def
get_context
(
self
,
doc
:
dict
,
num_fewshot
:
int
,
gen_prefix
:
str
|
None
=
None
):
# draw an extra fewshot sample if using same split as evaluating on
# draw an extra fewshot sample if using same split as evaluating on
prefix
=
gen_prefix
+
" "
if
gen_prefix
else
""
prefix
=
gen_prefix
+
" "
if
gen_prefix
else
""
n_samples
=
(
n_samples
=
(
...
@@ -95,9 +98,12 @@ class ContextSampler:
...
@@ -95,9 +98,12 @@ class ContextSampler:
for
doc
in
selected_docs
:
for
doc
in
selected_docs
:
doc_content
=
self
.
doc_to_text
(
doc
)
doc_content
=
self
.
doc_to_text
(
doc
)
doc_target
=
self
.
doc_to_target
(
doc
)
doc_target
=
self
.
doc_to_target
(
doc
)
if
self
.
config
.
doc_to_choice
is
None
or
isinstance
(
doc_content
,
str
):
if
(
self
.
config
.
doc_to_choice
is
None
and
isinstance
(
doc_content
,
str
)
)
or
isinstance
(
doc_content
,
str
):
labeled_examples
+=
doc_content
labeled_examples
+=
doc_content
else
:
else
:
if
isinstance
(
doc_content
,
int
):
labeled_examples
+=
self
.
doc_to_choice
(
doc
)[
doc_content
]
labeled_examples
+=
self
.
doc_to_choice
(
doc
)[
doc_content
]
if
doc_target
!=
""
:
if
doc_target
!=
""
:
...
@@ -126,7 +132,7 @@ class ContextSampler:
...
@@ -126,7 +132,7 @@ class ContextSampler:
doc
:
dict
,
doc
:
dict
,
num_fewshot
:
int
,
num_fewshot
:
int
,
fewshot_as_multiturn
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
gen_prefix
:
Optional
[
str
]
=
None
,
gen_prefix
:
str
|
None
=
None
,
):
):
# TODO: Do we need any other delimiter
# TODO: Do we need any other delimiter
prefix
=
gen_prefix
+
" "
if
gen_prefix
else
""
prefix
=
gen_prefix
+
" "
if
gen_prefix
else
""
...
@@ -181,16 +187,22 @@ class ContextSampler:
...
@@ -181,16 +187,22 @@ class ContextSampler:
return
chat_history
return
chat_history
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]:
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]:
"""
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
"""
assert
self
.
rnd
is
not
None
,
(
"Error: `rnd` must be set to a random.Random instance before sampling."
)
return
self
.
rnd
.
sample
(
self
.
docs
,
n
)
return
self
.
rnd
.
sample
(
self
.
docs
,
n
)
class
FirstNSampler
(
ContextSampler
):
class
FirstNSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]:
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
[
str
,
Any
]
]:
"""
"""
Draw the first `n` samples in order from the specified split.
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...
@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler):
...
@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler):
class
BalancedSampler
(
ContextSampler
):
class
BalancedSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
None
:
def
sample
(
self
,
n
:
int
):
"""
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
TODO: what order should they be in? maybe random?
"""
"""
pass
raise
NotImplementedError
class
ManualSampler
(
ContextSampler
):
class
ManualSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
None
:
def
sample
(
self
,
n
:
int
):
""" """
""" """
pass
raise
NotImplementedError
SAMPLER_REGISTRY
=
{
SAMPLER_REGISTRY
:
dict
[
str
,
type
[
ContextSampler
]]
=
{
"default"
:
ContextSampler
,
"default"
:
ContextSampler
,
"first_n"
:
FirstNSampler
,
"first_n"
:
FirstNSampler
,
}
}
...
@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = {
...
@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = {
def
get_sampler
(
name
:
str
):
def
get_sampler
(
name
:
str
):
try
:
try
:
return
SAMPLER_REGISTRY
[
name
]
return
SAMPLER_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
as
e
:
raise
Value
Error
(
raise
Key
Error
(
f
"Attempted to use contextsampler '
{
name
}
', but no sampling strategy for this name found! Supported model names:
{
', '
.
join
(
SAMPLER_REGISTRY
.
keys
())
}
"
f
"Attempted to use contextsampler '
{
name
}
', but no sampling strategy for this name found! Supported model names:
{
', '
.
join
(
SAMPLER_REGISTRY
.
keys
())
}
"
)
)
from
e
lm_eval/api/task.py
View file @
d9876b22
...
@@ -21,6 +21,7 @@ from typing_extensions import deprecated
...
@@ -21,6 +21,7 @@ from typing_extensions import deprecated
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api.instance
import
Instance
,
OutputType
from
lm_eval.api.instance
import
Instance
,
OutputType
from
lm_eval.api.metrics
import
bits_per_byte
,
mean
,
weighted_perplexity
from
lm_eval.api.metrics
import
bits_per_byte
,
mean
,
weighted_perplexity
from
lm_eval.api.utils
import
check_gold_index_error
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.task
import
TaskConfig
from
lm_eval.config.task
import
TaskConfig
...
@@ -380,7 +381,7 @@ class Task(abc.ABC):
...
@@ -380,7 +381,7 @@ class Task(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
process_results
(
self
,
doc
:
dict
,
results
:
list
):
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
[
str
,
Any
]
:
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
dict where keys are the names of submetrics and values are the values of
the metric for that one document
the metric for that one document
...
@@ -390,7 +391,7 @@ class Task(abc.ABC):
...
@@ -390,7 +391,7 @@ class Task(abc.ABC):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
pass
raise
NotImplementedError
@
deprecated
(
"not used anymore"
)
@
deprecated
(
"not used anymore"
)
def
aggregation
(
self
):
def
aggregation
(
self
):
...
@@ -955,11 +956,13 @@ class ConfigurableTask(Task):
...
@@ -955,11 +956,13 @@ class ConfigurableTask(Task):
def
apply_filters
(
self
)
->
list
[
Instance
]
|
None
:
def
apply_filters
(
self
)
->
list
[
Instance
]
|
None
:
"""Iterates over FilterEnsembles and applies them to instances"""
"""Iterates over FilterEnsembles and applies them to instances"""
if
hasattr
(
self
,
"_filters"
):
if
hasattr
(
self
,
"_filters"
)
and
self
.
_instances
:
for
f
in
self
.
_filters
:
for
f
in
self
.
_filters
:
f
.
ensemble
.
apply
(
self
.
_instances
)
f
.
ensemble
.
apply
(
self
.
_instances
)
else
:
else
:
eval_logger
.
warning
(
"No filter defined, passing through instances"
)
eval_logger
.
warning
(
"No filter defined or instances found. Passing through instances"
)
return
self
.
_instances
return
self
.
_instances
def
should_decontaminate
(
self
):
def
should_decontaminate
(
self
):
...
@@ -993,13 +996,12 @@ class ConfigurableTask(Task):
...
@@ -993,13 +996,12 @@ class ConfigurableTask(Task):
"""
"""
return
doc
return
doc
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
|
None
=
None
):
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
[...,
str
]
|
None
=
None
)
->
str
:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_text = self.prompt
# doc_to_text = self.prompt
if
doc_to_text
is
not
None
:
doc_to_text
=
doc_to_text
or
self
.
config
.
doc_to_text
doc_to_text
=
doc_to_text
else
:
doc_to_text
=
self
.
config
.
doc_to_text
if
isinstance
(
doc_to_text
,
int
):
if
isinstance
(
doc_to_text
,
int
):
return
doc_to_text
return
doc_to_text
...
@@ -1261,7 +1263,7 @@ class ConfigurableTask(Task):
...
@@ -1261,7 +1263,7 @@ class ConfigurableTask(Task):
**
kwargs
,
**
kwargs
,
)
)
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
:
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
[
str
,
Any
]
:
if
callable
(
self
.
config
.
process_results
):
if
callable
(
self
.
config
.
process_results
):
return
self
.
config
.
process_results
(
doc
,
results
)
return
self
.
config
.
process_results
(
doc
,
results
)
...
@@ -1275,9 +1277,12 @@ class ConfigurableTask(Task):
...
@@ -1275,9 +1277,12 @@ class ConfigurableTask(Task):
**
({
"acc"
:
int
(
is_greedy
)}
if
"acc"
in
use_metric
else
{}),
**
({
"acc"
:
int
(
is_greedy
)}
if
"acc"
in
use_metric
else
{}),
}
}
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
(
loglikelihood
,)
=
results
(
loglikelihood
,
*
_
)
=
results
_words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
assert
isinstance
(
_target
:
=
self
.
doc_to_target
(
doc
),
str
),
(
_bytes
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
"Require target to be a string for loglikelihood_rolling"
)
_words
=
self
.
count_words
(
_target
)
_bytes
=
self
.
count_bytes
(
_target
)
return
{
return
{
**
(
**
(
{
"word_perplexity"
:
(
loglikelihood
,
_words
)}
{
"word_perplexity"
:
(
loglikelihood
,
_words
)}
...
@@ -1322,19 +1327,7 @@ class ConfigurableTask(Task):
...
@@ -1322,19 +1327,7 @@ class ConfigurableTask(Task):
else
:
else
:
gold
=
self
.
doc_to_target
(
doc
)
gold
=
self
.
doc_to_target
(
doc
)
gold_index_error
=
False
gold
,
gold_index_error
=
check_gold_index_error
(
choices
,
gold
)
if
isinstance
(
gold
,
list
):
gold
=
[
i
if
i
<
len
(
choices
)
else
-
100
for
i
in
gold
]
if
-
100
in
gold
:
gold_index_error
=
True
else
:
if
isinstance
(
gold
,
int
):
gold
=
gold
if
gold
<
len
(
choices
)
else
-
100
elif
isinstance
(
gold
,
str
):
gold
=
choices
.
index
(
gold
)
if
gold
in
choices
else
-
100
if
gold
==
-
100
:
gold_index_error
=
True
if
gold_index_error
:
if
gold_index_error
:
eval_logger
.
warning
(
eval_logger
.
warning
(
...
@@ -1382,11 +1375,6 @@ class ConfigurableTask(Task):
...
@@ -1382,11 +1375,6 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
gold
=
self
.
doc_to_target
(
doc
)
gold
=
self
.
doc_to_target
(
doc
)
result
=
results
[
0
]
result
=
results
[
0
]
if
self
.
config
.
doc_to_choice
is
not
None
:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices
=
self
.
doc_to_choice
(
doc
)
gold
=
choices
[
gold
]
for
metric
in
self
.
_metric_fn_list
:
for
metric
in
self
.
_metric_fn_list
:
try
:
try
:
result_score
=
self
.
_metric_fn_list
[
metric
](
result_score
=
self
.
_metric_fn_list
[
metric
](
...
...
lm_eval/api/utils.py
0 → 100644
View file @
d9876b22
from
__future__
import
annotations
def
check_gold_index_error
(
choices
:
list
[
int
]
|
list
[
str
],
gold
:
list
[
int
]
|
int
|
str
)
->
tuple
[
int
|
list
[
int
],
bool
]:
gold_index_error
=
False
if
isinstance
(
gold
,
list
):
gold
=
[
i
if
i
<
len
(
choices
)
else
-
100
for
i
in
gold
]
if
-
100
in
gold
:
gold_index_error
=
True
return
gold
,
gold_index_error
else
:
if
isinstance
(
gold
,
int
):
gold
=
gold
if
gold
<
len
(
choices
)
else
-
100
elif
isinstance
(
gold
,
str
):
gold
=
choices
.
index
(
gold
)
if
gold
in
choices
else
-
100
if
gold
==
-
100
:
gold_index_error
=
True
return
gold
,
gold_index_error
lm_eval/config/task.py
View file @
d9876b22
...
@@ -3,7 +3,7 @@ from __future__ import annotations
...
@@ -3,7 +3,7 @@ from __future__ import annotations
import
logging
import
logging
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
from
lm_eval.api.instance
import
OutputType
...
@@ -45,7 +45,9 @@ class FewshotConfig:
...
@@ -45,7 +45,9 @@ class FewshotConfig:
split
:
str
|
None
=
None
split
:
str
|
None
=
None
sampler
:
str
|
Callable
=
"default"
sampler
:
str
|
Callable
=
"default"
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
process_docs
:
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
|
None
=
None
process_docs
:
Callable
[[
list
[
dict
[
str
,
Any
]]],
Iterable
[
dict
[
str
,
Any
]]]
|
None
=
(
None
)
fewshot_indices
:
list
[
int
]
|
None
=
None
fewshot_indices
:
list
[
int
]
|
None
=
None
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
...
@@ -82,7 +84,7 @@ class FewshotConfig:
...
@@ -82,7 +84,7 @@ class FewshotConfig:
"samples must be either a list of dicts or a callable returning a list"
"samples must be either a list of dicts or a callable returning a list"
)
)
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
]
|
None
:
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
[
str
,
Any
]
]
|
None
:
"""Get processed documents from configured source."""
"""Get processed documents from configured source."""
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
if
raw_docs
is
None
:
if
raw_docs
is
None
:
...
@@ -93,7 +95,7 @@ class FewshotConfig:
...
@@ -93,7 +95,7 @@ class FewshotConfig:
return
raw_docs
return
raw_docs
@
property
@
property
def
get_sampler
(
self
):
def
get_sampler
(
self
)
->
Callable
[...,
Any
]
|
None
:
from
lm_eval.api
import
samplers
from
lm_eval.api
import
samplers
if
isinstance
(
self
.
sampler
,
str
):
if
isinstance
(
self
.
sampler
,
str
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment